Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit aec43de2 authored by Chienyuan's avatar Chienyuan
Browse files

gd HCI: LE Address rotation

Tag: #gd-refactor
Bug: 152348535
Test: gd/cert/run --host
Test: atest bluetooth_test_gd
Change-Id: Icb2c85118af85383c536bd92378e261cab449591
parent a6102c7d
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ filegroup {
        "device.cc",
        "device_database.cc",
        "hci_layer.cc",
        "le_address_rotator.cc",
        "le_advertising_manager.cc",
        "le_scanning_manager.cc",
    ],
@@ -33,6 +34,7 @@ filegroup {
        "dual_device_test.cc",
        "hci_layer_test.cc",
        "hci_packets_test.cc",
        "le_address_rotator_test.cc",
        "le_advertising_manager_test.cc",
        "le_scanning_manager_test.cc",
    ],
+14 −0
Original line number Diff line number Diff line
@@ -145,6 +145,16 @@ void AclManager::SetLeInitiatorAddress(AddressWithType initiator_address) {
      common::BindOnce(&le_impl::set_le_initiator_address, common::Unretained(pimpl_->le_impl_), initiator_address));
}

void AclManager::SetPrivacyPolicyForInitiatorAddress(LeAddressRotator::AddressPolicy address_policy,
                                                     AddressWithType fixed_address,
                                                     crypto_toolbox::Octet16 rotation_irk,
                                                     std::chrono::milliseconds minimum_rotation_time,
                                                     std::chrono::milliseconds maximum_rotation_time) {
  GetHandler()->Post(common::BindOnce(&le_impl::set_privacy_policy_for_initiator_address,
                                      common::Unretained(pimpl_->le_impl_), address_policy, fixed_address, rotation_irk,
                                      minimum_rotation_time, maximum_rotation_time));
}

void AclManager::CancelConnect(Address address) {
  GetHandler()->Post(BindOnce(&classic_impl::cancel_connect, common::Unretained(pimpl_->classic_impl_), address));
}
@@ -173,6 +183,10 @@ void AclManager::SetSecurityModule(security::SecurityModule* security_module) {
      BindOnce(&classic_impl::set_security_module, common::Unretained(pimpl_->classic_impl_), security_module));
}

LeAddressRotator* AclManager::GetLeAddressRotator() {
  return pimpl_->le_impl_->le_address_rotator_;
}

void AclManager::ListDependencies(ModuleList* list) {
  list->add<HciLayer>();
  list->add<Controller>();
+7 −0
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@
#include "hci/address_with_type.h"
#include "hci/hci_layer.h"
#include "hci/hci_packets.h"
#include "hci/le_address_rotator.h"
#include "module.h"
#include "os/handler.h"

@@ -60,6 +61,10 @@ class AclManager : public Module {
  virtual void CreateLeConnection(AddressWithType address_with_type);

  virtual void SetLeInitiatorAddress(AddressWithType initiator_address);
  virtual void SetPrivacyPolicyForInitiatorAddress(LeAddressRotator::AddressPolicy address_policy,
                                                   AddressWithType fixed_address, crypto_toolbox::Octet16 rotation_irk,
                                                   std::chrono::milliseconds minimum_rotation_time,
                                                   std::chrono::milliseconds maximum_rotation_time);

  // Generates OnConnectFail with error code "terminated by local host 0x16" if cancelled, or OnConnectSuccess if not
  // successfully cancelled and already connected
@@ -73,6 +78,8 @@ class AclManager : public Module {
  // In order to avoid circular dependency use setter rather than module dependency.
  virtual void SetSecurityModule(security::SecurityModule* security_module);

  virtual LeAddressRotator* GetLeAddressRotator();

  static const ModuleFactory Factory;

 protected:
+93 −69
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@
#include "hci/acl_manager/assembler.h"
#include "hci/acl_manager/disconnector_for_le.h"
#include "hci/acl_manager/round_robin_scheduler.h"
#include "hci/le_address_rotator.h"
#include "os/alarm.h"
#include "os/rand.h"

@@ -41,7 +42,7 @@ struct le_acl_connection {
  LeConnectionManagementCallbacks* le_connection_management_callbacks_ = nullptr;
};

struct le_impl {
struct le_impl : public bluetooth::hci::LeAddressRotatorCallback {
  le_impl(HciLayer* hci_layer, Controller* controller, os::Handler* handler, RoundRobinScheduler* round_robin_scheduler,
          DisconnectorForLe* disconnector)
      : hci_layer_(hci_layer), controller_(controller), round_robin_scheduler_(round_robin_scheduler),
@@ -51,24 +52,17 @@ struct le_impl {
    handler_ = handler;
    le_acl_connection_interface_ = hci_layer_->GetLeAclConnectionInterface(
        handler_->BindOn(this, &le_impl::on_le_event), handler_->BindOn(this, &le_impl::on_le_disconnect));
    le_initiator_address_ =
        AddressWithType(Address{{0x00, 0x11, 0xFF, 0xFF, 0x33, 0x22}}, AddressType::RANDOM_DEVICE_ADDRESS);

    if (le_initiator_address_.GetAddressType() == AddressType::RANDOM_DEVICE_ADDRESS) {
      address_rotation_alarm_ = std::make_unique<os::Alarm>(handler_);
      RotateRandomAddress();
    }
    le_address_rotator_ = new LeAddressRotator(common::Bind(&le_impl::SetRandomAddress, common::Unretained(this)),
                                               handler_, controller->GetControllerMacAddress());
    le_address_rotator_->Register(this);
  }

  ~le_impl() {
    for (auto subevent_code : LeConnectionManagementEvents) {
      hci_layer_->UnregisterLeEventHandler(subevent_code);
    }
    // Address might have been already canceled if public address was used
    if (address_rotation_alarm_) {
      address_rotation_alarm_->Cancel();
      address_rotation_alarm_.reset();
    }
    le_address_rotator_->Unregister(this);
    delete le_address_rotator_;
    le_acl_connections_.clear();
  }

@@ -115,9 +109,14 @@ struct le_impl {
    auto peer_address_type = connection_complete.GetPeerAddressType();
    // TODO: find out which address and type was used to initiate the connection
    AddressWithType remote_address(address, peer_address_type);
    AddressWithType local_address = le_initiator_address_;
    AddressWithType local_address = le_address_rotator_->GetCurrentAddress();
    on_common_le_connection_complete(remote_address);
    if (status != ErrorCode::SUCCESS) {
    if (status == ErrorCode::UNKNOWN_CONNECTION &&
        canceled_connections_.find(remote_address) != canceled_connections_.end()) {
      // connection canceled by LeAddressRotator.OnPause(), will auto reconnect by LeAddressRotator.OnResume()
      return;
    } else if (status != ErrorCode::SUCCESS) {
      canceled_connections_.erase(remote_address);
      le_client_handler_->Post(common::BindOnce(&LeConnectionCallbacks::OnLeConnectFail,
                                                common::Unretained(le_client_callbacks_), remote_address, status));
      return;
@@ -150,12 +149,17 @@ struct le_impl {
    auto peer_address_type = connection_complete.GetPeerAddressType();
    auto peer_resolvable_address = connection_complete.GetPeerResolvablePrivateAddress();
    AddressWithType remote_address(address, peer_address_type);
    AddressWithType local_address = le_initiator_address_;
    AddressWithType local_address = le_address_rotator_->GetCurrentAddress();
    if (!peer_resolvable_address.IsEmpty()) {
      remote_address = AddressWithType(peer_resolvable_address, AddressType::RANDOM_DEVICE_ADDRESS);
    }
    on_common_le_connection_complete(remote_address);
    if (status != ErrorCode::SUCCESS) {
    if (status == ErrorCode::UNKNOWN_CONNECTION &&
        canceled_connections_.find(remote_address) != canceled_connections_.end()) {
      // connection canceled by LeAddressRotator.OnPause(), will auto reconnect by LeAddressRotator.OnResume()
      return;
    } else if (status != ErrorCode::SUCCESS) {
      canceled_connections_.erase(remote_address);
      le_client_handler_->Post(common::BindOnce(&LeConnectionCallbacks::OnLeConnectFail,
                                                common::Unretained(le_client_callbacks_), remote_address, status));
      return;
@@ -201,6 +205,22 @@ struct le_impl {
        complete_view.GetConnInterval(), complete_view.GetConnLatency(), complete_view.GetSupervisionTimeout());
  }

  void on_le_set_random_address_complete(CommandCompleteView view) {
    auto complete_view = LeSetRandomAddressCompleteView::Create(view);
    if (!complete_view.IsValid()) {
      LOG_ERROR("Received on_le_set_random_address_complete with invalid packet");
      le_address_rotator_->OnLeSetRandomAddressComplete(false);
      return;
    } else if (complete_view.GetStatus() != ErrorCode::SUCCESS) {
      auto status = complete_view.GetStatus();
      std::string error_code = ErrorCodeText(status);
      LOG_ERROR("Received on_le_set_random_address_complete with error code %s", error_code.c_str());
      le_address_rotator_->OnLeSetRandomAddressComplete(false);
      return;
    }
    le_address_rotator_->OnLeSetRandomAddressComplete(true);
  }

  std::chrono::milliseconds GetNextPrivateAddrressIntervalMs() {
    /* 7 minutes minimum, 15 minutes maximum for random address refreshing */
    const uint64_t interval_min_ms = (7 * 60 * 1000);
@@ -209,44 +229,10 @@ struct le_impl {
    return std::chrono::milliseconds(interval_min_ms + os::GenerateRandom() % interval_random_part_max_ms);
  }

  /* This function generates Resolvable Private Address (RPA) from Identity
   * Resolving Key |irk| and |prand|*/
  hci::Address GenerateRpa(const Octet16& irk, std::array<uint8_t, 8> prand) {
    /* most significant bit, bit7, bit6 is 01 to be resolvable random */
    constexpr uint8_t BLE_RESOLVE_ADDR_MSB = 0x40;
    constexpr uint8_t BLE_RESOLVE_ADDR_MASK = 0xc0;
    prand[2] &= (~BLE_RESOLVE_ADDR_MASK);
    prand[2] |= BLE_RESOLVE_ADDR_MSB;

    hci::Address address;
    address.address[3] = prand[0];
    address.address[4] = prand[1];
    address.address[5] = prand[2];

    /* encrypt with IRK */
    Octet16 p = crypto_toolbox::aes_128(irk, prand.data(), 3);

    /* set hash to be LSB of rpAddress */
    address.address[0] = p[0];
    address.address[1] = p[1];
    address.address[2] = p[2];
    return address;
  }

  void RotateRandomAddress() {
    // TODO: we must stop advertising, conection initiation, and scanning before calling SetRandomAddress.
    // TODO: ensure this is called before first connection initiation.
    // TODO: obtain proper IRK
    Octet16 irk = {} /* TODO: = BTM_GetDeviceIDRoot() */;
    std::array<uint8_t, 8> random = os::GenerateRandom<8>();
    hci::Address address = GenerateRpa(irk, random);

    hci_layer_->EnqueueCommand(hci::LeSetRandomAddressBuilder::Create(address),
                               handler_->BindOnce(check_command_complete<LeSetRandomAddressCompleteView>));

    le_initiator_address_ = AddressWithType(address, AddressType::RANDOM_DEVICE_ADDRESS);
    address_rotation_alarm_->Schedule(BindOnce(&le_impl::RotateRandomAddress, common::Unretained(this)),
                                      GetNextPrivateAddrressIntervalMs());
  void SetRandomAddress(Address address) {
    hci_layer_->EnqueueCommand(
        hci::LeSetRandomAddressBuilder::Create(address),
        handler_->BindOnce(&le_impl::on_le_set_random_address_complete, common::Unretained(this)));
  }

  void create_le_connection(AddressWithType address_with_type) {
@@ -255,13 +241,19 @@ struct le_impl {
    uint16_t le_scan_interval = 0x0060;
    uint16_t le_scan_window = 0x0030;
    InitiatorFilterPolicy initiator_filter_policy = InitiatorFilterPolicy::USE_PEER_ADDRESS;
    OwnAddressType own_address_type = static_cast<OwnAddressType>(le_initiator_address_.GetAddressType());
    OwnAddressType own_address_type =
        static_cast<OwnAddressType>(le_address_rotator_->GetCurrentAddress().GetAddressType());
    uint16_t conn_interval_min = 0x0018;
    uint16_t conn_interval_max = 0x0028;
    uint16_t conn_latency = 0x0000;
    uint16_t supervision_timeout = 0x001f4;
    ASSERT(le_client_callbacks_ != nullptr);

    if (pause_connection) {
      canceled_connections_.insert(address_with_type);
      return;
    }

    connecting_le_.insert(address_with_type);

    // TODO: make features check nicer, like HCI_LE_EXTENDED_ADVERTISING_SUPPORTED
@@ -298,21 +290,19 @@ struct le_impl {
  }

  void set_le_initiator_address(AddressWithType le_initiator_address) {
    le_initiator_address_ = le_initiator_address;

    if (le_initiator_address_.GetAddressType() != AddressType::RANDOM_DEVICE_ADDRESS) {
    if (le_initiator_address.GetAddressType() != AddressType::RANDOM_DEVICE_ADDRESS) {
      // Usually controllers provide vendor-specific way to override public address. Implement it if it's ever needed.
      LOG_ALWAYS_FATAL("Don't know how to use this type of address");
    }

    if (address_rotation_alarm_) {
      address_rotation_alarm_->Cancel();
      address_rotation_alarm_.reset();
    le_address_rotator_->SetAddress(le_initiator_address);
  }

    // TODO: we must stop advertising, conection initiation, and scanning before calling SetRandomAddress.
    hci_layer_->EnqueueCommand(hci::LeSetRandomAddressBuilder::Create(le_initiator_address_.GetAddress()),
                               handler_->BindOnce([](CommandCompleteView status) {}));
  void set_privacy_policy_for_initiator_address(LeAddressRotator::AddressPolicy address_policy,
                                                AddressWithType fixed_address, crypto_toolbox::Octet16 rotation_irk,
                                                std::chrono::milliseconds minimum_rotation_time,
                                                std::chrono::milliseconds maximum_rotation_time) {
    le_address_rotator_->SetPrivacyPolicyForInitiatorAddress(address_policy, fixed_address, rotation_irk,
                                                             minimum_rotation_time, maximum_rotation_time);
  }

  void handle_register_le_callbacks(LeConnectionCallbacks* callbacks, os::Handler* handler) {
@@ -328,20 +318,54 @@ struct le_impl {
    return connection->second;
  }

  void OnPause() override {
    if (pause_connection) {
      le_address_rotator_->AckPause(this);
      return;
    }

    pause_connection = true;
    if (!connecting_le_.empty()) {
      canceled_connections_ = connecting_le_;
      le_acl_connection_interface_->EnqueueCommand(
          LeCreateConnectionCancelBuilder::Create(),
          handler_->BindOnce(&le_impl::on_create_connection_cancel_complete, common::Unretained(this)));
    } else {
      le_address_rotator_->AckPause(this);
    }
  }

  void on_create_connection_cancel_complete(CommandCompleteView view) {
    auto complete_view = CreateConnectionCancelCompleteView::Create(view);
    ASSERT(complete_view.IsValid());
    ASSERT(complete_view.GetStatus() == ErrorCode::SUCCESS);
    le_address_rotator_->AckPause(this);
  }

  void OnResume() override {
    pause_connection = false;
    for (auto address_with_type : canceled_connections_) {
      create_le_connection(address_with_type);
    }
    canceled_connections_.clear();
    le_address_rotator_->AckResume(this);
  }

  static constexpr uint16_t kMinimumCeLength = 0x0002;
  static constexpr uint16_t kMaximumCeLength = 0x0C00;
  HciLayer* hci_layer_ = nullptr;
  Controller* controller_ = nullptr;
  os::Handler* handler_ = nullptr;
  RoundRobinScheduler* round_robin_scheduler_ = nullptr;
  LeAddressRotator* le_address_rotator_ = nullptr;
  LeAclConnectionInterface* le_acl_connection_interface_ = nullptr;
  LeConnectionCallbacks* le_client_callbacks_ = nullptr;
  os::Handler* le_client_handler_ = nullptr;
  std::map<uint16_t, le_acl_connection> le_acl_connections_;
  std::set<AddressWithType> connecting_le_;
  AddressWithType le_initiator_address_{Address{}, AddressType::RANDOM_DEVICE_ADDRESS};
  std::unique_ptr<os::Alarm> address_rotation_alarm_;
  std::set<AddressWithType> canceled_connections_;
  DisconnectorForLe* disconnector_;
  bool pause_connection = false;
};

}  // namespace acl_manager
+5 −0
Original line number Diff line number Diff line
@@ -70,6 +70,11 @@ class MockAclManager : public AclManager {
  MOCK_METHOD(void, CreateConnection, (Address address), (override));
  MOCK_METHOD(void, CreateLeConnection, (AddressWithType address_with_type), (override));
  MOCK_METHOD(void, CancelConnect, (Address address), (override));
  MOCK_METHOD(void, SetPrivacyPolicyForInitiatorAddress,
              (LeAddressRotator::AddressPolicy address_policy, AddressWithType fixed_address,
               crypto_toolbox::Octet16 rotation_irk, std::chrono::milliseconds minimum_rotation_time,
               std::chrono::milliseconds maximum_rotation_time),
              (override));
};

}  // namespace testing
Loading