Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 67519325 authored by Yifan Hong's avatar Yifan Hong
Browse files

binder: TLS checks trigger properly.

Previously, libbinder_tls ignores the result of
isTriggeredPolled() by always returning OK regardless
of whether the shutdown trigger is triggered or not, causing
program to continue when it shouldn't be. Return the status
properly like FdTrigger::triggerablePoll:

- If any error during poll() return the code
- If shutdown, return -ECANCELED (new in this CL for TLS)
- Otherwise return OK

Refactor RpcTransportTest so that we can add a new test
to check that trigerablePoll() returns -ECANCELED in the
above case.

Test: binderRpcTest
Fixes: 199309623

Change-Id: Ia545ba71cc10be5c46f722a5d3e699f89e1bc70c
parent e07d273c
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -347,7 +347,7 @@ status_t RpcTransportTls::isTriggered(FdTrigger* fdTrigger) {
        ALOGE("%s: %s", __PRETTY_FUNCTION__, ret.error().message().c_str());
        return ret.error().code() == 0 ? UNKNOWN_ERROR : -ret.error().code();
    }
    return OK;
    return *ret ? -ECANCELED : OK;
}

status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, const void* data,
+117 −24
Original line number Diff line number Diff line
@@ -53,6 +53,7 @@
#include "RpcCertificateVerifierSimple.h"

using namespace std::chrono_literals;
using namespace std::placeholders;
using testing::AssertionFailure;
using testing::AssertionResult;
using testing::AssertionSuccess;
@@ -1536,17 +1537,17 @@ public:
            ASSERT_TRUE(acceptedFd.ok());
            auto serverTransport = mCtx->newTransport(std::move(acceptedFd), mFdTrigger.get());
            if (serverTransport == nullptr) return; // handshake failed
            std::string message(kMessage);
            ASSERT_EQ(OK,
                      serverTransport->interruptableWriteFully(mFdTrigger.get(), message.data(),
                                                               message.size()));
            ASSERT_TRUE(mPostConnect(serverTransport.get(), mFdTrigger.get()));
        }
        void shutdownAndWait() {
            mFdTrigger->trigger();
            if (mThread != nullptr) {
                mThread->join();
                mThread = nullptr;
            shutdown();
            join();
        }
        void shutdown() { mFdTrigger->trigger(); }

        void setPostConnect(
                std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> fn) {
            mPostConnect = std::move(fn);
        }

    private:
@@ -1558,6 +1559,26 @@ public:
        std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier =
                std::make_shared<RpcCertificateVerifierSimple>();
        bool mSetup = false;
        // The function invoked after connection and handshake. By default, it is
        // |defaultPostConnect| that sends |kMessage| to the client.
        std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> mPostConnect =
                Server::defaultPostConnect;

        void join() {
            if (mThread != nullptr) {
                mThread->join();
                mThread = nullptr;
            }
        }

        static AssertionResult defaultPostConnect(RpcTransport* serverTransport,
                                                  FdTrigger* fdTrigger) {
            std::string message(kMessage);
            auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(),
                                                                   message.size());
            if (status != OK) return AssertionFailure() << statusToString(status);
            return AssertionSuccess();
        }
    };

    class Client {
@@ -1566,8 +1587,6 @@ public:
        Client(Client&&) = default;
        [[nodiscard]] AssertionResult setUp() {
            auto [socketType, rpcSecurity, certificateFormat] = GetParam();
            mFd = mConnectToServer();
            if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server";
            mFdTrigger = FdTrigger::make();
            mCtx = newFactory(rpcSecurity, mCertVerifier)->newClientCtx();
            if (mCtx == nullptr) return AssertionFailure() << "newClientCtx";
@@ -1577,24 +1596,35 @@ public:
        std::shared_ptr<RpcCertificateVerifierSimple> getCertVerifier() const {
            return mCertVerifier;
        }
        void run(bool handshakeOk = true, bool readOk = true) {
            auto clientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get());
            if (clientTransport == nullptr) {
                ASSERT_FALSE(handshakeOk) << "newTransport returns nullptr, but it shouldn't";
                return;
        // connect() and do handshake
        bool setUpTransport() {
            mFd = mConnectToServer();
            if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server";
            mClientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get());
            return mClientTransport != nullptr;
        }
            ASSERT_TRUE(handshakeOk) << "newTransport does not return nullptr, but it should";
            std::string expectedMessage(kMessage);
        AssertionResult readMessage(const std::string& expectedMessage = kMessage) {
            LOG_ALWAYS_FATAL_IF(mClientTransport == nullptr, "setUpTransport not called or failed");
            std::string readMessage(expectedMessage.size(), '\0');
            status_t readStatus =
                    clientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(),
                    mClientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(),
                                                             readMessage.size());
            if (readOk) {
                ASSERT_EQ(OK, readStatus);
                ASSERT_EQ(readMessage, expectedMessage);
            } else {
                ASSERT_NE(OK, readStatus);
            if (readStatus != OK) {
                return AssertionFailure() << statusToString(readStatus);
            }
            if (readMessage != expectedMessage) {
                return AssertionFailure()
                        << "Expected " << expectedMessage << ", actual " << readMessage;
            }
            return AssertionSuccess();
        }
        void run(bool handshakeOk = true, bool readOk = true) {
            if (!setUpTransport()) {
                ASSERT_FALSE(handshakeOk) << "newTransport returns nullptr, but it shouldn't";
                return;
            }
            ASSERT_TRUE(handshakeOk) << "newTransport does not return nullptr, but it should";
            ASSERT_EQ(readOk, readMessage());
        }

    private:
@@ -1604,6 +1634,7 @@ public:
        std::unique_ptr<RpcTransportCtx> mCtx;
        std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier =
                std::make_shared<RpcCertificateVerifierSimple>();
        std::unique_ptr<RpcTransport> mClientTransport;
    };

    // Make A trust B.
@@ -1729,6 +1760,68 @@ TEST_P(RpcTransportTest, MaliciousClient) {
    maliciousClient.run(true, readOk);
}

TEST_P(RpcTransportTest, Trigger) {
    std::string msg2 = ", world!";
    std::mutex writeMutex;
    std::condition_variable writeCv;
    bool shouldContinueWriting = false;
    auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) {
        std::string message(kMessage);
        auto status =
                serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size());
        if (status != OK) return AssertionFailure() << statusToString(status);

        {
            std::unique_lock<std::mutex> lock(writeMutex);
            if (!writeCv.wait_for(lock, 3s, [&] { return shouldContinueWriting; })) {
                return AssertionFailure() << "write barrier not cleared in time!";
            }
        }

        status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size());
        if (status != -ECANCELED)
            return AssertionFailure() << "When FdTrigger is shut down, interruptableWriteFully "
                                         "should return -ECANCELLED, but it is "
                                      << statusToString(status);
        return AssertionSuccess();
    };

    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
    ASSERT_TRUE(server->setUp());

    // Set up client
    Client client(server->getConnectToServerFn());
    ASSERT_TRUE(client.setUp());

    // Exchange keys
    ASSERT_EQ(OK, trust(&client, server));
    ASSERT_EQ(OK, trust(server, &client));

    server->setPostConnect(serverPostConnect);

    // Start server
    server->start();
    // connect() to server and do handshake
    ASSERT_TRUE(client.setUpTransport());
    // read the first message. This confirms that server has finished handshake and start handling
    // client fd. Server thread should pause at waitForWriteBarrier.
    ASSERT_TRUE(client.readMessage(kMessage));
    // Trigger server shutdown after server starts handling client FD. This ensures that the second
    // write is on an FdTrigger that has been shut down.
    server->shutdown();
    // Continues server thread to write the second message.
    {
        std::unique_lock<std::mutex> lock(writeMutex);
        shouldContinueWriting = true;
        lock.unlock();
        writeCv.notify_all();
    }
    // After this line, server thread unblocks and attempts to write the second message, but
    // shutdown is triggered, so write should failed with -ECANCELLED. See |serverPostConnect|.
    // On the client side, second read fails with DEAD_OBJECT
    ASSERT_FALSE(client.readMessage(msg2));
}

std::vector<RpcCertificateFormat> testRpcCertificateFormats() {
    return {
            RpcCertificateFormat::PEM,