Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 021a9a4a authored by Josh Gao's avatar Josh Gao Committed by Gerrit Code Review
Browse files

Merge "Switch LocalSocket to android::base::{Send,Receive}FileDescriptorVector."

parents 0b2044a9 79e3be8a
Loading
Loading
Loading
Loading
+45 −129
Original line number Diff line number Diff line
@@ -33,14 +33,16 @@
#include <unistd.h>
#include <sys/ioctl.h>

#include <android-base/cmsg.h>
#include <android-base/macros.h>
#include <cutils/sockets.h>
#include <netinet/tcp.h>
#include <nativehelper/ScopedUtfChars.h>

namespace android {
using android::base::ReceiveFileDescriptorVector;
using android::base::SendFileDescriptorVector;

template <typename T>
void UNUSED(T t) {}
namespace android {

static jfieldID field_inboundFileDescriptors;
static jfieldID field_outboundFileDescriptors;
@@ -118,43 +120,40 @@ socket_bind_local (JNIEnv *env, jobject object, jobject fileDescriptor,
}

/**
 * Processes ancillary data, handling only
 * SCM_RIGHTS. Creates appropriate objects and sets appropriate
 * fields in the LocalSocketImpl object. Returns 0 on success
 * or -1 if an exception was thrown.
 * Reads data from a socket into buf, processing any ancillary data
 * and adding it to thisJ.
 *
 * Returns the length of normal data read, or -1 if an exception has
 * been thrown in this function.
 */
static int socket_process_cmsg(JNIEnv *env, jobject thisJ, struct msghdr * pMsg)
static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
        void *buffer, size_t len)
{
    struct cmsghdr *cmsgptr;
    ssize_t ret;
    std::vector<android::base::unique_fd> received_fds;

    for (cmsgptr = CMSG_FIRSTHDR(pMsg);
            cmsgptr != NULL; cmsgptr = CMSG_NXTHDR(pMsg, cmsgptr)) {
    ret = ReceiveFileDescriptorVector(fd, buffer, len, 64, &received_fds);

        if (cmsgptr->cmsg_level != SOL_SOCKET) {
            continue;
    if (ret < 0) {
        if (errno == EPIPE) {
            // Treat this as an end of stream
            return 0;
        }

        if (cmsgptr->cmsg_type == SCM_RIGHTS) {
            int *pDescriptors = (int *)CMSG_DATA(cmsgptr);
            jobjectArray fdArray;
            int count
                = ((cmsgptr->cmsg_len - CMSG_LEN(0)) / sizeof(int));

            if (count < 0) {
                jniThrowException(env, "java/io/IOException",
                    "invalid cmsg length");
        jniThrowIOException(env, errno);
        return -1;
    }

            fdArray = env->NewObjectArray(count, class_FileDescriptor, NULL);
    if (received_fds.size() > 0) {
        jobjectArray fdArray = env->NewObjectArray(received_fds.size(), class_FileDescriptor, NULL);

        if (fdArray == NULL) {
            // NewObjectArray has thrown.
            return -1;
        }

            for (int i = 0; i < count; i++) {
                jobject fdObject
                        = jniCreateFileDescriptor(env, pDescriptors[i]);
        for (size_t i = 0; i < received_fds.size(); i++) {
            jobject fdObject = jniCreateFileDescriptor(env, received_fds[i].get());

            if (env->ExceptionCheck()) {
                return -1;
@@ -167,69 +166,12 @@ static int socket_process_cmsg(JNIEnv *env, jobject thisJ, struct msghdr * pMsg)
            }
        }

            env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);

            if (env->ExceptionCheck()) {
                return -1;
            }
        }
        for (auto &fd : received_fds) {
            // The fds are stored in java.io.FileDescriptors now.
            static_cast<void>(fd.release());
        }

    return 0;
}

/**
 * Reads data from a socket into buf, processing any ancillary data
 * and adding it to thisJ.
 *
 * Returns the length of normal data read, or -1 if an exception has
 * been thrown in this function.
 */
static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
        void *buffer, size_t len)
{
    ssize_t ret;
    struct msghdr msg;
    struct iovec iv;
    unsigned char *buf = (unsigned char *)buffer;
    // Enough buffer for a pile of fd's. We throw an exception if
    // this buffer is too small.
    struct cmsghdr cmsgbuf[2*sizeof(cmsghdr) + 0x100];

    memset(&msg, 0, sizeof(msg));
    memset(&iv, 0, sizeof(iv));

    iv.iov_base = buf;
    iv.iov_len = len;

    msg.msg_iov = &iv;
    msg.msg_iovlen = 1;
    msg.msg_control = cmsgbuf;
    msg.msg_controllen = sizeof(cmsgbuf);

    ret = TEMP_FAILURE_RETRY(recvmsg(fd, &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC));

    if (ret < 0 && errno == EPIPE) {
        // Treat this as an end of stream
        return 0;
    }

    if (ret < 0) {
        jniThrowIOException(env, errno);
        return -1;
    }

    if ((msg.msg_flags & (MSG_CTRUNC | MSG_OOB | MSG_ERRQUEUE)) != 0) {
        // To us, any of the above flags are a fatal error

        jniThrowException(env, "java/io/IOException",
                "Unexpected error or truncation during recvmsg()");

        return -1;
    }

    if (ret >= 0) {
        socket_process_cmsg(env, thisJ, &msg);
        env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
    }

    return ret;
@@ -243,7 +185,6 @@ static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
static int socket_write_all(JNIEnv *env, jobject object, int fd,
        void *buf, size_t len)
{
    ssize_t ret;
    struct msghdr msg;
    unsigned char *buffer = (unsigned char *)buf;
    memset(&msg, 0, sizeof(msg));
@@ -256,14 +197,11 @@ static int socket_write_all(JNIEnv *env, jobject object, int fd,
        return -1;
    }

    struct cmsghdr *cmsg;
    int countFds = outboundFds == NULL ? 0 : env->GetArrayLength(outboundFds);
    int fds[countFds];
    char msgbuf[CMSG_SPACE(countFds)];
    std::vector<int> fds;

    // Add any pending outbound file descriptors to the message
    if (outboundFds != NULL) {

        if (env->ExceptionCheck()) {
            return -1;
        }
@@ -274,47 +212,25 @@ static int socket_write_all(JNIEnv *env, jobject object, int fd,
                return -1;
            }

            fds[i] = jniGetFDFromFileDescriptor(env, fdObject);
            fds.push_back(jniGetFDFromFileDescriptor(env, fdObject));
            if (env->ExceptionCheck()) {
                return -1;
            }
        }

        // See "man cmsg" really
        msg.msg_control = msgbuf;
        msg.msg_controllen = sizeof msgbuf;
        cmsg = CMSG_FIRSTHDR(&msg);
        cmsg->cmsg_level = SOL_SOCKET;
        cmsg->cmsg_type = SCM_RIGHTS;
        cmsg->cmsg_len = CMSG_LEN(sizeof fds);
        memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
    }

    // We only write our msg_control during the first write
    while (len > 0) {
        struct iovec iv;
        memset(&iv, 0, sizeof(iv));

        iv.iov_base = buffer;
        iv.iov_len = len;
    ssize_t rc = SendFileDescriptorVector(fd, buffer, len, fds);

        msg.msg_iov = &iv;
        msg.msg_iovlen = 1;

        do {
            ret = sendmsg(fd, &msg, MSG_NOSIGNAL);
        } while (ret < 0 && errno == EINTR);

        if (ret < 0) {
    while (rc != len) {
        if (rc == -1) {
            jniThrowIOException(env, errno);
            return -1;
        }

        buffer += ret;
        len -= ret;
        buffer += rc;
        len -= rc;

        // Wipes out any msg_control too
        memset(&msg, 0, sizeof(msg));
        rc = send(fd, buffer, len, MSG_NOSIGNAL);
    }

    return 0;