Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c68b880e authored by Myles Watson's avatar Myles Watson
Browse files

AclManagerTest: Make connections symmetric

Incoming and outgoing connections should have their
own event_stream.

Bug: 145832107
Tag: #gd-refactor
Test: cert/run --host
Change-Id: Ibeb3cd3ec6f598f1fe80e71eb6844776525ef56c
parent 4ff425de
Loading
Loading
Loading
Loading
+18 −6
Original line number Diff line number Diff line
@@ -27,20 +27,26 @@ class HalCaptures(object):
    @staticmethod
    def ReadBdAddrCompleteCapture():
        return Capture(
            lambda packet: b'\x0e\x0a\x01\x09\x10' in packet.payload, lambda packet: hci_packets.ReadBdAddrCompleteView(
            lambda packet: packet.payload[0:5] == b'\x0e\x0a\x01\x09\x10', lambda packet: hci_packets.ReadBdAddrCompleteView(
                hci_packets.CommandCompleteView(
                    hci_packets.EventPacketView(bt_packets.PacketViewLittleEndian(list(packet.payload))))))

    @staticmethod
    def ConnectionRequestCapture():
        return Capture(
            lambda packet: b'\x04\x0a' in packet.payload, lambda packet: hci_packets.ConnectionRequestView(
            lambda packet: packet.payload[0:2] == b'\x04\x0a', lambda packet: hci_packets.ConnectionRequestView(
                hci_packets.EventPacketView(bt_packets.PacketViewLittleEndian(list(packet.payload)))))

    @staticmethod
    def ConnectionCompleteCapture():
        return Capture(
            lambda packet: b'\x03\x0b\x00' in packet.payload, lambda packet: hci_packets.ConnectionCompleteView(
            lambda packet: packet.payload[0:3] == b'\x03\x0b\x00', lambda packet: hci_packets.ConnectionCompleteView(
                hci_packets.EventPacketView(bt_packets.PacketViewLittleEndian(list(packet.payload)))))

    @staticmethod
    def DisconnectionCompleteCapture():
        return Capture(
            lambda packet: packet.payload[0:2] == b'\x05\x04', lambda packet: hci_packets.DisconnectionCompleteView(
                hci_packets.EventPacketView(bt_packets.PacketViewLittleEndian(list(packet.payload)))))

    @staticmethod
@@ -57,20 +63,26 @@ class HciCaptures(object):
    @staticmethod
    def ReadBdAddrCompleteCapture():
        return Capture(
            lambda packet: b'\x0e\x0a\x01\x09\x10' in packet.event, lambda packet: hci_packets.ReadBdAddrCompleteView(
            lambda packet: packet.event[0:5] == b'\x0e\x0a\x01\x09\x10', lambda packet: hci_packets.ReadBdAddrCompleteView(
                hci_packets.CommandCompleteView(
                    hci_packets.EventPacketView(bt_packets.PacketViewLittleEndian(list(packet.event))))))

    @staticmethod
    def ConnectionRequestCapture():
        return Capture(
            lambda packet: b'\x04\x0a' in packet.event, lambda packet: hci_packets.ConnectionRequestView(
            lambda packet: packet.event[0:2] == b'\x04\x0a', lambda packet: hci_packets.ConnectionRequestView(
                hci_packets.EventPacketView(bt_packets.PacketViewLittleEndian(list(packet.event)))))

    @staticmethod
    def ConnectionCompleteCapture():
        return Capture(
            lambda packet: b'\x03\x0b\x00' in packet.event, lambda packet: hci_packets.ConnectionCompleteView(
            lambda packet: packet.event[0:3] == b'\x03\x0b\x00', lambda packet: hci_packets.ConnectionCompleteView(
                hci_packets.EventPacketView(bt_packets.PacketViewLittleEndian(list(packet.event)))))

    @staticmethod
    def DisconnectionCompleteCapture():
        return Capture(
            lambda packet: packet.event[0:2] == b'\x05\x04', lambda packet: hci_packets.DisconnectionCompleteView(
                hci_packets.EventPacketView(bt_packets.PacketViewLittleEndian(list(packet.event)))))

    @staticmethod
+41 −19
Original line number Diff line number Diff line
@@ -27,20 +27,17 @@ from hci.facade import acl_manager_facade_pb2 as acl_manager_facade

class PyAclManagerAclConnection(IEventStream, Closable):

    def __init__(self, device, acl_stream, remote_addr, handle):
    def __init__(self, device, acl_stream, remote_addr, handle, event_stream):
        self.device = device
        self.handle = handle
        # todo enable filtering after sorting out handles
        # self.our_acl_stream = FilteringEventStream(acl_stream, None)
        self.our_acl_stream = acl_stream
        self.connection_event_stream = event_stream

        if remote_addr:
            remote_addr_bytes = bytes(remote_addr, 'utf8') if type(remote_addr) is str else bytes(remote_addr)
            self.connection_event_stream = EventStream(
                self.device.hci_acl_manager.CreateConnection(
                    acl_manager_facade.ConnectionMsg(address=remote_addr_bytes)))
        else:
            self.connection_event_stream = None
    def disconnect(self, reason):
        packet_bytes = bytes(hci_packets.DisconnectBuilder(self.handle, reason).Serialize())
        self.device.hci_acl_manager.ConnectionCommand(acl_manager_facade.ConnectionCommandMsg(packet=packet_bytes))

    def close(self):
        safeClose(self.connection_event_stream)
@@ -50,6 +47,12 @@ class PyAclManagerAclConnection(IEventStream, Closable):
        assertThat(self.connection_event_stream).emits(connection_complete)
        self.handle = connection_complete.get().GetConnectionHandle()

    def wait_for_disconnection_complete(self):
        disconnection_complete = HciCaptures.DisconnectionCompleteCapture()
        assertThat(self.connection_event_stream).isNotNone()
        assertThat(self.connection_event_stream).emits(disconnection_complete)
        self.disconnect_reason = disconnection_complete.get().GetReason()

    def send(self, data):
        self.device.hci_acl_manager.SendAclData(acl_manager_facade.AclData(handle=self.handle, payload=bytes(data)))

@@ -61,27 +64,46 @@ class PyAclManager(Closable):

    def __init__(self, device):
        self.device = device

        self.acl_stream = EventStream(self.device.hci_acl_manager.FetchAclData(empty_proto.Empty()))
        self.incoming_connection_stream = None
        self.incoming_connection_event_stream = None
        self.outgoing_connection_event_stream = None

    def close(self):
        safeClose(self.acl_stream)
        safeClose(self.incoming_connection_stream)
        safeClose(self.incoming_connection_event_stream)
        safeClose(self.outgoing_connection_event_stream)

    # temporary, until everyone is migrated
    def get_acl_stream(self):
        return self.acl_stream

    def listen_for_incoming_connections(self):
        self.incoming_connection_stream = EventStream(
    def listen_for_an_incoming_connection(self):
        assertThat(self.incoming_connection_event_stream).isNone()
        self.incoming_connection_event_stream = EventStream(
            self.device.hci_acl_manager.FetchIncomingConnection(empty_proto.Empty()))

    def initiate_connection(self, remote_addr):
        return PyAclManagerAclConnection(self.device, self.acl_stream, remote_addr, None)
        assertThat(self.outgoing_connection_event_stream).isNone()
        remote_addr_bytes = bytes(remote_addr, 'utf8') if type(remote_addr) is str else bytes(remote_addr)
        self.outgoing_connection_event_stream = EventStream(
            self.device.hci_acl_manager.CreateConnection(acl_manager_facade.ConnectionMsg(address=remote_addr_bytes)))

    def accept_connection(self):
    def complete_connection(self, event_stream):
        connection_complete = HciCaptures.ConnectionCompleteCapture()
        assertThat(self.incoming_connection_stream).emits(connection_complete)
        handle = connection_complete.get().GetConnectionHandle()
        return PyAclManagerAclConnection(self.device, self.acl_stream, None, handle)
        assertThat(event_stream).emits(connection_complete)
        complete = connection_complete.get()
        handle = complete.GetConnectionHandle()
        address = complete.GetBdAddr()
        return PyAclManagerAclConnection(self.device, self.acl_stream, address, handle, event_stream)

    def complete_incoming_connection(self):
        assertThat(self.incoming_connection_event_stream).isNotNone()
        event_stream = self.incoming_connection_event_stream
        self.incoming_connection_event_stream = None
        return self.complete_connection(event_stream)

    def complete_outgoing_connection(self):
        assertThat(self.outgoing_connection_event_stream).isNotNone()
        event_stream = self.outgoing_connection_event_stream
        self.outgoing_connection_event_stream = None
        return self.complete_connection(event_stream)
+27 −11
Original line number Diff line number Diff line
@@ -52,12 +52,10 @@ class AclManagerTest(GdBaseTestClass):
        self.cert_hci.enable_inquiry_and_page_scan()
        cert_address = self.cert_hci.read_own_address()

        with self.dut_acl_manager.initiate_connection(cert_address) as dut_acl:
        self.dut_acl_manager.initiate_connection(cert_address)
        cert_acl = self.cert_hci.accept_connection()
        with self.dut_acl_manager.complete_outgoing_connection() as dut_acl:
            cert_acl.send_first(b'\x26\x00\x07\x00This is just SomeAclData from the Cert')

            dut_acl.wait_for_connection_complete()

            dut_acl.send(b'\x29\x00\x07\x00This is just SomeMoreAclData from the DUT')

            assertThat(cert_acl).emits(lambda packet: b'SomeMoreAclData' in packet.data)
@@ -67,10 +65,26 @@ class AclManagerTest(GdBaseTestClass):
        dut_address = self.dut.hci_controller.GetMacAddressSimple()
        self.dut.neighbor.EnablePageScan(neighbor_facade.EnableMsg(enabled=True))

        self.dut_acl_manager.listen_for_incoming_connections()
        self.dut_acl_manager.listen_for_an_incoming_connection()
        self.cert_hci.initiate_connection(dut_address)
        dut_acl = self.dut_acl_manager.complete_incoming_connection()

        dut_acl = self.dut_acl_manager.accept_connection()
        cert_acl = self.cert_hci.complete_connection()

        dut_acl.send(b'\x29\x00\x07\x00This is just SomeMoreAclData from the DUT')

        cert_acl.send_first(b'\x26\x00\x07\x00This is just SomeAclData from the Cert')

        assertThat(cert_acl).emits(lambda packet: b'SomeMoreAclData' in packet.data)
        assertThat(dut_acl).emits(lambda packet: b'SomeAclData' in packet.payload)

    def test_cert_connects_disconnects(self):
        dut_address = self.dut.hci_controller.GetMacAddressSimple()
        self.dut.neighbor.EnablePageScan(neighbor_facade.EnableMsg(enabled=True))

        self.dut_acl_manager.listen_for_an_incoming_connection()
        self.cert_hci.initiate_connection(dut_address)
        dut_acl = self.dut_acl_manager.complete_incoming_connection()

        cert_acl = self.cert_hci.complete_connection()

@@ -81,17 +95,19 @@ class AclManagerTest(GdBaseTestClass):
        assertThat(cert_acl).emits(lambda packet: b'SomeMoreAclData' in packet.data)
        assertThat(dut_acl).emits(lambda packet: b'SomeAclData' in packet.payload)

        dut_acl.disconnect(hci_packets.DisconnectReason.REMOTE_USER_TERMINATED_CONNECTION)
        dut_acl.wait_for_disconnection_complete()

    def test_recombination_l2cap_packet(self):
        self.cert_hci.enable_inquiry_and_page_scan()
        cert_address = self.cert_hci.read_own_address()

        with self.dut_acl_manager.initiate_connection(cert_address) as dut_acl:
        self.dut_acl_manager.initiate_connection(cert_address)
        cert_acl = self.cert_hci.accept_connection()
        with self.dut_acl_manager.complete_outgoing_connection() as dut_acl:
            cert_acl.send_first(b'\x06\x00\x07\x00Hello')
            cert_acl.send_continuing(b'!')
            cert_acl.send_first(b'\xe8\x03\x07\x00' + b'Hello' * 200)

            dut_acl.wait_for_connection_complete()

            assertThat(dut_acl).emits(lambda packet: b'Hello!' in packet.payload,
                                      lambda packet: b'Hello' * 200 in packet.payload).inOrder()
+239 −38

File changed.

Preview size limit exceeded, changes collapsed.

+22 −0
Original line number Diff line number Diff line
@@ -8,8 +8,12 @@ service AclManagerFacade {
  rpc CreateConnection(ConnectionMsg) returns (stream ConnectionEvent) {}
  rpc CancelConnection(ConnectionMsg) returns (google.protobuf.Empty) {}
  rpc Disconnect(HandleMsg) returns (google.protobuf.Empty) {}
  rpc WriteDefaultLinkPolicySettings(PolicyMsg) returns (google.protobuf.Empty) {}
  rpc AuthenticationRequested(HandleMsg) returns (google.protobuf.Empty) {}
  rpc ConnectionCommand(ConnectionCommandMsg) returns (google.protobuf.Empty) {}
  rpc SwitchRole(RoleMsg) returns (google.protobuf.Empty) {}
  rpc SendAclData(AclData) returns (google.protobuf.Empty) {}
  // TODO: Take a HandleMsg to get AclData
  rpc FetchAclData(google.protobuf.Empty) returns (stream AclData) {}
  rpc FetchIncomingConnection(google.protobuf.Empty) returns (stream ConnectionEvent) {}
}
@@ -22,6 +26,24 @@ message ConnectionMsg {
  bytes address = 1;
}

message PolicyMsg {
  uint32 policy = 1;
}

enum NewRole {
  MASTER = 0;
  SLAVE = 1;
}

message RoleMsg {
  bytes address = 1;
  NewRole role = 2;
}

message ConnectionCommandMsg {
  bytes packet = 1;
}

message ConnectionEvent {
  bytes event = 1;
}
Loading