Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 20ba1129 authored by Zach Johnson's avatar Zach Johnson
Browse files

Add IEventStream, so other subjects can use emit, etc

And use in PyHci & PyAclManager. Can't quite turn on filtering
on the ACL messages, due to the handles being wrong.

Test: cert/run --host --test_filter=AclManagerTest
Change-Id: I30ef332b06ae553553757337cd66da91f8debe3a
parent 40232190
Loading
Loading
Loading
Loading
+112 −61
Original line number Diff line number Diff line
@@ -24,14 +24,44 @@ from google.protobuf import text_format
from concurrent.futures import ThreadPoolExecutor
from grpc import RpcError

from abc import ABC, abstractmethod

class EventStream(object):

class IEventStream(ABC):

    @abstractmethod
    def get_event_queue(self):
        pass


class FilteringEventStream(IEventStream):

    def __init__(self, stream, filter_fn):
        self.filter_fn = filter_fn
        self.event_queue = SimpleQueue()
        self.stream = stream

        self.stream.register_callback(self.__event_callback, self.filter_fn)

    def __event_callback(self, event):
        self.event_queue.put(event)

    def get_event_queue(self):
        return self.event_queue

    def unregister(self):
        self.stream.unregister(self.__event_callback)


DEFAULT_TIMEOUT_SECONDS = 3


class EventStream(IEventStream):
    """
    A class that streams events from a gRPC stream, which you can assert on.

    Don't use these asserts directly, use the ones from cert.truth.
    """
    DEFAULT_TIMEOUT_SECONDS = 3

    def __init__(self, server_stream_call):
        if server_stream_call is None:
@@ -53,11 +83,8 @@ class EventStream(object):
    def __del__(self):
        self.shutdown()

    def remaining_time_delta(self, end_time):
        remaining = end_time - datetime.now()
        if remaining < timedelta(milliseconds=0):
            remaining = timedelta(milliseconds=0)
        return remaining
    def get_event_queue(self):
        return self.event_queue

    def shutdown(self):
        """
@@ -170,7 +197,7 @@ class EventStream(object):
        event = None
        end_time = datetime.now() + timeout
        while event is None and datetime.now() < end_time:
            remaining = self.remaining_time_delta(end_time)
            remaining = static_remaining_time_delta(end_time)
            logging.debug("Waiting for event (%fs remaining)" %
                          (remaining.total_seconds()))
            try:
@@ -202,26 +229,7 @@ class EventStream(object):
                               happen
        :return:
        """
        logging.debug("assert_event_occurs %d %fs" % (at_least_times,
                                                      timeout.total_seconds()))
        event_list = []
        end_time = datetime.now() + timeout
        while len(event_list) < at_least_times and datetime.now() < end_time:
            remaining = self.remaining_time_delta(end_time)
            logging.debug("Waiting for event (%fs remaining)" %
                          (remaining.total_seconds()))
            try:
                current_event = self.event_queue.get(
                    timeout=remaining.total_seconds())
                if match_fn(current_event):
                    event_list.append(current_event)
            except Empty:
                continue
        logging.debug("Done waiting for event")
        asserts.assert_true(
            len(event_list) >= at_least_times,
            msg=("Expected at least %d events, but got %d" % (at_least_times,
                                                              len(event_list))))
        NOT_FOR_YOU_assert_event_occurs(self, match_fn, at_least_times, timeout)

    def assert_event_occurs_at_most(
            self,
@@ -242,7 +250,7 @@ class EventStream(object):
        event_list = []
        end_time = datetime.now() + timeout
        while len(event_list) <= at_most_times and datetime.now() < end_time:
            remaining = self.remaining_time_delta(end_time)
            remaining = static_remaining_time_delta(end_time)
            logging.debug("Waiting for event iteration (%fs remaining)" %
                          (remaining.total_seconds()))
            try:
@@ -263,16 +271,59 @@ class EventStream(object):
            match_fns,
            order_matters,
            timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
        NOT_FOR_YOU_assert_all_events_occur(self, match_fns, order_matters,
                                            timeout)


def static_remaining_time_delta(end_time):
    remaining = end_time - datetime.now()
    if remaining < timedelta(milliseconds=0):
        remaining = timedelta(milliseconds=0)
    return remaining


def NOT_FOR_YOU_assert_event_occurs(
        istream,
        match_fn,
        at_least_times=1,
        timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
    logging.debug("assert_event_occurs %d %fs" % (at_least_times,
                                                  timeout.total_seconds()))
    event_list = []
    end_time = datetime.now() + timeout
    while len(event_list) < at_least_times and datetime.now() < end_time:
        remaining = static_remaining_time_delta(end_time)
        logging.debug(
            "Waiting for event (%fs remaining)" % (remaining.total_seconds()))
        try:
            current_event = istream.get_event_queue().get(
                timeout=remaining.total_seconds())
            if match_fn(current_event):
                event_list.append(current_event)
        except Empty:
            continue
    logging.debug("Done waiting for event")
    asserts.assert_true(
        len(event_list) >= at_least_times,
        msg=("Expected at least %d events, but got %d" % (at_least_times,
                                                          len(event_list))))


def NOT_FOR_YOU_assert_all_events_occur(
        istream,
        match_fns,
        order_matters,
        timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
    logging.debug("assert_all_events_occur %fs" % timeout.total_seconds())
    pending_matches = list(match_fns)
    matched_order = []
    end_time = datetime.now() + timeout
    while len(pending_matches) > 0 and datetime.now() < end_time:
            remaining = self.remaining_time_delta(end_time)
            logging.debug("Waiting for event (%fs remaining)" %
                          (remaining.total_seconds()))
        remaining = static_remaining_time_delta(end_time)
        logging.debug(
            "Waiting for event (%fs remaining)" % (remaining.total_seconds()))
        try:
                current_event = self.event_queue.get(
            current_event = istream.get_event_queue().get(
                timeout=remaining.total_seconds())
            for match_fn in pending_matches:
                if match_fn(current_event):
@@ -283,8 +334,8 @@ class EventStream(object):
    logging.debug("Done waiting for event")
    asserts.assert_true(
        len(matched_order) == len(match_fns),
            msg=("Expected at least %d events, but got %d" %
                 (len(match_fns), len(matched_order))))
        msg=("Expected at least %d events, but got %d" % (len(match_fns),
                                                          len(matched_order))))
    if order_matters:
        correct_order = True
        i = 0
+10 −8
Original line number Diff line number Diff line
@@ -20,7 +20,9 @@ from mobly.asserts import assert_true
from mobly.asserts import assert_false

from mobly import signals
from cert.event_stream import EventStream
from cert.event_stream import IEventStream
from cert.event_stream import NOT_FOR_YOU_assert_event_occurs
from cert.event_stream import NOT_FOR_YOU_assert_all_events_occur

import sys, traceback

@@ -63,7 +65,7 @@ class EventStreamSubject(ObjectSubject):
        if len(match_fns) == 0:
            raise signals.TestFailure("Must specify a match function")
        elif len(match_fns) == 1:
            self._value.assert_event_occurs(match_fns[0])
            NOT_FOR_YOU_assert_event_occurs(self._value, match_fns[0])
            return EventStreamContinuationSubject(self._value)
        else:
            return MultiMatchStreamSubject(self._value, match_fns)
@@ -76,13 +78,13 @@ class MultiMatchStreamSubject(object):
        self._match_fns = match_fns

    def inAnyOrder(self):
        self._stream.assert_all_events_occur(
            self._match_fns, order_matters=False)
        NOT_FOR_YOU_assert_all_events_occur(
            self._stream, self._match_fns, order_matters=False)
        return EventStreamContinuationSubject(self._stream)

    def inOrder(self):
        self._stream.assert_all_events_occur(
            self._match_fns, order_matters=True)
        NOT_FOR_YOU_assert_all_events_occur(
            self._stream, self._match_fns, order_matters=True)
        return EventStreamContinuationSubject(self._stream)


@@ -95,7 +97,7 @@ class EventStreamContinuationSubject(ObjectSubject):
        if len(match_fns) == 0:
            raise signals.TestFailure("Must specify a match function")
        elif len(match_fns) == 1:
            self._value.assert_event_occurs(match_fns[0])
            NOT_FOR_YOU_assert_event_occurs(self._value, match_fns[0])
            return EventStreamContinuationSubject(self._value)
        else:
            return MultiMatchStreamSubject(self._value, match_fns)
@@ -116,7 +118,7 @@ class BooleanSubject(ObjectSubject):
def assertThat(subject):
    if type(subject) is bool:
        return BooleanSubject(subject)
    elif isinstance(subject, EventStream):
    elif isinstance(subject, IEventStream):
        return EventStreamSubject(subject)
    else:
        return ObjectSubject(subject)
+5 −5
Original line number Diff line number Diff line
@@ -58,9 +58,9 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
                    b'\x29\x00\x07\x00This is just SomeMoreAclData from the DUT'
                )

                assertThat(cert_hci.get_acl_stream()).emits(
                assertThat(cert_acl).emits(
                    lambda packet: b'SomeMoreAclData' in packet.data)
                assertThat(dut_acl_manager.get_acl_stream()).emits(
                assertThat(dut_acl).emits(
                    lambda packet: b'SomeAclData' in packet.payload)

    def test_cert_connects(self):
@@ -85,9 +85,9 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
            cert_acl.send_first(
                b'\x26\x00\x07\x00This is just SomeAclData from the Cert')

            assertThat(cert_hci.get_acl_stream()).emits(
            assertThat(cert_acl).emits(
                lambda packet: b'SomeMoreAclData' in packet.data)
            assertThat(dut_acl_manager.get_acl_stream()).emits(
            assertThat(dut_acl).emits(
                lambda packet: b'SomeAclData' in packet.payload)

    def test_recombination_l2cap_packet(self):
@@ -105,6 +105,6 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):

                dut_acl.wait_for_connection_complete()

                assertThat(dut_acl_manager.get_acl_stream()).emits(
                assertThat(dut_acl).emits(
                    lambda packet: b'Hello!' in packet.payload,
                    lambda packet: b'Hello' * 200 in packet.payload).inOrder()
+13 −7
Original line number Diff line number Diff line
@@ -16,6 +16,8 @@

from google.protobuf import empty_pb2 as empty_proto
from cert.event_stream import EventStream
from cert.event_stream import FilteringEventStream
from cert.event_stream import IEventStream
from captures import ReadBdAddrCompleteCapture
from captures import ConnectionCompleteCapture
from captures import ConnectionRequestCapture
@@ -25,11 +27,13 @@ from hci.facade import facade_pb2 as hci_facade
from hci.facade import acl_manager_facade_pb2 as acl_manager_facade


class PyAclManagerAclConnection(object):
class PyAclManagerAclConnection(IEventStream):

    def __init__(self, device, remote_addr, handle):
    def __init__(self, device, acl_stream, remote_addr, handle):
        self.device = device
        self.handle = handle
        # todo enable filtering after sorting out handles
        self.our_acl_stream = FilteringEventStream(acl_stream, None)

        if remote_addr:
            self.connection_event_stream = EventStream(
@@ -64,6 +68,9 @@ class PyAclManagerAclConnection(object):
        self.device.hci_acl_manager.SendAclData(
            acl_manager_facade.AclData(handle=self.handle, payload=bytes(data)))

    def get_event_queue(self):
        return self.our_acl_stream.get_event_queue()


class PyAclManager(object):

@@ -94,14 +101,13 @@ class PyAclManager(object):
            self.device.hci_acl_manager.FetchIncomingConnection(
                empty_proto.Empty()))

    def get_acl_stream(self):
        return self.acl_stream

    def initiate_connection(self, remote_addr):
        return PyAclManagerAclConnection(self.device, remote_addr, None)
        return PyAclManagerAclConnection(self.device, self.acl_stream,
                                         remote_addr, None)

    def accept_connection(self):
        connection_complete = ConnectionCompleteCapture()
        assertThat(self.incoming_connection_stream).emits(connection_complete)
        handle = connection_complete.get().GetConnectionHandle()
        return PyAclManagerAclConnection(self.device, None, handle)
        return PyAclManagerAclConnection(self.device, self.acl_stream, None,
                                         handle)
+10 −7
Original line number Diff line number Diff line
@@ -16,6 +16,8 @@

from google.protobuf import empty_pb2 as empty_proto
from cert.event_stream import EventStream
from cert.event_stream import FilteringEventStream
from cert.event_stream import IEventStream
from captures import ReadBdAddrCompleteCapture
from captures import ConnectionCompleteCapture
from captures import ConnectionRequestCapture
@@ -24,16 +26,17 @@ from cert.truth import assertThat
from hci.facade import facade_pb2 as hci_facade


class PyHciAclConnection(object):
class PyHciAclConnection(IEventStream):

    def __init__(self, handle, acl_stream, device):
        self.handle = handle
        self.acl_stream = acl_stream
        self.handle = int(handle)
        self.device = device
        # todo, handle we got is 0, so doesn't match - fix before enabling filtering
        self.our_acl_stream = FilteringEventStream(acl_stream, None)

    def send(self, pb_flag, b_flag, data):
        acl_msg = hci_facade.AclMsg(
            handle=int(self.handle),
            handle=self.handle,
            packet_boundary_flag=int(pb_flag),
            broadcast_flag=int(b_flag),
            data=data)
@@ -47,6 +50,9 @@ class PyHciAclConnection(object):
        self.send(hci_packets.PacketBoundaryFlag.CONTINUING_FRAGMENT,
                  hci_packets.BroadcastFlag.POINT_TO_POINT, bytes(data))

    def get_event_queue(self):
        return self.our_acl_stream.get_event_queue()


class PyHci(object):

@@ -80,9 +86,6 @@ class PyHci(object):
    def get_event_stream(self):
        return self.event_stream

    def get_acl_stream(self):
        return self.acl_stream

    def send_command_with_complete(self, command):
        self.device.hci.send_command_with_complete(command)