Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit cbfb18e1 authored by Steven Moreland's avatar Steven Moreland Committed by Gerrit Code Review
Browse files

Merge "libbinder: RPC allow RpcSession to be reusable"

parents cdb2c8f3 27a8bc7c
Loading
Loading
Loading
Loading
+68 −37
Original line number Original line Diff line number Diff line
@@ -28,6 +28,7 @@


#include <android-base/hex.h>
#include <android-base/hex.h>
#include <android-base/macros.h>
#include <android-base/macros.h>
#include <android-base/scopeguard.h>
#include <android_runtime/vm.h>
#include <android_runtime/vm.h>
#include <binder/BpBinder.h>
#include <binder/BpBinder.h>
#include <binder/Parcel.h>
#include <binder/Parcel.h>
@@ -54,13 +55,13 @@ using base::unique_fd;
RpcSession::RpcSession(std::unique_ptr<RpcTransportCtx> ctx) : mCtx(std::move(ctx)) {
RpcSession::RpcSession(std::unique_ptr<RpcTransportCtx> ctx) : mCtx(std::move(ctx)) {
    LOG_RPC_DETAIL("RpcSession created %p", this);
    LOG_RPC_DETAIL("RpcSession created %p", this);


    mState = std::make_unique<RpcState>();
    mRpcBinderState = std::make_unique<RpcState>();
}
}
RpcSession::~RpcSession() {
RpcSession::~RpcSession() {
    LOG_RPC_DETAIL("RpcSession destroyed %p", this);
    LOG_RPC_DETAIL("RpcSession destroyed %p", this);


    std::lock_guard<std::mutex> _l(mMutex);
    std::lock_guard<std::mutex> _l(mMutex);
    LOG_ALWAYS_FATAL_IF(mIncomingConnections.size() != 0,
    LOG_ALWAYS_FATAL_IF(mThreadState.mIncomingConnections.size() != 0,
                        "Should not be able to destroy a session with servers in use.");
                        "Should not be able to destroy a session with servers in use.");
}
}


@@ -77,10 +78,12 @@ sp<RpcSession> RpcSession::make(std::unique_ptr<RpcTransportCtxFactory> rpcTrans


void RpcSession::setMaxThreads(size_t threads) {
void RpcSession::setMaxThreads(size_t threads) {
    std::lock_guard<std::mutex> _l(mMutex);
    std::lock_guard<std::mutex> _l(mMutex);
    LOG_ALWAYS_FATAL_IF(!mOutgoingConnections.empty() || !mIncomingConnections.empty(),
    LOG_ALWAYS_FATAL_IF(!mThreadState.mOutgoingConnections.empty() ||
                                !mThreadState.mIncomingConnections.empty(),
                        "Must set max threads before setting up connections, but has %zu client(s) "
                        "Must set max threads before setting up connections, but has %zu client(s) "
                        "and %zu server(s)",
                        "and %zu server(s)",
                        mOutgoingConnections.size(), mIncomingConnections.size());
                        mThreadState.mOutgoingConnections.size(),
                        mThreadState.mIncomingConnections.size());
    mMaxThreads = threads;
    mMaxThreads = threads;
}
}


@@ -194,11 +197,11 @@ bool RpcSession::shutdownAndWait(bool wait) {
        LOG_ALWAYS_FATAL_IF(mShutdownListener == nullptr, "Shutdown listener not installed");
        LOG_ALWAYS_FATAL_IF(mShutdownListener == nullptr, "Shutdown listener not installed");
        mShutdownListener->waitForShutdown(_l, sp<RpcSession>::fromExisting(this));
        mShutdownListener->waitForShutdown(_l, sp<RpcSession>::fromExisting(this));


        LOG_ALWAYS_FATAL_IF(!mThreads.empty(), "Shutdown failed");
        LOG_ALWAYS_FATAL_IF(!mThreadState.mThreads.empty(), "Shutdown failed");
    }
    }


    _l.unlock();
    _l.unlock();
    mState->clear();
    mRpcBinderState->clear();


    return true;
    return true;
}
}
@@ -260,11 +263,11 @@ void RpcSession::WaitForShutdownListener::onSessionIncomingThreadEnded() {


void RpcSession::WaitForShutdownListener::waitForShutdown(std::unique_lock<std::mutex>& lock,
void RpcSession::WaitForShutdownListener::waitForShutdown(std::unique_lock<std::mutex>& lock,
                                                          const sp<RpcSession>& session) {
                                                          const sp<RpcSession>& session) {
    while (session->mIncomingConnections.size() > 0) {
    while (session->mThreadState.mIncomingConnections.size() > 0) {
        if (std::cv_status::timeout == mCv.wait_for(lock, std::chrono::seconds(1))) {
        if (std::cv_status::timeout == mCv.wait_for(lock, std::chrono::seconds(1))) {
            ALOGE("Waiting for RpcSession to shut down (1s w/o progress): %zu incoming connections "
            ALOGE("Waiting for RpcSession to shut down (1s w/o progress): %zu incoming connections "
                  "still.",
                  "still.",
                  session->mIncomingConnections.size());
                  session->mThreadState.mIncomingConnections.size());
        }
        }
    }
    }
}
}
@@ -274,7 +277,7 @@ void RpcSession::preJoinThreadOwnership(std::thread thread) {


    {
    {
        std::lock_guard<std::mutex> _l(mMutex);
        std::lock_guard<std::mutex> _l(mMutex);
        mThreads[thread.get_id()] = std::move(thread);
        mThreadState.mThreads[thread.get_id()] = std::move(thread);
    }
    }
}
}


@@ -289,7 +292,8 @@ RpcSession::PreJoinSetupResult RpcSession::preJoinSetup(
    if (connection == nullptr) {
    if (connection == nullptr) {
        status = DEAD_OBJECT;
        status = DEAD_OBJECT;
    } else {
    } else {
        status = mState->readConnectionInit(connection, sp<RpcSession>::fromExisting(this));
        status =
                mRpcBinderState->readConnectionInit(connection, sp<RpcSession>::fromExisting(this));
    }
    }


    return PreJoinSetupResult{
    return PreJoinSetupResult{
@@ -376,10 +380,10 @@ void RpcSession::join(sp<RpcSession>&& session, PreJoinSetupResult&& setupResult
    sp<RpcSession::EventListener> listener;
    sp<RpcSession::EventListener> listener;
    {
    {
        std::lock_guard<std::mutex> _l(session->mMutex);
        std::lock_guard<std::mutex> _l(session->mMutex);
        auto it = session->mThreads.find(std::this_thread::get_id());
        auto it = session->mThreadState.mThreads.find(std::this_thread::get_id());
        LOG_ALWAYS_FATAL_IF(it == session->mThreads.end());
        LOG_ALWAYS_FATAL_IF(it == session->mThreadState.mThreads.end());
        it->second.detach();
        it->second.detach();
        session->mThreads.erase(it);
        session->mThreadState.mThreads.erase(it);


        listener = session->mEventListener.promote();
        listener = session->mEventListener.promote();
    }
    }
@@ -410,12 +414,34 @@ status_t RpcSession::setupClient(const std::function<status_t(const std::vector<
                                                              bool incoming)>& connectAndInit) {
                                                              bool incoming)>& connectAndInit) {
    {
    {
        std::lock_guard<std::mutex> _l(mMutex);
        std::lock_guard<std::mutex> _l(mMutex);
        LOG_ALWAYS_FATAL_IF(mOutgoingConnections.size() != 0,
        LOG_ALWAYS_FATAL_IF(mThreadState.mOutgoingConnections.size() != 0,
                            "Must only setup session once, but already has %zu clients",
                            "Must only setup session once, but already has %zu clients",
                            mOutgoingConnections.size());
                            mThreadState.mOutgoingConnections.size());
    }
    }

    if (auto status = initShutdownTrigger(); status != OK) return status;
    if (auto status = initShutdownTrigger(); status != OK) return status;


    auto oldProtocolVersion = mProtocolVersion;
    auto cleanup = base::ScopeGuard([&] {
        // if any threads are started, shut them down
        (void)shutdownAndWait(true);

        mShutdownListener = nullptr;
        mEventListener.clear();

        mId.clear();

        mShutdownTrigger = nullptr;
        mRpcBinderState = std::make_unique<RpcState>();

        // protocol version may have been downgraded - if we reuse this object
        // to connect to another server, force that server to request a
        // downgrade again
        mProtocolVersion = oldProtocolVersion;

        mThreadState = {};
    });

    if (status_t status = connectAndInit({}, false /*incoming*/); status != OK) return status;
    if (status_t status = connectAndInit({}, false /*incoming*/); status != OK) return status;


    {
    {
@@ -464,6 +490,8 @@ status_t RpcSession::setupClient(const std::function<status_t(const std::vector<
        if (status_t status = connectAndInit(mId, true /*incoming*/); status != OK) return status;
        if (status_t status = connectAndInit(mId, true /*incoming*/); status != OK) return status;
    }
    }


    cleanup.Disable();

    return OK;
    return OK;
}
}


@@ -634,12 +662,12 @@ status_t RpcSession::addOutgoingConnection(std::unique_ptr<RpcTransport> rpcTran
        std::lock_guard<std::mutex> _l(mMutex);
        std::lock_guard<std::mutex> _l(mMutex);
        connection->rpcTransport = std::move(rpcTransport);
        connection->rpcTransport = std::move(rpcTransport);
        connection->exclusiveTid = gettid();
        connection->exclusiveTid = gettid();
        mOutgoingConnections.push_back(connection);
        mThreadState.mOutgoingConnections.push_back(connection);
    }
    }


    status_t status = OK;
    status_t status = OK;
    if (init) {
    if (init) {
        mState->sendConnectionInit(connection, sp<RpcSession>::fromExisting(this));
        mRpcBinderState->sendConnectionInit(connection, sp<RpcSession>::fromExisting(this));
    }
    }


    {
    {
@@ -671,9 +699,9 @@ sp<RpcSession::RpcConnection> RpcSession::assignIncomingConnectionToThisThread(
        std::unique_ptr<RpcTransport> rpcTransport) {
        std::unique_ptr<RpcTransport> rpcTransport) {
    std::lock_guard<std::mutex> _l(mMutex);
    std::lock_guard<std::mutex> _l(mMutex);


    if (mIncomingConnections.size() >= mMaxThreads) {
    if (mThreadState.mIncomingConnections.size() >= mMaxThreads) {
        ALOGE("Cannot add thread to session with %zu threads (max is set to %zu)",
        ALOGE("Cannot add thread to session with %zu threads (max is set to %zu)",
              mIncomingConnections.size(), mMaxThreads);
              mThreadState.mIncomingConnections.size(), mMaxThreads);
        return nullptr;
        return nullptr;
    }
    }


@@ -681,7 +709,7 @@ sp<RpcSession::RpcConnection> RpcSession::assignIncomingConnectionToThisThread(
    // happens when new connections are still being established as part of a
    // happens when new connections are still being established as part of a
    // very short-lived session which shuts down after it already started
    // very short-lived session which shuts down after it already started
    // accepting new connections.
    // accepting new connections.
    if (mIncomingConnections.size() < mMaxIncomingConnections) {
    if (mThreadState.mIncomingConnections.size() < mThreadState.mMaxIncomingConnections) {
        return nullptr;
        return nullptr;
    }
    }


@@ -689,18 +717,19 @@ sp<RpcSession::RpcConnection> RpcSession::assignIncomingConnectionToThisThread(
    session->rpcTransport = std::move(rpcTransport);
    session->rpcTransport = std::move(rpcTransport);
    session->exclusiveTid = gettid();
    session->exclusiveTid = gettid();


    mIncomingConnections.push_back(session);
    mThreadState.mIncomingConnections.push_back(session);
    mMaxIncomingConnections = mIncomingConnections.size();
    mThreadState.mMaxIncomingConnections = mThreadState.mIncomingConnections.size();


    return session;
    return session;
}
}


bool RpcSession::removeIncomingConnection(const sp<RpcConnection>& connection) {
bool RpcSession::removeIncomingConnection(const sp<RpcConnection>& connection) {
    std::unique_lock<std::mutex> _l(mMutex);
    std::unique_lock<std::mutex> _l(mMutex);
    if (auto it = std::find(mIncomingConnections.begin(), mIncomingConnections.end(), connection);
    if (auto it = std::find(mThreadState.mIncomingConnections.begin(),
        it != mIncomingConnections.end()) {
                            mThreadState.mIncomingConnections.end(), connection);
        mIncomingConnections.erase(it);
        it != mThreadState.mIncomingConnections.end()) {
        if (mIncomingConnections.size() == 0) {
        mThreadState.mIncomingConnections.erase(it);
        if (mThreadState.mIncomingConnections.size() == 0) {
            sp<EventListener> listener = mEventListener.promote();
            sp<EventListener> listener = mEventListener.promote();
            if (listener) {
            if (listener) {
                _l.unlock();
                _l.unlock();
@@ -725,7 +754,7 @@ status_t RpcSession::ExclusiveConnection::find(const sp<RpcSession>& session, Co
    pid_t tid = gettid();
    pid_t tid = gettid();
    std::unique_lock<std::mutex> _l(session->mMutex);
    std::unique_lock<std::mutex> _l(session->mMutex);


    session->mWaitingThreads++;
    session->mThreadState.mWaitingThreads++;
    while (true) {
    while (true) {
        sp<RpcConnection> exclusive;
        sp<RpcConnection> exclusive;
        sp<RpcConnection> available;
        sp<RpcConnection> available;
@@ -733,8 +762,8 @@ status_t RpcSession::ExclusiveConnection::find(const sp<RpcSession>& session, Co
        // CHECK FOR DEDICATED CLIENT SOCKET
        // CHECK FOR DEDICATED CLIENT SOCKET
        //
        //
        // A server/looper should always use a dedicated connection if available
        // A server/looper should always use a dedicated connection if available
        findConnection(tid, &exclusive, &available, session->mOutgoingConnections,
        findConnection(tid, &exclusive, &available, session->mThreadState.mOutgoingConnections,
                       session->mOutgoingConnectionsOffset);
                       session->mThreadState.mOutgoingConnectionsOffset);


        // WARNING: this assumes a server cannot request its client to send
        // WARNING: this assumes a server cannot request its client to send
        // a transaction, as mIncomingConnections is excluded below.
        // a transaction, as mIncomingConnections is excluded below.
@@ -747,8 +776,9 @@ status_t RpcSession::ExclusiveConnection::find(const sp<RpcSession>& session, Co
        // command. So, we move to considering the second available thread
        // command. So, we move to considering the second available thread
        // for subsequent calls.
        // for subsequent calls.
        if (use == ConnectionUse::CLIENT_ASYNC && (exclusive != nullptr || available != nullptr)) {
        if (use == ConnectionUse::CLIENT_ASYNC && (exclusive != nullptr || available != nullptr)) {
            session->mOutgoingConnectionsOffset = (session->mOutgoingConnectionsOffset + 1) %
            session->mThreadState.mOutgoingConnectionsOffset =
                    session->mOutgoingConnections.size();
                    (session->mThreadState.mOutgoingConnectionsOffset + 1) %
                    session->mThreadState.mOutgoingConnections.size();
        }
        }


        // USE SERVING SOCKET (e.g. nested transaction)
        // USE SERVING SOCKET (e.g. nested transaction)
@@ -756,7 +786,7 @@ status_t RpcSession::ExclusiveConnection::find(const sp<RpcSession>& session, Co
            sp<RpcConnection> exclusiveIncoming;
            sp<RpcConnection> exclusiveIncoming;
            // server connections are always assigned to a thread
            // server connections are always assigned to a thread
            findConnection(tid, &exclusiveIncoming, nullptr /*available*/,
            findConnection(tid, &exclusiveIncoming, nullptr /*available*/,
                           session->mIncomingConnections, 0 /* index hint */);
                           session->mThreadState.mIncomingConnections, 0 /* index hint */);


            // asynchronous calls cannot be nested, we currently allow ref count
            // asynchronous calls cannot be nested, we currently allow ref count
            // calls to be nested (so that you can use this without having extra
            // calls to be nested (so that you can use this without having extra
@@ -785,19 +815,20 @@ status_t RpcSession::ExclusiveConnection::find(const sp<RpcSession>& session, Co
            break;
            break;
        }
        }


        if (session->mOutgoingConnections.size() == 0) {
        if (session->mThreadState.mOutgoingConnections.size() == 0) {
            ALOGE("Session has no client connections. This is required for an RPC server to make "
            ALOGE("Session has no client connections. This is required for an RPC server to make "
                  "any non-nested (e.g. oneway or on another thread) calls. Use: %d. Server "
                  "any non-nested (e.g. oneway or on another thread) calls. Use: %d. Server "
                  "connections: %zu",
                  "connections: %zu",
                  static_cast<int>(use), session->mIncomingConnections.size());
                  static_cast<int>(use), session->mThreadState.mIncomingConnections.size());
            return WOULD_BLOCK;
            return WOULD_BLOCK;
        }
        }


        LOG_RPC_DETAIL("No available connections (have %zu clients and %zu servers). Waiting...",
        LOG_RPC_DETAIL("No available connections (have %zu clients and %zu servers). Waiting...",
                       session->mOutgoingConnections.size(), session->mIncomingConnections.size());
                       session->mThreadState.mOutgoingConnections.size(),
                       session->mThreadState.mIncomingConnections.size());
        session->mAvailableConnectionCv.wait(_l);
        session->mAvailableConnectionCv.wait(_l);
    }
    }
    session->mWaitingThreads--;
    session->mThreadState.mWaitingThreads--;


    return OK;
    return OK;
}
}
@@ -836,7 +867,7 @@ RpcSession::ExclusiveConnection::~ExclusiveConnection() {
    if (!mReentrant && mConnection != nullptr) {
    if (!mReentrant && mConnection != nullptr) {
        std::unique_lock<std::mutex> _l(mSession->mMutex);
        std::unique_lock<std::mutex> _l(mSession->mMutex);
        mConnection->exclusiveTid = std::nullopt;
        mConnection->exclusiveTid = std::nullopt;
        if (mSession->mWaitingThreads > 0) {
        if (mSession->mThreadState.mWaitingThreads > 0) {
            _l.unlock();
            _l.unlock();
            mSession->mAvailableConnectionCv.notify_one();
            mSession->mAvailableConnectionCv.notify_one();
        }
        }
+12 −9
Original line number Original line Diff line number Diff line
@@ -168,7 +168,7 @@ public:
    sp<RpcServer> server();
    sp<RpcServer> server();


    // internal only
    // internal only
    const std::unique_ptr<RpcState>& state() { return mState; }
    const std::unique_ptr<RpcState>& state() { return mRpcBinderState; }


private:
private:
    friend sp<RpcSession>;
    friend sp<RpcSession>;
@@ -303,7 +303,7 @@ private:


    std::unique_ptr<FdTrigger> mShutdownTrigger;
    std::unique_ptr<FdTrigger> mShutdownTrigger;


    std::unique_ptr<RpcState> mState;
    std::unique_ptr<RpcState> mRpcBinderState;


    std::mutex mMutex; // for all below
    std::mutex mMutex; // for all below


@@ -311,6 +311,8 @@ private:
    std::optional<uint32_t> mProtocolVersion;
    std::optional<uint32_t> mProtocolVersion;


    std::condition_variable mAvailableConnectionCv; // for mWaitingThreads
    std::condition_variable mAvailableConnectionCv; // for mWaitingThreads

    struct ThreadState {
        size_t mWaitingThreads = 0;
        size_t mWaitingThreads = 0;
        // hint index into clients, ++ when sending an async transaction
        // hint index into clients, ++ when sending an async transaction
        size_t mOutgoingConnectionsOffset = 0;
        size_t mOutgoingConnectionsOffset = 0;
@@ -318,6 +320,7 @@ private:
        size_t mMaxIncomingConnections = 0;
        size_t mMaxIncomingConnections = 0;
        std::vector<sp<RpcConnection>> mIncomingConnections;
        std::vector<sp<RpcConnection>> mIncomingConnections;
        std::map<std::thread::id, std::thread> mThreads;
        std::map<std::thread::id, std::thread> mThreads;
    } mThreadState;
};
};


} // namespace android
} // namespace android