Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 637e2603 authored by Zach Johnson's avatar Zach Johnson
Browse files

Add captures, to clean up test logic and make tests easer to read

Test: cert/run --host
Change-Id: I560183190070bdb032960ea0b43da8ea6b8b4ffd
parent 9e4a9a55
Loading
Loading
Loading
Loading
+41 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3
#
#   Copyright 2020 - The Android Open Source Project
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.


class Capture(object):
    """
    Wrap a match function and use in its place, to capture the value
    that matched. Specify an optional |capture_fn| to transform the
    captured value.
    """

    def __init__(self, match_fn, capture_fn=None):
        self._match_fn = match_fn
        self._capture_fn = capture_fn
        self._value = None

    def __call__(self, obj):
        if self._match_fn(obj) != True:
            return False

        if self._capture_fn is not None:
            self._value = self._capture_fn(obj)
        else:
            self._value = obj
        return True

    def get(self):
        return self._value
+36 −116
Original line number Diff line number Diff line
@@ -28,6 +28,9 @@ from hci.facade import controller_facade_pb2 as controller_facade
from hci.facade import facade_pb2 as hci_facade
import bluetooth_packets_python3 as bt_packets
from bluetooth_packets_python3 import hci_packets
from captures import ReadBdAddrCompleteCapture
from captures import ConnectionCompleteCapture
from captures import ConnectionRequestCapture


class AclManagerTest(GdFacadeOnlyBaseTestClass):
@@ -69,24 +72,11 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
                hci_packets.WriteScanEnableBuilder(
                    hci_packets.ScanEnable.INQUIRY_AND_PAGE_SCAN), True)

            cert_address = None

            def get_address_from_complete(packet):
                packet_bytes = packet.event
                if b'\x0e\x0a\x01\x09\x10' in packet_bytes:
                    nonlocal cert_address
                    addr_view = hci_packets.ReadBdAddrCompleteView(
                        hci_packets.CommandCompleteView(
                            hci_packets.EventPacketView(
                                bt_packets.PacketViewLittleEndian(
                                    list(packet_bytes)))))
                    cert_address = addr_view.GetBdAddr()
                    return True
                return False

            self.enqueue_hci_command(hci_packets.ReadBdAddrBuilder(), True)

            assertThat(cert_hci_event_stream).emits(get_address_from_complete)
            read_bd_addr = ReadBdAddrCompleteCapture()
            assertThat(cert_hci_event_stream).emits(read_bd_addr)
            cert_address = read_bd_addr.get().GetBdAddr()

            with EventStream(
                    self.dut.hci_acl_manager.CreateConnection(
@@ -96,43 +86,20 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
                            address=bytes(cert_address,
                                          'utf8')))) as connection_event_stream:

                connection_request = None

                def get_connect_request(packet):
                    if b'\x04\x0a' in packet.event:
                        nonlocal connection_request
                        connection_request = hci_packets.ConnectionRequestView(
                            hci_packets.EventPacketView(
                                bt_packets.PacketViewLittleEndian(
                                    list(packet.event))))
                        return True
                    return False

                # Cert Accepts
                assertThat(cert_hci_event_stream).emits(get_connect_request)
                connection_request = ConnectionRequestCapture()
                assertThat(cert_hci_event_stream).emits(connection_request)

                self.enqueue_hci_command(
                    hci_packets.AcceptConnectionRequestBuilder(
                        connection_request.GetBdAddr(),
                        connection_request.get().GetBdAddr(),
                        hci_packets.AcceptConnectionRequestRole.REMAIN_SLAVE),
                    False)

                # Cert gets ConnectionComplete with a handle and sends ACL data
                handle = 0xfff

                def get_handle(packet):
                    packet_bytes = packet.event
                    if b'\x03\x0b\x00' in packet_bytes:
                        nonlocal handle
                        cc_view = hci_packets.ConnectionCompleteView(
                            hci_packets.EventPacketView(
                                bt_packets.PacketViewLittleEndian(
                                    list(packet_bytes))))
                        handle = cc_view.GetConnectionHandle()
                        return True
                    return False

                assertThat(cert_hci_event_stream).emits(get_handle)
                cert_handle = handle
                connection_complete = ConnectionCompleteCapture()
                assertThat(cert_hci_event_stream).emits(connection_complete)
                cert_handle = connection_complete.get().GetConnectionHandle()

                self.enqueue_acl_data(
                    cert_handle, hci_packets.PacketBoundaryFlag.
@@ -143,12 +110,13 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
                    ))

                # DUT gets a connection complete event and sends and receives
                handle = 0xfff
                connection_event_stream.assert_event_occurs(get_handle)
                connection_complete = ConnectionCompleteCapture()
                connection_event_stream.assert_event_occurs(connection_complete)
                dut_handle = connection_complete.get().GetConnectionHandle()

                self.dut.hci_acl_manager.SendAclData(
                    acl_manager_facade.AclData(
                        handle=handle,
                        handle=dut_handle,
                        payload=bytes(
                            b'\x29\x00\x07\x00This is just SomeMoreAclData from the DUT'
                        )))
@@ -186,34 +154,21 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
                    hci_packets.CreateConnectionRoleSwitch.ALLOW_ROLE_SWITCH),
                False)

            conn_handle = 0xfff

            def get_handle(packet):
                packet_bytes = packet.event
                if b'\x03\x0b\x00' in packet_bytes:
                    nonlocal conn_handle
                    cc_view = hci_packets.ConnectionCompleteView(
                        hci_packets.EventPacketView(
                            bt_packets.PacketViewLittleEndian(
                                list(packet_bytes))))
                    conn_handle = cc_view.GetConnectionHandle()
                    return True
                return False

            # DUT gets a connection request
            incoming_connection_stream.assert_event_occurs(get_handle)
            connection_complete = ConnectionCompleteCapture()
            assertThat(incoming_connection_stream).emits(connection_complete)
            dut_handle = connection_complete.get().GetConnectionHandle()

            self.dut.hci_acl_manager.SendAclData(
                acl_manager_facade.AclData(
                    handle=conn_handle,
                    handle=dut_handle,
                    payload=bytes(
                        b'\x29\x00\x07\x00This is just SomeMoreAclData from the DUT'
                    )))

            conn_handle = 0xfff

            assertThat(cert_hci_event_stream).emits(get_handle)
            cert_handle = conn_handle
            connection_complete = ConnectionCompleteCapture()
            assertThat(cert_hci_event_stream).emits(connection_complete)
            cert_handle = connection_complete.get().GetConnectionHandle()

            self.enqueue_acl_data(
                cert_handle,
@@ -241,24 +196,11 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
                hci_packets.WriteScanEnableBuilder(
                    hci_packets.ScanEnable.INQUIRY_AND_PAGE_SCAN), True)

            cert_address = None

            def get_address_from_complete(packet):
                packet_bytes = packet.event
                if b'\x0e\x0a\x01\x09\x10' in packet_bytes:
                    nonlocal cert_address
                    addr_view = hci_packets.ReadBdAddrCompleteView(
                        hci_packets.CommandCompleteView(
                            hci_packets.EventPacketView(
                                bt_packets.PacketViewLittleEndian(
                                    list(packet_bytes)))))
                    cert_address = addr_view.GetBdAddr()
                    return True
                return False

            self.enqueue_hci_command(hci_packets.ReadBdAddrBuilder(), True)

            assertThat(cert_hci_event_stream).emits(get_address_from_complete)
            read_bd_addr = ReadBdAddrCompleteCapture()
            assertThat(cert_hci_event_stream).emits(read_bd_addr)
            cert_address = read_bd_addr.get().GetBdAddr()

            with EventStream(
                    self.dut.hci_acl_manager.CreateConnection(
@@ -268,43 +210,19 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
                            address=bytes(cert_address,
                                          'utf8')))) as connection_event_stream:

                connection_request = None

                def get_connect_request(packet):
                    if b'\x04\x0a' in packet.event:
                        nonlocal connection_request
                        connection_request = hci_packets.ConnectionRequestView(
                            hci_packets.EventPacketView(
                                bt_packets.PacketViewLittleEndian(
                                    list(packet.event))))
                        return True
                    return False

                # Cert Accepts
                assertThat(cert_hci_event_stream).emits(get_connect_request)
                connection_request = ConnectionRequestCapture()
                assertThat(cert_hci_event_stream).emits(connection_request)
                self.enqueue_hci_command(
                    hci_packets.AcceptConnectionRequestBuilder(
                        connection_request.GetBdAddr(),
                        connection_request.get().GetBdAddr(),
                        hci_packets.AcceptConnectionRequestRole.REMAIN_SLAVE),
                    False)

                # Cert gets ConnectionComplete with a handle and sends ACL data
                handle = 0xfff

                def get_handle(packet):
                    packet_bytes = packet.event
                    if b'\x03\x0b\x00' in packet_bytes:
                        nonlocal handle
                        cc_view = hci_packets.ConnectionCompleteView(
                            hci_packets.EventPacketView(
                                bt_packets.PacketViewLittleEndian(
                                    list(packet_bytes))))
                        handle = cc_view.GetConnectionHandle()
                        return True
                    return False

                assertThat(cert_hci_event_stream).emits(get_handle)
                cert_handle = handle
                connection_complete = ConnectionCompleteCapture()
                assertThat(cert_hci_event_stream).emits(connection_complete)
                cert_handle = connection_complete.get().GetConnectionHandle()

                self.enqueue_acl_data(
                    cert_handle, hci_packets.PacketBoundaryFlag.
@@ -322,7 +240,9 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
                    bytes(b'\xe8\x03\x07\x00' + b'Hello' * 200))

                # DUT gets a connection complete event and sends and receives
                connection_event_stream.assert_event_occurs(get_handle)
                connection_complete = ConnectionCompleteCapture()
                connection_event_stream.assert_event_occurs(connection_complete)
                dut_handle = connection_complete.get().GetConnectionHandle()

                assertThat(acl_data_stream).emits(
                    lambda packet: b'Hello!' in packet.payload).then(
+44 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3
#
#   Copyright 2020 - The Android Open Source Project
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

import bluetooth_packets_python3 as bt_packets
from bluetooth_packets_python3 import hci_packets
from cert.capture import Capture


def ReadBdAddrCompleteCapture():
    return Capture(lambda packet: b'\x0e\x0a\x01\x09\x10' in packet.event,
      lambda packet: hci_packets.ReadBdAddrCompleteView(
                  hci_packets.CommandCompleteView(
                            hci_packets.EventPacketView(
                                bt_packets.PacketViewLittleEndian(
                                    list(packet.event))))))


def ConnectionRequestCapture():
    return Capture(lambda packet: b'\x04\x0a' in packet.event,
      lambda packet: hci_packets.ConnectionRequestView(
                            hci_packets.EventPacketView(
                                bt_packets.PacketViewLittleEndian(
                                    list(packet.event)))))


def ConnectionCompleteCapture():
    return Capture(lambda packet: b'\x03\x0b\x00' in packet.event,
      lambda packet: hci_packets.ConnectionCompleteView(
                            hci_packets.EventPacketView(
                                bt_packets.PacketViewLittleEndian(
                                    list(packet.event)))))