Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d618b567 authored by Alex Vakulenko's avatar Alex Vakulenko
Browse files

libpdx_uds: Fix send/receive over socket to handle signal interrupts

Previous implementation of send/receive didn't account for the fact that
send/receive operation might be interrupted by a signal and transfer
fewer bytes than requested.

Fix this by repeatedly calling send/recv until all the requested data
is transferred over sockets.

Also added a number of unit tests for send/receive functions.

Bug: 37427314
Test: `m -j32` succeeds for Sailfish.
      `libpdx_uds_tests` pass on device

Change-Id: Ib8f78967af3c218d9f18fb3dfe8953c35800540b
parent 5330710a
Loading
Loading
Loading
Loading
+2 −0
Original line number Original line Diff line number Diff line
@@ -35,10 +35,12 @@ cc_test {
        "-Werror",
        "-Werror",
    ],
    ],
    srcs: [
    srcs: [
        "ipc_helper_tests.cpp",
        "remote_method_tests.cpp",
        "remote_method_tests.cpp",
        "service_framework_tests.cpp",
        "service_framework_tests.cpp",
    ],
    ],
    static_libs: [
    static_libs: [
        "libgmock",
        "libpdx_uds",
        "libpdx_uds",
        "libpdx",
        "libpdx",
    ],
    ],
+168 −63
Original line number Original line Diff line number Diff line
@@ -18,6 +18,150 @@ namespace android {
namespace pdx {
namespace pdx {
namespace uds {
namespace uds {


namespace {

// Default implementations of Send/Receive interfaces to use standard socket
// send/sendmsg/recv/recvmsg functions.
class SocketSender : public SendInterface {
 public:
  ssize_t Send(int socket_fd, const void* data, size_t size,
               int flags) override {
    return send(socket_fd, data, size, flags);
  }
  ssize_t SendMessage(int socket_fd, const msghdr* msg, int flags) override {
    return sendmsg(socket_fd, msg, flags);
  }
} g_socket_sender;

class SocketReceiver : public RecvInterface {
 public:
  ssize_t Receive(int socket_fd, void* data, size_t size, int flags) override {
    return recv(socket_fd, data, size, flags);
  }
  ssize_t ReceiveMessage(int socket_fd, msghdr* msg, int flags) override {
    return recvmsg(socket_fd, msg, flags);
  }
} g_socket_receiver;

}  // anonymous namespace

// Helper wrappers around send()/sendmsg() which repeat send() calls on data
// that was not sent with the initial call to send/sendmsg. This is important to
// handle transmissions interrupted by signals.
Status<void> SendAll(SendInterface* sender, const BorrowedHandle& socket_fd,
                     const void* data, size_t size) {
  Status<void> ret;
  const uint8_t* ptr = static_cast<const uint8_t*>(data);
  while (size > 0) {
    ssize_t size_written =
        RETRY_EINTR(sender->Send(socket_fd.Get(), ptr, size, MSG_NOSIGNAL));
    if (size_written < 0) {
      ret.SetError(errno);
      ALOGE("SendAll: Failed to send data over socket: %s",
            ret.GetErrorMessage().c_str());
      break;
    }
    size -= size_written;
    ptr += size_written;
  }
  return ret;
}

Status<void> SendMsgAll(SendInterface* sender, const BorrowedHandle& socket_fd,
                        const msghdr* msg) {
  Status<void> ret;
  ssize_t sent_size =
      RETRY_EINTR(sender->SendMessage(socket_fd.Get(), msg, MSG_NOSIGNAL));
  if (sent_size < 0) {
    ret.SetError(errno);
    ALOGE("SendMsgAll: Failed to send data over socket: %s",
          ret.GetErrorMessage().c_str());
    return ret;
  }

  ssize_t chunk_start_offset = 0;
  for (size_t i = 0; i < msg->msg_iovlen; i++) {
    ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
    if (sent_size < chunk_end_offset) {
      size_t offset_within_chunk = sent_size - chunk_start_offset;
      size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
      const uint8_t* chunk_base =
          static_cast<const uint8_t*>(msg->msg_iov[i].iov_base);
      ret = SendAll(sender, socket_fd, chunk_base + offset_within_chunk,
                    data_size);
      if (!ret)
        break;
      sent_size += data_size;
    }
    chunk_start_offset = chunk_end_offset;
  }
  return ret;
}

// Helper wrappers around recv()/recvmsg() which repeat recv() calls on data
// that was not received with the initial call to recvmsg(). This is important
// to handle transmissions interrupted by signals as well as the case when
// initial data did not arrive in a single chunk over the socket (e.g. socket
// buffer was full at the time of transmission, and only portion of initial
// message was sent and the rest was blocked until the buffer was cleared by the
// receiving side).
Status<void> RecvMsgAll(RecvInterface* receiver,
                        const BorrowedHandle& socket_fd, msghdr* msg) {
  Status<void> ret;
  ssize_t size_read = RETRY_EINTR(receiver->ReceiveMessage(
      socket_fd.Get(), msg, MSG_WAITALL | MSG_CMSG_CLOEXEC));
  if (size_read < 0) {
    ret.SetError(errno);
    ALOGE("RecvMsgAll: Failed to receive data from socket: %s",
          ret.GetErrorMessage().c_str());
    return ret;
  } else if (size_read == 0) {
    ret.SetError(ESHUTDOWN);
    ALOGW("RecvMsgAll: Socket has been shut down");
    return ret;
  }

  ssize_t chunk_start_offset = 0;
  for (size_t i = 0; i < msg->msg_iovlen; i++) {
    ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
    if (size_read < chunk_end_offset) {
      size_t offset_within_chunk = size_read - chunk_start_offset;
      size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
      uint8_t* chunk_base = static_cast<uint8_t*>(msg->msg_iov[i].iov_base);
      ret = RecvAll(receiver, socket_fd, chunk_base + offset_within_chunk,
                    data_size);
      if (!ret)
        break;
      size_read += data_size;
    }
    chunk_start_offset = chunk_end_offset;
  }
  return ret;
}

Status<void> RecvAll(RecvInterface* receiver, const BorrowedHandle& socket_fd,
                     void* data, size_t size) {
  Status<void> ret;
  uint8_t* ptr = static_cast<uint8_t*>(data);
  while (size > 0) {
    ssize_t size_read = RETRY_EINTR(receiver->Receive(
        socket_fd.Get(), ptr, size, MSG_WAITALL | MSG_CMSG_CLOEXEC));
    if (size_read < 0) {
      ret.SetError(errno);
      ALOGE("RecvAll: Failed to receive data from socket: %s",
            ret.GetErrorMessage().c_str());
      break;
    } else if (size_read == 0) {
      ret.SetError(ESHUTDOWN);
      ALOGW("RecvAll: Socket has been shut down");
      break;
    }
    size -= size_read;
    ptr += size_read;
  }
  return ret;
}

uint32_t kMagicPreamble = 0x7564736d;  // 'udsm'.
uint32_t kMagicPreamble = 0x7564736d;  // 'udsm'.


struct MessagePreamble {
struct MessagePreamble {
@@ -32,17 +176,14 @@ Status<void> SendPayload::Send(const BorrowedHandle& socket_fd) {


Status<void> SendPayload::Send(const BorrowedHandle& socket_fd,
Status<void> SendPayload::Send(const BorrowedHandle& socket_fd,
                               const ucred* cred) {
                               const ucred* cred) {
  SendInterface* sender = sender_ ? sender_ : &g_socket_sender;
  MessagePreamble preamble;
  MessagePreamble preamble;
  preamble.magic = kMagicPreamble;
  preamble.magic = kMagicPreamble;
  preamble.data_size = buffer_.size();
  preamble.data_size = buffer_.size();
  preamble.fd_count = file_handles_.size();
  preamble.fd_count = file_handles_.size();

  Status<void> ret = SendAll(sender, socket_fd, &preamble, sizeof(preamble));
  ssize_t ret = RETRY_EINTR(
  if (!ret)
      send(socket_fd.Get(), &preamble, sizeof(preamble), MSG_NOSIGNAL));
    return ret;
  if (ret < 0)
    return ErrorStatus(errno);
  if (ret != sizeof(preamble))
    return ErrorStatus(EIO);


  msghdr msg = {};
  msghdr msg = {};
  iovec recv_vect = {buffer_.data(), buffer_.size()};
  iovec recv_vect = {buffer_.data(), buffer_.size()};
@@ -72,12 +213,7 @@ Status<void> SendPayload::Send(const BorrowedHandle& socket_fd,
    }
    }
  }
  }


  ret = RETRY_EINTR(sendmsg(socket_fd.Get(), &msg, MSG_NOSIGNAL));
  return SendMsgAll(sender, socket_fd, &msg);
  if (ret < 0)
    return ErrorStatus(errno);
  if (static_cast<size_t>(ret) != buffer_.size())
    return ErrorStatus(EIO);
  return {};
}
}


// MessageWriter
// MessageWriter
@@ -132,15 +268,16 @@ Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd) {


Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
                                     ucred* cred) {
                                     ucred* cred) {
  RecvInterface* receiver = receiver_ ? receiver_ : &g_socket_receiver;
  MessagePreamble preamble;
  MessagePreamble preamble;
  ssize_t ret = RETRY_EINTR(
  Status<void> ret = RecvAll(receiver, socket_fd, &preamble, sizeof(preamble));
      recv(socket_fd.Get(), &preamble, sizeof(preamble), MSG_WAITALL));
  if (!ret)
  if (ret < 0)
    return ret;
    return ErrorStatus(errno);

  else if (ret == 0)
  if (preamble.magic != kMagicPreamble) {
    return ErrorStatus(ESHUTDOWN);
    ret.SetError(EIO);
  else if (ret != sizeof(preamble) || preamble.magic != kMagicPreamble)
    return ret;
    return ErrorStatus(EIO);
  }


  buffer_.resize(preamble.data_size);
  buffer_.resize(preamble.data_size);
  file_handles_.clear();
  file_handles_.clear();
@@ -159,13 +296,9 @@ Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
    msg.msg_control = alloca(msg.msg_controllen);
    msg.msg_control = alloca(msg.msg_controllen);
  }
  }


  ret = RETRY_EINTR(recvmsg(socket_fd.Get(), &msg, MSG_WAITALL));
  ret = RecvMsgAll(receiver, socket_fd, &msg);
  if (ret < 0)
  if (!ret)
    return ErrorStatus(errno);
    return ret;
  else if (ret == 0)
    return ErrorStatus(ESHUTDOWN);
  else if (static_cast<uint32_t>(ret) != preamble.data_size)
    return ErrorStatus(EIO);


  bool cred_available = false;
  bool cred_available = false;
  file_handles_.reserve(preamble.fd_count);
  file_handles_.reserve(preamble.fd_count);
@@ -186,11 +319,10 @@ Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
    cmsg = CMSG_NXTHDR(&msg, cmsg);
    cmsg = CMSG_NXTHDR(&msg, cmsg);
  }
  }


  if (cred && !cred_available) {
  if (cred && !cred_available)
    return ErrorStatus(EIO);
    ret.SetError(EIO);
  }


  return {};
  return ret;
}
}


// MessageReader
// MessageReader
@@ -223,13 +355,7 @@ bool ReceivePayload::GetChannelHandle(ChannelReference /*ref*/,


Status<void> SendData(const BorrowedHandle& socket_fd, const void* data,
Status<void> SendData(const BorrowedHandle& socket_fd, const void* data,
                      size_t size) {
                      size_t size) {
  ssize_t size_written =
  return SendAll(&g_socket_sender, socket_fd, data, size);
      RETRY_EINTR(send(socket_fd.Get(), data, size, MSG_NOSIGNAL));
  if (size_written < 0)
    return ErrorStatus(errno);
  if (static_cast<size_t>(size_written) != size)
    return ErrorStatus(EIO);
  return {};
}
}


Status<void> SendDataVector(const BorrowedHandle& socket_fd, const iovec* data,
Status<void> SendDataVector(const BorrowedHandle& socket_fd, const iovec* data,
@@ -237,26 +363,12 @@ Status<void> SendDataVector(const BorrowedHandle& socket_fd, const iovec* data,
  msghdr msg = {};
  msghdr msg = {};
  msg.msg_iov = const_cast<iovec*>(data);
  msg.msg_iov = const_cast<iovec*>(data);
  msg.msg_iovlen = count;
  msg.msg_iovlen = count;
  ssize_t size_written =
  return SendMsgAll(&g_socket_sender, socket_fd, &msg);
      RETRY_EINTR(sendmsg(socket_fd.Get(), &msg, MSG_NOSIGNAL));
  if (size_written < 0)
    return ErrorStatus(errno);
  if (static_cast<size_t>(size_written) != CountVectorSize(data, count))
    return ErrorStatus(EIO);
  return {};
}
}


Status<void> ReceiveData(const BorrowedHandle& socket_fd, void* data,
Status<void> ReceiveData(const BorrowedHandle& socket_fd, void* data,
                         size_t size) {
                         size_t size) {
  ssize_t size_read =
  return RecvAll(&g_socket_receiver, socket_fd, data, size);
      RETRY_EINTR(recv(socket_fd.Get(), data, size, MSG_WAITALL));
  if (size_read < 0)
    return ErrorStatus(errno);
  else if (size_read == 0)
    return ErrorStatus(ESHUTDOWN);
  else if (static_cast<size_t>(size_read) != size)
    return ErrorStatus(EIO);
  return {};
}
}


Status<void> ReceiveDataVector(const BorrowedHandle& socket_fd,
Status<void> ReceiveDataVector(const BorrowedHandle& socket_fd,
@@ -264,14 +376,7 @@ Status<void> ReceiveDataVector(const BorrowedHandle& socket_fd,
  msghdr msg = {};
  msghdr msg = {};
  msg.msg_iov = const_cast<iovec*>(data);
  msg.msg_iov = const_cast<iovec*>(data);
  msg.msg_iovlen = count;
  msg.msg_iovlen = count;
  ssize_t size_read = RETRY_EINTR(recvmsg(socket_fd.Get(), &msg, MSG_WAITALL));
  return RecvMsgAll(&g_socket_receiver, socket_fd, &msg);
  if (size_read < 0)
    return ErrorStatus(errno);
  else if (size_read == 0)
    return ErrorStatus(ESHUTDOWN);
  else if (static_cast<size_t>(size_read) != CountVectorSize(data, count))
    return ErrorStatus(EIO);
  return {};
}
}


size_t CountVectorSize(const iovec* vector, size_t count) {
size_t CountVectorSize(const iovec* vector, size_t count) {
+365 −0
Original line number Original line Diff line number Diff line
#include "uds/ipc_helper.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>

using testing::Return;
using testing::SetErrnoAndReturn;
using testing::_;

using android::pdx::BorrowedHandle;
using android::pdx::uds::SendInterface;
using android::pdx::uds::RecvInterface;
using android::pdx::uds::SendAll;
using android::pdx::uds::SendMsgAll;
using android::pdx::uds::RecvAll;
using android::pdx::uds::RecvMsgAll;

namespace {

// Useful constants for tests.
static constexpr intptr_t kPtr = 1234;
static constexpr int kSocketFd = 5678;
static const BorrowedHandle kSocket{kSocketFd};

// Helper functions to construct test data pointer values.
void* IntToPtr(intptr_t value) { return reinterpret_cast<void*>(value); }
const void* IntToConstPtr(intptr_t value) {
  return reinterpret_cast<const void*>(value);
}

// Mock classes for SendInterface/RecvInterface.
class MockSender : public SendInterface {
 public:
  MOCK_METHOD4(Send, ssize_t(int socket_fd, const void* data, size_t size,
                             int flags));
  MOCK_METHOD3(SendMessage,
               ssize_t(int socket_fd, const msghdr* msg, int flags));
};

class MockReceiver : public RecvInterface {
 public:
  MOCK_METHOD4(Receive,
               ssize_t(int socket_fd, void* data, size_t size, int flags));
  MOCK_METHOD3(ReceiveMessage, ssize_t(int socket_fd, msghdr* msg, int flags));
};

// Test case classes.
class SendTest : public testing::Test {
 public:
  SendTest() {
    ON_CALL(sender_, Send(_, _, _, _))
        .WillByDefault(SetErrnoAndReturn(EIO, -1));
    ON_CALL(sender_, SendMessage(_, _, _))
        .WillByDefault(SetErrnoAndReturn(EIO, -1));
  }

 protected:
  MockSender sender_;
};

class RecvTest : public testing::Test {
 public:
  RecvTest() {
    ON_CALL(receiver_, Receive(_, _, _, _))
        .WillByDefault(SetErrnoAndReturn(EIO, -1));
    ON_CALL(receiver_, ReceiveMessage(_, _, _))
        .WillByDefault(SetErrnoAndReturn(EIO, -1));
  }

 protected:
  MockReceiver receiver_;
};

class MessageTestBase : public testing::Test {
 public:
  MessageTestBase() {
    memset(&msg_, 0, sizeof(msg_));
    msg_.msg_iovlen = data_.size();
    msg_.msg_iov = data_.data();
  }

 protected:
  static constexpr intptr_t kPtr1 = kPtr;
  static constexpr intptr_t kPtr2 = kPtr + 200;
  static constexpr intptr_t kPtr3 = kPtr + 1000;

  MockSender sender_;
  msghdr msg_;
  std::vector<iovec> data_{
      {IntToPtr(kPtr1), 100}, {IntToPtr(kPtr2), 200}, {IntToPtr(kPtr3), 300}};
};

class SendMessageTest : public MessageTestBase {
 public:
  SendMessageTest() {
    ON_CALL(sender_, Send(_, _, _, _))
        .WillByDefault(SetErrnoAndReturn(EIO, -1));
    ON_CALL(sender_, SendMessage(_, _, _))
        .WillByDefault(SetErrnoAndReturn(EIO, -1));
  }

 protected:
  MockSender sender_;
};

class RecvMessageTest : public MessageTestBase {
 public:
  RecvMessageTest() {
    ON_CALL(receiver_, Receive(_, _, _, _))
        .WillByDefault(SetErrnoAndReturn(EIO, -1));
    ON_CALL(receiver_, ReceiveMessage(_, _, _))
        .WillByDefault(SetErrnoAndReturn(EIO, -1));
  }

 protected:
  MockReceiver receiver_;
};

// Actual tests.

// SendAll
TEST_F(SendTest, Complete) {
  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
      .WillOnce(Return(100));

  auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
  EXPECT_TRUE(status);
}

TEST_F(SendTest, Signal) {
  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
      .WillOnce(Return(20));
  EXPECT_CALL(sender_,
              Send(kSocketFd, IntToConstPtr(kPtr + 20), 80, MSG_NOSIGNAL))
      .WillOnce(Return(40));
  EXPECT_CALL(sender_,
              Send(kSocketFd, IntToConstPtr(kPtr + 60), 40, MSG_NOSIGNAL))
      .WillOnce(Return(40));

  auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
  EXPECT_TRUE(status);
}

TEST_F(SendTest, Eintr) {
  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
      .WillOnce(SetErrnoAndReturn(EINTR, -1))
      .WillOnce(Return(100));

  auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
  EXPECT_TRUE(status);
}

TEST_F(SendTest, Error) {
  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
      .WillOnce(SetErrnoAndReturn(EIO, -1));

  auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
  ASSERT_FALSE(status);
  EXPECT_EQ(EIO, status.error());
}

TEST_F(SendTest, Error2) {
  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
      .WillOnce(Return(50));
  EXPECT_CALL(sender_,
              Send(kSocketFd, IntToConstPtr(kPtr + 50), 50, MSG_NOSIGNAL))
      .WillOnce(SetErrnoAndReturn(EIO, -1));

  auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
  ASSERT_FALSE(status);
  EXPECT_EQ(EIO, status.error());
}

// RecvAll
TEST_F(RecvTest, Complete) {
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100,
                                 MSG_WAITALL | MSG_CMSG_CLOEXEC))
      .WillOnce(Return(100));

  auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
  EXPECT_TRUE(status);
}

TEST_F(RecvTest, Signal) {
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100, _))
      .WillOnce(Return(20));
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr + 20), 80, _))
      .WillOnce(Return(40));
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr + 60), 40, _))
      .WillOnce(Return(40));

  auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
  EXPECT_TRUE(status);
}

TEST_F(RecvTest, Eintr) {
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100, _))
      .WillOnce(SetErrnoAndReturn(EINTR, -1))
      .WillOnce(Return(100));

  auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
  EXPECT_TRUE(status);
}

TEST_F(RecvTest, Error) {
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100, _))
      .WillOnce(SetErrnoAndReturn(EIO, -1));

  auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
  ASSERT_FALSE(status);
  EXPECT_EQ(EIO, status.error());
}

TEST_F(RecvTest, Error2) {
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100, _))
      .WillOnce(Return(30));
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr + 30), 70, _))
      .WillOnce(SetErrnoAndReturn(EIO, -1));

  auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
  ASSERT_FALSE(status);
  EXPECT_EQ(EIO, status.error());
}

// SendMsgAll
TEST_F(SendMessageTest, Complete) {
  EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, MSG_NOSIGNAL))
      .WillOnce(Return(600));

  auto status = SendMsgAll(&sender_, kSocket, &msg_);
  EXPECT_TRUE(status);
}

TEST_F(SendMessageTest, Partial) {
  EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _)).WillOnce(Return(70));
  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr1 + 70), 30, _))
      .WillOnce(Return(30));
  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr2), 200, _))
      .WillOnce(Return(190));
  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr2 + 190), 10, _))
      .WillOnce(Return(10));
  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr3), 300, _))
      .WillOnce(Return(300));

  auto status = SendMsgAll(&sender_, kSocket, &msg_);
  EXPECT_TRUE(status);
}

TEST_F(SendMessageTest, Partial2) {
  EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _)).WillOnce(Return(310));
  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr3 + 10), 290, _))
      .WillOnce(Return(290));

  auto status = SendMsgAll(&sender_, kSocket, &msg_);
  EXPECT_TRUE(status);
}

TEST_F(SendMessageTest, Eintr) {
  EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _))
      .WillOnce(SetErrnoAndReturn(EINTR, -1))
      .WillOnce(Return(70));
  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr1 + 70), 30, _))
      .WillOnce(SetErrnoAndReturn(EINTR, -1))
      .WillOnce(Return(30));
  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr2), 200, _))
      .WillOnce(Return(200));
  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr3), 300, _))
      .WillOnce(Return(300));

  auto status = SendMsgAll(&sender_, kSocket, &msg_);
  EXPECT_TRUE(status);
}

TEST_F(SendMessageTest, Error) {
  EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _))
      .WillOnce(SetErrnoAndReturn(EBADF, -1));

  auto status = SendMsgAll(&sender_, kSocket, &msg_);
  ASSERT_FALSE(status);
  EXPECT_EQ(EBADF, status.error());
}

TEST_F(SendMessageTest, Error2) {
  EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _)).WillOnce(Return(20));
  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr1 + 20), 80, _))
      .WillOnce(SetErrnoAndReturn(EBADF, -1));

  auto status = SendMsgAll(&sender_, kSocket, &msg_);
  ASSERT_FALSE(status);
  EXPECT_EQ(EBADF, status.error());
}

// RecvMsgAll
TEST_F(RecvMessageTest, Complete) {
  EXPECT_CALL(receiver_,
              ReceiveMessage(kSocketFd, &msg_, MSG_WAITALL | MSG_CMSG_CLOEXEC))
      .WillOnce(Return(600));

  auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
  EXPECT_TRUE(status);
}

TEST_F(RecvMessageTest, Partial) {
  EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
      .WillOnce(Return(70));
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr1 + 70), 30, _))
      .WillOnce(Return(30));
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr2), 200, _))
      .WillOnce(Return(190));
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr2 + 190), 10, _))
      .WillOnce(Return(10));
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr3), 300, _))
      .WillOnce(Return(300));

  auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
  EXPECT_TRUE(status);
}

TEST_F(RecvMessageTest, Partial2) {
  EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
      .WillOnce(Return(310));
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr3 + 10), 290, _))
      .WillOnce(Return(290));

  auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
  EXPECT_TRUE(status);
}

TEST_F(RecvMessageTest, Eintr) {
  EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
      .WillOnce(SetErrnoAndReturn(EINTR, -1))
      .WillOnce(Return(70));
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr1 + 70), 30, _))
      .WillOnce(SetErrnoAndReturn(EINTR, -1))
      .WillOnce(Return(30));
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr2), 200, _))
      .WillOnce(Return(200));
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr3), 300, _))
      .WillOnce(Return(300));

  auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
  EXPECT_TRUE(status);
}

TEST_F(RecvMessageTest, Error) {
  EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
      .WillOnce(SetErrnoAndReturn(EBADF, -1));

  auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
  ASSERT_FALSE(status);
  EXPECT_EQ(EBADF, status.error());
}

TEST_F(RecvMessageTest, Error2) {
  EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
      .WillOnce(Return(20));
  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr1 + 20), 80, _))
      .WillOnce(SetErrnoAndReturn(EBADF, -1));

  auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
  ASSERT_FALSE(status);
  EXPECT_EQ(EBADF, status.error());
}

}  // namespace
+36 −0
Original line number Original line Diff line number Diff line
@@ -14,6 +14,38 @@ namespace android {
namespace pdx {
namespace pdx {
namespace uds {
namespace uds {


// Test interfaces used for unit-testing payload sending/receiving over sockets.
class SendInterface {
 public:
  virtual ssize_t Send(int socket_fd, const void* data, size_t size,
                       int flags) = 0;
  virtual ssize_t SendMessage(int socket_fd, const msghdr* msg, int flags) = 0;

 protected:
  virtual ~SendInterface() = default;
};

class RecvInterface {
 public:
  virtual ssize_t Receive(int socket_fd, void* data, size_t size,
                          int flags) = 0;
  virtual ssize_t ReceiveMessage(int socket_fd, msghdr* msg, int flags) = 0;

 protected:
  virtual ~RecvInterface() = default;
};

// Helper methods that allow to send/receive data through abstract interfaces.
// Useful for mocking out the underlying socket I/O.
Status<void> SendAll(SendInterface* sender, const BorrowedHandle& socket_fd,
                     const void* data, size_t size);
Status<void> SendMsgAll(SendInterface* sender, const BorrowedHandle& socket_fd,
                        const msghdr* msg);
Status<void> RecvAll(RecvInterface* receiver, const BorrowedHandle& socket_fd,
                     void* data, size_t size);
Status<void> RecvMsgAll(RecvInterface* receiver,
                        const BorrowedHandle& socket_fd, msghdr* msg);

#define RETRY_EINTR(fnc_call)                 \
#define RETRY_EINTR(fnc_call)                 \
  ([&]() -> decltype(fnc_call) {              \
  ([&]() -> decltype(fnc_call) {              \
    decltype(fnc_call) result;                \
    decltype(fnc_call) result;                \
@@ -25,6 +57,7 @@ namespace uds {


class SendPayload : public MessageWriter, public OutputResourceMapper {
class SendPayload : public MessageWriter, public OutputResourceMapper {
 public:
 public:
  SendPayload(SendInterface* sender = nullptr) : sender_{sender} {}
  Status<void> Send(const BorrowedHandle& socket_fd);
  Status<void> Send(const BorrowedHandle& socket_fd);
  Status<void> Send(const BorrowedHandle& socket_fd, const ucred* cred);
  Status<void> Send(const BorrowedHandle& socket_fd, const ucred* cred);


@@ -44,12 +77,14 @@ class SendPayload : public MessageWriter, public OutputResourceMapper {
      const RemoteChannelHandle& handle) override;
      const RemoteChannelHandle& handle) override;


 private:
 private:
  SendInterface* sender_;
  ByteBuffer buffer_;
  ByteBuffer buffer_;
  std::vector<int> file_handles_;
  std::vector<int> file_handles_;
};
};


class ReceivePayload : public MessageReader, public InputResourceMapper {
class ReceivePayload : public MessageReader, public InputResourceMapper {
 public:
 public:
  ReceivePayload(RecvInterface* receiver = nullptr) : receiver_{receiver} {}
  Status<void> Receive(const BorrowedHandle& socket_fd);
  Status<void> Receive(const BorrowedHandle& socket_fd);
  Status<void> Receive(const BorrowedHandle& socket_fd, ucred* cred);
  Status<void> Receive(const BorrowedHandle& socket_fd, ucred* cred);


@@ -64,6 +99,7 @@ class ReceivePayload : public MessageReader, public InputResourceMapper {
                        LocalChannelHandle* handle) override;
                        LocalChannelHandle* handle) override;


 private:
 private:
  RecvInterface* receiver_;
  ByteBuffer buffer_;
  ByteBuffer buffer_;
  std::vector<LocalHandle> file_handles_;
  std::vector<LocalHandle> file_handles_;
  size_t read_pos_{0};
  size_t read_pos_{0};