Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 3c400807 authored by Yifan Hong's avatar Yifan Hong Committed by Gerrit Code Review
Browse files

Merge "binder: RpcServer / RpcSession add API for certs."

parents 060d7f33 ecf937dd
Loading
Loading
Loading
Loading
+18 −8
Original line number Original line Diff line number Diff line
@@ -39,8 +39,7 @@ namespace android {
using base::ScopeGuard;
using base::ScopeGuard;
using base::unique_fd;
using base::unique_fd;


RpcServer::RpcServer(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory)
RpcServer::RpcServer(std::unique_ptr<RpcTransportCtx> ctx) : mCtx(std::move(ctx)) {}
      : mRpcTransportCtxFactory(std::move(rpcTransportCtxFactory)) {}
RpcServer::~RpcServer() {
RpcServer::~RpcServer() {
    (void)shutdown();
    (void)shutdown();
}
}
@@ -49,7 +48,9 @@ sp<RpcServer> RpcServer::make(std::unique_ptr<RpcTransportCtxFactory> rpcTranspo
    // Default is without TLS.
    // Default is without TLS.
    if (rpcTransportCtxFactory == nullptr)
    if (rpcTransportCtxFactory == nullptr)
        rpcTransportCtxFactory = RpcTransportCtxFactoryRaw::make();
        rpcTransportCtxFactory = RpcTransportCtxFactoryRaw::make();
    return sp<RpcServer>::make(std::move(rpcTransportCtxFactory));
    auto ctx = rpcTransportCtxFactory->newServerCtx();
    if (ctx == nullptr) return nullptr;
    return sp<RpcServer>::make(std::move(ctx));
}
}


void RpcServer::iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction() {
void RpcServer::iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction() {
@@ -138,6 +139,20 @@ sp<IBinder> RpcServer::getRootObject() {
    return ret;
    return ret;
}
}


std::string RpcServer::getCertificate(CertificateFormat format) {
    std::lock_guard<std::mutex> _l(mLock);
    return mCtx->getCertificate(format);
}

status_t RpcServer::addTrustedPeerCertificate(CertificateFormat format, std::string_view cert) {
    std::lock_guard<std::mutex> _l(mLock);
    // Ensure that join thread is not running or shutdown trigger is not set up. In either case,
    // it means there are child threads running. It is invalid to add trusted peer certificates
    // after join thread and/or child threads are running to avoid race condition.
    if (mJoinThreadRunning || mShutdownTrigger != nullptr) return INVALID_OPERATION;
    return mCtx->addTrustedPeerCertificate(format, cert);
}

static void joinRpcServer(sp<RpcServer>&& thiz) {
static void joinRpcServer(sp<RpcServer>&& thiz) {
    thiz->join();
    thiz->join();
}
}
@@ -159,10 +174,6 @@ void RpcServer::join() {
        mJoinThreadRunning = true;
        mJoinThreadRunning = true;
        mShutdownTrigger = FdTrigger::make();
        mShutdownTrigger = FdTrigger::make();
        LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr, "Cannot create join signaler");
        LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr, "Cannot create join signaler");

        mCtx = mRpcTransportCtxFactory->newServerCtx();
        LOG_ALWAYS_FATAL_IF(mCtx == nullptr, "Unable to create RpcTransportCtx with %s sockets",
                            mRpcTransportCtxFactory->toCString());
    }
    }


    status_t status;
    status_t status;
@@ -229,7 +240,6 @@ bool RpcServer::shutdown() {
    LOG_RPC_DETAIL("Finished waiting on shutdown.");
    LOG_RPC_DETAIL("Finished waiting on shutdown.");


    mShutdownTrigger = nullptr;
    mShutdownTrigger = nullptr;
    mCtx = nullptr;
    return true;
    return true;
}
}


+27 −20
Original line number Original line Diff line number Diff line
@@ -49,8 +49,7 @@ namespace android {


using base::unique_fd;
using base::unique_fd;


RpcSession::RpcSession(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory)
RpcSession::RpcSession(std::unique_ptr<RpcTransportCtx> ctx) : mCtx(std::move(ctx)) {
      : mRpcTransportCtxFactory(std::move(rpcTransportCtxFactory)) {
    LOG_RPC_DETAIL("RpcSession created %p", this);
    LOG_RPC_DETAIL("RpcSession created %p", this);


    mState = std::make_unique<RpcState>();
    mState = std::make_unique<RpcState>();
@@ -63,11 +62,26 @@ RpcSession::~RpcSession() {
                        "Should not be able to destroy a session with servers in use.");
                        "Should not be able to destroy a session with servers in use.");
}
}


sp<RpcSession> RpcSession::make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory) {
sp<RpcSession> RpcSession::make() {
    // Default is without TLS.
    // Default is without TLS.
    if (rpcTransportCtxFactory == nullptr)
    return make(RpcTransportCtxFactoryRaw::make(), std::nullopt, std::nullopt);
        rpcTransportCtxFactory = RpcTransportCtxFactoryRaw::make();
}
    return sp<RpcSession>::make(std::move(rpcTransportCtxFactory));

sp<RpcSession> RpcSession::make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory,
                                std::optional<CertificateFormat> serverCertificateFormat,
                                std::optional<std::string> serverCertificate) {
    auto ctx = rpcTransportCtxFactory->newClientCtx();
    if (ctx == nullptr) return nullptr;
    LOG_ALWAYS_FATAL_IF(serverCertificateFormat.has_value() != serverCertificate.has_value());
    if (serverCertificateFormat.has_value() && serverCertificate.has_value()) {
        status_t status =
                ctx->addTrustedPeerCertificate(*serverCertificateFormat, *serverCertificate);
        if (status != OK) {
            ALOGE("Cannot add trusted server certificate: %s", statusToString(status).c_str());
            return nullptr;
        }
    }
    return sp<RpcSession>::make(std::move(ctx));
}
}


void RpcSession::setMaxThreads(size_t threads) {
void RpcSession::setMaxThreads(size_t threads) {
@@ -155,12 +169,7 @@ status_t RpcSession::addNullDebuggingClient() {
        return -savedErrno;
        return -savedErrno;
    }
    }


    auto ctx = mRpcTransportCtxFactory->newClientCtx();
    auto server = mCtx->newTransport(std::move(serverFd), mShutdownTrigger.get());
    if (ctx == nullptr) {
        ALOGE("Unable to create RpcTransportCtx for null debugging client");
        return NO_MEMORY;
    }
    auto server = ctx->newTransport(std::move(serverFd), mShutdownTrigger.get());
    if (server == nullptr) {
    if (server == nullptr) {
        ALOGE("Unable to set up RpcTransport");
        ALOGE("Unable to set up RpcTransport");
        return UNKNOWN_ERROR;
        return UNKNOWN_ERROR;
@@ -531,15 +540,9 @@ status_t RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr,
status_t RpcSession::initAndAddConnection(unique_fd fd, const RpcAddress& sessionId,
status_t RpcSession::initAndAddConnection(unique_fd fd, const RpcAddress& sessionId,
                                          bool incoming) {
                                          bool incoming) {
    LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr);
    LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr);
    auto ctx = mRpcTransportCtxFactory->newClientCtx();
    auto server = mCtx->newTransport(std::move(fd), mShutdownTrigger.get());
    if (ctx == nullptr) {
        ALOGE("Unable to create client RpcTransportCtx with %s sockets",
              mRpcTransportCtxFactory->toCString());
        return NO_MEMORY;
    }
    auto server = ctx->newTransport(std::move(fd), mShutdownTrigger.get());
    if (server == nullptr) {
    if (server == nullptr) {
        ALOGE("Unable to set up RpcTransport in %s context", mRpcTransportCtxFactory->toCString());
        ALOGE("%s: Unable to set up RpcTransport", __PRETTY_FUNCTION__);
        return UNKNOWN_ERROR;
        return UNKNOWN_ERROR;
    }
    }


@@ -694,6 +697,10 @@ bool RpcSession::removeIncomingConnection(const sp<RpcConnection>& connection) {
    return false;
    return false;
}
}


std::string RpcSession::getCertificate(CertificateFormat format) {
    return mCtx->getCertificate(format);
}

status_t RpcSession::ExclusiveConnection::find(const sp<RpcSession>& session, ConnectionUse use,
status_t RpcSession::ExclusiveConnection::find(const sp<RpcSession>& session, ConnectionUse use,
                                               ExclusiveConnection* connection) {
                                               ExclusiveConnection* connection) {
    connection->mSession = session;
    connection->mSession = session;
+13 −3
Original line number Original line Diff line number Diff line
@@ -133,6 +133,17 @@ public:
    void setRootObjectWeak(const wp<IBinder>& binder);
    void setRootObjectWeak(const wp<IBinder>& binder);
    sp<IBinder> getRootObject();
    sp<IBinder> getRootObject();


    /**
     * See RpcTransportCtx::getCertificate
     */
    std::string getCertificate(CertificateFormat);

    /**
     * See RpcTransportCtx::addTrustedPeerCertificate.
     * Thread-safe. This is only possible before the server is join()-ing.
     */
    status_t addTrustedPeerCertificate(CertificateFormat, std::string_view cert);

    /**
    /**
     * Runs join() in a background thread. Immediately returns.
     * Runs join() in a background thread. Immediately returns.
     */
     */
@@ -170,7 +181,7 @@ public:


private:
private:
    friend sp<RpcServer>;
    friend sp<RpcServer>;
    explicit RpcServer(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory);
    explicit RpcServer(std::unique_ptr<RpcTransportCtx> ctx);


    void onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
    void onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
    void onSessionIncomingThreadEnded() override;
    void onSessionIncomingThreadEnded() override;
@@ -178,7 +189,7 @@ private:
    static void establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd);
    static void establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd);
    status_t setupSocketServer(const RpcSocketAddress& address);
    status_t setupSocketServer(const RpcSocketAddress& address);


    const std::unique_ptr<RpcTransportCtxFactory> mRpcTransportCtxFactory;
    const std::unique_ptr<RpcTransportCtx> mCtx;
    bool mAgreedExperimental = false;
    bool mAgreedExperimental = false;
    size_t mMaxThreads = 1;
    size_t mMaxThreads = 1;
    std::optional<uint32_t> mProtocolVersion;
    std::optional<uint32_t> mProtocolVersion;
@@ -193,7 +204,6 @@ private:
    std::map<RpcAddress, sp<RpcSession>> mSessions;
    std::map<RpcAddress, sp<RpcSession>> mSessions;
    std::unique_ptr<FdTrigger> mShutdownTrigger;
    std::unique_ptr<FdTrigger> mShutdownTrigger;
    std::condition_variable mShutdownCv;
    std::condition_variable mShutdownCv;
    std::unique_ptr<RpcTransportCtx> mCtx;
};
};


} // namespace android
} // namespace android
+16 −4
Original line number Original line Diff line number Diff line
@@ -51,8 +51,15 @@ constexpr uint32_t RPC_WIRE_PROTOCOL_VERSION = RPC_WIRE_PROTOCOL_VERSION_EXPERIM
 */
 */
class RpcSession final : public virtual RefBase {
class RpcSession final : public virtual RefBase {
public:
public:
    static sp<RpcSession> make(
    // Create an RpcSession with default configuration (raw sockets).
            std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory = nullptr);
    static sp<RpcSession> make();

    // Create an RpcSession with the given configuration. |serverCertificateFormat| and
    // |serverCertificate| must have values or be nullopt simultaneously. If they have values, set
    // server certificate.
    static sp<RpcSession> make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory,
                               std::optional<CertificateFormat> serverCertificateFormat,
                               std::optional<std::string> serverCertificate);


    /**
    /**
     * Set the maximum number of threads allowed to be made (for things like callbacks).
     * Set the maximum number of threads allowed to be made (for things like callbacks).
@@ -124,6 +131,11 @@ public:
     */
     */
    status_t getRemoteMaxThreads(size_t* maxThreads);
    status_t getRemoteMaxThreads(size_t* maxThreads);


    /**
     * See RpcTransportCtx::getCertificate
     */
    std::string getCertificate(CertificateFormat);

    /**
    /**
     * Shuts down the service.
     * Shuts down the service.
     *
     *
@@ -159,7 +171,7 @@ private:
    friend sp<RpcSession>;
    friend sp<RpcSession>;
    friend RpcServer;
    friend RpcServer;
    friend RpcState;
    friend RpcState;
    explicit RpcSession(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory);
    explicit RpcSession(std::unique_ptr<RpcTransportCtx> ctx);


    class EventListener : public virtual RefBase {
    class EventListener : public virtual RefBase {
    public:
    public:
@@ -259,7 +271,7 @@ private:
        bool mReentrant = false;
        bool mReentrant = false;
    };
    };


    const std::unique_ptr<RpcTransportCtxFactory> mRpcTransportCtxFactory;
    const std::unique_ptr<RpcTransportCtx> mCtx;


    // On the other side of a session, for each of mOutgoingConnections here, there should
    // On the other side of a session, for each of mOutgoingConnections here, there should
    // be one of mIncomingConnections on the other side (and vice versa).
    // be one of mIncomingConnections on the other side (and vice versa).
+4 −2
Original line number Original line Diff line number Diff line
@@ -522,7 +522,8 @@ public:
        status_t status;
        status_t status;


        for (size_t i = 0; i < options.numSessions; i++) {
        for (size_t i = 0; i < options.numSessions; i++) {
            sp<RpcSession> session = RpcSession::make(newFactory(rpcSecurity));
            sp<RpcSession> session =
                    RpcSession::make(newFactory(rpcSecurity), std::nullopt, std::nullopt);
            session->setMaxThreads(options.numIncomingConnections);
            session->setMaxThreads(options.numIncomingConnections);


            switch (socketType) {
            switch (socketType) {
@@ -1207,7 +1208,8 @@ static bool testSupportVsockLoopback() {
    }
    }
    server->start();
    server->start();


    sp<RpcSession> session = RpcSession::make(RpcTransportCtxFactoryRaw::make());
    sp<RpcSession> session =
            RpcSession::make(RpcTransportCtxFactoryRaw::make(), std::nullopt, std::nullopt);
    status_t status = session->setupVsockClient(VMADDR_CID_LOCAL, vsockPort);
    status_t status = session->setupVsockClient(VMADDR_CID_LOCAL, vsockPort);
    while (!server->shutdown()) usleep(10000);
    while (!server->shutdown()) usleep(10000);
    ALOGE("Detected vsock loopback supported: %s", statusToString(status).c_str());
    ALOGE("Detected vsock loopback supported: %s", statusToString(status).c_str());
Loading