Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 52a8b895 authored by Treehugger Robot's avatar Treehugger Robot Committed by Gerrit Code Review
Browse files

Merge "ACL packet recombination"

parents b74d8d5c 694027e9
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -42,7 +42,7 @@ constexpr uint8_t kH4HeaderSize = 1;
constexpr uint8_t kHciAclHeaderSize = 4;
constexpr uint8_t kHciScoHeaderSize = 3;
constexpr uint8_t kHciEvtHeaderSize = 2;
constexpr int kBufSize = 1024;
constexpr int kBufSize = 1024 + 4 + 1;  // DeviceProperties::acl_data_packet_size_ + ACL header + H4 header

int ConnectToRootCanal(const std::string& server, int port) {
  int socket_fd = socket(AF_INET, SOCK_STREAM, 0);
@@ -238,7 +238,7 @@ class HciHalHostRootcanal : public HciHal {
      ASSERT_LOG(received_size != -1, "Can't receive from socket: %s", strerror(errno));
      ASSERT_LOG(received_size == kHciAclHeaderSize, "malformed ACL header received");

      uint16_t hci_acl_data_total_length = buf[4] * 256 + buf[3];
      uint16_t hci_acl_data_total_length = (buf[4] << 8) + buf[3];
      int payload_size;
      RUN_NO_INTR(payload_size = recv(sock_fd_, buf + kH4HeaderSize + kHciAclHeaderSize, hci_acl_data_total_length, 0));
      ASSERT_LOG(payload_size != -1, "Can't receive from socket: %s", strerror(errno));
+97 −32
Original line number Diff line number Diff line
@@ -36,10 +36,36 @@ constexpr size_t kMaxQueuedPacketsPerConnection = 10;
using common::Bind;
using common::BindOnce;

namespace {
class PacketViewForRecombination : public packet::PacketView<kLittleEndian> {
 public:
  PacketViewForRecombination(const PacketView& packetView) : PacketView(packetView) {}
  void AppendPacketView(packet::PacketView<kLittleEndian> to_append) {
    Append(to_append);
  }
};

constexpr int kL2capBasicFrameHeaderSize = 4;

// Per spec 5.1 Vol 2 Part B 5.3, ACL link shall carry L2CAP data. Therefore, an ACL packet shall contain L2CAP PDU.
// This function returns the PDU size of the L2CAP data if it's a starting packet. Returns 0 if it's invalid.
uint16_t GetL2capPduSize(AclPacketView packet) {
  auto l2cap_payload = packet.GetPayload();
  if (l2cap_payload.size() < kL2capBasicFrameHeaderSize) {
    LOG_ERROR("Controller sent an invalid L2CAP starting packet!");
    return 0;
  }
  return (l2cap_payload.at(1) << 8) + l2cap_payload.at(0);
}

}  // namespace

struct AclManager::acl_connection {
  acl_connection(AddressWithType address_with_type) : address_with_type_(address_with_type) {}
  acl_connection(AddressWithType address_with_type, os::Handler* handler)
      : address_with_type_(address_with_type), handler_(handler) {}
  friend AclConnection;
  AddressWithType address_with_type_;
  os::Handler* handler_;
  std::unique_ptr<AclConnection::Queue> queue_ = std::make_unique<AclConnection::Queue>(10);
  bool is_disconnected_ = false;
  ErrorCode disconnect_reason_;
@@ -54,9 +80,70 @@ struct AclManager::acl_connection {
  bool is_registered_ = false;
  // Credits: Track the number of packets which have been sent to the controller
  uint16_t number_of_sent_packets_ = 0;
  bool enqueue_registered_{false};
  PacketViewForRecombination recombination_stage_{std::make_shared<std::vector<uint8_t>>()};
  int remaining_sdu_continuation_packet_size_ = 0;
  bool enqueue_registered_ = false;
  std::queue<packet::PacketView<kLittleEndian>> incoming_queue_;

  std::unique_ptr<packet::PacketView<kLittleEndian>> on_incoming_data_ready() {
    auto packet = incoming_queue_.front();
    incoming_queue_.pop();
    if (incoming_queue_.empty()) {
      auto queue_end = queue_->GetDownEnd();
      queue_end->UnregisterEnqueue();
      enqueue_registered_ = false;
    }
    return std::make_unique<PacketView<kLittleEndian>>(packet);
  }

  void on_incoming_packet(AclPacketView packet) {
    // TODO: What happens if the connection is stalled and fills up?
    PacketView<kLittleEndian> payload = packet.GetPayload();
    auto payload_size = payload.size();
    auto packet_boundary_flag = packet.GetPacketBoundaryFlag();
    if (packet_boundary_flag == PacketBoundaryFlag::FIRST_NON_AUTOMATICALLY_FLUSHABLE) {
      LOG_ERROR("Controller is not allowed to send FIRST_NON_AUTOMATICALLY_FLUSHABLE to host except loopback mode");
      return;
    }
    if (packet_boundary_flag == PacketBoundaryFlag::CONTINUING_FRAGMENT) {
      if (remaining_sdu_continuation_packet_size_ < payload_size) {
        LOG_WARN("Remote sent unexpected L2CAP PDU. Drop the entire L2CAP PDU");
        recombination_stage_ = PacketViewForRecombination(std::make_shared<std::vector<uint8_t>>());
        remaining_sdu_continuation_packet_size_ = 0;
        return;
      }
      remaining_sdu_continuation_packet_size_ -= payload_size;
      recombination_stage_.AppendPacketView(payload);
      if (remaining_sdu_continuation_packet_size_ != 0) {
        return;
      } else {
        payload = recombination_stage_;
      }
    } else if (packet_boundary_flag == PacketBoundaryFlag::FIRST_AUTOMATICALLY_FLUSHABLE) {
      if (recombination_stage_.size() > 0) {
        LOG_ERROR("Controller sent a starting packet without finishing previous packet. Drop previous one.");
      }
      auto l2cap_pdu_size = GetL2capPduSize(packet);
      remaining_sdu_continuation_packet_size_ = l2cap_pdu_size - (payload_size - kL2capBasicFrameHeaderSize);
      if (remaining_sdu_continuation_packet_size_ > 0) {
        recombination_stage_ = payload;
        return;
      }
    }
    if (incoming_queue_.size() > kMaxQueuedPacketsPerConnection) {
      LOG_ERROR("Dropping packet due to congestion from remote:%s", address_with_type_.ToString().c_str());
      return;
    }

    incoming_queue_.push(payload);
    if (!enqueue_registered_) {
      enqueue_registered_ = true;
      auto queue_end = queue_->GetDownEnd();
      queue_end->RegisterEnqueue(
          handler_, common::Bind(&AclManager::acl_connection::on_incoming_data_ready, common::Unretained(this)));
    }
  }

  void call_disconnect_callback() {
    disconnect_handler_->Post(BindOnce(std::move(on_disconnect_callback_), disconnect_reason_));
  }
@@ -225,17 +312,6 @@ struct AclManager::impl {
    return std::unique_ptr<AclPacketBuilder>(raw_pointer);
  }

  std::unique_ptr<packet::PacketView<kLittleEndian>> OnIncomingReadReady(AclManager::acl_connection* connection) {
    auto packet = connection->incoming_queue_.front();
    connection->incoming_queue_.pop();
    if (connection->incoming_queue_.empty()) {
      auto queue_end = connection->queue_->GetDownEnd();
      queue_end->UnregisterEnqueue();
      connection->enqueue_registered_ = false;
    }
    return std::make_unique<PacketView<kLittleEndian>>(packet);
  }

  void dequeue_and_route_acl_packet_to_connection() {
    auto packet = hci_queue_end_->TryDequeue();
    ASSERT(packet != nullptr);
@@ -252,23 +328,8 @@ struct AclManager::impl {
      LOG_INFO("Dropping packet of size %zu to unknown connection 0x%0hx", packet->size(), handle);
      return;
    }
    // TODO: What happens if the connection is stalled and fills up?
    // TODO hsz: define enqueue callback
    auto queue_end = connection_pair->second.queue_->GetDownEnd();
    PacketView<kLittleEndian> payload = packet->GetPayload();

    if (connection_pair->second.incoming_queue_.size() > kMaxQueuedPacketsPerConnection) {
      LOG_INFO("Dropping packet due to congestion from remote:%s",
               connection_pair->second.address_with_type_.ToString().c_str());
      return;
    }

    connection_pair->second.incoming_queue_.push(payload);
    if (!connection_pair->second.enqueue_registered_) {
      connection_pair->second.enqueue_registered_ = true;
      queue_end->RegisterEnqueue(handler_, common::Bind(&AclManager::impl::OnIncomingReadReady,
                                                        common::Unretained(this), &connection_pair->second));
    }
    connection_pair->second.on_incoming_packet(*packet);
  }

  void on_incoming_connection(EventPacketView packet) {
@@ -328,7 +389,8 @@ struct AclManager::impl {
    // TODO: Check and save other connection parameters
    uint16_t handle = connection_complete.GetConnectionHandle();
    ASSERT(acl_connections_.count(handle) == 0);
    acl_connections_.emplace(handle, address_with_type);
    acl_connections_.emplace(std::piecewise_construct, std::forward_as_tuple(handle),
                             std::forward_as_tuple(address_with_type, handler_));
    if (acl_connections_.size() == 1 && fragments_to_send_.size() == 0) {
      start_round_robin();
    }
@@ -361,7 +423,8 @@ struct AclManager::impl {
    // TODO: Check and save other connection parameters
    uint16_t handle = connection_complete.GetConnectionHandle();
    ASSERT(acl_connections_.count(handle) == 0);
    acl_connections_.emplace(handle, reporting_address_with_type);
    acl_connections_.emplace(std::piecewise_construct, std::forward_as_tuple(handle),
                             std::forward_as_tuple(reporting_address_with_type, handler_));
    if (acl_connections_.size() == 1 && fragments_to_send_.size() == 0) {
      start_round_robin();
    }
@@ -386,7 +449,9 @@ struct AclManager::impl {
    }
    uint16_t handle = connection_complete.GetConnectionHandle();
    ASSERT(acl_connections_.count(handle) == 0);
    acl_connections_.emplace(handle, AddressWithType{address, AddressType::PUBLIC_DEVICE_ADDRESS});
    acl_connections_.emplace(
        std::piecewise_construct, std::forward_as_tuple(handle),
        std::forward_as_tuple(AddressWithType{address, AddressType::PUBLIC_DEVICE_ADDRESS}, handler_));
    if (acl_connections_.size() == 1 && fragments_to_send_.size() == 0) {
      start_round_robin();
    }
+2 −0
Original line number Diff line number Diff line
@@ -54,6 +54,8 @@ PacketView<kLittleEndian> GetPacketView(std::unique_ptr<packet::BasePacketBuilde
std::unique_ptr<BasePacketBuilder> NextPayload(uint16_t handle) {
  static uint32_t packet_number = 1;
  auto payload = std::make_unique<RawBuilder>();
  payload->AddOctets2(6);  // L2CAP PDU size
  payload->AddOctets2(2);  // L2CAP CID
  payload->AddOctets2(handle);
  payload->AddOctets4(packet_number++);
  return std::move(payload);
+110 −3
Original line number Diff line number Diff line
@@ -159,9 +159,11 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):

                self.enqueue_acl_data(
                    cert_handle, hci_packets.PacketBoundaryFlag.
                    FIRST_NON_AUTOMATICALLY_FLUSHABLE,
                    FIRST_AUTOMATICALLY_FLUSHABLE,
                    hci_packets.BroadcastFlag.POINT_TO_POINT,
                    bytes(b'This is just SomeAclData from the Cert'))
                    bytes(
                        b'\x26\x00\x07\x00This is just SomeAclData from the Cert'
                    ))

                # DUT gets a connection complete event and sends and receives
                connection_event_asserts = EventAsserts(connection_event_stream)
@@ -172,7 +174,8 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
                    acl_manager_facade.AclData(
                        handle=handle,
                        payload=bytes(
                            b'This is just SomeMoreAclData from the DUT')))
                            b'\x29\x00\x07\x00This is just SomeMoreAclData from the DUT'
                        )))

                acl_data_asserts = EventAsserts(acl_data_stream)
                cert_acl_data_asserts = EventAsserts(cert_acl_data_stream)
@@ -180,3 +183,107 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
                    lambda packet: b'SomeMoreAclData' in packet.data)
                acl_data_asserts.assert_event_occurs(
                    lambda packet: b'SomeAclData' in packet.payload)

    def test_recombination_l2cap_packet(self):
        self.register_for_event(hci_packets.EventCode.CONNECTION_REQUEST)
        self.register_for_event(hci_packets.EventCode.CONNECTION_COMPLETE)
        with EventCallbackStream(self.cert_device.hci.FetchEvents(empty_proto.Empty())) as cert_hci_event_stream, \
            EventCallbackStream(self.cert_device.hci.FetchAclPackets(empty_proto.Empty())) as cert_acl_data_stream, \
            EventCallbackStream(self.device_under_test.hci_acl_manager.FetchAclData(empty_proto.Empty())) as acl_data_stream:

            # CERT Enables scans and gets its address
            self.enqueue_hci_command(
                hci_packets.WriteScanEnableBuilder(
                    hci_packets.ScanEnable.INQUIRY_AND_PAGE_SCAN), True)

            cert_address = None

            def get_address_from_complete(packet):
                packet_bytes = packet.event
                if b'\x0e\x0a\x01\x09\x10' in packet_bytes:
                    nonlocal cert_address
                    addr_view = hci_packets.ReadBdAddrCompleteView(
                        hci_packets.CommandCompleteView(
                            hci_packets.EventPacketView(
                                bt_packets.PacketViewLittleEndian(
                                    list(packet_bytes)))))
                    cert_address = addr_view.GetBdAddr()
                    return True
                return False

            self.enqueue_hci_command(hci_packets.ReadBdAddrBuilder(), True)

            cert_hci_event_asserts = EventAsserts(cert_hci_event_stream)
            cert_hci_event_asserts.assert_event_occurs(
                get_address_from_complete)

            with EventCallbackStream(
                    self.device_under_test.hci_acl_manager.CreateConnection(
                        acl_manager_facade.ConnectionMsg(
                            address_type=int(
                                hci_packets.AddressType.PUBLIC_DEVICE_ADDRESS),
                            address=bytes(cert_address,
                                          'utf8')))) as connection_event_stream:
                connection_request = None

                def get_connect_request(packet):
                    if b'\x04\x0a' in packet.event:
                        nonlocal connection_request
                        connection_request = hci_packets.ConnectionRequestView(
                            hci_packets.EventPacketView(
                                bt_packets.PacketViewLittleEndian(
                                    list(packet.event))))
                        return True
                    return False

                # Cert Accepts
                cert_hci_event_asserts.assert_event_occurs(get_connect_request)
                self.enqueue_hci_command(
                    hci_packets.AcceptConnectionRequestBuilder(
                        connection_request.GetBdAddr(),
                        hci_packets.AcceptConnectionRequestRole.REMAIN_SLAVE),
                    False)

                # Cert gets ConnectionComplete with a handle and sends ACL data
                handle = 0xfff

                def get_handle(packet):
                    packet_bytes = packet.event
                    if b'\x03\x0b\x00' in packet_bytes:
                        nonlocal handle
                        cc_view = hci_packets.ConnectionCompleteView(
                            hci_packets.EventPacketView(
                                bt_packets.PacketViewLittleEndian(
                                    list(packet_bytes))))
                        handle = cc_view.GetConnectionHandle()
                        return True
                    return False

                cert_hci_event_asserts.assert_event_occurs(get_handle)
                cert_handle = handle

                acl_data_asserts = EventAsserts(acl_data_stream)

                self.enqueue_acl_data(
                    cert_handle, hci_packets.PacketBoundaryFlag.
                    FIRST_AUTOMATICALLY_FLUSHABLE,
                    hci_packets.BroadcastFlag.POINT_TO_POINT,
                    bytes(b'\x06\x00\x07\x00Hello'))
                self.enqueue_acl_data(
                    cert_handle,
                    hci_packets.PacketBoundaryFlag.CONTINUING_FRAGMENT,
                    hci_packets.BroadcastFlag.POINT_TO_POINT, bytes(b'!'))
                self.enqueue_acl_data(
                    cert_handle, hci_packets.PacketBoundaryFlag.
                    FIRST_AUTOMATICALLY_FLUSHABLE,
                    hci_packets.BroadcastFlag.POINT_TO_POINT,
                    bytes(b'\x88\x13\x07\x00' + b'Hello' * 1000))

                # DUT gets a connection complete event and sends and receives
                connection_event_asserts = EventAsserts(connection_event_stream)
                connection_event_asserts.assert_event_occurs(get_handle)

                acl_data_asserts.assert_event_occurs(
                    lambda packet: b'Hello!' in packet.payload)
                acl_data_asserts.assert_event_occurs(
                    lambda packet: b'Hello' * 1000 in packet.payload)
+21 −10
Original line number Diff line number Diff line
@@ -139,11 +139,9 @@ ErrorCode LinkLayerController::SendAclToRemote(
      std::make_unique<bluetooth::packet::RawBuilder>();
  std::vector<uint8_t> payload_bytes(acl_payload.begin(), acl_payload.end());

  constexpr auto pb_flag_controller_to_host =
      bluetooth::hci::PacketBoundaryFlag::FIRST_AUTOMATICALLY_FLUSHABLE;
  uint16_t first_two_bytes =
      static_cast<uint16_t>(acl_packet.GetHandle()) +
      (static_cast<uint16_t>(pb_flag_controller_to_host) << 12) +
      (static_cast<uint16_t>(acl_packet.GetPacketBoundaryFlag()) << 12) +
      (static_cast<uint16_t>(acl_packet.GetBroadcastFlag()) << 14);
  raw_builder_ptr->AddOctets2(first_two_bytes);
  raw_builder_ptr->AddOctets2(static_cast<uint16_t>(payload_bytes.size()));
@@ -303,15 +301,28 @@ void LinkLayerController::IncomingAclPacket(

  std::vector<uint8_t> payload_data(acl_view.GetPayload().begin(),
                                    acl_view.GetPayload().end());
  uint16_t acl_buffer_size = properties_.GetAclDataPacketSize();
  int num_packets =
      (payload_data.size() + acl_buffer_size - 1) / acl_buffer_size;

  auto pb_flag_controller_to_host = acl_view.GetPacketBoundaryFlag();
  for (int i = 0; i < num_packets; i++) {
    size_t start_index = acl_buffer_size * i;
    size_t end_index =
        std::min(start_index + acl_buffer_size, payload_data.size());
    std::vector<uint8_t> fragment(&payload_data[start_index],
                                  &payload_data[end_index]);
    std::unique_ptr<bluetooth::packet::RawBuilder> raw_builder_ptr =
      std::make_unique<bluetooth::packet::RawBuilder>(payload_data);

        std::make_unique<bluetooth::packet::RawBuilder>(fragment);
    auto acl_packet = bluetooth::hci::AclPacketBuilder::Create(
      local_handle, acl_view.GetPacketBoundaryFlag(),
      acl_view.GetBroadcastFlag(), std::move(raw_builder_ptr));
        local_handle, pb_flag_controller_to_host, acl_view.GetBroadcastFlag(),
        std::move(raw_builder_ptr));
    pb_flag_controller_to_host =
        bluetooth::hci::PacketBoundaryFlag::CONTINUING_FRAGMENT;

    send_acl_(std::move(acl_packet));
  }
}

void LinkLayerController::IncomingRemoteNameRequest(
    model::packets::LinkLayerPacketView packet) {