Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 04ebca20 authored by Chris Manton's avatar Chris Manton
Browse files

le_impl: Add set privacy address unittest

Bug: 242901573
Test: atest bluetooth_test_gd_unit bluetooth_test_gd
Tag: #refactor
BYPASS_LONG_LINES_REASON: Bluetooth likes 120 lines
Ignore-AOSP-First: Cherry-pick

Merged-In: I1431763a2f9193e70f351684329d1d57d2d4a031
Change-Id: I1431763a2f9193e70f351684329d1d57d2d4a031
parent 572a11b3
Loading
Loading
Loading
Loading
+134 −8
Original line number Original line Diff line number Diff line
@@ -20,9 +20,11 @@
#include <gtest/gtest.h>
#include <gtest/gtest.h>


#include <chrono>
#include <chrono>
#include <mutex>


#include "common/bidi_queue.h"
#include "common/bidi_queue.h"
#include "common/callback.h"
#include "common/callback.h"
#include "common/testing/log_capture.h"
#include "hci/acl_manager.h"
#include "hci/acl_manager.h"
#include "hci/acl_manager/le_connection_management_callbacks.h"
#include "hci/acl_manager/le_connection_management_callbacks.h"
#include "hci/address_with_type.h"
#include "hci/address_with_type.h"
@@ -32,12 +34,27 @@
#include "os/log.h"
#include "os/log.h"
#include "packet/raw_builder.h"
#include "packet/raw_builder.h"


using namespace bluetooth;
using namespace std::chrono_literals;
using namespace std::chrono_literals;


using ::bluetooth::common::BidiQueue;
using ::bluetooth::common::BidiQueue;
using ::bluetooth::common::Callback;
using ::bluetooth::common::Callback;
using ::bluetooth::os::Handler;
using ::bluetooth::os::Handler;
using ::bluetooth::os::Thread;
using ::bluetooth::os::Thread;
using ::bluetooth::testing::LogCapture;

namespace {
constexpr char kFixedAddress[] = "c0:aa:bb:cc:dd:ee";
constexpr char kRemoteAddress[] = "00:11:22:33:44:55";
[[maybe_unused]] constexpr bool kAddToFilterAcceptList = true;
[[maybe_unused]] constexpr bool kSkipFilterAcceptList = !kAddToFilterAcceptList;
[[maybe_unused]] constexpr bool kIsDirectConnection = true;
[[maybe_unused]] constexpr bool kIsBackgroundConnection = !kIsDirectConnection;
[[maybe_unused]] constexpr ::bluetooth::crypto_toolbox::Octet16 kRotationIrk = {};
[[maybe_unused]] constexpr std::chrono::milliseconds kMinimumRotationTime(14 * 1000);
[[maybe_unused]] constexpr std::chrono::milliseconds kMaximumRotationTime(16 * 1000);

}  // namespace


namespace bluetooth {
namespace bluetooth {
namespace hci {
namespace hci {
@@ -101,6 +118,8 @@ class TestController : public Controller {
};
};


class TestHciLayer : public HciLayer {
class TestHciLayer : public HciLayer {
  // This is a springboard class that converts from `AclCommandBuilder`
  // to `ComandBuilder` for use in the hci layer.
  template <typename T>
  template <typename T>
  class CommandInterfaceImpl : public CommandInterface<T> {
  class CommandInterfaceImpl : public CommandInterface<T> {
   public:
   public:
@@ -122,6 +141,7 @@ class TestHciLayer : public HciLayer {
  void EnqueueCommand(
  void EnqueueCommand(
      std::unique_ptr<CommandBuilder> command,
      std::unique_ptr<CommandBuilder> command,
      common::ContextualOnceCallback<void(CommandStatusView)> on_status) override {
      common::ContextualOnceCallback<void(CommandStatusView)> on_status) override {
    const std::lock_guard<std::mutex> lock(command_queue_mutex_);
    command_queue_.push(std::move(command));
    command_queue_.push(std::move(command));
    command_status_callbacks.push_back(std::move(on_status));
    command_status_callbacks.push_back(std::move(on_status));
    if (command_promise_ != nullptr) {
    if (command_promise_ != nullptr) {
@@ -133,6 +153,7 @@ class TestHciLayer : public HciLayer {
  void EnqueueCommand(
  void EnqueueCommand(
      std::unique_ptr<CommandBuilder> command,
      std::unique_ptr<CommandBuilder> command,
      common::ContextualOnceCallback<void(CommandCompleteView)> on_complete) override {
      common::ContextualOnceCallback<void(CommandCompleteView)> on_complete) override {
    const std::lock_guard<std::mutex> lock(command_queue_mutex_);
    command_queue_.push(std::move(command));
    command_queue_.push(std::move(command));
    command_complete_callbacks.push_back(std::move(on_complete));
    command_complete_callbacks.push_back(std::move(on_complete));
    if (command_promise_ != nullptr) {
    if (command_promise_ != nullptr) {
@@ -141,6 +162,31 @@ class TestHciLayer : public HciLayer {
    }
    }
  }
  }


  std::unique_ptr<CommandBuilder> DequeueCommand() {
    const std::lock_guard<std::mutex> lock(command_queue_mutex_);
    auto packet = std::move(command_queue_.front());
    command_queue_.pop();
    return std::move(packet);
  }

  std::shared_ptr<std::vector<uint8_t>> DequeueCommandBytes() {
    auto command = DequeueCommand();
    auto bytes = std::make_shared<std::vector<uint8_t>>();
    packet::BitInserter bi(*bytes);
    command->Serialize(bi);
    return bytes;
  }

  bool IsPacketQueueEmpty() const {
    const std::lock_guard<std::mutex> lock(command_queue_mutex_);
    return command_queue_.empty();
  }

  size_t NumberOfQueuedCommands() const {
    const std::lock_guard<std::mutex> lock(command_queue_mutex_);
    return command_queue_.size();
  }

 public:
 public:
  void SetCommandFuture() {
  void SetCommandFuture() {
    ASSERT_LOG(command_promise_ == nullptr, "Promises, Promises, ... Only one at a time.");
    ASSERT_LOG(command_promise_ == nullptr, "Promises, Promises, ... Only one at a time.");
@@ -159,7 +205,7 @@ class TestHciLayer : public HciLayer {


  CommandView GetCommand(OpCode op_code) {
  CommandView GetCommand(OpCode op_code) {
    if (!command_queue_.empty()) {
    if (!command_queue_.empty()) {
      std::lock_guard<std::mutex> lock(mutex_);
      std::lock_guard<std::mutex> lock(command_queue_mutex_);
      if (command_future_ != nullptr) {
      if (command_future_ != nullptr) {
        command_future_.reset();
        command_future_.reset();
        command_promise_.reset();
        command_promise_.reset();
@@ -168,7 +214,7 @@ class TestHciLayer : public HciLayer {
      auto result = command_future_->wait_for(std::chrono::milliseconds(1000));
      auto result = command_future_->wait_for(std::chrono::milliseconds(1000));
      EXPECT_NE(std::future_status::timeout, result);
      EXPECT_NE(std::future_status::timeout, result);
    }
    }
    std::lock_guard<std::mutex> lock(mutex_);
    std::lock_guard<std::mutex> lock(command_queue_mutex_);
    ASSERT_LOG(
    ASSERT_LOG(
        !command_queue_.empty(), "Expecting command %s but command queue was empty", OpCodeText(op_code).c_str());
        !command_queue_.empty(), "Expecting command %s but command queue was empty", OpCodeText(op_code).c_str());
    CommandView command_packet_view = GetLastCommand();
    CommandView command_packet_view = GetLastCommand();
@@ -222,9 +268,9 @@ class TestHciLayer : public HciLayer {
  std::list<common::ContextualOnceCallback<void(CommandStatusView)>> command_status_callbacks;
  std::list<common::ContextualOnceCallback<void(CommandStatusView)>> command_status_callbacks;
  common::ContextualCallback<void(LeMetaEventView)> le_event_handler_;
  common::ContextualCallback<void(LeMetaEventView)> le_event_handler_;
  std::queue<std::unique_ptr<CommandBuilder>> command_queue_;
  std::queue<std::unique_ptr<CommandBuilder>> command_queue_;
  mutable std::mutex command_queue_mutex_;
  std::unique_ptr<std::promise<void>> command_promise_;
  std::unique_ptr<std::promise<void>> command_promise_;
  std::unique_ptr<std::future<void>> command_future_;
  std::unique_ptr<std::future<void>> command_future_;
  mutable std::mutex mutex_;
  CommandInterfaceImpl<AclCommandBuilder> le_acl_connection_manager_interface_{*this};
  CommandInterfaceImpl<AclCommandBuilder> le_acl_connection_manager_interface_{*this};
};
};


@@ -242,6 +288,15 @@ class LeImplTest : public ::testing::Test {
    le_impl_ = new le_impl(hci_layer_, controller_, handler_, round_robin_scheduler_, true);
    le_impl_ = new le_impl(hci_layer_, controller_, handler_, round_robin_scheduler_, true);
    le_impl_->handle_register_le_callbacks(&mock_le_connection_callbacks_, handler_);
    le_impl_->handle_register_le_callbacks(&mock_le_connection_callbacks_, handler_);


    Address address;
    Address::FromString(kFixedAddress, address);
    fixed_address_ = AddressWithType(address, AddressType::PUBLIC_DEVICE_ADDRESS);

    Address::FromString(kRemoteAddress, address);
    remote_public_address_ = AddressWithType(address, AddressType::PUBLIC_DEVICE_ADDRESS);
  }

  void set_random_device_address_policy() {
    // Set address policy
    // Set address policy
    hci_layer_->SetCommandFuture();
    hci_layer_->SetCommandFuture();
    hci::Address address;
    hci::Address address;
@@ -279,7 +334,7 @@ class LeImplTest : public ::testing::Test {
    std::promise<void> promise;
    std::promise<void> promise;
    auto future = promise.get_future();
    auto future = promise.get_future();
    handler_->BindOnceOn(&promise, &std::promise<void>::set_value).Invoke();
    handler_->BindOnceOn(&promise, &std::promise<void>::set_value).Invoke();
    auto status = future.wait_for(10ms);
    auto status = future.wait_for(2s);
    ASSERT_EQ(status, std::future_status::ready);
    ASSERT_EQ(status, std::future_status::ready);
  }
  }


@@ -313,6 +368,16 @@ class LeImplTest : public ::testing::Test {
    MOCK_METHOD(void, OnLeConnectFail, (AddressWithType, ErrorCode reason), (override));
    MOCK_METHOD(void, OnLeConnectFail, (AddressWithType, ErrorCode reason), (override));
  } mock_le_connection_callbacks_;
  } mock_le_connection_callbacks_;


 protected:
  void set_privacy_policy_for_initiator_address(
      const AddressWithType& address, const LeAddressManager::AddressPolicy& policy) {
    le_impl_->set_privacy_policy_for_initiator_address(
        policy, address, kRotationIrk, kMinimumRotationTime, kMaximumRotationTime);
  }

  AddressWithType fixed_address_;
  AddressWithType remote_public_address_;

  uint16_t packet_count_;
  uint16_t packet_count_;
  std::unique_ptr<std::promise<void>> packet_promise_;
  std::unique_ptr<std::promise<void>> packet_promise_;
  std::unique_ptr<std::future<void>> packet_future_;
  std::unique_ptr<std::future<void>> packet_future_;
@@ -370,6 +435,8 @@ TEST_F(LeImplTest, remove_device_from_connect_list) {
}
}


TEST_F(LeImplTest, connection_complete_with_periperal_role) {
TEST_F(LeImplTest, connection_complete_with_periperal_role) {
  set_random_device_address_policy();

  // Create connection
  // Create connection
  hci_layer_->SetCommandFuture();
  hci_layer_->SetCommandFuture();
  le_impl_->create_le_connection(
  le_impl_->create_le_connection(
@@ -388,7 +455,7 @@ TEST_F(LeImplTest, connection_complete_with_periperal_role) {
  hci::Address remote_address;
  hci::Address remote_address;
  Address::FromString("D0:05:04:03:02:01", remote_address);
  Address::FromString("D0:05:04:03:02:01", remote_address);
  hci::AddressWithType address_with_type(remote_address, hci::AddressType::PUBLIC_DEVICE_ADDRESS);
  hci::AddressWithType address_with_type(remote_address, hci::AddressType::PUBLIC_DEVICE_ADDRESS);
  EXPECT_CALL(mock_le_connection_callbacks_, OnLeConnectSuccess(address_with_type, testing::_));
  EXPECT_CALL(mock_le_connection_callbacks_, OnLeConnectSuccess(address_with_type, ::testing::_));
  hci_layer_->IncomingLeMetaEvent(LeConnectionCompleteBuilder::Create(
  hci_layer_->IncomingLeMetaEvent(LeConnectionCompleteBuilder::Create(
      ErrorCode::SUCCESS,
      ErrorCode::SUCCESS,
      0x0041,
      0x0041,
@@ -406,6 +473,8 @@ TEST_F(LeImplTest, connection_complete_with_periperal_role) {
}
}


TEST_F(LeImplTest, enhanced_connection_complete_with_periperal_role) {
TEST_F(LeImplTest, enhanced_connection_complete_with_periperal_role) {
  set_random_device_address_policy();

  controller_->AddSupported(OpCode::LE_EXTENDED_CREATE_CONNECTION);
  controller_->AddSupported(OpCode::LE_EXTENDED_CREATE_CONNECTION);
  // Create connection
  // Create connection
  hci_layer_->SetCommandFuture();
  hci_layer_->SetCommandFuture();
@@ -425,7 +494,7 @@ TEST_F(LeImplTest, enhanced_connection_complete_with_periperal_role) {
  hci::Address remote_address;
  hci::Address remote_address;
  Address::FromString("D0:05:04:03:02:01", remote_address);
  Address::FromString("D0:05:04:03:02:01", remote_address);
  hci::AddressWithType address_with_type(remote_address, hci::AddressType::PUBLIC_DEVICE_ADDRESS);
  hci::AddressWithType address_with_type(remote_address, hci::AddressType::PUBLIC_DEVICE_ADDRESS);
  EXPECT_CALL(mock_le_connection_callbacks_, OnLeConnectSuccess(address_with_type, testing::_));
  EXPECT_CALL(mock_le_connection_callbacks_, OnLeConnectSuccess(address_with_type, ::testing::_));
  hci_layer_->IncomingLeMetaEvent(LeEnhancedConnectionCompleteBuilder::Create(
  hci_layer_->IncomingLeMetaEvent(LeEnhancedConnectionCompleteBuilder::Create(
      ErrorCode::SUCCESS,
      ErrorCode::SUCCESS,
      0x0041,
      0x0041,
@@ -445,6 +514,8 @@ TEST_F(LeImplTest, enhanced_connection_complete_with_periperal_role) {
}
}


TEST_F(LeImplTest, connection_complete_with_central_role) {
TEST_F(LeImplTest, connection_complete_with_central_role) {
  set_random_device_address_policy();

  hci::Address remote_address;
  hci::Address remote_address;
  Address::FromString("D0:05:04:03:02:01", remote_address);
  Address::FromString("D0:05:04:03:02:01", remote_address);
  hci::AddressWithType address_with_type(remote_address, hci::AddressType::PUBLIC_DEVICE_ADDRESS);
  hci::AddressWithType address_with_type(remote_address, hci::AddressType::PUBLIC_DEVICE_ADDRESS);
@@ -462,7 +533,7 @@ TEST_F(LeImplTest, connection_complete_with_central_role) {
  ASSERT_EQ(ConnectabilityState::ARMED, le_impl_->connectability_state_);
  ASSERT_EQ(ConnectabilityState::ARMED, le_impl_->connectability_state_);


  // Receive connection complete of outgoing connection (Role::CENTRAL)
  // Receive connection complete of outgoing connection (Role::CENTRAL)
  EXPECT_CALL(mock_le_connection_callbacks_, OnLeConnectSuccess(address_with_type, testing::_));
  EXPECT_CALL(mock_le_connection_callbacks_, OnLeConnectSuccess(address_with_type, ::testing::_));
  hci_layer_->IncomingLeMetaEvent(LeConnectionCompleteBuilder::Create(
  hci_layer_->IncomingLeMetaEvent(LeConnectionCompleteBuilder::Create(
      ErrorCode::SUCCESS,
      ErrorCode::SUCCESS,
      0x0041,
      0x0041,
@@ -480,6 +551,8 @@ TEST_F(LeImplTest, connection_complete_with_central_role) {
}
}


TEST_F(LeImplTest, enhanced_connection_complete_with_central_role) {
TEST_F(LeImplTest, enhanced_connection_complete_with_central_role) {
  set_random_device_address_policy();

  controller_->AddSupported(OpCode::LE_EXTENDED_CREATE_CONNECTION);
  controller_->AddSupported(OpCode::LE_EXTENDED_CREATE_CONNECTION);
  hci::Address remote_address;
  hci::Address remote_address;
  Address::FromString("D0:05:04:03:02:01", remote_address);
  Address::FromString("D0:05:04:03:02:01", remote_address);
@@ -498,7 +571,7 @@ TEST_F(LeImplTest, enhanced_connection_complete_with_central_role) {
  ASSERT_EQ(ConnectabilityState::ARMED, le_impl_->connectability_state_);
  ASSERT_EQ(ConnectabilityState::ARMED, le_impl_->connectability_state_);


  // Receive connection complete of outgoing connection (Role::CENTRAL)
  // Receive connection complete of outgoing connection (Role::CENTRAL)
  EXPECT_CALL(mock_le_connection_callbacks_, OnLeConnectSuccess(address_with_type, testing::_));
  EXPECT_CALL(mock_le_connection_callbacks_, OnLeConnectSuccess(address_with_type, ::testing::_));
  hci_layer_->IncomingLeMetaEvent(LeEnhancedConnectionCompleteBuilder::Create(
  hci_layer_->IncomingLeMetaEvent(LeEnhancedConnectionCompleteBuilder::Create(
      ErrorCode::SUCCESS,
      ErrorCode::SUCCESS,
      0x0041,
      0x0041,
@@ -517,6 +590,59 @@ TEST_F(LeImplTest, enhanced_connection_complete_with_central_role) {
  ASSERT_EQ(ConnectabilityState::DISARMED, le_impl_->connectability_state_);
  ASSERT_EQ(ConnectabilityState::DISARMED, le_impl_->connectability_state_);
}
}


TEST_F(LeImplTest, register_with_address_manager__AddressPolicyNotSet) {
  bluetooth::common::InitFlags::SetAllForTesting();
  auto log_capture = std::make_unique<LogCapture>();

  std::promise<void> promise;
  auto future = promise.get_future();
  handler_->Post(common::BindOnce(
      [](struct le_impl* le_impl, os::Handler* handler, std::promise<void> promise) {
        le_impl->register_with_address_manager();
        handler->Post(common::BindOnce([](std::promise<void> promise) { promise.set_value(); }, std::move(promise)));
      },
      le_impl_,
      handler_,
      std::move(promise)));

  // Let |LeAddressManager::register_client| execute on handler
  auto status = future.wait_for(2s);
  ASSERT_EQ(status, std::future_status::ready);

  handler_->Post(common::BindOnce(
      [](struct le_impl* le_impl) {
        ASSERT_TRUE(le_impl->address_manager_registered);
        ASSERT_TRUE(le_impl->pause_connection);
      },
      le_impl_));

  std::promise<void> promise2;
  auto future2 = promise2.get_future();
  handler_->Post(common::BindOnce(
      [](struct le_impl* le_impl, os::Handler* handler, std::promise<void> promise) {
        le_impl->ready_to_unregister = true;
        le_impl->check_for_unregister();
        ASSERT_FALSE(le_impl->address_manager_registered);
        ASSERT_FALSE(le_impl->pause_connection);
        handler->Post(common::BindOnce([](std::promise<void> promise) { promise.set_value(); }, std::move(promise)));
      },
      le_impl_,
      handler_,
      std::move(promise2)));

  // Let |LeAddressManager::unregister_client| execute on handler
  auto status2 = future2.wait_for(2s);
  ASSERT_EQ(status2, std::future_status::ready);

  handler_->Post(common::BindOnce(
      [](std::unique_ptr<LogCapture> log_capture) {
        log_capture->Sync();
        ASSERT_TRUE(log_capture->Rewind()->Find("address policy isn't set yet"));
        ASSERT_TRUE(log_capture->Rewind()->Find("Client unregistered"));
      },
      std::move(log_capture)));
}

}  // namespace acl_manager
}  // namespace acl_manager
}  // namespace hci
}  // namespace hci
}  // namespace bluetooth
}  // namespace bluetooth
+9 −0
Original line number Original line Diff line number Diff line
@@ -169,8 +169,10 @@ void LeAddressManager::register_client(LeAddressManagerCallback* callback) {
      address_policy_ == AddressPolicy::USE_NON_RESOLVABLE_ADDRESS) {
      address_policy_ == AddressPolicy::USE_NON_RESOLVABLE_ADDRESS) {
      if (registered_clients_.size() == 1) {
      if (registered_clients_.size() == 1) {
        schedule_rotate_random_address();
        schedule_rotate_random_address();
        LOG_INFO("Scheduled address rotation for first client registered");
      }
      }
  }
  }
  LOG_INFO("Client registered");
}
}


void LeAddressManager::Unregister(LeAddressManagerCallback* callback) {
void LeAddressManager::Unregister(LeAddressManagerCallback* callback) {
@@ -185,9 +187,11 @@ void LeAddressManager::unregister_client(LeAddressManagerCallback* callback) {
      ack_resume(callback);
      ack_resume(callback);
    }
    }
    registered_clients_.erase(callback);
    registered_clients_.erase(callback);
    LOG_INFO("Client unregistered");
  }
  }
  if (registered_clients_.empty() && address_rotation_alarm_ != nullptr) {
  if (registered_clients_.empty() && address_rotation_alarm_ != nullptr) {
    address_rotation_alarm_->Cancel();
    address_rotation_alarm_->Cancel();
    LOG_INFO("Cancelled address rotation alarm");
  }
  }
}
}


@@ -243,12 +247,14 @@ void LeAddressManager::push_command(Command command) {


void LeAddressManager::ack_pause(LeAddressManagerCallback* callback) {
void LeAddressManager::ack_pause(LeAddressManagerCallback* callback) {
  if (registered_clients_.find(callback) == registered_clients_.end()) {
  if (registered_clients_.find(callback) == registered_clients_.end()) {
    LOG_INFO("No clients registered to ack pause");
    return;
    return;
  }
  }
  registered_clients_.find(callback)->second = ClientState::PAUSED;
  registered_clients_.find(callback)->second = ClientState::PAUSED;
  for (auto client : registered_clients_) {
  for (auto client : registered_clients_) {
    switch (client.second) {
    switch (client.second) {
      case ClientState::PAUSED:
      case ClientState::PAUSED:
        LOG_INFO("Client already in paused state");
        break;
        break;
      case ClientState::WAITING_FOR_PAUSE:
      case ClientState::WAITING_FOR_PAUSE:
        // make sure all client paused
        // make sure all client paused
@@ -260,6 +266,8 @@ void LeAddressManager::ack_pause(LeAddressManagerCallback* callback) {
        client.second = ClientState::WAITING_FOR_PAUSE;
        client.second = ClientState::WAITING_FOR_PAUSE;
        client.first->OnPause();
        client.first->OnPause();
        return;
        return;
      default:
        LOG_ERROR("Found client in unexpected state:%u", client.second);
    }
    }
  }
  }


@@ -275,6 +283,7 @@ void LeAddressManager::resume_registered_clients() {
    return;
    return;
  }
  }


  LOG_INFO("Resuming registered clients");
  for (auto& client : registered_clients_) {
  for (auto& client : registered_clients_) {
    client.second = ClientState::WAITING_FOR_RESUME;
    client.second = ClientState::WAITING_FOR_RESUME;
    client.first->OnResume();
    client.first->OnResume();