Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 50b49429 authored by Treehugger Robot's avatar Treehugger Robot Committed by Gerrit Code Review
Browse files

Merge "Make StopWatchingFileDescriptor deterministic."

parents 45af8f05 9393b324
Loading
Loading
Loading
Loading
+9 −13
Original line number Diff line number Diff line
@@ -96,7 +96,7 @@ class AsyncManager::AsyncFdWatcher {
  int WatchFdForNonBlockingReads(int file_descriptor, const ReadCallback& on_read_fd_ready_callback) {
    // add file descriptor and callback
    {
      std::unique_lock<std::mutex> guard(internal_mutex_);
      std::unique_lock<std::recursive_mutex> guard(internal_mutex_);
      watched_shared_fds_[file_descriptor] = on_read_fd_ready_callback;
    }

@@ -114,7 +114,7 @@ class AsyncManager::AsyncFdWatcher {
  }

  void StopWatchingFileDescriptor(int file_descriptor) {
    std::unique_lock<std::mutex> guard(internal_mutex_);
    std::unique_lock<std::recursive_mutex> guard(internal_mutex_);
    watched_shared_fds_.erase(file_descriptor);
  }

@@ -138,7 +138,7 @@ class AsyncManager::AsyncFdWatcher {
    }

    {
      std::unique_lock<std::mutex> guard(internal_mutex_);
      std::unique_lock<std::recursive_mutex> guard(internal_mutex_);
      watched_shared_fds_.clear();
    }

@@ -188,7 +188,7 @@ class AsyncManager::AsyncFdWatcher {

    // add watched FDs to the set
    {
      std::unique_lock<std::mutex> guard(internal_mutex_);
      std::unique_lock<std::recursive_mutex> guard(internal_mutex_);
      for (auto& fdp : watched_shared_fds_) {
        FD_SET(fdp.first, &read_fds);
        nfds = std::max(fdp.first, nfds);
@@ -211,17 +211,13 @@ class AsyncManager::AsyncFdWatcher {

  // check all file descriptors and call callbacks if necesary
  void runAppropriateCallbacks(fd_set& read_fds) {
    // not a good idea to call a callback while holding the FD lock,
    // nor to release the lock while traversing the map
    std::vector<decltype(watched_shared_fds_)::value_type> fds;
    {
      std::unique_lock<std::mutex> guard(internal_mutex_);
    std::unique_lock<std::recursive_mutex> guard(internal_mutex_);
    for (auto& fdc : watched_shared_fds_) {
      if (FD_ISSET(fdc.first, &read_fds)) {
        fds.push_back(fdc);
      }
    }
    }
    for (auto& p : fds) {
      p.second(p.first);
    }
@@ -256,7 +252,7 @@ class AsyncManager::AsyncFdWatcher {

  std::atomic_bool running_{false};
  std::thread thread_;
  std::mutex internal_mutex_;
  std::recursive_mutex internal_mutex_;

  std::map<int, ReadCallback> watched_shared_fds_;

+208 −24
Original line number Diff line number Diff line
@@ -16,18 +16,51 @@

#include "model/setup/async_manager.h"

#include <gtest/gtest.h>
#include <cstdint>
#include <cstring>
#include <vector>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <fcntl.h>        // for fcntl, F_SETFL, O_NONBLOCK
#include <gtest/gtest.h>  // for Message, TestPartResult, SuiteApi...
#include <netdb.h>        // for gethostbyname, h_addr, hostent
#include <netinet/in.h>   // for sockaddr_in, in_addr, INADDR_ANY
#include <stdio.h>        // for printf
#include <sys/socket.h>   // for socket, AF_INET, accept, bind
#include <sys/types.h>    // for in_addr_t
#include <time.h>         // for NULL, size_t
#include <unistd.h>       // for close, write, read

#include <condition_variable>  // for condition_variable
#include <cstdint>             // for uint16_t
#include <cstring>             // for memset, strcmp, strcpy, strlen
#include <mutex>               // for mutex
#include <ratio>               // for ratio
#include <string>              // for string
#include <tuple>               // for tuple

#include "osi/include/osi.h"  // for OSI_NO_INTR

namespace test_vendor_lib {

class Event {
 public:
  void set(bool set = true) {
    std::unique_lock<std::mutex> lk(m_);
    set_ = set;
    cv_.notify_all();
  }

  void reset() { set(false); }

  bool wait_for(std::chrono::microseconds timeout) {
    std::unique_lock<std::mutex> lk(m_);
    return cv_.wait_for(lk, timeout, [&] { return set_; });
  }

  bool operator*() { return set_; }

 private:
  std::mutex m_;
  std::condition_variable cv_;
  bool set_{false};
};

class AsyncManagerSocketTest : public ::testing::Test {
 public:
  static const uint16_t kPort = 6111;
@@ -39,16 +72,16 @@ class AsyncManagerSocketTest : public ::testing::Test {

 protected:
  int StartServer() {
    struct sockaddr_in serv_addr;
    struct sockaddr_in serv_addr = {};
    int fd = socket(AF_INET, SOCK_STREAM, 0);
    EXPECT_FALSE(fd < 0);

    memset(&serv_addr, 0, sizeof(serv_addr));
    serv_addr.sin_family = AF_INET;
    serv_addr.sin_addr.s_addr = INADDR_ANY;
    serv_addr.sin_port = htons(kPort);
    int reuse_flag = 1;
    EXPECT_FALSE(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_flag, sizeof(reuse_flag)) < 0);
    EXPECT_FALSE(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_flag,
                            sizeof(reuse_flag)) < 0);
    EXPECT_FALSE(bind(fd, (sockaddr*)&serv_addr, sizeof(serv_addr)) < 0);

    listen(fd, 1);
@@ -66,9 +99,19 @@ class AsyncManagerSocketTest : public ::testing::Test {
    return connection_fd;
  }

  std::tuple<int, int> ConnectSocketPair() {
    int cli = ConnectClient();
    WriteFromClient(cli);
    AwaitServerResponse(cli);
    int ser = connection_fd_;
    connection_fd_ = -1;
    return {cli, ser};
  }

  void ReadIncomingMessage(int fd) {
    int n = TEMP_FAILURE_RETRY(read(fd, server_buffer_, kBufferSize - 1));
    ASSERT_FALSE(n < 0);
    int n;
    OSI_NO_INTR(n = read(fd, server_buffer_, kBufferSize - 1));
    ASSERT_GE(n, 0) << strerror(errno);

    if (n == 0) {  // got EOF
      async_manager_.StopWatchingFileDescriptor(fd);
@@ -84,9 +127,10 @@ class AsyncManagerSocketTest : public ::testing::Test {
    socket_fd_ = StartServer();

    async_manager_.WatchFdForNonBlockingReads(socket_fd_, [this](int fd) {
      int connection_fd = AcceptConnection(fd);
      connection_fd_ = AcceptConnection(fd);

      async_manager_.WatchFdForNonBlockingReads(connection_fd, [this](int fd) { ReadIncomingMessage(fd); });
      async_manager_.WatchFdForNonBlockingReads(
          connection_fd_, [this](int fd) { ReadIncomingMessage(fd); });
    });
  }

@@ -98,11 +142,11 @@ class AsyncManagerSocketTest : public ::testing::Test {

  int ConnectClient() {
    int socket_cli_fd = socket(AF_INET, SOCK_STREAM, 0);
    EXPECT_FALSE(socket_cli_fd < 0);
    EXPECT_GE(socket_cli_fd, 0) << strerror(errno);

    struct hostent* server;
    server = gethostbyname("localhost");
    EXPECT_FALSE(server == NULL);
    EXPECT_FALSE(server == NULL) << strerror(errno);

    struct sockaddr_in serv_addr;
    memset((void*)&serv_addr, 0, sizeof(serv_addr));
@@ -110,8 +154,9 @@ class AsyncManagerSocketTest : public ::testing::Test {
    serv_addr.sin_addr.s_addr = *(reinterpret_cast<in_addr_t*>(server->h_addr));
    serv_addr.sin_port = htons(kPort);

    int result = connect(socket_cli_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
    EXPECT_FALSE(result < 0);
    int result =
        connect(socket_cli_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
    EXPECT_GE(result, 0) << strerror(errno);

    return socket_cli_fd;
  }
@@ -119,17 +164,18 @@ class AsyncManagerSocketTest : public ::testing::Test {
  void WriteFromClient(int socket_cli_fd) {
    strcpy(client_buffer_, "1");
    int n = write(socket_cli_fd, client_buffer_, strlen(client_buffer_));
    ASSERT_TRUE(n > 0);
    ASSERT_GT(n, 0) << strerror(errno);
  }

  void AwaitServerResponse(int socket_cli_fd) {
    int n = read(socket_cli_fd, client_buffer_, 1);
    ASSERT_TRUE(n > 0);
    ASSERT_GT(n, 0) << strerror(errno);
  }

 private:
 protected:
  AsyncManager async_manager_;
  int socket_fd_;
  int connection_fd_;
  char server_buffer_[kBufferSize];
  char client_buffer_[kBufferSize];
};
@@ -144,6 +190,144 @@ TEST_F(AsyncManagerSocketTest, TestOneConnection) {
  close(socket_cli_fd);
}

TEST_F(AsyncManagerSocketTest, CanUnsubscribeInCallback) {
  int socket_cli_fd = ConnectClient();
  WriteFromClient(socket_cli_fd);
  AwaitServerResponse(socket_cli_fd);
  fcntl(connection_fd_, F_SETFL, O_NONBLOCK);

  std::string data('x', 32);

  bool stopped = false;
  async_manager_.WatchFdForNonBlockingReads(connection_fd_, [&](int fd) {
    async_manager_.StopWatchingFileDescriptor(fd);
    char buf[32];
    while (read(fd, buf, sizeof(buf)) > 0)
      ;
    stopped = true;
  });

  while (!stopped) {
    write(socket_cli_fd, data.data(), data.size());
  }

  SUCCEED();
  close(socket_cli_fd);
}

TEST_F(AsyncManagerSocketTest, NoEventsAfterUnsubscribe) {
  // This tests makes sure the AsyncManager never fires an event
  // after calling StopWatchingFileDescriptor.
  using clock = std::chrono::system_clock;
  using namespace std::chrono_literals;

  clock::time_point time_fast_called;
  clock::time_point time_slow_called;
  clock::time_point time_stopped_listening;

  int round = 0;
  auto [slow_cli_fd, slow_s_fd] = ConnectSocketPair();
  fcntl(slow_s_fd, F_SETFL, O_NONBLOCK);

  auto [fast_cli_fd, fast_s_fd] = ConnectSocketPair();
  fcntl(fast_s_fd, F_SETFL, O_NONBLOCK);

  std::string data(1, 'x');

  // The idea here is as follows:
  // We want to make sure that an unsubscribed callback never gets called.
  // This is to make sure we can safely do things like this:
  //
  // class Foo {
  //   Foo(int fd, AsyncManager* am) : fd_(fd), am_(am) {
  //     am_->WatchFdForNonBlockingReads(
  //         fd, [&](int fd) { printf("This shouldn't crash! %p\n", this); });
  //   }
  //   ~Foo() { am_->StopWatchingFileDescriptor(fd_); }
  //
  //   AsyncManager* am_;
  //   int fd_;
  // };
  //
  // We are going to force a failure as follows:
  //
  // The slow callback needs to be called first, if it does not we cannot
  // force failure, so we have to try multiple times.
  //
  // t1, is the thread doing the loop.
  // t2, is the async manager handler thread.
  //
  // t1 will block until the slowcallback.
  // t2 will now block (for at most 250 ms).
  // t1 will unsubscribe the fast callback.
  // 2 cases:
  //   with bug:
  //      - t1 takes a timestamp, unblocks t2,
  //      - t2 invokes the fast callback, and gets a timestamp.
  //      - Now the unsubscribe time is before the callback time.
  //   without bug.:
  //      - t1 locks un unsusbcribe in asyn manager
  //      - t2 unlocks due to timeout,
  //      - t2 invokes the fast callback, and gets a timestamp.
  //      - t1 is unlocked and gets a timestamp.
  //      - Now the unsubscribe time is after the callback time..

  do {
    Event unblock_slow, inslow, infast;
    time_fast_called = {};
    time_slow_called = {};
    time_stopped_listening = {};
    printf("round: %d\n", round++);

    // Register fd events
    async_manager_.WatchFdForNonBlockingReads(slow_s_fd, [&](int /*fd*/) {
      if (*inslow) return;
      time_slow_called = clock::now();
      printf("slow: %lld\n",
             time_slow_called.time_since_epoch().count() % 10000);
      inslow.set();
      unblock_slow.wait_for(25ms);
    });

    async_manager_.WatchFdForNonBlockingReads(fast_s_fd, [&](int /*fd*/) {
      if (*infast) return;
      time_fast_called = clock::now();
      printf("fast: %lld\n",
             time_fast_called.time_since_epoch().count() % 10000);
      infast.set();
    });

    // Generate fd events
    write(fast_cli_fd, data.data(), data.size());
    write(slow_cli_fd, data.data(), data.size());

    // Block in the right places.
    if (inslow.wait_for(25ms)) {
      async_manager_.StopWatchingFileDescriptor(fast_s_fd);
      time_stopped_listening = clock::now();
      printf("stop: %lld\n",
             time_stopped_listening.time_since_epoch().count() % 10000);
      unblock_slow.set();
    }

    infast.wait_for(25ms);

    // Unregister.
    async_manager_.StopWatchingFileDescriptor(fast_s_fd);
    async_manager_.StopWatchingFileDescriptor(slow_s_fd);
  } while (time_fast_called < time_slow_called);

  // fast before stop listening.
  ASSERT_LT(time_fast_called.time_since_epoch().count(),
            time_stopped_listening.time_since_epoch().count());

  // Cleanup
  close(fast_cli_fd);
  close(fast_s_fd);
  close(slow_cli_fd);
  close(slow_s_fd);
}

TEST_F(AsyncManagerSocketTest, TestRepeatedConnections) {
  static const int num_connections = 30;
  for (int i = 0; i < num_connections; i++) {
@@ -159,7 +343,7 @@ TEST_F(AsyncManagerSocketTest, TestMultipleConnections) {
  int socket_cli_fd[num_connections];
  for (int i = 0; i < num_connections; i++) {
    socket_cli_fd[i] = ConnectClient();
    EXPECT_TRUE(socket_cli_fd[i] > 0);
    ASSERT_TRUE(socket_cli_fd[i] > 0);
    WriteFromClient(socket_cli_fd[i]);
  }
  for (int i = 0; i < num_connections; i++) {