Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ea077ced authored by Rahul Arya's avatar Rahul Arya Committed by Automerger Merge Worker
Browse files

Merge "[GATT Server] Use TransportIndex instead of conn_id in API" am: cc9d7d27 am: a9f5fc11

parents e3bbb24e a9f5fc11
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
//! Build file to generate packets
//!
//! Run `cargo install .` in `tools/pdl` to ensure `pdl` is in your
//! path.
use std::{
    env,
    fs::File,
+48 −27
Original line number Diff line number Diff line
@@ -90,13 +90,15 @@ impl Arbiter {
    }

    /// Test to see if a buffer contains a valid ATT packet with an opcode we
    /// are interested in intercepting (those intended for servers)
    /// are interested in intercepting (those intended for servers that are isolated)
    pub fn try_parse_att_server_packet(
        &self,
        tcb_idx: TransportIndex,
        packet: Box<[u8]>,
    ) -> Option<(OwnedAttView, ConnectionId)> {
        let conn_id = *self.transport_to_owned_connection.get(&tcb_idx)?;
    ) -> Option<OwnedAttView> {
        if !self.transport_to_owned_connection.contains_key(&tcb_idx) {
            return None;
        }

        let att = OwnedAttView::try_parse(packet).ok()?;

@@ -108,7 +110,7 @@ impl Arbiter {

        match classify_opcode(att.view().get_opcode()) {
            OperationType::Command | OperationType::Request | OperationType::Confirmation => {
                Some((att, conn_id))
                Some(att)
            }
            _ => None,
        }
@@ -135,10 +137,10 @@ impl Arbiter {
        Some(conn_id)
    }

    /// Handle a disconnection and return the disconnected conn_id, if any
    pub fn on_le_disconnect(&mut self, tcb_idx: TransportIndex) -> Option<ConnectionId> {
    /// Handle a disconnection, if any, and return whether the disconnection was registered
    pub fn on_le_disconnect(&mut self, tcb_idx: TransportIndex) -> bool {
        info!("processing disconnection on transport {tcb_idx:?}");
        self.transport_to_owned_connection.remove(&tcb_idx)
        self.transport_to_owned_connection.remove(&tcb_idx).is_some()
    }

    /// Look up the conn_id for a given tcb_idx, if present
@@ -160,10 +162,10 @@ fn on_le_connect(tcb_idx: u8, advertiser: u8) {
}

fn on_le_disconnect(tcb_idx: u8) {
    if let Some(conn_id) = with_arbiter(|arbiter| arbiter.on_le_disconnect(TransportIndex(tcb_idx)))
    {
    let tcb_idx = TransportIndex(tcb_idx);
    if with_arbiter(|arbiter| arbiter.on_le_disconnect(tcb_idx)) {
        do_in_rust_thread(move |modules| {
            if let Err(err) = modules.gatt_module.on_le_disconnect(conn_id) {
            if let Err(err) = modules.gatt_module.on_le_disconnect(tcb_idx) {
                error!("{err:?}")
            }
        })
@@ -171,15 +173,16 @@ fn on_le_disconnect(tcb_idx: u8) {
}

fn intercept_packet(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction {
    if let Some((att, conn_id)) = with_arbiter(|arbiter| {
        arbiter.try_parse_att_server_packet(TransportIndex(tcb_idx), packet.into_boxed_slice())
    let tcb_idx = TransportIndex(tcb_idx);
    if let Some(att) = with_arbiter(|arbiter| {
        arbiter.try_parse_att_server_packet(tcb_idx, packet.into_boxed_slice())
    }) {
        do_in_rust_thread(move |modules| {
            trace!("pushing packet to GATT");
            if let Some(bearer) = modules.gatt_module.get_bearer(conn_id) {
            if let Some(bearer) = modules.gatt_module.get_bearer(tcb_idx) {
                bearer.handle_packet(att.view())
            } else {
                error!("{conn_id:?} closed, bearer does not exist");
                error!("Bearer for {tcb_idx:?} not found");
            }
        });
        InterceptAction::Drop
@@ -189,10 +192,10 @@ fn intercept_packet(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction {
}

fn on_mtu_event(tcb_idx: TransportIndex, event: MtuEvent) {
    if let Some(conn_id) = with_arbiter(|arbiter| arbiter.get_conn_id(tcb_idx)) {
    if with_arbiter(|arbiter| arbiter.get_conn_id(tcb_idx)).is_some() {
        do_in_rust_thread(move |modules| {
            let Some(bearer) = modules.gatt_module.get_bearer(conn_id) else {
                error!("Bearer for {conn_id:?} not found");
            let Some(bearer) = modules.gatt_module.get_bearer(tcb_idx) else {
                error!("Bearer for {tcb_idx:?} not found");
                return;
            };
            if let Err(err) = bearer.handle_mtu_event(event) {
@@ -215,12 +218,13 @@ mod test {
    };

    const TCB_IDX: TransportIndex = TransportIndex(1);
    const ADVERTISER_ID: AdvertiserId = AdvertiserId(2);
    const SERVER_ID: ServerId = ServerId(3);
    const ANOTHER_TCB_IDX: TransportIndex = TransportIndex(2);
    const ADVERTISER_ID: AdvertiserId = AdvertiserId(3);
    const SERVER_ID: ServerId = ServerId(4);

    const CONN_ID: ConnectionId = ConnectionId::new(TCB_IDX, SERVER_ID);

    const ANOTHER_ADVERTISER_ID: AdvertiserId = AdvertiserId(4);
    const ANOTHER_ADVERTISER_ID: AdvertiserId = AdvertiserId(5);

    #[test]
    fn test_non_isolated_connect() {
@@ -256,9 +260,9 @@ mod test {
        let mut arbiter = Arbiter::new();
        arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);

        let conn_id = arbiter.on_le_disconnect(TCB_IDX);
        let ok = arbiter.on_le_disconnect(TCB_IDX);

        assert!(conn_id.is_none())
        assert!(!ok)
    }

    #[test]
@@ -267,9 +271,9 @@ mod test {
        arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
        arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);

        let conn_id = arbiter.on_le_disconnect(TCB_IDX);
        let ok = arbiter.on_le_disconnect(TCB_IDX);

        assert_eq!(conn_id, Some(CONN_ID));
        assert!(ok)
    }

    #[test]
@@ -348,7 +352,7 @@ mod test {

        let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());

        assert!(matches!(out, Some((_, CONN_ID))));
        assert!(out.is_some());
    }

    #[test]
@@ -396,6 +400,23 @@ mod test {
        assert!(out.is_none());
    }

    #[test]
    fn test_packet_bypass_when_different_connection() {
        let mut arbiter = Arbiter::new();
        arbiter.associate_server_with_advertiser(SERVER_ID, ADVERTISER_ID);
        arbiter.on_le_connect(TCB_IDX, ADVERTISER_ID);
        arbiter.on_le_connect(ANOTHER_TCB_IDX, ANOTHER_ADVERTISER_ID);
        let packet = AttBuilder {
            opcode: AttOpcode::READ_REQUEST,
            _child_: AttReadRequestBuilder { attribute_handle: AttHandle(1).into() }.into(),
        };

        let out =
            arbiter.try_parse_att_server_packet(ANOTHER_TCB_IDX, packet.to_vec().unwrap().into());

        assert!(out.is_none());
    }

    #[test]
    fn test_packet_capture_when_isolated_after_advertiser_closes() {
        let mut arbiter = Arbiter::new();
@@ -409,7 +430,7 @@ mod test {

        let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());

        assert!(matches!(out, Some((_, CONN_ID))));
        assert!(out.is_some());
    }

    #[test]
@@ -425,7 +446,7 @@ mod test {

        let out = arbiter.try_parse_att_server_packet(TCB_IDX, packet.to_vec().unwrap().into());

        assert!(matches!(out, Some((_, CONN_ID))));
        assert!(out.is_some());
    }

    #[test]
+26 −26
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@ use crate::packets::{AttAttributeDataChild, AttAttributeDataView, AttErrorCode};

use super::{
    ffi::AttributeBackingType,
    ids::{AttHandle, ConnectionId, TransactionId},
    ids::{AttHandle, ConnectionId, TransactionId, TransportIndex},
    server::IndicationError,
};

@@ -98,7 +98,7 @@ pub trait RawGattDatastore {
    /// Read a characteristic from the specified connection at the given handle.
    async fn read(
        &self,
        conn_id: ConnectionId,
        tcb_idx: TransportIndex,
        handle: AttHandle,
        offset: u32,
        attr_type: AttributeBackingType,
@@ -107,7 +107,7 @@ pub trait RawGattDatastore {
    /// Write data to a given characteristic on the specified connection.
    async fn write(
        &self,
        conn_id: ConnectionId,
        tcb_idx: TransportIndex,
        handle: AttHandle,
        attr_type: AttributeBackingType,
        write_type: GattWriteRequestType,
@@ -118,7 +118,7 @@ pub trait RawGattDatastore {
    /// for a response from the upper layer.
    fn write_no_response(
        &self,
        conn_id: ConnectionId,
        tcb_idx: TransportIndex,
        handle: AttHandle,
        attr_type: AttributeBackingType,
        data: AttAttributeDataView<'_>,
@@ -127,7 +127,7 @@ pub trait RawGattDatastore {
    /// Execute or cancel any prepared writes
    async fn execute(
        &self,
        conn_id: ConnectionId,
        tcb_idx: TransportIndex,
        decision: TransactionDecision,
    ) -> Result<(), AttErrorCode>;
}
@@ -139,7 +139,7 @@ pub trait GattDatastore {
    /// Read a characteristic from the specified connection at the given handle.
    async fn read(
        &self,
        conn_id: ConnectionId,
        tcb_idx: TransportIndex,
        handle: AttHandle,
        attr_type: AttributeBackingType,
    ) -> Result<AttAttributeDataChild, AttErrorCode>;
@@ -147,7 +147,7 @@ pub trait GattDatastore {
    /// Write data to a given characteristic on the specified connection.
    async fn write(
        &self,
        conn_id: ConnectionId,
        tcb_idx: TransportIndex,
        handle: AttHandle,
        attr_type: AttributeBackingType,
        data: AttAttributeDataView<'_>,
@@ -159,7 +159,7 @@ impl<T: GattDatastore + ?Sized> RawGattDatastore for T {
    /// Read a characteristic from the specified connection at the given handle.
    async fn read(
        &self,
        conn_id: ConnectionId,
        tcb_idx: TransportIndex,
        handle: AttHandle,
        offset: u32,
        attr_type: AttributeBackingType,
@@ -168,13 +168,13 @@ impl<T: GattDatastore + ?Sized> RawGattDatastore for T {
            warn!("got read blob request for non-long attribute {handle:?}");
            return Err(AttErrorCode::ATTRIBUTE_NOT_LONG);
        }
        self.read(conn_id, handle, attr_type).await
        self.read(tcb_idx, handle, attr_type).await
    }

    /// Write data to a given characteristic on the specified connection.
    async fn write(
        &self,
        conn_id: ConnectionId,
        tcb_idx: TransportIndex,
        handle: AttHandle,
        attr_type: AttributeBackingType,
        write_type: GattWriteRequestType,
@@ -182,26 +182,26 @@ impl<T: GattDatastore + ?Sized> RawGattDatastore for T {
    ) -> Result<(), AttErrorCode> {
        match write_type {
            GattWriteRequestType::Prepare { .. } => {
                warn!("got prepare write attempt to characteristic {handle:?} not supporting write_without_response");
                warn!("got prepare write attempt on {tcb_idx:?} to characteristic {handle:?} not supporting write_without_response");
                Err(AttErrorCode::WRITE_REQUEST_REJECTED)
            }
            GattWriteRequestType::Request => self.write(conn_id, handle, attr_type, data).await,
            GattWriteRequestType::Request => self.write(tcb_idx, handle, attr_type, data).await,
        }
    }

    fn write_no_response(
        &self,
        conn_id: ConnectionId,
        tcb_idx: TransportIndex,
        handle: AttHandle,
        _: AttributeBackingType,
        _: AttAttributeDataView<'_>,
    ) {
        // silently drop, since there's no way to return an error
        warn!("got write command on {conn_id:?} to characteristic {handle:?} not supporting write_without_response");
        warn!("got write command on {tcb_idx:?} to characteristic {handle:?} not supporting write_without_response");
    }

    /// Execute or cancel any prepared writes
    async fn execute(&self, _: ConnectionId, _: TransactionDecision) -> Result<(), AttErrorCode> {
    async fn execute(&self, _: TransportIndex, _: TransactionDecision) -> Result<(), AttErrorCode> {
        // we never do prepared writes, so who cares
        return Ok(());
    }
@@ -222,7 +222,7 @@ mod test {

    use super::*;

    const CONN_ID: ConnectionId = ConnectionId(1);
    const TCB_IDX: TransportIndex = TransportIndex(1);
    const HANDLE: AttHandle = AttHandle(1);
    const DATA: [u8; 4] = [1, 2, 3, 4];

@@ -236,7 +236,7 @@ mod test {
            spawn_local(async move {
                RawGattDatastore::read(
                    &datastore,
                    CONN_ID,
                    TCB_IDX,
                    HANDLE,
                    0,
                    AttributeBackingType::Characteristic,
@@ -248,7 +248,7 @@ mod test {
            // assert: got read event
            assert!(matches!(
                resp,
                MockDatastoreEvents::Read(CONN_ID, HANDLE, AttributeBackingType::Characteristic, _)
                MockDatastoreEvents::Read(TCB_IDX, HANDLE, AttributeBackingType::Characteristic, _)
            ));
        });
    }
@@ -263,7 +263,7 @@ mod test {
            let pending = spawn_local(async move {
                RawGattDatastore::read(
                    &datastore,
                    CONN_ID,
                    TCB_IDX,
                    HANDLE,
                    0,
                    AttributeBackingType::Characteristic,
@@ -289,7 +289,7 @@ mod test {
        // act: send read blob request
        let resp = block_on_locally(RawGattDatastore::read(
            &datastore,
            CONN_ID,
            TCB_IDX,
            HANDLE,
            1,
            AttributeBackingType::Characteristic,
@@ -315,7 +315,7 @@ mod test {
            spawn_local(async move {
                RawGattDatastore::write(
                    &datastore,
                    CONN_ID,
                    TCB_IDX,
                    HANDLE,
                    AttributeBackingType::Characteristic,
                    GattWriteRequestType::Request,
@@ -329,7 +329,7 @@ mod test {
            assert!(matches!(
                resp,
                MockDatastoreEvents::Write(
                    CONN_ID,
                    TCB_IDX,
                    HANDLE,
                    AttributeBackingType::Characteristic,
                    _,
@@ -349,7 +349,7 @@ mod test {
            let pending = spawn_local(async move {
                RawGattDatastore::write(
                    &datastore,
                    CONN_ID,
                    TCB_IDX,
                    HANDLE,
                    AttributeBackingType::Characteristic,
                    GattWriteRequestType::Request,
@@ -376,7 +376,7 @@ mod test {
        // act: send prepare write request
        let resp = block_on_locally(RawGattDatastore::write(
            &datastore,
            CONN_ID,
            TCB_IDX,
            HANDLE,
            AttributeBackingType::Characteristic,
            GattWriteRequestType::Prepare { offset: 1 },
@@ -397,7 +397,7 @@ mod test {
        // act: send write command
        RawGattDatastore::write_no_response(
            &datastore,
            CONN_ID,
            TCB_IDX,
            HANDLE,
            AttributeBackingType::Characteristic,
            make_data().view(),
@@ -415,7 +415,7 @@ mod test {
        // act: send execute request
        let resp = block_on_locally(RawGattDatastore::execute(
            &datastore,
            CONN_ID,
            TCB_IDX,
            TransactionDecision::Execute,
        ));

+58 −21
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ use tokio::{sync::oneshot, time::timeout};

use crate::{
    gatt::{
        ids::{AttHandle, ConnectionId, TransactionId},
        ids::{AttHandle, ConnectionId, ServerId, TransactionId, TransportIndex},
        GattCallbacks,
    },
    packets::{AttAttributeDataChild, AttAttributeDataView, AttErrorCode},
@@ -48,7 +48,7 @@ const TIMEOUT: Duration = Duration::from_secs(15);
/// The cause of a failure to dispatch a call to send_response()
#[derive(Debug, PartialEq, Eq)]
pub enum CallbackResponseError {
    /// The TransactionId supplied was invalid
    /// The TransactionId supplied was invalid for the specified connection
    NonExistentTransaction(TransactionId),
    /// The TransactionId was valid but has since terminated
    ListenerHungUp(TransactionId),
@@ -86,6 +86,11 @@ impl CallbackTransactionManager {
            Err(CallbackResponseError::NonExistentTransaction(trans_id))
        }
    }

    /// Get an impl GattDatastore tied to a particular server
    pub fn get_datastore(self: &Rc<Self>, server_id: ServerId) -> impl RawGattDatastore {
        GattDatastoreImpl { callback_transaction_manager: self.clone(), server_id }
    }
}

impl PendingTransactionsState {
@@ -124,37 +129,58 @@ impl PendingTransactionWatcher {
    }
}

struct GattDatastoreImpl {
    callback_transaction_manager: Rc<CallbackTransactionManager>,
    server_id: ServerId,
}

#[async_trait(?Send)]
impl RawGattDatastore for CallbackTransactionManager {
impl RawGattDatastore for GattDatastoreImpl {
    async fn read(
        &self,
        conn_id: ConnectionId,
        tcb_idx: TransportIndex,
        handle: AttHandle,
        offset: u32,
        attr_type: AttributeBackingType,
    ) -> Result<AttAttributeDataChild, AttErrorCode> {
        let pending_transaction =
            self.pending_transactions.borrow_mut().start_new_transaction(conn_id);
        let conn_id = ConnectionId::new(tcb_idx, self.server_id);

        let pending_transaction = self
            .callback_transaction_manager
            .pending_transactions
            .borrow_mut()
            .start_new_transaction(conn_id);
        let trans_id = pending_transaction.trans_id;

        self.callbacks.on_server_read(conn_id, trans_id, handle, attr_type, offset);
        self.callback_transaction_manager.callbacks.on_server_read(
            ConnectionId::new(tcb_idx, self.server_id),
            trans_id,
            handle,
            attr_type,
            offset,
        );

        pending_transaction.wait(self).await
        pending_transaction.wait(&self.callback_transaction_manager).await
    }

    async fn write(
        &self,
        conn_id: ConnectionId,
        tcb_idx: TransportIndex,
        handle: AttHandle,
        attr_type: AttributeBackingType,
        write_type: GattWriteRequestType,
        data: AttAttributeDataView<'_>,
    ) -> Result<(), AttErrorCode> {
        let pending_transaction =
            self.pending_transactions.borrow_mut().start_new_transaction(conn_id);
        let conn_id = ConnectionId::new(tcb_idx, self.server_id);

        let pending_transaction = self
            .callback_transaction_manager
            .pending_transactions
            .borrow_mut()
            .start_new_transaction(conn_id);
        let trans_id = pending_transaction.trans_id;

        self.callbacks.on_server_write(
        self.callback_transaction_manager.callbacks.on_server_write(
            conn_id,
            trans_id,
            handle,
@@ -164,18 +190,24 @@ impl RawGattDatastore for CallbackTransactionManager {
        );

        // the data passed back is irrelevant for write requests
        pending_transaction.wait(self).await.map(|_| ())
        pending_transaction.wait(&self.callback_transaction_manager).await.map(|_| ())
    }

    fn write_no_response(
        &self,
        conn_id: ConnectionId,
        tcb_idx: TransportIndex,
        handle: AttHandle,
        attr_type: AttributeBackingType,
        data: AttAttributeDataView<'_>,
    ) {
        let trans_id = self.pending_transactions.borrow_mut().alloc_transaction_id();
        self.callbacks.on_server_write(
        let conn_id = ConnectionId::new(tcb_idx, self.server_id);

        let trans_id = self
            .callback_transaction_manager
            .pending_transactions
            .borrow_mut()
            .alloc_transaction_id();
        self.callback_transaction_manager.callbacks.on_server_write(
            conn_id,
            trans_id,
            handle,
@@ -187,16 +219,21 @@ impl RawGattDatastore for CallbackTransactionManager {

    async fn execute(
        &self,
        conn_id: ConnectionId,
        tcb_idx: TransportIndex,
        decision: TransactionDecision,
    ) -> Result<(), AttErrorCode> {
        let pending_transaction =
            self.pending_transactions.borrow_mut().start_new_transaction(conn_id);
        let conn_id = ConnectionId::new(tcb_idx, self.server_id);

        let pending_transaction = self
            .callback_transaction_manager
            .pending_transactions
            .borrow_mut()
            .start_new_transaction(conn_id);
        let trans_id = pending_transaction.trans_id;

        self.callbacks.on_execute(conn_id, trans_id, decision);
        self.callback_transaction_manager.callbacks.on_execute(conn_id, trans_id, decision);

        // the data passed back is irrelevant for execute requests
        pending_transaction.wait(self).await.map(|_| ())
        pending_transaction.wait(&self.callback_transaction_manager).await.map(|_| ())
    }
}
+2 −2
Original line number Diff line number Diff line
@@ -392,7 +392,7 @@ fn add_service(server_id: u8, service_records: Vec<GattRecord>) {
                let ok = modules.gatt_module.register_gatt_service(
                    server_id,
                    service.clone(),
                    modules.gatt_incoming_callbacks.clone(),
                    modules.gatt_incoming_callbacks.get_datastore(server_id),
                );
                match ok {
                    Ok(_) => info!(
@@ -477,7 +477,7 @@ fn send_indication(_server_id: u8, handle: u16, conn_id: u16, value: &[u8]) {
    trace!("send_indication {handle:?}, {conn_id:?}");

    do_in_rust_thread(move |modules| {
        let Some(bearer) = modules.gatt_module.get_bearer(conn_id) else {
        let Some(bearer) = modules.gatt_module.get_bearer(conn_id.get_tcb_idx()) else {
            error!("connection {conn_id:?} does not exist");
            return;
        };
Loading