Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 189169e5 authored by Rahul Arya's avatar Rahul Arya
Browse files

[Private GATT] Add support for MTU Exchange

Snoop MTU_REQ/RSP packets from legacy stack, and use them to track the
MTU used in the isolated server.

Bug: 255880936
Test: unit

Change-Id: Ifcaa35be47abdbf714b592318184701645b55800
parent e8b68ea9
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ pub mod channel;
pub mod ffi;
pub mod ids;
pub mod mocks;
mod mtu;
pub mod opcode_types;
pub mod server;

+54 −4
Original line number Diff line number Diff line
@@ -7,13 +7,13 @@ use log::{error, info, trace};

use crate::{
    do_in_rust_thread,
    gatt::server::att_server_bearer::AttServerBearer,
    packets::{OwnedAttView, OwnedPacket},
    packets::{AttOpcode, OwnedAttView, OwnedPacket},
};

use super::{
    ffi::{InterceptAction, StoreCallbacksFromRust},
    ids::{AdvertiserId, ConnectionId, ServerId, TransportIndex},
    mtu::MtuEvent,
    opcode_types::{classify_opcode, OperationType},
};

@@ -32,7 +32,14 @@ pub struct Arbiter {
pub fn initialize_arbiter() {
    *ARBITER.lock().unwrap() = Some(Arbiter::new());

    StoreCallbacksFromRust(on_le_connect, on_le_disconnect, intercept_packet);
    StoreCallbacksFromRust(
        on_le_connect,
        on_le_disconnect,
        intercept_packet,
        |tcb_idx| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::OutgoingRequest),
        |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingResponse(mtu)),
        |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingRequest(mtu)),
    );
}

/// Acquire the mutex holding the Arbiter and provide a mutable reference to the
@@ -93,6 +100,12 @@ impl Arbiter {

        let att = OwnedAttView::try_parse(packet).ok()?;

        if att.view().get_opcode() == AttOpcode::EXCHANGE_MTU_REQUEST {
            // special case: this server opcode is handled by legacy stack, and we snoop
            // on its handling, since the MTU is shared between the client + server
            return None;
        }

        match classify_opcode(att.view().get_opcode()) {
            OperationType::Command | OperationType::Request | OperationType::Confirmation => {
                Some((att, conn_id))
@@ -127,6 +140,11 @@ impl Arbiter {
        info!("processing disconnection on transport {tcb_idx:?}");
        self.transport_to_owned_connection.remove(&tcb_idx)
    }

    /// Look up the conn_id for a given tcb_idx, if present
    pub fn get_conn_id(&self, tcb_idx: TransportIndex) -> Option<ConnectionId> {
        self.transport_to_owned_connection.get(&tcb_idx).copied()
    }
}

fn on_le_connect(tcb_idx: u8, advertiser: u8) {
@@ -168,13 +186,30 @@ fn intercept_packet(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction {
    }
}

fn on_mtu_event(tcb_idx: TransportIndex, event: MtuEvent) {
    if let Some(conn_id) = with_arbiter(|arbiter| arbiter.get_conn_id(tcb_idx)) {
        do_in_rust_thread(move |modules| {
            let Some(bearer) = modules.gatt_module.get_bearer(conn_id) else {
                error!("Bearer for {conn_id:?} not found");
                return;
            };
            if let Err(err) = bearer.handle_mtu_event(event) {
                error!("{err:?}")
            }
        });
    }
}

#[cfg(test)]
mod test {
    use super::*;

    use crate::{
        gatt::ids::AttHandle,
        packets::{AttBuilder, AttOpcode, AttReadRequestBuilder, Serializable},
        packets::{
            AttBuilder, AttExchangeMtuRequestBuilder, AttOpcode, AttReadRequestBuilder,
            Serializable,
        },
    };

    const TCB_IDX: TransportIndex = TransportIndex(1);
@@ -329,6 +364,21 @@ mod test {
        assert!(out.is_none());
    }

    #[test]
    fn test_mtu_bypass() {
        let mut arbiter = Arbiter::new();
        arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
        arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
        let packet = AttBuilder {
            opcode: AttOpcode::EXCHANGE_MTU_REQUEST,
            _child_: AttExchangeMtuRequestBuilder { mtu: 64 }.into(),
        };

        let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());

        assert!(out.is_none());
    }

    #[test]
    fn test_packet_bypass_when_not_isolated() {
        let mut arbiter = Arbiter::new();
+3 −0
Original line number Diff line number Diff line
@@ -149,6 +149,9 @@ mod inner {
            on_le_connect: fn(tcb_idx: u8, advertiser: u8),
            on_le_disconnect: fn(tcb_idx: u8),
            intercept_packet: fn(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction,
            on_outgoing_mtu_req: fn(tcb_idx: u8),
            on_incoming_mtu_resp: fn(tcb_idx: u8, mtu: usize),
            on_incoming_mtu_req: fn(tcb_idx: u8, mtu: usize),
        );

        /// Send an outgoing packet on the specified tcb_idx
+251 −0
Original line number Diff line number Diff line
//! The MTU on an ATT bearer is determined either by L2CAP (if EATT) or by the
//! ATT_EXCHANGE_MTU procedure (if on an unenhanced bearer).
//!
//! In the latter case, the MTU may be either (1) unset, (2) pending, or (3)
//! set. If the MTU is pending, ATT notifications/indications may not be sent.
//! Refer to Core Spec 5.3 Vol 3F 3.4.2 MTU exchange for full details.

use std::{cell::Cell, future::Future};

use anyhow::{bail, Result};
use log::info;
use tokio::sync::OwnedMutexGuard;

use crate::core::shared_mutex::SharedMutex;

/// An MTU event that we have snooped
pub enum MtuEvent {
    /// We have sent an MTU_REQ
    OutgoingRequest,
    /// We have received an MTU_RESP
    IncomingResponse(usize),
    /// We have received an MTU_REQ (and will immediately reply)
    IncomingRequest(usize),
}

/// The state of MTU negotiation on an unenhanced ATT bearer
pub struct AttMtu {
    /// The MTU we have committed to (i.e. sent a REQ and got a RESP, or
    /// vice-versa)
    previous_mtu: Cell<usize>,
    /// The MTU we have committed or are about to commit to (if a REQ is
    /// pending)
    stable_mtu: SharedMutex<usize>,
    /// Lock guard held if we are currrently performing MTU negotiation
    pending_exchange: Cell<Option<OwnedMutexGuard<usize>>>,
}

// NOTE: this is only true for ATT, not EATT
const DEFAULT_ATT_MTU: usize = 23;

impl AttMtu {
    /// Constructor
    pub fn new() -> Self {
        Self {
            previous_mtu: Cell::new(DEFAULT_ATT_MTU),
            stable_mtu: SharedMutex::new(DEFAULT_ATT_MTU),
            pending_exchange: Cell::new(None),
        }
    }

    /// Get the most recently negotiated MTU, or the default (if an MTU_REQ is
    /// outstanding and we get an ATT_REQ)
    pub fn snapshot_or_default(&self) -> usize {
        self.stable_mtu.try_lock().as_deref().cloned().unwrap_or_else(|_| self.previous_mtu.get())
    }

    /// Get the most recently negotiated MTU, or block if negotiation is ongoing
    /// (i.e. if an MTU_REQ is outstanding)
    pub fn snapshot(&self) -> impl Future<Output = Option<usize>> {
        let pending_snapshot = self.stable_mtu.lock();
        async move { pending_snapshot.await.as_deref().cloned() }
    }

    /// Handle an MtuEvent and update the stored MTU
    pub fn handle_event(&self, event: MtuEvent) -> Result<()> {
        match event {
            MtuEvent::OutgoingRequest => self.on_outgoing_request(),
            MtuEvent::IncomingResponse(mtu) => self.on_incoming_response(mtu),
            MtuEvent::IncomingRequest(mtu) => {
                self.on_incoming_request(mtu);
                Ok(())
            }
        }
    }

    fn on_outgoing_request(&self) -> Result<()> {
        let Ok(pending_mtu) = self.stable_mtu.try_lock() else {
          bail!("Sent ATT_EXCHANGE_MTU_REQ while an existing MTU exchange is taking place");
        };
        info!("Sending MTU_REQ, pausing indications/notifications");
        self.pending_exchange.replace(Some(pending_mtu));
        Ok(())
    }

    fn on_incoming_response(&self, mtu: usize) -> Result<()> {
        let Some(mut pending_exchange) = self.pending_exchange.take() else {
            bail!("Got ATT_EXCHANGE_MTU_RESP when transaction not taking place");
        };
        info!("Got an MTU_RESP of {mtu}");
        *pending_exchange = mtu;
        // note: since MTU_REQ can be sent at most once, this is a no-op, as the
        // stable_mtu will never again be blocked we do it anyway for clarity
        self.previous_mtu.set(mtu);
        Ok(())
    }

    fn on_incoming_request(&self, mtu: usize) {
        self.previous_mtu.set(mtu);
        if let Ok(mut stable_mtu) = self.stable_mtu.try_lock() {
            info!("Accepted an MTU_REQ of {mtu:?}");
            *stable_mtu = mtu;
        } else {
            info!("Accepted an MTU_REQ while our own MTU_REQ was outstanding")
        }
    }
}

#[cfg(test)]
mod test {
    use crate::utils::task::{block_on_locally, try_await};

    use super::*;

    const NEW_MTU: usize = 51;
    const ANOTHER_NEW_MTU: usize = 52;

    #[test]
    fn test_default_mtu() {
        let mtu = AttMtu::new();

        let stable_value = mtu.snapshot_or_default();
        let latest_value = tokio_test::block_on(mtu.snapshot()).unwrap();

        assert_eq!(stable_value, DEFAULT_ATT_MTU);
        assert_eq!(latest_value, DEFAULT_ATT_MTU);
    }

    #[test]
    fn test_guaranteed_mtu_during_client_negotiation() {
        // arrange
        let mtu = AttMtu::new();

        // act: send an MTU_REQ and validate snapshotted value
        mtu.handle_event(MtuEvent::OutgoingRequest).unwrap();
        let stable_value = mtu.snapshot_or_default();

        // assert: we use the default MTU for requests handled
        // while our request is pending
        assert_eq!(stable_value, DEFAULT_ATT_MTU);
    }

    #[test]
    fn test_mtu_blocking_snapshot_during_client_negotiation() {
        block_on_locally(async move {
            // arrange
            let mtu = AttMtu::new();

            // act: send an MTU_REQ
            mtu.handle_event(MtuEvent::OutgoingRequest).unwrap();
            // take snapshot of pending future
            let pending_mtu = try_await(mtu.snapshot()).await.unwrap_err();
            // resolve MTU_REQ
            mtu.handle_event(MtuEvent::IncomingResponse(NEW_MTU)).unwrap();

            // assert: that the snapshot resolved with the NEW_MTU
            assert_eq!(pending_mtu.await.unwrap(), NEW_MTU);
        });
    }

    #[test]
    fn test_receive_mtu_request() {
        block_on_locally(async move {
            // arrange
            let mtu = AttMtu::new();

            // act: receive an MTU_REQ
            mtu.handle_event(MtuEvent::IncomingRequest(NEW_MTU)).unwrap();
            // take snapshot
            let snapshot = mtu.snapshot().await;

            // assert: that the snapshot resolved with the NEW_MTU
            assert_eq!(snapshot.unwrap(), NEW_MTU);
        });
    }

    #[test]
    fn test_client_then_server_negotiation() {
        block_on_locally(async move {
            // arrange
            let mtu = AttMtu::new();

            // act: send an MTU_REQ
            mtu.handle_event(MtuEvent::OutgoingRequest).unwrap();
            // receive an MTU_RESP
            mtu.handle_event(MtuEvent::IncomingResponse(NEW_MTU)).unwrap();
            // receive an MTU_REQ
            mtu.handle_event(MtuEvent::IncomingRequest(ANOTHER_NEW_MTU)).unwrap();
            // take snapshot
            let snapshot = mtu.snapshot().await;

            // assert: that the snapshot resolved with ANOTHER_NEW_MTU
            assert_eq!(snapshot.unwrap(), ANOTHER_NEW_MTU);
        });
    }

    #[test]
    fn test_server_negotiation_then_pending_client_default_value() {
        block_on_locally(async move {
            // arrange
            let mtu = AttMtu::new();

            // act: receive an MTU_REQ
            mtu.handle_event(MtuEvent::IncomingRequest(NEW_MTU)).unwrap();
            // send a MTU_REQ
            mtu.handle_event(MtuEvent::OutgoingRequest).unwrap();
            // take snapshot for requests
            let snapshot = mtu.snapshot_or_default();

            // assert: that the snapshot resolved to NEW_MTU
            assert_eq!(snapshot, NEW_MTU);
        });
    }

    #[test]
    fn test_server_negotiation_then_pending_client_finalized_value() {
        block_on_locally(async move {
            // arrange
            let mtu = AttMtu::new();

            // act: receive an MTU_REQ
            mtu.handle_event(MtuEvent::IncomingRequest(NEW_MTU)).unwrap();
            // send a MTU_REQ
            mtu.handle_event(MtuEvent::OutgoingRequest).unwrap();
            // take snapshot of pending future
            let snapshot = try_await(mtu.snapshot()).await.unwrap_err();
            // receive MTU_RESP
            mtu.handle_event(MtuEvent::IncomingResponse(ANOTHER_NEW_MTU)).unwrap();

            // assert: that the snapshot resolved to ANOTHER_NEW_MTU
            assert_eq!(snapshot.await.unwrap(), ANOTHER_NEW_MTU);
        });
    }

    #[test]
    fn test_mtu_dropped_while_pending() {
        block_on_locally(async move {
            // arrange
            let mtu = AttMtu::new();

            // act: send a MTU_REQ
            mtu.handle_event(MtuEvent::OutgoingRequest).unwrap();
            // take snapshot and store pending future
            let pending_mtu = try_await(mtu.snapshot()).await.unwrap_err();
            // drop the mtu (when the bearer closes)
            drop(mtu);

            // assert: that the snapshot resolves to None since the bearer is gone
            assert!(pending_mtu.await.is_none());
        });
    }
}
+123 −9
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@

use std::{cell::Cell, future::Future};

use anyhow::Result;
use log::{error, trace, warn};
use tokio::task::spawn_local;

@@ -14,6 +15,7 @@ use crate::{
    },
    gatt::{
        ids::AttHandle,
        mtu::{AttMtu, MtuEvent},
        opcode_types::{classify_opcode, OperationType},
    },
    packets::{
@@ -34,8 +36,6 @@ enum AttRequestState<T: AttDatabase> {
    Pending(Option<OwnedHandle<()>>),
}

const DEFAULT_ATT_MTU: usize = 23;

/// The errors that can occur while trying to send a packet
#[derive(Debug)]
pub enum SendError {
@@ -51,7 +51,7 @@ pub enum SendError {
pub struct AttServerBearer<T: AttDatabase> {
    // general
    send_packet: Box<dyn Fn(AttBuilder) -> Result<(), SerializeError>>,
    mtu: Cell<usize>,
    mtu: AttMtu,

    // request state
    curr_request: Cell<AttRequestState<T>>,
@@ -71,7 +71,7 @@ impl<T: AttDatabase + Clone + 'static> AttServerBearer<T> {
        let (indication_handler, pending_confirmation) = IndicationHandler::new(db.clone());
        Self {
            send_packet: Box::new(send_packet),
            mtu: Cell::new(DEFAULT_ATT_MTU),
            mtu: AttMtu::new(),

            curr_request: AttRequestState::Idle(AttRequestHandler::new(db)).into(),

@@ -116,27 +116,43 @@ impl<T: AttDatabase + Clone + 'static> WeakBoxRef<'_, AttServerBearer<T>> {
        trace!("sending indication for handle {handle:?}");

        let locked_indication_handler = self.indication_handler.lock();
        let pending_mtu = self.mtu.snapshot();
        let this = self.downgrade();

        async move {
            locked_indication_handler
            // first wait until we are at the head of the queue and are ready to send
            // indications
            let mut indication_handler = locked_indication_handler
                .await
                .ok_or_else(|| {
                    warn!("indication for handle {handle:?} cancelled while queued since the connection dropped");
                    IndicationError::SendError(SendError::ConnectionDropped)
                })?
                .send(handle, data, |packet| this.try_send_packet(packet))
                })?;
            // then, if MTU negotiation is taking place, wait for it to complete
            let mtu = pending_mtu
                .await
                .ok_or_else(|| {
                    warn!("indication for handle {handle:?} cancelled while waiting for MTU exchange to complete since the connection dropped");
                    IndicationError::SendError(SendError::ConnectionDropped)
                })?;
            // finally, send, and wait for a response
            indication_handler.send(handle, data, mtu, |packet| this.try_send_packet(packet)).await
        }
    }

    /// Handle a snooped MTU event, to update the MTU we use for our various
    /// operations
    pub fn handle_mtu_event(&self, mtu_event: MtuEvent) -> Result<()> {
        self.mtu.handle_event(mtu_event)
    }

    fn handle_request(&self, packet: AttView<'_>) {
        let curr_request = self.curr_request.replace(AttRequestState::Pending(None));
        self.curr_request.replace(match curr_request {
            AttRequestState::Idle(mut request_handler) => {
                // even if the MTU is updated afterwards, 5.3 3F 3.4.2.2 states that the
                // request-time MTU should be used
                let mtu = self.mtu.get();
                let mtu = self.mtu.snapshot_or_default();
                let packet = packet.to_owned_packet();
                let this = self.downgrade();
                let task = spawn_local(async move {
@@ -220,7 +236,7 @@ mod test {
        },
        utils::{
            packet::{build_att_data, build_att_view_or_crash},
            task::block_on_locally,
            task::{block_on_locally, try_await},
        },
    };

@@ -557,4 +573,102 @@ mod test {
            ));
        });
    }

    #[test]
    fn test_single_indication_pending_mtu() {
        block_on_locally(async {
            // arrange: pending MTU negotiation
            let (conn, mut rx) = open_connection();
            conn.as_ref().handle_mtu_event(MtuEvent::OutgoingRequest).unwrap();

            // act: try to send an indication with a large payload size
            let _ =
                try_await(conn.as_ref().send_indication(
                    VALID_HANDLE,
                    AttAttributeDataChild::RawData((1..50).collect()),
                ))
                .await;
            // then resolve the MTU negotiation with a large MTU
            conn.as_ref().handle_mtu_event(MtuEvent::IncomingResponse(100)).unwrap();

            // assert: the indication was sent
            assert_eq!(rx.recv().await.unwrap().opcode, AttOpcode::HANDLE_VALUE_INDICATION);
        });
    }

    #[test]
    fn test_single_indication_pending_mtu_fail() {
        block_on_locally(async {
            // arrange: pending MTU negotiation
            let (conn, _) = open_connection();
            conn.as_ref().handle_mtu_event(MtuEvent::OutgoingRequest).unwrap();

            // act: try to send an indication with a large payload size
            let pending_mtu =
                try_await(conn.as_ref().send_indication(
                    VALID_HANDLE,
                    AttAttributeDataChild::RawData((1..50).collect()),
                ))
                .await
                .unwrap_err();
            // then resolve the MTU negotiation with a small MTU
            conn.as_ref().handle_mtu_event(MtuEvent::IncomingResponse(32)).unwrap();

            // assert: the indication failed to send
            assert!(matches!(pending_mtu.await, Err(IndicationError::DataExceedsMtu { .. })));
        });
    }

    #[test]
    fn test_server_transaction_pending_mtu() {
        block_on_locally(async {
            // arrange: pending MTU negotiation
            let (conn, mut rx) = open_connection();
            conn.as_ref().handle_mtu_event(MtuEvent::OutgoingRequest).unwrap();

            // act: send server packet
            conn.as_ref().handle_packet(
                build_att_view_or_crash(AttReadRequestBuilder {
                    attribute_handle: VALID_HANDLE.into(),
                })
                .view(),
            );

            // assert: that we reply even while the MTU req is outstanding
            assert_eq!(rx.recv().await.unwrap().opcode, AttOpcode::READ_RESPONSE);
        });
    }

    #[test]
    fn test_queued_indication_pending_mtu_uses_mtu_on_dequeue() {
        block_on_locally(async {
            // arrange: an outstanding indication
            let (conn, mut rx) = open_connection();
            let _ =
                try_await(conn.as_ref().send_indication(
                    VALID_HANDLE,
                    AttAttributeDataChild::RawData([1, 2, 3].into()),
                ))
                .await;
            rx.recv().await.unwrap(); // flush rx_queue

            // act: enqueue an indication with a large payload
            let _ =
                try_await(conn.as_ref().send_indication(
                    VALID_HANDLE,
                    AttAttributeDataChild::RawData((1..50).collect()),
                ))
                .await;
            // then perform MTU negotiation to upgrade to a large MTU
            conn.as_ref().handle_mtu_event(MtuEvent::OutgoingRequest).unwrap();
            conn.as_ref().handle_mtu_event(MtuEvent::IncomingResponse(512)).unwrap();
            // finally resolve the first indication, so the second indication can be sent
            conn.as_ref().handle_packet(
                build_att_view_or_crash(AttHandleValueConfirmationBuilder {}).view(),
            );

            // assert: the second indication successfully sent (so it used the new MTU)
            assert_eq!(rx.recv().await.unwrap().opcode, AttOpcode::HANDLE_VALUE_INDICATION);
        });
    }
}
Loading