Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit b1ce80ca authored by Yifan Hong's avatar Yifan Hong
Browse files

binder: Add tests for using pre-signed certificates.

Test: binderRpcTest
Fixes: 199344157
Change-Id: I0f9d8ce3d4fadecd197d87393f689bdeb35dbc56
parent ff73aa94
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -70,4 +70,14 @@ status_t RpcAuthSelfSigned::configure(SSL_CTX* ctx) {
    return OK;
}

status_t RpcAuthPreSigned::configure(SSL_CTX* ctx) {
    if (!SSL_CTX_use_PrivateKey(ctx, mPkey.get())) {
        return INVALID_OPERATION;
    }
    if (!SSL_CTX_use_certificate(ctx, mCert.get())) {
        return INVALID_OPERATION;
    }
    return OK;
}

} // namespace android
+11 −0
Original line number Diff line number Diff line
@@ -35,4 +35,15 @@ private:
    const uint32_t mValidSeconds;
};

class RpcAuthPreSigned : public RpcAuth {
public:
    RpcAuthPreSigned(bssl::UniquePtr<EVP_PKEY> pkey, bssl::UniquePtr<X509> cert)
          : mPkey(std::move(pkey)), mCert(std::move(cert)) {}
    status_t configure(SSL_CTX* ctx) override;

private:
    bssl::UniquePtr<EVP_PKEY> mPkey;
    bssl::UniquePtr<X509> mCert;
};

} // namespace android
+128 −63
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@
#include <binder/ProcessState.h>
#include <binder/RpcServer.h>
#include <binder/RpcSession.h>
#include <binder/RpcTlsUtils.h>
#include <binder/RpcTransport.h>
#include <binder/RpcTransportRaw.h>
#include <binder/RpcTransportTls.h>
@@ -1439,37 +1440,10 @@ TEST(BinderRpc, Java) {
INSTANTIATE_TEST_CASE_P(BinderRpc, BinderRpcSimple, ::testing::ValuesIn(RpcSecurityValues()),
                        BinderRpcSimple::PrintTestParam);

class RpcTransportTest
      : public ::testing::TestWithParam<
                std::tuple<SocketType, RpcSecurity, std::optional<RpcCertificateFormat>>> {
class RpcTransportTestUtils {
public:
    using Param = std::tuple<SocketType, RpcSecurity, std::optional<RpcCertificateFormat>>;
    using ConnectToServer = std::function<base::unique_fd()>;
    static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) {
        auto [socketType, rpcSecurity, certificateFormat] = info.param;
        auto ret = PrintToString(socketType) + "_" + newFactory(rpcSecurity)->toCString();
        if (certificateFormat.has_value()) ret += "_" + PrintToString(*certificateFormat);
        return ret;
    }
    static std::vector<ParamType> getRpcTranportTestParams() {
        std::vector<RpcTransportTest::ParamType> ret;
        for (auto socketType : testSocketTypes(false /* hasPreconnected */)) {
            for (auto rpcSecurity : RpcSecurityValues()) {
                switch (rpcSecurity) {
                    case RpcSecurity::RAW: {
                        ret.emplace_back(socketType, rpcSecurity, std::nullopt);
                    } break;
                    case RpcSecurity::TLS: {
                        ret.emplace_back(socketType, rpcSecurity, RpcCertificateFormat::PEM);
                        ret.emplace_back(socketType, rpcSecurity, RpcCertificateFormat::DER);
                    } break;
                }
            }
        }
        return ret;
    }
    void TearDown() override {
        for (auto& server : mServers) server->shutdownAndWait();
    }

    // A server that handles client socket connections.
    class Server {
@@ -1477,8 +1451,10 @@ public:
        explicit Server() {}
        Server(Server&&) = default;
        ~Server() { shutdownAndWait(); }
        [[nodiscard]] AssertionResult setUp() {
            auto [socketType, rpcSecurity, certificateFormat] = GetParam();
        [[nodiscard]] AssertionResult setUp(
                const Param& param,
                std::unique_ptr<RpcAuth> auth = std::make_unique<RpcAuthSelfSigned>()) {
            auto [socketType, rpcSecurity, certificateFormat] = param;
            auto rpcServer = RpcServer::make(newFactory(rpcSecurity));
            rpcServer->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
            switch (socketType) {
@@ -1529,7 +1505,7 @@ public:
            }
            mFd = rpcServer->releaseServer();
            if (!mFd.ok()) return AssertionFailure() << "releaseServer returns invalid fd";
            mCtx = newFactory(rpcSecurity, mCertVerifier)->newServerCtx();
            mCtx = newFactory(rpcSecurity, mCertVerifier, std::move(auth))->newServerCtx();
            if (mCtx == nullptr) return AssertionFailure() << "newServerCtx";
            mSetup = true;
            return AssertionSuccess();
@@ -1608,8 +1584,8 @@ public:
    public:
        explicit Client(ConnectToServer connectToServer) : mConnectToServer(connectToServer) {}
        Client(Client&&) = default;
        [[nodiscard]] AssertionResult setUp() {
            auto [socketType, rpcSecurity, certificateFormat] = GetParam();
        [[nodiscard]] AssertionResult setUp(const Param& param) {
            auto [socketType, rpcSecurity, certificateFormat] = param;
            mFdTrigger = FdTrigger::make();
            mCtx = newFactory(rpcSecurity, mCertVerifier)->newClientCtx();
            if (mCtx == nullptr) return AssertionFailure() << "newClientCtx";
@@ -1662,8 +1638,9 @@ public:

    // Make A trust B.
    template <typename A, typename B>
    status_t trust(A* a, B* b) {
        auto [socketType, rpcSecurity, certificateFormat] = GetParam();
    static status_t trust(RpcSecurity rpcSecurity,
                          std::optional<RpcCertificateFormat> certificateFormat, const A& a,
                          const B& b) {
        if (rpcSecurity != RpcSecurity::TLS) return OK;
        LOG_ALWAYS_FATAL_IF(!certificateFormat.has_value());
        auto bCert = b->getCtx()->getCertificate(*certificateFormat);
@@ -1671,15 +1648,48 @@ public:
    }

    static constexpr const char* kMessage = "hello";
    std::vector<std::unique_ptr<Server>> mServers;
};

class RpcTransportTest : public testing::TestWithParam<RpcTransportTestUtils::Param> {
public:
    using Server = RpcTransportTestUtils::Server;
    using Client = RpcTransportTestUtils::Client;
    static inline std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) {
        auto [socketType, rpcSecurity, certificateFormat] = info.param;
        auto ret = PrintToString(socketType) + "_" + newFactory(rpcSecurity)->toCString();
        if (certificateFormat.has_value()) ret += "_" + PrintToString(*certificateFormat);
        return ret;
    }
    static std::vector<ParamType> getRpcTranportTestParams() {
        std::vector<ParamType> ret;
        for (auto socketType : testSocketTypes(false /* hasPreconnected */)) {
            for (auto rpcSecurity : RpcSecurityValues()) {
                switch (rpcSecurity) {
                    case RpcSecurity::RAW: {
                        ret.emplace_back(socketType, rpcSecurity, std::nullopt);
                    } break;
                    case RpcSecurity::TLS: {
                        ret.emplace_back(socketType, rpcSecurity, RpcCertificateFormat::PEM);
                        ret.emplace_back(socketType, rpcSecurity, RpcCertificateFormat::DER);
                    } break;
                }
            }
        }
        return ret;
    }
    template <typename A, typename B>
    status_t trust(const A& a, const B& b) {
        auto [socketType, rpcSecurity, certificateFormat] = GetParam();
        return RpcTransportTestUtils::trust(rpcSecurity, certificateFormat, a, b);
    }
};

TEST_P(RpcTransportTest, GoodCertificate) {
    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
    ASSERT_TRUE(server->setUp());
    auto server = std::make_unique<Server>();
    ASSERT_TRUE(server->setUp(GetParam()));

    Client client(server->getConnectToServerFn());
    ASSERT_TRUE(client.setUp());
    ASSERT_TRUE(client.setUp(GetParam()));

    ASSERT_EQ(OK, trust(&client, server));
    ASSERT_EQ(OK, trust(server, &client));
@@ -1689,13 +1699,13 @@ TEST_P(RpcTransportTest, GoodCertificate) {
}

TEST_P(RpcTransportTest, MultipleClients) {
    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
    ASSERT_TRUE(server->setUp());
    auto server = std::make_unique<Server>();
    ASSERT_TRUE(server->setUp(GetParam()));

    std::vector<Client> clients;
    for (int i = 0; i < 2; i++) {
        auto& client = clients.emplace_back(server->getConnectToServerFn());
        ASSERT_TRUE(client.setUp());
        ASSERT_TRUE(client.setUp(GetParam()));
        ASSERT_EQ(OK, trust(&client, server));
        ASSERT_EQ(OK, trust(server, &client));
    }
@@ -1707,11 +1717,11 @@ TEST_P(RpcTransportTest, MultipleClients) {
TEST_P(RpcTransportTest, UntrustedServer) {
    auto [socketType, rpcSecurity, certificateFormat] = GetParam();

    auto untrustedServer = mServers.emplace_back(std::make_unique<Server>()).get();
    ASSERT_TRUE(untrustedServer->setUp());
    auto untrustedServer = std::make_unique<Server>();
    ASSERT_TRUE(untrustedServer->setUp(GetParam()));

    Client client(untrustedServer->getConnectToServerFn());
    ASSERT_TRUE(client.setUp());
    ASSERT_TRUE(client.setUp(GetParam()));

    ASSERT_EQ(OK, trust(untrustedServer, &client));

@@ -1724,14 +1734,14 @@ TEST_P(RpcTransportTest, UntrustedServer) {
}
TEST_P(RpcTransportTest, MaliciousServer) {
    auto [socketType, rpcSecurity, certificateFormat] = GetParam();
    auto validServer = mServers.emplace_back(std::make_unique<Server>()).get();
    ASSERT_TRUE(validServer->setUp());
    auto validServer = std::make_unique<Server>();
    ASSERT_TRUE(validServer->setUp(GetParam()));

    auto maliciousServer = mServers.emplace_back(std::make_unique<Server>()).get();
    ASSERT_TRUE(maliciousServer->setUp());
    auto maliciousServer = std::make_unique<Server>();
    ASSERT_TRUE(maliciousServer->setUp(GetParam()));

    Client client(maliciousServer->getConnectToServerFn());
    ASSERT_TRUE(client.setUp());
    ASSERT_TRUE(client.setUp(GetParam()));

    ASSERT_EQ(OK, trust(&client, validServer));
    ASSERT_EQ(OK, trust(validServer, &client));
@@ -1747,11 +1757,11 @@ TEST_P(RpcTransportTest, MaliciousServer) {

TEST_P(RpcTransportTest, UntrustedClient) {
    auto [socketType, rpcSecurity, certificateFormat] = GetParam();
    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
    ASSERT_TRUE(server->setUp());
    auto server = std::make_unique<Server>();
    ASSERT_TRUE(server->setUp(GetParam()));

    Client client(server->getConnectToServerFn());
    ASSERT_TRUE(client.setUp());
    ASSERT_TRUE(client.setUp(GetParam()));

    ASSERT_EQ(OK, trust(&client, server));

@@ -1766,13 +1776,13 @@ TEST_P(RpcTransportTest, UntrustedClient) {

TEST_P(RpcTransportTest, MaliciousClient) {
    auto [socketType, rpcSecurity, certificateFormat] = GetParam();
    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
    ASSERT_TRUE(server->setUp());
    auto server = std::make_unique<Server>();
    ASSERT_TRUE(server->setUp(GetParam()));

    Client validClient(server->getConnectToServerFn());
    ASSERT_TRUE(validClient.setUp());
    ASSERT_TRUE(validClient.setUp(GetParam()));
    Client maliciousClient(server->getConnectToServerFn());
    ASSERT_TRUE(maliciousClient.setUp());
    ASSERT_TRUE(maliciousClient.setUp(GetParam()));

    ASSERT_EQ(OK, trust(&validClient, server));
    ASSERT_EQ(OK, trust(&maliciousClient, server));
@@ -1790,7 +1800,7 @@ TEST_P(RpcTransportTest, Trigger) {
    std::condition_variable writeCv;
    bool shouldContinueWriting = false;
    auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) {
        std::string message(kMessage);
        std::string message(RpcTransportTestUtils::kMessage);
        auto status =
                serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size());
        if (status != OK) return AssertionFailure() << statusToString(status);
@@ -1810,12 +1820,12 @@ TEST_P(RpcTransportTest, Trigger) {
        return AssertionSuccess();
    };

    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
    ASSERT_TRUE(server->setUp());
    auto server = std::make_unique<Server>();
    ASSERT_TRUE(server->setUp(GetParam()));

    // Set up client
    Client client(server->getConnectToServerFn());
    ASSERT_TRUE(client.setUp());
    ASSERT_TRUE(client.setUp(GetParam()));

    // Exchange keys
    ASSERT_EQ(OK, trust(&client, server));
@@ -1828,7 +1838,7 @@ TEST_P(RpcTransportTest, Trigger) {
    ASSERT_TRUE(client.setUpTransport());
    // read the first message. This ensures that server has finished handshake and start handling
    // client fd. Server thread should pause at writeCv.wait_for().
    ASSERT_TRUE(client.readMessage(kMessage));
    ASSERT_TRUE(client.readMessage(RpcTransportTestUtils::kMessage));
    // Trigger server shutdown after server starts handling client FD. This ensures that the second
    // write is on an FdTrigger that has been shut down.
    server->shutdown();
@@ -1848,6 +1858,61 @@ INSTANTIATE_TEST_CASE_P(BinderRpc, RpcTransportTest,
                        ::testing::ValuesIn(RpcTransportTest::getRpcTranportTestParams()),
                        RpcTransportTest::PrintParamInfo);

class RpcTransportTlsKeyTest
      : public testing::TestWithParam<std::tuple<SocketType, RpcCertificateFormat, RpcKeyFormat>> {
public:
    template <typename A, typename B>
    status_t trust(const A& a, const B& b) {
        auto [socketType, certificateFormat, keyFormat] = GetParam();
        return RpcTransportTestUtils::trust(RpcSecurity::TLS, certificateFormat, a, b);
    }
    static std::string PrintParamInfo(const testing::TestParamInfo<ParamType>& info) {
        auto [socketType, certificateFormat, keyFormat] = info.param;
        auto ret = PrintToString(socketType) + "_certificate_" + PrintToString(certificateFormat) +
                "_key_" + PrintToString(keyFormat);
        return ret;
    };
};

TEST_P(RpcTransportTlsKeyTest, PreSignedCertificate) {
    auto [socketType, certificateFormat, keyFormat] = GetParam();

    std::vector<uint8_t> pkeyData, certData;
    {
        auto pkey = makeKeyPairForSelfSignedCert();
        ASSERT_NE(nullptr, pkey);
        auto cert = makeSelfSignedCert(pkey.get(), kCertValidSeconds);
        ASSERT_NE(nullptr, cert);
        pkeyData = serializeUnencryptedPrivatekey(pkey.get(), keyFormat);
        certData = serializeCertificate(cert.get(), certificateFormat);
    }

    auto desPkey = deserializeUnencryptedPrivatekey(pkeyData, keyFormat);
    auto desCert = deserializeCertificate(certData, certificateFormat);
    auto auth = std::make_unique<RpcAuthPreSigned>(std::move(desPkey), std::move(desCert));
    auto utilsParam =
            std::make_tuple(socketType, RpcSecurity::TLS, std::make_optional(certificateFormat));

    auto server = std::make_unique<RpcTransportTestUtils::Server>();
    ASSERT_TRUE(server->setUp(utilsParam, std::move(auth)));

    RpcTransportTestUtils::Client client(server->getConnectToServerFn());
    ASSERT_TRUE(client.setUp(utilsParam));

    ASSERT_EQ(OK, trust(&client, server));
    ASSERT_EQ(OK, trust(server, &client));

    server->start();
    client.run();
}

INSTANTIATE_TEST_CASE_P(
        BinderRpc, RpcTransportTlsKeyTest,
        testing::Combine(testing::ValuesIn(testSocketTypes(false /* hasPreconnected*/)),
                         testing::Values(RpcCertificateFormat::PEM, RpcCertificateFormat::DER),
                         testing::Values(RpcKeyFormat::PEM, RpcKeyFormat::DER)),
        RpcTransportTlsKeyTest::PrintParamInfo);

} // namespace android

int main(int argc, char** argv) {