Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 1a235859 authored by Yifan Hong's avatar Yifan Hong Committed by Steven Moreland
Browse files

Add RpcServer::shutdown.

The function terminates any existing execution of join().

After this CL, join() is only allowed to be called in one thread.

Test: binderLibTest
Change-Id: I5f1abbb39ee42a8f94b7394a702a152701537e7e
parent 022f9945
Loading
Loading
Loading
Loading
+63 −4
Original line number Diff line number Diff line
@@ -16,19 +16,21 @@

#define LOG_TAG "RpcServer"

#include <poll.h>
#include <sys/socket.h>
#include <sys/un.h>

#include <thread>
#include <vector>

#include <android-base/macros.h>
#include <android-base/scopeguard.h>
#include <binder/Parcel.h>
#include <binder/RpcServer.h>
#include <log/log.h>
#include "RpcState.h"

#include "RpcSocketAddress.h"
#include "RpcState.h"
#include "RpcWireFormat.h"

namespace android {
@@ -99,7 +101,7 @@ bool RpcServer::setupInetServer(unsigned int port, unsigned int* assignedPort) {

void RpcServer::setMaxThreads(size_t threads) {
    LOG_ALWAYS_FATAL_IF(threads <= 0, "RpcServer is useless without threads");
    LOG_ALWAYS_FATAL_IF(mStarted, "must be called before started");
    LOG_ALWAYS_FATAL_IF(mJoinThreadRunning, "Cannot set max threads while running");
    mMaxThreads = threads;
}

@@ -126,16 +128,61 @@ sp<IBinder> RpcServer::getRootObject() {
    return ret;
}

std::unique_ptr<RpcServer::FdTrigger> RpcServer::FdTrigger::make() {
    auto ret = std::make_unique<RpcServer::FdTrigger>();
    if (!android::base::Pipe(&ret->mRead, &ret->mWrite)) return nullptr;
    return ret;
}

void RpcServer::FdTrigger::trigger() {
    mWrite.reset();
}

void RpcServer::join() {
    LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!");

    {
        std::lock_guard<std::mutex> _l(mLock);
        LOG_ALWAYS_FATAL_IF(!mServer.ok(), "RpcServer must be setup to join.");
        LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr, "Already joined");
        mJoinThreadRunning = true;
        mShutdownTrigger = FdTrigger::make();
        LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr, "Cannot create join signaler");
    }

    while (true) {
        (void)acceptOne();
        pollfd pfd[]{{.fd = mServer.get(), .events = POLLIN, .revents = 0},
                     {.fd = mShutdownTrigger->readFd().get(), .events = POLLHUP, .revents = 0}};
        int ret = TEMP_FAILURE_RETRY(poll(pfd, arraysize(pfd), -1));
        if (ret < 0) {
            ALOGE("Could not poll socket: %s", strerror(errno));
            continue;
        }
        if (ret == 0) {
            continue;
        }
        if (pfd[1].revents & POLLHUP) {
            LOG_RPC_DETAIL("join() exiting because shutdown requested.");
            break;
        }

        (void)acceptOneNoCheck();
    }

    {
        std::lock_guard<std::mutex> _l(mLock);
        mJoinThreadRunning = false;
    }
    mShutdownCv.notify_all();
}

bool RpcServer::acceptOne() {
    LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!");
    LOG_ALWAYS_FATAL_IF(!hasServer(), "RpcServer must be setup to join.");
    LOG_ALWAYS_FATAL_IF(!hasServer(), "RpcServer must be setup to acceptOne.");
    return acceptOneNoCheck();
}

bool RpcServer::acceptOneNoCheck() {
    unique_fd clientFd(
            TEMP_FAILURE_RETRY(accept4(mServer.get(), nullptr, nullptr /*length*/, SOCK_CLOEXEC)));

@@ -156,6 +203,18 @@ bool RpcServer::acceptOne() {
    return true;
}

bool RpcServer::shutdown() {
    LOG_ALWAYS_FATAL_IF(!mAgreedExperimental, "no!");
    std::unique_lock<std::mutex> _l(mLock);
    if (mShutdownTrigger == nullptr) return false;

    mShutdownTrigger->trigger();
    while (mJoinThreadRunning) mShutdownCv.wait(_l);

    mShutdownTrigger = nullptr;
    return true;
}

std::vector<sp<RpcSession>> RpcServer::listSessions() {
    std::lock_guard<std::mutex> _l(mLock);
    std::vector<sp<RpcSession>> sessions;
+33 −2
Original line number Diff line number Diff line
@@ -119,10 +119,21 @@ public:
    /**
     * You must have at least one client session before calling this.
     *
     * TODO(b/185167543): way to shut down?
     * If a client needs to actively terminate join, call shutdown() in a separate thread.
     *
     * At any given point, there can only be one thread calling join().
     */
    void join();

    /**
     * Shut down any existing join(). Return true if successfully shut down, false otherwise
     * (e.g. no join() is running). Will wait for the server to be fully
     * shutdown.
     *
     * TODO(b/185167543): wait for sessions to shutdown as well
     */
    [[nodiscard]] bool shutdown();

    /**
     * Accept one connection on this server. You must have at least one client
     * session before calling this.
@@ -142,14 +153,31 @@ public:
    void onSessionTerminating(const sp<RpcSession>& session);

private:
    /** This is not a pipe. */
    struct FdTrigger {
        static std::unique_ptr<FdTrigger> make();
        /**
         * poll() on this fd for POLLHUP to get notification when trigger is called
         */
        base::borrowed_fd readFd() const { return mRead; }
        /**
         * Close the write end of the pipe so that the read end receives POLLHUP.
         */
        void trigger();

    private:
        base::unique_fd mWrite;
        base::unique_fd mRead;
    };

    friend sp<RpcServer>;
    RpcServer();

    void establishConnection(sp<RpcServer>&& session, base::unique_fd clientFd);
    bool setupSocketServer(const RpcSocketAddress& address);
    [[nodiscard]] bool acceptOneNoCheck();

    bool mAgreedExperimental = false;
    bool mStarted = false; // TODO(b/185167543): support dynamically added clients
    size_t mMaxThreads = 1;
    base::unique_fd mServer; // socket we are accepting sessions on

@@ -159,6 +187,9 @@ private:
    wp<IBinder> mRootObjectWeak;
    std::map<int32_t, sp<RpcSession>> mSessions;
    int32_t mSessionIdCounter = 0;
    bool mJoinThreadRunning = false;
    std::unique_ptr<FdTrigger> mShutdownTrigger;
    std::condition_variable mShutdownCv;
};

} // namespace android
+50 −0
Original line number Diff line number Diff line
@@ -40,6 +40,8 @@
#include "../RpcState.h"   // for debugging
#include "../vm_sockets.h" // for VMADDR_*

using namespace std::chrono_literals;

namespace android {

TEST(BinderRpcParcel, EntireParcelFormatted) {
@@ -970,6 +972,54 @@ TEST_P(BinderRpcServerRootObject, WeakRootObject) {
INSTANTIATE_TEST_CASE_P(BinderRpc, BinderRpcServerRootObject,
                        ::testing::Combine(::testing::Bool(), ::testing::Bool()));

class OneOffSignal {
public:
    // If notify() was previously called, or is called within |duration|, return true; else false.
    template <typename R, typename P>
    bool wait(std::chrono::duration<R, P> duration) {
        std::unique_lock<std::mutex> lock(mMutex);
        return mCv.wait_for(lock, duration, [this] { return mValue; });
    }
    void notify() {
        std::unique_lock<std::mutex> lock(mMutex);
        mValue = true;
        lock.unlock();
        mCv.notify_all();
    }

private:
    std::mutex mMutex;
    std::condition_variable mCv;
    bool mValue = false;
};

TEST(BinderRpc, Shutdown) {
    auto addr = allocateSocketAddress();
    unlink(addr.c_str());
    auto server = RpcServer::make();
    server->iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction();
    ASSERT_TRUE(server->setupUnixDomainServer(addr.c_str()));
    auto joinEnds = std::make_shared<OneOffSignal>();

    // If things are broken and the thread never stops, don't block other tests. Because the thread
    // may run after the test finishes, it must not access the stack memory of the test. Hence,
    // shared pointers are passed.
    std::thread([server, joinEnds] {
        server->join();
        joinEnds->notify();
    }).detach();

    bool shutdown = false;
    for (int i = 0; i < 10 && !shutdown; i++) {
        usleep(300 * 1000); // 300ms; total 3s
        if (server->shutdown()) shutdown = true;
    }
    ASSERT_TRUE(shutdown) << "server->shutdown() never returns true";

    ASSERT_TRUE(joinEnds->wait(2s))
            << "After server->shutdown() returns true, join() did not stop after 2s";
}

} // namespace android

int main(int argc, char** argv) {