Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit dd67b94a authored by Steven Moreland's avatar Steven Moreland
Browse files

libbinder: fix RPC setup races

When setting up connections, there are a few cases where we take the
server lock and then we take the session lock. However, when a session
is shutting down, there is one case where we took the session lock and
then the server lock. This is a big no-no, and it was causing a
deadlock in the 'Fds' test (this creates many threads - but it is very
shortlived, the threads are still being setup on the server when the
process shutsdown, hitting the deadlock occassionally).

The solution to this involves keeping a little bit of extra state inside
of RpcSession directly to understand when it's shutting down. Also, we
now fully cleanup sessions before removing them during the shutdown
process.

From this point on, we should always take the server lock and then the
session lock in order to avoid races (never the session and then the
server).

Bug: N/A
Test: binderRpcTest
Test: while $ANDROID_BUILD_TOP/out/host/linux-x86/nativetest/binderRpcTest/binderRpcTest --gtest_filter="*Fd*"; do : ; done
Change-Id: I9144c43939c0640a2ec53f93f6e685ddce4b3e83
parent c18818c6
Loading
Loading
Loading
Loading
+1 −1
Original line number Original line Diff line number Diff line
@@ -369,7 +369,7 @@ bool RpcServer::setupSocketServer(const RpcSocketAddress& addr) {
    return true;
    return true;
}
}


void RpcServer::onSessionLockedAllIncomingThreadsEnded(const sp<RpcSession>& session) {
void RpcServer::onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) {
    auto id = session->mId;
    auto id = session->mId;
    LOG_ALWAYS_FATAL_IF(id == std::nullopt, "Server sessions must be initialized with ID");
    LOG_ALWAYS_FATAL_IF(id == std::nullopt, "Server sessions must be initialized with ID");
    LOG_RPC_DETAIL("Dropping session with address %s", id->toString().c_str());
    LOG_RPC_DETAIL("Dropping session with address %s", id->toString().c_str());
+30 −8
Original line number Original line Diff line number Diff line
@@ -132,6 +132,7 @@ bool RpcSession::shutdownAndWait(bool wait) {
    if (wait) {
    if (wait) {
        LOG_ALWAYS_FATAL_IF(mShutdownListener == nullptr, "Shutdown listener not installed");
        LOG_ALWAYS_FATAL_IF(mShutdownListener == nullptr, "Shutdown listener not installed");
        mShutdownListener->waitForShutdown(_l);
        mShutdownListener->waitForShutdown(_l);

        LOG_ALWAYS_FATAL_IF(!mThreads.empty(), "Shutdown failed");
        LOG_ALWAYS_FATAL_IF(!mThreads.empty(), "Shutdown failed");
    }
    }


@@ -261,7 +262,7 @@ status_t RpcSession::readId() {
    return OK;
    return OK;
}
}


void RpcSession::WaitForShutdownListener::onSessionLockedAllIncomingThreadsEnded(
void RpcSession::WaitForShutdownListener::onSessionAllIncomingThreadsEnded(
        const sp<RpcSession>& session) {
        const sp<RpcSession>& session) {
    (void)session;
    (void)session;
    mShutdown = true;
    mShutdown = true;
@@ -293,7 +294,13 @@ RpcSession::PreJoinSetupResult RpcSession::preJoinSetup(base::unique_fd fd) {
    // be able to do nested calls (we can't only read from it)
    // be able to do nested calls (we can't only read from it)
    sp<RpcConnection> connection = assignIncomingConnectionToThisThread(std::move(fd));
    sp<RpcConnection> connection = assignIncomingConnectionToThisThread(std::move(fd));


    status_t status = mState->readConnectionInit(connection, sp<RpcSession>::fromExisting(this));
    status_t status;

    if (connection == nullptr) {
        status = DEAD_OBJECT;
    } else {
        status = mState->readConnectionInit(connection, sp<RpcSession>::fromExisting(this));
    }


    return PreJoinSetupResult{
    return PreJoinSetupResult{
            .connection = std::move(connection),
            .connection = std::move(connection),
@@ -360,6 +367,7 @@ void RpcSession::join(sp<RpcSession>&& session, PreJoinSetupResult&& setupResult
    sp<RpcConnection>& connection = setupResult.connection;
    sp<RpcConnection>& connection = setupResult.connection;


    if (setupResult.status == OK) {
    if (setupResult.status == OK) {
        LOG_ALWAYS_FATAL_IF(!connection, "must have connection if setup succeeded");
        JavaThreadAttacher javaThreadAttacher;
        JavaThreadAttacher javaThreadAttacher;
        while (true) {
        while (true) {
            status_t status = session->state()->getAndExecuteCommand(connection, session,
            status_t status = session->state()->getAndExecuteCommand(connection, session,
@@ -375,9 +383,6 @@ void RpcSession::join(sp<RpcSession>&& session, PreJoinSetupResult&& setupResult
              statusToString(setupResult.status).c_str());
              statusToString(setupResult.status).c_str());
    }
    }


    LOG_ALWAYS_FATAL_IF(!session->removeIncomingConnection(connection),
                        "bad state: connection object guaranteed to be in list");

    sp<RpcSession::EventListener> listener;
    sp<RpcSession::EventListener> listener;
    {
    {
        std::lock_guard<std::mutex> _l(session->mMutex);
        std::lock_guard<std::mutex> _l(session->mMutex);
@@ -389,6 +394,12 @@ void RpcSession::join(sp<RpcSession>&& session, PreJoinSetupResult&& setupResult
        listener = session->mEventListener.promote();
        listener = session->mEventListener.promote();
    }
    }


    // done after all cleanup, since session shutdown progresses via callbacks here
    if (connection != nullptr) {
        LOG_ALWAYS_FATAL_IF(!session->removeIncomingConnection(connection),
                            "bad state: connection object guaranteed to be in list");
    }

    session = nullptr;
    session = nullptr;


    if (listener != nullptr) {
    if (listener != nullptr) {
@@ -579,24 +590,35 @@ bool RpcSession::setForServer(const wp<RpcServer>& server, const wp<EventListene


sp<RpcSession::RpcConnection> RpcSession::assignIncomingConnectionToThisThread(unique_fd fd) {
sp<RpcSession::RpcConnection> RpcSession::assignIncomingConnectionToThisThread(unique_fd fd) {
    std::lock_guard<std::mutex> _l(mMutex);
    std::lock_guard<std::mutex> _l(mMutex);

    // Don't accept any more connections, some have shutdown. Usually this
    // happens when new connections are still being established as part of a
    // very short-lived session which shuts down after it already started
    // accepting new connections.
    if (mIncomingConnections.size() < mMaxIncomingConnections) {
        return nullptr;
    }

    sp<RpcConnection> session = sp<RpcConnection>::make();
    sp<RpcConnection> session = sp<RpcConnection>::make();
    session->fd = std::move(fd);
    session->fd = std::move(fd);
    session->exclusiveTid = gettid();
    session->exclusiveTid = gettid();

    mIncomingConnections.push_back(session);
    mIncomingConnections.push_back(session);
    mMaxIncomingConnections = mIncomingConnections.size();


    return session;
    return session;
}
}


bool RpcSession::removeIncomingConnection(const sp<RpcConnection>& connection) {
bool RpcSession::removeIncomingConnection(const sp<RpcConnection>& connection) {
    std::lock_guard<std::mutex> _l(mMutex);
    std::unique_lock<std::mutex> _l(mMutex);
    if (auto it = std::find(mIncomingConnections.begin(), mIncomingConnections.end(), connection);
    if (auto it = std::find(mIncomingConnections.begin(), mIncomingConnections.end(), connection);
        it != mIncomingConnections.end()) {
        it != mIncomingConnections.end()) {
        mIncomingConnections.erase(it);
        mIncomingConnections.erase(it);
        if (mIncomingConnections.size() == 0) {
        if (mIncomingConnections.size() == 0) {
            sp<EventListener> listener = mEventListener.promote();
            sp<EventListener> listener = mEventListener.promote();
            if (listener) {
            if (listener) {
                listener->onSessionLockedAllIncomingThreadsEnded(
                _l.unlock();
                        sp<RpcSession>::fromExisting(this));
                listener->onSessionAllIncomingThreadsEnded(sp<RpcSession>::fromExisting(this));
            }
            }
        }
        }
        return true;
        return true;
+1 −1
Original line number Original line Diff line number Diff line
@@ -156,7 +156,7 @@ private:
    friend sp<RpcServer>;
    friend sp<RpcServer>;
    RpcServer();
    RpcServer();


    void onSessionLockedAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
    void onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
    void onSessionIncomingThreadEnded() override;
    void onSessionIncomingThreadEnded() override;


    static void establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd);
    static void establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd);
+4 −3
Original line number Original line Diff line number Diff line
@@ -177,19 +177,19 @@ private:


    class EventListener : public virtual RefBase {
    class EventListener : public virtual RefBase {
    public:
    public:
        virtual void onSessionLockedAllIncomingThreadsEnded(const sp<RpcSession>& session) = 0;
        virtual void onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) = 0;
        virtual void onSessionIncomingThreadEnded() = 0;
        virtual void onSessionIncomingThreadEnded() = 0;
    };
    };


    class WaitForShutdownListener : public EventListener {
    class WaitForShutdownListener : public EventListener {
    public:
    public:
        void onSessionLockedAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
        void onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
        void onSessionIncomingThreadEnded() override;
        void onSessionIncomingThreadEnded() override;
        void waitForShutdown(std::unique_lock<std::mutex>& lock);
        void waitForShutdown(std::unique_lock<std::mutex>& lock);


    private:
    private:
        std::condition_variable mCv;
        std::condition_variable mCv;
        bool mShutdown = false;
        volatile bool mShutdown = false;
    };
    };


    struct RpcConnection : public RefBase {
    struct RpcConnection : public RefBase {
@@ -297,6 +297,7 @@ private:
    // hint index into clients, ++ when sending an async transaction
    // hint index into clients, ++ when sending an async transaction
    size_t mOutgoingConnectionsOffset = 0;
    size_t mOutgoingConnectionsOffset = 0;
    std::vector<sp<RpcConnection>> mOutgoingConnections;
    std::vector<sp<RpcConnection>> mOutgoingConnections;
    size_t mMaxIncomingConnections = 0;
    std::vector<sp<RpcConnection>> mIncomingConnections;
    std::vector<sp<RpcConnection>> mIncomingConnections;
    std::map<std::thread::id, std::thread> mThreads;
    std::map<std::thread::id, std::thread> mThreads;
};
};