Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 0fbf2976 authored by Yifan Hong's avatar Yifan Hong Committed by Automerger Merge Worker
Browse files

Merge changes from topic "binder-tls-trigger" am: 53195f9d am: dd326395...

Merge changes from topic "binder-tls-trigger" am: 53195f9d am: dd326395 am: 1fbb5b6f am: 5fccda13

Original change: https://android-review.googlesource.com/c/platform/frameworks/native/+/1825993

Change-Id: I57a6b2267c01f34ce75fc40b9c840dbe2f1b5511
parents d355e392 5fccda13
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -347,7 +347,7 @@ status_t RpcTransportTls::isTriggered(FdTrigger* fdTrigger) {
        ALOGE("%s: %s", __PRETTY_FUNCTION__, ret.error().message().c_str());
        return ret.error().code() == 0 ? UNKNOWN_ERROR : -ret.error().code();
    }
    return OK;
    return *ret ? -ECANCELED : OK;
}

status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, const void* data,
+120 −27
Original line number Diff line number Diff line
@@ -53,6 +53,7 @@
#include "RpcCertificateVerifierSimple.h"

using namespace std::chrono_literals;
using namespace std::placeholders;
using testing::AssertionFailure;
using testing::AssertionResult;
using testing::AssertionSuccess;
@@ -1444,7 +1445,7 @@ public:
                PrintToString(certificateFormat);
    }
    void TearDown() override {
        for (auto& server : mServers) server->shutdown();
        for (auto& server : mServers) server->shutdownAndWait();
    }

    // A server that handles client socket connections.
@@ -1452,7 +1453,7 @@ public:
    public:
        explicit Server() {}
        Server(Server&&) = default;
        ~Server() { shutdown(); }
        ~Server() { shutdownAndWait(); }
        [[nodiscard]] AssertionResult setUp() {
            auto [socketType, rpcSecurity, certificateFormat] = GetParam();
            auto rpcServer = RpcServer::make(newFactory(rpcSecurity));
@@ -1536,17 +1537,17 @@ public:
            ASSERT_TRUE(acceptedFd.ok());
            auto serverTransport = mCtx->newTransport(std::move(acceptedFd), mFdTrigger.get());
            if (serverTransport == nullptr) return; // handshake failed
            std::string message(kMessage);
            ASSERT_EQ(OK,
                      serverTransport->interruptableWriteFully(mFdTrigger.get(), message.data(),
                                                               message.size()));
            ASSERT_TRUE(mPostConnect(serverTransport.get(), mFdTrigger.get()));
        }
        void shutdown() {
            mFdTrigger->trigger();
            if (mThread != nullptr) {
                mThread->join();
                mThread = nullptr;
        void shutdownAndWait() {
            shutdown();
            join();
        }
        void shutdown() { mFdTrigger->trigger(); }

        void setPostConnect(
                std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> fn) {
            mPostConnect = std::move(fn);
        }

    private:
@@ -1558,6 +1559,26 @@ public:
        std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier =
                std::make_shared<RpcCertificateVerifierSimple>();
        bool mSetup = false;
        // The function invoked after connection and handshake. By default, it is
        // |defaultPostConnect| that sends |kMessage| to the client.
        std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> mPostConnect =
                Server::defaultPostConnect;

        void join() {
            if (mThread != nullptr) {
                mThread->join();
                mThread = nullptr;
            }
        }

        static AssertionResult defaultPostConnect(RpcTransport* serverTransport,
                                                  FdTrigger* fdTrigger) {
            std::string message(kMessage);
            auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(),
                                                                   message.size());
            if (status != OK) return AssertionFailure() << statusToString(status);
            return AssertionSuccess();
        }
    };

    class Client {
@@ -1566,8 +1587,6 @@ public:
        Client(Client&&) = default;
        [[nodiscard]] AssertionResult setUp() {
            auto [socketType, rpcSecurity, certificateFormat] = GetParam();
            mFd = mConnectToServer();
            if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server";
            mFdTrigger = FdTrigger::make();
            mCtx = newFactory(rpcSecurity, mCertVerifier)->newClientCtx();
            if (mCtx == nullptr) return AssertionFailure() << "newClientCtx";
@@ -1577,24 +1596,35 @@ public:
        std::shared_ptr<RpcCertificateVerifierSimple> getCertVerifier() const {
            return mCertVerifier;
        }
        void run(bool handshakeOk = true, bool readOk = true) {
            auto clientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get());
            if (clientTransport == nullptr) {
                ASSERT_FALSE(handshakeOk) << "newTransport returns nullptr, but it shouldn't";
                return;
        // connect() and do handshake
        bool setUpTransport() {
            mFd = mConnectToServer();
            if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server";
            mClientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get());
            return mClientTransport != nullptr;
        }
            ASSERT_TRUE(handshakeOk) << "newTransport does not return nullptr, but it should";
            std::string expectedMessage(kMessage);
        AssertionResult readMessage(const std::string& expectedMessage = kMessage) {
            LOG_ALWAYS_FATAL_IF(mClientTransport == nullptr, "setUpTransport not called or failed");
            std::string readMessage(expectedMessage.size(), '\0');
            status_t readStatus =
                    clientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(),
                    mClientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(),
                                                             readMessage.size());
            if (readOk) {
                ASSERT_EQ(OK, readStatus);
                ASSERT_EQ(readMessage, expectedMessage);
            } else {
                ASSERT_NE(OK, readStatus);
            if (readStatus != OK) {
                return AssertionFailure() << statusToString(readStatus);
            }
            if (readMessage != expectedMessage) {
                return AssertionFailure()
                        << "Expected " << expectedMessage << ", actual " << readMessage;
            }
            return AssertionSuccess();
        }
        void run(bool handshakeOk = true, bool readOk = true) {
            if (!setUpTransport()) {
                ASSERT_FALSE(handshakeOk) << "newTransport returns nullptr, but it shouldn't";
                return;
            }
            ASSERT_TRUE(handshakeOk) << "newTransport does not return nullptr, but it should";
            ASSERT_EQ(readOk, readMessage());
        }

    private:
@@ -1604,6 +1634,7 @@ public:
        std::unique_ptr<RpcTransportCtx> mCtx;
        std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier =
                std::make_shared<RpcCertificateVerifierSimple>();
        std::unique_ptr<RpcTransport> mClientTransport;
    };

    // Make A trust B.
@@ -1729,6 +1760,68 @@ TEST_P(RpcTransportTest, MaliciousClient) {
    maliciousClient.run(true, readOk);
}

TEST_P(RpcTransportTest, Trigger) {
    std::string msg2 = ", world!";
    std::mutex writeMutex;
    std::condition_variable writeCv;
    bool shouldContinueWriting = false;
    auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) {
        std::string message(kMessage);
        auto status =
                serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size());
        if (status != OK) return AssertionFailure() << statusToString(status);

        {
            std::unique_lock<std::mutex> lock(writeMutex);
            if (!writeCv.wait_for(lock, 3s, [&] { return shouldContinueWriting; })) {
                return AssertionFailure() << "write barrier not cleared in time!";
            }
        }

        status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size());
        if (status != -ECANCELED)
            return AssertionFailure() << "When FdTrigger is shut down, interruptableWriteFully "
                                         "should return -ECANCELLED, but it is "
                                      << statusToString(status);
        return AssertionSuccess();
    };

    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
    ASSERT_TRUE(server->setUp());

    // Set up client
    Client client(server->getConnectToServerFn());
    ASSERT_TRUE(client.setUp());

    // Exchange keys
    ASSERT_EQ(OK, trust(&client, server));
    ASSERT_EQ(OK, trust(server, &client));

    server->setPostConnect(serverPostConnect);

    // Start server
    server->start();
    // connect() to server and do handshake
    ASSERT_TRUE(client.setUpTransport());
    // read the first message. This confirms that server has finished handshake and start handling
    // client fd. Server thread should pause at waitForWriteBarrier.
    ASSERT_TRUE(client.readMessage(kMessage));
    // Trigger server shutdown after server starts handling client FD. This ensures that the second
    // write is on an FdTrigger that has been shut down.
    server->shutdown();
    // Continues server thread to write the second message.
    {
        std::unique_lock<std::mutex> lock(writeMutex);
        shouldContinueWriting = true;
        lock.unlock();
        writeCv.notify_all();
    }
    // After this line, server thread unblocks and attempts to write the second message, but
    // shutdown is triggered, so write should failed with -ECANCELLED. See |serverPostConnect|.
    // On the client side, second read fails with DEAD_OBJECT
    ASSERT_FALSE(client.readMessage(msg2));
}

std::vector<RpcCertificateFormat> testRpcCertificateFormats() {
    return {
            RpcCertificateFormat::PEM,