Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit f7d9e63d authored by Treehugger Robot's avatar Treehugger Robot Committed by Gerrit Code Review
Browse files

Merge "SM: SecurityRecordDatabase updates"

parents aacd3ca4 a99eb9b7
Loading
Loading
Loading
Loading
+19 −15
Original line number Original line Diff line number Diff line
@@ -33,29 +33,33 @@ namespace bluetooth {
namespace security {
namespace security {
namespace internal {
namespace internal {


void SecurityManagerImpl::DispatchPairingHandler(record::SecurityRecord& record, bool locally_initiated) {
void SecurityManagerImpl::DispatchPairingHandler(
    std::shared_ptr<record::SecurityRecord> record, bool locally_initiated) {
  common::OnceCallback<void(hci::Address, PairingResultOrFailure)> callback =
  common::OnceCallback<void(hci::Address, PairingResultOrFailure)> callback =
      common::BindOnce(&SecurityManagerImpl::OnPairingHandlerComplete, common::Unretained(this));
      common::BindOnce(&SecurityManagerImpl::OnPairingHandlerComplete, common::Unretained(this));
  auto entry = pairing_handler_map_.find(record.GetPseudoAddress().GetAddress());
  auto entry = pairing_handler_map_.find(record->GetPseudoAddress().GetAddress());
  if (entry != pairing_handler_map_.end()) {
  if (entry != pairing_handler_map_.end()) {
    LOG_WARN("Device already has a pairing handler, and is in the middle of pairing!");
    LOG_WARN("Device already has a pairing handler, and is in the middle of pairing!");
    return;
    return;
  }
  }
  std::shared_ptr<pairing::PairingHandler> pairing_handler = nullptr;
  std::shared_ptr<pairing::PairingHandler> pairing_handler = nullptr;
  switch (record.GetPseudoAddress().GetAddressType()) {
  switch (record->GetPseudoAddress().GetAddressType()) {
    case hci::AddressType::PUBLIC_DEVICE_ADDRESS: {
    case hci::AddressType::PUBLIC_DEVICE_ADDRESS: {
      std::shared_ptr<record::SecurityRecord> record_copy =
          std::make_shared<record::SecurityRecord>(record.GetPseudoAddress());
      pairing_handler = std::make_shared<security::pairing::ClassicPairingHandler>(
      pairing_handler = std::make_shared<security::pairing::ClassicPairingHandler>(
          security_manager_channel_, record_copy, security_handler_, std::move(callback), user_interface_,
          security_manager_channel_,
          user_interface_handler_, "TODO: grab device name properly");
          record,
          security_handler_,
          std::move(callback),
          user_interface_,
          user_interface_handler_,
          "TODO: grab device name properly");
      break;
      break;
    }
    }
    default:
    default:
      ASSERT_LOG(false, "Pairing type %hhu not implemented!", record.GetPseudoAddress().GetAddressType());
      ASSERT_LOG(false, "Pairing type %hhu not implemented!", record->GetPseudoAddress().GetAddressType());
  }
  }
  auto new_entry = std::pair<hci::Address, std::shared_ptr<pairing::PairingHandler>>(
  auto new_entry = std::pair<hci::Address, std::shared_ptr<pairing::PairingHandler>>(
      record.GetPseudoAddress().GetAddress(), pairing_handler);
      record->GetPseudoAddress().GetAddress(), pairing_handler);
  pairing_handler_map_.insert(std::move(new_entry));
  pairing_handler_map_.insert(std::move(new_entry));
  pairing_handler->Initiate(locally_initiated, this->local_io_capability_, this->local_oob_data_present_,
  pairing_handler->Initiate(locally_initiated, this->local_io_capability_, this->local_oob_data_present_,
                            this->local_authentication_requirements_);
                            this->local_authentication_requirements_);
@@ -69,8 +73,8 @@ void SecurityManagerImpl::Init() {
}
}


void SecurityManagerImpl::CreateBond(hci::AddressWithType device) {
void SecurityManagerImpl::CreateBond(hci::AddressWithType device) {
  record::SecurityRecord& record = security_database_.FindOrCreate(device);
  auto record = security_database_.FindOrCreate(device);
  if (record.IsBonded()) {
  if (record->IsBonded()) {
    NotifyDeviceBonded(device);
    NotifyDeviceBonded(device);
  } else {
  } else {
    // Dispatch pairing handler, if we are calling create we are the initiator
    // Dispatch pairing handler, if we are calling create we are the initiator
@@ -79,8 +83,8 @@ void SecurityManagerImpl::CreateBond(hci::AddressWithType device) {
}
}


void SecurityManagerImpl::CreateBondLe(hci::AddressWithType address) {
void SecurityManagerImpl::CreateBondLe(hci::AddressWithType address) {
  record::SecurityRecord& record = security_database_.FindOrCreate(address);
  auto record = security_database_.FindOrCreate(address);
  if (record.IsBonded()) {
  if (record->IsBonded()) {
    NotifyDeviceBondFailed(address, PairingFailure("Already bonded"));
    NotifyDeviceBondFailed(address, PairingFailure("Already bonded"));
    return;
    return;
  }
  }
@@ -581,10 +585,10 @@ void SecurityManagerImpl::InternalEnforceSecurityPolicy(
  switch (policy) {
  switch (policy) {
    case l2cap::classic::SecurityPolicy::BEST:
    case l2cap::classic::SecurityPolicy::BEST:
    case l2cap::classic::SecurityPolicy::AUTHENTICATED_ENCRYPTED_TRANSPORT:
    case l2cap::classic::SecurityPolicy::AUTHENTICATED_ENCRYPTED_TRANSPORT:
      result = record.IsAuthenticated() && record.RequiresMitmProtection() && record.IsEncryptionRequired();
      result = record->IsAuthenticated() && record->RequiresMitmProtection() && record->IsEncryptionRequired();
      break;
      break;
    case l2cap::classic::SecurityPolicy::ENCRYPTED_TRANSPORT:
    case l2cap::classic::SecurityPolicy::ENCRYPTED_TRANSPORT:
      result = record.IsAuthenticated() && record.IsEncryptionRequired();
      result = record->IsAuthenticated() && record->IsEncryptionRequired();
      break;
      break;
    case l2cap::classic::SecurityPolicy::_SDP_ONLY_NO_SECURITY_WHATSOEVER_PLAINTEXT_TRANSPORT_OK:
    case l2cap::classic::SecurityPolicy::_SDP_ONLY_NO_SECURITY_WHATSOEVER_PLAINTEXT_TRANSPORT_OK:
      result = true;
      result = true;
+1 −1
Original line number Original line Diff line number Diff line
@@ -190,7 +190,7 @@ class SecurityManagerImpl : public channel::ISecurityManagerChannelListener, pub
  template <class T>
  template <class T>
  void HandleEvent(T packet);
  void HandleEvent(T packet);


  void DispatchPairingHandler(record::SecurityRecord& record, bool locally_initiated);
  void DispatchPairingHandler(std::shared_ptr<record::SecurityRecord> record, bool locally_initiated);
  void OnL2capRegistrationCompleteLe(l2cap::le::FixedChannelManager::RegistrationResult result,
  void OnL2capRegistrationCompleteLe(l2cap::le::FixedChannelManager::RegistrationResult result,
                                     std::unique_ptr<l2cap::le::FixedChannelService> le_smp_service);
                                     std::unique_ptr<l2cap::le::FixedChannelService> le_smp_service);
  void OnSmpCommandLe(hci::AddressWithType device);
  void OnSmpCommandLe(hci::AddressWithType device);
+14 −12
Original line number Original line Diff line number Diff line
@@ -16,6 +16,9 @@


#pragma once
#pragma once


#include <set>

#include "hci/address_with_type.h"
#include "security/record/security_record.h"
#include "security/record/security_record.h"


namespace bluetooth {
namespace bluetooth {
@@ -24,16 +27,17 @@ namespace record {


class SecurityRecordDatabase {
class SecurityRecordDatabase {
 public:
 public:
  using iterator = std::vector<record::SecurityRecord>::iterator;
  using iterator = std::set<std::shared_ptr<SecurityRecord>>::iterator;


  record::SecurityRecord& FindOrCreate(hci::AddressWithType address) {
  std::shared_ptr<SecurityRecord> FindOrCreate(hci::AddressWithType address) {
    auto it = Find(address);
    auto it = Find(address);
    // Security record check
    // Security record check
    if (it != records_.end()) return *it;
    if (it != records_.end()) return *it;


    // No security record, create one
    // No security record, create one
    records_.emplace_back(address);
    auto record_ptr = std::make_shared<SecurityRecord>(address);
    return records_.back();
    records_.insert(record_ptr);
    return record_ptr;
  }
  }


  void Remove(const hci::AddressWithType& address) {
  void Remove(const hci::AddressWithType& address) {
@@ -42,22 +46,20 @@ class SecurityRecordDatabase {
    // No record exists
    // No record exists
    if (it == records_.end()) return;
    if (it == records_.end()) return;


    record::SecurityRecord& last = records_.back();
    records_.erase(it);
    *it = std::move(last);
    records_.pop_back();
  }
  }


  iterator Find(hci::AddressWithType address) {
  iterator Find(hci::AddressWithType address) {
    for (auto it = records_.begin(); it != records_.end(); ++it) {
    for (auto it = records_.begin(); it != records_.end(); ++it) {
      record::SecurityRecord& record = *it;
      std::shared_ptr<SecurityRecord> record = *it;
      if (record.identity_address_.has_value() && record.identity_address_.value() == address) return it;
      if (record->identity_address_.has_value() && record->identity_address_.value() == address) return it;
      if (record.GetPseudoAddress() == address) return it;
      if (record->GetPseudoAddress() == address) return it;
      if (record.irk.has_value() && address.IsRpaThatMatchesIrk(record.irk.value())) return it;
      if (record->irk.has_value() && address.IsRpaThatMatchesIrk(record->irk.value())) return it;
    }
    }
    return records_.end();
    return records_.end();
  }
  }


  std::vector<record::SecurityRecord> records_;
  std::set<std::shared_ptr<SecurityRecord>> records_;
};
};


}  // namespace record
}  // namespace record