Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 51c44a96 authored by Steven Moreland's avatar Steven Moreland
Browse files

libbinder: support server-specific session

When multiple clients connect to a server, we need a way to tell these
clients apart. Having a per-client root object is the easiest way to do
this (the alternative, using getCalling* like is used in binder, isn't
so great because it requires global/thread local place, but given that
many RpcSession objects can be created, and these can also be used in
conjunction with kernel binder, it is complicated figuring out exactly
where to call getCalling*).

Bug: 199259751
Test: binderRpcTest
Change-Id: I5727db618b5ea138bfa19e75ed915f6a6991518e
parent 269a5d6d
Loading
Loading
Loading
Loading
+33 −5
Original line number Diff line number Diff line
@@ -127,14 +127,23 @@ void RpcServer::setProtocolVersion(uint32_t version) {

void RpcServer::setRootObject(const sp<IBinder>& binder) {
    std::lock_guard<std::mutex> _l(mLock);
    mRootObjectFactory = nullptr;
    mRootObjectWeak = mRootObject = binder;
}

void RpcServer::setRootObjectWeak(const wp<IBinder>& binder) {
    std::lock_guard<std::mutex> _l(mLock);
    mRootObject.clear();
    mRootObjectFactory = nullptr;
    mRootObjectWeak = binder;
}
void RpcServer::setPerSessionRootObject(
        std::function<sp<IBinder>(const sockaddr*, socklen_t)>&& makeObject) {
    std::lock_guard<std::mutex> _l(mLock);
    mRootObject.clear();
    mRootObjectWeak.clear();
    mRootObjectFactory = std::move(makeObject);
}

sp<IBinder> RpcServer::getRootObject() {
    std::lock_guard<std::mutex> _l(mLock);
@@ -174,8 +183,14 @@ void RpcServer::join() {

    status_t status;
    while ((status = mShutdownTrigger->triggerablePoll(mServer, POLLIN)) == OK) {
        unique_fd clientFd(TEMP_FAILURE_RETRY(
                accept4(mServer.get(), nullptr, nullptr /*length*/, SOCK_CLOEXEC | SOCK_NONBLOCK)));
        sockaddr_storage addr;
        socklen_t addrLen = sizeof(addr);

        unique_fd clientFd(
                TEMP_FAILURE_RETRY(accept4(mServer.get(), reinterpret_cast<sockaddr*>(&addr),
                                           &addrLen, SOCK_CLOEXEC | SOCK_NONBLOCK)));

        LOG_ALWAYS_FATAL_IF(addrLen > static_cast<socklen_t>(sizeof(addr)), "Truncated address");

        if (clientFd < 0) {
            ALOGE("Could not accept4 socket: %s", strerror(errno));
@@ -187,7 +202,7 @@ void RpcServer::join() {
            std::lock_guard<std::mutex> _l(mLock);
            std::thread thread =
                    std::thread(&RpcServer::establishConnection, sp<RpcServer>::fromExisting(this),
                                std::move(clientFd));
                                std::move(clientFd), addr, addrLen);
            mConnectingThreads[thread.get_id()] = std::move(thread);
        }
    }
@@ -257,7 +272,8 @@ size_t RpcServer::numUninitializedSessions() {
    return mConnectingThreads.size();
}

void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd) {
void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd,
                                    const sockaddr_storage addr, socklen_t addrLen) {
    // TODO(b/183988761): cannot trust this simple ID
    LOG_ALWAYS_FATAL_IF(!server->mAgreedExperimental, "no!");

@@ -383,11 +399,23 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie
            session = RpcSession::make();
            session->setMaxIncomingThreads(server->mMaxThreads);
            if (!session->setProtocolVersion(protocolVersion)) return;

            // if null, falls back to server root
            sp<IBinder> sessionSpecificRoot;
            if (server->mRootObjectFactory != nullptr) {
                sessionSpecificRoot =
                        server->mRootObjectFactory(reinterpret_cast<const sockaddr*>(&addr),
                                                   addrLen);
                if (sessionSpecificRoot == nullptr) {
                    ALOGE("Warning: server returned null from root object factory");
                }
            }

            if (!session->setForServer(server,
                                       sp<RpcServer::EventListener>::fromExisting(
                                               static_cast<RpcServer::EventListener*>(
                                                       server.get())),
                                       sessionId)) {
                                       sessionId, sessionSpecificRoot)) {
                ALOGE("Failed to attach server to session");
                return;
            }
+3 −1
Original line number Diff line number Diff line
@@ -700,7 +700,8 @@ status_t RpcSession::addOutgoingConnection(std::unique_ptr<RpcTransport> rpcTran
}

bool RpcSession::setForServer(const wp<RpcServer>& server, const wp<EventListener>& eventListener,
                              const std::vector<uint8_t>& sessionId) {
                              const std::vector<uint8_t>& sessionId,
                              const sp<IBinder>& sessionSpecificRoot) {
    LOG_ALWAYS_FATAL_IF(mForServer != nullptr);
    LOG_ALWAYS_FATAL_IF(server == nullptr);
    LOG_ALWAYS_FATAL_IF(mEventListener != nullptr);
@@ -713,6 +714,7 @@ bool RpcSession::setForServer(const wp<RpcServer>& server, const wp<EventListene
    mId = sessionId;
    mForServer = server;
    mEventListener = eventListener;
    mSessionSpecificRootObject = sessionSpecificRoot;
    return true;
}

+3 −1
Original line number Diff line number Diff line
@@ -870,7 +870,9 @@ processTransactInternalTailCall:
                    if (server) {
                        switch (transaction->code) {
                            case RPC_SPECIAL_TRANSACT_GET_ROOT: {
                                replyStatus = reply.writeStrongBinder(server->getRootObject());
                                sp<IBinder> root = session->mSessionSpecificRootObject
                                        ?: server->getRootObject();
                                replyStatus = reply.writeStrongBinder(root);
                                break;
                            }
                            default: {
+7 −1
Original line number Diff line number Diff line
@@ -130,6 +130,10 @@ public:
     * Holds a weak reference to the root object.
     */
    void setRootObjectWeak(const wp<IBinder>& binder);
    /**
     * Allows a root object to be created for each session
     */
    void setPerSessionRootObject(std::function<sp<IBinder>(const sockaddr*, socklen_t)>&& object);
    sp<IBinder> getRootObject();

    /**
@@ -179,7 +183,8 @@ private:
    void onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
    void onSessionIncomingThreadEnded() override;

    static void establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd);
    static void establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd,
                                    const sockaddr_storage addr, socklen_t addrLen);
    status_t setupSocketServer(const RpcSocketAddress& address);

    const std::unique_ptr<RpcTransportCtx> mCtx;
@@ -194,6 +199,7 @@ private:
    std::map<std::thread::id, std::thread> mConnectingThreads;
    sp<IBinder> mRootObject;
    wp<IBinder> mRootObjectWeak;
    std::function<sp<IBinder>(const sockaddr*, socklen_t)> mRootObjectFactory;
    std::map<std::vector<uint8_t>, sp<RpcSession>> mSessions;
    std::unique_ptr<FdTrigger> mShutdownTrigger;
    std::condition_variable mShutdownCv;
+6 −1
Original line number Diff line number Diff line
@@ -256,7 +256,8 @@ private:
                                                 bool init);
    [[nodiscard]] bool setForServer(const wp<RpcServer>& server,
                                    const wp<RpcSession::EventListener>& eventListener,
                                    const std::vector<uint8_t>& sessionId);
                                    const std::vector<uint8_t>& sessionId,
                                    const sp<IBinder>& sessionSpecificRoot);
    sp<RpcConnection> assignIncomingConnectionToThisThread(
            std::unique_ptr<RpcTransport> rpcTransport);
    [[nodiscard]] bool removeIncomingConnection(const sp<RpcConnection>& connection);
@@ -313,6 +314,10 @@ private:
    sp<WaitForShutdownListener> mShutdownListener; // used for client sessions
    wp<EventListener> mEventListener; // mForServer if server, mShutdownListener if client

    // session-specific root object (if a different root is used for each
    // session)
    sp<IBinder> mSessionSpecificRootObject;

    std::vector<uint8_t> mId;

    std::unique_ptr<FdTrigger> mShutdownTrigger;
Loading