Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 406e3c4f authored by Myles Watson's avatar Myles Watson Committed by android-build-merger
Browse files

Bluetooth: Add a timeout in async_fd_watcher am: 7d42dcad

am: 92add17f

Change-Id: I4d3b7421997df9b9647e62294e58bfd783d6635a
parents 11239d83 92add17f
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -38,7 +38,9 @@ cc_library_shared {
cc_test_host {
    name: "bluetooth-vendor-interface-unit-tests",
    srcs: [
        "async_fd_watcher.cc",
        "bluetooth_address.cc",
        "test/async_fd_watcher_unittest.cc",
        "test/bluetooth_address_test.cc",
        "test/properties.cc",
    ],
+43 −8
Original line number Diff line number Diff line
@@ -50,6 +50,20 @@ int AsyncFdWatcher::WatchFdForNonBlockingReads(
  return 0;
}

int AsyncFdWatcher::ConfigureTimeout(
    const std::chrono::milliseconds timeout,
    const TimeoutCallback& on_timeout_callback) {
  // Add timeout and callback
  {
    std::unique_lock<std::mutex> guard(timeout_mutex_);
    timeout_cb_ = on_timeout_callback;
    timeout_ms_ = timeout;
  }

  notifyThread();
  return 0;
}

void AsyncFdWatcher::StopWatchingFileDescriptor() { stopThread(); }

AsyncFdWatcher::~AsyncFdWatcher() {}
@@ -86,6 +100,11 @@ int AsyncFdWatcher::stopThread() {
    read_fd_ = -1;
  }

  {
    std::unique_lock<std::mutex> guard(timeout_mutex_);
    timeout_cb_ = nullptr;
  }

  return 0;
}

@@ -104,21 +123,37 @@ void AsyncFdWatcher::ThreadRoutine() {
    FD_SET(notification_listen_fd_, &read_fds);
    FD_SET(read_fd_, &read_fds);

    // Wait until there is data available to read on some FD
    struct timeval timeout;
    struct timeval* timeout_ptr = NULL;
    if (timeout_ms_ > std::chrono::milliseconds(0)) {
      timeout.tv_sec = timeout_ms_.count() / 1000;
      timeout.tv_usec = (timeout_ms_.count() % 1000) * 1000;
      timeout_ptr = &timeout;
    }

    // Wait until there is data available to read on some FD.
    int nfds = std::max(notification_listen_fd_, read_fd_);
    int retval = select(nfds + 1, &read_fds, NULL, NULL, NULL);
    if (retval <= 0) continue;  // there was some error or a timeout
    int retval = select(nfds + 1, &read_fds, NULL, NULL, timeout_ptr);

    // Read data from the notification FD
    // There was some error.
    if (retval < 0) continue;

    // Timeout.
    if (retval == 0) {
      std::unique_lock<std::mutex> guard(timeout_mutex_);
      if (timeout_ms_ > std::chrono::milliseconds(0) && timeout_cb_)
        timeout_cb_();
      continue;
    }

    // Read data from the notification FD.
    if (FD_ISSET(notification_listen_fd_, &read_fds)) {
      char buffer[] = {0};
      TEMP_FAILURE_RETRY(read(notification_listen_fd_, buffer, 1));
      continue;
    }

    // Make sure we're still running
    if (!running_) break;

    // Invoke the data ready callback if appropriate
    // Invoke the data ready callback if appropriate.
    if (FD_ISSET(read_fd_, &read_fds)) {
      std::unique_lock<std::mutex> guard(internal_mutex_);
      if (cb_) cb_(read_fd_);
+6 −0
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ namespace V1_0 {
namespace implementation {

using ReadCallback = std::function<void(int)>;
using TimeoutCallback = std::function<void(void)>;

class AsyncFdWatcher {
 public:
@@ -34,6 +35,8 @@ class AsyncFdWatcher {

  int WatchFdForNonBlockingReads(int file_descriptor,
                                 const ReadCallback& on_read_fd_ready_callback);
  int ConfigureTimeout(const std::chrono::milliseconds timeout,
                       const TimeoutCallback& on_timeout_callback);
  void StopWatchingFileDescriptor();

 private:
@@ -48,11 +51,14 @@ class AsyncFdWatcher {
  std::atomic_bool running_{false};
  std::thread thread_;
  std::mutex internal_mutex_;
  std::mutex timeout_mutex_;

  int read_fd_;
  int notification_listen_fd_;
  int notification_write_fd_;
  ReadCallback cb_;
  TimeoutCallback timeout_cb_;
  std::chrono::milliseconds timeout_ms_;
};


+295 −0
Original line number Diff line number Diff line
//
// Copyright 2017 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

#include "async_fd_watcher.h"
#include <gtest/gtest.h>
#include <cstdint>
#include <cstring>
#include <vector>

#include <netdb.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <utils/Log.h>

namespace android {
namespace hardware {
namespace bluetooth {
namespace V1_0 {
namespace implementation {

class AsyncFdWatcherSocketTest : public ::testing::Test {
 public:
  static const uint16_t kPort = 6111;
  static const size_t kBufferSize = 16;

  bool CheckBufferEquals() {
    return strcmp(server_buffer_, client_buffer_) == 0;
  }

 protected:
  int StartServer() {
    ALOGD("%s", __func__);
    struct sockaddr_in serv_addr;
    int fd = socket(AF_INET, SOCK_STREAM, 0);
    EXPECT_FALSE(fd < 0);

    memset(&serv_addr, 0, sizeof(serv_addr));
    serv_addr.sin_family = AF_INET;
    serv_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    serv_addr.sin_port = htons(kPort);
    int reuse_flag = 1;
    EXPECT_FALSE(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_flag,
                            sizeof(reuse_flag)) < 0);
    EXPECT_FALSE(bind(fd, (sockaddr*)&serv_addr, sizeof(serv_addr)) <
                 0);

    ALOGD("%s before listen", __func__);
    listen(fd, 1);
    return fd;
  }

  int AcceptConnection(int fd) {
    ALOGD("%s", __func__);
    struct sockaddr_in cli_addr;
    memset(&cli_addr, 0, sizeof(cli_addr));
    socklen_t clilen = sizeof(cli_addr);

    int connection_fd = accept(fd, (struct sockaddr*)&cli_addr, &clilen);
    EXPECT_FALSE(connection_fd < 0);

    return connection_fd;
  }

  void ReadIncomingMessage(int fd) {
    ALOGD("%s", __func__);
    int n = TEMP_FAILURE_RETRY(read(fd, server_buffer_, kBufferSize - 1));
    EXPECT_FALSE(n < 0);

    if (n == 0)  // got EOF
      ALOGD("%s: EOF", __func__);
    else
      ALOGD("%s: Got something", __func__);
      n = write(fd, "1", 1);
  }

  void SetUp() override {
    ALOGD("%s", __func__);
    memset(server_buffer_, 0, kBufferSize);
    memset(client_buffer_, 0, kBufferSize);
  }

  void ConfigureServer() {
    socket_fd_ = StartServer();

    conn_watcher_.WatchFdForNonBlockingReads(socket_fd_, [this](int fd) {
      int connection_fd = AcceptConnection(fd);
      ALOGD("%s: Conn_watcher fd = %d", __func__, fd);

      conn_watcher_.ConfigureTimeout(std::chrono::seconds(0), [this]() { bool connection_timeout_cleared = false; ASSERT_TRUE(connection_timeout_cleared); });

      ALOGD("%s: 3", __func__);
      async_fd_watcher_.WatchFdForNonBlockingReads(
          connection_fd, [this](int fd) { ReadIncomingMessage(fd); });

      // Time out if it takes longer than a second.
      SetTimeout(std::chrono::seconds(1));
    });
    conn_watcher_.ConfigureTimeout(std::chrono::seconds(1), [this]() { bool connection_timeout = true; ASSERT_FALSE(connection_timeout); });
  }

  void CleanUpServer() {
    async_fd_watcher_.StopWatchingFileDescriptor();
    conn_watcher_.StopWatchingFileDescriptor();
    close(socket_fd_);
  }

  void TearDown() override {
    ALOGD("%s 3", __func__);
    EXPECT_TRUE(CheckBufferEquals());
  }

  void OnTimeout() {
    ALOGD("%s", __func__);
    timed_out_ = true;
  }

  void ClearTimeout() {
    ALOGD("%s", __func__);
    timed_out_ = false;
  }

  bool TimedOut() {
    ALOGD("%s %d", __func__, timed_out_? 1 : 0);
    return timed_out_;
  }

  void SetTimeout(std::chrono::milliseconds timeout_ms) {
    ALOGD("%s", __func__);
    async_fd_watcher_.ConfigureTimeout(timeout_ms, [this]() { OnTimeout(); });
    ClearTimeout();
  }

  int ConnectClient() {
    ALOGD("%s", __func__);
    int socket_cli_fd = socket(AF_INET, SOCK_STREAM, 0);
    EXPECT_FALSE(socket_cli_fd < 0);

    struct sockaddr_in serv_addr;
    memset((void*)&serv_addr, 0, sizeof(serv_addr));
    serv_addr.sin_family = AF_INET;
    serv_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    serv_addr.sin_port = htons(kPort);

    int result =
        connect(socket_cli_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
    EXPECT_FALSE(result < 0);

    return socket_cli_fd;
  }

  void WriteFromClient(int socket_cli_fd) {
    ALOGD("%s", __func__);
    strcpy(client_buffer_, "1");
    int n = write(socket_cli_fd, client_buffer_, strlen(client_buffer_));
    EXPECT_TRUE(n > 0);
  }

  void AwaitServerResponse(int socket_cli_fd) {
    ALOGD("%s", __func__);
    int n = read(socket_cli_fd, client_buffer_, 1);
    ALOGD("%s done", __func__);
    EXPECT_TRUE(n > 0);
  }

 private:
  AsyncFdWatcher async_fd_watcher_;
  AsyncFdWatcher conn_watcher_;
  int socket_fd_;
  char server_buffer_[kBufferSize];
  char client_buffer_[kBufferSize];
  bool timed_out_;
};

// Use a single AsyncFdWatcher to signal a connection to the server socket.
TEST_F(AsyncFdWatcherSocketTest, Connect) {
  int socket_fd = StartServer();

  AsyncFdWatcher conn_watcher;
  conn_watcher.WatchFdForNonBlockingReads(socket_fd, [this](int fd) {
    int connection_fd = AcceptConnection(fd);
    close(connection_fd);
  });

  // Fail if the client doesn't connect within 1 second.
  conn_watcher.ConfigureTimeout(std::chrono::seconds(1), [this]() {
     bool connection_timeout = true;
     ASSERT_FALSE(connection_timeout);
  });

  ConnectClient();
  conn_watcher.StopWatchingFileDescriptor();
  close(socket_fd);
}

// Use a single AsyncFdWatcher to signal a connection to the server socket.
TEST_F(AsyncFdWatcherSocketTest, TimedOutConnect) {
  int socket_fd = StartServer();
  bool timed_out = false;
  bool* timeout_ptr = &timed_out;

  AsyncFdWatcher conn_watcher;
  conn_watcher.WatchFdForNonBlockingReads(socket_fd, [this](int fd) {
    int connection_fd = AcceptConnection(fd);
    close(connection_fd);
  });

  // Set the timeout flag after 100ms.
  conn_watcher.ConfigureTimeout(std::chrono::milliseconds(100), [this, timeout_ptr]() { *timeout_ptr = true; });
  EXPECT_FALSE(timed_out);
  sleep(1);
  EXPECT_TRUE(timed_out);
  conn_watcher.StopWatchingFileDescriptor();
  close(socket_fd);
}

// Use two AsyncFdWatchers to set up a server socket.
TEST_F(AsyncFdWatcherSocketTest, ClientServer) {
  ConfigureServer();
  int socket_cli_fd = ConnectClient();

  WriteFromClient(socket_cli_fd);

  AwaitServerResponse(socket_cli_fd);

  close(socket_cli_fd);
  CleanUpServer();
}

// Use two AsyncFdWatchers to set up a server socket, which times out.
TEST_F(AsyncFdWatcherSocketTest, TimeOutTest) {
  ConfigureServer();
  int socket_cli_fd = ConnectClient();

  while (!TimedOut()) sleep(1);

  close(socket_cli_fd);
  CleanUpServer();
}

// Use two AsyncFdWatchers to set up a server socket, which times out.
TEST_F(AsyncFdWatcherSocketTest, RepeatedTimeOutTest) {
  ConfigureServer();
  int socket_cli_fd = ConnectClient();
  ClearTimeout();

  // Time out when there are no writes.
  EXPECT_FALSE(TimedOut());
  sleep(2);
  EXPECT_TRUE(TimedOut());
  ClearTimeout();

  // Don't time out when there is a write.
  WriteFromClient(socket_cli_fd);
  AwaitServerResponse(socket_cli_fd);
  EXPECT_FALSE(TimedOut());
  ClearTimeout();

  // Time out when the write is late.
  sleep(2);
  WriteFromClient(socket_cli_fd);
  AwaitServerResponse(socket_cli_fd);
  EXPECT_TRUE(TimedOut());
  ClearTimeout();

  // Time out when there is a pause after a write.
  WriteFromClient(socket_cli_fd);
  sleep(2);
  AwaitServerResponse(socket_cli_fd);
  EXPECT_TRUE(TimedOut());
  ClearTimeout();

  close(socket_cli_fd);
  CleanUpServer();
}

} // namespace implementation
} // namespace V1_0
} // namespace bluetooth
} // namespace hardware
} // namespace android