Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 88a86744 authored by Sean Thomas's avatar Sean Thomas
Browse files

Factor out repeated code into function

There was a lot of repeated code and an obstruction to factoring this
out into a function was that the move constructors for hwtrust::Csr and
hwtrust::DiceChain were not functional.

Test: atest libkeymint_remote_prov_support_test
Change-Id: I23fbdf00ab4edfa9d34308201edf2a81451b265e
parent d7e02b58
Loading
Loading
Loading
Loading
+30 −55
Original line number Diff line number Diff line
@@ -862,15 +862,15 @@ ErrMsgOr<std::unique_ptr<cppbor::Array>> verifyProductionCsr(const cppbor::Array
                     allowAnyMode);
}

ErrMsgOr<bool> isCsrWithProperDiceChain(const std::vector<uint8_t>& encodedCsr,
                                        const std::string& instanceName) {
ErrMsgOr<hwtrust::DiceChain> getDiceChain(const std::vector<uint8_t>& encodedCsr, bool isFactory,
                                          bool allowAnyMode, std::string_view instanceName) {
    auto diceChainKind = getDiceChainKind();
    if (!diceChainKind) {
        return diceChainKind.message();
    }

    auto csr = hwtrust::Csr::validate(encodedCsr, *diceChainKind, false /*isFactory*/,
                                      true /*allowAnyMode*/, deviceSuffix(instanceName));
    auto csr = hwtrust::Csr::validate(encodedCsr, *diceChainKind, isFactory, allowAnyMode,
                                      deviceSuffix(instanceName));
    if (!csr.ok()) {
        return csr.error().message();
    }
@@ -880,6 +880,16 @@ ErrMsgOr<bool> isCsrWithProperDiceChain(const std::vector<uint8_t>& encodedCsr,
        return diceChain.error().message();
    }

    return std::move(*diceChain);
}

ErrMsgOr<bool> isCsrWithProperDiceChain(const std::vector<uint8_t>& encodedCsr,
                                        const std::string& instanceName) {
    auto diceChain =
            getDiceChain(encodedCsr, /*isFactory=*/false, /*allowAnyMode=*/true, instanceName);
    if (!diceChain) {
        return diceChain.message();
    }
    return diceChain->IsProper();
}

@@ -898,20 +908,10 @@ ErrMsgOr<bool> compareRootPublicKeysInDiceChains(const std::vector<uint8_t>& enc
                                                 std::string_view instanceName1,
                                                 const std::vector<uint8_t>& encodedCsr2,
                                                 std::string_view instanceName2) {
    auto diceChainKind = getDiceChainKind();
    if (!diceChainKind) {
        return diceChainKind.message();
    }

    auto csr1 = hwtrust::Csr::validate(encodedCsr1, *diceChainKind, false /*isFactory*/,
                                       true /*allowAnyMode*/, deviceSuffix(instanceName1));
    if (!csr1.ok()) {
        return csr1.error().message();
    }

    auto diceChain1 = csr1->getDiceChain();
    if (!diceChain1.ok()) {
        return diceChain1.error().message();
    auto diceChain1 =
            getDiceChain(encodedCsr1, /*isFactory=*/false, /*allowAnyMode=*/true, instanceName1);
    if (!diceChain1) {
        return diceChain1.message();
    }

    auto proper1 = diceChain1->IsProper();
@@ -920,15 +920,10 @@ ErrMsgOr<bool> compareRootPublicKeysInDiceChains(const std::vector<uint8_t>& enc
               hexlify(encodedCsr1);
    }

    auto csr2 = hwtrust::Csr::validate(encodedCsr2, *diceChainKind, false /*isFactory*/,
                                       true /*allowAnyMode*/, deviceSuffix(instanceName2));
    if (!csr2.ok()) {
        return csr2.error().message();
    }

    auto diceChain2 = csr2->getDiceChain();
    if (!diceChain2.ok()) {
        return diceChain2.error().message();
    auto diceChain2 =
            getDiceChain(encodedCsr2, /*isFactory=*/false, /*allowAnyMode=*/true, instanceName2);
    if (!diceChain2) {
        return diceChain2.message();
    }

    auto proper2 = diceChain2->IsProper();
@@ -946,20 +941,10 @@ ErrMsgOr<bool> compareRootPublicKeysInDiceChains(const std::vector<uint8_t>& enc
}

ErrMsgOr<bool> verifyComponentNameInKeyMintDiceChain(const std::vector<uint8_t>& encodedCsr) {
    auto diceChainKind = getDiceChainKind();
    if (!diceChainKind) {
        return diceChainKind.message();
    }

    auto csr = hwtrust::Csr::validate(encodedCsr, *diceChainKind, false /*isFactory*/,
                                      true /*allowAnyMode*/, deviceSuffix(DEFAULT_INSTANCE_NAME));
    if (!csr.ok()) {
        return csr.error().message();
    }

    auto diceChain = csr->getDiceChain();
    if (!diceChain.ok()) {
        return diceChain.error().message();
    auto diceChain = getDiceChain(encodedCsr, /*isFactory=*/false, /*allowAnyMode=*/true,
                                  DEFAULT_INSTANCE_NAME);
    if (!diceChain) {
        return diceChain.message();
    }

    auto satisfied = diceChain->componentNameContains(kKeyMintComponentName);
@@ -972,20 +957,10 @@ ErrMsgOr<bool> verifyComponentNameInKeyMintDiceChain(const std::vector<uint8_t>&

ErrMsgOr<bool> hasNonNormalModeInDiceChain(const std::vector<uint8_t>& encodedCsr,
                                           std::string_view instanceName) {
    auto diceChainKind = getDiceChainKind();
    if (!diceChainKind) {
        return diceChainKind.message();
    }

    auto csr = hwtrust::Csr::validate(encodedCsr, *diceChainKind, false /*isFactory*/,
                                      true /*allowAnyMode*/, deviceSuffix(instanceName));
    if (!csr.ok()) {
        return csr.error().message();
    }

    auto diceChain = csr->getDiceChain();
    if (!diceChain.ok()) {
        return diceChain.error().message();
    auto diceChain =
            getDiceChain(encodedCsr, /*isFactory=*/false, /*allowAnyMode=*/true, instanceName);
    if (!diceChain) {
        return diceChain.message();
    }

    auto hasNonNormalModeInDiceChain = diceChain->hasNonNormalMode();