Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 3af3ce54 authored by Zach Johnson's avatar Zach Johnson
Browse files

Rework RegisterCompletedAclPacketsCallback to use ContextualCallback

Also, clean up binding in hci/controller.cc

Test: cert/run --host
Test: atest bluetooth_test_gd
Bug: 156859507
Tag: #gd-refactor
Change-Id: I068b2568f1c9702e0c69c6ea356ccba09c2298a1
parent 50933b2c
Loading
Loading
Loading
Loading
+2 −3
Original line number Diff line number Diff line
@@ -31,8 +31,7 @@ RoundRobinScheduler::RoundRobinScheduler(os::Handler* handler, Controller* contr
  le_max_acl_packet_credits_ = le_buffer_size.total_num_le_packets_;
  le_acl_packet_credits_ = le_max_acl_packet_credits_;
  le_hci_mtu_ = le_buffer_size.le_data_packet_length_;
  controller_->RegisterCompletedAclPacketsCallback(
      common::Bind(&RoundRobinScheduler::incoming_acl_credits, common::Unretained(this)), handler_);
  controller_->RegisterCompletedAclPacketsCallback(handler->BindOn(this, &RoundRobinScheduler::incoming_acl_credits));
}

RoundRobinScheduler::~RoundRobinScheduler() {
+17 −18
Original line number Diff line number Diff line
@@ -53,23 +53,15 @@ class TestController : public Controller {
    return le_buffer_size;
  }

  void RegisterCompletedAclPacketsCallback(common::Callback<void(uint16_t /* handle */, uint16_t /* num_packets */)> cb,
                                           os::Handler* handler) {
    acl_credits_handler_ = handler;
  void RegisterCompletedAclPacketsCallback(CompletedAclPacketsCallback cb) {
    acl_credits_callback_ = cb;
  }

  std::future<void> SendCompletedAclPacketsCallback(uint16_t handle, uint16_t credits) {
    auto promise = std::make_unique<std::promise<void>>();
    auto future = promise->get_future();
    acl_credits_handler_->Post(Bind(acl_credits_callback_, handle, credits));
    acl_credits_handler_->Post(common::BindOnce(
        [](std::unique_ptr<std::promise<void>> promise) mutable { promise->set_value(); }, std::move(promise)));
    return future;
  void SendCompletedAclPacketsCallback(uint16_t handle, uint16_t credits) {
    acl_credits_callback_.Invoke(handle, credits);
  }

  void UnregisterCompletedAclPacketsCallback() {
    acl_credits_handler_ = nullptr;
    acl_credits_callback_ = {};
  }

@@ -79,8 +71,7 @@ class TestController : public Controller {
  const uint16_t le_hci_mtu_ = 27;

 private:
  Handler* acl_credits_handler_;
  Callback<void(uint16_t, uint16_t)> acl_credits_callback_;
  CompletedAclPacketsCallback acl_credits_callback_;
};

class RoundRobinSchedulerTest : public ::testing::Test {
@@ -103,6 +94,14 @@ class RoundRobinSchedulerTest : public ::testing::Test {
    delete thread_;
  }

  void sync_handler() {
    std::promise<void> promise;
    auto future = promise.get_future();
    handler_->BindOnceOn(&promise, &std::promise<void>::set_value).Invoke();
    auto status = future.wait_for(std::chrono::milliseconds(3));
    EXPECT_EQ(status, std::future_status::ready);
  }

  void EnqueueAclUpEnd(AclConnection::QueueUpEnd* queue_up_end, std::vector<uint8_t> packet) {
    if (enqueue_promise_ != nullptr) {
      enqueue_future_->wait();
@@ -248,8 +247,8 @@ TEST_F(RoundRobinSchedulerTest, do_not_register_when_credits_is_zero) {
  ASSERT_EQ(round_robin_scheduler_->GetCredits(), 0);

  SetPacketFuture(5);
  auto future = controller_->SendCompletedAclPacketsCallback(0x01, 10);
  future.wait();
  controller_->SendCompletedAclPacketsCallback(0x01, 10);
  sync_handler();
  packet_future_->wait();
  for (uint8_t i = 10; i < 15; i++) {
    std::vector<uint8_t> packet = {0x01, 0x02, 0x03, i};
@@ -261,8 +260,8 @@ TEST_F(RoundRobinSchedulerTest, do_not_register_when_credits_is_zero) {
}

TEST_F(RoundRobinSchedulerTest, reveived_completed_callback_with_unknown_handle) {
  auto future = controller_->SendCompletedAclPacketsCallback(0x00, 1);
  future.wait();
  controller_->SendCompletedAclPacketsCallback(0x00, 1);
  sync_handler();
  EXPECT_EQ(round_robin_scheduler_->GetCredits(), controller_->max_acl_packet_credits_);
  EXPECT_EQ(round_robin_scheduler_->GetLeCredits(), controller_->le_max_acl_packet_credits_);
}
+4 −7
Original line number Diff line number Diff line
@@ -70,15 +70,13 @@ std::unique_ptr<AclPacketBuilder> NextAclPacket(uint16_t handle) {

class TestController : public Controller {
 public:
  void RegisterCompletedAclPacketsCallback(common::Callback<void(uint16_t /* handle */, uint16_t /* packets */)> cb,
                                           os::Handler* handler) override {
  void RegisterCompletedAclPacketsCallback(
      common::ContextualCallback<void(uint16_t /* handle */, uint16_t /* packets */)> cb) override {
    acl_cb_ = cb;
    acl_cb_handler_ = handler;
  }

  void UnregisterCompletedAclPacketsCallback() override {
    acl_cb_ = {};
    acl_cb_handler_ = nullptr;
  }

  uint16_t GetControllerAclPacketLength() const override {
@@ -101,14 +99,13 @@ class TestController : public Controller {
  }

  void CompletePackets(uint16_t handle, uint16_t packets) {
    acl_cb_handler_->Post(common::BindOnce(acl_cb_, handle, packets));
    acl_cb_.Invoke(handle, packets);
  }

  uint16_t acl_buffer_length_ = 1024;
  uint16_t total_acl_buffers_ = 2;
  uint64_t le_local_supported_features_ = 0;
  common::Callback<void(uint16_t /* handle */, uint16_t /* packets */)> acl_cb_;
  os::Handler* acl_cb_handler_ = nullptr;
  common::ContextualCallback<void(uint16_t /* handle */, uint16_t /* packets */)> acl_cb_;

 protected:
  void Start() override {}
+28 −48
Original line number Diff line number Diff line
@@ -20,19 +20,11 @@
#include <memory>
#include <utility>

#include "common/bind.h"
#include "common/callback.h"
#include "hci/hci_layer.h"

namespace bluetooth {
namespace hci {

using common::Bind;
using common::BindOnce;
using common::Callback;
using common::Closure;
using common::OnceCallback;
using common::OnceClosure;
using os::Handler;

struct Controller::impl {
@@ -107,7 +99,7 @@ struct Controller::impl {
  }

  void NumberOfCompletedPackets(EventPacketView event) {
    if (acl_credits_handler_ == nullptr) {
    if (acl_credits_callback_.IsEmpty()) {
      LOG_WARN("Received event when AclManager is not listening");
      return;
    }
@@ -116,32 +108,18 @@ struct Controller::impl {
    for (auto completed_packets : complete_view.GetCompletedPackets()) {
      uint16_t handle = completed_packets.connection_handle_;
      uint16_t credits = completed_packets.host_num_of_completed_packets_;
      acl_credits_handler_->Post(Bind(acl_credits_callback_, handle, credits));
      acl_credits_callback_.Invoke(handle, credits);
    }
  }

  void RegisterCompletedAclPacketsCallback(Callback<void(uint16_t /* handle */, uint16_t /* packets */)> cb,
                                           Handler* handler) {
    module_.GetHandler()->Post(common::BindOnce(&impl::register_completed_acl_packets_callback,
                                                common::Unretained(this), cb, common::Unretained(handler)));
  }

  void register_completed_acl_packets_callback(Callback<void(uint16_t /* handle */, uint16_t /* packets */)> cb,
                                               Handler* handler) {
    ASSERT(acl_credits_handler_ == nullptr);
    acl_credits_callback_ = cb;
    acl_credits_handler_ = handler;
  }

  void UnregisterCompletedAclPacketsCallback() {
    module_.GetHandler()->Post(
        common::BindOnce(&impl::unregister_completed_acl_packets_callback, common::Unretained(this)));
  void register_completed_acl_packets_callback(CompletedAclPacketsCallback callback) {
    ASSERT(acl_credits_callback_.IsEmpty());
    acl_credits_callback_ = callback;
  }

  void unregister_completed_acl_packets_callback() {
    ASSERT(acl_credits_handler_ != nullptr);
    ASSERT(!acl_credits_callback_.IsEmpty());
    acl_credits_callback_ = {};
    acl_credits_handler_ = nullptr;
  }

  void read_local_name_complete_handler(CommandCompleteView view) {
@@ -702,8 +680,7 @@ struct Controller::impl {

  HciLayer* hci_;

  Callback<void(uint16_t, uint16_t)> acl_credits_callback_;
  Handler* acl_credits_handler_ = nullptr;
  CompletedAclPacketsCallback acl_credits_callback_{};
  LocalVersionInformation local_version_information_;
  std::array<uint8_t, 64> local_supported_commands_;
  uint64_t local_supported_features_;
@@ -728,13 +705,12 @@ Controller::Controller() : impl_(std::make_unique<impl>(*this)) {}

Controller::~Controller() = default;

void Controller::RegisterCompletedAclPacketsCallback(Callback<void(uint16_t /* handle */, uint16_t /* packets */)> cb,
                                                     Handler* handler) {
  impl_->RegisterCompletedAclPacketsCallback(cb, handler);  // TODO hsz: why here?
void Controller::RegisterCompletedAclPacketsCallback(CompletedAclPacketsCallback cb) {
  CallOn(impl_.get(), &impl::register_completed_acl_packets_callback, cb);
}

void Controller::UnregisterCompletedAclPacketsCallback() {
  impl_->UnregisterCompletedAclPacketsCallback();  // TODO hsz: why here?
  CallOn(impl_.get(), &impl::unregister_completed_acl_packets_callback);
}

std::string Controller::GetControllerLocalName() const {
@@ -785,41 +761,41 @@ Address Controller::GetControllerMacAddress() const {
}

void Controller::SetEventMask(uint64_t event_mask) {
  GetHandler()->Post(common::BindOnce(&impl::set_event_mask, common::Unretained(impl_.get()), event_mask));
  CallOn(impl_.get(), &impl::set_event_mask, event_mask);
}

void Controller::Reset() {
  GetHandler()->Post(common::BindOnce(&impl::reset, common::Unretained(impl_.get())));
  CallOn(impl_.get(), &impl::reset);
}

void Controller::SetEventFilterClearAll() {
  std::unique_ptr<SetEventFilterClearAllBuilder> packet = SetEventFilterClearAllBuilder::Create();
  GetHandler()->Post(common::BindOnce(&impl::set_event_filter, common::Unretained(impl_.get()), std::move(packet)));
  CallOn(impl_.get(), &impl::set_event_filter, std::move(packet));
}

void Controller::SetEventFilterInquiryResultAllDevices() {
  std::unique_ptr<SetEventFilterInquiryResultAllDevicesBuilder> packet =
      SetEventFilterInquiryResultAllDevicesBuilder::Create();
  GetHandler()->Post(common::BindOnce(&impl::set_event_filter, common::Unretained(impl_.get()), std::move(packet)));
  CallOn(impl_.get(), &impl::set_event_filter, std::move(packet));
}

void Controller::SetEventFilterInquiryResultClassOfDevice(ClassOfDevice class_of_device,
                                                          ClassOfDevice class_of_device_mask) {
  std::unique_ptr<SetEventFilterInquiryResultClassOfDeviceBuilder> packet =
      SetEventFilterInquiryResultClassOfDeviceBuilder::Create(class_of_device, class_of_device_mask);
  GetHandler()->Post(common::BindOnce(&impl::set_event_filter, common::Unretained(impl_.get()), std::move(packet)));
  CallOn(impl_.get(), &impl::set_event_filter, std::move(packet));
}

void Controller::SetEventFilterInquiryResultAddress(Address address) {
  std::unique_ptr<SetEventFilterInquiryResultAddressBuilder> packet =
      SetEventFilterInquiryResultAddressBuilder::Create(address);
  GetHandler()->Post(common::BindOnce(&impl::set_event_filter, common::Unretained(impl_.get()), std::move(packet)));
  CallOn(impl_.get(), &impl::set_event_filter, std::move(packet));
}

void Controller::SetEventFilterConnectionSetupAllDevices(AutoAcceptFlag auto_accept_flag) {
  std::unique_ptr<SetEventFilterConnectionSetupAllDevicesBuilder> packet =
      SetEventFilterConnectionSetupAllDevicesBuilder::Create(auto_accept_flag);
  GetHandler()->Post(common::BindOnce(&impl::set_event_filter, common::Unretained(impl_.get()), std::move(packet)));
  CallOn(impl_.get(), &impl::set_event_filter, std::move(packet));
}

void Controller::SetEventFilterConnectionSetupClassOfDevice(ClassOfDevice class_of_device,
@@ -828,30 +804,34 @@ void Controller::SetEventFilterConnectionSetupClassOfDevice(ClassOfDevice class_
  std::unique_ptr<SetEventFilterConnectionSetupClassOfDeviceBuilder> packet =
      SetEventFilterConnectionSetupClassOfDeviceBuilder::Create(class_of_device, class_of_device_mask,
                                                                auto_accept_flag);
  GetHandler()->Post(common::BindOnce(&impl::set_event_filter, common::Unretained(impl_.get()), std::move(packet)));
  CallOn(impl_.get(), &impl::set_event_filter, std::move(packet));
}

void Controller::SetEventFilterConnectionSetupAddress(Address address, AutoAcceptFlag auto_accept_flag) {
  std::unique_ptr<SetEventFilterConnectionSetupAddressBuilder> packet =
      SetEventFilterConnectionSetupAddressBuilder::Create(address, auto_accept_flag);
  GetHandler()->Post(common::BindOnce(&impl::set_event_filter, common::Unretained(impl_.get()), std::move(packet)));
  CallOn(impl_.get(), &impl::set_event_filter, std::move(packet));
}

void Controller::WriteLocalName(std::string local_name) {
  impl_->local_name_ = local_name;
  GetHandler()->Post(common::BindOnce(&impl::write_local_name, common::Unretained(impl_.get()), local_name));
  CallOn(impl_.get(), &impl::write_local_name, local_name);
}

void Controller::HostBufferSize(uint16_t host_acl_data_packet_length, uint8_t host_synchronous_data_packet_length,
                                uint16_t host_total_num_acl_data_packets,
                                uint16_t host_total_num_synchronous_data_packets) {
  GetHandler()->Post(common::BindOnce(&impl::host_buffer_size, common::Unretained(impl_.get()),
                                      host_acl_data_packet_length, host_synchronous_data_packet_length,
                                      host_total_num_acl_data_packets, host_total_num_synchronous_data_packets));
  CallOn(
      impl_.get(),
      &impl::host_buffer_size,
      host_acl_data_packet_length,
      host_synchronous_data_packet_length,
      host_total_num_acl_data_packets,
      host_total_num_synchronous_data_packets);
}

void Controller::LeSetEventMask(uint64_t le_event_mask) {
  GetHandler()->Post(common::BindOnce(&impl::le_set_event_mask, common::Unretained(impl_.get()), le_event_mask));
  CallOn(impl_.get(), &impl::le_set_event_mask, le_event_mask);
}

LeBufferSize Controller::GetControllerLeBufferSize() const {
+4 −3
Original line number Diff line number Diff line
@@ -16,7 +16,7 @@

#pragma once

#include "common/callback.h"
#include "common/contextual_callback.h"
#include "hci/address.h"
#include "hci/hci_packets.h"
#include "module.h"
@@ -31,8 +31,9 @@ class Controller : public Module {
  virtual ~Controller();
  DISALLOW_COPY_AND_ASSIGN(Controller);

  virtual void RegisterCompletedAclPacketsCallback(
      common::Callback<void(uint16_t /* handle */, uint16_t /* num_packets */)> cb, os::Handler* handler);
  using CompletedAclPacketsCallback =
      common::ContextualCallback<void(uint16_t /* handle */, uint16_t /* num_packets */)>;
  virtual void RegisterCompletedAclPacketsCallback(CompletedAclPacketsCallback cb);

  virtual void UnregisterCompletedAclPacketsCallback();

Loading