Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 09aa736e authored by Alex Vakulenko's avatar Alex Vakulenko
Browse files

libpdx_uds: Fix RPC channel ID allocation to not recycle values as often

The value of channel ID for PDX service on UDS transport was actual file
descriptor value for the data socket. Since channels are open and closed
constantly, it is quite often that channel ID is being reused immediately,
so it is almost impossible to use `cid` as unique identifier for objects
being passed around.

Instead, we now use monotonically growing channel ID value.

To prevent the possibility of using channel ID as the socket FD and vice
versa, changed all helper functions that used to take socket_fd as an int
to explicitly use BorrowedHandle which disables implicit conversion from
int (and hence makes it impossible to mistakenly pass in the channel ID).

Bug: 37082296
Test: `m -j32` succeeds for sailfish-eng
      Ran libpdx_uds_tests on the device -> all pass
      Device boots and CubeSea and VrHome render correctly with vr_flinger

Change-Id: Ibb8dfee4d6c3f4b6120c0b6e20a253f1b9307c19
parent 0c91fbbc
Loading
Loading
Loading
Loading
+12 −10
Original line number Diff line number Diff line
@@ -67,7 +67,7 @@ struct TransactionState {
  ResponseHeader<LocalHandle> response;
};

Status<void> ReadAndDiscardData(int socket_fd, size_t size) {
Status<void> ReadAndDiscardData(const BorrowedHandle& socket_fd, size_t size) {
  while (size > 0) {
    // If there is more data to read in the message than the buffers provided
    // by the caller, read and discard the extra data from the socket.
@@ -83,9 +83,10 @@ Status<void> ReadAndDiscardData(int socket_fd, size_t size) {
  return ErrorStatus(EIO);
}

Status<void> SendRequest(int socket_fd, TransactionState* transaction_state,
                         int opcode, const iovec* send_vector,
                         size_t send_count, size_t max_recv_len) {
Status<void> SendRequest(const BorrowedHandle& socket_fd,
                         TransactionState* transaction_state, int opcode,
                         const iovec* send_vector, size_t send_count,
                         size_t max_recv_len) {
  size_t send_len = CountVectorSize(send_vector, send_count);
  InitRequest(&transaction_state->request, opcode, send_len, max_recv_len,
              false);
@@ -95,7 +96,8 @@ Status<void> SendRequest(int socket_fd, TransactionState* transaction_state,
  return status;
}

Status<void> ReceiveResponse(int socket_fd, TransactionState* transaction_state,
Status<void> ReceiveResponse(const BorrowedHandle& socket_fd,
                             TransactionState* transaction_state,
                             const iovec* receive_vector, size_t receive_count,
                             size_t max_recv_len) {
  auto status = ReceiveData(socket_fd, &transaction_state->response);
@@ -164,7 +166,7 @@ Status<void> ClientChannel::SendImpulse(int opcode, const void* buffer,

  InitRequest(&request, opcode, length, 0, true);
  memcpy(request.impulse_payload.data(), buffer, length);
  return SendData(channel_handle_.value(), request);
  return SendData(BorrowedHandle{channel_handle_.value()}, request);
}

Status<int> ClientChannel::SendAndReceive(void* transaction_state, int opcode,
@@ -182,11 +184,11 @@ Status<int> ClientChannel::SendAndReceive(void* transaction_state, int opcode,
  auto* state = static_cast<TransactionState*>(transaction_state);
  size_t max_recv_len = CountVectorSize(receive_vector, receive_count);

  auto status = SendRequest(channel_handle_.value(), state, opcode, send_vector,
                            send_count, max_recv_len);
  auto status = SendRequest(BorrowedHandle{channel_handle_.value()}, state,
                            opcode, send_vector, send_count, max_recv_len);
  if (status) {
    status = ReceiveResponse(channel_handle_.value(), state, receive_vector,
                             receive_count, max_recv_len);
    status = ReceiveResponse(BorrowedHandle{channel_handle_.value()}, state,
                             receive_vector, receive_count, max_recv_len);
  }
  if (!result.PropagateError(status)) {
    const int return_code = state->response.ret_code;
+2 −2
Original line number Diff line number Diff line
@@ -111,11 +111,11 @@ Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect(
        remote.sun_path);
  RequestHeader<BorrowedHandle> request;
  InitRequest(&request, opcodes::CHANNEL_OPEN, 0, 0, false);
  status = SendData(socket_fd.Get(), request);
  status = SendData(socket_fd.Borrow(), request);
  if (!status)
    return ErrorStatus(status.error());
  ResponseHeader<LocalHandle> response;
  status = ReceiveData(socket_fd.Get(), &response);
  status = ReceiveData(socket_fd.Borrow(), &response);
  if (!status)
    return ErrorStatus(status.error());
  int ref = response.ret_code;
+27 −18
Original line number Diff line number Diff line
@@ -26,18 +26,19 @@ struct MessagePreamble {
  uint32_t fd_count{0};
};

Status<void> SendPayload::Send(int socket_fd) {
Status<void> SendPayload::Send(const BorrowedHandle& socket_fd) {
  return Send(socket_fd, nullptr);
}

Status<void> SendPayload::Send(int socket_fd, const ucred* cred) {
Status<void> SendPayload::Send(const BorrowedHandle& socket_fd,
                               const ucred* cred) {
  MessagePreamble preamble;
  preamble.magic = kMagicPreamble;
  preamble.data_size = buffer_.size();
  preamble.fd_count = file_handles_.size();

  ssize_t ret =
      RETRY_EINTR(send(socket_fd, &preamble, sizeof(preamble), MSG_NOSIGNAL));
  ssize_t ret = RETRY_EINTR(
      send(socket_fd.Get(), &preamble, sizeof(preamble), MSG_NOSIGNAL));
  if (ret < 0)
    return ErrorStatus(errno);
  if (ret != sizeof(preamble))
@@ -71,7 +72,7 @@ Status<void> SendPayload::Send(int socket_fd, const ucred* cred) {
    }
  }

  ret = RETRY_EINTR(sendmsg(socket_fd, &msg, MSG_NOSIGNAL));
  ret = RETRY_EINTR(sendmsg(socket_fd.Get(), &msg, MSG_NOSIGNAL));
  if (ret < 0)
    return ErrorStatus(errno);
  if (static_cast<size_t>(ret) != buffer_.size())
@@ -125,14 +126,15 @@ Status<ChannelReference> SendPayload::PushChannelHandle(
  return ErrorStatus{EOPNOTSUPP};
}

Status<void> ReceivePayload::Receive(int socket_fd) {
Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd) {
  return Receive(socket_fd, nullptr);
}

Status<void> ReceivePayload::Receive(int socket_fd, ucred* cred) {
Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
                                     ucred* cred) {
  MessagePreamble preamble;
  ssize_t ret =
      RETRY_EINTR(recv(socket_fd, &preamble, sizeof(preamble), MSG_WAITALL));
  ssize_t ret = RETRY_EINTR(
      recv(socket_fd.Get(), &preamble, sizeof(preamble), MSG_WAITALL));
  if (ret < 0)
    return ErrorStatus(errno);
  else if (ret == 0)
@@ -157,7 +159,7 @@ Status<void> ReceivePayload::Receive(int socket_fd, ucred* cred) {
    msg.msg_control = alloca(msg.msg_controllen);
  }

  ret = RETRY_EINTR(recvmsg(socket_fd, &msg, MSG_WAITALL));
  ret = RETRY_EINTR(recvmsg(socket_fd.Get(), &msg, MSG_WAITALL));
  if (ret < 0)
    return ErrorStatus(errno);
  else if (ret == 0)
@@ -219,8 +221,10 @@ bool ReceivePayload::GetChannelHandle(ChannelReference /*ref*/,
  return false;
}

Status<void> SendData(int socket_fd, const void* data, size_t size) {
  ssize_t size_written = RETRY_EINTR(send(socket_fd, data, size, MSG_NOSIGNAL));
Status<void> SendData(const BorrowedHandle& socket_fd, const void* data,
                      size_t size) {
  ssize_t size_written =
      RETRY_EINTR(send(socket_fd.Get(), data, size, MSG_NOSIGNAL));
  if (size_written < 0)
    return ErrorStatus(errno);
  if (static_cast<size_t>(size_written) != size)
@@ -228,11 +232,13 @@ Status<void> SendData(int socket_fd, const void* data, size_t size) {
  return {};
}

Status<void> SendDataVector(int socket_fd, const iovec* data, size_t count) {
Status<void> SendDataVector(const BorrowedHandle& socket_fd, const iovec* data,
                            size_t count) {
  msghdr msg = {};
  msg.msg_iov = const_cast<iovec*>(data);
  msg.msg_iovlen = count;
  ssize_t size_written = RETRY_EINTR(sendmsg(socket_fd, &msg, MSG_NOSIGNAL));
  ssize_t size_written =
      RETRY_EINTR(sendmsg(socket_fd.Get(), &msg, MSG_NOSIGNAL));
  if (size_written < 0)
    return ErrorStatus(errno);
  if (static_cast<size_t>(size_written) != CountVectorSize(data, count))
@@ -240,8 +246,10 @@ Status<void> SendDataVector(int socket_fd, const iovec* data, size_t count) {
  return {};
}

Status<void> ReceiveData(int socket_fd, void* data, size_t size) {
  ssize_t size_read = RETRY_EINTR(recv(socket_fd, data, size, MSG_WAITALL));
Status<void> ReceiveData(const BorrowedHandle& socket_fd, void* data,
                         size_t size) {
  ssize_t size_read =
      RETRY_EINTR(recv(socket_fd.Get(), data, size, MSG_WAITALL));
  if (size_read < 0)
    return ErrorStatus(errno);
  else if (size_read == 0)
@@ -251,11 +259,12 @@ Status<void> ReceiveData(int socket_fd, void* data, size_t size) {
  return {};
}

Status<void> ReceiveDataVector(int socket_fd, const iovec* data, size_t count) {
Status<void> ReceiveDataVector(const BorrowedHandle& socket_fd,
                               const iovec* data, size_t count) {
  msghdr msg = {};
  msg.msg_iov = const_cast<iovec*>(data);
  msg.msg_iovlen = count;
  ssize_t size_read = RETRY_EINTR(recvmsg(socket_fd, &msg, MSG_WAITALL));
  ssize_t size_read = RETRY_EINTR(recvmsg(socket_fd.Get(), &msg, MSG_WAITALL));
  if (size_read < 0)
    return ErrorStatus(errno);
  else if (size_read == 0)
+16 −12
Original line number Diff line number Diff line
@@ -25,8 +25,8 @@ namespace uds {

class SendPayload : public MessageWriter, public OutputResourceMapper {
 public:
  Status<void> Send(int socket_fd);
  Status<void> Send(int socket_fd, const ucred* cred);
  Status<void> Send(const BorrowedHandle& socket_fd);
  Status<void> Send(const BorrowedHandle& socket_fd, const ucred* cred);

  // MessageWriter
  void* GetNextWriteBufferSection(size_t size) override;
@@ -50,8 +50,8 @@ class SendPayload : public MessageWriter, public OutputResourceMapper {

class ReceivePayload : public MessageReader, public InputResourceMapper {
 public:
  Status<void> Receive(int socket_fd);
  Status<void> Receive(int socket_fd, ucred* cred);
  Status<void> Receive(const BorrowedHandle& socket_fd);
  Status<void> Receive(const BorrowedHandle& socket_fd, ucred* cred);

  // MessageReader
  BufferSection GetNextReadBufferSection() override;
@@ -111,25 +111,27 @@ class ResponseHeader {
};

template <typename T>
inline Status<void> SendData(int socket_fd, const T& data) {
inline Status<void> SendData(const BorrowedHandle& socket_fd, const T& data) {
  SendPayload payload;
  rpc::Serialize(data, &payload);
  return payload.Send(socket_fd);
}

template <typename FileHandleType>
inline Status<void> SendData(int socket_fd,
inline Status<void> SendData(const BorrowedHandle& socket_fd,
                             const RequestHeader<FileHandleType>& request) {
  SendPayload payload;
  rpc::Serialize(request, &payload);
  return payload.Send(socket_fd, &request.cred);
}

Status<void> SendData(int socket_fd, const void* data, size_t size);
Status<void> SendDataVector(int socket_fd, const iovec* data, size_t count);
Status<void> SendData(const BorrowedHandle& socket_fd, const void* data,
                      size_t size);
Status<void> SendDataVector(const BorrowedHandle& socket_fd, const iovec* data,
                            size_t count);

template <typename T>
inline Status<void> ReceiveData(int socket_fd, T* data) {
inline Status<void> ReceiveData(const BorrowedHandle& socket_fd, T* data) {
  ReceivePayload payload;
  Status<void> status = payload.Receive(socket_fd);
  if (status && rpc::Deserialize(data, &payload) != rpc::ErrorCode::NO_ERROR)
@@ -138,7 +140,7 @@ inline Status<void> ReceiveData(int socket_fd, T* data) {
}

template <typename FileHandleType>
inline Status<void> ReceiveData(int socket_fd,
inline Status<void> ReceiveData(const BorrowedHandle& socket_fd,
                                RequestHeader<FileHandleType>* request) {
  ReceivePayload payload;
  Status<void> status = payload.Receive(socket_fd, &request->cred);
@@ -147,8 +149,10 @@ inline Status<void> ReceiveData(int socket_fd,
  return status;
}

Status<void> ReceiveData(int socket_fd, void* data, size_t size);
Status<void> ReceiveDataVector(int socket_fd, const iovec* data, size_t count);
Status<void> ReceiveData(const BorrowedHandle& socket_fd, void* data,
                         size_t size);
Status<void> ReceiveDataVector(const BorrowedHandle& socket_fd,
                               const iovec* data, size_t count);

size_t CountVectorSize(const iovec* data, size_t count);
void InitRequest(android::pdx::uds::RequestHeader<BorrowedHandle>* request,
+14 −10
Original line number Diff line number Diff line
@@ -117,18 +117,20 @@ class Endpoint : public pdx::Endpoint {
    return next_message_id_.fetch_add(1, std::memory_order_relaxed);
  }

  void BuildCloseMessage(int channel_id, Message* message);
  void BuildCloseMessage(int32_t channel_id, Message* message);

  Status<void> AcceptConnection(Message* message);
  Status<void> ReceiveMessageForChannel(int channel_id, Message* message);
  Status<void> ReceiveMessageForChannel(const BorrowedHandle& channel_fd,
                                        Message* message);
  Status<void> OnNewChannel(LocalHandle channel_fd);
  Status<ChannelData*> OnNewChannelLocked(LocalHandle channel_fd,
                                          Channel* channel_state);
  Status<void> CloseChannelLocked(int channel_id);
  Status<void> ReenableEpollEvent(int fd);
  Channel* GetChannelState(int channel_id);
  int GetChannelSocketFd(int channel_id);
  int GetChannelEventFd(int channel_id);
  Status<std::pair<int32_t, ChannelData*>> OnNewChannelLocked(
      LocalHandle channel_fd, Channel* channel_state);
  Status<void> CloseChannelLocked(int32_t channel_id);
  Status<void> ReenableEpollEvent(const BorrowedHandle& channel_fd);
  Channel* GetChannelState(int32_t channel_id);
  BorrowedHandle GetChannelSocketFd(int32_t channel_id);
  BorrowedHandle GetChannelEventFd(int32_t channel_id);
  int32_t GetChannelId(const BorrowedHandle& channel_fd);

  std::string endpoint_path_;
  bool is_blocking_;
@@ -137,7 +139,9 @@ class Endpoint : public pdx::Endpoint {
  LocalHandle epoll_fd_;

  mutable std::mutex channel_mutex_;
  std::map<int, ChannelData> channels_;
  std::map<int32_t, ChannelData> channels_;
  std::map<int, int32_t> channel_fd_to_id_;
  int32_t last_channel_id_{0};

  Service* service_{nullptr};
  std::atomic<uint32_t> next_message_id_;
Loading