Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit a335edca authored by Frederick Mayle's avatar Frederick Mayle Committed by Automerger Merge Worker
Browse files

Merge "binder: Add FD support to RPC Binder" am: e8b659c3 am: a5532800 am: 61d6f208

parents 0cba7cd5 61d6f208
Loading
Loading
Loading
Loading
+267 −66
Original line number Diff line number Diff line
@@ -40,6 +40,7 @@
#include <binder/Status.h>
#include <binder/TextOutput.h>

#include <android-base/scopeguard.h>
#include <cutils/ashmem.h>
#include <cutils/compiler.h>
#include <utils/Flattenable.h>
@@ -152,6 +153,10 @@ static void release_object(const sp<ProcessState>& proc, const flat_binder_objec
    ALOGE("Invalid object type 0x%08x", obj.hdr.type);
}

static int toRawFd(const std::variant<base::unique_fd, base::borrowed_fd>& v) {
    return std::visit([](const auto& fd) { return fd.get(); }, v);
}

Parcel::RpcFields::RpcFields(const sp<RpcSession>& session) : mSession(session) {
    LOG_ALWAYS_FATAL_IF(mSession == nullptr);
}
@@ -530,6 +535,63 @@ status_t Parcel::appendFrom(const Parcel* parcel, size_t offset, size_t len) {
                }
            }
        }
    } else {
        auto* rpcFields = maybeRpcFields();
        LOG_ALWAYS_FATAL_IF(rpcFields == nullptr);
        auto* otherRpcFields = parcel->maybeRpcFields();
        if (otherRpcFields == nullptr) {
            return BAD_TYPE;
        }
        if (rpcFields->mSession != otherRpcFields->mSession) {
            return BAD_TYPE;
        }

        const size_t savedDataPos = mDataPos;
        base::ScopeGuard scopeGuard = [&]() { mDataPos = savedDataPos; };

        rpcFields->mObjectPositions.reserve(otherRpcFields->mObjectPositions.size());
        if (otherRpcFields->mFds != nullptr) {
            if (rpcFields->mFds == nullptr) {
                rpcFields->mFds = std::make_unique<decltype(rpcFields->mFds)::element_type>();
            }
            rpcFields->mFds->reserve(otherRpcFields->mFds->size());
        }
        for (size_t i = 0; i < otherRpcFields->mObjectPositions.size(); i++) {
            const binder_size_t objPos = otherRpcFields->mObjectPositions[i];
            if (offset <= objPos && objPos < offset + len) {
                size_t newDataPos = objPos - offset + startPos;
                rpcFields->mObjectPositions.push_back(newDataPos);

                mDataPos = newDataPos;
                int32_t objectType;
                if (status_t status = readInt32(&objectType); status != OK) {
                    return status;
                }
                if (objectType != RpcFields::TYPE_NATIVE_FILE_DESCRIPTOR) {
                    continue;
                }

                if (!mAllowFds) {
                    return FDS_NOT_ALLOWED;
                }

                // Read FD, duplicate, and add to list.
                int32_t fdIndex;
                if (status_t status = readInt32(&fdIndex); status != OK) {
                    return status;
                }
                const auto& oldFd = otherRpcFields->mFds->at(fdIndex);
                // To match kernel binder behavior, we always dup, even if the
                // FD was unowned in the source parcel.
                rpcFields->mFds->emplace_back(
                        base::unique_fd(fcntl(toRawFd(oldFd), F_DUPFD_CLOEXEC, 0)));
                // Fixup the index in the data.
                mDataPos = newDataPos + 4;
                if (status_t status = writeInt32(rpcFields->mFds->size() - 1); status != OK) {
                    return status;
                }
            }
        }
    }

    return err;
@@ -584,7 +646,7 @@ void Parcel::restoreAllowFds(bool lastValue)
bool Parcel::hasFileDescriptors() const
{
    if (const auto* rpcFields = maybeRpcFields()) {
        return false;
        return rpcFields->mFds != nullptr && !rpcFields->mFds->empty();
    }
    auto* kernelFields = maybeKernelFields();
    if (!kernelFields->mFdsKnown) {
@@ -621,11 +683,7 @@ std::vector<sp<IBinder>> Parcel::debugReadAllStrongBinders() const {
std::vector<int> Parcel::debugReadAllFileDescriptors() const {
    std::vector<int> ret;

    const auto* kernelFields = maybeKernelFields();
    if (kernelFields == nullptr) {
        return ret;
    }

    if (const auto* kernelFields = maybeKernelFields()) {
        size_t initPosition = dataPosition();
        for (size_t i = 0; i < kernelFields->mObjectsSize; i++) {
            binder_size_t offset = kernelFields->mObjects[i];
@@ -639,16 +697,17 @@ std::vector<int> Parcel::debugReadAllFileDescriptors() const {
            LOG_ALWAYS_FATAL_IF(fd == -1);
            ret.push_back(fd);
        }

        setDataPosition(initPosition);
    } else if (const auto* rpcFields = maybeRpcFields(); rpcFields && rpcFields->mFds) {
        for (const auto& fd : *rpcFields->mFds) {
            ret.push_back(toRawFd(fd));
        }
    }

    return ret;
}

status_t Parcel::hasFileDescriptorsInRange(size_t offset, size_t len, bool* result) const {
    const auto* kernelFields = maybeKernelFields();
    if (kernelFields == nullptr) {
        return BAD_TYPE;
    }
    if (len > INT32_MAX || offset > INT32_MAX) {
        // Don't accept size_t values which may have come from an inadvertent conversion from a
        // negative int.
@@ -659,6 +718,7 @@ status_t Parcel::hasFileDescriptorsInRange(size_t offset, size_t len, bool* resu
        return BAD_VALUE;
    }
    *result = false;
    if (const auto* kernelFields = maybeKernelFields()) {
        for (size_t i = 0; i < kernelFields->mObjectsSize; i++) {
            size_t pos = kernelFields->mObjects[i];
            if (pos < offset) continue;
@@ -669,12 +729,24 @@ status_t Parcel::hasFileDescriptorsInRange(size_t offset, size_t len, bool* resu
                    continue;
                }
            }
        const flat_binder_object* flat = reinterpret_cast<const flat_binder_object*>(mData + pos);
            const flat_binder_object* flat =
                    reinterpret_cast<const flat_binder_object*>(mData + pos);
            if (flat->hdr.type == BINDER_TYPE_FD) {
                *result = true;
                break;
            }
        }
    } else if (const auto* rpcFields = maybeRpcFields()) {
        for (uint32_t pos : rpcFields->mObjectPositions) {
            if (offset <= pos && pos < limit) {
                const auto* type = reinterpret_cast<const RpcFields::ObjectType*>(mData + pos);
                if (*type == RpcFields::TYPE_NATIVE_FILE_DESCRIPTOR) {
                    *result = true;
                    break;
                }
            }
        }
    }
    return NO_ERROR;
}

@@ -1293,11 +1365,40 @@ status_t Parcel::writeNativeHandle(const native_handle* handle)
    return err;
}

status_t Parcel::writeFileDescriptor(int fd, bool takeOwnership)
{
    if (isForRpc()) {
        ALOGE("Cannot write file descriptor to remote binder.");
        return BAD_TYPE;
status_t Parcel::writeFileDescriptor(int fd, bool takeOwnership) {
    if (auto* rpcFields = maybeRpcFields()) {
        std::variant<base::unique_fd, base::borrowed_fd> fdVariant;
        if (takeOwnership) {
            fdVariant = base::unique_fd(fd);
        } else {
            fdVariant = base::borrowed_fd(fd);
        }
        if (!mAllowFds) {
            return FDS_NOT_ALLOWED;
        }
        switch (rpcFields->mSession->getFileDescriptorTransportMode()) {
            case RpcSession::FileDescriptorTransportMode::NONE: {
                return FDS_NOT_ALLOWED;
            }
            case RpcSession::FileDescriptorTransportMode::UNIX: {
                if (rpcFields->mFds == nullptr) {
                    rpcFields->mFds = std::make_unique<decltype(rpcFields->mFds)::element_type>();
                }
                size_t dataPos = mDataPos;
                if (dataPos > UINT32_MAX) {
                    return NO_MEMORY;
                }
                if (status_t err = writeInt32(RpcFields::TYPE_NATIVE_FILE_DESCRIPTOR); err != OK) {
                    return err;
                }
                if (status_t err = writeInt32(rpcFields->mFds->size()); err != OK) {
                    return err;
                }
                rpcFields->mObjectPositions.push_back(dataPos);
                rpcFields->mFds->push_back(std::move(fdVariant));
                return OK;
            }
        }
    }

    flat_binder_object obj;
@@ -2038,8 +2139,31 @@ native_handle* Parcel::readNativeHandle() const
    return h;
}

int Parcel::readFileDescriptor() const
{
int Parcel::readFileDescriptor() const {
    if (const auto* rpcFields = maybeRpcFields()) {
        if (!std::binary_search(rpcFields->mObjectPositions.begin(),
                                rpcFields->mObjectPositions.end(), mDataPos)) {
            ALOGW("Attempt to read file descriptor from Parcel %p at offset %zu that is not in the "
                  "object list",
                  this, mDataPos);
            return BAD_TYPE;
        }

        int32_t objectType = readInt32();
        if (objectType != RpcFields::TYPE_NATIVE_FILE_DESCRIPTOR) {
            return BAD_TYPE;
        }

        int32_t fdIndex = readInt32();
        if (rpcFields->mFds == nullptr || fdIndex < 0 ||
            static_cast<size_t>(fdIndex) >= rpcFields->mFds->size()) {
            ALOGE("RPC Parcel contains invalid file descriptor index. index=%d fd_count=%zu",
                  fdIndex, rpcFields->mFds ? rpcFields->mFds->size() : 0);
            return BAD_VALUE;
        }
        return toRawFd(rpcFields->mFds->at(fdIndex));
    }

    const flat_binder_object* flat = readObject(true);

    if (flat && flat->hdr.type == BINDER_TYPE_FD) {
@@ -2049,8 +2173,7 @@ int Parcel::readFileDescriptor() const
    return BAD_TYPE;
}

int Parcel::readParcelFileDescriptor() const
{
int Parcel::readParcelFileDescriptor() const {
    int32_t hasComm = readInt32();
    int fd = readFileDescriptor();
    if (hasComm != 0) {
@@ -2270,10 +2393,7 @@ const flat_binder_object* Parcel::readObject(bool nullMetaData) const
}

void Parcel::closeFileDescriptors() {
    auto* kernelFields = maybeKernelFields();
    if (kernelFields == nullptr) {
        return;
    }
    if (auto* kernelFields = maybeKernelFields()) {
        size_t i = kernelFields->mObjectsSize;
        if (i > 0) {
            // ALOGI("Closing file descriptors for %zu objects...", i);
@@ -2287,6 +2407,9 @@ void Parcel::closeFileDescriptors() {
                close(flat->handle);
            }
        }
    } else if (auto* rpcFields = maybeRpcFields()) {
        rpcFields->mFds.reset();
    }
}

uintptr_t Parcel::ipcData() const
@@ -2363,8 +2486,11 @@ void Parcel::ipcSetDataReference(const uint8_t* data, size_t dataSize, const bin
    scanForFds();
}

void Parcel::rpcSetDataReference(const sp<RpcSession>& session, const uint8_t* data,
                                 size_t dataSize, release_func relFunc) {
status_t Parcel::rpcSetDataReference(const sp<RpcSession>& session, const uint8_t* data,
                                     size_t dataSize, const uint32_t* objectTable,
                                     size_t objectTableSize,
                                     std::vector<base::unique_fd> ancillaryFds,
                                     release_func relFunc) {
    // this code uses 'mOwner == nullptr' to understand whether it owns memory
    LOG_ALWAYS_FATAL_IF(relFunc == nullptr, "must provide cleanup function");

@@ -2373,9 +2499,32 @@ void Parcel::rpcSetDataReference(const sp<RpcSession>& session, const uint8_t* d
    freeData();
    markForRpc(session);

    auto* rpcFields = maybeRpcFields();
    LOG_ALWAYS_FATAL_IF(rpcFields == nullptr); // guaranteed by markForRpc.

    mData = const_cast<uint8_t*>(data);
    mDataSize = mDataCapacity = dataSize;
    mOwner = relFunc;

    if (objectTableSize != ancillaryFds.size()) {
        ALOGE("objectTableSize=%zu ancillaryFds.size=%zu", objectTableSize, ancillaryFds.size());
        freeData(); // don't leak mData
        return BAD_VALUE;
    }

    rpcFields->mObjectPositions.reserve(objectTableSize);
    for (size_t i = 0; i < objectTableSize; i++) {
        rpcFields->mObjectPositions.push_back(objectTable[i]);
    }
    if (!ancillaryFds.empty()) {
        rpcFields->mFds = std::make_unique<decltype(rpcFields->mFds)::element_type>();
        rpcFields->mFds->reserve(ancillaryFds.size());
        for (auto& fd : ancillaryFds) {
            rpcFields->mFds->push_back(std::move(fd));
        }
    }

    return OK;
}

void Parcel::print(TextOutput& to, uint32_t /*flags*/) const
@@ -2558,6 +2707,9 @@ status_t Parcel::restartWrite(size_t desired)
        kernelFields->mObjectsSorted = false;
        kernelFields->mHasFds = false;
        kernelFields->mFdsKnown = true;
    } else if (auto* rpcFields = maybeRpcFields()) {
        rpcFields->mObjectPositions.clear();
        rpcFields->mFds.reset();
    }
    mAllowFds = true;

@@ -2573,18 +2725,27 @@ status_t Parcel::continueWrite(size_t desired)
    }

    auto* kernelFields = maybeKernelFields();
    auto* rpcFields = maybeRpcFields();

    // If shrinking, first adjust for any objects that appear
    // after the new data size.
    size_t objectsSize = kernelFields ? kernelFields->mObjectsSize : 0;
    if (kernelFields && desired < mDataSize) {
    size_t objectsSize =
            kernelFields ? kernelFields->mObjectsSize : rpcFields->mObjectPositions.size();
    if (desired < mDataSize) {
        if (desired == 0) {
            objectsSize = 0;
        } else {
            if (kernelFields) {
                while (objectsSize > 0) {
                    if (kernelFields->mObjects[objectsSize - 1] < desired) break;
                    objectsSize--;
                }
            } else {
                while (objectsSize > 0) {
                    if (rpcFields->mObjectPositions[objectsSize - 1] < desired) break;
                    objectsSize--;
                }
            }
        }
    }

@@ -2604,7 +2765,7 @@ status_t Parcel::continueWrite(size_t desired)
        }
        binder_size_t* objects = nullptr;

        if (objectsSize) {
        if (kernelFields && objectsSize) {
            objects = (binder_size_t*)calloc(objectsSize, sizeof(binder_size_t));
            if (!objects) {
                free(data);
@@ -2620,6 +2781,12 @@ status_t Parcel::continueWrite(size_t desired)
            acquireObjects();
            kernelFields->mObjectsSize = oldObjectsSize;
        }
        if (rpcFields) {
            if (status_t status = truncateRpcObjects(objectsSize); status != OK) {
                free(data);
                return status;
            }
        }

        if (mData) {
            memcpy(data, mData, mDataSize < desired ? mDataSize : desired);
@@ -2678,6 +2845,11 @@ status_t Parcel::continueWrite(size_t desired)
            kernelFields->mNextObjectHint = 0;
            kernelFields->mObjectsSorted = false;
        }
        if (rpcFields) {
            if (status_t status = truncateRpcObjects(objectsSize); status != OK) {
                return status;
            }
        }

        // We own the data, so we can just do a realloc().
        if (desired > mDataCapacity) {
@@ -2734,6 +2906,35 @@ status_t Parcel::continueWrite(size_t desired)
    return NO_ERROR;
}

status_t Parcel::truncateRpcObjects(size_t newObjectsSize) {
    auto* rpcFields = maybeRpcFields();
    if (newObjectsSize == 0) {
        rpcFields->mObjectPositions.clear();
        if (rpcFields->mFds) {
            rpcFields->mFds->clear();
        }
        return OK;
    }
    while (rpcFields->mObjectPositions.size() > newObjectsSize) {
        uint32_t pos = rpcFields->mObjectPositions.back();
        rpcFields->mObjectPositions.pop_back();
        const auto type = *reinterpret_cast<const RpcFields::ObjectType*>(mData + pos);
        if (type == RpcFields::TYPE_NATIVE_FILE_DESCRIPTOR) {
            const auto fdIndex =
                    *reinterpret_cast<const int32_t*>(mData + pos + sizeof(RpcFields::ObjectType));
            if (rpcFields->mFds == nullptr || fdIndex < 0 ||
                static_cast<size_t>(fdIndex) >= rpcFields->mFds->size()) {
                ALOGE("RPC Parcel contains invalid file descriptor index. index=%d fd_count=%zu",
                      fdIndex, rpcFields->mFds ? rpcFields->mFds->size() : 0);
                return BAD_VALUE;
            }
            // In practice, this always removes the last element.
            rpcFields->mFds->erase(rpcFields->mFds->begin() + fdIndex);
        }
    }
    return OK;
}

void Parcel::initState()
{
    LOG_ALLOC("Parcel %p: initState", this);
+22 −3
Original line number Diff line number Diff line
@@ -122,6 +122,14 @@ void RpcServer::setProtocolVersion(uint32_t version) {
    mProtocolVersion = version;
}

void RpcServer::setSupportedFileDescriptorTransportModes(
        const std::vector<RpcSession::FileDescriptorTransportMode>& modes) {
    mSupportedFileDescriptorTransportModes.reset();
    for (RpcSession::FileDescriptorTransportMode mode : modes) {
        mSupportedFileDescriptorTransportModes.set(static_cast<size_t>(mode));
    }
}

void RpcServer::setRootObject(const sp<IBinder>& binder) {
    std::lock_guard<std::mutex> _l(mLock);
    mRootObjectFactory = nullptr;
@@ -292,7 +300,7 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie
    if (status == OK) {
        iovec iov{&header, sizeof(header)};
        status = client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1,
                                                std::nullopt);
                                                std::nullopt, /*enableAncillaryFds=*/false);
        if (status != OK) {
            ALOGE("Failed to read ID for client connecting to RPC server: %s",
                  statusToString(status).c_str());
@@ -307,7 +315,7 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie
                sessionId.resize(header.sessionIdSize);
                iovec iov{sessionId.data(), sessionId.size()};
                status = client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1,
                                                        std::nullopt);
                                                        std::nullopt, /*enableAncillaryFds=*/false);
                if (status != OK) {
                    ALOGE("Failed to read session ID for client connecting to RPC server: %s",
                          statusToString(status).c_str());
@@ -338,7 +346,7 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie

            iovec iov{&response, sizeof(response)};
            status = client->interruptableWriteFully(server->mShutdownTrigger.get(), &iov, 1,
                                                     std::nullopt);
                                                     std::nullopt, nullptr);
            if (status != OK) {
                ALOGE("Failed to send new session response: %s", statusToString(status).c_str());
                // still need to cleanup before we can return
@@ -396,6 +404,17 @@ void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clie
            session->setMaxIncomingThreads(server->mMaxThreads);
            if (!session->setProtocolVersion(protocolVersion)) return;

            if (server->mSupportedFileDescriptorTransportModes.test(
                        header.fileDescriptorTransportMode)) {
                session->setFileDescriptorTransportMode(
                        static_cast<RpcSession::FileDescriptorTransportMode>(
                                header.fileDescriptorTransportMode));
            } else {
                ALOGE("Rejecting connection: FileDescriptorTransportMode is not supported: %hhu",
                      header.fileDescriptorTransportMode);
                return;
            }

            // if null, falls back to server root
            sp<IBinder> sessionSpecificRoot;
            if (server->mRootObjectFactory != nullptr) {
+14 −4
Original line number Diff line number Diff line
@@ -129,6 +129,14 @@ std::optional<uint32_t> RpcSession::getProtocolVersion() {
    return mProtocolVersion;
}

void RpcSession::setFileDescriptorTransportMode(FileDescriptorTransportMode mode) {
    mFileDescriptorTransportMode = mode;
}

RpcSession::FileDescriptorTransportMode RpcSession::getFileDescriptorTransportMode() {
    return mFileDescriptorTransportMode;
}

status_t RpcSession::setupUnixDomainClient(const char* path) {
    return setupSocketClient(UnixSocketAddress(path));
}
@@ -606,6 +614,7 @@ status_t RpcSession::initAndAddConnection(unique_fd fd, const std::vector<uint8_
    RpcConnectionHeader header{
            .version = mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION),
            .options = 0,
            .fileDescriptorTransportMode = static_cast<uint8_t>(mFileDescriptorTransportMode),
            .sessionIdSize = static_cast<uint16_t>(sessionId.size()),
    };

@@ -614,8 +623,8 @@ status_t RpcSession::initAndAddConnection(unique_fd fd, const std::vector<uint8_
    }

    iovec headerIov{&header, sizeof(header)};
    auto sendHeaderStatus =
            server->interruptableWriteFully(mShutdownTrigger.get(), &headerIov, 1, std::nullopt);
    auto sendHeaderStatus = server->interruptableWriteFully(mShutdownTrigger.get(), &headerIov, 1,
                                                            std::nullopt, nullptr);
    if (sendHeaderStatus != OK) {
        ALOGE("Could not write connection header to socket: %s",
              statusToString(sendHeaderStatus).c_str());
@@ -625,8 +634,9 @@ status_t RpcSession::initAndAddConnection(unique_fd fd, const std::vector<uint8_
    if (sessionId.size() > 0) {
        iovec sessionIov{const_cast<void*>(static_cast<const void*>(sessionId.data())),
                         sessionId.size()};
        auto sendSessionIdStatus = server->interruptableWriteFully(mShutdownTrigger.get(),
                                                                   &sessionIov, 1, std::nullopt);
        auto sendSessionIdStatus =
                server->interruptableWriteFully(mShutdownTrigger.get(), &sessionIov, 1,
                                                std::nullopt, nullptr);
        if (sendSessionIdStatus != OK) {
            ALOGE("Could not write session ID ('%s') to socket: %s",
                  base::HexString(sessionId.data(), sessionId.size()).c_str(),
+190 −87

File changed.

Preview size limit exceeded, changes collapsed.

+7 −1
Original line number Diff line number Diff line
@@ -181,7 +181,9 @@ private:
    [[nodiscard]] status_t rpcSend(
            const sp<RpcSession::RpcConnection>& connection, const sp<RpcSession>& session,
            const char* what, iovec* iovs, int niovs,
            const std::optional<android::base::function_ref<status_t()>>& altPoll);
            const std::optional<android::base::function_ref<status_t()>>& altPoll,
            const std::vector<std::variant<base::unique_fd, base::borrowed_fd>>* ancillaryFds =
                    nullptr);
    [[nodiscard]] status_t rpcRec(const sp<RpcSession::RpcConnection>& connection,
                                  const sp<RpcSession>& session, const char* what, iovec* iovs,
                                  int niovs);
@@ -201,6 +203,10 @@ private:
                                            const sp<RpcSession>& session,
                                            const RpcWireHeader& command);

    // Whether `parcel` is compatible with `session`.
    [[nodiscard]] static status_t validateParcel(const sp<RpcSession>& session,
                                                 const Parcel& parcel, std::string* errorMsg);

    struct BinderNode {
        // Two cases:
        // A - local binder we are serving
Loading