Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ec72f555 authored by Zach Johnson's avatar Zach Johnson
Browse files

Start migrating to channel abstractions

Filter out packets by channel id, so matchers don't have to
know about them and tests can be clearer.

Test: cert/run --host --test_filter=L2capTest
Change-Id: Id2a749f3b3b543a8b5a9a17c8db16cb97f447a33
parent f52a0fc9
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -43,10 +43,12 @@ class FilteringEventStream(IEventStream):
        self.event_queue = SimpleQueue()
        self.stream = stream

        self.stream.register_callback(self.__event_callback, self.filter_fn)
        self.stream.register_callback(
            self.__event_callback,
            lambda packet: self.filter_fn(packet) is not None)

    def __event_callback(self, event):
        self.event_queue.put(event)
        self.event_queue.put(self.filter_fn(event))

    def get_event_queue(self):
        return self.event_queue
+32 −18
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ import bluetooth_packets_python3 as bt_packets
from bluetooth_packets_python3 import l2cap_packets
from bluetooth_packets_python3.l2cap_packets import CommandCode
from bluetooth_packets_python3.l2cap_packets import ConnectionResponseResult
import logging


class L2capMatchers(object):
@@ -51,12 +52,25 @@ class L2capMatchers(object):
        return lambda packet: L2capMatchers._is_control_frame_with_code(packet, CommandCode.COMMAND_REJECT)

    @staticmethod
    def SupervisoryFrame(scid, req_seq=None, f=None, s=None, p=None):
        return lambda packet: L2capMatchers._is_matching_supervisory_frame(packet, scid, req_seq, f, s, p)
    def SupervisoryFrame(req_seq=None, f=None, s=None, p=None):
        return lambda packet: L2capMatchers._is_matching_supervisory_frame(packet, req_seq, f, s, p)

    @staticmethod
    def InformationFrame(scid, tx_seq=None, payload=None):
        return lambda packet: L2capMatchers._is_matching_information_frame(packet, scid, tx_seq, payload)
    def InformationFrame(tx_seq=None, payload=None):
        return lambda packet: L2capMatchers._is_matching_information_frame(packet, tx_seq, payload)

    @staticmethod
    def Data(payload):
        return lambda packet: packet.GetPayload().GetBytes() == payload

    # this is a hack - should be removed
    @staticmethod
    def PartialData(payload):
        return lambda packet: payload in packet.GetPayload().GetBytes()

    @staticmethod
    def ExtractBasicFrame(scid):
        return lambda packet: L2capMatchers._basic_frame_for(packet, scid)

    @staticmethod
    def _basic_frame(packet):
@@ -66,28 +80,29 @@ class L2capMatchers(object):
            bt_packets.PacketViewLittleEndian(list(packet.payload)))

    @staticmethod
    def _information_frame(packet, scid):
    def _basic_frame_for(packet, scid):
        frame = L2capMatchers._basic_frame(packet)
        if frame.GetChannelId() != scid:
            return None
        standard_frame = l2cap_packets.StandardFrameView(frame)
        return frame

    @staticmethod
    def _information_frame(packet):
        standard_frame = l2cap_packets.StandardFrameView(packet)
        if standard_frame.GetFrameType() != l2cap_packets.FrameType.I_FRAME:
            return None
        return l2cap_packets.EnhancedInformationFrameView(standard_frame)

    @staticmethod
    def _supervisory_frame(packet, scid):
        frame = L2capMatchers._basic_frame(packet)
        if frame.GetChannelId() != scid:
            return None
        standard_frame = l2cap_packets.StandardFrameView(frame)
    def _supervisory_frame(packet):
        standard_frame = l2cap_packets.StandardFrameView(packet)
        if standard_frame.GetFrameType() != l2cap_packets.FrameType.S_FRAME:
            return None
        return l2cap_packets.EnhancedSupervisoryFrameView(standard_frame)

    @staticmethod
    def _is_matching_information_frame(packet, scid, tx_seq, payload):
        frame = L2capMatchers._information_frame(packet, scid)
    def _is_matching_information_frame(packet, tx_seq, payload):
        frame = L2capMatchers._information_frame(packet)
        if frame is None:
            return False
        if tx_seq is not None and frame.GetTxSeq() != tx_seq:
@@ -98,8 +113,8 @@ class L2capMatchers(object):
        return True

    @staticmethod
    def _is_matching_supervisory_frame(packet, scid, req_seq, f, s, p):
        frame = L2capMatchers._supervisory_frame(packet, scid)
    def _is_matching_supervisory_frame(packet, req_seq, f, s, p):
        frame = L2capMatchers._supervisory_frame(packet)
        if frame is None:
            return False
        if req_seq is not None and frame.GetReqSeq() != req_seq:
@@ -114,10 +129,9 @@ class L2capMatchers(object):

    @staticmethod
    def _control_frame(packet):
        frame = L2capMatchers._basic_frame(packet)
        if frame is None or frame.GetChannelId() != 1:
        if packet.GetChannelId() != 1:
            return None
        return l2cap_packets.ControlView(frame.GetPayload())
        return l2cap_packets.ControlView(packet.GetPayload())

    @staticmethod
    def _control_frame_with_code(packet, code):
+28 −19
Original line number Diff line number Diff line
@@ -28,14 +28,21 @@ from cert.matchers import L2capMatchers

class CertL2capChannel(IEventStream):

    def __init__(self, device, scid, acl_stream):
    def __init__(self, device, scid, acl_stream, acl):
        self._device = device
        self._scid = scid
        self._our_acl_view = acl_stream
        self._acl_stream = acl_stream
        self._acl = acl
        self._our_acl_view = FilteringEventStream(
            acl_stream, L2capMatchers.ExtractBasicFrame(scid))

    def get_event_queue(self):
        return self._our_acl_view.get_event_queue()

    def send(self, packet):
        frame = l2cap_packets.BasicFrameBuilder(self._scid, packet)
        self._acl.send(frame.Serialize())


class CertL2cap(Closable):

@@ -75,24 +82,26 @@ class CertL2cap(Closable):
    def connect_acl(self, remote_addr):
        self._acl = self._acl_manager.initiate_connection(remote_addr)
        self._acl.wait_for_connection_complete()
        self.control_channel = CertL2capChannel(self._device, 1,
                                                self.get_acl_stream(),
                                                self._acl)
        self.get_acl_stream().register_callback(self._handle_control_packet)

    def open_channel(self, signal_id, psm, scid):
        # what is the 1 here for?
        open_channel = l2cap_packets.BasicFrameBuilder(
            1, l2cap_packets.ConnectionRequestBuilder(signal_id, psm, scid))
        self.send_acl(open_channel)
        self.control_channel.send(
            l2cap_packets.ConnectionRequestBuilder(signal_id, psm, scid))

        assertThat(self._acl).emits(L2capMatchers.ConnectionResponse(scid))
        return CertL2capChannel(self._device, scid, self.get_acl_stream())
        assertThat(self.control_channel).emits(
            L2capMatchers.ConnectionResponse(scid))
        return CertL2capChannel(self._device, scid, self.get_acl_stream(),
                                self._acl)

    # prefer to use channel abstraction instead, if at all possible
    def send_acl(self, packet):
        self._acl.send(packet.Serialize())

    def send_control_packet(self, packet):
        frame = l2cap_packets.BasicFrameBuilder(1, packet)
        self.send_acl(frame)
    def get_control_channel(self):
        return self.control_channel

    # temporary until clients migrated
    def get_acl_stream(self):
@@ -149,7 +158,7 @@ class CertL2cap(Closable):
            sid, cid, cid, l2cap_packets.ConnectionResponseResult.SUCCESS,
            l2cap_packets.ConnectionResponseStatus.
            NO_FURTHER_INFORMATION_AVAILABLE)
        self.send_control_packet(connection_response)
        self.control_channel.send(connection_response)
        return True

    def _on_connection_response_default(self, l2cap_control_view):
@@ -168,7 +177,7 @@ class CertL2cap(Closable):

        config_request = l2cap_packets.ConfigurationRequestBuilder(
            sid + 1, dcid, l2cap_packets.Continuation.END, options)
        self.send_control_packet(config_request)
        self.control_channel.send(config_request)
        return True

    def _on_connection_response_configuration_request_with_unknown_options_and_hint(
@@ -204,7 +213,7 @@ class CertL2cap(Closable):
        config_response = l2cap_packets.ConfigurationResponseBuilder(
            sid, self.scid_to_dcid.get(dcid, 0), l2cap_packets.Continuation.END,
            l2cap_packets.ConfigurationResponseResult.SUCCESS, [])
        self.send_control_packet(config_response)
        self.control_channel.send(config_response)

    def _on_configuration_request_unacceptable_parameters(
            self, l2cap_control_view):
@@ -225,7 +234,7 @@ class CertL2cap(Closable):
            sid, self.scid_to_dcid.get(dcid, 0), l2cap_packets.Continuation.END,
            l2cap_packets.ConfigurationResponseResult.UNACCEPTABLE_PARAMETERS,
            [mtu_opt, fcs_opt, rfc_opt])
        self.send_control_packet(config_response)
        self.control_channel.send(config_response)

    def _on_configuration_response_default(self, l2cap_control_view):
        configuration_response = l2cap_packets.ConfigurationResponseView(
@@ -240,7 +249,7 @@ class CertL2cap(Closable):
        dcid = disconnection_request.GetDestinationCid()
        disconnection_response = l2cap_packets.DisconnectionResponseBuilder(
            sid, dcid, scid)
        self.send_control_packet(disconnection_response)
        self.control_channel.send(disconnection_response)

    def _on_disconnection_response_default(self, l2cap_control_view):
        disconnection_response = l2cap_packets.DisconnectionResponseView(
@@ -254,18 +263,18 @@ class CertL2cap(Closable):
        if information_type == l2cap_packets.InformationRequestInfoType.CONNECTIONLESS_MTU:
            response = l2cap_packets.InformationResponseConnectionlessMtuBuilder(
                sid, l2cap_packets.InformationRequestResult.SUCCESS, 100)
            self.send_control_packet(response)
            self.control_channel.send(response)
            return
        if information_type == l2cap_packets.InformationRequestInfoType.EXTENDED_FEATURES_SUPPORTED:
            response = l2cap_packets.InformationResponseExtendedFeaturesBuilder(
                sid, l2cap_packets.InformationRequestResult.SUCCESS, 0, 0, 0, 1,
                0, 1, 0, 0, 0, 0)
            self.send_control_packet(response)
            self.control_channel.send(response)
            return
        if information_type == l2cap_packets.InformationRequestInfoType.FIXED_CHANNELS_SUPPORTED:
            response = l2cap_packets.InformationResponseFixedChannelsBuilder(
                sid, l2cap_packets.InformationRequestResult.SUCCESS, 2)
            self.send_control_packet(response)
            self.control_channel.send(response)
            return

    def _on_information_response_default(self, l2cap_control_view):
+178 −215

File changed.

Preview size limit exceeded, changes collapsed.

+13 −5
Original line number Diff line number Diff line
@@ -72,11 +72,19 @@ PYBIND11_MODULE(bluetooth_packets_python3, m) {
      m, "PacketStructBigEndian");
  py::class_<Iterator<kLittleEndian>>(m, "IteratorLittleEndian");
  py::class_<Iterator<!kLittleEndian>>(m, "IteratorBigEndian");
  py::class_<PacketView<kLittleEndian>>(m, "PacketViewLittleEndian").def(py::init([](std::vector<uint8_t> bytes) {
  py::class_<PacketView<kLittleEndian>>(m, "PacketViewLittleEndian")
      .def(py::init([](std::vector<uint8_t> bytes) {
        // Make a copy
        auto bytes_shared = std::make_shared<std::vector<uint8_t>>(bytes);
        return std::make_unique<PacketView<kLittleEndian>>(bytes_shared);
  }));
      }))
      .def("GetBytes", [](const PacketView<kLittleEndian> view) {
        std::string result;
        for (auto it = view.begin(); it != view.end(); it++) {
          result += *it;
        }
        return py::bytes(result);
      });
  py::class_<PacketView<!kLittleEndian>>(m, "PacketViewBigEndian").def(py::init([](std::vector<uint8_t> bytes) {
    // Make a copy
    auto bytes_shared = std::make_shared<std::vector<uint8_t>>(bytes);