Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c7b5554a authored by Jack He's avatar Jack He
Browse files

HCI: Use unique_ptr to pass AclConnection

* Use std::unique_ptr to pass AclConnection so that the AclConnection
  object can be mocked
* Add mocks for AclConnection and AclManager

Test: bluetooth_test_gd
Bug: 139700781
Change-Id: If7403207843d356330b6bd4875683df7966623e9
parent 753f25dc
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -238,7 +238,7 @@ struct AclManager::impl {
      if (acl_connections_.size() == 1 && packet_to_send_ == nullptr) {
        start_round_robin();
      }
      AclConnection connection_proxy{&acl_manager_, handle, address};
      std::unique_ptr<AclConnection> connection_proxy(new AclConnection(&acl_manager_, handle, address));
      client_handler_->Post(common::BindOnce(&ConnectionCallbacks::OnConnectSuccess,
                                             common::Unretained(client_callbacks_), std::move(connection_proxy)));
    } else {
@@ -451,5 +451,7 @@ void AclManager::Stop() {

const ModuleFactory AclManager::Factory = ModuleFactory([]() { return new AclManager(); });

AclManager::~AclManager() = default;

}  // namespace hci
}  // namespace bluetooth
+17 −9
Original line number Diff line number Diff line
@@ -33,9 +33,10 @@ class AclManager;

class AclConnection {
 public:
  AclConnection() : manager_(nullptr) {}
  AclConnection() : manager_(nullptr), handle_(0), address_(common::Address::kEmpty){};
  virtual ~AclConnection() = default;

  common::Address GetAddress() const {
  virtual common::Address GetAddress() const {
    return address_;
  }

@@ -47,10 +48,10 @@ class AclConnection {
  using QueueUpEnd = common::BidiQueueEnd<BasePacketBuilder, PacketView<kLittleEndian>>;
  using QueueDownEnd = common::BidiQueueEnd<PacketView<kLittleEndian>, BasePacketBuilder>;
  QueueUpEnd* GetAclQueueEnd() const;
  void RegisterDisconnectCallback(common::OnceCallback<void(ErrorCode)> on_disconnect, os::Handler* handler);
  bool Disconnect(DisconnectReason);
  virtual void RegisterDisconnectCallback(common::OnceCallback<void(ErrorCode)> on_disconnect, os::Handler* handler);
  virtual bool Disconnect(DisconnectReason reason);
  // Ask AclManager to clean me up. Must invoke after on_disconnect is called
  void Finish();
  virtual void Finish();

  // TODO: API to change link settings ... ?

@@ -61,13 +62,14 @@ class AclConnection {
  AclManager* manager_;
  uint16_t handle_;
  common::Address address_;
  DISALLOW_COPY_AND_ASSIGN(AclConnection);
};

class ConnectionCallbacks {
 public:
  virtual ~ConnectionCallbacks() = default;
  // Invoked when controller sends Connection Complete event with Success error code
  virtual void OnConnectSuccess(AclConnection /* , initiated_by_local ? */) = 0;
  virtual void OnConnectSuccess(std::unique_ptr<AclConnection> /* , initiated_by_local ? */) = 0;
  // Invoked when controller sends Connection Complete event with non-Success error code
  virtual void OnConnectFail(common::Address, ErrorCode reason) = 0;
};
@@ -75,17 +77,22 @@ class ConnectionCallbacks {
class AclManager : public Module {
 public:
  AclManager();
  // NOTE: It is necessary to forward declare a default destructor that overrides the base class one, because
  // "struct impl" is forwarded declared in .cc and compiler needs a concrete definition of "struct impl" when
  // compiling AclManager's destructor. Hence we need to forward declare the destructor for AclManager to delay
  // compiling AclManager's destructor until it starts linking the .cc file.
  ~AclManager() override;

  // Returns true if callbacks are successfully registered. Should register only once when user module starts.
  // Generates OnConnectSuccess when an incoming connection is established.
  bool RegisterCallbacks(ConnectionCallbacks* callbacks, os::Handler* handler);
  virtual bool RegisterCallbacks(ConnectionCallbacks* callbacks, os::Handler* handler);

  // Generates OnConnectSuccess if connected, or OnConnectFail otherwise
  void CreateConnection(common::Address address);
  virtual void CreateConnection(common::Address address);

  // Generates OnConnectFail with error code "terminated by local host 0x16" if cancelled, or OnConnectSuccess if not
  // successfully cancelled and already connected
  void CancelConnect(common::Address address);
  virtual void CancelConnect(common::Address address);

  static const ModuleFactory Factory;

@@ -103,6 +110,7 @@ class AclManager : public Module {
  std::unique_ptr<impl> pimpl_;

  struct acl_connection;
  DISALLOW_COPY_AND_ASSIGN(AclManager);
};

}  // namespace hci
+45 −0
Original line number Diff line number Diff line
/*
 * Copyright 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

#include "hci/acl_manager.h"

#include <gmock/gmock.h>

// Unit test interfaces
namespace bluetooth {
namespace hci {
namespace testing {

class MockAclConnection : public AclConnection {
 public:
  MOCK_METHOD(common::Address, GetAddress, (), (const, override));
  MOCK_METHOD(void, RegisterDisconnectCallback,
              (common::OnceCallback<void(ErrorCode)> on_disconnect, os::Handler* handler), (override));
  MOCK_METHOD(bool, Disconnect, (DisconnectReason reason), (override));
  MOCK_METHOD(void, Finish, (), (override));
};

class MockAclManager : public AclManager {
 public:
  MOCK_METHOD(bool, RegisterCallbacks, (ConnectionCallbacks * callbacks, os::Handler* handler), (override));
  MOCK_METHOD(void, CreateConnection, (common::Address address), (override));
  MOCK_METHOD(void, CancelConnect, (common::Address address), (override));
};

}  // namespace testing
}  // namespace hci
}  // namespace bluetooth
 No newline at end of file
+25 −22
Original line number Diff line number Diff line
@@ -218,12 +218,12 @@ class AclManagerNoCallbacksTest : public ::testing::Test {
    return mock_connection_callback_.connection_promise_->get_future();
  }

  AclConnection& GetLastConnection() {
  std::shared_ptr<AclConnection> GetLastConnection() {
    return mock_connection_callback_.connections_.back();
  }

  void SendAclData(uint16_t handle, AclConnection connection) {
    auto queue_end = connection.GetAclQueueEnd();
  void SendAclData(uint16_t handle, std::shared_ptr<AclConnection> connection) {
    auto queue_end = connection->GetAclQueueEnd();
    std::promise<void> promise;
    auto future = promise.get_future();
    queue_end->RegisterEnqueue(client_handler_,
@@ -240,8 +240,9 @@ class AclManagerNoCallbacksTest : public ::testing::Test {

  class MockConnectionCallback : public ConnectionCallbacks {
   public:
    void OnConnectSuccess(AclConnection connection) override {
      connections_.push_back(connection);
    void OnConnectSuccess(std::unique_ptr<AclConnection> connection) override {
      // Convert to std::shared_ptr during push_back()
      connections_.push_back(std::move(connection));
      if (connection_promise_ != nullptr) {
        connection_promise_->set_value();
        connection_promise_.reset();
@@ -249,7 +250,7 @@ class AclManagerNoCallbacksTest : public ::testing::Test {
    }
    MOCK_METHOD2(OnConnectFail, void(Address, ErrorCode reason));

    std::list<AclConnection> connections_;
    std::list<std::shared_ptr<AclConnection>> connections_;
    std::unique_ptr<std::promise<void>> connection_promise_;
  } mock_connection_callback_;
};
@@ -299,8 +300,8 @@ TEST_F(AclManagerTest, invoke_registered_callback_connection_complete_success) {
  auto first_connection_status = first_connection.wait_for(kTimeout);
  ASSERT_EQ(first_connection_status, std::future_status::ready);

  AclConnection& connection = GetLastConnection();
  ASSERT_EQ(connection.GetAddress(), remote);
  std::shared_ptr<AclConnection> connection = GetLastConnection();
  ASSERT_EQ(connection->GetAddress(), remote);
}

TEST_F(AclManagerTest, invoke_registered_callback_connection_complete_fail) {
@@ -341,12 +342,12 @@ TEST_F(AclManagerTest, invoke_registered_callback_disconnection_complete) {
  auto first_connection_status = first_connection.wait_for(kTimeout);
  ASSERT_EQ(first_connection_status, std::future_status::ready);

  AclConnection& connection = GetLastConnection();
  std::shared_ptr<AclConnection> connection = GetLastConnection();

  // Register the disconnect handler
  std::promise<ErrorCode> promise;
  auto future = promise.get_future();
  connection.RegisterDisconnectCallback(
  connection->RegisterDisconnectCallback(
      common::BindOnce([](std::promise<ErrorCode> promise, ErrorCode reason) { promise.set_value(reason); },
                       std::move(promise)),
      client_handler_);
@@ -380,12 +381,12 @@ TEST_F(AclManagerTest, acl_connection_finish_after_disconnected) {
  auto first_connection_status = first_connection.wait_for(kTimeout);
  ASSERT_EQ(first_connection_status, std::future_status::ready);

  AclConnection& connection = GetLastConnection();
  std::shared_ptr<AclConnection> connection = GetLastConnection();

  // Register the disconnect handler
  std::promise<ErrorCode> promise;
  auto future = promise.get_future();
  connection.RegisterDisconnectCallback(
  connection->RegisterDisconnectCallback(
      common::BindOnce([](std::promise<ErrorCode> promise, ErrorCode reason) { promise.set_value(reason); },
                       std::move(promise)),
      client_handler_);
@@ -397,7 +398,7 @@ TEST_F(AclManagerTest, acl_connection_finish_after_disconnected) {
  ASSERT_EQ(disconnection_status, std::future_status::ready);
  ASSERT_EQ(ErrorCode::REMOTE_DEVICE_TERMINATED_CONNECTION_POWER_OFF, future.get());

  connection.Finish();
  connection->Finish();
}

TEST_F(AclManagerTest, acl_send_data_one_connection) {
@@ -419,15 +420,16 @@ TEST_F(AclManagerTest, acl_send_data_one_connection) {
  auto first_connection_status = first_connection.wait_for(kTimeout);
  ASSERT_EQ(first_connection_status, std::future_status::ready);

  AclConnection& connection = GetLastConnection();
  std::shared_ptr<AclConnection> connection = GetLastConnection();

  // Register the disconnect handler
  connection.RegisterDisconnectCallback(common::Bind([](AclConnection conn, ErrorCode) { conn.Finish(); }, connection),
  connection->RegisterDisconnectCallback(
      common::Bind([](std::shared_ptr<AclConnection> conn, ErrorCode) { conn->Finish(); }, connection),
      client_handler_);

  // Send a packet from HCI
  test_hci_layer_->IncomingAclData(handle);
  auto queue_end = connection.GetAclQueueEnd();
  auto queue_end = connection->GetAclQueueEnd();

  std::unique_ptr<PacketView<kLittleEndian>> received;
  do {
@@ -445,7 +447,7 @@ TEST_F(AclManagerTest, acl_send_data_one_connection) {
  SendAclData(handle, connection);

  sent_packet = test_hci_layer_->OutgoingAclData();
  connection.Disconnect(DisconnectReason::AUTHENTICATION_FAILURE);
  connection->Disconnect(DisconnectReason::AUTHENTICATION_FAILURE);
}

TEST_F(AclManagerTest, acl_send_data_credits) {
@@ -466,11 +468,12 @@ TEST_F(AclManagerTest, acl_send_data_credits) {
  auto first_connection_status = first_connection.wait_for(kTimeout);
  ASSERT_EQ(first_connection_status, std::future_status::ready);

  AclConnection& connection = GetLastConnection();
  std::shared_ptr<AclConnection> connection = GetLastConnection();

  // Register the disconnect handler
  connection.RegisterDisconnectCallback(
      common::BindOnce([](AclConnection conn, ErrorCode) { conn.Finish(); }, std::move(connection)), client_handler_);
  connection->RegisterDisconnectCallback(
      common::BindOnce([](std::shared_ptr<AclConnection> conn, ErrorCode) { conn->Finish(); }, connection),
      client_handler_);

  // Use all the credits
  for (uint16_t credits = 0; credits < test_controller_->total_acl_buffers_; credits++) {
@@ -489,7 +492,7 @@ TEST_F(AclManagerTest, acl_send_data_credits) {

  auto after_credits_sent_packet = test_hci_layer_->OutgoingAclData();

  connection.Disconnect(DisconnectReason::AUTHENTICATION_FAILURE);
  connection->Disconnect(DisconnectReason::AUTHENTICATION_FAILURE);
}

}  // namespace
+20 −19
Original line number Diff line number Diff line
@@ -91,7 +91,7 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public ::bluet
      LOG_ERROR("Invalid address");
      return ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT, "Invalid address");
    } else {
      connection->second.Disconnect(DisconnectReason::REMOTE_USER_TERMINATED_CONNECTION);
      connection->second->Disconnect(DisconnectReason::REMOTE_USER_TERMINATED_CONNECTION);
      return ::grpc::Status::OK;
    }
  }
@@ -101,7 +101,7 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public ::bluet
    std::unique_lock<std::mutex> lock(mutex_);
    std::promise<void> promise;
    auto future = promise.get_future();
    acl_connections_[request->remote().address()].GetAclQueueEnd()->RegisterEnqueue(
    acl_connections_[request->remote().address()]->GetAclQueueEnd()->RegisterEnqueue(
        facade_handler_, common::Bind(&AclManagerFacadeService::enqueue_packet, common::Unretained(this),
                                      common::Unretained(request), common::Passed(std::move(promise))));
    future.wait();
@@ -109,7 +109,7 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public ::bluet
  }

  std::unique_ptr<BasePacketBuilder> enqueue_packet(const AclData* request, std::promise<void> promise) {
    acl_connections_[request->remote().address()].GetAclQueueEnd()->UnregisterEnqueue();
    acl_connections_[request->remote().address()]->GetAclQueueEnd()->UnregisterEnqueue();
    std::string req_string = request->payload();
    std::unique_ptr<RawBuilder> packet = std::make_unique<RawBuilder>();
    packet->AddOctets(std::vector<uint8_t>(req_string.begin(), req_string.end()));
@@ -130,7 +130,7 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public ::bluet
      return;
    }

    auto packet = connection->second.GetAclQueueEnd()->TryDequeue();
    auto packet = connection->second->GetAclQueueEnd()->TryDequeue();
    auto acl_packet = AclPacketView::Create(*packet);
    AclData acl_data;
    acl_data.mutable_remote()->set_address(address);
@@ -139,14 +139,15 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public ::bluet
    acl_stream_.OnIncomingEvent(acl_data);
  }

  void OnConnectSuccess(::bluetooth::hci::AclConnection connection) override {
  void OnConnectSuccess(std::unique_ptr<::bluetooth::hci::AclConnection> connection) override {
    std::unique_lock<std::mutex> lock(mutex_);
    auto addr = connection.GetAddress();
    acl_connections_.emplace(addr.ToString(), connection);
    connection.RegisterDisconnectCallback(
    auto addr = connection->GetAddress();
    std::shared_ptr<::bluetooth::hci::AclConnection> shared_connection = std::move(connection);
    acl_connections_.emplace(addr.ToString(), shared_connection);
    shared_connection->RegisterDisconnectCallback(
        common::BindOnce(&AclManagerFacadeService::on_disconnect, common::Unretained(this), addr.ToString()),
        facade_handler_);
    connection_complete_stream_.OnIncomingEvent(connection);
    connection_complete_stream_.OnIncomingEvent(shared_connection);
  }

  void on_disconnect(std::string address, ErrorCode code) {
@@ -189,14 +190,14 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public ::bluet
  ::bluetooth::os::Handler* facade_handler_;

  class ConnectionCompleteStreamCallback
      : public ::bluetooth::grpc::GrpcEventStreamCallback<ConnectionEvent, AclConnection> {
      : public ::bluetooth::grpc::GrpcEventStreamCallback<ConnectionEvent, std::shared_ptr<AclConnection>> {
   public:
    void OnWriteResponse(ConnectionEvent* response, AclConnection const& connection) override {
      response->mutable_remote()->set_address(connection.GetAddress().ToString());
      response->set_connection_handle(connection.GetHandle());
    void OnWriteResponse(ConnectionEvent* response, const std::shared_ptr<AclConnection>& connection) override {
      response->mutable_remote()->set_address(connection->GetAddress().ToString());
      response->set_connection_handle(connection->GetHandle());
    }
  } connection_complete_stream_callback_;
  ::bluetooth::grpc::GrpcEventStream<ConnectionEvent, AclConnection> connection_complete_stream_{
  ::bluetooth::grpc::GrpcEventStream<ConnectionEvent, std::shared_ptr<AclConnection>> connection_complete_stream_{
      &connection_complete_stream_callback_};

  class ConnectionFailedStreamCallback
@@ -226,7 +227,7 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public ::bluet
    ~AclStreamCallback() {
      if (subscribed_) {
        for (const auto& connection : service_->acl_connections_) {
          connection.second.GetAclQueueEnd()->UnregisterDequeue();
          connection.second->GetAclQueueEnd()->UnregisterDequeue();
        }
        subscribed_ = false;
      }
@@ -238,8 +239,8 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public ::bluet
        return;
      }
      for (const auto& connection : service_->acl_connections_) {
        auto remote_address = connection.second.GetAddress().ToString();
        connection.second.GetAclQueueEnd()->RegisterDequeue(
        auto remote_address = connection.second->GetAddress().ToString();
        connection.second->GetAclQueueEnd()->RegisterDequeue(
            service_->facade_handler_,
            common::Bind(&AclManagerFacadeService::on_incoming_acl, common::Unretained(service_), remote_address));
      }
@@ -252,7 +253,7 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public ::bluet
        return;
      }
      for (const auto& connection : service_->acl_connections_) {
        connection.second.GetAclQueueEnd()->UnregisterDequeue();
        connection.second->GetAclQueueEnd()->UnregisterDequeue();
      }
      subscribed_ = false;
    }
@@ -267,7 +268,7 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public ::bluet
  } acl_stream_callback_{this};
  ::bluetooth::grpc::GrpcEventStream<AclData, AclData> acl_stream_{&acl_stream_callback_};

  std::map<std::string, AclConnection> acl_connections_;
  std::map<std::string, std::shared_ptr<AclConnection>> acl_connections_;
};

void AclManagerFacadeModule::ListDependencies(ModuleList* list) {