Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 8886840d authored by Treehugger Robot's avatar Treehugger Robot Committed by Gerrit Code Review
Browse files

Merge "Add LE related fields into SecurityRecord"

parents d2f7b24c 31d1b418
Loading
Loading
Loading
Loading
+6 −6
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ std::shared_ptr<bluetooth::security::record::SecurityRecord> SecurityManagerImpl
    std::shared_ptr<security::record::SecurityRecord> record =
        std::make_shared<security::record::SecurityRecord>(device);
    auto new_entry = std::pair<hci::Address, std::shared_ptr<security::record::SecurityRecord>>(
        record->GetDevice().GetAddress(), record);
        record->GetPseudoAddress().GetAddress(), record);
    // Keep track of it
    security_record_map_.insert(new_entry);
    return record;
@@ -51,23 +51,23 @@ void SecurityManagerImpl::DispatchPairingHandler(std::shared_ptr<security::recor
                                                 bool locally_initiated) {
  common::OnceCallback<void(hci::Address, PairingResultOrFailure)> callback =
      common::BindOnce(&SecurityManagerImpl::OnPairingHandlerComplete, common::Unretained(this));
  auto entry = pairing_handler_map_.find(record->GetDevice().GetAddress());
  auto entry = pairing_handler_map_.find(record->GetPseudoAddress().GetAddress());
  if (entry != pairing_handler_map_.end()) {
    LOG_WARN("Device already has a pairing handler, and is in the middle of pairing!");
    return;
  }
  std::shared_ptr<pairing::PairingHandler> pairing_handler = nullptr;
  switch (record->GetDevice().GetAddressType()) {
  switch (record->GetPseudoAddress().GetAddressType()) {
    case hci::AddressType::PUBLIC_DEVICE_ADDRESS:
      pairing_handler = std::make_shared<security::pairing::ClassicPairingHandler>(
          l2cap_classic_module_->GetFixedChannelManager(), security_manager_channel_, record, security_handler_,
          std::move(callback));
      break;
    default:
      ASSERT_LOG(false, "Pairing type %hhu not implemented!", record->GetDevice().GetAddressType());
      ASSERT_LOG(false, "Pairing type %hhu not implemented!", record->GetPseudoAddress().GetAddressType());
  }
  auto new_entry = std::pair<hci::Address, std::shared_ptr<pairing::PairingHandler>>(record->GetDevice().GetAddress(),
                                                                                     pairing_handler);
  auto new_entry = std::pair<hci::Address, std::shared_ptr<pairing::PairingHandler>>(
      record->GetPseudoAddress().GetAddress(), pairing_handler);
  pairing_handler_map_.insert(std::move(new_entry));
  pairing_handler->Initiate(locally_initiated, pairing::kDefaultIoCapability, pairing::kDefaultOobDataPresent,
                            pairing::kDefaultAuthenticationRequirements);
+25 −25
Original line number Diff line number Diff line
@@ -26,12 +26,12 @@ void ClassicPairingHandler::OnRegistrationComplete(
    std::unique_ptr<l2cap::classic::FixedChannelService> fixed_channel_service) {
  fixed_channel_service_ = std::move(fixed_channel_service);
  fixed_channel_manager_->ConnectServices(
      GetRecord()->GetDevice().GetAddress(),
      GetRecord()->GetPseudoAddress().GetAddress(),
      common::Bind(&ClassicPairingHandler::OnConnectionFail, common::Unretained(this)), security_handler_);
}

void ClassicPairingHandler::OnUnregistered() {
  std::move(complete_callback_).Run(GetRecord()->GetDevice().GetAddress(), last_status_);
  std::move(complete_callback_).Run(GetRecord()->GetPseudoAddress().GetAddress(), last_status_);
}

void ClassicPairingHandler::OnConnectionOpen(std::unique_ptr<l2cap::classic::FixedChannel> fixed_channel) {
@@ -92,20 +92,20 @@ void ClassicPairingHandler::OnReceive(hci::MasterLinkKeyCompleteView packet) {
void ClassicPairingHandler::OnReceive(hci::PinCodeRequestView packet) {
  ASSERT(packet.IsValid());
  LOG_INFO("Received: %s", hci::EventCodeText(packet.GetEventCode()).c_str());
  ASSERT_LOG(GetRecord()->GetDevice().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  ASSERT_LOG(GetRecord()->GetPseudoAddress().GetAddress() == packet.GetBdAddr(), "Address mismatch");
}

void ClassicPairingHandler::OnReceive(hci::LinkKeyRequestView packet) {
  ASSERT(packet.IsValid());
  // TODO(optedoblivion): Add collision detection here
  LOG_INFO("Received: %s", hci::EventCodeText(packet.GetEventCode()).c_str());
  ASSERT_LOG(GetRecord()->GetDevice().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  ASSERT_LOG(GetRecord()->GetPseudoAddress().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  if (GetRecord()->IsBonded() || GetRecord()->IsPaired()) {
    auto packet =
        hci::LinkKeyRequestReplyBuilder::Create(GetRecord()->GetDevice().GetAddress(), GetRecord()->GetLinkKey());
    auto packet = hci::LinkKeyRequestReplyBuilder::Create(GetRecord()->GetPseudoAddress().GetAddress(),
                                                          GetRecord()->GetLinkKey());
    this->GetChannel()->SendCommand(std::move(packet));
  } else {
    auto packet = hci::LinkKeyRequestNegativeReplyBuilder::Create(GetRecord()->GetDevice().GetAddress());
    auto packet = hci::LinkKeyRequestNegativeReplyBuilder::Create(GetRecord()->GetPseudoAddress().GetAddress());
    this->GetChannel()->SendCommand(std::move(packet));
  }
}
@@ -113,26 +113,26 @@ void ClassicPairingHandler::OnReceive(hci::LinkKeyRequestView packet) {
void ClassicPairingHandler::OnReceive(hci::LinkKeyNotificationView packet) {
  ASSERT(packet.IsValid());
  LOG_INFO("Received: %s", hci::EventCodeText(packet.GetEventCode()).c_str());
  ASSERT_LOG(GetRecord()->GetDevice().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  ASSERT_LOG(GetRecord()->GetPseudoAddress().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  GetRecord()->SetLinkKey(packet.GetLinkKey(), packet.GetKeyType());
}

void ClassicPairingHandler::OnReceive(hci::IoCapabilityRequestView packet) {
  ASSERT(packet.IsValid());
  LOG_INFO("Received: %s", hci::EventCodeText(packet.GetEventCode()).c_str());
  ASSERT_LOG(GetRecord()->GetDevice().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  ASSERT_LOG(GetRecord()->GetPseudoAddress().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  hci::IoCapability io_capability = local_io_capability_;
  hci::OobDataPresent oob_present = hci::OobDataPresent::NOT_PRESENT;
  hci::AuthenticationRequirements authentication_requirements = local_authentication_requirements_;
  auto reply_packet = hci::IoCapabilityRequestReplyBuilder::Create(GetRecord()->GetDevice().GetAddress(), io_capability,
                                                                   oob_present, authentication_requirements);
  auto reply_packet = hci::IoCapabilityRequestReplyBuilder::Create(
      GetRecord()->GetPseudoAddress().GetAddress(), io_capability, oob_present, authentication_requirements);
  this->GetChannel()->SendCommand(std::move(reply_packet));
}

void ClassicPairingHandler::OnReceive(hci::IoCapabilityResponseView packet) {
  ASSERT(packet.IsValid());
  LOG_INFO("Received: %s", hci::EventCodeText(packet.GetEventCode()).c_str());
  ASSERT_LOG(GetRecord()->GetDevice().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  ASSERT_LOG(GetRecord()->GetPseudoAddress().GetAddress() == packet.GetBdAddr(), "Address mismatch");

  // Using local variable until device database pointer is ready
  remote_io_capability_ = packet.GetIoCapability();
@@ -142,7 +142,7 @@ void ClassicPairingHandler::OnReceive(hci::IoCapabilityResponseView packet) {
void ClassicPairingHandler::OnReceive(hci::SimplePairingCompleteView packet) {
  ASSERT(packet.IsValid());
  LOG_INFO("Received: %s", hci::EventCodeText(packet.GetEventCode()).c_str());
  ASSERT_LOG(GetRecord()->GetDevice().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  ASSERT_LOG(GetRecord()->GetPseudoAddress().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  Cancel();
}

@@ -164,13 +164,13 @@ void ClassicPairingHandler::OnReceive(hci::EncryptionKeyRefreshCompleteView pack
void ClassicPairingHandler::OnReceive(hci::RemoteOobDataRequestView packet) {
  ASSERT(packet.IsValid());
  LOG_INFO("Received: %s", hci::EventCodeText(packet.GetEventCode()).c_str());
  ASSERT_LOG(GetRecord()->GetDevice().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  ASSERT_LOG(GetRecord()->GetPseudoAddress().GetAddress() == packet.GetBdAddr(), "Address mismatch");
}

void ClassicPairingHandler::OnReceive(hci::UserPasskeyNotificationView packet) {
  ASSERT(packet.IsValid());
  LOG_INFO("Received: %s", hci::EventCodeText(packet.GetEventCode()).c_str());
  ASSERT_LOG(GetRecord()->GetDevice().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  ASSERT_LOG(GetRecord()->GetPseudoAddress().GetAddress() == packet.GetBdAddr(), "Address mismatch");
}

void ClassicPairingHandler::OnReceive(hci::KeypressNotificationView packet) {
@@ -204,7 +204,7 @@ void ClassicPairingHandler::OnReceive(hci::KeypressNotificationView packet) {
void ClassicPairingHandler::OnReceive(hci::UserConfirmationRequestView packet) {
  ASSERT(packet.IsValid());
  LOG_INFO("Received: %s", hci::EventCodeText(packet.GetEventCode()).c_str());
  ASSERT_LOG(GetRecord()->GetDevice().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  ASSERT_LOG(GetRecord()->GetPseudoAddress().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  // if locally_initialized, use default, otherwise us remote io caps
  hci::IoCapability initiator_io_capability = (locally_initiated_) ? local_io_capability_ : remote_io_capability_;
  hci::IoCapability responder_io_capability = (!locally_initiated_) ? local_io_capability_ : remote_io_capability_;
@@ -216,13 +216,13 @@ void ClassicPairingHandler::OnReceive(hci::UserConfirmationRequestView packet) {
          // NumericComparison, Both auto confirm
          LOG_INFO("Numeric Comparison: A and B auto confirm");
          GetChannel()->SendCommand(
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetDevice().GetAddress()));
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetPseudoAddress().GetAddress()));
          // Unauthenticated
          break;
        case hci::IoCapability::DISPLAY_YES_NO:
          // NumericComparison, Initiator auto confirm, Responder display
          GetChannel()->SendCommand(
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetDevice().GetAddress()));
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetPseudoAddress().GetAddress()));
          LOG_INFO("Numeric Comparison: A auto confirm");
          // Unauthenticated
          break;
@@ -236,7 +236,7 @@ void ClassicPairingHandler::OnReceive(hci::UserConfirmationRequestView packet) {
          // NumericComparison, Both auto confirm
          LOG_INFO("Numeric Comparison: A and B auto confirm");
          GetChannel()->SendCommand(
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetDevice().GetAddress()));
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetPseudoAddress().GetAddress()));
          // Unauthenticated
          break;
      }
@@ -293,7 +293,7 @@ void ClassicPairingHandler::OnReceive(hci::UserConfirmationRequestView packet) {
          // NumericComparison, both auto confirm
          LOG_INFO("Numeric Comparison: A and B auto confirm");
          GetChannel()->SendCommand(
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetDevice().GetAddress()));
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetPseudoAddress().GetAddress()));
          // Unauthenticated
          break;
      }
@@ -304,28 +304,28 @@ void ClassicPairingHandler::OnReceive(hci::UserConfirmationRequestView packet) {
          // NumericComparison, both auto confirm
          LOG_INFO("Numeric Comparison: A and B auto confirm");
          GetChannel()->SendCommand(
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetDevice().GetAddress()));
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetPseudoAddress().GetAddress()));
          // Unauthenticated
          break;
        case hci::IoCapability::DISPLAY_YES_NO:
          // NumericComparison, Initiator auto confirm, Responder Yes/No confirm, no show conf val
          LOG_INFO("Numeric Comparison: A auto confirm");
          GetChannel()->SendCommand(
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetDevice().GetAddress()));
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetPseudoAddress().GetAddress()));
          // Unauthenticated
          break;
        case hci::IoCapability::KEYBOARD_ONLY:
          // NumericComparison, both auto confirm
          LOG_INFO("Numeric Comparison: A and B auto confirm");
          GetChannel()->SendCommand(
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetDevice().GetAddress()));
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetPseudoAddress().GetAddress()));
          // Unauthenticated
          break;
        case hci::IoCapability::NO_INPUT_NO_OUTPUT:
          // NumericComparison, both auto confirm
          LOG_INFO("Numeric Comparison: A and B auto confirm");
          GetChannel()->SendCommand(
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetDevice().GetAddress()));
              hci::UserConfirmationRequestReplyBuilder::Create(GetRecord()->GetPseudoAddress().GetAddress()));
          // Unauthenticated
          break;
      }
@@ -335,7 +335,7 @@ void ClassicPairingHandler::OnReceive(hci::UserConfirmationRequestView packet) {

void ClassicPairingHandler::OnReceive(hci::UserPasskeyRequestView packet) {
  ASSERT(packet.IsValid());
  ASSERT_LOG(GetRecord()->GetDevice().GetAddress() == packet.GetBdAddr(), "Address mismatch");
  ASSERT_LOG(GetRecord()->GetPseudoAddress().GetAddress() == packet.GetBdAddr(), "Address mismatch");
}

}  // namespace pairing
+16 −4
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@
#include <memory>
#include <utility>

#include "crypto_toolbox/crypto_toolbox.h"
#include "hci/address_with_type.h"

namespace bluetooth {
@@ -31,7 +32,7 @@ enum BondState { NOT_BONDED, PAIRING, PAIRED, BONDED };

class SecurityRecord {
 public:
  explicit SecurityRecord(hci::AddressWithType device) : device_(device), state_(NOT_BONDED) {}
  explicit SecurityRecord(hci::AddressWithType address) : pseudo_address_(address), state_(NOT_BONDED) {}

  /**
   * Returns true if the device is bonded to another device
@@ -64,15 +65,26 @@ class SecurityRecord {
    return key_type_;
  }

  hci::AddressWithType GetDevice() {
    return device_;
  hci::AddressWithType GetPseudoAddress() {
    return pseudo_address_;
  }

 private:
  const hci::AddressWithType device_;
  /* First address we have ever seen this device with, that we used to create bond */
  const hci::AddressWithType pseudo_address_;

  /* Identity Address */
  std::optional<hci::AddressWithType> identity_address_;

  BondState state_;
  std::array<uint8_t, 16> link_key_ = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
  hci::KeyType key_type_ = hci::KeyType::DEBUG_COMBINATION;

  std::optional<crypto_toolbox::Octet16> ltk;
  std::optional<uint16_t> ediv;
  std::optional<std::array<uint8_t, 8>> rand;
  std::optional<crypto_toolbox::Octet16> irk;
  std::optional<crypto_toolbox::Octet16> signature_key;
};

}  // namespace record