Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 826367f2 authored by Steven Moreland's avatar Steven Moreland
Browse files

libbinder: Session ID implemented directly.

In preparation for removing RpcAddress.

Bug: 182940634
Test: binderRpcTest (w & w/o LOG_RPC_DETAIL)
Change-Id: I945e650bbab9f8df4f785b689983b62c59bb8674
parent 33e1d32c
Loading
Loading
Loading
Loading
+38 −12
Original line number Diff line number Diff line
@@ -23,6 +23,8 @@
#include <thread>
#include <vector>

#include <android-base/file.h>
#include <android-base/hex.h>
#include <android-base/scopeguard.h>
#include <binder/Parcel.h>
#include <binder/RpcServer.h>
@@ -290,17 +292,29 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie
        }
    }

    std::vector<uint8_t> sessionId;
    if (status == OK) {
        if (header.sessionIdSize > 0) {
            sessionId.resize(header.sessionIdSize);
            status = client->interruptableReadFully(server->mShutdownTrigger.get(),
                                                    sessionId.data(), sessionId.size());
            if (status != OK) {
                ALOGE("Failed to read session ID for client connecting to RPC server: %s",
                      statusToString(status).c_str());
                // still need to cleanup before we can return
            }
        }
    }

    bool incoming = false;
    uint32_t protocolVersion = 0;
    RpcAddress sessionId = RpcAddress::zero();
    bool requestingNewSession = false;

    if (status == OK) {
        incoming = header.options & RPC_CONNECTION_OPTION_INCOMING;
        protocolVersion = std::min(header.version,
                                   server->mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION));
        sessionId = RpcAddress::fromRawEmbedded(&header.sessionId);
        requestingNewSession = sessionId.isZero();
        requestingNewSession = sessionId.empty();

        if (requestingNewSession) {
            RpcNewSessionResponse response{
@@ -342,15 +356,26 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie
                return;
            }

            // Uniquely identify session at the application layer. Even if a
            // client/server use the same certificates, if they create multiple
            // sessions, we still want to distinguish between them.
            constexpr size_t kSessionIdSize = 32;
            sessionId.resize(kSessionIdSize);
            size_t tries = 0;
            do {
                // don't block if there is some entropy issue
                if (tries++ > 5) {
                    ALOGE("Cannot find new address: %s", sessionId.toString().c_str());
                    ALOGE("Cannot find new address: %s",
                          base::HexString(sessionId.data(), sessionId.size()).c_str());
                    return;
                }

                sessionId = RpcAddress::random(true /*forServer*/);
                base::unique_fd fd(TEMP_FAILURE_RETRY(
                        open("/dev/urandom", O_RDONLY | O_CLOEXEC | O_NOFOLLOW)));
                if (!base::ReadFully(fd, sessionId.data(), sessionId.size())) {
                    ALOGE("Could not read from /dev/urandom to create session ID");
                    return;
                }
            } while (server->mSessions.end() != server->mSessions.find(sessionId));

            session = RpcSession::make();
@@ -370,7 +395,7 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie
            auto it = server->mSessions.find(sessionId);
            if (it == server->mSessions.end()) {
                ALOGE("Cannot add thread, no record of session with ID %s",
                      sessionId.toString().c_str());
                      base::HexString(sessionId.data(), sessionId.size()).c_str());
                return;
            }
            session = it->second;
@@ -432,16 +457,17 @@ status_t RpcServer::setupSocketServer(const RpcSocketAddress& addr) {
}

void RpcServer::onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) {
    auto id = session->mId;
    LOG_ALWAYS_FATAL_IF(id == std::nullopt, "Server sessions must be initialized with ID");
    LOG_RPC_DETAIL("Dropping session with address %s", id->toString().c_str());
    const std::vector<uint8_t>& id = session->mId;
    LOG_ALWAYS_FATAL_IF(id.empty(), "Server sessions must be initialized with ID");
    LOG_RPC_DETAIL("Dropping session with address %s",
                   base::HexString(id.data(), id.size()).c_str());

    std::lock_guard<std::mutex> _l(mLock);
    auto it = mSessions.find(*id);
    auto it = mSessions.find(id);
    LOG_ALWAYS_FATAL_IF(it == mSessions.end(), "Bad state, unknown session id %s",
                        id->toString().c_str());
                        base::HexString(id.data(), id.size()).c_str());
    LOG_ALWAYS_FATAL_IF(it->second != session, "Bad state, session has id mismatch %s",
                        id->toString().c_str());
                        base::HexString(id.data(), id.size()).c_str());
    (void)mSessions.erase(it);
}

+36 −19
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@

#include <string_view>

#include <android-base/hex.h>
#include <android-base/macros.h>
#include <android_runtime/vm.h>
#include <binder/Parcel.h>
@@ -143,7 +144,7 @@ status_t RpcSession::setupInetClient(const char* addr, unsigned int port) {
}

status_t RpcSession::setupPreconnectedClient(unique_fd fd, std::function<unique_fd()>&& request) {
    return setupClient([&](const RpcAddress& sessionId, bool incoming) -> status_t {
    return setupClient([&](const std::vector<uint8_t>& sessionId, bool incoming) -> status_t {
        // std::move'd from fd becomes -1 (!ok())
        if (!fd.ok()) {
            fd = request();
@@ -244,12 +245,11 @@ status_t RpcSession::readId() {
                                                ConnectionUse::CLIENT, &connection);
    if (status != OK) return status;

    mId = RpcAddress::zero();
    status = state()->getSessionId(connection.get(), sp<RpcSession>::fromExisting(this),
                                   &mId.value());
    status = state()->getSessionId(connection.get(), sp<RpcSession>::fromExisting(this), &mId);
    if (status != OK) return status;

    LOG_RPC_DETAIL("RpcSession %p has id %s", this, mId->toString().c_str());
    LOG_RPC_DETAIL("RpcSession %p has id %s", this,
                   base::HexString(mId.data(), mId.size()).c_str());
    return OK;
}

@@ -408,8 +408,8 @@ sp<RpcServer> RpcSession::server() {
    return server;
}

status_t RpcSession::setupClient(
        const std::function<status_t(const RpcAddress& sessionId, bool incoming)>& connectAndInit) {
status_t RpcSession::setupClient(const std::function<status_t(const std::vector<uint8_t>& sessionId,
                                                              bool incoming)>& connectAndInit) {
    {
        std::lock_guard<std::mutex> _l(mMutex);
        LOG_ALWAYS_FATAL_IF(mOutgoingConnections.size() != 0,
@@ -418,8 +418,7 @@ status_t RpcSession::setupClient(
    }
    if (auto status = initShutdownTrigger(); status != OK) return status;

    if (status_t status = connectAndInit(RpcAddress::zero(), false /*incoming*/); status != OK)
        return status;
    if (status_t status = connectAndInit({}, false /*incoming*/); status != OK) return status;

    {
        ExclusiveConnection connection;
@@ -460,26 +459,25 @@ status_t RpcSession::setupClient(

    // we've already setup one client
    for (size_t i = 0; i + 1 < numThreadsAvailable; i++) {
        if (status_t status = connectAndInit(mId.value(), false /*incoming*/); status != OK)
            return status;
        if (status_t status = connectAndInit(mId, false /*incoming*/); status != OK) return status;
    }

    for (size_t i = 0; i < mMaxThreads; i++) {
        if (status_t status = connectAndInit(mId.value(), true /*incoming*/); status != OK)
            return status;
        if (status_t status = connectAndInit(mId, true /*incoming*/); status != OK) return status;
    }

    return OK;
}

status_t RpcSession::setupSocketClient(const RpcSocketAddress& addr) {
    return setupClient([&](const RpcAddress& sessionId, bool incoming) {
    return setupClient([&](const std::vector<uint8_t>& sessionId, bool incoming) {
        return setupOneSocketConnection(addr, sessionId, incoming);
    });
}

status_t RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr,
                                              const RpcAddress& sessionId, bool incoming) {
                                              const std::vector<uint8_t>& sessionId,
                                              bool incoming) {
    for (size_t tries = 0; tries < 5; tries++) {
        if (tries > 0) usleep(10000);

@@ -537,7 +535,7 @@ status_t RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr,
    return UNKNOWN_ERROR;
}

status_t RpcSession::initAndAddConnection(unique_fd fd, const RpcAddress& sessionId,
status_t RpcSession::initAndAddConnection(unique_fd fd, const std::vector<uint8_t>& sessionId,
                                          bool incoming) {
    LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr);
    auto server = mCtx->newTransport(std::move(fd), mShutdownTrigger.get());
@@ -548,13 +546,20 @@ status_t RpcSession::initAndAddConnection(unique_fd fd, const RpcAddress& sessio

    LOG_RPC_DETAIL("Socket at client with RpcTransport %p", server.get());

    if (sessionId.size() > std::numeric_limits<uint16_t>::max()) {
        ALOGE("Session ID too big %zu", sessionId.size());
        return BAD_VALUE;
    }

    RpcConnectionHeader header{
            .version = mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION),
            .options = 0,
            .sessionIdSize = static_cast<uint16_t>(sessionId.size()),
    };
    memcpy(&header.sessionId, &sessionId.viewRawEmbedded(), sizeof(RpcWireAddress));

    if (incoming) header.options |= RPC_CONNECTION_OPTION_INCOMING;
    if (incoming) {
        header.options |= RPC_CONNECTION_OPTION_INCOMING;
    }

    auto sendHeaderStatus =
            server->interruptableWriteFully(mShutdownTrigger.get(), &header, sizeof(header));
@@ -564,6 +569,18 @@ status_t RpcSession::initAndAddConnection(unique_fd fd, const RpcAddress& sessio
        return sendHeaderStatus;
    }

    if (sessionId.size() > 0) {
        auto sendSessionIdStatus =
                server->interruptableWriteFully(mShutdownTrigger.get(), sessionId.data(),
                                                sessionId.size());
        if (sendSessionIdStatus != OK) {
            ALOGE("Could not write session ID ('%s') to socket: %s",
                  base::HexString(sessionId.data(), sessionId.size()).c_str(),
                  statusToString(sendSessionIdStatus).c_str());
            return sendSessionIdStatus;
        }
    }

    LOG_RPC_DETAIL("Socket at client: header sent");

    if (incoming) {
@@ -636,7 +653,7 @@ status_t RpcSession::addOutgoingConnection(std::unique_ptr<RpcTransport> rpcTran
}

bool RpcSession::setForServer(const wp<RpcServer>& server, const wp<EventListener>& eventListener,
                              const RpcAddress& sessionId) {
                              const std::vector<uint8_t>& sessionId) {
    LOG_ALWAYS_FATAL_IF(mForServer != nullptr);
    LOG_ALWAYS_FATAL_IF(server == nullptr);
    LOG_ALWAYS_FATAL_IF(mEventListener != nullptr);
+3 −3
Original line number Diff line number Diff line
@@ -397,7 +397,7 @@ status_t RpcState::getMaxThreads(const sp<RpcSession::RpcConnection>& connection
}

status_t RpcState::getSessionId(const sp<RpcSession::RpcConnection>& connection,
                                const sp<RpcSession>& session, RpcAddress* sessionIdOut) {
                                const sp<RpcSession>& session, std::vector<uint8_t>* sessionIdOut) {
    Parcel data;
    data.markForRpc(session);
    Parcel reply;
@@ -410,7 +410,7 @@ status_t RpcState::getSessionId(const sp<RpcSession::RpcConnection>& connection,
        return status;
    }

    return sessionIdOut->readFromParcel(reply);
    return reply.readByteVector(sessionIdOut);
}

status_t RpcState::transact(const sp<RpcSession::RpcConnection>& connection,
@@ -792,7 +792,7 @@ processTransactInternalTailCall:
                    // for client connections, this should always report the value
                    // originally returned from the server, so this is asserting
                    // that it exists
                    replyStatus = session->mId.value().writeToParcel(&reply);
                    replyStatus = reply.writeByteVector(session->mId);
                    break;
                }
                default: {
+1 −1
Original line number Diff line number Diff line
@@ -73,7 +73,7 @@ public:
    status_t getMaxThreads(const sp<RpcSession::RpcConnection>& connection,
                           const sp<RpcSession>& session, size_t* maxThreadsOut);
    status_t getSessionId(const sp<RpcSession::RpcConnection>& connection,
                          const sp<RpcSession>& session, RpcAddress* sessionIdOut);
                          const sp<RpcSession>& session, std::vector<uint8_t>* sessionIdOut);

    [[nodiscard]] status_t transact(const sp<RpcSession::RpcConnection>& connection,
                                    const sp<IBinder>& address, uint32_t code, const Parcel& data,
+6 −7
Original line number Diff line number Diff line
@@ -20,9 +20,7 @@ namespace android {
#pragma clang diagnostic push
#pragma clang diagnostic error "-Wpadded"

enum : uint8_t {
    RPC_CONNECTION_OPTION_INCOMING = 0x1, // default is outgoing
};
constexpr uint8_t RPC_CONNECTION_OPTION_INCOMING = 0x1; // default is outgoing

constexpr uint64_t RPC_WIRE_ADDRESS_OPTION_CREATED = 1 << 0; // distinguish from '0' address
constexpr uint64_t RPC_WIRE_ADDRESS_OPTION_FOR_SERVER = 1 << 1;
@@ -39,12 +37,13 @@ static_assert(sizeof(RpcWireAddress) == 40);
 */
struct RpcConnectionHeader {
    uint32_t version; // maximum supported by caller
    uint8_t reserver0[4];
    RpcWireAddress sessionId;
    uint8_t options;
    uint8_t reserved1[7];
    uint8_t reservered[9];
    // Follows is sessionIdSize bytes.
    // if size is 0, this is requesting a new session.
    uint16_t sessionIdSize;
};
static_assert(sizeof(RpcConnectionHeader) == 56);
static_assert(sizeof(RpcConnectionHeader) == 16);

/**
 * In response to an RpcConnectionHeader which corresponds to a new session,
Loading