Loading libs/binder/RpcTransportTls.cpp +1 −1 Original line number Diff line number Diff line Loading @@ -347,7 +347,7 @@ status_t RpcTransportTls::isTriggered(FdTrigger* fdTrigger) { ALOGE("%s: %s", __PRETTY_FUNCTION__, ret.error().message().c_str()); return ret.error().code() == 0 ? UNKNOWN_ERROR : -ret.error().code(); } return OK; return *ret ? -ECANCELED : OK; } status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, const void* data, Loading libs/binder/tests/binderRpcTest.cpp +120 −27 Original line number Diff line number Diff line Loading @@ -53,6 +53,7 @@ #include "RpcCertificateVerifierSimple.h" using namespace std::chrono_literals; using namespace std::placeholders; using testing::AssertionFailure; using testing::AssertionResult; using testing::AssertionSuccess; Loading Loading @@ -1444,7 +1445,7 @@ public: PrintToString(certificateFormat); } void TearDown() override { for (auto& server : mServers) server->shutdown(); for (auto& server : mServers) server->shutdownAndWait(); } // A server that handles client socket connections. Loading @@ -1452,7 +1453,7 @@ public: public: explicit Server() {} Server(Server&&) = default; ~Server() { shutdown(); } ~Server() { shutdownAndWait(); } [[nodiscard]] AssertionResult setUp() { auto [socketType, rpcSecurity, certificateFormat] = GetParam(); auto rpcServer = RpcServer::make(newFactory(rpcSecurity)); Loading Loading @@ -1536,17 +1537,17 @@ public: ASSERT_TRUE(acceptedFd.ok()); auto serverTransport = mCtx->newTransport(std::move(acceptedFd), mFdTrigger.get()); if (serverTransport == nullptr) return; // handshake failed std::string message(kMessage); ASSERT_EQ(OK, serverTransport->interruptableWriteFully(mFdTrigger.get(), message.data(), message.size())); ASSERT_TRUE(mPostConnect(serverTransport.get(), mFdTrigger.get())); } void shutdown() { mFdTrigger->trigger(); if (mThread != nullptr) { mThread->join(); mThread = nullptr; void shutdownAndWait() { shutdown(); join(); } void shutdown() { mFdTrigger->trigger(); } void setPostConnect( std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> fn) { mPostConnect = std::move(fn); } private: Loading @@ -1558,6 +1559,26 @@ public: std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier = std::make_shared<RpcCertificateVerifierSimple>(); bool mSetup = false; // The function invoked after connection and handshake. By default, it is // |defaultPostConnect| that sends |kMessage| to the client. std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> mPostConnect = Server::defaultPostConnect; void join() { if (mThread != nullptr) { mThread->join(); mThread = nullptr; } } static AssertionResult defaultPostConnect(RpcTransport* serverTransport, FdTrigger* fdTrigger) { std::string message(kMessage); auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size()); if (status != OK) return AssertionFailure() << statusToString(status); return AssertionSuccess(); } }; class Client { Loading @@ -1566,8 +1587,6 @@ public: Client(Client&&) = default; [[nodiscard]] AssertionResult setUp() { auto [socketType, rpcSecurity, certificateFormat] = GetParam(); mFd = mConnectToServer(); if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server"; mFdTrigger = FdTrigger::make(); mCtx = newFactory(rpcSecurity, mCertVerifier)->newClientCtx(); if (mCtx == nullptr) return AssertionFailure() << "newClientCtx"; Loading @@ -1577,24 +1596,35 @@ public: std::shared_ptr<RpcCertificateVerifierSimple> getCertVerifier() const { return mCertVerifier; } void run(bool handshakeOk = true, bool readOk = true) { auto clientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get()); if (clientTransport == nullptr) { ASSERT_FALSE(handshakeOk) << "newTransport returns nullptr, but it shouldn't"; return; // connect() and do handshake bool setUpTransport() { mFd = mConnectToServer(); if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server"; mClientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get()); return mClientTransport != nullptr; } ASSERT_TRUE(handshakeOk) << "newTransport does not return nullptr, but it should"; std::string expectedMessage(kMessage); AssertionResult readMessage(const std::string& expectedMessage = kMessage) { LOG_ALWAYS_FATAL_IF(mClientTransport == nullptr, "setUpTransport not called or failed"); std::string readMessage(expectedMessage.size(), '\0'); status_t readStatus = clientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(), mClientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(), readMessage.size()); if (readOk) { ASSERT_EQ(OK, readStatus); ASSERT_EQ(readMessage, expectedMessage); } else { ASSERT_NE(OK, readStatus); if (readStatus != OK) { return AssertionFailure() << statusToString(readStatus); } if (readMessage != expectedMessage) { return AssertionFailure() << "Expected " << expectedMessage << ", actual " << readMessage; } return AssertionSuccess(); } void run(bool handshakeOk = true, bool readOk = true) { if (!setUpTransport()) { ASSERT_FALSE(handshakeOk) << "newTransport returns nullptr, but it shouldn't"; return; } ASSERT_TRUE(handshakeOk) << "newTransport does not return nullptr, but it should"; ASSERT_EQ(readOk, readMessage()); } private: Loading @@ -1604,6 +1634,7 @@ public: std::unique_ptr<RpcTransportCtx> mCtx; std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier = std::make_shared<RpcCertificateVerifierSimple>(); std::unique_ptr<RpcTransport> mClientTransport; }; // Make A trust B. Loading Loading @@ -1729,6 +1760,68 @@ TEST_P(RpcTransportTest, MaliciousClient) { maliciousClient.run(true, readOk); } TEST_P(RpcTransportTest, Trigger) { std::string msg2 = ", world!"; std::mutex writeMutex; std::condition_variable writeCv; bool shouldContinueWriting = false; auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) { std::string message(kMessage); auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size()); if (status != OK) return AssertionFailure() << statusToString(status); { std::unique_lock<std::mutex> lock(writeMutex); if (!writeCv.wait_for(lock, 3s, [&] { return shouldContinueWriting; })) { return AssertionFailure() << "write barrier not cleared in time!"; } } status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size()); if (status != -ECANCELED) return AssertionFailure() << "When FdTrigger is shut down, interruptableWriteFully " "should return -ECANCELLED, but it is " << statusToString(status); return AssertionSuccess(); }; auto server = mServers.emplace_back(std::make_unique<Server>()).get(); ASSERT_TRUE(server->setUp()); // Set up client Client client(server->getConnectToServerFn()); ASSERT_TRUE(client.setUp()); // Exchange keys ASSERT_EQ(OK, trust(&client, server)); ASSERT_EQ(OK, trust(server, &client)); server->setPostConnect(serverPostConnect); // Start server server->start(); // connect() to server and do handshake ASSERT_TRUE(client.setUpTransport()); // read the first message. This confirms that server has finished handshake and start handling // client fd. Server thread should pause at waitForWriteBarrier. ASSERT_TRUE(client.readMessage(kMessage)); // Trigger server shutdown after server starts handling client FD. This ensures that the second // write is on an FdTrigger that has been shut down. server->shutdown(); // Continues server thread to write the second message. { std::unique_lock<std::mutex> lock(writeMutex); shouldContinueWriting = true; lock.unlock(); writeCv.notify_all(); } // After this line, server thread unblocks and attempts to write the second message, but // shutdown is triggered, so write should failed with -ECANCELLED. See |serverPostConnect|. // On the client side, second read fails with DEAD_OBJECT ASSERT_FALSE(client.readMessage(msg2)); } std::vector<RpcCertificateFormat> testRpcCertificateFormats() { return { RpcCertificateFormat::PEM, Loading Loading
libs/binder/RpcTransportTls.cpp +1 −1 Original line number Diff line number Diff line Loading @@ -347,7 +347,7 @@ status_t RpcTransportTls::isTriggered(FdTrigger* fdTrigger) { ALOGE("%s: %s", __PRETTY_FUNCTION__, ret.error().message().c_str()); return ret.error().code() == 0 ? UNKNOWN_ERROR : -ret.error().code(); } return OK; return *ret ? -ECANCELED : OK; } status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, const void* data, Loading
libs/binder/tests/binderRpcTest.cpp +120 −27 Original line number Diff line number Diff line Loading @@ -53,6 +53,7 @@ #include "RpcCertificateVerifierSimple.h" using namespace std::chrono_literals; using namespace std::placeholders; using testing::AssertionFailure; using testing::AssertionResult; using testing::AssertionSuccess; Loading Loading @@ -1444,7 +1445,7 @@ public: PrintToString(certificateFormat); } void TearDown() override { for (auto& server : mServers) server->shutdown(); for (auto& server : mServers) server->shutdownAndWait(); } // A server that handles client socket connections. Loading @@ -1452,7 +1453,7 @@ public: public: explicit Server() {} Server(Server&&) = default; ~Server() { shutdown(); } ~Server() { shutdownAndWait(); } [[nodiscard]] AssertionResult setUp() { auto [socketType, rpcSecurity, certificateFormat] = GetParam(); auto rpcServer = RpcServer::make(newFactory(rpcSecurity)); Loading Loading @@ -1536,17 +1537,17 @@ public: ASSERT_TRUE(acceptedFd.ok()); auto serverTransport = mCtx->newTransport(std::move(acceptedFd), mFdTrigger.get()); if (serverTransport == nullptr) return; // handshake failed std::string message(kMessage); ASSERT_EQ(OK, serverTransport->interruptableWriteFully(mFdTrigger.get(), message.data(), message.size())); ASSERT_TRUE(mPostConnect(serverTransport.get(), mFdTrigger.get())); } void shutdown() { mFdTrigger->trigger(); if (mThread != nullptr) { mThread->join(); mThread = nullptr; void shutdownAndWait() { shutdown(); join(); } void shutdown() { mFdTrigger->trigger(); } void setPostConnect( std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> fn) { mPostConnect = std::move(fn); } private: Loading @@ -1558,6 +1559,26 @@ public: std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier = std::make_shared<RpcCertificateVerifierSimple>(); bool mSetup = false; // The function invoked after connection and handshake. By default, it is // |defaultPostConnect| that sends |kMessage| to the client. std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> mPostConnect = Server::defaultPostConnect; void join() { if (mThread != nullptr) { mThread->join(); mThread = nullptr; } } static AssertionResult defaultPostConnect(RpcTransport* serverTransport, FdTrigger* fdTrigger) { std::string message(kMessage); auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size()); if (status != OK) return AssertionFailure() << statusToString(status); return AssertionSuccess(); } }; class Client { Loading @@ -1566,8 +1587,6 @@ public: Client(Client&&) = default; [[nodiscard]] AssertionResult setUp() { auto [socketType, rpcSecurity, certificateFormat] = GetParam(); mFd = mConnectToServer(); if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server"; mFdTrigger = FdTrigger::make(); mCtx = newFactory(rpcSecurity, mCertVerifier)->newClientCtx(); if (mCtx == nullptr) return AssertionFailure() << "newClientCtx"; Loading @@ -1577,24 +1596,35 @@ public: std::shared_ptr<RpcCertificateVerifierSimple> getCertVerifier() const { return mCertVerifier; } void run(bool handshakeOk = true, bool readOk = true) { auto clientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get()); if (clientTransport == nullptr) { ASSERT_FALSE(handshakeOk) << "newTransport returns nullptr, but it shouldn't"; return; // connect() and do handshake bool setUpTransport() { mFd = mConnectToServer(); if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server"; mClientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get()); return mClientTransport != nullptr; } ASSERT_TRUE(handshakeOk) << "newTransport does not return nullptr, but it should"; std::string expectedMessage(kMessage); AssertionResult readMessage(const std::string& expectedMessage = kMessage) { LOG_ALWAYS_FATAL_IF(mClientTransport == nullptr, "setUpTransport not called or failed"); std::string readMessage(expectedMessage.size(), '\0'); status_t readStatus = clientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(), mClientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(), readMessage.size()); if (readOk) { ASSERT_EQ(OK, readStatus); ASSERT_EQ(readMessage, expectedMessage); } else { ASSERT_NE(OK, readStatus); if (readStatus != OK) { return AssertionFailure() << statusToString(readStatus); } if (readMessage != expectedMessage) { return AssertionFailure() << "Expected " << expectedMessage << ", actual " << readMessage; } return AssertionSuccess(); } void run(bool handshakeOk = true, bool readOk = true) { if (!setUpTransport()) { ASSERT_FALSE(handshakeOk) << "newTransport returns nullptr, but it shouldn't"; return; } ASSERT_TRUE(handshakeOk) << "newTransport does not return nullptr, but it should"; ASSERT_EQ(readOk, readMessage()); } private: Loading @@ -1604,6 +1634,7 @@ public: std::unique_ptr<RpcTransportCtx> mCtx; std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier = std::make_shared<RpcCertificateVerifierSimple>(); std::unique_ptr<RpcTransport> mClientTransport; }; // Make A trust B. Loading Loading @@ -1729,6 +1760,68 @@ TEST_P(RpcTransportTest, MaliciousClient) { maliciousClient.run(true, readOk); } TEST_P(RpcTransportTest, Trigger) { std::string msg2 = ", world!"; std::mutex writeMutex; std::condition_variable writeCv; bool shouldContinueWriting = false; auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) { std::string message(kMessage); auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size()); if (status != OK) return AssertionFailure() << statusToString(status); { std::unique_lock<std::mutex> lock(writeMutex); if (!writeCv.wait_for(lock, 3s, [&] { return shouldContinueWriting; })) { return AssertionFailure() << "write barrier not cleared in time!"; } } status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size()); if (status != -ECANCELED) return AssertionFailure() << "When FdTrigger is shut down, interruptableWriteFully " "should return -ECANCELLED, but it is " << statusToString(status); return AssertionSuccess(); }; auto server = mServers.emplace_back(std::make_unique<Server>()).get(); ASSERT_TRUE(server->setUp()); // Set up client Client client(server->getConnectToServerFn()); ASSERT_TRUE(client.setUp()); // Exchange keys ASSERT_EQ(OK, trust(&client, server)); ASSERT_EQ(OK, trust(server, &client)); server->setPostConnect(serverPostConnect); // Start server server->start(); // connect() to server and do handshake ASSERT_TRUE(client.setUpTransport()); // read the first message. This confirms that server has finished handshake and start handling // client fd. Server thread should pause at waitForWriteBarrier. ASSERT_TRUE(client.readMessage(kMessage)); // Trigger server shutdown after server starts handling client FD. This ensures that the second // write is on an FdTrigger that has been shut down. server->shutdown(); // Continues server thread to write the second message. { std::unique_lock<std::mutex> lock(writeMutex); shouldContinueWriting = true; lock.unlock(); writeCv.notify_all(); } // After this line, server thread unblocks and attempts to write the second message, but // shutdown is triggered, so write should failed with -ECANCELLED. See |serverPostConnect|. // On the client side, second read fails with DEAD_OBJECT ASSERT_FALSE(client.readMessage(msg2)); } std::vector<RpcCertificateFormat> testRpcCertificateFormats() { return { RpcCertificateFormat::PEM, Loading