Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 9393b324 authored by Erwin Jansen's avatar Erwin Jansen
Browse files

Make StopWatchingFileDescriptor deterministic.

The AsyncManager allows for registering and unregistering watchers for
sockets. Previously it was possible to receive a callback even though
unregister was called from a separate thread.

This makes it very difficult to properly manage object lifetime for an
object that took a dependency on a callback, as a callback could arrive
while the object was destroyed.

We now introduce a recursive_mutex to protect the subscriptions. This
guarantees:

- We will block unsubscription until all relevant callbacks have
  finished.
- We can still stop watching from the callback thread.

This guarantees that we will never get a callback after unsubscription.

Includes 2 unit tests:

1. Guarantee that we do not deadlock in unsubscription from a callback.
2. Attempt to validate that a callback never happens after unsubscribe.

The 2nd test will fail when the recursive mutex is removed.

Test: 2 additional unit tests.
Change-Id: I2b9cdb0c31f88e67bbaca0667bd8822f69096995
parent 50bbf0ad
Loading
Loading
Loading
Loading
+9 −13
Original line number Diff line number Diff line
@@ -96,7 +96,7 @@ class AsyncManager::AsyncFdWatcher {
  int WatchFdForNonBlockingReads(int file_descriptor, const ReadCallback& on_read_fd_ready_callback) {
    // add file descriptor and callback
    {
      std::unique_lock<std::mutex> guard(internal_mutex_);
      std::unique_lock<std::recursive_mutex> guard(internal_mutex_);
      watched_shared_fds_[file_descriptor] = on_read_fd_ready_callback;
    }

@@ -114,7 +114,7 @@ class AsyncManager::AsyncFdWatcher {
  }

  void StopWatchingFileDescriptor(int file_descriptor) {
    std::unique_lock<std::mutex> guard(internal_mutex_);
    std::unique_lock<std::recursive_mutex> guard(internal_mutex_);
    watched_shared_fds_.erase(file_descriptor);
  }

@@ -138,7 +138,7 @@ class AsyncManager::AsyncFdWatcher {
    }

    {
      std::unique_lock<std::mutex> guard(internal_mutex_);
      std::unique_lock<std::recursive_mutex> guard(internal_mutex_);
      watched_shared_fds_.clear();
    }

@@ -188,7 +188,7 @@ class AsyncManager::AsyncFdWatcher {

    // add watched FDs to the set
    {
      std::unique_lock<std::mutex> guard(internal_mutex_);
      std::unique_lock<std::recursive_mutex> guard(internal_mutex_);
      for (auto& fdp : watched_shared_fds_) {
        FD_SET(fdp.first, &read_fds);
        nfds = std::max(fdp.first, nfds);
@@ -211,17 +211,13 @@ class AsyncManager::AsyncFdWatcher {

  // check all file descriptors and call callbacks if necesary
  void runAppropriateCallbacks(fd_set& read_fds) {
    // not a good idea to call a callback while holding the FD lock,
    // nor to release the lock while traversing the map
    std::vector<decltype(watched_shared_fds_)::value_type> fds;
    {
      std::unique_lock<std::mutex> guard(internal_mutex_);
    std::unique_lock<std::recursive_mutex> guard(internal_mutex_);
    for (auto& fdc : watched_shared_fds_) {
      if (FD_ISSET(fdc.first, &read_fds)) {
        fds.push_back(fdc);
      }
    }
    }
    for (auto& p : fds) {
      p.second(p.first);
    }
@@ -256,7 +252,7 @@ class AsyncManager::AsyncFdWatcher {

  std::atomic_bool running_{false};
  std::thread thread_;
  std::mutex internal_mutex_;
  std::recursive_mutex internal_mutex_;

  std::map<int, ReadCallback> watched_shared_fds_;

+208 −24
Original line number Diff line number Diff line
@@ -16,18 +16,51 @@

#include "model/setup/async_manager.h"

#include <gtest/gtest.h>
#include <cstdint>
#include <cstring>
#include <vector>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <fcntl.h>        // for fcntl, F_SETFL, O_NONBLOCK
#include <gtest/gtest.h>  // for Message, TestPartResult, SuiteApi...
#include <netdb.h>        // for gethostbyname, h_addr, hostent
#include <netinet/in.h>   // for sockaddr_in, in_addr, INADDR_ANY
#include <stdio.h>        // for printf
#include <sys/socket.h>   // for socket, AF_INET, accept, bind
#include <sys/types.h>    // for in_addr_t
#include <time.h>         // for NULL, size_t
#include <unistd.h>       // for close, write, read

#include <condition_variable>  // for condition_variable
#include <cstdint>             // for uint16_t
#include <cstring>             // for memset, strcmp, strcpy, strlen
#include <mutex>               // for mutex
#include <ratio>               // for ratio
#include <string>              // for string
#include <tuple>               // for tuple

#include "osi/include/osi.h"  // for OSI_NO_INTR

namespace test_vendor_lib {

class Event {
 public:
  void set(bool set = true) {
    std::unique_lock<std::mutex> lk(m_);
    set_ = set;
    cv_.notify_all();
  }

  void reset() { set(false); }

  bool wait_for(std::chrono::microseconds timeout) {
    std::unique_lock<std::mutex> lk(m_);
    return cv_.wait_for(lk, timeout, [&] { return set_; });
  }

  bool operator*() { return set_; }

 private:
  std::mutex m_;
  std::condition_variable cv_;
  bool set_{false};
};

class AsyncManagerSocketTest : public ::testing::Test {
 public:
  static const uint16_t kPort = 6111;
@@ -39,16 +72,16 @@ class AsyncManagerSocketTest : public ::testing::Test {

 protected:
  int StartServer() {
    struct sockaddr_in serv_addr;
    struct sockaddr_in serv_addr = {};
    int fd = socket(AF_INET, SOCK_STREAM, 0);
    EXPECT_FALSE(fd < 0);

    memset(&serv_addr, 0, sizeof(serv_addr));
    serv_addr.sin_family = AF_INET;
    serv_addr.sin_addr.s_addr = INADDR_ANY;
    serv_addr.sin_port = htons(kPort);
    int reuse_flag = 1;
    EXPECT_FALSE(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_flag, sizeof(reuse_flag)) < 0);
    EXPECT_FALSE(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_flag,
                            sizeof(reuse_flag)) < 0);
    EXPECT_FALSE(bind(fd, (sockaddr*)&serv_addr, sizeof(serv_addr)) < 0);

    listen(fd, 1);
@@ -66,9 +99,19 @@ class AsyncManagerSocketTest : public ::testing::Test {
    return connection_fd;
  }

  std::tuple<int, int> ConnectSocketPair() {
    int cli = ConnectClient();
    WriteFromClient(cli);
    AwaitServerResponse(cli);
    int ser = connection_fd_;
    connection_fd_ = -1;
    return {cli, ser};
  }

  void ReadIncomingMessage(int fd) {
    int n = TEMP_FAILURE_RETRY(read(fd, server_buffer_, kBufferSize - 1));
    ASSERT_FALSE(n < 0);
    int n;
    OSI_NO_INTR(n = read(fd, server_buffer_, kBufferSize - 1));
    ASSERT_GE(n, 0) << strerror(errno);

    if (n == 0) {  // got EOF
      async_manager_.StopWatchingFileDescriptor(fd);
@@ -84,9 +127,10 @@ class AsyncManagerSocketTest : public ::testing::Test {
    socket_fd_ = StartServer();

    async_manager_.WatchFdForNonBlockingReads(socket_fd_, [this](int fd) {
      int connection_fd = AcceptConnection(fd);
      connection_fd_ = AcceptConnection(fd);

      async_manager_.WatchFdForNonBlockingReads(connection_fd, [this](int fd) { ReadIncomingMessage(fd); });
      async_manager_.WatchFdForNonBlockingReads(
          connection_fd_, [this](int fd) { ReadIncomingMessage(fd); });
    });
  }

@@ -98,11 +142,11 @@ class AsyncManagerSocketTest : public ::testing::Test {

  int ConnectClient() {
    int socket_cli_fd = socket(AF_INET, SOCK_STREAM, 0);
    EXPECT_FALSE(socket_cli_fd < 0);
    EXPECT_GE(socket_cli_fd, 0) << strerror(errno);

    struct hostent* server;
    server = gethostbyname("localhost");
    EXPECT_FALSE(server == NULL);
    EXPECT_FALSE(server == NULL) << strerror(errno);

    struct sockaddr_in serv_addr;
    memset((void*)&serv_addr, 0, sizeof(serv_addr));
@@ -110,8 +154,9 @@ class AsyncManagerSocketTest : public ::testing::Test {
    serv_addr.sin_addr.s_addr = *(reinterpret_cast<in_addr_t*>(server->h_addr));
    serv_addr.sin_port = htons(kPort);

    int result = connect(socket_cli_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
    EXPECT_FALSE(result < 0);
    int result =
        connect(socket_cli_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
    EXPECT_GE(result, 0) << strerror(errno);

    return socket_cli_fd;
  }
@@ -119,17 +164,18 @@ class AsyncManagerSocketTest : public ::testing::Test {
  void WriteFromClient(int socket_cli_fd) {
    strcpy(client_buffer_, "1");
    int n = write(socket_cli_fd, client_buffer_, strlen(client_buffer_));
    ASSERT_TRUE(n > 0);
    ASSERT_GT(n, 0) << strerror(errno);
  }

  void AwaitServerResponse(int socket_cli_fd) {
    int n = read(socket_cli_fd, client_buffer_, 1);
    ASSERT_TRUE(n > 0);
    ASSERT_GT(n, 0) << strerror(errno);
  }

 private:
 protected:
  AsyncManager async_manager_;
  int socket_fd_;
  int connection_fd_;
  char server_buffer_[kBufferSize];
  char client_buffer_[kBufferSize];
};
@@ -144,6 +190,144 @@ TEST_F(AsyncManagerSocketTest, TestOneConnection) {
  close(socket_cli_fd);
}

TEST_F(AsyncManagerSocketTest, CanUnsubscribeInCallback) {
  int socket_cli_fd = ConnectClient();
  WriteFromClient(socket_cli_fd);
  AwaitServerResponse(socket_cli_fd);
  fcntl(connection_fd_, F_SETFL, O_NONBLOCK);

  std::string data('x', 32);

  bool stopped = false;
  async_manager_.WatchFdForNonBlockingReads(connection_fd_, [&](int fd) {
    async_manager_.StopWatchingFileDescriptor(fd);
    char buf[32];
    while (read(fd, buf, sizeof(buf)) > 0)
      ;
    stopped = true;
  });

  while (!stopped) {
    write(socket_cli_fd, data.data(), data.size());
  }

  SUCCEED();
  close(socket_cli_fd);
}

TEST_F(AsyncManagerSocketTest, NoEventsAfterUnsubscribe) {
  // This tests makes sure the AsyncManager never fires an event
  // after calling StopWatchingFileDescriptor.
  using clock = std::chrono::system_clock;
  using namespace std::chrono_literals;

  clock::time_point time_fast_called;
  clock::time_point time_slow_called;
  clock::time_point time_stopped_listening;

  int round = 0;
  auto [slow_cli_fd, slow_s_fd] = ConnectSocketPair();
  fcntl(slow_s_fd, F_SETFL, O_NONBLOCK);

  auto [fast_cli_fd, fast_s_fd] = ConnectSocketPair();
  fcntl(fast_s_fd, F_SETFL, O_NONBLOCK);

  std::string data(1, 'x');

  // The idea here is as follows:
  // We want to make sure that an unsubscribed callback never gets called.
  // This is to make sure we can safely do things like this:
  //
  // class Foo {
  //   Foo(int fd, AsyncManager* am) : fd_(fd), am_(am) {
  //     am_->WatchFdForNonBlockingReads(
  //         fd, [&](int fd) { printf("This shouldn't crash! %p\n", this); });
  //   }
  //   ~Foo() { am_->StopWatchingFileDescriptor(fd_); }
  //
  //   AsyncManager* am_;
  //   int fd_;
  // };
  //
  // We are going to force a failure as follows:
  //
  // The slow callback needs to be called first, if it does not we cannot
  // force failure, so we have to try multiple times.
  //
  // t1, is the thread doing the loop.
  // t2, is the async manager handler thread.
  //
  // t1 will block until the slowcallback.
  // t2 will now block (for at most 250 ms).
  // t1 will unsubscribe the fast callback.
  // 2 cases:
  //   with bug:
  //      - t1 takes a timestamp, unblocks t2,
  //      - t2 invokes the fast callback, and gets a timestamp.
  //      - Now the unsubscribe time is before the callback time.
  //   without bug.:
  //      - t1 locks un unsusbcribe in asyn manager
  //      - t2 unlocks due to timeout,
  //      - t2 invokes the fast callback, and gets a timestamp.
  //      - t1 is unlocked and gets a timestamp.
  //      - Now the unsubscribe time is after the callback time..

  do {
    Event unblock_slow, inslow, infast;
    time_fast_called = {};
    time_slow_called = {};
    time_stopped_listening = {};
    printf("round: %d\n", round++);

    // Register fd events
    async_manager_.WatchFdForNonBlockingReads(slow_s_fd, [&](int /*fd*/) {
      if (*inslow) return;
      time_slow_called = clock::now();
      printf("slow: %lld\n",
             time_slow_called.time_since_epoch().count() % 10000);
      inslow.set();
      unblock_slow.wait_for(25ms);
    });

    async_manager_.WatchFdForNonBlockingReads(fast_s_fd, [&](int /*fd*/) {
      if (*infast) return;
      time_fast_called = clock::now();
      printf("fast: %lld\n",
             time_fast_called.time_since_epoch().count() % 10000);
      infast.set();
    });

    // Generate fd events
    write(fast_cli_fd, data.data(), data.size());
    write(slow_cli_fd, data.data(), data.size());

    // Block in the right places.
    if (inslow.wait_for(25ms)) {
      async_manager_.StopWatchingFileDescriptor(fast_s_fd);
      time_stopped_listening = clock::now();
      printf("stop: %lld\n",
             time_stopped_listening.time_since_epoch().count() % 10000);
      unblock_slow.set();
    }

    infast.wait_for(25ms);

    // Unregister.
    async_manager_.StopWatchingFileDescriptor(fast_s_fd);
    async_manager_.StopWatchingFileDescriptor(slow_s_fd);
  } while (time_fast_called < time_slow_called);

  // fast before stop listening.
  ASSERT_LT(time_fast_called.time_since_epoch().count(),
            time_stopped_listening.time_since_epoch().count());

  // Cleanup
  close(fast_cli_fd);
  close(fast_s_fd);
  close(slow_cli_fd);
  close(slow_s_fd);
}

TEST_F(AsyncManagerSocketTest, TestRepeatedConnections) {
  static const int num_connections = 30;
  for (int i = 0; i < num_connections; i++) {
@@ -159,7 +343,7 @@ TEST_F(AsyncManagerSocketTest, TestMultipleConnections) {
  int socket_cli_fd[num_connections];
  for (int i = 0; i < num_connections; i++) {
    socket_cli_fd[i] = ConnectClient();
    EXPECT_TRUE(socket_cli_fd[i] > 0);
    ASSERT_TRUE(socket_cli_fd[i] > 0);
    WriteFromClient(socket_cli_fd[i]);
  }
  for (int i = 0; i < num_connections; i++) {