Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit b595b894 authored by Henri Chataing's avatar Henri Chataing Committed by Gerrit Code Review
Browse files

Merge "snoop_logger_socket: Replace Select() by Poll()" into main

parents 6224ecd8 6d9bbd1a
Loading
Loading
Loading
Loading
+23 −34
Original line number Diff line number Diff line
@@ -48,19 +48,19 @@ constexpr int INCOMING_SOCKET_CONNECTIONS_QUEUE_SIZE_ = 10;

SnoopLoggerSocket::SnoopLoggerSocket(SyscallWrapperInterface* syscall_if, int socket_address,
                                     int socket_port)
    : syscall_if_(syscall_if),
      socket_address_(socket_address),
      socket_port_(socket_port),
      notification_listen_fd_(-1),
      notification_write_fd_(-1),
      listen_socket_(-1),
      fd_max_(-1),
      client_socket_(-1) {
    : syscall_if_(syscall_if), socket_address_(socket_address), socket_port_(socket_port) {
  log::info("address {} port {}", socket_address, socket_port);
  ResetPollFds();
}

SnoopLoggerSocket::~SnoopLoggerSocket() { Cleanup(); }

void SnoopLoggerSocket::ResetPollFds() {
  for (int i = 0; i < kNumPollFd; i++) {
    poll_fds_[i].fd = -1;
  }
}

void SnoopLoggerSocket::Write(int& client_socket, const void* data, size_t length) {
  if (client_socket == -1) {
    return;
@@ -85,10 +85,6 @@ int SnoopLoggerSocket::InitializeCommunications() {
  int self_pipe_fds[2];
  int ret;

  fd_max_ = -1;

  syscall_if_->FDZero(&save_sock_fds_);

  // Set up the communication channel
  ret = syscall_if_->Pipe2(self_pipe_fds, O_NONBLOCK | O_CLOEXEC);
  if (ret < 0) {
@@ -99,12 +95,14 @@ int SnoopLoggerSocket::InitializeCommunications() {
  notification_listen_fd_ = self_pipe_fds[0];
  notification_write_fd_ = self_pipe_fds[1];

  syscall_if_->FDSet(notification_listen_fd_, &save_sock_fds_);
  fd_max_ = notification_listen_fd_;
  ResetPollFds();
  poll_fds_[kNotificationFd].fd = notification_listen_fd_;
  poll_fds_[kNotificationFd].events = POLLIN;

  listen_socket_ = CreateSocket();
  if (listen_socket_ == INVALID_FD) {
    log::error("Unable to create a listen socket.");
    poll_fds_[kNotificationFd].fd = -1;
    SafeCloseSocket(notification_listen_fd_);
    SafeCloseSocket(notification_write_fd_);
    return -1;
@@ -114,20 +112,19 @@ int SnoopLoggerSocket::InitializeCommunications() {
}

bool SnoopLoggerSocket::ProcessIncomingRequest() {
  int ret;
  fd_set sock_fds = save_sock_fds_;

  if ((syscall_if_->Select(fd_max_ + 1, &sock_fds, NULL, NULL, NULL)) == -1) {
    log::error("select failed {}", strerror(syscall_if_->GetErrno()));
    if (syscall_if_->GetErrno() == EINTR) {
      return true;
  if (syscall_if_->Poll(poll_fds_, kNumPollFd, -1) == -1) {
    log::error("Poll failed {}", strerror(syscall_if_->GetErrno()));
    return syscall_if_->GetErrno() == EINTR;
  }

  if (poll_fds_[kNotificationFd].revents) {
    log::warn("exiting from listen_fn_ thread");
    return false;
  }

  if ((listen_socket_ != -1) && syscall_if_->FDIsSet(listen_socket_, &sock_fds)) {
  if (poll_fds_[kSocketFd].revents) {
    int client_socket = -1;
    ret = AcceptIncomingConnection(listen_socket_, client_socket);
    int ret = AcceptIncomingConnection(listen_socket_, client_socket);
    if (ret != 0) {
      // Unrecoverable error, stop the thread.
      return false;
@@ -138,12 +135,7 @@ bool SnoopLoggerSocket::ProcessIncomingRequest() {
    }

    InitializeClientSocket(client_socket);

    ClientSocketConnected(client_socket);
  } else if ((notification_listen_fd_ != -1) &&
             syscall_if_->FDIsSet(notification_listen_fd_, &sock_fds)) {
    log::warn("exting from listen_fn_ thread");
    return false;
  }

  return true;
@@ -154,6 +146,7 @@ void SnoopLoggerSocket::Cleanup() {
  SafeCloseSocket(notification_write_fd_);
  SafeCloseSocket(client_socket_);
  SafeCloseSocket(listen_socket_);
  ResetPollFds();
}

int SnoopLoggerSocket::AcceptIncomingConnection(int listen_socket, int& client_socket) {
@@ -210,11 +203,6 @@ int SnoopLoggerSocket::CreateSocket() {
    return INVALID_FD;
  }

  syscall_if_->FDSet(socket_fd, &save_sock_fds_);
  if (socket_fd > fd_max_) {
    fd_max_ = socket_fd;
  }

  // Enable REUSEADDR
  int enable = 1;
  ret = syscall_if_->Setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable));
@@ -245,6 +233,8 @@ int SnoopLoggerSocket::CreateSocket() {
    return INVALID_FD;
  }

  poll_fds_[kSocketFd].fd = socket_fd;
  poll_fds_[kSocketFd].events = POLLIN;
  return socket_fd;
}

@@ -270,7 +260,6 @@ void SnoopLoggerSocket::SafeCloseSocket(int& fd) {
  log::debug("{}", fd);
  if (fd != -1) {
    syscall_if_->Close(fd);
    syscall_if_->FDClr(fd, &save_sock_fds_);
    fd = -1;
  }
}
+15 −9
Original line number Diff line number Diff line
@@ -66,21 +66,27 @@ private:
  int socket_port_;

  // A pair of FD to send information to the listen thread.
  int notification_listen_fd_;
  int notification_write_fd_;
  int notification_listen_fd_{-1};
  int notification_write_fd_{-1};

  // Server socket
  int listen_socket_;

  // Socket FDs for listening for connections
  // and for communitcation with listener thread.
  fd_set save_sock_fds_;
  int fd_max_;
  int listen_socket_{-1};

  // Reference to connected client socket.
  std::mutex client_socket_mutex_;
  int client_socket_;
  std::condition_variable client_socket_cv_;
  int client_socket_{-1};

  enum PollFd {
    kNotificationFd,
    kSocketFd,
    kNumPollFd,
  };

  // Array of FDs for polling.
  struct pollfd poll_fds_[kNumPollFd];

  void ResetPollFds();
};

}  // namespace hal
+42 −65
Original line number Diff line number Diff line
@@ -51,10 +51,7 @@ protected:
    ON_CALL(mock, Bind(Eq(fd), _, _)).WillByDefault(Return(0));
    ON_CALL(mock, Listen(Eq(fd), _)).WillByDefault(Return(0));

    EXPECT_CALL(mock, FDZero);
    EXPECT_CALL(mock, Pipe2(_, _));
    EXPECT_CALL(mock, FDSet(Eq(listen_fd), _));
    EXPECT_CALL(mock, FDSet(Eq(fd), _));
    EXPECT_CALL(mock, Socket);
    EXPECT_CALL(mock, Setsockopt);
    EXPECT_CALL(mock, Bind);
@@ -64,11 +61,8 @@ protected:

    // will be called in destructor
    EXPECT_CALL(mock, Close(Eq(fd)));
    EXPECT_CALL(mock, FDClr(Eq(fd), _));
    EXPECT_CALL(mock, Close(Eq(listen_fd)));
    EXPECT_CALL(mock, FDClr(Eq(listen_fd), _));
    EXPECT_CALL(mock, Close(Eq(write_fd)));
    EXPECT_CALL(mock, FDClr(Eq(write_fd), _));
  }

  void TearDown() override {}
@@ -104,8 +98,6 @@ TEST_F(SnoopLoggerSocketModuleTest, test_CreateSocket_fail_on_Setsockopt) {
  EXPECT_CALL(mock, Socket);
  EXPECT_CALL(mock, Setsockopt);
  EXPECT_CALL(mock, Close);
  EXPECT_CALL(mock, FDClr(Eq(fd), _));
  EXPECT_CALL(mock, FDSet(Eq(fd), _));
  EXPECT_CALL(mock, GetErrno);
  ASSERT_EQ(sls.CreateSocket(), INVALID_FD);
}
@@ -119,8 +111,6 @@ TEST_F(SnoopLoggerSocketModuleTest, test_CreateSocket_fail_on_Bind) {
  EXPECT_CALL(mock, Setsockopt);
  EXPECT_CALL(mock, Bind);
  EXPECT_CALL(mock, Close);
  EXPECT_CALL(mock, FDSet(Eq(fd), _));
  EXPECT_CALL(mock, FDClr(Eq(fd), _));
  EXPECT_CALL(mock, GetErrno);
  ASSERT_EQ(sls.CreateSocket(), INVALID_FD);
}
@@ -136,8 +126,6 @@ TEST_F(SnoopLoggerSocketModuleTest, test_CreateSocket_fail_on_Listen) {
  EXPECT_CALL(mock, Bind);
  EXPECT_CALL(mock, Listen);
  EXPECT_CALL(mock, Close);
  EXPECT_CALL(mock, FDSet(Eq(fd), _));
  EXPECT_CALL(mock, FDClr(Eq(fd), _));
  EXPECT_CALL(mock, GetErrno);
  ASSERT_EQ(sls.CreateSocket(), INVALID_FD);
}
@@ -152,7 +140,6 @@ TEST_F(SnoopLoggerSocketModuleTest, test_CreateSocket_success) {
  EXPECT_CALL(mock, Setsockopt);
  EXPECT_CALL(mock, Bind);
  EXPECT_CALL(mock, Listen);
  EXPECT_CALL(mock, FDSet(Eq(fd), _));
  ASSERT_EQ(sls.CreateSocket(), fd);
}

@@ -170,7 +157,6 @@ TEST_F(SnoopLoggerSocketModuleTest, test_Write_fd_fail_on_Send_ECONNRESET) {

  EXPECT_CALL(mock, Send(Eq(fd), Eq(data), Eq(sizeof(data)), _));
  EXPECT_CALL(mock, Close(Eq(fd)));
  EXPECT_CALL(mock, FDClr(Eq(fd), _));
  EXPECT_CALL(mock, GetErrno);

  sls.Write(fd, data, sizeof(data));
@@ -213,7 +199,6 @@ TEST_F(SnoopLoggerSocketModuleTest, test_Write_success) {
  sls.ClientSocketConnected(client_fd);

  sls.Write(data, sizeof(data));
  EXPECT_CALL(mock, FDClr(Eq(client_fd), _));
}

TEST_F(SnoopLoggerSocketModuleTest, test_Write_fd_fail_on_Send_EINTR) {
@@ -242,8 +227,6 @@ TEST_F(SnoopLoggerSocketModuleTest, test_ClientSocketConnected) {

  EXPECT_CALL(mock, Close(Eq(fd))).Times(1);
  EXPECT_CALL(mock, Close(Eq(fd + 1))).Times(1);
  EXPECT_CALL(mock, FDClr(Eq(fd), _));
  EXPECT_CALL(mock, FDClr(Eq(fd + 1), _));

  sls.ClientSocketConnected(fd);

@@ -264,7 +247,6 @@ TEST_F(SnoopLoggerSocketModuleTest, test_WaitForClientSocketConnected) {
  ASSERT_TRUE(sls.WaitForClientSocketConnected());

  EXPECT_CALL(mock, Close(Eq(fd)));
  EXPECT_CALL(mock, FDClr(Eq(fd), _));
}

TEST_F(SnoopLoggerSocketModuleTest, test_InitializeClientSocket) {
@@ -347,7 +329,6 @@ TEST_F(SnoopLoggerSocketModuleTest, test_InitializeCommunications_fail_on_Pipe2)
  int ret = -9;

  ON_CALL(mock, Pipe2(_, _)).WillByDefault(Invoke([ret](int* /* fds */, int) { return ret; }));
  EXPECT_CALL(mock, FDZero);
  EXPECT_CALL(mock, Pipe2(_, _));

  ASSERT_EQ(sls.InitializeCommunications(), ret);
@@ -363,16 +344,12 @@ TEST_F(SnoopLoggerSocketModuleTest, test_InitializeCommunications_fail_on_Create
    return 0;
  }));

  EXPECT_CALL(mock, FDZero);
  EXPECT_CALL(mock, Pipe2(_, _));
  EXPECT_CALL(mock, FDSet(listen_fd, _));
  EXPECT_CALL(mock, Socket);
  EXPECT_CALL(mock, GetErrno);

  EXPECT_CALL(mock, Close(Eq(listen_fd)));
  EXPECT_CALL(mock, FDClr(Eq(listen_fd), _));
  EXPECT_CALL(mock, Close(Eq(write_fd)));
  EXPECT_CALL(mock, FDClr(Eq(write_fd), _));

  ASSERT_EQ(sls.InitializeCommunications(), -1);
}
@@ -381,54 +358,49 @@ TEST_F(SnoopLoggerSocketModuleTest, test_InitializeCommunications_success) {
  ASSERT_NO_FATAL_FAILURE(InitializeCommunicationsSuccess(sls, mock));
}

TEST_F(SnoopLoggerSocketModuleTest, test_ProcessIncomingRequest_fail_on_Select_EINTR) {
  ON_CALL(mock, Select).WillByDefault(Return(-1));
TEST_F(SnoopLoggerSocketModuleTest, test_ProcessIncomingRequest_fail_on_Poll_EINTR) {
  ON_CALL(mock, Poll).WillByDefault(Return(-1));
  ON_CALL(mock, GetErrno()).WillByDefault(Return(EINTR));

  EXPECT_CALL(mock, Select);
  EXPECT_CALL(mock, Poll);
  EXPECT_CALL(mock, GetErrno).Times(2);
  ASSERT_TRUE(sls.ProcessIncomingRequest());
}

TEST_F(SnoopLoggerSocketModuleTest, test_ProcessIncomingRequest_fail_on_Select_EINVAL) {
  ON_CALL(mock, Select).WillByDefault(Return(-1));
TEST_F(SnoopLoggerSocketModuleTest, test_ProcessIncomingRequest_fail_on_Poll_EINVAL) {
  ON_CALL(mock, Poll).WillByDefault(Return(-1));
  ON_CALL(mock, GetErrno()).WillByDefault(Return(EINVAL));

  EXPECT_CALL(mock, Select);
  EXPECT_CALL(mock, Poll);
  EXPECT_CALL(mock, GetErrno).Times(2);
  ASSERT_FALSE(sls.ProcessIncomingRequest());
}

TEST_F(SnoopLoggerSocketModuleTest, test_ProcessIncomingRequest_no_fds) {
  ON_CALL(mock, Select).WillByDefault(Return(0));

  EXPECT_CALL(mock, Select);
  ASSERT_TRUE(sls.ProcessIncomingRequest());
}

TEST_F(SnoopLoggerSocketModuleTest, test_ProcessIncomingRequest_FDIsSet_false) {
TEST_F(SnoopLoggerSocketModuleTest, test_ProcessIncomingRequest_spurious) {
  ASSERT_NO_FATAL_FAILURE(InitializeCommunicationsSuccess(sls, mock));

  ON_CALL(mock, Select).WillByDefault(Return(0));
  ON_CALL(mock, FDIsSet(fd, _)).WillByDefault(Return(false));
  ON_CALL(mock, FDIsSet(listen_fd, _)).WillByDefault(Return(false));
  ON_CALL(mock, Poll)
          .WillByDefault(Invoke([](struct pollfd* fds, nfds_t /*nfds*/, int /*timeout*/) {
            fds[0].revents = 0;
            fds[1].revents = 0;
            return 0;
          }));

  EXPECT_CALL(mock, Select);
  EXPECT_CALL(mock, FDIsSet(Eq(fd), _));
  EXPECT_CALL(mock, FDIsSet(Eq(listen_fd), _));
  EXPECT_CALL(mock, Poll);
  ASSERT_TRUE(sls.ProcessIncomingRequest());
}

TEST_F(SnoopLoggerSocketModuleTest, test_ProcessIncomingRequest_signal_close) {
  ASSERT_NO_FATAL_FAILURE(InitializeCommunicationsSuccess(sls, mock));

  ON_CALL(mock, Select).WillByDefault(Return(0));
  ON_CALL(mock, FDIsSet(fd, _)).WillByDefault(Return(false));
  ON_CALL(mock, FDIsSet(listen_fd, _)).WillByDefault(Return(true));
  ON_CALL(mock, Poll)
          .WillByDefault(Invoke([](struct pollfd* fds, nfds_t /*nfds*/, int /*timeout*/) {
            fds[0].revents = POLLIN;
            fds[1].revents = 0;
            return 0;
          }));

  EXPECT_CALL(mock, Select);
  EXPECT_CALL(mock, FDIsSet(Eq(fd), _));
  EXPECT_CALL(mock, FDIsSet(Eq(listen_fd), _));
  EXPECT_CALL(mock, Poll);
  ASSERT_FALSE(sls.ProcessIncomingRequest());
}

@@ -436,15 +408,17 @@ TEST_F(SnoopLoggerSocketModuleTest,
       test_ProcessIncomingRequest_signal_incoming_connection_fail_on_accept_exit) {
  ASSERT_NO_FATAL_FAILURE(InitializeCommunicationsSuccess(sls, mock));

  ON_CALL(mock, Select).WillByDefault(Return(0));
  ON_CALL(mock, FDIsSet(fd, _)).WillByDefault(Return(true));
  ON_CALL(mock, FDIsSet(listen_fd, _)).WillByDefault(Return(false));
  ON_CALL(mock, Poll)
          .WillByDefault(Invoke([](struct pollfd* fds, nfds_t /*nfds*/, int /*timeout*/) {
            fds[0].revents = 0;
            fds[1].revents = POLLIN;
            return 0;
          }));

  ON_CALL(mock, Accept(fd, _, _, _)).WillByDefault(Return(INVALID_FD));
  ON_CALL(mock, GetErrno()).WillByDefault(Return(EINVAL));

  EXPECT_CALL(mock, Select);
  EXPECT_CALL(mock, FDIsSet(Eq(fd), _));
  EXPECT_CALL(mock, Poll);
  EXPECT_CALL(mock, Accept(Eq(fd), _, _, _));
  EXPECT_CALL(mock, GetErrno);
  ASSERT_FALSE(sls.ProcessIncomingRequest());
@@ -454,15 +428,17 @@ TEST_F(SnoopLoggerSocketModuleTest,
       test_ProcessIncomingRequest_signal_incoming_connection_fail_on_accept_continue) {
  ASSERT_NO_FATAL_FAILURE(InitializeCommunicationsSuccess(sls, mock));

  ON_CALL(mock, Select).WillByDefault(Return(0));
  ON_CALL(mock, FDIsSet(fd, _)).WillByDefault(Return(true));
  ON_CALL(mock, FDIsSet(listen_fd, _)).WillByDefault(Return(false));
  ON_CALL(mock, Poll)
          .WillByDefault(Invoke([](struct pollfd* fds, nfds_t /*nfds*/, int /*timeout*/) {
            fds[0].revents = 0;
            fds[1].revents = POLLIN;
            return 0;
          }));

  ON_CALL(mock, Accept(fd, _, _, _)).WillByDefault(Return(INVALID_FD));
  ON_CALL(mock, GetErrno()).WillByDefault(Return(ENOMEM));

  EXPECT_CALL(mock, Select);
  EXPECT_CALL(mock, FDIsSet(Eq(fd), _));
  EXPECT_CALL(mock, Poll);
  EXPECT_CALL(mock, Accept(Eq(fd), _, _, _));
  EXPECT_CALL(mock, GetErrno);
  ASSERT_TRUE(sls.ProcessIncomingRequest());
@@ -474,22 +450,23 @@ TEST_F(SnoopLoggerSocketModuleTest,

  int client_fd = 23;

  ON_CALL(mock, Select).WillByDefault(Return(0));
  ON_CALL(mock, FDIsSet(fd, _)).WillByDefault(Return(true));
  ON_CALL(mock, FDIsSet(listen_fd, _)).WillByDefault(Return(false));
  ON_CALL(mock, Poll)
          .WillByDefault(Invoke([](struct pollfd* fds, nfds_t /*nfds*/, int /*timeout*/) {
            fds[0].revents = 0;
            fds[1].revents = POLLIN;
            return 0;
          }));

  ON_CALL(mock, Accept(fd, _, _, _)).WillByDefault(Return(client_fd));
  ON_CALL(mock, GetErrno()).WillByDefault(Return(0));

  EXPECT_CALL(mock, Send(client_fd, _, _, _)).Times(1);

  EXPECT_CALL(mock, Select);
  EXPECT_CALL(mock, FDIsSet(Eq(fd), _));
  EXPECT_CALL(mock, Poll);
  EXPECT_CALL(mock, Accept(Eq(fd), _, _, _));
  ASSERT_TRUE(sls.ProcessIncomingRequest());

  EXPECT_CALL(mock, Close(Eq(client_fd)));
  EXPECT_CALL(mock, FDClr(Eq(client_fd), _));
}

TEST_F(SnoopLoggerSocketModuleTest, test_NotifySocketListener_no_fd) {
+2 −11
Original line number Diff line number Diff line
@@ -94,17 +94,8 @@ int SyscallWrapperImpl::Pipe2(int* pipefd, int flags) {

int SyscallWrapperImpl::GetErrno() const { return errno_; }

void SyscallWrapperImpl::FDSet(int fd, fd_set* set) { FD_SET(fd, set); }

void SyscallWrapperImpl::FDClr(int fd, fd_set* set) { FD_CLR(fd, set); }

bool SyscallWrapperImpl::FDIsSet(int fd, fd_set* set) { return FD_ISSET(fd, set); }

void SyscallWrapperImpl::FDZero(fd_set* set) { FD_ZERO(set); }

int SyscallWrapperImpl::Select(int __nfds, fd_set* __readfds, fd_set* __writefds,
                               fd_set* __exceptfds, struct timeval* __timeout) {
  int ret = select(__nfds, __readfds, __writefds, __exceptfds, __timeout);
int SyscallWrapperImpl::Poll(struct pollfd* fds, nfds_t nfds, int timeout) {
  int ret = poll(fds, nfds, timeout);
  errno_ = errno;
  return ret;
}
+1 −10
Original line number Diff line number Diff line
@@ -48,16 +48,7 @@ class SyscallWrapperImpl : public SyscallWrapperInterface {

  int GetErrno() const;

  void FDSet(int fd, fd_set* set);

  void FDClr(int fd, fd_set* set);

  bool FDIsSet(int fd, fd_set* set);

  void FDZero(fd_set* set);

  int Select(int __nfds, fd_set* __readfds, fd_set* __writefds, fd_set* __exceptfds,
             struct timeval* __timeout);
  int Poll(struct pollfd* fds, nfds_t nfds, int timeout);

private:
  int errno_;
Loading