Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ecf937dd authored by Yifan Hong's avatar Yifan Hong
Browse files

binder: RpcServer / RpcSession add API for certs.

These APIs call into RpcTransportCtx::getCertificate
and RpcTransportClientCtx::addTrustedPeerCertificate,
respectively.

For RpcSession, peer (server) certificates are fixed when
it is constructed.

Test: binderRpcTest
Bug: 195166979
Change-Id: I0d1bd93042895aeb3ab1de4fe6b9d90e73d0d116
parent 588d59c6
Loading
Loading
Loading
Loading
+18 −8
Original line number Diff line number Diff line
@@ -39,8 +39,7 @@ namespace android {
using base::ScopeGuard;
using base::unique_fd;

RpcServer::RpcServer(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory)
      : mRpcTransportCtxFactory(std::move(rpcTransportCtxFactory)) {}
RpcServer::RpcServer(std::unique_ptr<RpcTransportCtx> ctx) : mCtx(std::move(ctx)) {}
RpcServer::~RpcServer() {
    (void)shutdown();
}
@@ -49,7 +48,9 @@ sp<RpcServer> RpcServer::make(std::unique_ptr<RpcTransportCtxFactory> rpcTranspo
    // Default is without TLS.
    if (rpcTransportCtxFactory == nullptr)
        rpcTransportCtxFactory = RpcTransportCtxFactoryRaw::make();
    return sp<RpcServer>::make(std::move(rpcTransportCtxFactory));
    auto ctx = rpcTransportCtxFactory->newServerCtx();
    if (ctx == nullptr) return nullptr;
    return sp<RpcServer>::make(std::move(ctx));
}

void RpcServer::iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction() {
@@ -138,6 +139,20 @@ sp<IBinder> RpcServer::getRootObject() {
    return ret;
}

std::string RpcServer::getCertificate(CertificateFormat format) {
    std::lock_guard<std::mutex> _l(mLock);
    return mCtx->getCertificate(format);
}

status_t RpcServer::addTrustedPeerCertificate(CertificateFormat format, std::string_view cert) {
    std::lock_guard<std::mutex> _l(mLock);
    // Ensure that join thread is not running or shutdown trigger is not set up. In either case,
    // it means there are child threads running. It is invalid to add trusted peer certificates
    // after join thread and/or child threads are running to avoid race condition.
    if (mJoinThreadRunning || mShutdownTrigger != nullptr) return INVALID_OPERATION;
    return mCtx->addTrustedPeerCertificate(format, cert);
}

static void joinRpcServer(sp<RpcServer>&& thiz) {
    thiz->join();
}
@@ -159,10 +174,6 @@ void RpcServer::join() {
        mJoinThreadRunning = true;
        mShutdownTrigger = FdTrigger::make();
        LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr, "Cannot create join signaler");

        mCtx = mRpcTransportCtxFactory->newServerCtx();
        LOG_ALWAYS_FATAL_IF(mCtx == nullptr, "Unable to create RpcTransportCtx with %s sockets",
                            mRpcTransportCtxFactory->toCString());
    }

    status_t status;
@@ -229,7 +240,6 @@ bool RpcServer::shutdown() {
    LOG_RPC_DETAIL("Finished waiting on shutdown.");

    mShutdownTrigger = nullptr;
    mCtx = nullptr;
    return true;
}

+27 −20
Original line number Diff line number Diff line
@@ -49,8 +49,7 @@ namespace android {

using base::unique_fd;

RpcSession::RpcSession(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory)
      : mRpcTransportCtxFactory(std::move(rpcTransportCtxFactory)) {
RpcSession::RpcSession(std::unique_ptr<RpcTransportCtx> ctx) : mCtx(std::move(ctx)) {
    LOG_RPC_DETAIL("RpcSession created %p", this);

    mState = std::make_unique<RpcState>();
@@ -63,11 +62,26 @@ RpcSession::~RpcSession() {
                        "Should not be able to destroy a session with servers in use.");
}

sp<RpcSession> RpcSession::make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory) {
sp<RpcSession> RpcSession::make() {
    // Default is without TLS.
    if (rpcTransportCtxFactory == nullptr)
        rpcTransportCtxFactory = RpcTransportCtxFactoryRaw::make();
    return sp<RpcSession>::make(std::move(rpcTransportCtxFactory));
    return make(RpcTransportCtxFactoryRaw::make(), std::nullopt, std::nullopt);
}

sp<RpcSession> RpcSession::make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory,
                                std::optional<CertificateFormat> serverCertificateFormat,
                                std::optional<std::string> serverCertificate) {
    auto ctx = rpcTransportCtxFactory->newClientCtx();
    if (ctx == nullptr) return nullptr;
    LOG_ALWAYS_FATAL_IF(serverCertificateFormat.has_value() != serverCertificate.has_value());
    if (serverCertificateFormat.has_value() && serverCertificate.has_value()) {
        status_t status =
                ctx->addTrustedPeerCertificate(*serverCertificateFormat, *serverCertificate);
        if (status != OK) {
            ALOGE("Cannot add trusted server certificate: %s", statusToString(status).c_str());
            return nullptr;
        }
    }
    return sp<RpcSession>::make(std::move(ctx));
}

void RpcSession::setMaxThreads(size_t threads) {
@@ -155,12 +169,7 @@ status_t RpcSession::addNullDebuggingClient() {
        return -savedErrno;
    }

    auto ctx = mRpcTransportCtxFactory->newClientCtx();
    if (ctx == nullptr) {
        ALOGE("Unable to create RpcTransportCtx for null debugging client");
        return NO_MEMORY;
    }
    auto server = ctx->newTransport(std::move(serverFd), mShutdownTrigger.get());
    auto server = mCtx->newTransport(std::move(serverFd), mShutdownTrigger.get());
    if (server == nullptr) {
        ALOGE("Unable to set up RpcTransport");
        return UNKNOWN_ERROR;
@@ -529,15 +538,9 @@ status_t RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr,
status_t RpcSession::initAndAddConnection(unique_fd fd, const RpcAddress& sessionId,
                                          bool incoming) {
    LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr);
    auto ctx = mRpcTransportCtxFactory->newClientCtx();
    if (ctx == nullptr) {
        ALOGE("Unable to create client RpcTransportCtx with %s sockets",
              mRpcTransportCtxFactory->toCString());
        return NO_MEMORY;
    }
    auto server = ctx->newTransport(std::move(fd), mShutdownTrigger.get());
    auto server = mCtx->newTransport(std::move(fd), mShutdownTrigger.get());
    if (server == nullptr) {
        ALOGE("Unable to set up RpcTransport in %s context", mRpcTransportCtxFactory->toCString());
        ALOGE("%s: Unable to set up RpcTransport", __PRETTY_FUNCTION__);
        return UNKNOWN_ERROR;
    }

@@ -692,6 +695,10 @@ bool RpcSession::removeIncomingConnection(const sp<RpcConnection>& connection) {
    return false;
}

std::string RpcSession::getCertificate(CertificateFormat format) {
    return mCtx->getCertificate(format);
}

status_t RpcSession::ExclusiveConnection::find(const sp<RpcSession>& session, ConnectionUse use,
                                               ExclusiveConnection* connection) {
    connection->mSession = session;
+13 −3
Original line number Diff line number Diff line
@@ -133,6 +133,17 @@ public:
    void setRootObjectWeak(const wp<IBinder>& binder);
    sp<IBinder> getRootObject();

    /**
     * See RpcTransportCtx::getCertificate
     */
    std::string getCertificate(CertificateFormat);

    /**
     * See RpcTransportCtx::addTrustedPeerCertificate.
     * Thread-safe. This is only possible before the server is join()-ing.
     */
    status_t addTrustedPeerCertificate(CertificateFormat, std::string_view cert);

    /**
     * Runs join() in a background thread. Immediately returns.
     */
@@ -170,7 +181,7 @@ public:

private:
    friend sp<RpcServer>;
    explicit RpcServer(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory);
    explicit RpcServer(std::unique_ptr<RpcTransportCtx> ctx);

    void onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
    void onSessionIncomingThreadEnded() override;
@@ -178,7 +189,7 @@ private:
    static void establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd);
    status_t setupSocketServer(const RpcSocketAddress& address);

    const std::unique_ptr<RpcTransportCtxFactory> mRpcTransportCtxFactory;
    const std::unique_ptr<RpcTransportCtx> mCtx;
    bool mAgreedExperimental = false;
    size_t mMaxThreads = 1;
    std::optional<uint32_t> mProtocolVersion;
@@ -193,7 +204,6 @@ private:
    std::map<RpcAddress, sp<RpcSession>> mSessions;
    std::unique_ptr<FdTrigger> mShutdownTrigger;
    std::condition_variable mShutdownCv;
    std::unique_ptr<RpcTransportCtx> mCtx;
};

} // namespace android
+16 −4
Original line number Diff line number Diff line
@@ -51,8 +51,15 @@ constexpr uint32_t RPC_WIRE_PROTOCOL_VERSION = RPC_WIRE_PROTOCOL_VERSION_EXPERIM
 */
class RpcSession final : public virtual RefBase {
public:
    static sp<RpcSession> make(
            std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory = nullptr);
    // Create an RpcSession with default configuration (raw sockets).
    static sp<RpcSession> make();

    // Create an RpcSession with the given configuration. |serverCertificateFormat| and
    // |serverCertificate| must have values or be nullopt simultaneously. If they have values, set
    // server certificate.
    static sp<RpcSession> make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory,
                               std::optional<CertificateFormat> serverCertificateFormat,
                               std::optional<std::string> serverCertificate);

    /**
     * Set the maximum number of threads allowed to be made (for things like callbacks).
@@ -124,6 +131,11 @@ public:
     */
    status_t getRemoteMaxThreads(size_t* maxThreads);

    /**
     * See RpcTransportCtx::getCertificate
     */
    std::string getCertificate(CertificateFormat);

    /**
     * Shuts down the service.
     *
@@ -159,7 +171,7 @@ private:
    friend sp<RpcSession>;
    friend RpcServer;
    friend RpcState;
    explicit RpcSession(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory);
    explicit RpcSession(std::unique_ptr<RpcTransportCtx> ctx);

    class EventListener : public virtual RefBase {
    public:
@@ -259,7 +271,7 @@ private:
        bool mReentrant = false;
    };

    const std::unique_ptr<RpcTransportCtxFactory> mRpcTransportCtxFactory;
    const std::unique_ptr<RpcTransportCtx> mCtx;

    // On the other side of a session, for each of mOutgoingConnections here, there should
    // be one of mIncomingConnections on the other side (and vice versa).
+4 −2
Original line number Diff line number Diff line
@@ -522,7 +522,8 @@ public:
        status_t status;

        for (size_t i = 0; i < options.numSessions; i++) {
            sp<RpcSession> session = RpcSession::make(newFactory(rpcSecurity));
            sp<RpcSession> session =
                    RpcSession::make(newFactory(rpcSecurity), std::nullopt, std::nullopt);
            session->setMaxThreads(options.numIncomingConnections);

            switch (socketType) {
@@ -1207,7 +1208,8 @@ static bool testSupportVsockLoopback() {
    }
    server->start();

    sp<RpcSession> session = RpcSession::make(RpcTransportCtxFactoryRaw::make());
    sp<RpcSession> session =
            RpcSession::make(RpcTransportCtxFactoryRaw::make(), std::nullopt, std::nullopt);
    status_t status = session->setupVsockClient(VMADDR_CID_LOCAL, vsockPort);
    while (!server->shutdown()) usleep(10000);
    ALOGE("Detected vsock loopback supported: %s", statusToString(status).c_str());
Loading