Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 6eefa42d authored by Alex Vakulenko's avatar Alex Vakulenko
Browse files

libpdx_uds: Allow to create Endpoint/ClientChannel from a socket pair

This is important to enable Service/Client operation in unit tests.
Being able to create a pair of Unix domain sockets and construct both
Service and Client so that they can talk to each other without having
to create a physical socket file is convenient.

This change makes it possible to create an instance of Endpoint and
ClientChannel classes based just on a pair of sockets (Endpoint does
take another socket to simulate the main endpoint FD to accept incoming
connection on, but it is not used for this, only the shutdown events
are received from that main socket. Endpoint uses the channel FD to
perform actual communication with the client).

Bug: 37443070
Test: `libpdx_uds_tests` pass
Change-Id: Ifa1a9d03b97bd90282a04715c2105ad37a8de936
parent 7038c239
Loading
Loading
Loading
Loading
+28 −14
Original line number Diff line number Diff line
@@ -39,32 +39,42 @@ std::string ClientChannelFactory::GetEndpointPath(
ClientChannelFactory::ClientChannelFactory(const std::string& endpoint_path)
    : endpoint_path_{GetEndpointPath(endpoint_path)} {}

ClientChannelFactory::ClientChannelFactory(LocalHandle socket)
    : socket_{std::move(socket)} {}

std::unique_ptr<pdx::ClientChannelFactory> ClientChannelFactory::Create(
    const std::string& endpoint_path) {
  return std::unique_ptr<pdx::ClientChannelFactory>{
      new ClientChannelFactory{endpoint_path}};
}

std::unique_ptr<pdx::ClientChannelFactory> ClientChannelFactory::Create(
    LocalHandle socket) {
  return std::unique_ptr<pdx::ClientChannelFactory>{
      new ClientChannelFactory{std::move(socket)}};
}

Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect(
    int64_t timeout_ms) const {
  Status<void> status;

  LocalHandle socket_fd{socket(AF_UNIX, SOCK_STREAM, 0)};
  if (!socket_fd) {
  bool connected = socket_.IsValid();
  if (!connected) {
    socket_.Reset(socket(AF_UNIX, SOCK_STREAM, 0));
    LOG_ALWAYS_FATAL_IF(
        endpoint_path_.empty(),
        "ClientChannelFactory::Connect: unspecified socket path");
  }

  if (!socket_) {
    ALOGE("ClientChannelFactory::Connect: socket error: %s", strerror(errno));
    return ErrorStatus(errno);
  }

  sockaddr_un remote;
  remote.sun_family = AF_UNIX;
  strncpy(remote.sun_path, endpoint_path_.c_str(), sizeof(remote.sun_path));
  remote.sun_path[sizeof(remote.sun_path) - 1] = '\0';

  bool use_timeout = (timeout_ms >= 0);
  auto now = steady_clock::now();
  auto time_end = now + std::chrono::milliseconds{timeout_ms};

  bool connected = false;
  int max_eaccess = 5;  // Max number of times to retry when EACCES returned.
  while (!connected) {
    int64_t timeout = -1;
@@ -74,6 +84,10 @@ Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect(
      if (timeout < 0)
        return ErrorStatus(ETIMEDOUT);
    }
    sockaddr_un remote;
    remote.sun_family = AF_UNIX;
    strncpy(remote.sun_path, endpoint_path_.c_str(), sizeof(remote.sun_path));
    remote.sun_path[sizeof(remote.sun_path) - 1] = '\0';
    ALOGD("ClientChannelFactory: Waiting for endpoint at %s", remote.sun_path);
    status = WaitForEndpoint(endpoint_path_, timeout);
    if (!status)
@@ -81,7 +95,7 @@ Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect(

    ALOGD("ClientChannelFactory: Connecting to %s", remote.sun_path);
    int ret = RETRY_EINTR(connect(
        socket_fd.Get(), reinterpret_cast<sockaddr*>(&remote), sizeof(remote)));
        socket_.Get(), reinterpret_cast<sockaddr*>(&remote), sizeof(remote)));
    if (ret == -1) {
      ALOGD("ClientChannelFactory: Connect error %d: %s", errno,
            strerror(errno));
@@ -107,20 +121,20 @@ Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect(
      }
    } else {
      connected = true;
      ALOGD("ClientChannelFactory: Connected successfully to %s...",
            remote.sun_path);
    }
    if (use_timeout)
      now = steady_clock::now();
  }  // while (!connected)

  ALOGD("ClientChannelFactory: Connected successfully to %s...",
        remote.sun_path);
  RequestHeader<BorrowedHandle> request;
  InitRequest(&request, opcodes::CHANNEL_OPEN, 0, 0, false);
  status = SendData(socket_fd.Borrow(), request);
  status = SendData(socket_.Borrow(), request);
  if (!status)
    return ErrorStatus(status.error());
  ResponseHeader<LocalHandle> response;
  status = ReceiveData(socket_fd.Borrow(), &response);
  status = ReceiveData(socket_.Borrow(), &response);
  if (!status)
    return ErrorStatus(status.error());
  int ref = response.ret_code;
@@ -129,7 +143,7 @@ Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect(

  LocalHandle event_fd = std::move(response.file_descriptors[ref]);
  return ClientChannel::Create(ChannelManager::Get().CreateHandle(
      std::move(socket_fd), std::move(event_fd)));
      std::move(socket_), std::move(event_fd)));
}

}  // namespace uds
+4 −1
Original line number Diff line number Diff line
@@ -275,6 +275,7 @@ Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
    return ret;

  if (preamble.magic != kMagicPreamble) {
    ALOGE("ReceivePayload::Receive: Message header is invalid");
    ret.SetError(EIO);
    return ret;
  }
@@ -319,8 +320,10 @@ Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
    cmsg = CMSG_NXTHDR(&msg, cmsg);
  }

  if (cred && !cred_available)
  if (cred && !cred_available) {
    ALOGE("ReceivePayload::Receive: Failed to obtain message credentials");
    ret.SetError(EIO);
  }

  return ret;
}
+3 −0
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ class ClientChannelFactory : public pdx::ClientChannelFactory {
 public:
  static std::unique_ptr<pdx::ClientChannelFactory> Create(
      const std::string& endpoint_path);
  static std::unique_ptr<pdx::ClientChannelFactory> Create(LocalHandle socket);

  Status<std::unique_ptr<pdx::ClientChannel>> Connect(
      int64_t timeout_ms) const override;
@@ -22,7 +23,9 @@ class ClientChannelFactory : public pdx::ClientChannelFactory {

 private:
  explicit ClientChannelFactory(const std::string& endpoint_path);
  explicit ClientChannelFactory(LocalHandle socket);

  mutable LocalHandle socket_;
  std::string endpoint_path_;
};

+11 −0
Original line number Diff line number Diff line
@@ -97,6 +97,14 @@ class Endpoint : public pdx::Endpoint {
  static std::unique_ptr<Endpoint> CreateAndBindSocket(
      const std::string& endpoint_path, bool blocking = kDefaultBlocking);

  // Helper method to create an endpoint from an existing socket FD.
  // Mostly helpful for tests.
  static std::unique_ptr<Endpoint> CreateFromSocketFd(LocalHandle socket_fd);

  // Test helper method to register a new channel identified by |channel_fd|
  // socket file descriptor.
  Status<void> RegisterNewChannelForTests(LocalHandle channel_fd);

  int epoll_fd() const { return epoll_fd_.Get(); }

 private:
@@ -109,6 +117,9 @@ class Endpoint : public pdx::Endpoint {
  // This class must be instantiated using Create() static methods above.
  Endpoint(const std::string& endpoint_path, bool blocking,
           bool use_init_socket_fd = true);
  Endpoint(LocalHandle socket_fd);

  void Init(LocalHandle socket_fd);

  Endpoint(const Endpoint&) = delete;
  void operator=(const Endpoint&) = delete;
+43 −13
Original line number Diff line number Diff line
@@ -161,9 +161,16 @@ Endpoint::Endpoint(const std::string& endpoint_path, bool blocking,
        bind(fd.Get(), reinterpret_cast<sockaddr*>(&local), sizeof(local));
    CHECK_EQ(ret, 0) << "Endpoint::Endpoint: bind error: " << strerror(errno);
  }
  CHECK_EQ(listen(fd.Get(), kMaxBackLogForSocketListen), 0)
      << "Endpoint::Endpoint: listen error: " << strerror(errno);
  Init(std::move(fd));
}

Endpoint::Endpoint(LocalHandle socket_fd) { Init(std::move(socket_fd)); }

void Endpoint::Init(LocalHandle socket_fd) {
  if (socket_fd) {
    CHECK_EQ(listen(socket_fd.Get(), kMaxBackLogForSocketListen), 0)
        << "Endpoint::Endpoint: listen error: " << strerror(errno);
  }
  cancel_event_fd_.Reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
  CHECK(cancel_event_fd_.IsValid())
      << "Endpoint::Endpoint: Failed to create event fd: " << strerror(errno);
@@ -172,24 +179,27 @@ Endpoint::Endpoint(const std::string& endpoint_path, bool blocking,
  CHECK(epoll_fd_.IsValid())
      << "Endpoint::Endpoint: Failed to create epoll fd: " << strerror(errno);

  if (socket_fd) {
    epoll_event socket_event;
    socket_event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
  socket_event.data.fd = fd.Get();
    socket_event.data.fd = socket_fd.Get();
    int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, socket_fd.Get(),
                        &socket_event);
    CHECK_EQ(ret, 0)
        << "Endpoint::Endpoint: Failed to add socket fd to epoll fd: "
        << strerror(errno);
  }

  epoll_event cancel_event;
  cancel_event.events = EPOLLIN;
  cancel_event.data.fd = cancel_event_fd_.Get();

  int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, fd.Get(), &socket_event);
  CHECK_EQ(ret, 0)
      << "Endpoint::Endpoint: Failed to add socket fd to epoll fd: "
      << strerror(errno);
  ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(),
  int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(),
                      &cancel_event);
  CHECK_EQ(ret, 0)
      << "Endpoint::Endpoint: Failed to add cancel event fd to epoll fd: "
      << strerror(errno);
  socket_fd_ = std::move(fd);
  socket_fd_ = std::move(socket_fd);
}

void* Endpoint::AllocateMessageState() { return new MessageState; }
@@ -199,6 +209,9 @@ void Endpoint::FreeMessageState(void* state) {
}

Status<void> Endpoint::AcceptConnection(Message* message) {
  if (!socket_fd_)
    return ErrorStatus(EBADF);

  sockaddr_un remote;
  socklen_t addrlen = sizeof(remote);
  LocalHandle channel_fd{accept4(socket_fd_.Get(),
@@ -515,7 +528,7 @@ Status<void> Endpoint::MessageReceive(Message* message) {
    return ErrorStatus{ESHUTDOWN};
  }

  if (event.data.fd == socket_fd_.Get()) {
  if (socket_fd_ && event.data.fd == socket_fd_.Get()) {
    auto status = AcceptConnection(message);
    if (!status)
      return status;
@@ -680,6 +693,23 @@ std::unique_ptr<Endpoint> Endpoint::CreateAndBindSocket(
      new Endpoint(endpoint_path, blocking, false));
}

std::unique_ptr<Endpoint> Endpoint::CreateFromSocketFd(LocalHandle socket_fd) {
  return std::unique_ptr<Endpoint>(new Endpoint(std::move(socket_fd)));
}

Status<void> Endpoint::RegisterNewChannelForTests(LocalHandle channel_fd) {
  int optval = 1;
  if (setsockopt(channel_fd.Get(), SOL_SOCKET, SO_PASSCRED, &optval,
                 sizeof(optval)) == -1) {
    ALOGE(
        "Endpoint::RegisterNewChannelForTests: Failed to enable the receiving"
        "of the credentials for channel %d: %s",
        channel_fd.Get(), strerror(errno));
    return ErrorStatus(errno);
  }
  return OnNewChannel(std::move(channel_fd));
}

}  // namespace uds
}  // namespace pdx
}  // namespace android