Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 3f15ef46 authored by Jakub Pawlowski's avatar Jakub Pawlowski
Browse files

SecurityRecordDatabase

Store SecurityRecords directly in vector, rather than shared_ptr
Move management of SecurityRecord into separate unit -
SecurityRecordDatabase.

Bug: 142341141
Change-Id: I0cc2dd8a7ddcf5a01117f0ebf7bd68111a93a2c5
parent 9fe8c849
Loading
Loading
Loading
Loading
+15 −34
Original line number Original line Diff line number Diff line
@@ -28,46 +28,29 @@ namespace bluetooth {
namespace security {
namespace security {
namespace internal {
namespace internal {


std::shared_ptr<bluetooth::security::record::SecurityRecord> SecurityManagerImpl::CreateSecurityRecord(
void SecurityManagerImpl::DispatchPairingHandler(record::SecurityRecord& record, bool locally_initiated) {
    hci::Address address) {
  hci::AddressWithType device(address, hci::AddressType::PUBLIC_DEVICE_ADDRESS);
  // Security record check
  auto entry = security_record_map_.find(device.GetAddress());
  if (entry == security_record_map_.end()) {
    LOG_INFO("No security record for device: %s ", device.ToString().c_str());
    // Create one
    std::shared_ptr<security::record::SecurityRecord> record =
        std::make_shared<security::record::SecurityRecord>(device);
    auto new_entry = std::pair<hci::Address, std::shared_ptr<security::record::SecurityRecord>>(
        record->GetPseudoAddress().GetAddress(), record);
    // Keep track of it
    security_record_map_.insert(new_entry);
    return record;
  }
  return entry->second;
}

void SecurityManagerImpl::DispatchPairingHandler(std::shared_ptr<security::record::SecurityRecord> record,
                                                 bool locally_initiated) {
  common::OnceCallback<void(hci::Address, PairingResultOrFailure)> callback =
  common::OnceCallback<void(hci::Address, PairingResultOrFailure)> callback =
      common::BindOnce(&SecurityManagerImpl::OnPairingHandlerComplete, common::Unretained(this));
      common::BindOnce(&SecurityManagerImpl::OnPairingHandlerComplete, common::Unretained(this));
  auto entry = pairing_handler_map_.find(record->GetPseudoAddress().GetAddress());
  auto entry = pairing_handler_map_.find(record.GetPseudoAddress().GetAddress());
  if (entry != pairing_handler_map_.end()) {
  if (entry != pairing_handler_map_.end()) {
    LOG_WARN("Device already has a pairing handler, and is in the middle of pairing!");
    LOG_WARN("Device already has a pairing handler, and is in the middle of pairing!");
    return;
    return;
  }
  }
  std::shared_ptr<pairing::PairingHandler> pairing_handler = nullptr;
  std::shared_ptr<pairing::PairingHandler> pairing_handler = nullptr;
  switch (record->GetPseudoAddress().GetAddressType()) {
  switch (record.GetPseudoAddress().GetAddressType()) {
    case hci::AddressType::PUBLIC_DEVICE_ADDRESS:
    case hci::AddressType::PUBLIC_DEVICE_ADDRESS: {
      std::shared_ptr<record::SecurityRecord> record_copy =
          std::make_shared<record::SecurityRecord>(record.GetPseudoAddress());
      pairing_handler = std::make_shared<security::pairing::ClassicPairingHandler>(
      pairing_handler = std::make_shared<security::pairing::ClassicPairingHandler>(
          l2cap_classic_module_->GetFixedChannelManager(), security_manager_channel_, record, security_handler_,
          l2cap_classic_module_->GetFixedChannelManager(), security_manager_channel_, record_copy, security_handler_,
          std::move(callback), listeners_);
          std::move(callback), listeners_);
      break;
      break;
    }
    default:
    default:
      ASSERT_LOG(false, "Pairing type %hhu not implemented!", record->GetPseudoAddress().GetAddressType());
      ASSERT_LOG(false, "Pairing type %hhu not implemented!", record.GetPseudoAddress().GetAddressType());
  }
  }
  auto new_entry = std::pair<hci::Address, std::shared_ptr<pairing::PairingHandler>>(
  auto new_entry = std::pair<hci::Address, std::shared_ptr<pairing::PairingHandler>>(
      record->GetPseudoAddress().GetAddress(), pairing_handler);
      record.GetPseudoAddress().GetAddress(), pairing_handler);
  pairing_handler_map_.insert(std::move(new_entry));
  pairing_handler_map_.insert(std::move(new_entry));
  pairing_handler->Initiate(locally_initiated, pairing::kDefaultIoCapability, pairing::kDefaultOobDataPresent,
  pairing_handler->Initiate(locally_initiated, pairing::kDefaultIoCapability, pairing::kDefaultOobDataPresent,
                            pairing::kDefaultAuthenticationRequirements);
                            pairing::kDefaultAuthenticationRequirements);
@@ -81,8 +64,8 @@ void SecurityManagerImpl::Init() {
}
}


void SecurityManagerImpl::CreateBond(hci::AddressWithType device) {
void SecurityManagerImpl::CreateBond(hci::AddressWithType device) {
  auto record = CreateSecurityRecord(device.GetAddress());
  record::SecurityRecord& record = security_database_.FindOrCreate(device);
  if (record->IsBonded()) {
  if (record.IsBonded()) {
    NotifyDeviceBonded(device);
    NotifyDeviceBonded(device);
  } else {
  } else {
    // Dispatch pairing handler, if we are calling create we are the initiator
    // Dispatch pairing handler, if we are calling create we are the initiator
@@ -106,10 +89,7 @@ void SecurityManagerImpl::CancelBond(hci::AddressWithType device) {


void SecurityManagerImpl::RemoveBond(hci::AddressWithType device) {
void SecurityManagerImpl::RemoveBond(hci::AddressWithType device) {
  CancelBond(device);
  CancelBond(device);
  auto entry = security_record_map_.find(device.GetAddress());
  security_database_.Remove(device);
  if (entry != security_record_map_.end()) {
    security_record_map_.erase(entry);
  }
  // Signal disconnect
  // Signal disconnect
  // Remove security record
  // Remove security record
  // Signal Remove from database
  // Signal Remove from database
@@ -168,7 +148,8 @@ void SecurityManagerImpl::HandleEvent(T packet) {
    auto event = hci::EventPacketView::Create(std::move(packet));
    auto event = hci::EventPacketView::Create(std::move(packet));
    ASSERT_LOG(event.IsValid(), "Received invalid packet");
    ASSERT_LOG(event.IsValid(), "Received invalid packet");
    const hci::EventCode code = event.GetEventCode();
    const hci::EventCode code = event.GetEventCode();
    auto record = CreateSecurityRecord(bd_addr);
    auto record =
        security_database_.FindOrCreate(hci::AddressWithType{bd_addr, hci::AddressType::PUBLIC_DEVICE_ADDRESS});
    switch (code) {
    switch (code) {
      case hci::EventCode::LINK_KEY_REQUEST:
      case hci::EventCode::LINK_KEY_REQUEST:
        DispatchPairingHandler(record, true);
        DispatchPairingHandler(record, true);
+3 −3
Original line number Original line Diff line number Diff line
@@ -26,6 +26,7 @@
#include "security/channel/security_manager_channel.h"
#include "security/channel/security_manager_channel.h"
#include "security/pairing/classic_pairing_handler.h"
#include "security/pairing/classic_pairing_handler.h"
#include "security/record/security_record.h"
#include "security/record/security_record.h"
#include "security/security_record_database.h"


namespace bluetooth {
namespace bluetooth {
namespace security {
namespace security {
@@ -122,8 +123,7 @@ class SecurityManagerImpl : public channel::ISecurityManagerChannelListener {
  template <class T>
  template <class T>
  void HandleEvent(T packet);
  void HandleEvent(T packet);


  std::shared_ptr<record::SecurityRecord> CreateSecurityRecord(hci::Address address);
  void DispatchPairingHandler(record::SecurityRecord& record, bool locally_initiated);
  void DispatchPairingHandler(std::shared_ptr<record::SecurityRecord> record, bool locally_initiated);
  void OnL2capRegistrationCompleteLe(l2cap::le::FixedChannelManager::RegistrationResult result,
  void OnL2capRegistrationCompleteLe(l2cap::le::FixedChannelManager::RegistrationResult result,
                                     std::unique_ptr<l2cap::le::FixedChannelService> le_smp_service);
                                     std::unique_ptr<l2cap::le::FixedChannelService> le_smp_service);
  void OnConnectionOpenLe(std::unique_ptr<l2cap::le::FixedChannel> channel);
  void OnConnectionOpenLe(std::unique_ptr<l2cap::le::FixedChannel> channel);
@@ -137,7 +137,7 @@ class SecurityManagerImpl : public channel::ISecurityManagerChannelListener {
  std::unique_ptr<l2cap::le::FixedChannelManager> l2cap_manager_le_;
  std::unique_ptr<l2cap::le::FixedChannelManager> l2cap_manager_le_;
  hci::LeSecurityInterface* hci_security_interface_le_ __attribute__((unused));
  hci::LeSecurityInterface* hci_security_interface_le_ __attribute__((unused));
  channel::SecurityManagerChannel* security_manager_channel_;
  channel::SecurityManagerChannel* security_manager_channel_;
  std::unordered_map<hci::Address, std::shared_ptr<record::SecurityRecord>> security_record_map_;
  SecurityRecordDatabase security_database_;
  std::unordered_map<hci::Address, std::shared_ptr<pairing::PairingHandler>> pairing_handler_map_;
  std::unordered_map<hci::Address, std::shared_ptr<pairing::PairingHandler>> pairing_handler_map_;
};
};
}  // namespace internal
}  // namespace internal
+7 −4
Original line number Original line Diff line number Diff line
@@ -41,6 +41,8 @@ class SecurityRecord {
 public:
 public:
  explicit SecurityRecord(hci::AddressWithType address) : pseudo_address_(address), state_(PAIRING) {}
  explicit SecurityRecord(hci::AddressWithType address) : pseudo_address_(address), state_(PAIRING) {}


  SecurityRecord& operator=(const SecurityRecord& other) = default;

  /**
  /**
   * Returns true if Link Keys are stored persistently
   * Returns true if Link Keys are stored persistently
   */
   */
@@ -72,15 +74,16 @@ class SecurityRecord {


 private:
 private:
  /* First address we have ever seen this device with, that we used to create bond */
  /* First address we have ever seen this device with, that we used to create bond */
  const hci::AddressWithType pseudo_address_;
  hci::AddressWithType pseudo_address_;

  /* Identity Address */
  std::optional<hci::AddressWithType> identity_address_;


  BondState state_;
  BondState state_;
  std::array<uint8_t, 16> link_key_ = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
  std::array<uint8_t, 16> link_key_ = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
  hci::KeyType key_type_ = hci::KeyType::DEBUG_COMBINATION;
  hci::KeyType key_type_ = hci::KeyType::DEBUG_COMBINATION;


 public:
  /* Identity Address */
  std::optional<hci::AddressWithType> identity_address_;

  std::optional<crypto_toolbox::Octet16> ltk;
  std::optional<crypto_toolbox::Octet16> ltk;
  std::optional<uint16_t> ediv;
  std::optional<uint16_t> ediv;
  std::optional<std::array<uint8_t, 8>> rand;
  std::optional<std::array<uint8_t, 8>> rand;
+63 −0
Original line number Original line Diff line number Diff line
/*
 * Copyright 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include "security/record/security_record.h"

namespace bluetooth {
namespace security {

class SecurityRecordDatabase {
 public:
  using iterator = std::vector<record::SecurityRecord>::iterator;

  record::SecurityRecord& FindOrCreate(hci::AddressWithType address) {
    auto it = Find(address);
    // Security record check
    if (it != records_.end()) return *it;

    // No security record, create one
    records_.emplace_back(address);
    return records_.back();
  }

  void Remove(const hci::AddressWithType& address) {
    auto it = Find(address);

    // No record exists
    if (it == records_.end()) return;

    record::SecurityRecord& last = records_.back();
    *it = std::move(last);
    records_.pop_back();
  }

  iterator Find(hci::AddressWithType address) {
    for (auto it = records_.begin(); it != records_.end(); ++it) {
      record::SecurityRecord& record = *it;
      if (record.identity_address_.has_value() && record.identity_address_.value() == address) return it;
      if (record.GetPseudoAddress() == address) return it;
      if (record.irk.has_value() && address.IsRpaThatMatchesIrk(record.irk.value())) return it;
    }
    return records_.end();
  }

  std::vector<record::SecurityRecord> records_;
};

}  // namespace security
}  // namespace bluetooth
 No newline at end of file