Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit f4736f27 authored by Hansong Zhang's avatar Hansong Zhang
Browse files

Cert: Let EventStream subscribe() block

In EventStream, let subscribe() and unsubscribe() block. Also simplify
the code.

Test: cert/run_cert.sh
Change-Id: I73eaed21fc77a114c8b678a0c86775f125a35cd1
parent b4240ed5
Loading
Loading
Loading
Loading
+29 −33
Original line number Original line Diff line number Diff line
@@ -24,38 +24,45 @@ from grpc import StatusCode


class EventStream(object):
class EventStream(object):


  event_buffer = []

  def __init__(self, stream_stub_fn):
  def __init__(self, stream_stub_fn):
    self.stream_stub_fn = stream_stub_fn
    self.stream_stub_fn = stream_stub_fn
    self.event_buffer = []


  def clear_event_buffer(self):
    self.subscribe_request = common_pb2.EventStreamRequest(
    self.event_buffer.clear()

  def subscribe(self):
    return self.stream_stub_fn(
        common_pb2.EventStreamRequest(
        subscription_mode=common_pb2.SUBSCRIBE,
        subscription_mode=common_pb2.SUBSCRIBE,
        fetch_mode=common_pb2.NONE
        fetch_mode=common_pb2.NONE
    )
    )
    )


  def unsubscribe(self):
    self.unsubscribe_request = common_pb2.EventStreamRequest(
    return self.stream_stub_fn(
        common_pb2.EventStreamRequest(
        subscription_mode=common_pb2.UNSUBSCRIBE,
        subscription_mode=common_pb2.UNSUBSCRIBE,
        fetch_mode=common_pb2.NONE
        fetch_mode=common_pb2.NONE
    )
    )
    )


  def assert_none(self):
    self.fetch_all_current_request = common_pb2.EventStreamRequest(
    response = self.stream_stub_fn(
        subscription_mode=common_pb2.UNCHANGED,
        common_pb2.EventStreamRequest(
            subscription_mode=common_pb2.NONE,
        fetch_mode=common_pb2.ALL_CURRENT
        fetch_mode=common_pb2.ALL_CURRENT
    )
    )

    self.fetch_at_least_one_request = lambda expiration_time : common_pb2.EventStreamRequest(
        subscription_mode=common_pb2.UNCHANGED,
        fetch_mode=common_pb2.AT_LEAST_ONE,
        timeout_ms = int((expiration_time - datetime.now()).total_seconds() * 1000)
    )
    )


  def clear_event_buffer(self):
    self.event_buffer.clear()

  def subscribe(self):
    rpc = self.stream_stub_fn(self.subscribe_request)
    return rpc.result()

  def unsubscribe(self):
    rpc = self.stream_stub_fn(self.unsubscribe_request)
    return rpc.result()

  def assert_none(self):
    response = self.stream_stub_fn(self.fetch_all_current_request)

    try:
    try:
      for event in response:
      for event in response:
        self.event_buffer.append(event)
        self.event_buffer.append(event)
@@ -66,12 +73,7 @@ class EventStream(object):
      asserts.fail("event_buffer is not empty \n%s" % self.event_buffer)
      asserts.fail("event_buffer is not empty \n%s" % self.event_buffer)


  def assert_none_matching(self, match_fn):
  def assert_none_matching(self, match_fn):
    response = self.stream_stub_fn(
    response = self.stream_stub_fn(self.fetch_all_current_request)
        common_pb2.EventStreamRequest(
            subscription_mode=common_pb2.NONE,
            fetch_mode=common_pb2.ALL_CURRENT
        )
    )


    try:
    try:
      for event in response:
      for event in response:
@@ -95,13 +97,7 @@ class EventStream(object):
      if datetime.now() > expiration_time:
      if datetime.now() > expiration_time:
        asserts.fail("timeout of %s exceeded" % str(timeout))
        asserts.fail("timeout of %s exceeded" % str(timeout))


      response = self.stream_stub_fn(
      response = self.stream_stub_fn(self.fetch_at_least_one_request(expiration_time))
          common_pb2.EventStreamRequest(
              subscription_mode=common_pb2.NONE,
              fetch_mode=common_pb2.AT_LEAST_ONE,
              timeout_ms = int((expiration_time - datetime.now()).total_seconds() * 1000)
          )
      )


      try:
      try:
        for event in response:
        for event in response:
+12 −10
Original line number Original line Diff line number Diff line
@@ -52,6 +52,8 @@ class SimpleHalTest(GdBaseTestClass):
        self.device_under_test.hal.SendHciResetCommand(empty_pb2.Empty())
        self.device_under_test.hal.SendHciResetCommand(empty_pb2.Empty())
        self.cert_device.hal.SendHciResetCommand(empty_pb2.Empty())
        self.cert_device.hal.SendHciResetCommand(empty_pb2.Empty())


        self.hci_event_stream = self.device_under_test.hal.hci_event_stream

    def teardown_test(self):
    def teardown_test(self):
        self.device_under_test.rootservice.StopStack(
        self.device_under_test.rootservice.StopStack(
            facade_rootservice_pb2.StopStackRequest()
            facade_rootservice_pb2.StopStackRequest()
@@ -61,11 +63,11 @@ class SimpleHalTest(GdBaseTestClass):
        )
        )


    def test_none_event(self):
    def test_none_event(self):
        self.device_under_test.hal.hci_event_stream.clear_event_buffer()
        self.hci_event_stream.clear_event_buffer()


        self.device_under_test.hal.hci_event_stream.subscribe()
        self.hci_event_stream.subscribe()
        self.device_under_test.hal.hci_event_stream.assert_none()
        self.hci_event_stream.assert_none()
        self.device_under_test.hal.hci_event_stream.unsubscribe()
        self.hci_event_stream.unsubscribe()


    def test_example(self):
    def test_example(self):
        response = self.device_under_test.hal.SetLoopbackMode(
        response = self.device_under_test.hal.SetLoopbackMode(
@@ -77,7 +79,7 @@ class SimpleHalTest(GdBaseTestClass):
            hal_facade_pb2.LoopbackModeSettings(enable=True)
            hal_facade_pb2.LoopbackModeSettings(enable=True)
        )
        )


        self.device_under_test.hal.hci_event_stream.subscribe()
        self.hci_event_stream.subscribe()


        self.device_under_test.hal.SendHciCommand(
        self.device_under_test.hal.SendHciCommand(
            hal_facade_pb2.HciCommandPacket(
            hal_facade_pb2.HciCommandPacket(
@@ -85,13 +87,13 @@ class SimpleHalTest(GdBaseTestClass):
            )
            )
        )
        )


        self.device_under_test.hal.hci_event_stream.assert_event_occurs(
        self.hci_event_stream.assert_event_occurs(
            lambda packet: packet.payload == b'\x19\x08\x01\x04\x053\x8b\x9e0\x01'
            lambda packet: packet.payload == b'\x19\x08\x01\x04\x053\x8b\x9e0\x01'
        )
        )
        self.device_under_test.hal.hci_event_stream.unsubscribe()
        self.hci_event_stream.unsubscribe()


    def test_inquiry_from_dut(self):
    def test_inquiry_from_dut(self):
        self.device_under_test.hal.hci_event_stream.subscribe()
        self.hci_event_stream.subscribe()


        self.cert_device.hal.SetScanMode(
        self.cert_device.hal.SetScanMode(
            hal_cert_pb2.ScanModeSettings(mode=3)
            hal_cert_pb2.ScanModeSettings(mode=3)
@@ -99,8 +101,8 @@ class SimpleHalTest(GdBaseTestClass):
        self.device_under_test.hal.SetInquiry(
        self.device_under_test.hal.SetInquiry(
            hal_facade_pb2.InquirySettings(length=0x30, num_responses=0xff)
            hal_facade_pb2.InquirySettings(length=0x30, num_responses=0xff)
        )
        )
        self.device_under_test.hal.hci_event_stream.assert_event_occurs(
        self.hci_event_stream.assert_event_occurs(
            lambda packet: b'\x02\x0f' in packet.payload
            lambda packet: b'\x02\x0f' in packet.payload
            # Expecting an HCI Event (code 0x02, length 0x0f)
            # Expecting an HCI Event (code 0x02, length 0x0f)
        )
        )
        self.device_under_test.hal.hci_event_stream.unsubscribe()
        self.hci_event_stream.unsubscribe()