Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit dd33d1d7 authored by Automerger Merge Worker's avatar Automerger Merge Worker
Browse files

Merge changes Iafd263c7,I29300b6b,If8422986,Ied877cd6,If30d8ca4, ... am: 3cc6444f am: 14050d9d

Change-Id: Ieb5ad133f9a1a0e37d3adda48d7cfcb1ae66c099
parents df60905c 14050d9d
Loading
Loading
Loading
Loading
+39 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3
#
#   Copyright 2019 - The Android Open Source Project
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

from abc import ABC, abstractmethod


class Closable(ABC):

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.close()
        return traceback is None

    def __del__(self):
        self.close()

    @abstractmethod
    def close(self):
        pass


def safeClose(closable):
    if closable is not None:
        closable.close()
+115 −72
Original line number Diff line number Diff line
@@ -24,14 +24,46 @@ from google.protobuf import text_format
from concurrent.futures import ThreadPoolExecutor
from grpc import RpcError

from abc import ABC, abstractmethod

class EventStream(object):
from cert.closable import Closable


class IEventStream(ABC):

    @abstractmethod
    def get_event_queue(self):
        pass


class FilteringEventStream(IEventStream):

    def __init__(self, stream, filter_fn):
        self.filter_fn = filter_fn
        self.event_queue = SimpleQueue()
        self.stream = stream

        self.stream.register_callback(self.__event_callback, self.filter_fn)

    def __event_callback(self, event):
        self.event_queue.put(event)

    def get_event_queue(self):
        return self.event_queue

    def unregister(self):
        self.stream.unregister(self.__event_callback)


DEFAULT_TIMEOUT_SECONDS = 3


class EventStream(IEventStream, Closable):
    """
    A class that streams events from a gRPC stream, which you can assert on.

    Don't use these asserts directly, use the ones from cert.truth.
    """
    DEFAULT_TIMEOUT_SECONDS = 3

    def __init__(self, server_stream_call):
        if server_stream_call is None:
@@ -43,23 +75,10 @@ class EventStream(object):
        self.executor = ThreadPoolExecutor()
        self.future = self.executor.submit(EventStream._event_loop, self)

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.shutdown()
        return traceback is None

    def __del__(self):
        self.shutdown()
    def get_event_queue(self):
        return self.event_queue

    def remaining_time_delta(self, end_time):
        remaining = end_time - datetime.now()
        if remaining < timedelta(milliseconds=0):
            remaining = timedelta(milliseconds=0)
        return remaining

    def shutdown(self):
    def close(self):
        """
        Stop the gRPC lambda so that event_callback will not be invoked after th
        method returns.
@@ -170,7 +189,7 @@ class EventStream(object):
        event = None
        end_time = datetime.now() + timeout
        while event is None and datetime.now() < end_time:
            remaining = self.remaining_time_delta(end_time)
            remaining = static_remaining_time_delta(end_time)
            logging.debug("Waiting for event (%fs remaining)" %
                          (remaining.total_seconds()))
            try:
@@ -202,26 +221,7 @@ class EventStream(object):
                               happen
        :return:
        """
        logging.debug("assert_event_occurs %d %fs" % (at_least_times,
                                                      timeout.total_seconds()))
        event_list = []
        end_time = datetime.now() + timeout
        while len(event_list) < at_least_times and datetime.now() < end_time:
            remaining = self.remaining_time_delta(end_time)
            logging.debug("Waiting for event (%fs remaining)" %
                          (remaining.total_seconds()))
            try:
                current_event = self.event_queue.get(
                    timeout=remaining.total_seconds())
                if match_fn(current_event):
                    event_list.append(current_event)
            except Empty:
                continue
        logging.debug("Done waiting for event")
        asserts.assert_true(
            len(event_list) >= at_least_times,
            msg=("Expected at least %d events, but got %d" % (at_least_times,
                                                              len(event_list))))
        NOT_FOR_YOU_assert_event_occurs(self, match_fn, at_least_times, timeout)

    def assert_event_occurs_at_most(
            self,
@@ -242,7 +242,7 @@ class EventStream(object):
        event_list = []
        end_time = datetime.now() + timeout
        while len(event_list) <= at_most_times and datetime.now() < end_time:
            remaining = self.remaining_time_delta(end_time)
            remaining = static_remaining_time_delta(end_time)
            logging.debug("Waiting for event iteration (%fs remaining)" %
                          (remaining.total_seconds()))
            try:
@@ -263,16 +263,59 @@ class EventStream(object):
            match_fns,
            order_matters,
            timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
        NOT_FOR_YOU_assert_all_events_occur(self, match_fns, order_matters,
                                            timeout)


def static_remaining_time_delta(end_time):
    remaining = end_time - datetime.now()
    if remaining < timedelta(milliseconds=0):
        remaining = timedelta(milliseconds=0)
    return remaining


def NOT_FOR_YOU_assert_event_occurs(
        istream,
        match_fn,
        at_least_times=1,
        timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
    logging.debug("assert_event_occurs %d %fs" % (at_least_times,
                                                  timeout.total_seconds()))
    event_list = []
    end_time = datetime.now() + timeout
    while len(event_list) < at_least_times and datetime.now() < end_time:
        remaining = static_remaining_time_delta(end_time)
        logging.debug(
            "Waiting for event (%fs remaining)" % (remaining.total_seconds()))
        try:
            current_event = istream.get_event_queue().get(
                timeout=remaining.total_seconds())
            if match_fn(current_event):
                event_list.append(current_event)
        except Empty:
            continue
    logging.debug("Done waiting for event")
    asserts.assert_true(
        len(event_list) >= at_least_times,
        msg=("Expected at least %d events, but got %d" % (at_least_times,
                                                          len(event_list))))


def NOT_FOR_YOU_assert_all_events_occur(
        istream,
        match_fns,
        order_matters,
        timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
    logging.debug("assert_all_events_occur %fs" % timeout.total_seconds())
    pending_matches = list(match_fns)
    matched_order = []
    end_time = datetime.now() + timeout
    while len(pending_matches) > 0 and datetime.now() < end_time:
            remaining = self.remaining_time_delta(end_time)
            logging.debug("Waiting for event (%fs remaining)" %
                          (remaining.total_seconds()))
        remaining = static_remaining_time_delta(end_time)
        logging.debug(
            "Waiting for event (%fs remaining)" % (remaining.total_seconds()))
        try:
                current_event = self.event_queue.get(
            current_event = istream.get_event_queue().get(
                timeout=remaining.total_seconds())
            for match_fn in pending_matches:
                if match_fn(current_event):
@@ -283,8 +326,8 @@ class EventStream(object):
    logging.debug("Done waiting for event")
    asserts.assert_true(
        len(matched_order) == len(match_fns),
            msg=("Expected at least %d events, but got %d" %
                 (len(match_fns), len(matched_order))))
        msg=("Expected at least %d events, but got %d" % (len(match_fns),
                                                          len(matched_order))))
    if order_matters:
        correct_order = True
        i = 0
+100 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3
#
#   Copyright 2020 - The Android Open Source Project
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

from google.protobuf import empty_pb2 as empty_proto
from cert.event_stream import EventStream
from cert.event_stream import FilteringEventStream
from cert.event_stream import IEventStream
from cert.captures import ReadBdAddrCompleteCapture
from cert.captures import ConnectionCompleteCapture
from cert.captures import ConnectionRequestCapture
from cert.closable import Closable
from cert.closable import safeClose
from bluetooth_packets_python3 import hci_packets
from cert.truth import assertThat
from hci.facade import facade_pb2 as hci_facade
from hci.facade import acl_manager_facade_pb2 as acl_manager_facade


class PyAclManagerAclConnection(IEventStream, Closable):

    def __init__(self, device, acl_stream, remote_addr, handle):
        self.device = device
        self.handle = handle
        # todo enable filtering after sorting out handles
        self.our_acl_stream = FilteringEventStream(acl_stream, None)

        if remote_addr:
            remote_addr_bytes = bytes(
                remote_addr,
                'utf8') if type(remote_addr) is str else bytes(remote_addr)
            self.connection_event_stream = EventStream(
                self.device.hci_acl_manager.CreateConnection(
                    acl_manager_facade.ConnectionMsg(
                        address_type=int(
                            hci_packets.AddressType.PUBLIC_DEVICE_ADDRESS),
                        address=remote_addr_bytes)))
        else:
            self.connection_event_stream = None

    def close(self):
        safeClose(self.connection_event_stream)

    def wait_for_connection_complete(self):
        connection_complete = ConnectionCompleteCapture()
        assertThat(self.connection_event_stream).emits(connection_complete)
        self.handle = connection_complete.get().GetConnectionHandle()

    def send(self, data):
        self.device.hci_acl_manager.SendAclData(
            acl_manager_facade.AclData(handle=self.handle, payload=bytes(data)))

    def get_event_queue(self):
        return self.our_acl_stream.get_event_queue()


class PyAclManager(Closable):

    def __init__(self, device):
        self.device = device

        self.acl_stream = EventStream(
            self.device.hci_acl_manager.FetchAclData(empty_proto.Empty()))
        self.incoming_connection_stream = None

    def close(self):
        safeClose(self.acl_stream)
        safeClose(self.incoming_connection_stream)

    # temporary, until everyone is migrated
    def get_acl_stream(self):
        return self.acl_stream

    def listen_for_incoming_connections(self):
        self.incoming_connection_stream = EventStream(
            self.device.hci_acl_manager.FetchIncomingConnection(
                empty_proto.Empty()))

    def initiate_connection(self, remote_addr):
        return PyAclManagerAclConnection(self.device, self.acl_stream,
                                         remote_addr, None)

    def accept_connection(self):
        connection_complete = ConnectionCompleteCapture()
        assertThat(self.incoming_connection_stream).emits(connection_complete)
        handle = connection_complete.get().GetConnectionHandle()
        return PyAclManagerAclConnection(self.device, self.acl_stream, None,
                                         handle)
+19 −24
Original line number Diff line number Diff line
@@ -16,24 +16,29 @@

from google.protobuf import empty_pb2 as empty_proto
from cert.event_stream import EventStream
from captures import ReadBdAddrCompleteCapture
from captures import ConnectionCompleteCapture
from captures import ConnectionRequestCapture
from cert.event_stream import FilteringEventStream
from cert.event_stream import IEventStream
from cert.closable import Closable
from cert.closable import safeClose
from cert.captures import ReadBdAddrCompleteCapture
from cert.captures import ConnectionCompleteCapture
from cert.captures import ConnectionRequestCapture
from bluetooth_packets_python3 import hci_packets
from cert.truth import assertThat
from hci.facade import facade_pb2 as hci_facade


class PyHciAclConnection(object):
class PyHciAclConnection(IEventStream):

    def __init__(self, handle, acl_stream, device):
        self.handle = handle
        self.acl_stream = acl_stream
        self.handle = int(handle)
        self.device = device
        # todo, handle we got is 0, so doesn't match - fix before enabling filtering
        self.our_acl_stream = FilteringEventStream(acl_stream, None)

    def send(self, pb_flag, b_flag, data):
        acl_msg = hci_facade.AclMsg(
            handle=int(self.handle),
            handle=self.handle,
            packet_boundary_flag=int(pb_flag),
            broadcast_flag=int(b_flag),
            data=data)
@@ -47,8 +52,11 @@ class PyHciAclConnection(object):
        self.send(hci_packets.PacketBoundaryFlag.CONTINUING_FRAGMENT,
                  hci_packets.BroadcastFlag.POINT_TO_POINT, bytes(data))

    def get_event_queue(self):
        return self.our_acl_stream.get_event_queue()

class PyHci(object):

class PyHci(Closable):

    def __init__(self, device):
        self.device = device
@@ -63,26 +71,13 @@ class PyHci(object):
        self.acl_stream = EventStream(
            self.device.hci.FetchAclPackets(empty_proto.Empty()))

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.clean_up()
        return traceback is None

    def __del__(self):
        self.clean_up()

    def clean_up(self):
        self.event_stream.shutdown()
        self.acl_stream.shutdown()
    def close(self):
        safeClose(self.event_stream)
        safeClose(self.acl_stream)

    def get_event_stream(self):
        return self.event_stream

    def get_acl_stream(self):
        return self.acl_stream

    def send_command_with_complete(self, command):
        self.device.hci.send_command_with_complete(command)

Loading