Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit b85a71d7 authored by Sean Thomas's avatar Sean Thomas
Browse files

Expose more functionality from hwtrust

Instead of having parsing code in remote_prov_utils and hwtrust, we
should expose more of the values that are parsed in hwtrust and use
that interface to implement the verification of a CSR.

The aim of this CL is to preserve the functionality that previously
existed in remote_prov_utils and, also, add more tests.

Test: atest libkeymint_remote_prov_support_test
Change-Id: Id5408a425f28ea99052ba954c34441ed9307a5d2
parent 3a840454
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -29,6 +29,11 @@ namespace aidl::android::hardware::security::keymint::remote_prov {
using bytevec = std::vector<uint8_t>;
using namespace cppcose;

constexpr std::string_view kErrorChallengeMismatch = "challenges do not match";
constexpr std::string_view kErrorUdsCertsAreRequired = "UdsCerts are required";
constexpr std::string_view kErrorKeysToSignMismatch = "KeysToSign do not match";
constexpr std::string_view kErrorDiceChainIsDegenerate = "DICE chain is degenerate";

extern bytevec kTestMacKey;

// The Google root key for the Endpoint Encryption Key chain, encoded as COSE_Sign1
+63 −338
Original line number Diff line number Diff line
@@ -123,37 +123,6 @@ ErrMsgOr<std::tuple<bytevec, bytevec>> getAffineCoordinates(const bytevec& pubKe
    return std::make_tuple(std::move(pubX), std::move(pubY));
}

ErrMsgOr<bytevec> getRawPublicKey(const EVP_PKEY_Ptr& pubKey) {
    if (pubKey.get() == nullptr) {
        return "pkey is null.";
    }
    int keyType = EVP_PKEY_base_id(pubKey.get());
    switch (keyType) {
        case EVP_PKEY_EC: {
            int nid = EVP_PKEY_bits(pubKey.get()) == 384 ? NID_secp384r1 : NID_X9_62_prime256v1;
            auto ecKey = EC_KEY_Ptr(EVP_PKEY_get1_EC_KEY(pubKey.get()));
            if (ecKey.get() == nullptr) {
                return "Failed to get ec key";
          }
          return ecKeyGetPublicKey(ecKey.get(), nid);
        }
        case EVP_PKEY_ED25519: {
            bytevec rawPubKey;
            size_t rawKeySize = 0;
            if (!EVP_PKEY_get_raw_public_key(pubKey.get(), NULL, &rawKeySize)) {
                return "Failed to get raw public key.";
            }
            rawPubKey.resize(rawKeySize);
            if (!EVP_PKEY_get_raw_public_key(pubKey.get(), rawPubKey.data(), &rawKeySize)) {
                return "Failed to get raw public key.";
            }
            return rawPubKey;
        }
        default:
            return "Unknown key type.";
    }
}

ErrMsgOr<std::tuple<bytevec, bytevec>> generateEc256KeyPair() {
    auto ec_key = EC_KEY_Ptr(EC_KEY_new());
    if (ec_key.get() == nullptr) {
@@ -166,7 +135,7 @@ ErrMsgOr<std::tuple<bytevec, bytevec>> generateEc256KeyPair() {
    }

    if (EC_KEY_set_group(ec_key.get(), group.get()) != 1 ||
        EC_KEY_generate_key(ec_key.get()) != 1 || EC_KEY_check_key(ec_key.get()) < 0) {
        EC_KEY_generate_key(ec_key.get()) != 1 || EC_KEY_check_key(ec_key.get()) != 1) {
        return "Error generating key";
    }

@@ -331,17 +300,22 @@ bytevec getProdEekChain(int32_t supportedEekCurve) {
    return chain.encode();
}

bool maybeOverrideAllowAnyMode(bool allowAnyMode) {
    // Use ro.build.type instead of ro.debuggable because ro.debuggable=1 for VTS testing
    std::string build_type = ::android::base::GetProperty("ro.build.type", "");
    if (!build_type.empty() && build_type != "user") {
        return true;
    }
    return allowAnyMode;
}

ErrMsgOr<std::vector<BccEntryData>> validateBcc(const cppbor::Array* bcc,
                                                hwtrust::DiceChain::Kind kind, bool allowAnyMode,
                                                bool allowDegenerate,
                                                const std::string& instanceName) {
    auto encodedBcc = bcc->encode();

    // Use ro.build.type instead of ro.debuggable because ro.debuggable=1 for VTS testing
    std::string build_type = ::android::base::GetProperty("ro.build.type", "");
    if (!build_type.empty() && build_type != "user") {
        allowAnyMode = true;
    }
    allowAnyMode = maybeOverrideAllowAnyMode(allowAnyMode);

    auto chain =
            hwtrust::DiceChain::Verify(encodedBcc, kind, allowAnyMode, deviceSuffix(instanceName));
@@ -779,230 +753,6 @@ ErrMsgOr<std::vector<BccEntryData>> verifyProductionProtectedData(
                               /*isFactory=*/false, allowAnyMode);
}

ErrMsgOr<X509_Ptr> parseX509Cert(const std::vector<uint8_t>& cert) {
    CRYPTO_BUFFER_Ptr certBuf(CRYPTO_BUFFER_new(cert.data(), cert.size(), nullptr));
    if (!certBuf.get()) {
        return "Failed to create crypto buffer.";
    }
    X509_Ptr result(X509_parse_from_buffer(certBuf.get()));
    if (!result.get()) {
        return "Failed to parse certificate.";
    }
    return result;
}

std::string getX509IssuerName(const X509_Ptr& cert) {
    char* name = X509_NAME_oneline(X509_get_issuer_name(cert.get()), nullptr, 0);
    std::string result(name);
    OPENSSL_free(name);
    return result;
}

std::string getX509SubjectName(const X509_Ptr& cert) {
    char* name = X509_NAME_oneline(X509_get_subject_name(cert.get()), nullptr, 0);
    std::string result(name);
    OPENSSL_free(name);
    return result;
}

// Validates the certificate chain and returns the leaf public key.
ErrMsgOr<bytevec> validateCertChain(const cppbor::Array& chain) {
    bytevec rawPubKey;
    for (size_t i = 0; i < chain.size(); ++i) {
        // Root must be self-signed.
        size_t signingCertIndex = (i > 0) ? i - 1 : i;
        auto& keyCertItem = chain[i];
        auto& signingCertItem = chain[signingCertIndex];
        if (!keyCertItem || !keyCertItem->asBstr()) {
            return "Key certificate must be a Bstr.";
        }
        if (!signingCertItem || !signingCertItem->asBstr()) {
            return "Signing certificate must be a Bstr.";
        }

        auto keyCert = parseX509Cert(keyCertItem->asBstr()->value());
        if (!keyCert) {
            return keyCert.message();
        }
        auto signingCert = parseX509Cert(signingCertItem->asBstr()->value());
        if (!signingCert) {
            return signingCert.message();
        }

        EVP_PKEY_Ptr pubKey(X509_get_pubkey(keyCert->get()));
        if (!pubKey.get()) {
            return "Failed to get public key.";
        }
        EVP_PKEY_Ptr signingPubKey(X509_get_pubkey(signingCert->get()));
        if (!signingPubKey.get()) {
            return "Failed to get signing public key.";
        }

        if (!X509_verify(keyCert->get(), signingPubKey.get())) {
            return "Verification of certificate " + std::to_string(i) +
                   " faile. OpenSSL error string: " + ERR_error_string(ERR_get_error(), NULL);
        }

        auto certIssuer = getX509IssuerName(*keyCert);
        auto signerSubj = getX509SubjectName(*signingCert);
        if (certIssuer != signerSubj) {
            return "Certificate " + std::to_string(i) + " has wrong issuer. Signer subject is " +
                   signerSubj + " Issuer subject is " + certIssuer;
        }
        if (i == chain.size() - 1) {
            auto key = getRawPublicKey(pubKey);
            if (!key) return key.moveMessage();
            rawPubKey = key.moveValue();
        }
    }
    return rawPubKey;
}

std::optional<std::string> validateUdsCerts(const cppbor::Map& udsCerts,
                                            const bytevec& udsCoseKeyBytes) {
    for (const auto& [signerName, udsCertChain] : udsCerts) {
        if (!signerName || !signerName->asTstr()) {
            return "Signer Name must be a Tstr.";
        }
        if (!udsCertChain || !udsCertChain->asArray()) {
            return "UDS certificate chain must be an Array.";
        }
        if (udsCertChain->asArray()->size() < 2) {
            return "UDS certificate chain must have at least two entries: root and leaf.";
        }

        auto leafPubKey = validateCertChain(*udsCertChain->asArray());
        if (!leafPubKey) {
            return leafPubKey.message();
        }
        auto coseKey = CoseKey::parse(udsCoseKeyBytes);
        if (!coseKey) {
            return coseKey.moveMessage();
        }
        auto curve = coseKey->getIntValue(CoseKey::CURVE);
        if (!curve) {
            return "CoseKey must contain curve.";
        }
        bytevec udsPub;
        if (curve == CoseKeyCurve::P256 || curve == CoseKeyCurve::P384) {
            auto pubKey = coseKey->getEcPublicKey();
            if (!pubKey) {
                return pubKey.moveMessage();
            }
            // convert public key to uncompressed form by prepending 0x04 at begin.
            pubKey->insert(pubKey->begin(), 0x04);
            udsPub = pubKey.moveValue();
        } else if (curve == CoseKeyCurve::ED25519) {
            auto& pubkey = coseKey->getMap().get(cppcose::CoseKey::PUBKEY_X);
            if (!pubkey || !pubkey->asBstr()) {
                return "Invalid public key.";
            }
            udsPub = pubkey->asBstr()->value();
        } else {
            return "Unknown curve.";
        }
        if (*leafPubKey != udsPub) {
            return "Leaf public key in UDS certificate chain doesn't match UDS public key.";
        }
    }
    return std::nullopt;
}

ErrMsgOr<std::unique_ptr<cppbor::Array>> parseAndValidateCsrPayload(
        const cppbor::Array& keysToSign, const std::vector<uint8_t>& csrPayload,
        const RpcHardwareInfo& rpcHardwareInfo, bool isFactory) {
    auto [parsedCsrPayload, _, errMsg] = cppbor::parse(csrPayload);
    if (!parsedCsrPayload) {
        return errMsg;
    }

    std::unique_ptr<cppbor::Array> parsed(parsedCsrPayload.release()->asArray());
    if (!parsed) {
        return "CSR payload is not a CBOR array.";
    }

    if (parsed->size() != 4U) {
        return "CSR payload must contain version, certificate type, device info, keys. "
               "However, the parsed CSR payload has " +
               std::to_string(parsed->size()) + " entries.";
    }

    auto signedVersion = parsed->get(0)->asUint();
    auto signedCertificateType = parsed->get(1)->asTstr();
    auto signedDeviceInfo = parsed->get(2)->asMap();
    auto signedKeys = parsed->get(3)->asArray();

    if (!signedVersion || signedVersion->value() != 3U) {
        return "CSR payload version must be an unsigned integer and must be equal to 3.";
    }
    if (!signedCertificateType) {
        // Certificate type is allowed to be extendend by vendor, i.e. we can't
        // enforce its value.
        return "Certificate type must be a Tstr.";
    }
    if (!signedDeviceInfo) {
        return "Device info must be an Map.";
    }
    if (!signedKeys) {
        return "Keys must be an Array.";
    }

    auto result =
            parseAndValidateDeviceInfo(signedDeviceInfo->encode(), rpcHardwareInfo, isFactory);
    if (!result) {
        return result.message();
    }

    if (signedKeys->encode() != keysToSign.encode()) {
        return "Signed keys do not match.";
    }

    return std::move(parsed);
}

ErrMsgOr<bytevec> parseAndValidateAuthenticatedRequestSignedPayload(
        const std::vector<uint8_t>& signedPayload, const std::vector<uint8_t>& challenge) {
    auto [parsedSignedPayload, _, errMsg] = cppbor::parse(signedPayload);
    if (!parsedSignedPayload) {
        return errMsg;
    }
    if (!parsedSignedPayload->asArray()) {
        return "SignedData payload is not a CBOR array.";
    }
    if (parsedSignedPayload->asArray()->size() != 2U) {
        return "SignedData payload must contain the challenge and request. However, the parsed "
               "SignedData payload has " +
               std::to_string(parsedSignedPayload->asArray()->size()) + " entries.";
    }

    auto signedChallenge = parsedSignedPayload->asArray()->get(0)->asBstr();
    auto signedRequest = parsedSignedPayload->asArray()->get(1)->asBstr();

    if (!signedChallenge) {
        return "Challenge must be a Bstr.";
    }

    if (challenge.size() > 64) {
        return "Challenge size must be between 0 and 64 bytes inclusive. "
               "However, challenge is " +
               std::to_string(challenge.size()) + " bytes long.";
    }

    auto challengeBstr = cppbor::Bstr(challenge);
    if (*signedChallenge != challengeBstr) {
        return "Signed challenge does not match."
               "\n  Actual: " +
               cppbor::prettyPrint(signedChallenge->asBstr(), 64 /* maxBStrSize */) +
               "\nExpected: " + cppbor::prettyPrint(&challengeBstr, 64 /* maxBStrSize */);
    }

    if (!signedRequest) {
        return "Request must be a Bstr.";
    }

    return signedRequest->value();
}

ErrMsgOr<hwtrust::DiceChain::Kind> getDiceChainKind() {
    int vendor_api_level = ::android::base::GetIntProperty("ro.vendor.api_level", -1);
    if (vendor_api_level <= __ANDROID_API_T__) {
@@ -1018,104 +768,79 @@ ErrMsgOr<hwtrust::DiceChain::Kind> getDiceChainKind() {
    }
}

ErrMsgOr<bytevec> parseAndValidateAuthenticatedRequest(const std::vector<uint8_t>& request,
                                                       const std::vector<uint8_t>& challenge,
                                                       const std::string& instanceName,
                                                       bool allowAnyMode = false,
                                                       bool allowDegenerate = true,
                                                       bool requireUdsCerts = false) {
    auto [parsedRequest, _, csrErrMsg] = cppbor::parse(request);
    if (!parsedRequest) {
        return csrErrMsg;
    }
    if (!parsedRequest->asArray()) {
        return "AuthenticatedRequest is not a CBOR array.";
ErrMsgOr<std::unique_ptr<cppbor::Array>> verifyCsr(
        const cppbor::Array& keysToSign, const std::vector<uint8_t>& encodedCsr,
        const RpcHardwareInfo& rpcHardwareInfo, const std::string& instanceName,
        const std::vector<uint8_t>& challenge, bool isFactory, bool allowAnyMode = false,
        bool allowDegenerate = true, bool requireUdsCerts = false) {
    if (rpcHardwareInfo.versionNumber != 3) {
        return "Remotely provisioned component version (" +
               std::to_string(rpcHardwareInfo.versionNumber) +
               ") does not match expected version (3).";
    }
    if (parsedRequest->asArray()->size() != 4U) {
        return "AuthenticatedRequest must contain version, UDS certificates, DICE chain, and "
               "signed data. However, the parsed AuthenticatedRequest has " +
               std::to_string(parsedRequest->asArray()->size()) + " entries.";

    auto diceChainKind = getDiceChainKind();
    if (!diceChainKind) {
        return diceChainKind.message();
    }

    auto version = parsedRequest->asArray()->get(0)->asUint();
    auto udsCerts = parsedRequest->asArray()->get(1)->asMap();
    auto diceCertChain = parsedRequest->asArray()->get(2)->asArray();
    auto signedData = parsedRequest->asArray()->get(3)->asArray();
    allowAnyMode = maybeOverrideAllowAnyMode(allowAnyMode);

    if (!version || version->value() != 1U) {
        return "AuthenticatedRequest version must be an unsigned integer and must be equal to 1.";
    }
    auto csr = hwtrust::Csr::validate(encodedCsr, *diceChainKind, isFactory, allowAnyMode,
                                      deviceSuffix(instanceName));

    if (!udsCerts) {
        return "AuthenticatedRequest UdsCerts must be a Map.";
    if (!csr.ok()) {
        return csr.error().message();
    }
    if (requireUdsCerts && udsCerts->size() == 0) {
        return "AuthenticatedRequest UdsCerts must not be empty.";

    if (!allowDegenerate) {
        auto diceChain = csr->getDiceChain();
        if (!diceChain.ok()) {
            return diceChain.error().message();
        }
    if (!diceCertChain) {
        return "AuthenticatedRequest DiceCertChain must be an Array.";

        if (!diceChain->IsProper()) {
            return kErrorDiceChainIsDegenerate;
        }
    if (!signedData) {
        return "AuthenticatedRequest SignedData must be an Array.";
    }

    // DICE chain is [ pubkey, + DiceChainEntry ].
    auto diceChainKind = getDiceChainKind();
    if (!diceChainKind) {
        return diceChainKind.message();
    if (requireUdsCerts && !csr->hasUdsCerts()) {
        return kErrorUdsCertsAreRequired;
    }

    auto diceContents =
            validateBcc(diceCertChain, *diceChainKind, allowAnyMode, allowDegenerate, instanceName);
    if (!diceContents) {
        return diceContents.message() + "\n" + prettyPrint(diceCertChain);
    auto equalChallenges = csr->compareChallenge(challenge);
    if (!equalChallenges.ok()) {
        return equalChallenges.error().message();
    }

    if (!diceCertChain->get(0)->asMap()) {
        return "AuthenticatedRequest The first entry in DiceCertChain must be a Map.";
    }
    auto udsPub = diceCertChain->get(0)->asMap()->encode();
    auto error = validateUdsCerts(*udsCerts, udsPub);
    if (error) {
        return *error;
    if (!*equalChallenges) {
        return kErrorChallengeMismatch;
    }

    if (diceContents->empty()) {
        return "AuthenticatedRequest DiceContents must not be empty.";
    }
    auto& kmDiceKey = diceContents->back().pubKey;
    auto signedPayload = verifyAndParseCoseSign1(signedData, kmDiceKey, /*aad=*/{});
    if (!signedPayload) {
        return signedPayload.message();
    auto equalKeysToSign = csr->compareKeysToSign(keysToSign.encode());
    if (!equalKeysToSign.ok()) {
        return equalKeysToSign.error().message();
    }

    auto payload = parseAndValidateAuthenticatedRequestSignedPayload(*signedPayload, challenge);
    if (!payload) {
        return payload.message();
    if (!*equalKeysToSign) {
        return kErrorKeysToSignMismatch;
    }

    return payload;
    auto csrPayload = csr->getCsrPayload();
    if (!csrPayload) {
        return csrPayload.error().message();
    }

ErrMsgOr<std::unique_ptr<cppbor::Array>> verifyCsr(
        const cppbor::Array& keysToSign, const std::vector<uint8_t>& csr,
        const RpcHardwareInfo& rpcHardwareInfo, const std::string& instanceName,
        const std::vector<uint8_t>& challenge, bool isFactory, bool allowAnyMode = false,
        bool allowDegenerate = true, bool requireUdsCerts = false) {
    if (rpcHardwareInfo.versionNumber != 3) {
        return "Remotely provisioned component version (" +
               std::to_string(rpcHardwareInfo.versionNumber) +
               ") does not match expected version (3).";
    auto [csrPayloadDecoded, _, errMsg] = cppbor::parse(*csrPayload);
    if (!csrPayloadDecoded) {
        return errMsg;
    }

    auto csrPayload = parseAndValidateAuthenticatedRequest(
            csr, challenge, instanceName, allowAnyMode, allowDegenerate, requireUdsCerts);

    if (!csrPayload) {
        return csrPayload.message();
    if (!csrPayloadDecoded->asArray()) {
        return "CSR payload is not an array.";
    }

    return parseAndValidateCsrPayload(keysToSign, *csrPayload, rpcHardwareInfo, isFactory);
    return std::unique_ptr<cppbor::Array>(csrPayloadDecoded.release()->asArray());
}

ErrMsgOr<std::unique_ptr<cppbor::Array>> verifyFactoryCsr(
@@ -1143,8 +868,8 @@ ErrMsgOr<bool> isCsrWithProperDiceChain(const std::vector<uint8_t>& encodedCsr,
        return diceChainKind.message();
    }

    auto csr = hwtrust::Csr::validate(encodedCsr, *diceChainKind, false /*allowAnyMode*/,
                                      deviceSuffix(instanceName));
    auto csr = hwtrust::Csr::validate(encodedCsr, *diceChainKind, false /*isFactory*/,
                                      false /*allowAnyMode*/, deviceSuffix(instanceName));
    if (!csr.ok()) {
        return csr.error().message();
    }
+96 −8

File changed.

Preview size limit exceeded, changes collapsed.

+1 −0
Original line number Diff line number Diff line
@@ -999,6 +999,7 @@ TEST_P(CertificateRequestV2Test, DeviceInfo) {

    std::unique_ptr<cppbor::Array> csrPayload = std::move(*result);
    ASSERT_TRUE(csrPayload);
    ASSERT_TRUE(csrPayload->size() > 2);

    auto deviceInfo = csrPayload->get(2)->asMap();
    ASSERT_TRUE(deviceInfo);