Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 3cc6444f authored by Treehugger Robot's avatar Treehugger Robot Committed by Gerrit Code Review
Browse files

Merge changes Iafd263c7,I29300b6b,If8422986,Ied877cd6,If30d8ca4, ...

* changes:
  Register _handle_control_packet in a single place
  Migrate to use cert_acl instead of cert_acl_handle
  Remove cert_acl_data_stream from _open_channel, as it can be inferred
  Remove cert_acl_handle return value, as it was not really used
  Begin migrating L2capTest to use PyAclManager
  Begin PyL2cap, to make interfacing with l2cap facade easier
  Add Closable, to reduce repeated code
  Move captures & PyHci/PyAclManager to common cert
  Move common setup to setup_test & teardown_test
  Add IEventStream, so other subjects can use emit, etc
  Flesh out PyAclManager a little more
  Start PyAclManager, to simplify interaction with AclManager
parents d49f7fec b4ba004f
Loading
Loading
Loading
Loading
+39 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3
#
#   Copyright 2019 - The Android Open Source Project
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

from abc import ABC, abstractmethod


class Closable(ABC):

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.close()
        return traceback is None

    def __del__(self):
        self.close()

    @abstractmethod
    def close(self):
        pass


def safeClose(closable):
    if closable is not None:
        closable.close()
+115 −72
Original line number Diff line number Diff line
@@ -24,14 +24,46 @@ from google.protobuf import text_format
from concurrent.futures import ThreadPoolExecutor
from grpc import RpcError

from abc import ABC, abstractmethod

class EventStream(object):
from cert.closable import Closable


class IEventStream(ABC):

    @abstractmethod
    def get_event_queue(self):
        pass


class FilteringEventStream(IEventStream):

    def __init__(self, stream, filter_fn):
        self.filter_fn = filter_fn
        self.event_queue = SimpleQueue()
        self.stream = stream

        self.stream.register_callback(self.__event_callback, self.filter_fn)

    def __event_callback(self, event):
        self.event_queue.put(event)

    def get_event_queue(self):
        return self.event_queue

    def unregister(self):
        self.stream.unregister(self.__event_callback)


DEFAULT_TIMEOUT_SECONDS = 3


class EventStream(IEventStream, Closable):
    """
    A class that streams events from a gRPC stream, which you can assert on.

    Don't use these asserts directly, use the ones from cert.truth.
    """
    DEFAULT_TIMEOUT_SECONDS = 3

    def __init__(self, server_stream_call):
        if server_stream_call is None:
@@ -43,23 +75,10 @@ class EventStream(object):
        self.executor = ThreadPoolExecutor()
        self.future = self.executor.submit(EventStream._event_loop, self)

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.shutdown()
        return traceback is None

    def __del__(self):
        self.shutdown()
    def get_event_queue(self):
        return self.event_queue

    def remaining_time_delta(self, end_time):
        remaining = end_time - datetime.now()
        if remaining < timedelta(milliseconds=0):
            remaining = timedelta(milliseconds=0)
        return remaining

    def shutdown(self):
    def close(self):
        """
        Stop the gRPC lambda so that event_callback will not be invoked after th
        method returns.
@@ -170,7 +189,7 @@ class EventStream(object):
        event = None
        end_time = datetime.now() + timeout
        while event is None and datetime.now() < end_time:
            remaining = self.remaining_time_delta(end_time)
            remaining = static_remaining_time_delta(end_time)
            logging.debug("Waiting for event (%fs remaining)" %
                          (remaining.total_seconds()))
            try:
@@ -202,26 +221,7 @@ class EventStream(object):
                               happen
        :return:
        """
        logging.debug("assert_event_occurs %d %fs" % (at_least_times,
                                                      timeout.total_seconds()))
        event_list = []
        end_time = datetime.now() + timeout
        while len(event_list) < at_least_times and datetime.now() < end_time:
            remaining = self.remaining_time_delta(end_time)
            logging.debug("Waiting for event (%fs remaining)" %
                          (remaining.total_seconds()))
            try:
                current_event = self.event_queue.get(
                    timeout=remaining.total_seconds())
                if match_fn(current_event):
                    event_list.append(current_event)
            except Empty:
                continue
        logging.debug("Done waiting for event")
        asserts.assert_true(
            len(event_list) >= at_least_times,
            msg=("Expected at least %d events, but got %d" % (at_least_times,
                                                              len(event_list))))
        NOT_FOR_YOU_assert_event_occurs(self, match_fn, at_least_times, timeout)

    def assert_event_occurs_at_most(
            self,
@@ -242,7 +242,7 @@ class EventStream(object):
        event_list = []
        end_time = datetime.now() + timeout
        while len(event_list) <= at_most_times and datetime.now() < end_time:
            remaining = self.remaining_time_delta(end_time)
            remaining = static_remaining_time_delta(end_time)
            logging.debug("Waiting for event iteration (%fs remaining)" %
                          (remaining.total_seconds()))
            try:
@@ -263,16 +263,59 @@ class EventStream(object):
            match_fns,
            order_matters,
            timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
        NOT_FOR_YOU_assert_all_events_occur(self, match_fns, order_matters,
                                            timeout)


def static_remaining_time_delta(end_time):
    remaining = end_time - datetime.now()
    if remaining < timedelta(milliseconds=0):
        remaining = timedelta(milliseconds=0)
    return remaining


def NOT_FOR_YOU_assert_event_occurs(
        istream,
        match_fn,
        at_least_times=1,
        timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
    logging.debug("assert_event_occurs %d %fs" % (at_least_times,
                                                  timeout.total_seconds()))
    event_list = []
    end_time = datetime.now() + timeout
    while len(event_list) < at_least_times and datetime.now() < end_time:
        remaining = static_remaining_time_delta(end_time)
        logging.debug(
            "Waiting for event (%fs remaining)" % (remaining.total_seconds()))
        try:
            current_event = istream.get_event_queue().get(
                timeout=remaining.total_seconds())
            if match_fn(current_event):
                event_list.append(current_event)
        except Empty:
            continue
    logging.debug("Done waiting for event")
    asserts.assert_true(
        len(event_list) >= at_least_times,
        msg=("Expected at least %d events, but got %d" % (at_least_times,
                                                          len(event_list))))


def NOT_FOR_YOU_assert_all_events_occur(
        istream,
        match_fns,
        order_matters,
        timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
    logging.debug("assert_all_events_occur %fs" % timeout.total_seconds())
    pending_matches = list(match_fns)
    matched_order = []
    end_time = datetime.now() + timeout
    while len(pending_matches) > 0 and datetime.now() < end_time:
            remaining = self.remaining_time_delta(end_time)
            logging.debug("Waiting for event (%fs remaining)" %
                          (remaining.total_seconds()))
        remaining = static_remaining_time_delta(end_time)
        logging.debug(
            "Waiting for event (%fs remaining)" % (remaining.total_seconds()))
        try:
                current_event = self.event_queue.get(
            current_event = istream.get_event_queue().get(
                timeout=remaining.total_seconds())
            for match_fn in pending_matches:
                if match_fn(current_event):
@@ -283,8 +326,8 @@ class EventStream(object):
    logging.debug("Done waiting for event")
    asserts.assert_true(
        len(matched_order) == len(match_fns),
            msg=("Expected at least %d events, but got %d" %
                 (len(match_fns), len(matched_order))))
        msg=("Expected at least %d events, but got %d" % (len(match_fns),
                                                          len(matched_order))))
    if order_matters:
        correct_order = True
        i = 0
+100 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3
#
#   Copyright 2020 - The Android Open Source Project
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

from google.protobuf import empty_pb2 as empty_proto
from cert.event_stream import EventStream
from cert.event_stream import FilteringEventStream
from cert.event_stream import IEventStream
from cert.captures import ReadBdAddrCompleteCapture
from cert.captures import ConnectionCompleteCapture
from cert.captures import ConnectionRequestCapture
from cert.closable import Closable
from cert.closable import safeClose
from bluetooth_packets_python3 import hci_packets
from cert.truth import assertThat
from hci.facade import facade_pb2 as hci_facade
from hci.facade import acl_manager_facade_pb2 as acl_manager_facade


class PyAclManagerAclConnection(IEventStream, Closable):

    def __init__(self, device, acl_stream, remote_addr, handle):
        self.device = device
        self.handle = handle
        # todo enable filtering after sorting out handles
        self.our_acl_stream = FilteringEventStream(acl_stream, None)

        if remote_addr:
            remote_addr_bytes = bytes(
                remote_addr,
                'utf8') if type(remote_addr) is str else bytes(remote_addr)
            self.connection_event_stream = EventStream(
                self.device.hci_acl_manager.CreateConnection(
                    acl_manager_facade.ConnectionMsg(
                        address_type=int(
                            hci_packets.AddressType.PUBLIC_DEVICE_ADDRESS),
                        address=remote_addr_bytes)))
        else:
            self.connection_event_stream = None

    def close(self):
        safeClose(self.connection_event_stream)

    def wait_for_connection_complete(self):
        connection_complete = ConnectionCompleteCapture()
        assertThat(self.connection_event_stream).emits(connection_complete)
        self.handle = connection_complete.get().GetConnectionHandle()

    def send(self, data):
        self.device.hci_acl_manager.SendAclData(
            acl_manager_facade.AclData(handle=self.handle, payload=bytes(data)))

    def get_event_queue(self):
        return self.our_acl_stream.get_event_queue()


class PyAclManager(Closable):

    def __init__(self, device):
        self.device = device

        self.acl_stream = EventStream(
            self.device.hci_acl_manager.FetchAclData(empty_proto.Empty()))
        self.incoming_connection_stream = None

    def close(self):
        safeClose(self.acl_stream)
        safeClose(self.incoming_connection_stream)

    # temporary, until everyone is migrated
    def get_acl_stream(self):
        return self.acl_stream

    def listen_for_incoming_connections(self):
        self.incoming_connection_stream = EventStream(
            self.device.hci_acl_manager.FetchIncomingConnection(
                empty_proto.Empty()))

    def initiate_connection(self, remote_addr):
        return PyAclManagerAclConnection(self.device, self.acl_stream,
                                         remote_addr, None)

    def accept_connection(self):
        connection_complete = ConnectionCompleteCapture()
        assertThat(self.incoming_connection_stream).emits(connection_complete)
        handle = connection_complete.get().GetConnectionHandle()
        return PyAclManagerAclConnection(self.device, self.acl_stream, None,
                                         handle)
+19 −24
Original line number Diff line number Diff line
@@ -16,24 +16,29 @@

from google.protobuf import empty_pb2 as empty_proto
from cert.event_stream import EventStream
from captures import ReadBdAddrCompleteCapture
from captures import ConnectionCompleteCapture
from captures import ConnectionRequestCapture
from cert.event_stream import FilteringEventStream
from cert.event_stream import IEventStream
from cert.closable import Closable
from cert.closable import safeClose
from cert.captures import ReadBdAddrCompleteCapture
from cert.captures import ConnectionCompleteCapture
from cert.captures import ConnectionRequestCapture
from bluetooth_packets_python3 import hci_packets
from cert.truth import assertThat
from hci.facade import facade_pb2 as hci_facade


class PyHciAclConnection(object):
class PyHciAclConnection(IEventStream):

    def __init__(self, handle, acl_stream, device):
        self.handle = handle
        self.acl_stream = acl_stream
        self.handle = int(handle)
        self.device = device
        # todo, handle we got is 0, so doesn't match - fix before enabling filtering
        self.our_acl_stream = FilteringEventStream(acl_stream, None)

    def send(self, pb_flag, b_flag, data):
        acl_msg = hci_facade.AclMsg(
            handle=int(self.handle),
            handle=self.handle,
            packet_boundary_flag=int(pb_flag),
            broadcast_flag=int(b_flag),
            data=data)
@@ -47,8 +52,11 @@ class PyHciAclConnection(object):
        self.send(hci_packets.PacketBoundaryFlag.CONTINUING_FRAGMENT,
                  hci_packets.BroadcastFlag.POINT_TO_POINT, bytes(data))

    def get_event_queue(self):
        return self.our_acl_stream.get_event_queue()

class PyHci(object):

class PyHci(Closable):

    def __init__(self, device):
        self.device = device
@@ -63,26 +71,13 @@ class PyHci(object):
        self.acl_stream = EventStream(
            self.device.hci.FetchAclPackets(empty_proto.Empty()))

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.clean_up()
        return traceback is None

    def __del__(self):
        self.clean_up()

    def clean_up(self):
        self.event_stream.shutdown()
        self.acl_stream.shutdown()
    def close(self):
        safeClose(self.event_stream)
        safeClose(self.acl_stream)

    def get_event_stream(self):
        return self.event_stream

    def get_acl_stream(self):
        return self.acl_stream

    def send_command_with_complete(self, command):
        self.device.hci.send_command_with_complete(command)

Loading