Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit aa87e1d0 authored by Steven Moreland's avatar Steven Moreland Committed by Automerger Merge Worker
Browse files

Merge "libbinder: allow externally created connections" am: 91c6a9f5 am: ecf935cd

Original change: https://android-review.googlesource.com/c/platform/frameworks/native/+/1785497

Change-Id: I6bb274e43669e1a4abcbaf8b49383cdd44796b78
parents f3eabb3b ecf935cd
Loading
Loading
Loading
Loading
+70 −50
Original line number Original line Diff line number Diff line
@@ -126,6 +126,17 @@ bool RpcSession::setupInetClient(const char* addr, unsigned int port) {
    return false;
    return false;
}
}


bool RpcSession::setupPreconnectedClient(unique_fd fd, std::function<unique_fd()>&& request) {
    return setupClient([&](const RpcAddress& sessionId, bool incoming) {
        // std::move'd from fd becomes -1 (!ok())
        if (!fd.ok()) {
            fd = request();
            if (!fd.ok()) return false;
        }
        return initAndAddConnection(std::move(fd), sessionId, incoming);
    });
}

bool RpcSession::addNullDebuggingClient() {
bool RpcSession::addNullDebuggingClient() {
    // Note: only works on raw sockets.
    // Note: only works on raw sockets.
    unique_fd serverFd(TEMP_FAILURE_RETRY(open("/dev/null", O_WRONLY | O_CLOEXEC)));
    unique_fd serverFd(TEMP_FAILURE_RETRY(open("/dev/null", O_WRONLY | O_CLOEXEC)));
@@ -464,7 +475,8 @@ sp<RpcServer> RpcSession::server() {
    return server;
    return server;
}
}


bool RpcSession::setupSocketClient(const RpcSocketAddress& addr) {
bool RpcSession::setupClient(
        const std::function<bool(const RpcAddress& sessionId, bool incoming)>& connectAndInit) {
    {
    {
        std::lock_guard<std::mutex> _l(mMutex);
        std::lock_guard<std::mutex> _l(mMutex);
        LOG_ALWAYS_FATAL_IF(mOutgoingConnections.size() != 0,
        LOG_ALWAYS_FATAL_IF(mOutgoingConnections.size() != 0,
@@ -472,7 +484,7 @@ bool RpcSession::setupSocketClient(const RpcSocketAddress& addr) {
                            mOutgoingConnections.size());
                            mOutgoingConnections.size());
    }
    }


    if (!setupOneSocketConnection(addr, RpcAddress::zero(), false /*incoming*/)) return false;
    if (!connectAndInit(RpcAddress::zero(), false /*incoming*/)) return false;


    {
    {
        ExclusiveConnection connection;
        ExclusiveConnection connection;
@@ -491,37 +503,42 @@ bool RpcSession::setupSocketClient(const RpcSocketAddress& addr) {
    // TODO(b/186470974): first risk of blocking
    // TODO(b/186470974): first risk of blocking
    size_t numThreadsAvailable;
    size_t numThreadsAvailable;
    if (status_t status = getRemoteMaxThreads(&numThreadsAvailable); status != OK) {
    if (status_t status = getRemoteMaxThreads(&numThreadsAvailable); status != OK) {
        ALOGE("Could not get max threads after initial session to %s: %s", addr.toString().c_str(),
        ALOGE("Could not get max threads after initial session setup: %s",
              statusToString(status).c_str());
              statusToString(status).c_str());
        return false;
        return false;
    }
    }


    if (status_t status = readId(); status != OK) {
    if (status_t status = readId(); status != OK) {
        ALOGE("Could not get session id after initial session to %s; %s", addr.toString().c_str(),
        ALOGE("Could not get session id after initial session setup: %s",
              statusToString(status).c_str());
              statusToString(status).c_str());
        return false;
        return false;
    }
    }


    // we've already setup one client
    for (size_t i = 0; i + 1 < numThreadsAvailable; i++) {
        // TODO(b/189955605): shutdown existing connections?
        if (!setupOneSocketConnection(addr, mId.value(), false /*incoming*/)) return false;
    }

    // TODO(b/189955605): we should add additional sessions dynamically
    // TODO(b/189955605): we should add additional sessions dynamically
    // instead of all at once - the other side should be responsible for setting
    // instead of all at once - the other side should be responsible for setting
    // up additional connections. We need to create at least one (unless 0 are
    // up additional connections. We need to create at least one (unless 0 are
    // requested to be set) in order to allow the other side to reliably make
    // requested to be set) in order to allow the other side to reliably make
    // any requests at all.
    // any requests at all.


    // we've already setup one client
    for (size_t i = 0; i + 1 < numThreadsAvailable; i++) {
        if (!connectAndInit(mId.value(), false /*incoming*/)) return false;
    }

    for (size_t i = 0; i < mMaxThreads; i++) {
    for (size_t i = 0; i < mMaxThreads; i++) {
        if (!setupOneSocketConnection(addr, mId.value(), true /*incoming*/)) return false;
        if (!connectAndInit(mId.value(), true /*incoming*/)) return false;
    }
    }


    return true;
    return true;
}
}


bool RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr, const RpcAddress& id,
bool RpcSession::setupSocketClient(const RpcSocketAddress& addr) {
    return setupClient([&](const RpcAddress& sessionId, bool incoming) {
        return setupOneSocketConnection(addr, sessionId, incoming);
    });
}

bool RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr, const RpcAddress& sessionId,
                                          bool incoming) {
                                          bool incoming) {
    for (size_t tries = 0; tries < 5; tries++) {
    for (size_t tries = 0; tries < 5; tries++) {
        if (tries > 0) usleep(10000);
        if (tries > 0) usleep(10000);
@@ -547,52 +564,55 @@ bool RpcSession::setupOneSocketConnection(const RpcSocketAddress& addr, const Rp
        }
        }
        LOG_RPC_DETAIL("Socket at %s client with fd %d", addr.toString().c_str(), serverFd.get());
        LOG_RPC_DETAIL("Socket at %s client with fd %d", addr.toString().c_str(), serverFd.get());


        return initAndAddConnection(std::move(serverFd), sessionId, incoming);
    }

    ALOGE("Ran out of retries to connect to %s", addr.toString().c_str());
    return false;
}

bool RpcSession::initAndAddConnection(unique_fd fd, const RpcAddress& sessionId, bool incoming) {
    auto ctx = mRpcTransportCtxFactory->newClientCtx();
    auto ctx = mRpcTransportCtxFactory->newClientCtx();
    if (ctx == nullptr) {
    if (ctx == nullptr) {
        ALOGE("Unable to create client RpcTransportCtx with %s sockets",
        ALOGE("Unable to create client RpcTransportCtx with %s sockets",
              mRpcTransportCtxFactory->toCString());
              mRpcTransportCtxFactory->toCString());
        return false;
        return false;
    }
    }
        auto server = ctx->newTransport(std::move(serverFd));
    auto server = ctx->newTransport(std::move(fd));
    if (server == nullptr) {
    if (server == nullptr) {
            ALOGE("Unable to set up RpcTransport for %s", addr.toString().c_str());
        ALOGE("Unable to set up RpcTransport in %s context", mRpcTransportCtxFactory->toCString());
        return false;
        return false;
    }
    }


        LOG_RPC_DETAIL("Socket at %s client with RpcTransport %p", addr.toString().c_str(),
    LOG_RPC_DETAIL("Socket at client with RpcTransport %p", server.get());
                       server.get());


    RpcConnectionHeader header{
    RpcConnectionHeader header{
            .version = mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION),
            .version = mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION),
            .options = 0,
            .options = 0,
    };
    };
        memcpy(&header.sessionId, &id.viewRawEmbedded(), sizeof(RpcWireAddress));
    memcpy(&header.sessionId, &sessionId.viewRawEmbedded(), sizeof(RpcWireAddress));


    if (incoming) header.options |= RPC_CONNECTION_OPTION_INCOMING;
    if (incoming) header.options |= RPC_CONNECTION_OPTION_INCOMING;


    auto sentHeader = server->send(&header, sizeof(header));
    auto sentHeader = server->send(&header, sizeof(header));
    if (!sentHeader.ok()) {
    if (!sentHeader.ok()) {
            ALOGE("Could not write connection header to socket at %s: %s", addr.toString().c_str(),
        ALOGE("Could not write connection header to socket: %s",
              sentHeader.error().message().c_str());
              sentHeader.error().message().c_str());
        return false;
        return false;
    }
    }
    if (*sentHeader != sizeof(header)) {
    if (*sentHeader != sizeof(header)) {
            ALOGE("Could not write connection header to socket at %s: sent %zd bytes, expected %zd",
        ALOGE("Could not write connection header to socket: sent %zd bytes, expected %zd",
                  addr.toString().c_str(), *sentHeader, sizeof(header));
              *sentHeader, sizeof(header));
        return false;
        return false;
    }
    }


        LOG_RPC_DETAIL("Socket at %s client: header sent", addr.toString().c_str());
    LOG_RPC_DETAIL("Socket at client: header sent");


    if (incoming) {
    if (incoming) {
        return addIncomingConnection(std::move(server));
        return addIncomingConnection(std::move(server));
    } else {
    } else {
            return addOutgoingConnection(std::move(server), true);
        return addOutgoingConnection(std::move(server), true /*init*/);
        }
    }
    }

    ALOGE("Ran out of retries to connect to %s", addr.toString().c_str());
    return false;
}
}


bool RpcSession::addIncomingConnection(std::unique_ptr<RpcTransport> rpcTransport) {
bool RpcSession::addIncomingConnection(std::unique_ptr<RpcTransport> rpcTransport) {
+18 −1
Original line number Original line Diff line number Diff line
@@ -90,6 +90,19 @@ public:
     */
     */
    [[nodiscard]] bool setupInetClient(const char* addr, unsigned int port);
    [[nodiscard]] bool setupInetClient(const char* addr, unsigned int port);


    /**
     * Starts talking to an RPC server which has already been connected to. This
     * is expected to be used when another process has permission to connect to
     * a binder RPC service, but this process only has permission to talk to
     * that service.
     *
     * For convenience, if 'fd' is -1, 'request' will be called.
     *
     * For future compatibility, 'request' should not reference any stack data.
     */
    [[nodiscard]] bool setupPreconnectedClient(base::unique_fd fd,
                                               std::function<base::unique_fd()>&& request);

    /**
    /**
     * For debugging!
     * For debugging!
     *
     *
@@ -240,9 +253,13 @@ private:
    // join on thread passed to preJoinThreadOwnership
    // join on thread passed to preJoinThreadOwnership
    static void join(sp<RpcSession>&& session, PreJoinSetupResult&& result);
    static void join(sp<RpcSession>&& session, PreJoinSetupResult&& result);


    [[nodiscard]] bool setupClient(
            const std::function<bool(const RpcAddress& sessionId, bool incoming)>& connectAndInit);
    [[nodiscard]] bool setupSocketClient(const RpcSocketAddress& address);
    [[nodiscard]] bool setupSocketClient(const RpcSocketAddress& address);
    [[nodiscard]] bool setupOneSocketConnection(const RpcSocketAddress& address,
    [[nodiscard]] bool setupOneSocketConnection(const RpcSocketAddress& address,
                                                const RpcAddress& sessionId, bool server);
                                                const RpcAddress& sessionId, bool incoming);
    [[nodiscard]] bool initAndAddConnection(base::unique_fd fd, const RpcAddress& sessionId,
                                            bool incoming);
    [[nodiscard]] bool addIncomingConnection(std::unique_ptr<RpcTransport> rpcTransport);
    [[nodiscard]] bool addIncomingConnection(std::unique_ptr<RpcTransport> rpcTransport);
    [[nodiscard]] bool addOutgoingConnection(std::unique_ptr<RpcTransport> rpcTransport, bool init);
    [[nodiscard]] bool addOutgoingConnection(std::unique_ptr<RpcTransport> rpcTransport, bool init);
    [[nodiscard]] bool setForServer(const wp<RpcServer>& server,
    [[nodiscard]] bool setForServer(const wp<RpcServer>& server,
+27 −1
Original line number Original line Diff line number Diff line
@@ -42,6 +42,7 @@
#include <sys/prctl.h>
#include <sys/prctl.h>
#include <unistd.h>
#include <unistd.h>


#include "../RpcSocketAddress.h" // for testing preconnected clients
#include "../RpcState.h"   // for debugging
#include "../RpcState.h"   // for debugging
#include "../vm_sockets.h" // for VMADDR_*
#include "../vm_sockets.h" // for VMADDR_*


@@ -409,12 +410,15 @@ struct BinderRpcTestProcessSession {
};
};


enum class SocketType {
enum class SocketType {
    PRECONNECTED,
    UNIX,
    UNIX,
    VSOCK,
    VSOCK,
    INET,
    INET,
};
};
static inline std::string PrintToString(SocketType socketType) {
static inline std::string PrintToString(SocketType socketType) {
    switch (socketType) {
    switch (socketType) {
        case SocketType::PRECONNECTED:
            return "preconnected_uds";
        case SocketType::UNIX:
        case SocketType::UNIX:
            return "unix_domain_socket";
            return "unix_domain_socket";
        case SocketType::VSOCK:
        case SocketType::VSOCK:
@@ -427,6 +431,20 @@ static inline std::string PrintToString(SocketType socketType) {
    }
    }
}
}


static base::unique_fd connectToUds(const char* addrStr) {
    UnixSocketAddress addr(addrStr);
    base::unique_fd serverFd(
            TEMP_FAILURE_RETRY(socket(addr.addr()->sa_family, SOCK_STREAM | SOCK_CLOEXEC, 0)));
    int savedErrno = errno;
    CHECK(serverFd.ok()) << "Could not create socket " << addrStr << ": " << strerror(savedErrno);

    if (0 != TEMP_FAILURE_RETRY(connect(serverFd.get(), addr.addr(), addr.addrSize()))) {
        int savedErrno = errno;
        LOG(FATAL) << "Could not connect to socket " << addrStr << ": " << strerror(savedErrno);
    }
    return serverFd;
}

class BinderRpc : public ::testing::TestWithParam<std::tuple<SocketType, RpcSecurity>> {
class BinderRpc : public ::testing::TestWithParam<std::tuple<SocketType, RpcSecurity>> {
public:
public:
    struct Options {
    struct Options {
@@ -463,6 +481,8 @@ public:
                    unsigned int outPort = 0;
                    unsigned int outPort = 0;


                    switch (socketType) {
                    switch (socketType) {
                        case SocketType::PRECONNECTED:
                            [[fallthrough]];
                        case SocketType::UNIX:
                        case SocketType::UNIX:
                            CHECK(server->setupUnixDomainServer(addr.c_str())) << addr;
                            CHECK(server->setupUnixDomainServer(addr.c_str())) << addr;
                            break;
                            break;
@@ -501,6 +521,12 @@ public:
            session->setMaxThreads(options.numIncomingConnections);
            session->setMaxThreads(options.numIncomingConnections);


            switch (socketType) {
            switch (socketType) {
                case SocketType::PRECONNECTED:
                    if (session->setupPreconnectedClient({}, [=]() {
                            return connectToUds(addr.c_str());
                        }))
                        goto success;
                    break;
                case SocketType::UNIX:
                case SocketType::UNIX:
                    if (session->setupUnixDomainClient(addr.c_str())) goto success;
                    if (session->setupUnixDomainClient(addr.c_str())) goto success;
                    break;
                    break;
@@ -1176,7 +1202,7 @@ static bool testSupportVsockLoopback() {
}
}


static std::vector<SocketType> testSocketTypes() {
static std::vector<SocketType> testSocketTypes() {
    std::vector<SocketType> ret = {SocketType::UNIX, SocketType::INET};
    std::vector<SocketType> ret = {SocketType::PRECONNECTED, SocketType::UNIX, SocketType::INET};


    static bool hasVsockLoopback = testSupportVsockLoopback();
    static bool hasVsockLoopback = testSupportVsockLoopback();