Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit f5524db2 authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "libpdx_uds: Allow to create Endpoint/ClientChannel from a socket pair" into oc-dev

parents 20c35e42 6eefa42d
Loading
Loading
Loading
Loading
+28 −14
Original line number Diff line number Diff line
@@ -39,32 +39,42 @@ std::string ClientChannelFactory::GetEndpointPath(
ClientChannelFactory::ClientChannelFactory(const std::string& endpoint_path)
    : endpoint_path_{GetEndpointPath(endpoint_path)} {}

ClientChannelFactory::ClientChannelFactory(LocalHandle socket)
    : socket_{std::move(socket)} {}

std::unique_ptr<pdx::ClientChannelFactory> ClientChannelFactory::Create(
    const std::string& endpoint_path) {
  return std::unique_ptr<pdx::ClientChannelFactory>{
      new ClientChannelFactory{endpoint_path}};
}

std::unique_ptr<pdx::ClientChannelFactory> ClientChannelFactory::Create(
    LocalHandle socket) {
  return std::unique_ptr<pdx::ClientChannelFactory>{
      new ClientChannelFactory{std::move(socket)}};
}

Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect(
    int64_t timeout_ms) const {
  Status<void> status;

  LocalHandle socket_fd{socket(AF_UNIX, SOCK_STREAM, 0)};
  if (!socket_fd) {
  bool connected = socket_.IsValid();
  if (!connected) {
    socket_.Reset(socket(AF_UNIX, SOCK_STREAM, 0));
    LOG_ALWAYS_FATAL_IF(
        endpoint_path_.empty(),
        "ClientChannelFactory::Connect: unspecified socket path");
  }

  if (!socket_) {
    ALOGE("ClientChannelFactory::Connect: socket error: %s", strerror(errno));
    return ErrorStatus(errno);
  }

  sockaddr_un remote;
  remote.sun_family = AF_UNIX;
  strncpy(remote.sun_path, endpoint_path_.c_str(), sizeof(remote.sun_path));
  remote.sun_path[sizeof(remote.sun_path) - 1] = '\0';

  bool use_timeout = (timeout_ms >= 0);
  auto now = steady_clock::now();
  auto time_end = now + std::chrono::milliseconds{timeout_ms};

  bool connected = false;
  int max_eaccess = 5;  // Max number of times to retry when EACCES returned.
  while (!connected) {
    int64_t timeout = -1;
@@ -74,6 +84,10 @@ Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect(
      if (timeout < 0)
        return ErrorStatus(ETIMEDOUT);
    }
    sockaddr_un remote;
    remote.sun_family = AF_UNIX;
    strncpy(remote.sun_path, endpoint_path_.c_str(), sizeof(remote.sun_path));
    remote.sun_path[sizeof(remote.sun_path) - 1] = '\0';
    ALOGD("ClientChannelFactory: Waiting for endpoint at %s", remote.sun_path);
    status = WaitForEndpoint(endpoint_path_, timeout);
    if (!status)
@@ -81,7 +95,7 @@ Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect(

    ALOGD("ClientChannelFactory: Connecting to %s", remote.sun_path);
    int ret = RETRY_EINTR(connect(
        socket_fd.Get(), reinterpret_cast<sockaddr*>(&remote), sizeof(remote)));
        socket_.Get(), reinterpret_cast<sockaddr*>(&remote), sizeof(remote)));
    if (ret == -1) {
      ALOGD("ClientChannelFactory: Connect error %d: %s", errno,
            strerror(errno));
@@ -107,20 +121,20 @@ Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect(
      }
    } else {
      connected = true;
      ALOGD("ClientChannelFactory: Connected successfully to %s...",
            remote.sun_path);
    }
    if (use_timeout)
      now = steady_clock::now();
  }  // while (!connected)

  ALOGD("ClientChannelFactory: Connected successfully to %s...",
        remote.sun_path);
  RequestHeader<BorrowedHandle> request;
  InitRequest(&request, opcodes::CHANNEL_OPEN, 0, 0, false);
  status = SendData(socket_fd.Borrow(), request);
  status = SendData(socket_.Borrow(), request);
  if (!status)
    return ErrorStatus(status.error());
  ResponseHeader<LocalHandle> response;
  status = ReceiveData(socket_fd.Borrow(), &response);
  status = ReceiveData(socket_.Borrow(), &response);
  if (!status)
    return ErrorStatus(status.error());
  int ref = response.ret_code;
@@ -129,7 +143,7 @@ Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect(

  LocalHandle event_fd = std::move(response.file_descriptors[ref]);
  return ClientChannel::Create(ChannelManager::Get().CreateHandle(
      std::move(socket_fd), std::move(event_fd)));
      std::move(socket_), std::move(event_fd)));
}

}  // namespace uds
+4 −1
Original line number Diff line number Diff line
@@ -275,6 +275,7 @@ Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
    return ret;

  if (preamble.magic != kMagicPreamble) {
    ALOGE("ReceivePayload::Receive: Message header is invalid");
    ret.SetError(EIO);
    return ret;
  }
@@ -319,8 +320,10 @@ Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
    cmsg = CMSG_NXTHDR(&msg, cmsg);
  }

  if (cred && !cred_available)
  if (cred && !cred_available) {
    ALOGE("ReceivePayload::Receive: Failed to obtain message credentials");
    ret.SetError(EIO);
  }

  return ret;
}
+3 −0
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ class ClientChannelFactory : public pdx::ClientChannelFactory {
 public:
  static std::unique_ptr<pdx::ClientChannelFactory> Create(
      const std::string& endpoint_path);
  static std::unique_ptr<pdx::ClientChannelFactory> Create(LocalHandle socket);

  Status<std::unique_ptr<pdx::ClientChannel>> Connect(
      int64_t timeout_ms) const override;
@@ -22,7 +23,9 @@ class ClientChannelFactory : public pdx::ClientChannelFactory {

 private:
  explicit ClientChannelFactory(const std::string& endpoint_path);
  explicit ClientChannelFactory(LocalHandle socket);

  mutable LocalHandle socket_;
  std::string endpoint_path_;
};

+11 −0
Original line number Diff line number Diff line
@@ -97,6 +97,14 @@ class Endpoint : public pdx::Endpoint {
  static std::unique_ptr<Endpoint> CreateAndBindSocket(
      const std::string& endpoint_path, bool blocking = kDefaultBlocking);

  // Helper method to create an endpoint from an existing socket FD.
  // Mostly helpful for tests.
  static std::unique_ptr<Endpoint> CreateFromSocketFd(LocalHandle socket_fd);

  // Test helper method to register a new channel identified by |channel_fd|
  // socket file descriptor.
  Status<void> RegisterNewChannelForTests(LocalHandle channel_fd);

  int epoll_fd() const { return epoll_fd_.Get(); }

 private:
@@ -109,6 +117,9 @@ class Endpoint : public pdx::Endpoint {
  // This class must be instantiated using Create() static methods above.
  Endpoint(const std::string& endpoint_path, bool blocking,
           bool use_init_socket_fd = true);
  Endpoint(LocalHandle socket_fd);

  void Init(LocalHandle socket_fd);

  Endpoint(const Endpoint&) = delete;
  void operator=(const Endpoint&) = delete;
+43 −13
Original line number Diff line number Diff line
@@ -161,9 +161,16 @@ Endpoint::Endpoint(const std::string& endpoint_path, bool blocking,
        bind(fd.Get(), reinterpret_cast<sockaddr*>(&local), sizeof(local));
    CHECK_EQ(ret, 0) << "Endpoint::Endpoint: bind error: " << strerror(errno);
  }
  CHECK_EQ(listen(fd.Get(), kMaxBackLogForSocketListen), 0)
      << "Endpoint::Endpoint: listen error: " << strerror(errno);
  Init(std::move(fd));
}

Endpoint::Endpoint(LocalHandle socket_fd) { Init(std::move(socket_fd)); }

void Endpoint::Init(LocalHandle socket_fd) {
  if (socket_fd) {
    CHECK_EQ(listen(socket_fd.Get(), kMaxBackLogForSocketListen), 0)
        << "Endpoint::Endpoint: listen error: " << strerror(errno);
  }
  cancel_event_fd_.Reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
  CHECK(cancel_event_fd_.IsValid())
      << "Endpoint::Endpoint: Failed to create event fd: " << strerror(errno);
@@ -172,24 +179,27 @@ Endpoint::Endpoint(const std::string& endpoint_path, bool blocking,
  CHECK(epoll_fd_.IsValid())
      << "Endpoint::Endpoint: Failed to create epoll fd: " << strerror(errno);

  if (socket_fd) {
    epoll_event socket_event;
    socket_event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
  socket_event.data.fd = fd.Get();
    socket_event.data.fd = socket_fd.Get();
    int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, socket_fd.Get(),
                        &socket_event);
    CHECK_EQ(ret, 0)
        << "Endpoint::Endpoint: Failed to add socket fd to epoll fd: "
        << strerror(errno);
  }

  epoll_event cancel_event;
  cancel_event.events = EPOLLIN;
  cancel_event.data.fd = cancel_event_fd_.Get();

  int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, fd.Get(), &socket_event);
  CHECK_EQ(ret, 0)
      << "Endpoint::Endpoint: Failed to add socket fd to epoll fd: "
      << strerror(errno);
  ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(),
  int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(),
                      &cancel_event);
  CHECK_EQ(ret, 0)
      << "Endpoint::Endpoint: Failed to add cancel event fd to epoll fd: "
      << strerror(errno);
  socket_fd_ = std::move(fd);
  socket_fd_ = std::move(socket_fd);
}

void* Endpoint::AllocateMessageState() { return new MessageState; }
@@ -199,6 +209,9 @@ void Endpoint::FreeMessageState(void* state) {
}

Status<void> Endpoint::AcceptConnection(Message* message) {
  if (!socket_fd_)
    return ErrorStatus(EBADF);

  sockaddr_un remote;
  socklen_t addrlen = sizeof(remote);
  LocalHandle channel_fd{accept4(socket_fd_.Get(),
@@ -515,7 +528,7 @@ Status<void> Endpoint::MessageReceive(Message* message) {
    return ErrorStatus{ESHUTDOWN};
  }

  if (event.data.fd == socket_fd_.Get()) {
  if (socket_fd_ && event.data.fd == socket_fd_.Get()) {
    auto status = AcceptConnection(message);
    if (!status)
      return status;
@@ -680,6 +693,23 @@ std::unique_ptr<Endpoint> Endpoint::CreateAndBindSocket(
      new Endpoint(endpoint_path, blocking, false));
}

std::unique_ptr<Endpoint> Endpoint::CreateFromSocketFd(LocalHandle socket_fd) {
  return std::unique_ptr<Endpoint>(new Endpoint(std::move(socket_fd)));
}

Status<void> Endpoint::RegisterNewChannelForTests(LocalHandle channel_fd) {
  int optval = 1;
  if (setsockopt(channel_fd.Get(), SOL_SOCKET, SO_PASSCRED, &optval,
                 sizeof(optval)) == -1) {
    ALOGE(
        "Endpoint::RegisterNewChannelForTests: Failed to enable the receiving"
        "of the credentials for channel %d: %s",
        channel_fd.Get(), strerror(errno));
    return ErrorStatus(errno);
  }
  return OnNewChannel(std::move(channel_fd));
}

}  // namespace uds
}  // namespace pdx
}  // namespace android