Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 559a1bff authored by Zach Johnson's avatar Zach Johnson
Browse files

Add default handling for behaviors

Default to crash as a safe choice, but allow
overriding to ignore or to specific default handlers
if the user chooses.

Test: cert/run --host
Change-Id: I1860997d0049a1eef59e3a254184dbe2fdef5171
parent fe071f52
Loading
Loading
Loading
Loading
+23 −0
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@
#   limitations under the License.

from abc import ABC, abstractmethod
from mobly import signals

from cert.truth import assertThat

@@ -35,11 +36,16 @@ def when(has_behaviors):
    return has_behaviors.get_behaviors()


def IGNORE_UNHANDLED(obj):
    pass


class SingleArgumentBehavior(object):

    def __init__(self, reply_stage_factory):
        self._reply_stage_factory = reply_stage_factory
        self._instances = []
        self.set_default_to_crash()

    def begin(self, matcher):
        return PersistenceStage(self, matcher, self._reply_stage_factory)
@@ -47,10 +53,27 @@ class SingleArgumentBehavior(object):
    def append(self, behavior_instance):
        self._instances.append(behavior_instance)

    def set_default(self, fn):
        assertThat(fn).isNotNone()
        self._default_fn = fn

    def set_default_to_crash(self):
        self._default_fn = None

    def set_default_to_ignore(self):
        self._default_fn = IGNORE_UNHANDLED

    def run(self, obj):
        for instance in self._instances:
            if instance.try_run(obj):
                return
        if self._default_fn is not None:
            self._default_fn(obj)
        else:
            raise signals.TestFailure(
                "%s: behavior for %s went unhandled" %
                (self._reply_stage_factory().__class__.__name__, obj),
                extras=None)


class PersistenceStage(object):
+40 −0
Original line number Diff line number Diff line
@@ -122,10 +122,14 @@ class ObjectWithBehaviors(IHasBehaviors):
        self.behaviors = TestBehaviors(self)
        self.count = 0
        self.captured = []
        self.unhandled_count = 0

    def get_behaviors(self):
        return self.behaviors

    def increment_unhandled(self):
        self.unhandled_count += 1


class CertSelfTest(BaseTestClass):

@@ -526,6 +530,7 @@ class CertSelfTest(BaseTestClass):

    def test_fluent_behavior_simple(self):
        thing = ObjectWithBehaviors()

        when(thing).test_request(anything()).then().increment_count()

        thing.behaviors.test_request_behavior.run("A")
@@ -535,6 +540,9 @@ class CertSelfTest(BaseTestClass):

    def test_fluent_behavior__then_single__captures_one(self):
        thing = ObjectWithBehaviors()

        thing.behaviors.test_request_behavior.set_default_to_ignore()

        when(thing).test_request(anything()).then().increment_count()

        thing.behaviors.test_request_behavior.run("A")
@@ -546,6 +554,7 @@ class CertSelfTest(BaseTestClass):

    def test_fluent_behavior__then_times__captures_all(self):
        thing = ObjectWithBehaviors()

        when(thing).test_request(anything()).then(times=3).increment_count()

        thing.behaviors.test_request_behavior.run("A")
@@ -557,6 +566,7 @@ class CertSelfTest(BaseTestClass):

    def test_fluent_behavior__always__captures_all(self):
        thing = ObjectWithBehaviors()

        when(thing).test_request(anything()).always().increment_count()

        thing.behaviors.test_request_behavior.run("A")
@@ -568,6 +578,8 @@ class CertSelfTest(BaseTestClass):

    def test_fluent_behavior__matcher__captures_relevant(self):
        thing = ObjectWithBehaviors()
        thing.behaviors.test_request_behavior.set_default_to_ignore()

        when(thing).test_request(
            lambda obj: obj == "B").always().increment_count()

@@ -580,6 +592,8 @@ class CertSelfTest(BaseTestClass):

    def test_fluent_behavior__then_repeated__captures_relevant(self):
        thing = ObjectWithBehaviors()
        thing.behaviors.test_request_behavior.set_default_to_ignore()

        when(thing).test_request(
            anything()).then().increment_count().increment_count()

@@ -592,6 +606,8 @@ class CertSelfTest(BaseTestClass):

    def test_fluent_behavior__fallback__captures_relevant(self):
        thing = ObjectWithBehaviors()
        thing.behaviors.test_request_behavior.set_default_to_ignore()

        when(thing).test_request(lambda obj: obj == "B").then(
            times=1).increment_count()
        when(thing).test_request(
@@ -606,3 +622,27 @@ class CertSelfTest(BaseTestClass):

        assertThat(thing.count).isEqualTo(3)
        assertThat(thing.captured).isEqualTo(["B", "C", "C"])

    def test_fluent_behavior__default_unhandled_crash(self):
        thing = ObjectWithBehaviors()

        when(thing).test_request(anything()).then().increment_count()

        thing.behaviors.test_request_behavior.run("A")
        try:
            thing.behaviors.test_request_behavior.run("A")
        except Exception as e:
            logging.debug(e)
            return True  # Failed as expected
        return False

    def test_fluent_behavior__set_default_works(self):
        thing = ObjectWithBehaviors()
        thing.behaviors.test_request_behavior.set_default(
            lambda obj: thing.increment_unhandled())

        when(thing).test_request(anything()).then().increment_count()

        thing.behaviors.test_request_behavior.run("A")
        thing.behaviors.test_request_behavior.run("A")
        assertThat(thing.unhandled_count).isEqualTo(1)