Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 05edb428 authored by Treehugger Robot's avatar Treehugger Robot Committed by Gerrit Code Review
Browse files

Merge changes I7528d0c6,I995352c8

* changes:
  Facade: Separate ACL streams per handle
  Acl Cert: Don't unregister on disconnect
parents e93387c3 43ce5650
Loading
Loading
Loading
Loading
+13 −25
Original line number Diff line number Diff line
@@ -27,66 +27,54 @@ from hci.facade import acl_manager_facade_pb2 as acl_manager_facade

class PyAclManagerAclConnection(IEventStream, Closable):

    def __init__(self, device, acl_stream, remote_addr, handle, event_stream):
        self.device = device
    def __init__(self, acl_manager, remote_addr, handle, event_stream):
        self.acl_manager = acl_manager
        self.handle = handle
        # todo enable filtering after sorting out handles
        # self.our_acl_stream = FilteringEventStream(acl_stream, None)
        self.our_acl_stream = acl_stream
        self.remote_addr = remote_addr
        self.connection_event_stream = event_stream
        self.acl_stream = EventStream(self.acl_manager.FetchAclData(acl_manager_facade.HandleMsg(handle=self.handle)))

    def disconnect(self, reason):
        packet_bytes = bytes(hci_packets.DisconnectBuilder(self.handle, reason).Serialize())
        self.device.hci_acl_manager.ConnectionCommand(acl_manager_facade.ConnectionCommandMsg(packet=packet_bytes))
        self.acl_manager.ConnectionCommand(acl_manager_facade.ConnectionCommandMsg(packet=packet_bytes))

    def close(self):
        safeClose(self.connection_event_stream)

    def wait_for_connection_complete(self):
        connection_complete = HciCaptures.ConnectionCompleteCapture()
        assertThat(self.connection_event_stream).emits(connection_complete)
        self.handle = connection_complete.get().GetConnectionHandle()
        safeClose(self.acl_stream)

    def wait_for_disconnection_complete(self):
        disconnection_complete = HciCaptures.DisconnectionCompleteCapture()
        assertThat(self.connection_event_stream).isNotNone()
        assertThat(self.connection_event_stream).emits(disconnection_complete)
        self.disconnect_reason = disconnection_complete.get().GetReason()

    def send(self, data):
        self.device.hci_acl_manager.SendAclData(acl_manager_facade.AclData(handle=self.handle, payload=bytes(data)))
        self.acl_manager.SendAclData(acl_manager_facade.AclData(handle=self.handle, payload=bytes(data)))

    def get_event_queue(self):
        return self.our_acl_stream.get_event_queue()
        return self.acl_stream.get_event_queue()


class PyAclManager(Closable):
class PyAclManager:

    def __init__(self, device):
        self.device = device
        self.acl_stream = EventStream(self.device.hci_acl_manager.FetchAclData(empty_proto.Empty()))
        self.acl_manager = device.hci_acl_manager
        self.incoming_connection_event_stream = None
        self.outgoing_connection_event_stream = None

    def close(self):
        safeClose(self.acl_stream)
        safeClose(self.incoming_connection_event_stream)
        safeClose(self.outgoing_connection_event_stream)

    # temporary, until everyone is migrated
    def get_acl_stream(self):
        return self.acl_stream

    def listen_for_an_incoming_connection(self):
        assertThat(self.incoming_connection_event_stream).isNone()
        self.incoming_connection_event_stream = EventStream(
            self.device.hci_acl_manager.FetchIncomingConnection(empty_proto.Empty()))
            self.acl_manager.FetchIncomingConnection(empty_proto.Empty()))

    def initiate_connection(self, remote_addr):
        assertThat(self.outgoing_connection_event_stream).isNone()
        remote_addr_bytes = bytes(remote_addr, 'utf8') if type(remote_addr) is str else bytes(remote_addr)
        self.outgoing_connection_event_stream = EventStream(
            self.device.hci_acl_manager.CreateConnection(acl_manager_facade.ConnectionMsg(address=remote_addr_bytes)))
            self.acl_manager.CreateConnection(acl_manager_facade.ConnectionMsg(address=remote_addr_bytes)))

    def complete_connection(self, event_stream):
        connection_complete = HciCaptures.ConnectionCompleteCapture()
@@ -94,7 +82,7 @@ class PyAclManager(Closable):
        complete = connection_complete.get()
        handle = complete.GetConnectionHandle()
        address = complete.GetBdAddr()
        return PyAclManagerAclConnection(self.device, self.acl_stream, address, handle, event_stream)
        return PyAclManagerAclConnection(self.acl_manager, address, handle, event_stream)

    def complete_incoming_connection(self):
        assertThat(self.incoming_connection_event_stream).isNotNone()
+14 −17
Original line number Diff line number Diff line
@@ -45,7 +45,6 @@ class AclManagerTest(GdBaseTestClass):

    def teardown_test(self):
        self.cert_hci.close()
        self.dut_acl_manager.close()
        super().teardown_test()

    def test_dut_connects(self):
@@ -67,8 +66,7 @@ class AclManagerTest(GdBaseTestClass):

        self.dut_acl_manager.listen_for_an_incoming_connection()
        self.cert_hci.initiate_connection(dut_address)
        dut_acl = self.dut_acl_manager.complete_incoming_connection()

        with self.dut_acl_manager.complete_incoming_connection() as dut_acl:
            cert_acl = self.cert_hci.complete_connection()

            dut_acl.send(b'\x29\x00\x07\x00This is just SomeMoreAclData from the DUT')
@@ -84,8 +82,7 @@ class AclManagerTest(GdBaseTestClass):

        self.dut_acl_manager.listen_for_an_incoming_connection()
        self.cert_hci.initiate_connection(dut_address)
        dut_acl = self.dut_acl_manager.complete_incoming_connection()

        with self.dut_acl_manager.complete_incoming_connection() as dut_acl:
            cert_acl = self.cert_hci.complete_connection()

            dut_acl.send(b'\x29\x00\x07\x00This is just SomeMoreAclData from the DUT')
+25 −30
Original line number Diff line number Diff line
@@ -51,7 +51,7 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public Connect

  ~AclManagerFacadeService() override {
    std::unique_lock<std::mutex> lock(acl_connections_mutex_);
    for (auto connection : acl_connections_) {
    for (auto& connection : acl_connections_) {
      connection.second.connection_->GetAclQueueEnd()->UnregisterDequeue();
    }
  }
@@ -276,9 +276,8 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public Connect
      std::unique_lock<std::mutex> lock(acl_connections_mutex_);
      auto connection = acl_connections_.find(request->handle());
      if (connection == acl_connections_.end()) {
        LOG_ERROR("Invalid handle");
        return ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT, "Invalid handle");
      } else {
      }
      // TODO: This is unsafe because connection may have gone
      connection->second.connection_->GetAclQueueEnd()->RegisterEnqueue(
          facade_handler_,
@@ -287,7 +286,6 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public Connect
              common::Unretained(this),
              common::Unretained(request),
              common::Passed(std::move(promise))));
      }
      auto status = future.wait_for(std::chrono::milliseconds(1000));
      if (status != std::future_status::ready) {
        return ::grpc::Status(::grpc::StatusCode::RESOURCE_EXHAUSTED, "Can't send packet");
@@ -307,10 +305,12 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public Connect
  }

  ::grpc::Status FetchAclData(
      ::grpc::ServerContext* context,
      const ::google::protobuf::Empty* request,
      ::grpc::ServerWriter<AclData>* writer) override {
    return pending_acl_data_.RunLoop(context, writer);
      ::grpc::ServerContext* context, const HandleMsg* request, ::grpc::ServerWriter<AclData>* writer) override {
    auto connection = acl_connections_.find(request->handle());
    if (connection == acl_connections_.end()) {
      return ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT, "Invalid handle");
    }
    return connection->second.pending_acl_data_.RunLoop(context, writer);
  }

  static inline uint16_t to_handle(uint32_t current_request) {
@@ -326,26 +326,29 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public Connect

  void on_incoming_acl(std::shared_ptr<ClassicAclConnection> connection, uint16_t handle) {
    auto packet = connection->GetAclQueueEnd()->TryDequeue();
    auto connection_tracker = acl_connections_.find(handle);
    ASSERT_LOG(connection_tracker != acl_connections_.end(), "handle %d", handle);
    AclData acl_data;
    acl_data.set_handle(handle);
    acl_data.set_payload(std::string(packet->begin(), packet->end()));
    pending_acl_data_.OnIncomingEvent(acl_data);
    connection_tracker->second.pending_acl_data_.OnIncomingEvent(acl_data);
  }

  void OnConnectSuccess(std::unique_ptr<ClassicAclConnection> connection) override {
    std::unique_lock<std::mutex> lock(acl_connections_mutex_);
    auto addr = connection->GetAddress();
    std::shared_ptr<ClassicAclConnection> shared_connection = std::move(connection);
    uint16_t handle = to_handle(current_connection_request_);
    acl_connections_.emplace(std::pair(
        handle,
        Connection(handle, shared_connection, per_connection_events_[current_connection_request_], facade_handler_)));
    auto remote_address = shared_connection->GetAddress().ToString();
    acl_connections_.emplace(
        std::piecewise_construct,
        std::forward_as_tuple(handle),
        std::forward_as_tuple(
            handle, shared_connection, per_connection_events_[current_connection_request_], facade_handler_));
    shared_connection->GetAclQueueEnd()->RegisterDequeue(
        facade_handler_,
        common::Bind(&AclManagerFacadeService::on_incoming_acl, common::Unretained(this), shared_connection, handle));
    auto callbacks = acl_connections_.find(handle)->second.GetCallbacks();
    shared_connection->RegisterCallbacks(callbacks, facade_handler_);
    auto addr = shared_connection->GetAddress();
    std::unique_ptr<BasePacketBuilder> builder =
        ConnectionCompleteBuilder::Create(ErrorCode::SUCCESS, handle, addr, LinkType::ACL, Enable::DISABLED);
    ConnectionEvent success;
@@ -488,16 +491,8 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public Connect
      LOG_DEBUG("OnReadClockComplete clock:%d, accuracy:%d", clock, accuracy);
    }

    void on_incoming_acl() {
      auto packet = connection_->GetAclQueueEnd()->TryDequeue();
      LOG_INFO("Discarding packet of length %zu after disconnect", packet->size());
    }

    void OnDisconnection(ErrorCode reason) override {
      LOG_DEBUG("OnDisconnection reason: %s", ErrorCodeText(reason).c_str());
      connection_->GetAclQueueEnd()->UnregisterDequeue();
      connection_->GetAclQueueEnd()->RegisterDequeue(
          facade_handler_, common::Bind(&Connection::on_incoming_acl, common::Unretained(this)));
      std::unique_ptr<BasePacketBuilder> builder =
          DisconnectionCompleteBuilder::Create(ErrorCode::SUCCESS, handle_, reason);
      ConnectionEvent disconnection;
@@ -507,6 +502,7 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public Connect
    uint16_t handle_;
    std::shared_ptr<ClassicAclConnection> connection_;
    std::shared_ptr<::bluetooth::grpc::GrpcEventQueue<ConnectionEvent>> event_stream_;
    ::bluetooth::grpc::GrpcEventQueue<AclData> pending_acl_data_{"FetchAclData"};
    ::bluetooth::os::Handler* facade_handler_;
  };

@@ -515,7 +511,6 @@ class AclManagerFacadeService : public AclManagerFacade::Service, public Connect
  ::bluetooth::os::Handler* facade_handler_;
  mutable std::mutex acl_connections_mutex_;
  std::map<uint16_t, Connection> acl_connections_;
  ::bluetooth::grpc::GrpcEventQueue<AclData> pending_acl_data_{"FetchAclData"};
  std::vector<std::shared_ptr<::bluetooth::grpc::GrpcEventQueue<ConnectionEvent>>> per_connection_events_;
  uint32_t current_connection_request_{0};
};
+1 −2
Original line number Diff line number Diff line
@@ -13,8 +13,7 @@ service AclManagerFacade {
  rpc ConnectionCommand(ConnectionCommandMsg) returns (google.protobuf.Empty) {}
  rpc SwitchRole(RoleMsg) returns (google.protobuf.Empty) {}
  rpc SendAclData(AclData) returns (google.protobuf.Empty) {}
  // TODO: Take a HandleMsg to get AclData
  rpc FetchAclData(google.protobuf.Empty) returns (stream AclData) {}
  rpc FetchAclData(HandleMsg) returns (stream AclData) {}
  rpc FetchIncomingConnection(google.protobuf.Empty) returns (stream ConnectionEvent) {}
}

+5 −9
Original line number Diff line number Diff line
@@ -192,8 +192,8 @@ class CertL2cap(Closable, IHasBehaviors):
        self._acl_manager.initiate_connection(remote_addr)
        self._acl = self._acl_manager.complete_outgoing_connection()
        self.control_channel = CertL2capChannel(
            self._device, 1, 1, self._get_acl_stream(), self._acl, control_channel=None)
        self._get_acl_stream().register_callback(self._handle_control_packet)
            self._device, 1, 1, self._acl.acl_stream, self._acl, control_channel=None)
        self._acl.acl_stream.register_callback(self._handle_control_packet)

    def open_channel(self, signal_id, psm, scid, fcs=None):
        self.control_channel.send(l2cap_packets.ConnectionRequestBuilder(signal_id, psm, scid))
@@ -201,7 +201,7 @@ class CertL2cap(Closable, IHasBehaviors):
        response = L2capCaptures.ConnectionResponse(scid)
        assertThat(self.control_channel).emits(response)
        channel = CertL2capChannel(self._device, scid,
                                   response.get().GetDestinationCid(), self._get_acl_stream(), self._acl,
                                   response.get().GetDestinationCid(), self._acl.acl_stream, self._acl,
                                   self.control_channel, fcs)
        self.scid_to_channel[scid] = channel

@@ -216,8 +216,7 @@ class CertL2cap(Closable, IHasBehaviors):
        dcid = request.get().GetSourceCid()
        if scid is None or scid in self.scid_to_channel:
            scid = dcid
        channel = CertL2capChannel(self._device, scid, dcid, self._get_acl_stream(), self._acl, self.control_channel,
                                   fcs)
        channel = CertL2capChannel(self._device, scid, dcid, self._acl.acl_stream, self._acl, self.control_channel, fcs)
        self.scid_to_channel[scid] = channel

        connection_response = l2cap_packets.ConnectionResponseBuilder(
@@ -237,7 +236,7 @@ class CertL2cap(Closable, IHasBehaviors):
        sid = request.get().GetIdentifier()
        dcid = request.get().GetSourceCid()
        scid = dcid
        channel = CertL2capChannel(self._device, scid, dcid, self._get_acl_stream(), self._acl, self.control_channel)
        channel = CertL2capChannel(self._device, scid, dcid, self._acl.acl_stream, self._acl, self.control_channel)
        self.scid_to_channel[scid] = channel

        # Connection response and config request combo packet
@@ -256,9 +255,6 @@ class CertL2cap(Closable, IHasBehaviors):
    def get_control_channel(self):
        return self.control_channel

    def _get_acl_stream(self):
        return self._acl_manager.get_acl_stream()

    # Disable ERTM when exchange extened feature
    def claim_ertm_unsupported(self):
        self.support_ertm = False