Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ba65d9aa authored by Zach Johnson's avatar Zach Johnson
Browse files

Add Closable, to reduce repeated code

Test: cert/run --host
Change-Id: Ia80f8e398fe4608ecd8d8a5e0b050681378ddf11
parent c8b6f7f4
Loading
Loading
Loading
Loading
+39 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3
#
#   Copyright 2019 - The Android Open Source Project
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

from abc import ABC, abstractmethod


class Closable(ABC):

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.close()
        return traceback is None

    def __del__(self):
        self.close()

    @abstractmethod
    def close(self):
        pass


def safeClose(closable):
    if closable is not None:
        closable.close()
+4 −12
Original line number Diff line number Diff line
@@ -26,6 +26,8 @@ from grpc import RpcError

from abc import ABC, abstractmethod

from cert.closable import Closable


class IEventStream(ABC):

@@ -56,7 +58,7 @@ class FilteringEventStream(IEventStream):
DEFAULT_TIMEOUT_SECONDS = 3


class EventStream(IEventStream):
class EventStream(IEventStream, Closable):
    """
    A class that streams events from a gRPC stream, which you can assert on.

@@ -73,20 +75,10 @@ class EventStream(IEventStream):
        self.executor = ThreadPoolExecutor()
        self.future = self.executor.submit(EventStream._event_loop, self)

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.shutdown()
        return traceback is None

    def __del__(self):
        self.shutdown()

    def get_event_queue(self):
        return self.event_queue

    def shutdown(self):
    def close(self):
        """
        Stop the gRPC lambda so that event_callback will not be invoked after th
        method returns.
+9 −29
Original line number Diff line number Diff line
@@ -21,13 +21,15 @@ from cert.event_stream import IEventStream
from cert.captures import ReadBdAddrCompleteCapture
from cert.captures import ConnectionCompleteCapture
from cert.captures import ConnectionRequestCapture
from cert.closable import Closable
from cert.closable import safeClose
from bluetooth_packets_python3 import hci_packets
from cert.truth import assertThat
from hci.facade import facade_pb2 as hci_facade
from hci.facade import acl_manager_facade_pb2 as acl_manager_facade


class PyAclManagerAclConnection(IEventStream):
class PyAclManagerAclConnection(IEventStream, Closable):

    def __init__(self, device, acl_stream, remote_addr, handle):
        self.device = device
@@ -45,19 +47,8 @@ class PyAclManagerAclConnection(IEventStream):
        else:
            self.connection_event_stream = None

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.clean_up()
        return traceback is None

    def __del__(self):
        self.clean_up()

    def clean_up(self):
        if self.connection_event_stream is not None:
            self.connection_event_stream.shutdown()
    def close(self):
        safeClose(self.connection_event_stream)

    def wait_for_connection_complete(self):
        connection_complete = ConnectionCompleteCapture()
@@ -72,7 +63,7 @@ class PyAclManagerAclConnection(IEventStream):
        return self.our_acl_stream.get_event_queue()


class PyAclManager(object):
class PyAclManager(Closable):

    def __init__(self, device):
        self.device = device
@@ -81,20 +72,9 @@ class PyAclManager(object):
            self.device.hci_acl_manager.FetchAclData(empty_proto.Empty()))
        self.incoming_connection_stream = None

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.clean_up()
        return traceback is None

    def __del__(self):
        self.clean_up()

    def clean_up(self):
        self.acl_stream.shutdown()
        if self.incoming_connection_stream is not None:
            self.incoming_connection_stream.shutdown()
    def close(self):
        safeClose(self.acl_stream)
        safeClose(self.incoming_connection_stream)

    def listen_for_incoming_connections(self):
        self.incoming_connection_stream = EventStream(
+6 −14
Original line number Diff line number Diff line
@@ -18,6 +18,8 @@ from google.protobuf import empty_pb2 as empty_proto
from cert.event_stream import EventStream
from cert.event_stream import FilteringEventStream
from cert.event_stream import IEventStream
from cert.closable import Closable
from cert.closable import safeClose
from cert.captures import ReadBdAddrCompleteCapture
from cert.captures import ConnectionCompleteCapture
from cert.captures import ConnectionRequestCapture
@@ -54,7 +56,7 @@ class PyHciAclConnection(IEventStream):
        return self.our_acl_stream.get_event_queue()


class PyHci(object):
class PyHci(Closable):

    def __init__(self, device):
        self.device = device
@@ -69,19 +71,9 @@ class PyHci(object):
        self.acl_stream = EventStream(
            self.device.hci.FetchAclPackets(empty_proto.Empty()))

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.clean_up()
        return traceback is None

    def __del__(self):
        self.clean_up()

    def clean_up(self):
        self.event_stream.shutdown()
        self.acl_stream.shutdown()
    def close(self):
        safeClose(self.event_stream)
        safeClose(self.acl_stream)

    def get_event_stream(self):
        return self.event_stream
+2 −2
Original line number Diff line number Diff line
@@ -47,8 +47,8 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
        self.dut_acl_manager = PyAclManager(self.dut)

    def teardown_test(self):
        self.cert_hci.clean_up()
        self.dut_acl_manager.clean_up()
        self.cert_hci.close()
        self.dut_acl_manager.close()
        super().teardown_test()

    def test_dut_connects(self):