Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ed9134c5 authored by Treehugger Robot's avatar Treehugger Robot Committed by Gerrit Code Review
Browse files

Merge "Cert Test: Simplify server-client stream handling"

parents 8521a143 814de0f5
Loading
Loading
Loading
Loading
+175 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3
#
#   Copyright 2019 - The Android Open Source Project
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

from datetime import datetime, timedelta
import logging
from queue import SimpleQueue, Empty

from acts import asserts

from google.protobuf import text_format

from cert.event_callback_stream import EventCallbackStream

class EventAsserts(object):
    """
    A class that handles various asserts with respect to a gRPC unary stream

    This class must be created before an event happens as events in a
    EventCallbackStream is not sticky and will be lost if you don't subscribe
    to them before generating those events.

    When asserting on sequential events, a single EventAsserts object is enough

    When asserting on simultaneous events, you would need multiple EventAsserts
    objects as each EventAsserts object owns a separate queue that is actively
    being popped as asserted events happen
    """
    DEFAULT_TIMEOUT_SECONDS = 3
    DEFAULT_INCREMENTAL_TIMEOUT_SECONDS = 0.1

    def __init__(self, event_callback_stream):
        if event_callback_stream is None:
            raise ValueError("event_callback_stream cannot be None")
        self.event_callback_stream = event_callback_stream
        self.event_queue = SimpleQueue()
        self.callback = lambda event : self.event_queue.put(event)
        self.event_callback_stream.register_callback(self.callback)

    def __del__(self):
        self.event_callback_stream.unregister_callback(self.callback)

    def assert_none(self, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
        """
        Assert no event happens within timeout period

        :param timeout: a timedelta object
        :return:
        """
        logging.debug("assert_none")
        try:
            event = self.event_queue.get(timeout=timeout.seconds)
            asserts.assert_equal(event, None,
                                 msg=(
                                     "Expected None, but got %s" % text_format.MessageToString(
                                     event, as_one_line=True)))
        except Empty:
            return

    def assert_none_matching(self, match_fn,
                             timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
        """
        Assert no events where match_fn(event) is True happen within timeout
        period

        :param match_fn: return True/False on match_fn(event)
        :param timeout: a timedelta object
        :return:
        """
        logging.debug("assert_none_matching")
        event = None
        iter_count = 0
        timeout_seconds = timeout.seconds
        while timeout_seconds > 0:
            iter_count += 1
            logging.debug("Waiting for event iteration %d" % iter_count)
            try:
                time_before = datetime.now()
                current_event = self.event_queue.get(
                    timeout=timeout_seconds)
                time_elapsed = datetime.now() - time_before
                timeout_seconds -= time_elapsed.seconds
                if match_fn(current_event):
                    event = current_event
            except Empty:
                continue
        logging.debug(
            "Done waiting for event, got %s" % text_format.MessageToString(
                event, as_one_line=True))
        asserts.assert_equal(event, None,
                             msg=(
                                 "Expected None matching, but got %s" % text_format.MessageToString(
                                 event, as_one_line=True)))

    def assert_event_occurs(self, match_fn, at_least_times=1,
                            timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
        """
        Assert at least |at_least_times| instances of events happen where
        match_fn(event) returns True within timeout period

        :param match_fn: returns True/False on match_fn(event)
        :param timeout: a timedelta object
        :param at_least_times: how many times at least a matching event should
                               happen
        :return:
        """
        logging.debug("assert_event_occurs")
        event = []
        iter_count = 0
        timeout_seconds = timeout.seconds
        while len(event) < at_least_times and timeout_seconds > 0:
            iter_count += 1
            logging.debug("Waiting for event iteration %d" % iter_count)
            try:
                time_before = datetime.now()
                current_event = self.event_queue.get(
                    timeout=timeout_seconds)
                time_elapsed = datetime.now() - time_before
                timeout_seconds -= time_elapsed.seconds
                if match_fn(current_event):
                    event.append(current_event)
            except Empty:
                continue
        logging.debug(
            "Done waiting for event, got %s" % text_format.MessageToString(
                event, as_one_line=True))
        asserts.assert_true(len(event) == at_least_times,
                            msg=("Expected at least %d events, but got %d" % at_least_times, len(event)))

    def assert_event_occurs_at_most(self, match_fn, at_most_times,
                                    timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
        """
        Assert at most |at_most_times| instances of events happen where
        match_fn(event) returns True within timeout period

        :param match_fn: returns True/False on match_fn(event)
        :param at_most_times: how many times at most a matching event should
                               happen
        :param timeout:a timedelta object
        :return:
        """
        logging.debug("assert_event_occurs_at_most")
        event = []
        iter_count = 0
        timeout_seconds = timeout.seconds
        while timeout_seconds > 0:
            iter_count += 1
            logging.debug("Waiting for event iteration %d" % iter_count)
            try:
                time_before = datetime.now()
                current_event = self.event_queue.get(
                    timeout=timeout_seconds)
                time_elapsed = datetime.now() - time_before
                timeout_seconds -= time_elapsed.seconds
                if match_fn(current_event):
                    event.append(current_event)
            except Empty:
                continue
        logging.debug(
            "Done waiting for event, got %s" % text_format.MessageToString(
                event, as_one_line=True))
        asserts.assert_true(len(event) <= at_most_times,
                            msg=("Expected at most %d events, but got %d" % at_most_times, len(event)))
 No newline at end of file
+146 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3
#
#   Copyright 2019 - The Android Open Source Project
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

from concurrent.futures import ThreadPoolExecutor
from grpc import RpcError
from grpc._channel import _Rendezvous
import logging


class EventCallbackStream(object):
    """
    A an object that translate a gRPC stream of events to a Python stream of
    callbacks.

    All callbacks are non-sticky. This means that user will only receive callback
    generated after EventCallbackStream is registered and will not receive any
    callback after EventCallbackStream is unregistered

    You would need a new EventCallbackStream and anything that depends on this
    object once shutdown() is called
    """

    def __init__(self, server_stream_call):
        """
        Construct this object, call the |grpc_lambda| and trigger event_callback on
        the thread used to create this object until |destroy| is called when this
        object can no longer be used
        :param server_stream_call: A server stream call object returned from
                                   calling a gRPC server stream RPC API. The
                                   object must support iterator interface (i.e.
                                   next() method) and the grpc.Call interface
                                   so that we can cancel it
        :param event_callback: callback to be invoked with the only argument as
                               the generated event. The callback will be invoked
                               on a separate thread created within this object
        """
        if server_stream_call is None:
            raise ValueError("server_stream_call must not be None")
        self.server_stream_call = server_stream_call
        self.handlers = []
        self.executor = ThreadPoolExecutor()
        self.future = self.executor.submit(EventCallbackStream._event_loop,
                                           self)

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.shutdown()
        return True

    def __del__(self):
        self.shutdown()

    def register_callback(self, callback, matcher_fn=None):
        """
        Register a callback to handle events. Event will be handled by callback
        if matcher_fn(event) returns True

        callback and matcher are registered as a tuple. Hence the same callback
        with different matcher are considered two different handler units. Same
        matcher, but different callback are also considered different handling
        unit

        Callback will be invoked on a ThreadPoolExecutor owned by this
        EventCallbackStream

        :param callback: Will be called as callback(event)
        :param matcher_fn: A boolean function that returns True or False when
                           calling matcher_fn(event), if None, all event will
                           be matched
        """
        if callback is None:
            raise ValueError("callback must not be None")
        self.handlers.append((callback, matcher_fn))

    def unregister_callback(self, callback, matcher_fn=None):
        """
        Unregister callback and matcher_fn from the event stream. Both objects
        must match exactly the ones when calling register_callback()

        :param callback: callback used in register_callback()
        :param matcher_fn: matcher_fn used in register_callback()
        :raises ValueError when (callback, matcher_fn) tuple is not found
        """
        if callback is None:
            raise ValueError("callback must not be None")
        self.handlers.remove((callback, matcher_fn))

    def shutdown(self):
        """
        Stop the gRPC lambda so that event_callback will not be invoked after th
        method returns.

        This object will be useless after this call as there is no way to restart
        the gRPC callback. You would have to create a new EventCallbackStream

        :return: None on success, exception object on failure
        """
        while not self.server_stream_call.done():
            self.server_stream_call.cancel()
        exception_for_return = None
        try:
            result = self.future.result()
            if result:
                logging.warning("Inner loop error %s" % result)
                raise result
        except Exception as exp:
            logging.warning("Exception: %s" % (exp))
            exception_for_return = exp
        self.executor.shutdown()
        return exception_for_return

    def _event_loop(self):
        """
        Main loop for consuming the gRPC stream events.
        Blocks until computation is cancelled
        :return: None on success, exception object on failure
        """
        try:
            for event in self.server_stream_call:
                for (callback, matcher_fn) in self.handlers:
                    if not matcher_fn or matcher_fn(event):
                        callback(event)
            return None
        except RpcError as exp:
            if type(exp) is _Rendezvous:
                if exp.cancelled():
                    logging.debug("Cancelled")
                    return None
                else:
                    logging.warning("Not cancelled")
            return exp

system/gd/cert/event_stream.py

deleted100644 → 0
+0 −111
Original line number Diff line number Diff line
#!/usr/bin/env python3
#
#   Copyright 2019 - The Android Open Source Project
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

from acts import asserts

from facade import common_pb2
from datetime import datetime
from datetime import timedelta
from grpc import RpcError
from grpc import StatusCode

class EventStream(object):

  def __init__(self, stream_stub_fn):
    self.stream_stub_fn = stream_stub_fn
    self.event_buffer = []

    self.subscribe_request = common_pb2.EventStreamRequest(
        subscription_mode=common_pb2.SUBSCRIBE,
        fetch_mode=common_pb2.NONE
    )

    self.unsubscribe_request = common_pb2.EventStreamRequest(
        subscription_mode=common_pb2.UNSUBSCRIBE,
        fetch_mode=common_pb2.NONE
    )

    self.fetch_all_current_request = common_pb2.EventStreamRequest(
        subscription_mode=common_pb2.UNCHANGED,
        fetch_mode=common_pb2.ALL_CURRENT
    )

    self.fetch_at_least_one_request = lambda expiration_time : common_pb2.EventStreamRequest(
        subscription_mode=common_pb2.UNCHANGED,
        fetch_mode=common_pb2.AT_LEAST_ONE,
        timeout_ms = int((expiration_time - datetime.now()).total_seconds() * 1000)
    )

  def clear_event_buffer(self):
    self.event_buffer.clear()

  def subscribe(self):
    rpc = self.stream_stub_fn(self.subscribe_request)
    return rpc.result()

  def unsubscribe(self):
    rpc = self.stream_stub_fn(self.unsubscribe_request)
    return rpc.result()

  def assert_none(self):
    response = self.stream_stub_fn(self.fetch_all_current_request)

    try:
      for event in response:
        self.event_buffer.append(event)
    except RpcError:
        pass

    if len(self.event_buffer) != 0:
      asserts.fail("event_buffer is not empty \n%s" % self.event_buffer)

  def assert_none_matching(self, match_fn):
    response = self.stream_stub_fn(self.fetch_all_current_request)

    try:
      for event in response:
        self.event_buffer.append(event)
    except RpcError:
      pass

    for event in self.event_buffer:
      if match_fn(event):
        asserts.fail("event %s occurs" % event)

  def assert_event_occurs(self, match_fn, timeout=timedelta(seconds=3)):
    expiration_time = datetime.now() + timeout

    while len(self.event_buffer):
      element = self.event_buffer.pop(0)
      if match_fn(element):
        return

    while (True):
      if datetime.now() > expiration_time:
        asserts.fail("timeout of %s exceeded" % str(timeout))

      response = self.stream_stub_fn(self.fetch_at_least_one_request(expiration_time))

      try:
        for event in response:
          if (match_fn(event)):
            for remain_event in response:
              self.event_buffer.append(remain_event)
            return
      except RpcError:
        if response.code() == StatusCode.DEADLINE_EXCEEDED:
          continue
        raise
+0 −10
Original line number Diff line number Diff line
@@ -17,7 +17,6 @@
from gd_device_base import GdDeviceBase
from gd_device_base import replace_vars

from cert.event_stream import EventStream
from cert import rootservice_pb2_grpc as cert_rootservice_pb2_grpc
from hal.cert import api_pb2_grpc as hal_cert_pb2_grpc
from hci.cert import api_pb2_grpc as hci_cert_pb2_grpc
@@ -70,12 +69,3 @@ class GdCertDevice(GdDeviceBase):
        self.controller_read_only_property = cert_rootservice_pb2_grpc.ReadOnlyPropertyStub(self.grpc_channel)
        self.hci = hci_cert_pb2_grpc.AclManagerCertStub(self.grpc_channel)
        self.l2cap = l2cap_cert_pb2_grpc.L2capClassicModuleCertStub(self.grpc_channel)

        # Event streams
        self.hal.hci_event_stream = EventStream(self.hal.FetchHciEvent)
        self.hal.hci_acl_stream = EventStream(self.hal.FetchHciAcl)
        self.hal.hci_sco_stream = EventStream(self.hal.FetchHciSco)
        self.hci.connection_complete_stream = EventStream(self.hci.FetchConnectionComplete)
        self.hci.disconnection_stream = EventStream(self.hci.FetchDisconnection)
        self.hci.connection_failed_stream = EventStream(self.hci.FetchConnectionFailed)
        self.hci.acl_stream = EventStream(self.hci.FetchAclData)
+0 −13
Original line number Diff line number Diff line
@@ -17,7 +17,6 @@
from gd_device_base import GdDeviceBase
from gd_device_base import replace_vars

from cert.event_stream import EventStream
from facade import rootservice_pb2_grpc as facade_rootservice_pb2_grpc
from hal import facade_pb2_grpc as hal_facade_pb2_grpc
from hci import facade_pb2_grpc as hci_facade_pb2_grpc
@@ -74,15 +73,3 @@ class GdDevice(GdDeviceBase):
        self.l2cap = l2cap_facade_pb2_grpc.L2capClassicModuleFacadeStub(self.grpc_channel)
        self.hci_le_advertising_manager = le_advertising_manager_facade_pb2_grpc.LeAdvertisingManagerFacadeStub(self.grpc_channel)
        # Event streams
        self.hal.hci_event_stream = EventStream(self.hal.FetchHciEvent)
        self.hal.hci_acl_stream = EventStream(self.hal.FetchHciAcl)
        self.hal.hci_sco_stream = EventStream(self.hal.FetchHciSco)
        self.hci.connection_complete_stream = EventStream(self.hci.FetchConnectionComplete)
        self.hci.disconnection_stream = EventStream(self.hci.FetchDisconnection)
        self.hci.connection_failed_stream = EventStream(self.hci.FetchConnectionFailed)
        self.hci.acl_stream = EventStream(self.hci.FetchAclData)
        self.hci_classic_security.command_complete_stream = EventStream(self.hci_classic_security.FetchCommandCompleteEvent)
        self.l2cap.packet_stream = EventStream(self.l2cap.FetchL2capData)
        self.l2cap.connection_complete_stream = EventStream(self.l2cap.FetchConnectionComplete)
        self.l2cap.connection_close_stream = EventStream(self.l2cap.FetchConnectionClose)
Loading