Loading system/gd/security/internal/security_manager_impl.cc +19 −15 Original line number Original line Diff line number Diff line Loading @@ -33,29 +33,33 @@ namespace bluetooth { namespace security { namespace security { namespace internal { namespace internal { void SecurityManagerImpl::DispatchPairingHandler(record::SecurityRecord& record, bool locally_initiated) { void SecurityManagerImpl::DispatchPairingHandler( std::shared_ptr<record::SecurityRecord> record, bool locally_initiated) { common::OnceCallback<void(hci::Address, PairingResultOrFailure)> callback = common::OnceCallback<void(hci::Address, PairingResultOrFailure)> callback = common::BindOnce(&SecurityManagerImpl::OnPairingHandlerComplete, common::Unretained(this)); common::BindOnce(&SecurityManagerImpl::OnPairingHandlerComplete, common::Unretained(this)); auto entry = pairing_handler_map_.find(record.GetPseudoAddress().GetAddress()); auto entry = pairing_handler_map_.find(record->GetPseudoAddress().GetAddress()); if (entry != pairing_handler_map_.end()) { if (entry != pairing_handler_map_.end()) { LOG_WARN("Device already has a pairing handler, and is in the middle of pairing!"); LOG_WARN("Device already has a pairing handler, and is in the middle of pairing!"); return; return; } } std::shared_ptr<pairing::PairingHandler> pairing_handler = nullptr; std::shared_ptr<pairing::PairingHandler> pairing_handler = nullptr; switch (record.GetPseudoAddress().GetAddressType()) { switch (record->GetPseudoAddress().GetAddressType()) { case hci::AddressType::PUBLIC_DEVICE_ADDRESS: { case hci::AddressType::PUBLIC_DEVICE_ADDRESS: { std::shared_ptr<record::SecurityRecord> record_copy = std::make_shared<record::SecurityRecord>(record.GetPseudoAddress()); pairing_handler = std::make_shared<security::pairing::ClassicPairingHandler>( pairing_handler = std::make_shared<security::pairing::ClassicPairingHandler>( security_manager_channel_, record_copy, security_handler_, std::move(callback), user_interface_, security_manager_channel_, user_interface_handler_, "TODO: grab device name properly"); record, security_handler_, std::move(callback), user_interface_, user_interface_handler_, "TODO: grab device name properly"); break; break; } } default: default: ASSERT_LOG(false, "Pairing type %hhu not implemented!", record.GetPseudoAddress().GetAddressType()); ASSERT_LOG(false, "Pairing type %hhu not implemented!", record->GetPseudoAddress().GetAddressType()); } } auto new_entry = std::pair<hci::Address, std::shared_ptr<pairing::PairingHandler>>( auto new_entry = std::pair<hci::Address, std::shared_ptr<pairing::PairingHandler>>( record.GetPseudoAddress().GetAddress(), pairing_handler); record->GetPseudoAddress().GetAddress(), pairing_handler); pairing_handler_map_.insert(std::move(new_entry)); pairing_handler_map_.insert(std::move(new_entry)); pairing_handler->Initiate(locally_initiated, this->local_io_capability_, this->local_oob_data_present_, pairing_handler->Initiate(locally_initiated, this->local_io_capability_, this->local_oob_data_present_, this->local_authentication_requirements_); this->local_authentication_requirements_); Loading @@ -69,8 +73,8 @@ void SecurityManagerImpl::Init() { } } void SecurityManagerImpl::CreateBond(hci::AddressWithType device) { void SecurityManagerImpl::CreateBond(hci::AddressWithType device) { record::SecurityRecord& record = security_database_.FindOrCreate(device); auto record = security_database_.FindOrCreate(device); if (record.IsBonded()) { if (record->IsBonded()) { NotifyDeviceBonded(device); NotifyDeviceBonded(device); } else { } else { // Dispatch pairing handler, if we are calling create we are the initiator // Dispatch pairing handler, if we are calling create we are the initiator Loading @@ -79,8 +83,8 @@ void SecurityManagerImpl::CreateBond(hci::AddressWithType device) { } } void SecurityManagerImpl::CreateBondLe(hci::AddressWithType address) { void SecurityManagerImpl::CreateBondLe(hci::AddressWithType address) { record::SecurityRecord& record = security_database_.FindOrCreate(address); auto record = security_database_.FindOrCreate(address); if (record.IsBonded()) { if (record->IsBonded()) { NotifyDeviceBondFailed(address, PairingFailure("Already bonded")); NotifyDeviceBondFailed(address, PairingFailure("Already bonded")); return; return; } } Loading Loading @@ -581,10 +585,10 @@ void SecurityManagerImpl::InternalEnforceSecurityPolicy( switch (policy) { switch (policy) { case l2cap::classic::SecurityPolicy::BEST: case l2cap::classic::SecurityPolicy::BEST: case l2cap::classic::SecurityPolicy::AUTHENTICATED_ENCRYPTED_TRANSPORT: case l2cap::classic::SecurityPolicy::AUTHENTICATED_ENCRYPTED_TRANSPORT: result = record.IsAuthenticated() && record.RequiresMitmProtection() && record.IsEncryptionRequired(); result = record->IsAuthenticated() && record->RequiresMitmProtection() && record->IsEncryptionRequired(); break; break; case l2cap::classic::SecurityPolicy::ENCRYPTED_TRANSPORT: case l2cap::classic::SecurityPolicy::ENCRYPTED_TRANSPORT: result = record.IsAuthenticated() && record.IsEncryptionRequired(); result = record->IsAuthenticated() && record->IsEncryptionRequired(); break; break; case l2cap::classic::SecurityPolicy::_SDP_ONLY_NO_SECURITY_WHATSOEVER_PLAINTEXT_TRANSPORT_OK: case l2cap::classic::SecurityPolicy::_SDP_ONLY_NO_SECURITY_WHATSOEVER_PLAINTEXT_TRANSPORT_OK: result = true; result = true; Loading system/gd/security/internal/security_manager_impl.h +1 −1 Original line number Original line Diff line number Diff line Loading @@ -190,7 +190,7 @@ class SecurityManagerImpl : public channel::ISecurityManagerChannelListener, pub template <class T> template <class T> void HandleEvent(T packet); void HandleEvent(T packet); void DispatchPairingHandler(record::SecurityRecord& record, bool locally_initiated); void DispatchPairingHandler(std::shared_ptr<record::SecurityRecord> record, bool locally_initiated); void OnL2capRegistrationCompleteLe(l2cap::le::FixedChannelManager::RegistrationResult result, void OnL2capRegistrationCompleteLe(l2cap::le::FixedChannelManager::RegistrationResult result, std::unique_ptr<l2cap::le::FixedChannelService> le_smp_service); std::unique_ptr<l2cap::le::FixedChannelService> le_smp_service); void OnSmpCommandLe(hci::AddressWithType device); void OnSmpCommandLe(hci::AddressWithType device); Loading system/gd/security/record/security_record_database.h +14 −12 Original line number Original line Diff line number Diff line Loading @@ -16,6 +16,9 @@ #pragma once #pragma once #include <set> #include "hci/address_with_type.h" #include "security/record/security_record.h" #include "security/record/security_record.h" namespace bluetooth { namespace bluetooth { Loading @@ -24,16 +27,17 @@ namespace record { class SecurityRecordDatabase { class SecurityRecordDatabase { public: public: using iterator = std::vector<record::SecurityRecord>::iterator; using iterator = std::set<std::shared_ptr<SecurityRecord>>::iterator; record::SecurityRecord& FindOrCreate(hci::AddressWithType address) { std::shared_ptr<SecurityRecord> FindOrCreate(hci::AddressWithType address) { auto it = Find(address); auto it = Find(address); // Security record check // Security record check if (it != records_.end()) return *it; if (it != records_.end()) return *it; // No security record, create one // No security record, create one records_.emplace_back(address); auto record_ptr = std::make_shared<SecurityRecord>(address); return records_.back(); records_.insert(record_ptr); return record_ptr; } } void Remove(const hci::AddressWithType& address) { void Remove(const hci::AddressWithType& address) { Loading @@ -42,22 +46,20 @@ class SecurityRecordDatabase { // No record exists // No record exists if (it == records_.end()) return; if (it == records_.end()) return; record::SecurityRecord& last = records_.back(); records_.erase(it); *it = std::move(last); records_.pop_back(); } } iterator Find(hci::AddressWithType address) { iterator Find(hci::AddressWithType address) { for (auto it = records_.begin(); it != records_.end(); ++it) { for (auto it = records_.begin(); it != records_.end(); ++it) { record::SecurityRecord& record = *it; std::shared_ptr<SecurityRecord> record = *it; if (record.identity_address_.has_value() && record.identity_address_.value() == address) return it; if (record->identity_address_.has_value() && record->identity_address_.value() == address) return it; if (record.GetPseudoAddress() == address) return it; if (record->GetPseudoAddress() == address) return it; if (record.irk.has_value() && address.IsRpaThatMatchesIrk(record.irk.value())) return it; if (record->irk.has_value() && address.IsRpaThatMatchesIrk(record->irk.value())) return it; } } return records_.end(); return records_.end(); } } std::vector<record::SecurityRecord> records_; std::set<std::shared_ptr<SecurityRecord>> records_; }; }; } // namespace record } // namespace record Loading Loading
system/gd/security/internal/security_manager_impl.cc +19 −15 Original line number Original line Diff line number Diff line Loading @@ -33,29 +33,33 @@ namespace bluetooth { namespace security { namespace security { namespace internal { namespace internal { void SecurityManagerImpl::DispatchPairingHandler(record::SecurityRecord& record, bool locally_initiated) { void SecurityManagerImpl::DispatchPairingHandler( std::shared_ptr<record::SecurityRecord> record, bool locally_initiated) { common::OnceCallback<void(hci::Address, PairingResultOrFailure)> callback = common::OnceCallback<void(hci::Address, PairingResultOrFailure)> callback = common::BindOnce(&SecurityManagerImpl::OnPairingHandlerComplete, common::Unretained(this)); common::BindOnce(&SecurityManagerImpl::OnPairingHandlerComplete, common::Unretained(this)); auto entry = pairing_handler_map_.find(record.GetPseudoAddress().GetAddress()); auto entry = pairing_handler_map_.find(record->GetPseudoAddress().GetAddress()); if (entry != pairing_handler_map_.end()) { if (entry != pairing_handler_map_.end()) { LOG_WARN("Device already has a pairing handler, and is in the middle of pairing!"); LOG_WARN("Device already has a pairing handler, and is in the middle of pairing!"); return; return; } } std::shared_ptr<pairing::PairingHandler> pairing_handler = nullptr; std::shared_ptr<pairing::PairingHandler> pairing_handler = nullptr; switch (record.GetPseudoAddress().GetAddressType()) { switch (record->GetPseudoAddress().GetAddressType()) { case hci::AddressType::PUBLIC_DEVICE_ADDRESS: { case hci::AddressType::PUBLIC_DEVICE_ADDRESS: { std::shared_ptr<record::SecurityRecord> record_copy = std::make_shared<record::SecurityRecord>(record.GetPseudoAddress()); pairing_handler = std::make_shared<security::pairing::ClassicPairingHandler>( pairing_handler = std::make_shared<security::pairing::ClassicPairingHandler>( security_manager_channel_, record_copy, security_handler_, std::move(callback), user_interface_, security_manager_channel_, user_interface_handler_, "TODO: grab device name properly"); record, security_handler_, std::move(callback), user_interface_, user_interface_handler_, "TODO: grab device name properly"); break; break; } } default: default: ASSERT_LOG(false, "Pairing type %hhu not implemented!", record.GetPseudoAddress().GetAddressType()); ASSERT_LOG(false, "Pairing type %hhu not implemented!", record->GetPseudoAddress().GetAddressType()); } } auto new_entry = std::pair<hci::Address, std::shared_ptr<pairing::PairingHandler>>( auto new_entry = std::pair<hci::Address, std::shared_ptr<pairing::PairingHandler>>( record.GetPseudoAddress().GetAddress(), pairing_handler); record->GetPseudoAddress().GetAddress(), pairing_handler); pairing_handler_map_.insert(std::move(new_entry)); pairing_handler_map_.insert(std::move(new_entry)); pairing_handler->Initiate(locally_initiated, this->local_io_capability_, this->local_oob_data_present_, pairing_handler->Initiate(locally_initiated, this->local_io_capability_, this->local_oob_data_present_, this->local_authentication_requirements_); this->local_authentication_requirements_); Loading @@ -69,8 +73,8 @@ void SecurityManagerImpl::Init() { } } void SecurityManagerImpl::CreateBond(hci::AddressWithType device) { void SecurityManagerImpl::CreateBond(hci::AddressWithType device) { record::SecurityRecord& record = security_database_.FindOrCreate(device); auto record = security_database_.FindOrCreate(device); if (record.IsBonded()) { if (record->IsBonded()) { NotifyDeviceBonded(device); NotifyDeviceBonded(device); } else { } else { // Dispatch pairing handler, if we are calling create we are the initiator // Dispatch pairing handler, if we are calling create we are the initiator Loading @@ -79,8 +83,8 @@ void SecurityManagerImpl::CreateBond(hci::AddressWithType device) { } } void SecurityManagerImpl::CreateBondLe(hci::AddressWithType address) { void SecurityManagerImpl::CreateBondLe(hci::AddressWithType address) { record::SecurityRecord& record = security_database_.FindOrCreate(address); auto record = security_database_.FindOrCreate(address); if (record.IsBonded()) { if (record->IsBonded()) { NotifyDeviceBondFailed(address, PairingFailure("Already bonded")); NotifyDeviceBondFailed(address, PairingFailure("Already bonded")); return; return; } } Loading Loading @@ -581,10 +585,10 @@ void SecurityManagerImpl::InternalEnforceSecurityPolicy( switch (policy) { switch (policy) { case l2cap::classic::SecurityPolicy::BEST: case l2cap::classic::SecurityPolicy::BEST: case l2cap::classic::SecurityPolicy::AUTHENTICATED_ENCRYPTED_TRANSPORT: case l2cap::classic::SecurityPolicy::AUTHENTICATED_ENCRYPTED_TRANSPORT: result = record.IsAuthenticated() && record.RequiresMitmProtection() && record.IsEncryptionRequired(); result = record->IsAuthenticated() && record->RequiresMitmProtection() && record->IsEncryptionRequired(); break; break; case l2cap::classic::SecurityPolicy::ENCRYPTED_TRANSPORT: case l2cap::classic::SecurityPolicy::ENCRYPTED_TRANSPORT: result = record.IsAuthenticated() && record.IsEncryptionRequired(); result = record->IsAuthenticated() && record->IsEncryptionRequired(); break; break; case l2cap::classic::SecurityPolicy::_SDP_ONLY_NO_SECURITY_WHATSOEVER_PLAINTEXT_TRANSPORT_OK: case l2cap::classic::SecurityPolicy::_SDP_ONLY_NO_SECURITY_WHATSOEVER_PLAINTEXT_TRANSPORT_OK: result = true; result = true; Loading
system/gd/security/internal/security_manager_impl.h +1 −1 Original line number Original line Diff line number Diff line Loading @@ -190,7 +190,7 @@ class SecurityManagerImpl : public channel::ISecurityManagerChannelListener, pub template <class T> template <class T> void HandleEvent(T packet); void HandleEvent(T packet); void DispatchPairingHandler(record::SecurityRecord& record, bool locally_initiated); void DispatchPairingHandler(std::shared_ptr<record::SecurityRecord> record, bool locally_initiated); void OnL2capRegistrationCompleteLe(l2cap::le::FixedChannelManager::RegistrationResult result, void OnL2capRegistrationCompleteLe(l2cap::le::FixedChannelManager::RegistrationResult result, std::unique_ptr<l2cap::le::FixedChannelService> le_smp_service); std::unique_ptr<l2cap::le::FixedChannelService> le_smp_service); void OnSmpCommandLe(hci::AddressWithType device); void OnSmpCommandLe(hci::AddressWithType device); Loading
system/gd/security/record/security_record_database.h +14 −12 Original line number Original line Diff line number Diff line Loading @@ -16,6 +16,9 @@ #pragma once #pragma once #include <set> #include "hci/address_with_type.h" #include "security/record/security_record.h" #include "security/record/security_record.h" namespace bluetooth { namespace bluetooth { Loading @@ -24,16 +27,17 @@ namespace record { class SecurityRecordDatabase { class SecurityRecordDatabase { public: public: using iterator = std::vector<record::SecurityRecord>::iterator; using iterator = std::set<std::shared_ptr<SecurityRecord>>::iterator; record::SecurityRecord& FindOrCreate(hci::AddressWithType address) { std::shared_ptr<SecurityRecord> FindOrCreate(hci::AddressWithType address) { auto it = Find(address); auto it = Find(address); // Security record check // Security record check if (it != records_.end()) return *it; if (it != records_.end()) return *it; // No security record, create one // No security record, create one records_.emplace_back(address); auto record_ptr = std::make_shared<SecurityRecord>(address); return records_.back(); records_.insert(record_ptr); return record_ptr; } } void Remove(const hci::AddressWithType& address) { void Remove(const hci::AddressWithType& address) { Loading @@ -42,22 +46,20 @@ class SecurityRecordDatabase { // No record exists // No record exists if (it == records_.end()) return; if (it == records_.end()) return; record::SecurityRecord& last = records_.back(); records_.erase(it); *it = std::move(last); records_.pop_back(); } } iterator Find(hci::AddressWithType address) { iterator Find(hci::AddressWithType address) { for (auto it = records_.begin(); it != records_.end(); ++it) { for (auto it = records_.begin(); it != records_.end(); ++it) { record::SecurityRecord& record = *it; std::shared_ptr<SecurityRecord> record = *it; if (record.identity_address_.has_value() && record.identity_address_.value() == address) return it; if (record->identity_address_.has_value() && record->identity_address_.value() == address) return it; if (record.GetPseudoAddress() == address) return it; if (record->GetPseudoAddress() == address) return it; if (record.irk.has_value() && address.IsRpaThatMatchesIrk(record.irk.value())) return it; if (record->irk.has_value() && address.IsRpaThatMatchesIrk(record->irk.value())) return it; } } return records_.end(); return records_.end(); } } std::vector<record::SecurityRecord> records_; std::set<std::shared_ptr<SecurityRecord>> records_; }; }; } // namespace record } // namespace record Loading