Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d3275232 authored by Zach Johnson's avatar Zach Johnson Committed by Gerrit Code Review
Browse files

Merge changes I56018319,Id324c5b4,I73e50e6c

* changes:
  Add captures, to clean up test logic and make tests easer to read
  Add multi-matcher assert on EventStream, and expose through truth
  Merge EventCallbackStream & EventAsserts into EventStream
parents 7aad4e4b 637e2603
Loading
Loading
Loading
Loading
+41 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3
#
#   Copyright 2020 - The Android Open Source Project
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.


class Capture(object):
    """
    Wrap a match function and use in its place, to capture the value
    that matched. Specify an optional |capture_fn| to transform the
    captured value.
    """

    def __init__(self, match_fn, capture_fn=None):
        self._match_fn = match_fn
        self._capture_fn = capture_fn
        self._value = None

    def __call__(self, obj):
        if self._match_fn(obj) != True:
            return False

        if self._capture_fn is not None:
            self._value = self._capture_fn(obj)
        else:
            self._value = obj
        return True

    def get(self):
        return self._value
+83 −64
Original line number Diff line number Diff line
@@ -20,8 +20,7 @@ import time
from mobly import asserts
from datetime import datetime, timedelta
from acts.base_test import BaseTestClass
from cert.event_callback_stream import EventCallbackStream
from cert.event_asserts import EventAsserts
from cert.event_stream import EventStream
from cert.truth import assertThat

# Test packet nesting
@@ -96,48 +95,40 @@ class CertSelfTest(BaseTestClass):
        return True

    def test_assert_none_passes(self):
        with EventCallbackStream(FetchEvents(events=[],
                                             delay_ms=50)) as event_stream:
            event_asserts = EventAsserts(event_stream)
            event_asserts.assert_none(timeout=timedelta(milliseconds=10))
        with EventStream(FetchEvents(events=[], delay_ms=50)) as event_stream:
            event_stream.assert_none(timeout=timedelta(milliseconds=10))

    def test_assert_none_passes_after_one_second(self):
        with EventCallbackStream(FetchEvents([1],
                                             delay_ms=1500)) as event_stream:
            event_asserts = EventAsserts(event_stream)
            event_asserts.assert_none(timeout=timedelta(seconds=1.0))
        with EventStream(FetchEvents([1], delay_ms=1500)) as event_stream:
            event_stream.assert_none(timeout=timedelta(seconds=1.0))

    def test_assert_none_fails(self):
        try:
            with EventCallbackStream(FetchEvents(events=[17],
            with EventStream(FetchEvents(events=[17],
                                         delay_ms=50)) as event_stream:
                event_asserts = EventAsserts(event_stream)
                event_asserts.assert_none(timeout=timedelta(seconds=1))
                event_stream.assert_none(timeout=timedelta(seconds=1))
        except Exception as e:
            logging.debug(e)
            return True  # Failed as expected
        return False

    def test_assert_none_matching_passes(self):
        with EventCallbackStream(FetchEvents(events=[1, 2, 3],
        with EventStream(FetchEvents(events=[1, 2, 3],
                                     delay_ms=50)) as event_stream:
            event_asserts = EventAsserts(event_stream)
            event_asserts.assert_none_matching(
            event_stream.assert_none_matching(
                lambda data: data.value_ == 4, timeout=timedelta(seconds=0.15))

    def test_assert_none_matching_passes_after_1_second(self):
        with EventCallbackStream(
                FetchEvents(events=[1, 2, 3, 4], delay_ms=400)) as event_stream:
            event_asserts = EventAsserts(event_stream)
            event_asserts.assert_none_matching(
        with EventStream(FetchEvents(events=[1, 2, 3, 4],
                                     delay_ms=400)) as event_stream:
            event_stream.assert_none_matching(
                lambda data: data.value_ == 4, timeout=timedelta(seconds=1))

    def test_assert_none_matching_fails(self):
        try:
            with EventCallbackStream(
                    FetchEvents(events=[1, 2, 3], delay_ms=50)) as event_stream:
                event_asserts = EventAsserts(event_stream)
                event_asserts.assert_none_matching(
            with EventStream(FetchEvents(events=[1, 2, 3],
                                         delay_ms=50)) as event_stream:
                event_stream.assert_none_matching(
                    lambda data: data.value_ == 2, timeout=timedelta(seconds=1))
        except Exception as e:
            logging.debug(e)
@@ -145,28 +136,24 @@ class CertSelfTest(BaseTestClass):
        return False

    def test_assert_occurs_at_least_passes(self):
        with EventCallbackStream(
                FetchEvents(events=[1, 2, 3, 1, 2, 3],
        with EventStream(FetchEvents(events=[1, 2, 3, 1, 2, 3],
                                     delay_ms=40)) as event_stream:
            event_asserts = EventAsserts(event_stream)
            event_asserts.assert_event_occurs(
            event_stream.assert_event_occurs(
                lambda data: data.value_ == 1,
                timeout=timedelta(milliseconds=300),
                at_least_times=2)

    def test_assert_occurs_passes(self):
        with EventCallbackStream(FetchEvents(events=[1, 2, 3],
        with EventStream(FetchEvents(events=[1, 2, 3],
                                     delay_ms=50)) as event_stream:
            event_asserts = EventAsserts(event_stream)
            event_asserts.assert_event_occurs(
            event_stream.assert_event_occurs(
                lambda data: data.value_ == 1, timeout=timedelta(seconds=1))

    def test_assert_occurs_fails(self):
        try:
            with EventCallbackStream(
                    FetchEvents(events=[1, 2, 3], delay_ms=50)) as event_stream:
                event_asserts = EventAsserts(event_stream)
                event_asserts.assert_event_occurs(
            with EventStream(FetchEvents(events=[1, 2, 3],
                                         delay_ms=50)) as event_stream:
                event_stream.assert_event_occurs(
                    lambda data: data.value_ == 4, timeout=timedelta(seconds=1))
        except Exception as e:
            logging.debug(e)
@@ -174,21 +161,18 @@ class CertSelfTest(BaseTestClass):
        return False

    def test_assert_occurs_at_most_passes(self):
        with EventCallbackStream(FetchEvents(events=[1, 2, 3, 4],
        with EventStream(FetchEvents(events=[1, 2, 3, 4],
                                     delay_ms=50)) as event_stream:
            event_asserts = EventAsserts(event_stream)
            event_asserts.assert_event_occurs_at_most(
            event_stream.assert_event_occurs_at_most(
                lambda data: data.value_ < 4,
                timeout=timedelta(seconds=1),
                at_most_times=3)

    def test_assert_occurs_at_most_fails(self):
        try:
            with EventCallbackStream(
                    FetchEvents(events=[1, 2, 3, 4],
            with EventStream(FetchEvents(events=[1, 2, 3, 4],
                                         delay_ms=50)) as event_stream:
                event_asserts = EventAsserts(event_stream)
                event_asserts.assert_event_occurs_at_most(
                event_stream.assert_event_occurs_at_most(
                    lambda data: data.value_ > 1,
                    timeout=timedelta(seconds=1),
                    at_most_times=2)
@@ -286,24 +270,21 @@ class CertSelfTest(BaseTestClass):
        return False

    def test_assertThat_eventStream_emits_passes(self):
        with EventCallbackStream(FetchEvents(events=[1, 2, 3],
        with EventStream(FetchEvents(events=[1, 2, 3],
                                     delay_ms=50)) as event_stream:
            event_asserts = EventAsserts(event_stream)
            assertThat(event_asserts).emits(lambda data: data.value_ == 1)
            assertThat(event_stream).emits(lambda data: data.value_ == 1)

    def test_assertThat_eventStream_emits_then_passes(self):
        with EventCallbackStream(FetchEvents(events=[1, 2, 3],
        with EventStream(FetchEvents(events=[1, 2, 3],
                                     delay_ms=50)) as event_stream:
            event_asserts = EventAsserts(event_stream)
            assertThat(event_asserts).emits(lambda data: data.value_ == 1).then(
            assertThat(event_stream).emits(lambda data: data.value_ == 1).then(
                lambda data: data.value_ == 3)

    def test_assertThat_eventStream_emits_fails(self):
        try:
            with EventCallbackStream(
                    FetchEvents(events=[1, 2, 3], delay_ms=50)) as event_stream:
                event_asserts = EventAsserts(event_stream)
                assertThat(event_asserts).emits(lambda data: data.value_ == 4)
            with EventStream(FetchEvents(events=[1, 2, 3],
                                         delay_ms=50)) as event_stream:
                assertThat(event_stream).emits(lambda data: data.value_ == 4)
        except Exception as e:
            logging.debug(e)
            return True  # Failed as expected
@@ -311,13 +292,51 @@ class CertSelfTest(BaseTestClass):

    def test_assertThat_eventStream_emits_then_fails(self):
        try:
            with EventCallbackStream(
                    FetchEvents(events=[1, 2, 3], delay_ms=50)) as event_stream:
                event_asserts = EventAsserts(event_stream)
                assertThat(event_asserts).emits(
            with EventStream(FetchEvents(events=[1, 2, 3],
                                         delay_ms=50)) as event_stream:
                assertThat(event_stream).emits(
                    lambda data: data.value_ == 1).emits(
                        lambda data: data.value_ == 4)
        except Exception as e:
            logging.debug(e)
            return True  # Failed as expected
        return False

    def test_assertThat_eventStream_emitsInOrder_passes(self):
        with EventStream(FetchEvents(events=[1, 2, 3],
                                     delay_ms=50)) as event_stream:
            assertThat(event_stream).emits(
                lambda data: data.value_ == 1,
                lambda data: data.value_ == 2).inOrder()

    def test_assertThat_eventStream_emitsInAnyOrder_passes(self):
        with EventStream(FetchEvents(events=[1, 2, 3],
                                     delay_ms=50)) as event_stream:
            assertThat(event_stream).emits(
                lambda data: data.value_ == 2,
                lambda data: data.value_ == 1).inAnyOrder().then(
                    lambda data: data.value_ == 3)

    def test_assertThat_eventStream_emitsInOrder_fails(self):
        try:
            with EventStream(FetchEvents(events=[1, 2, 3],
                                         delay_ms=50)) as event_stream:
                assertThat(event_stream).emits(
                    lambda data: data.value_ == 2,
                    lambda data: data.value_ == 1).inOrder()
        except Exception as e:
            logging.debug(e)
            return True  # Failed as expected
        return False

    def test_assertThat_eventStream_emitsInAnyOrder_fails(self):
        try:
            with EventStream(FetchEvents(events=[1, 2, 3],
                                         delay_ms=50)) as event_stream:
                assertThat(event_stream).emits(
                    lambda data: data.value_ == 4,
                    lambda data: data.value_ == 1).inAnyOrder()
        except Exception as e:
            logging.debug(e)
            return True  # Failed as expected
        return False
+0 −147
Original line number Diff line number Diff line
#!/usr/bin/env python3
#
#   Copyright 2019 - The Android Open Source Project
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

from concurrent.futures import ThreadPoolExecutor
from grpc import RpcError
import logging


class EventCallbackStream(object):
    """
    A an object that translate a gRPC stream of events to a Python stream of
    callbacks.

    All callbacks are non-sticky. This means that user will only receive callback
    generated after EventCallbackStream is registered and will not receive any
    callback after EventCallbackStream is unregistered

    You would need a new EventCallbackStream and anything that depends on this
    object once shutdown() is called
    """

    def __init__(self, server_stream_call):
        """
        Construct this object, call the |grpc_lambda| and trigger event_callback on
        the thread used to create this object until |destroy| is called when this
        object can no longer be used
        :param server_stream_call: A server stream call object returned from
                                   calling a gRPC server stream RPC API. The
                                   object must support iterator interface (i.e.
                                   next() method) and the grpc.Call interface
                                   so that we can cancel it
        :param event_callback: callback to be invoked with the only argument as
                               the generated event. The callback will be invoked
                               on a separate thread created within this object
        """
        if server_stream_call is None:
            raise ValueError("server_stream_call must not be None")
        self.server_stream_call = server_stream_call
        self.handlers = []
        self.executor = ThreadPoolExecutor()
        self.future = self.executor.submit(EventCallbackStream._event_loop,
                                           self)

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.shutdown()
        if traceback is None:
            return True
        else:
            return False

    def __del__(self):
        self.shutdown()

    def register_callback(self, callback, matcher_fn=None):
        """
        Register a callback to handle events. Event will be handled by callback
        if matcher_fn(event) returns True

        callback and matcher are registered as a tuple. Hence the same callback
        with different matcher are considered two different handler units. Same
        matcher, but different callback are also considered different handling
        unit

        Callback will be invoked on a ThreadPoolExecutor owned by this
        EventCallbackStream

        :param callback: Will be called as callback(event)
        :param matcher_fn: A boolean function that returns True or False when
                           calling matcher_fn(event), if None, all event will
                           be matched
        """
        if callback is None:
            raise ValueError("callback must not be None")
        self.handlers.append((callback, matcher_fn))

    def unregister_callback(self, callback, matcher_fn=None):
        """
        Unregister callback and matcher_fn from the event stream. Both objects
        must match exactly the ones when calling register_callback()

        :param callback: callback used in register_callback()
        :param matcher_fn: matcher_fn used in register_callback()
        :raises ValueError when (callback, matcher_fn) tuple is not found
        """
        if callback is None:
            raise ValueError("callback must not be None")
        self.handlers.remove((callback, matcher_fn))

    def shutdown(self):
        """
        Stop the gRPC lambda so that event_callback will not be invoked after th
        method returns.

        This object will be useless after this call as there is no way to restart
        the gRPC callback. You would have to create a new EventCallbackStream

        :return: None on success, exception object on failure
        """
        while not self.server_stream_call.done():
            self.server_stream_call.cancel()
        exception_for_return = None
        try:
            result = self.future.result()
            if result:
                logging.warning("Inner loop error %s" % result)
                raise result
        except Exception as exp:
            logging.warning("Exception: %s" % (exp))
            exception_for_return = exp
        self.executor.shutdown()
        return exception_for_return

    def _event_loop(self):
        """
        Main loop for consuming the gRPC stream events.
        Blocks until computation is cancelled
        :return: None on success, exception object on failure
        """
        try:
            for event in self.server_stream_call:
                for (callback, matcher_fn) in self.handlers:
                    if not matcher_fn or matcher_fn(event):
                        callback(event)
            return None
        except RpcError as exp:
            if self.server_stream_call.cancelled():
                logging.debug("Cancelled")
                return None
            else:
                logging.warning("Some RPC error not due to cancellation")
            return exp
+140 −18
Original line number Diff line number Diff line
@@ -21,34 +21,37 @@ from queue import SimpleQueue, Empty
from mobly import asserts

from google.protobuf import text_format
from concurrent.futures import ThreadPoolExecutor
from grpc import RpcError


class EventAsserts(object):
class EventStream(object):
    """
    A class that handles various asserts with respect to a gRPC unary stream
    A class that streams events from a gRPC stream, which you can assert on.

    This class must be created before an event happens as events in a
    EventCallbackStream is not sticky and will be lost if you don't subscribe
    to them before generating those events.

    When asserting on sequential events, a single EventAsserts object is enough

    When asserting on simultaneous events, you would need multiple EventAsserts
    objects as each EventAsserts object owns a separate queue that is actively
    being popped as asserted events happen
    Don't use these asserts directly, use the ones from cert.truth.
    """
    DEFAULT_TIMEOUT_SECONDS = 3

    def __init__(self, event_callback_stream):
        if event_callback_stream is None:
            raise ValueError("event_callback_stream cannot be None")
        self.event_callback_stream = event_callback_stream
    def __init__(self, server_stream_call):
        if server_stream_call is None:
            raise ValueError("server_stream_call cannot be None")

        self.server_stream_call = server_stream_call
        self.event_queue = SimpleQueue()
        self.callback = lambda event: self.event_queue.put(event)
        self.event_callback_stream.register_callback(self.callback)
        self.handlers = []
        self.executor = ThreadPoolExecutor()
        self.future = self.executor.submit(EventStream._event_loop, self)

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.shutdown()
        return traceback is None

    def __del__(self):
        self.event_callback_stream.unregister_callback(self.callback)
        self.shutdown()

    def remaining_time_delta(self, end_time):
        remaining = end_time - datetime.now()
@@ -56,6 +59,86 @@ class EventAsserts(object):
            remaining = timedelta(milliseconds=0)
        return remaining

    def shutdown(self):
        """
        Stop the gRPC lambda so that event_callback will not be invoked after th
        method returns.

        This object will be useless after this call as there is no way to restart
        the gRPC callback. You would have to create a new EventStream

        :return: None on success, exception object on failure
        """
        while not self.server_stream_call.done():
            self.server_stream_call.cancel()
        exception_for_return = None
        try:
            result = self.future.result()
            if result:
                logging.warning("Inner loop error %s" % result)
                raise result
        except Exception as exp:
            logging.warning("Exception: %s" % (exp))
            exception_for_return = exp
        self.executor.shutdown()
        return exception_for_return

    def register_callback(self, callback, matcher_fn=None):
        """
        Register a callback to handle events. Event will be handled by callback
        if matcher_fn(event) returns True

        callback and matcher are registered as a tuple. Hence the same callback
        with different matcher are considered two different handler units. Same
        matcher, but different callback are also considered different handling
        unit

        Callback will be invoked on a ThreadPoolExecutor owned by this
        EventStream

        :param callback: Will be called as callback(event)
        :param matcher_fn: A boolean function that returns True or False when
                           calling matcher_fn(event), if None, all event will
                           be matched
        """
        if callback is None:
            raise ValueError("callback must not be None")
        self.handlers.append((callback, matcher_fn))

    def unregister_callback(self, callback, matcher_fn=None):
        """
        Unregister callback and matcher_fn from the event stream. Both objects
        must match exactly the ones when calling register_callback()

        :param callback: callback used in register_callback()
        :param matcher_fn: matcher_fn used in register_callback()
        :raises ValueError when (callback, matcher_fn) tuple is not found
        """
        if callback is None:
            raise ValueError("callback must not be None")
        self.handlers.remove((callback, matcher_fn))

    def _event_loop(self):
        """
        Main loop for consuming the gRPC stream events.
        Blocks until computation is cancelled
        :return: None on success, exception object on failure
        """
        try:
            for event in self.server_stream_call:
                self.event_queue.put(event)
                for (callback, matcher_fn) in self.handlers:
                    if not matcher_fn or matcher_fn(event):
                        callback(event)
            return None
        except RpcError as exp:
            if self.server_stream_call.cancelled():
                logging.debug("Cancelled")
                return None
            else:
                logging.warning("Some RPC error not due to cancellation")
            return exp

    def assert_none(self, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
        """
        Assert no event happens within timeout period
@@ -174,3 +257,42 @@ class EventAsserts(object):
            len(event_list) <= at_most_times,
            msg=("Expected at most %d events, but got %d" % (at_most_times,
                                                             len(event_list))))

    def assert_all_events_occur(
            self,
            match_fns,
            order_matters,
            timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
        logging.debug("assert_all_events_occur %fs" % timeout.total_seconds())
        pending_matches = list(match_fns)
        matched_order = []
        end_time = datetime.now() + timeout
        while len(pending_matches) > 0 and datetime.now() < end_time:
            remaining = self.remaining_time_delta(end_time)
            logging.debug("Waiting for event (%fs remaining)" %
                          (remaining.total_seconds()))
            try:
                current_event = self.event_queue.get(
                    timeout=remaining.total_seconds())
                for match_fn in pending_matches:
                    if match_fn(current_event):
                        pending_matches.remove(match_fn)
                        matched_order.append(match_fn)
            except Empty:
                continue
        logging.debug("Done waiting for event")
        asserts.assert_true(
            len(matched_order) == len(match_fns),
            msg=("Expected at least %d events, but got %d" %
                 (len(match_fns), len(matched_order))))
        if order_matters:
            correct_order = True
            i = 0
            while i < len(match_fns):
                if match_fns[i] is not matched_order[i]:
                    correct_order = False
                    break
                i += 1
            asserts.assert_true(
                correct_order, "Events not received in correct order %s %s" %
                (match_fns, matched_order))
+35 −8

File changed.

Preview size limit exceeded, changes collapsed.

Loading