Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit bdc7074d authored by David Brazdil's avatar David Brazdil Committed by Automerger Merge Worker
Browse files

Merge "RpcBinder: Add AF_UNIX socketpair transport" am: 08178a72 am: 6b7d4592 am: 41260c33

parents efc4b287 41260c33
Loading
Loading
Loading
Loading
+99 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@

#include <android-base/file.h>
#include <binder/RpcTransportRaw.h>
#include <log/log.h>
#include <string.h>

using android::base::ErrnoError;
@@ -25,6 +26,9 @@ using android::base::Result;

namespace android {

// Linux kernel supports up to 253 (from SCM_MAX_FD) for unix sockets.
constexpr size_t kMaxFdsPerMsg = 253;

Result<void> setNonBlocking(android::base::borrowed_fd fd) {
    int flags = TEMP_FAILURE_RETRY(fcntl(fd.get(), F_GETFL));
    if (flags == -1) {
@@ -63,4 +67,99 @@ std::unique_ptr<RpcTransportCtxFactory> makeDefaultRpcTransportCtxFactory() {
    return RpcTransportCtxFactoryRaw::make();
}

int sendMessageOnSocket(
        const RpcTransportFd& socket, iovec* iovs, int niovs,
        const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) {
    if (ancillaryFds != nullptr && !ancillaryFds->empty()) {
        if (ancillaryFds->size() > kMaxFdsPerMsg) {
            errno = EINVAL;
            return -1;
        }

        // CMSG_DATA is not necessarily aligned, so we copy the FDs into a buffer and then
        // use memcpy.
        int fds[kMaxFdsPerMsg];
        for (size_t i = 0; i < ancillaryFds->size(); i++) {
            fds[i] = std::visit([](const auto& fd) { return fd.get(); }, ancillaryFds->at(i));
        }
        const size_t fdsByteSize = sizeof(int) * ancillaryFds->size();

        alignas(struct cmsghdr) char msgControlBuf[CMSG_SPACE(sizeof(int) * kMaxFdsPerMsg)];

        msghdr msg{
                .msg_iov = iovs,
                .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs),
                .msg_control = msgControlBuf,
                .msg_controllen = sizeof(msgControlBuf),
        };

        cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
        cmsg->cmsg_level = SOL_SOCKET;
        cmsg->cmsg_type = SCM_RIGHTS;
        cmsg->cmsg_len = CMSG_LEN(fdsByteSize);
        memcpy(CMSG_DATA(cmsg), fds, fdsByteSize);

        msg.msg_controllen = CMSG_SPACE(fdsByteSize);
        return TEMP_FAILURE_RETRY(sendmsg(socket.fd.get(), &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC));
    }

    msghdr msg{
            .msg_iov = iovs,
            // posix uses int, glibc uses size_t.  niovs is a
            // non-negative int and can be cast to either.
            .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs),
    };
    return TEMP_FAILURE_RETRY(sendmsg(socket.fd.get(), &msg, MSG_NOSIGNAL));
}

int receiveMessageFromSocket(
        const RpcTransportFd& socket, iovec* iovs, int niovs,
        std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) {
    if (ancillaryFds != nullptr) {
        int fdBuffer[kMaxFdsPerMsg];
        alignas(struct cmsghdr) char msgControlBuf[CMSG_SPACE(sizeof(fdBuffer))];

        msghdr msg{
                .msg_iov = iovs,
                .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs),
                .msg_control = msgControlBuf,
                .msg_controllen = sizeof(msgControlBuf),
        };
        ssize_t processSize = TEMP_FAILURE_RETRY(recvmsg(socket.fd.get(), &msg, MSG_NOSIGNAL));
        if (processSize < 0) {
            return -1;
        }

        for (cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
            if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
                // NOTE: It is tempting to reinterpret_cast, but cmsg(3) explicitly asks
                // application devs to memcpy the data to ensure memory alignment.
                size_t dataLen = cmsg->cmsg_len - CMSG_LEN(0);
                LOG_ALWAYS_FATAL_IF(dataLen > sizeof(fdBuffer)); // validity check
                memcpy(fdBuffer, CMSG_DATA(cmsg), dataLen);
                size_t fdCount = dataLen / sizeof(int);
                ancillaryFds->reserve(ancillaryFds->size() + fdCount);
                for (size_t i = 0; i < fdCount; i++) {
                    ancillaryFds->emplace_back(base::unique_fd(fdBuffer[i]));
                }
                break;
            }
        }

        if (msg.msg_flags & MSG_CTRUNC) {
            errno = EPIPE;
            return -1;
        }
        return processSize;
    }
    msghdr msg{
            .msg_iov = iovs,
            // posix uses int, glibc uses size_t.  niovs is a
            // non-negative int and can be cast to either.
            .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs),
    };

    return TEMP_FAILURE_RETRY(recvmsg(socket.fd.get(), &msg, MSG_NOSIGNAL));
}

} // namespace android
+8 −0
Original line number Diff line number Diff line
@@ -33,4 +33,12 @@ status_t dupFileDescriptor(int oldFd, int* newFd);

std::unique_ptr<RpcTransportCtxFactory> makeDefaultRpcTransportCtxFactory();

int sendMessageOnSocket(
        const RpcTransportFd& socket, iovec* iovs, int niovs,
        const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds);

int receiveMessageFromSocket(
        const RpcTransportFd& socket, iovec* iovs, int niovs,
        std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds);

} // namespace android
+61 −11
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@
#include "OS.h"
#include "RpcSocketAddress.h"
#include "RpcState.h"
#include "RpcTransportUtils.h"
#include "RpcWireFormat.h"
#include "Utils.h"

@@ -61,6 +62,10 @@ sp<RpcServer> RpcServer::make(std::unique_ptr<RpcTransportCtxFactory> rpcTranspo
    return sp<RpcServer>::make(std::move(ctx));
}

status_t RpcServer::setupUnixDomainSocketBootstrapServer(unique_fd bootstrapFd) {
    return setupExternalServer(std::move(bootstrapFd), &RpcServer::recvmsgSocketConnection);
}

status_t RpcServer::setupUnixDomainServer(const char* path) {
    return setupSocketServer(UnixSocketAddress(path));
}
@@ -177,11 +182,50 @@ void RpcServer::start() {
    rpcJoinIfSingleThreaded(*mJoinThread);
}

status_t RpcServer::acceptSocketConnection(const RpcServer& server, RpcTransportFd* out) {
    RpcTransportFd clientSocket(unique_fd(TEMP_FAILURE_RETRY(
            accept4(server.mServer.fd.get(), nullptr, nullptr, SOCK_CLOEXEC | SOCK_NONBLOCK))));
    if (clientSocket.fd < 0) {
        int savedErrno = errno;
        ALOGE("Could not accept4 socket: %s", strerror(savedErrno));
        return -savedErrno;
    }

    *out = std::move(clientSocket);
    return OK;
}

status_t RpcServer::recvmsgSocketConnection(const RpcServer& server, RpcTransportFd* out) {
    int zero = 0;
    iovec iov{&zero, sizeof(zero)};
    std::vector<std::variant<base::unique_fd, base::borrowed_fd>> fds;

    if (receiveMessageFromSocket(server.mServer, &iov, 1, &fds) < 0) {
        int savedErrno = errno;
        ALOGE("Failed recvmsg: %s", strerror(savedErrno));
        return -savedErrno;
    }
    if (fds.size() != 1) {
        ALOGE("Expected exactly one fd from recvmsg, got %zu", fds.size());
        return -EINVAL;
    }

    unique_fd fd(std::move(std::get<unique_fd>(fds.back())));
    if (auto res = setNonBlocking(fd); !res.ok()) {
        ALOGE("Failed setNonBlocking: %s", res.error().message().c_str());
        return res.error().code() == 0 ? UNKNOWN_ERROR : -res.error().code();
    }

    *out = RpcTransportFd(std::move(fd));
    return OK;
}

void RpcServer::join() {

    {
        RpcMutexLockGuard _l(mLock);
        LOG_ALWAYS_FATAL_IF(!mServer.fd.ok(), "RpcServer must be setup to join.");
        LOG_ALWAYS_FATAL_IF(mAcceptFn == nullptr, "RpcServer must have an accept() function");
        LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr, "Already joined");
        mJoinThreadRunning = true;
        mShutdownTrigger = FdTrigger::make();
@@ -192,20 +236,19 @@ void RpcServer::join() {
    while ((status = mShutdownTrigger->triggerablePoll(mServer, POLLIN)) == OK) {
        std::array<uint8_t, kRpcAddressSize> addr;
        static_assert(addr.size() >= sizeof(sockaddr_storage), "kRpcAddressSize is too small");

        socklen_t addrLen = addr.size();
        RpcTransportFd clientSocket(unique_fd(TEMP_FAILURE_RETRY(
                accept4(mServer.fd.get(), reinterpret_cast<sockaddr*>(addr.data()), &addrLen,
                        SOCK_CLOEXEC | SOCK_NONBLOCK))));

        LOG_ALWAYS_FATAL_IF(addrLen > static_cast<socklen_t>(sizeof(sockaddr_storage)),
                            "Truncated address");

        if (clientSocket.fd < 0) {
            ALOGE("Could not accept4 socket: %s", strerror(errno));
        RpcTransportFd clientSocket;
        if (mAcceptFn(*this, &clientSocket) != OK) {
            continue;
        }
        if (getpeername(clientSocket.fd.get(), reinterpret_cast<sockaddr*>(addr.data()),
                        &addrLen)) {
            ALOGE("Could not getpeername socket: %s", strerror(errno));
            continue;
        }
        LOG_RPC_DETAIL("accept4 on fd %d yields fd %d", mServer.fd.get(), clientSocket.fd.get());

        LOG_RPC_DETAIL("accept on fd %d yields fd %d", mServer.fd.get(), clientSocket.fd.get());

        {
            RpcMutexLockGuard _l(mLock);
@@ -550,16 +593,23 @@ unique_fd RpcServer::releaseServer() {
    return std::move(mServer.fd);
}

status_t RpcServer::setupExternalServer(base::unique_fd serverFd) {
status_t RpcServer::setupExternalServer(
        base::unique_fd serverFd,
        std::function<status_t(const RpcServer&, RpcTransportFd*)>&& acceptFn) {
    RpcMutexLockGuard _l(mLock);
    if (mServer.fd.ok()) {
        ALOGE("Each RpcServer can only have one server.");
        return INVALID_OPERATION;
    }
    mServer = std::move(serverFd);
    mAcceptFn = std::move(acceptFn);
    return OK;
}

status_t RpcServer::setupExternalServer(base::unique_fd serverFd) {
    return setupExternalServer(std::move(serverFd), &RpcServer::acceptSocketConnection);
}

bool RpcServer::hasActiveRequests() {
    RpcMutexLockGuard _l(mLock);
    for (const auto& [_, session] : mSessions) {
+29 −0
Original line number Diff line number Diff line
@@ -41,6 +41,7 @@
#include "OS.h"
#include "RpcSocketAddress.h"
#include "RpcState.h"
#include "RpcTransportUtils.h"
#include "RpcWireFormat.h"
#include "Utils.h"

@@ -147,6 +148,34 @@ status_t RpcSession::setupUnixDomainClient(const char* path) {
    return setupSocketClient(UnixSocketAddress(path));
}

status_t RpcSession::setupUnixDomainSocketBootstrapClient(unique_fd bootstrapFd) {
    mBootstrapTransport =
            mCtx->newTransport(RpcTransportFd(std::move(bootstrapFd)), mShutdownTrigger.get());
    return setupClient([&](const std::vector<uint8_t>& sessionId, bool incoming) {
        int socks[2];
        if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0, socks) < 0) {
            int savedErrno = errno;
            ALOGE("Failed socketpair: %s", strerror(savedErrno));
            return -savedErrno;
        }
        unique_fd clientFd(socks[0]), serverFd(socks[1]);

        int zero = 0;
        iovec iov{&zero, sizeof(zero)};
        std::vector<std::variant<base::unique_fd, base::borrowed_fd>> fds;
        fds.push_back(std::move(serverFd));

        status_t status = mBootstrapTransport->interruptableWriteFully(mShutdownTrigger.get(), &iov,
                                                                       1, std::nullopt, &fds);
        if (status != OK) {
            ALOGE("Failed to send fd over bootstrap transport: %s", strerror(-status));
            return status;
        }

        return initAndAddConnection(RpcTransportFd(std::move(clientFd)), sessionId, incoming);
    });
}

status_t RpcSession::setupVsockClient(unsigned int cid, unsigned int port) {
    return setupSocketClient(VsockSocketAddress(cid, port));
}
+5 −102
Original line number Diff line number Diff line
@@ -23,6 +23,7 @@
#include <binder/RpcTransportRaw.h>

#include "FdTrigger.h"
#include "OS.h"
#include "RpcState.h"
#include "RpcTransportUtils.h"

@@ -30,9 +31,6 @@ namespace android {

namespace {

// Linux kernel supports up to 253 (from SCM_MAX_FD) for unix sockets.
constexpr size_t kMaxFdsPerMsg = 253;

// RpcTransport with TLS disabled.
class RpcTransportRaw : public RpcTransport {
public:
@@ -63,57 +61,9 @@ public:
            override {
        bool sentFds = false;
        auto send = [&](iovec* iovs, int niovs) -> ssize_t {
            if (ancillaryFds != nullptr && !ancillaryFds->empty() && !sentFds) {
                if (ancillaryFds->size() > kMaxFdsPerMsg) {
                    // This shouldn't happen because we check the FD count in RpcState.
                    ALOGE("Saw too many file descriptors in RpcTransportCtxRaw: %zu (max is %zu). "
                          "Aborting session.",
                          ancillaryFds->size(), kMaxFdsPerMsg);
                    errno = EINVAL;
                    return -1;
                }

                // CMSG_DATA is not necessarily aligned, so we copy the FDs into a buffer and then
                // use memcpy.
                int fds[kMaxFdsPerMsg];
                for (size_t i = 0; i < ancillaryFds->size(); i++) {
                    fds[i] = std::visit([](const auto& fd) { return fd.get(); },
                                        ancillaryFds->at(i));
                }
                const size_t fdsByteSize = sizeof(int) * ancillaryFds->size();

                alignas(struct cmsghdr) char msgControlBuf[CMSG_SPACE(sizeof(int) * kMaxFdsPerMsg)];

                msghdr msg{
                        .msg_iov = iovs,
                        .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs),
                        .msg_control = msgControlBuf,
                        .msg_controllen = sizeof(msgControlBuf),
                };

                cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
                cmsg->cmsg_level = SOL_SOCKET;
                cmsg->cmsg_type = SCM_RIGHTS;
                cmsg->cmsg_len = CMSG_LEN(fdsByteSize);
                memcpy(CMSG_DATA(cmsg), fds, fdsByteSize);

                msg.msg_controllen = CMSG_SPACE(fdsByteSize);

                ssize_t processedSize = TEMP_FAILURE_RETRY(
                        sendmsg(mSocket.fd.get(), &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC));
                if (processedSize > 0) {
                    sentFds = true;
                }
                return processedSize;
            }

            msghdr msg{
                    .msg_iov = iovs,
                    // posix uses int, glibc uses size_t.  niovs is a
                    // non-negative int and can be cast to either.
                    .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs),
            };
            return TEMP_FAILURE_RETRY(sendmsg(mSocket.fd.get(), &msg, MSG_NOSIGNAL));
            int ret = sendMessageOnSocket(mSocket, iovs, niovs, sentFds ? nullptr : ancillaryFds);
            sentFds |= ret > 0;
            return ret;
        };
        return interruptableReadOrWrite(mSocket, fdTrigger, iovs, niovs, send, "sendmsg", POLLOUT,
                                        altPoll);
@@ -124,54 +74,7 @@ public:
            const std::optional<android::base::function_ref<status_t()>>& altPoll,
            std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds) override {
        auto recv = [&](iovec* iovs, int niovs) -> ssize_t {
            if (ancillaryFds != nullptr) {
                int fdBuffer[kMaxFdsPerMsg];
                alignas(struct cmsghdr) char msgControlBuf[CMSG_SPACE(sizeof(fdBuffer))];

                msghdr msg{
                        .msg_iov = iovs,
                        .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs),
                        .msg_control = msgControlBuf,
                        .msg_controllen = sizeof(msgControlBuf),
                };
                ssize_t processSize =
                        TEMP_FAILURE_RETRY(recvmsg(mSocket.fd.get(), &msg, MSG_NOSIGNAL));
                if (processSize < 0) {
                    return -1;
                }

                for (cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr;
                     cmsg = CMSG_NXTHDR(&msg, cmsg)) {
                    if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
                        // NOTE: It is tempting to reinterpret_cast, but cmsg(3) explicitly asks
                        // application devs to memcpy the data to ensure memory alignment.
                        size_t dataLen = cmsg->cmsg_len - CMSG_LEN(0);
                        LOG_ALWAYS_FATAL_IF(dataLen > sizeof(fdBuffer)); // sanity check
                        memcpy(fdBuffer, CMSG_DATA(cmsg), dataLen);
                        size_t fdCount = dataLen / sizeof(int);
                        ancillaryFds->reserve(ancillaryFds->size() + fdCount);
                        for (size_t i = 0; i < fdCount; i++) {
                            ancillaryFds->emplace_back(base::unique_fd(fdBuffer[i]));
                        }
                        break;
                    }
                }

                if (msg.msg_flags & MSG_CTRUNC) {
                    ALOGE("msg was truncated. Aborting session.");
                    errno = EPIPE;
                    return -1;
                }

                return processSize;
            }
            msghdr msg{
                    .msg_iov = iovs,
                    // posix uses int, glibc uses size_t.  niovs is a
                    // non-negative int and can be cast to either.
                    .msg_iovlen = static_cast<decltype(msg.msg_iovlen)>(niovs),
            };
            return TEMP_FAILURE_RETRY(recvmsg(mSocket.fd.get(), &msg, MSG_NOSIGNAL));
            return receiveMessageFromSocket(mSocket, iovs, niovs, ancillaryFds);
        };
        return interruptableReadOrWrite(mSocket, fdTrigger, iovs, niovs, recv, "recvmsg", POLLIN,
                                        altPoll);
Loading