Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 739d98c8 authored by Treehugger Robot's avatar Treehugger Robot Committed by Gerrit Code Review
Browse files

Merge "[rkp_factory_tool] enforce the presence of UDS certs" into main

parents 65dc3587 e2307105
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -126,6 +126,7 @@ cc_test {
    ],
    shared_libs: [
        "libbase",
        "libbinder_ndk",
        "libcppbor",
        "libcppcose_rkp",
        "libcrypto",
+52 −0
Original line number Diff line number Diff line
/*
 * Copyright 2024 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <aidl/android/hardware/security/keymint/IRemotelyProvisionedComponent.h>
#include <aidl/android/hardware/security/keymint/RpcHardwareInfo.h>
#include <android-base/properties.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <cstdint>

namespace aidl::android::hardware::security::keymint::remote_prov {

using ::ndk::ScopedAStatus;

class MockIRemotelyProvisionedComponent : public IRemotelyProvisionedComponentDefault {
  public:
    MOCK_METHOD(ScopedAStatus, getHardwareInfo, (RpcHardwareInfo * _aidl_return), (override));
    MOCK_METHOD(ScopedAStatus, generateEcdsaP256KeyPair,
                (bool in_testMode, MacedPublicKey* out_macedPublicKey,
                 std::vector<uint8_t>* _aidl_return),
                (override));
    MOCK_METHOD(ScopedAStatus, generateCertificateRequest,
                (bool in_testMode, const std::vector<MacedPublicKey>& in_keysToSign,
                 const std::vector<uint8_t>& in_endpointEncryptionCertChain,
                 const std::vector<uint8_t>& in_challenge, DeviceInfo* out_deviceInfo,
                 ProtectedData* out_protectedData, std::vector<uint8_t>* _aidl_return),
                (override));
    MOCK_METHOD(ScopedAStatus, generateCertificateRequestV2,
                (const std::vector<MacedPublicKey>& in_keysToSign,
                 const std::vector<uint8_t>& in_challenge, std::vector<uint8_t>* _aidl_return),
                (override));
    MOCK_METHOD(ScopedAStatus, getInterfaceVersion, (int32_t* _aidl_return), (override));
    MOCK_METHOD(ScopedAStatus, getInterfaceHash, (std::string * _aidl_return), (override));
};

}  // namespace aidl::android::hardware::security::keymint::remote_prov
 No newline at end of file
+10 −1
Original line number Diff line number Diff line
@@ -94,6 +94,13 @@ const std::string DEFAULT_INSTANCE_NAME =
const std::string RKPVM_INSTANCE_NAME =
        "android.hardware.security.keymint.IRemotelyProvisionedComponent/avf";

/**
 * Returns the portion of an instance name after the /
 * e.g. for "android.hardware.security.keymint.IRemotelyProvisionedComponent/avf",
 * it returns "avf".
 */
std::string deviceSuffix(const std::string& name);

struct EekChain {
    bytevec chain;
    bytevec last_pubkey;
@@ -184,7 +191,9 @@ ErrMsgOr<std::vector<BccEntryData>> verifyProductionProtectedData(
ErrMsgOr<std::unique_ptr<cppbor::Array>> verifyFactoryCsr(
        const cppbor::Array& keysToSign, const std::vector<uint8_t>& csr,
        IRemotelyProvisionedComponent* provisionable, const std::string& instanceName,
        const std::vector<uint8_t>& challenge, bool allowDegenerate = true);
        const std::vector<uint8_t>& challenge, bool allowDegenerate = true,
        bool requireUdsCerts = false);

/**
 * Verify the CSR as if the device is a final production sample.
 */
+46 −30
Original line number Diff line number Diff line
@@ -52,8 +52,8 @@ using EVP_PKEY_CTX_Ptr = bssl::UniquePtr<EVP_PKEY_CTX>;
using X509_Ptr = bssl::UniquePtr<X509>;
using CRYPTO_BUFFER_Ptr = bssl::UniquePtr<CRYPTO_BUFFER>;

std::string device_suffix(const std::string& name) {
    size_t pos = name.find('/');
std::string deviceSuffix(const std::string& name) {
    size_t pos = name.rfind('/');
    if (pos == std::string::npos) {
        return name;
    }
@@ -344,15 +344,18 @@ ErrMsgOr<std::vector<BccEntryData>> validateBcc(const cppbor::Array* bcc,
    }

    auto chain =
            hwtrust::DiceChain::Verify(encodedBcc, kind, allowAnyMode, device_suffix(instanceName));
    if (!chain.ok()) return chain.error().message();

            hwtrust::DiceChain::Verify(encodedBcc, kind, allowAnyMode, deviceSuffix(instanceName));
    if (!chain.ok()) {
        return chain.error().message();
    }
    if (!allowDegenerate && !chain->IsProper()) {
        return "DICE chain is degenerate";
    }

    auto keys = chain->CosePublicKeys();
    if (!keys.ok()) return keys.error().message();
    if (!keys.ok()) {
        return keys.error().message();
    }
    std::vector<BccEntryData> result;
    for (auto& key : *keys) {
        result.push_back({std::move(key)});
@@ -857,7 +860,8 @@ ErrMsgOr<bytevec> validateCertChain(const cppbor::Array& chain) {
    return rawPubKey;
}

std::string validateUdsCerts(const cppbor::Map& udsCerts, const bytevec& udsCoseKeyBytes) {
std::optional<std::string> validateUdsCerts(const cppbor::Map& udsCerts,
                                            const bytevec& udsCoseKeyBytes) {
    for (const auto& [signerName, udsCertChain] : udsCerts) {
        if (!signerName || !signerName->asTstr()) {
            return "Signer Name must be a Tstr.";
@@ -874,8 +878,9 @@ std::string validateUdsCerts(const cppbor::Map& udsCerts, const bytevec& udsCose
            return leafPubKey.message();
        }
        auto coseKey = CoseKey::parse(udsCoseKeyBytes);
        if (!coseKey) return coseKey.moveMessage();

        if (!coseKey) {
            return coseKey.moveMessage();
        }
        auto curve = coseKey->getIntValue(CoseKey::CURVE);
        if (!curve) {
            return "CoseKey must contain curve.";
@@ -883,7 +888,9 @@ std::string validateUdsCerts(const cppbor::Map& udsCerts, const bytevec& udsCose
        bytevec udsPub;
        if (curve == CoseKeyCurve::P256 || curve == CoseKeyCurve::P384) {
            auto pubKey = coseKey->getEcPublicKey();
            if (!pubKey) return pubKey.moveMessage();
            if (!pubKey) {
                return pubKey.moveMessage();
            }
            // convert public key to uncompressed form by prepending 0x04 at begin.
            pubKey->insert(pubKey->begin(), 0x04);
            udsPub = pubKey.moveValue();
@@ -900,7 +907,7 @@ std::string validateUdsCerts(const cppbor::Map& udsCerts, const bytevec& udsCose
            return "Leaf public key in UDS certificate chain doesn't match UDS public key.";
        }
    }
    return "";
    return std::nullopt;
}

ErrMsgOr<std::unique_ptr<cppbor::Array>> parseAndValidateCsrPayload(
@@ -1016,7 +1023,8 @@ ErrMsgOr<bytevec> parseAndValidateAuthenticatedRequest(const std::vector<uint8_t
                                                       const std::vector<uint8_t>& challenge,
                                                       const std::string& instanceName,
                                                       bool allowAnyMode = false,
                                                       bool allowDegenerate = true) {
                                                       bool allowDegenerate = true,
                                                       bool requireUdsCerts = false) {
    auto [parsedRequest, _, csrErrMsg] = cppbor::parse(request);
    if (!parsedRequest) {
        return csrErrMsg;
@@ -1038,8 +1046,12 @@ ErrMsgOr<bytevec> parseAndValidateAuthenticatedRequest(const std::vector<uint8_t
    if (!version || version->value() != 1U) {
        return "AuthenticatedRequest version must be an unsigned integer and must be equal to 1.";
    }

    if (!udsCerts) {
        return "AuthenticatedRequest UdsCerts must be an Map.";
        return "AuthenticatedRequest UdsCerts must be a Map.";
    }
    if (requireUdsCerts && udsCerts->size() == 0) {
        return "AuthenticatedRequest UdsCerts must not be empty.";
    }
    if (!diceCertChain) {
        return "AuthenticatedRequest DiceCertChain must be an Array.";
@@ -1060,15 +1072,20 @@ ErrMsgOr<bytevec> parseAndValidateAuthenticatedRequest(const std::vector<uint8_t
        return diceContents.message() + "\n" + prettyPrint(diceCertChain);
    }

    if (!diceCertChain->get(0)->asMap()) {
        return "AuthenticatedRequest The first entry in DiceCertChain must be a Map.";
    }
    auto udsPub = diceCertChain->get(0)->asMap()->encode();
    auto& kmDiceKey = diceContents->back().pubKey;

    auto error = validateUdsCerts(*udsCerts, udsPub);
    if (!error.empty()) {
        return error;
    if (error) {
        return *error;
    }

    auto signedPayload = verifyAndParseCoseSign1(signedData, kmDiceKey, {} /* aad */);
    if (diceContents->empty()) {
        return "AuthenticatedRequest DiceContents must not be empty.";
    }
    auto& kmDiceKey = diceContents->back().pubKey;
    auto signedPayload = verifyAndParseCoseSign1(signedData, kmDiceKey, /*aad=*/{});
    if (!signedPayload) {
        return signedPayload.message();
    }
@@ -1081,13 +1098,11 @@ ErrMsgOr<bytevec> parseAndValidateAuthenticatedRequest(const std::vector<uint8_t
    return payload;
}

ErrMsgOr<std::unique_ptr<cppbor::Array>> verifyCsr(const cppbor::Array& keysToSign,
                                                   const std::vector<uint8_t>& csr,
                                                   IRemotelyProvisionedComponent* provisionable,
                                                   const std::string& instanceName,
                                                   const std::vector<uint8_t>& challenge,
                                                   bool isFactory, bool allowAnyMode = false,
                                                   bool allowDegenerate = true) {
ErrMsgOr<std::unique_ptr<cppbor::Array>> verifyCsr(
        const cppbor::Array& keysToSign, const std::vector<uint8_t>& csr,
        IRemotelyProvisionedComponent* provisionable, const std::string& instanceName,
        const std::vector<uint8_t>& challenge, bool isFactory, bool allowAnyMode = false,
        bool allowDegenerate = true, bool requireUdsCerts = false) {
    RpcHardwareInfo info;
    provisionable->getHardwareInfo(&info);
    if (info.versionNumber != 3) {
@@ -1095,8 +1110,9 @@ ErrMsgOr<std::unique_ptr<cppbor::Array>> verifyCsr(const cppbor::Array& keysToSi
               ") does not match expected version (3).";
    }

    auto csrPayload = parseAndValidateAuthenticatedRequest(csr, challenge, instanceName,
                                                           allowAnyMode, allowDegenerate);
    auto csrPayload = parseAndValidateAuthenticatedRequest(
            csr, challenge, instanceName, allowAnyMode, allowDegenerate, requireUdsCerts);

    if (!csrPayload) {
        return csrPayload.message();
    }
@@ -1107,9 +1123,9 @@ ErrMsgOr<std::unique_ptr<cppbor::Array>> verifyCsr(const cppbor::Array& keysToSi
ErrMsgOr<std::unique_ptr<cppbor::Array>> verifyFactoryCsr(
        const cppbor::Array& keysToSign, const std::vector<uint8_t>& csr,
        IRemotelyProvisionedComponent* provisionable, const std::string& instanceName,
        const std::vector<uint8_t>& challenge, bool allowDegenerate) {
        const std::vector<uint8_t>& challenge, bool allowDegenerate, bool requireUdsCerts) {
    return verifyCsr(keysToSign, csr, provisionable, instanceName, challenge, /*isFactory=*/true,
                     /*allowAnyMode=*/false, allowDegenerate);
                     /*allowAnyMode=*/false, allowDegenerate, requireUdsCerts);
}

ErrMsgOr<std::unique_ptr<cppbor::Array>> verifyProductionCsr(
@@ -1153,7 +1169,7 @@ ErrMsgOr<bool> isCsrWithProperDiceChain(const std::vector<uint8_t>& csr,

    auto encodedDiceChain = diceCertChain->encode();
    auto chain = hwtrust::DiceChain::Verify(encodedDiceChain, *diceChainKind,
                                            /*allowAnyMode=*/false, device_suffix(instanceName));
                                            /*allowAnyMode=*/false, deviceSuffix(instanceName));
    if (!chain.ok()) return chain.error().message();
    return chain->IsProper();
}
+326 −0

File changed.

Preview size limit exceeded, changes collapsed.