Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c752bf99 authored by Rahul Arya's avatar Rahul Arya Committed by Automerger Merge Worker
Browse files

Merge "[Private GATT] Refactor GattDatastore" am: ac88d4cb

parents 6770cbde ac88d4cb
Loading
Loading
Loading
Loading
+353 −4
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ mod callback_transaction_manager;
pub use callback_transaction_manager::{CallbackResponseError, CallbackTransactionManager};

use async_trait::async_trait;
use log::warn;

use crate::packets::{AttAttributeDataChild, AttAttributeDataView, AttErrorCode};

@@ -28,7 +29,6 @@ pub trait GattCallbacks {
        handle: AttHandle,
        attr_type: AttributeBackingType,
        offset: u32,
        is_long: bool,
    );

    /// Invoked when a client tries to write a characteristic/descriptor.
@@ -40,9 +40,7 @@ pub trait GattCallbacks {
        trans_id: TransactionId,
        handle: AttHandle,
        attr_type: AttributeBackingType,
        offset: u32,
        need_response: bool,
        is_prepare: bool,
        write_type: GattWriteType,
        value: AttAttributeDataView,
    );

@@ -53,11 +51,90 @@ pub trait GattCallbacks {
        conn_id: ConnectionId,
        result: Result<(), IndicationError>,
    );

    /// Execute or cancel any prepared writes
    fn on_execute(
        &self,
        conn_id: ConnectionId,
        trans_id: TransactionId,
        decision: TransactionDecision,
    );
}

/// The various write types available (requests + commands)
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub enum GattWriteType {
    /// Reliable, expects a response (WRITE_REQ or PREPARE_WRITE_REQ)
    Request(GattWriteRequestType),
    /// Unreliable, no response required (WRITE_CMD)
    Command,
}

/// The types of write requests (that need responses)
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub enum GattWriteRequestType {
    /// Atomic (WRITE_REQ)
    Request,
    /// Transactional, should not be committed yet (PREPARE_WRITE_REQ)
    Prepare {
        /// The byte offset at which to write
        offset: u32,
    },
}

/// Whether to commit or cancel a transaction
#[derive(Clone, Copy, Debug)]
pub enum TransactionDecision {
    /// Commit all pending writes
    Execute,
    /// Discard all pending writes
    Cancel,
}

/// This interface is an "async" version of the above, and is passed directly
/// into the GattModule
#[async_trait(?Send)]
pub trait RawGattDatastore {
    /// Read a characteristic from the specified connection at the given handle.
    async fn read(
        &self,
        conn_id: ConnectionId,
        handle: AttHandle,
        offset: u32,
        attr_type: AttributeBackingType,
    ) -> Result<AttAttributeDataChild, AttErrorCode>;

    /// Write data to a given characteristic on the specified connection.
    async fn write(
        &self,
        conn_id: ConnectionId,
        handle: AttHandle,
        attr_type: AttributeBackingType,
        write_type: GattWriteRequestType,
        data: AttAttributeDataView<'_>,
    ) -> Result<(), AttErrorCode>;

    /// Write data to a given characteristic on the specified connection, without waiting
    /// for a response from the upper layer.
    fn write_no_response(
        &self,
        conn_id: ConnectionId,
        handle: AttHandle,
        attr_type: AttributeBackingType,
        data: AttAttributeDataView<'_>,
    );

    /// Execute or cancel any prepared writes
    async fn execute(
        &self,
        conn_id: ConnectionId,
        decision: TransactionDecision,
    ) -> Result<(), AttErrorCode>;
}

/// This interface simplifies the interface of RawGattDatastore by rejecting all unsupported
/// operations, rather than requiring clients to do so.
#[async_trait(?Send)]
pub trait GattDatastore {
    /// Read a characteristic from the specified connection at the given handle.
    async fn read(
@@ -76,3 +153,275 @@ pub trait GattDatastore {
        data: AttAttributeDataView<'_>,
    ) -> Result<(), AttErrorCode>;
}

#[async_trait(?Send)]
impl<T: GattDatastore + ?Sized> RawGattDatastore for T {
    /// Read a characteristic from the specified connection at the given handle.
    async fn read(
        &self,
        conn_id: ConnectionId,
        handle: AttHandle,
        offset: u32,
        attr_type: AttributeBackingType,
    ) -> Result<AttAttributeDataChild, AttErrorCode> {
        if offset != 0 {
            warn!("got read blob request for non-long attribute {handle:?}");
            return Err(AttErrorCode::ATTRIBUTE_NOT_LONG);
        }
        self.read(conn_id, handle, attr_type).await
    }

    /// Write data to a given characteristic on the specified connection.
    async fn write(
        &self,
        conn_id: ConnectionId,
        handle: AttHandle,
        attr_type: AttributeBackingType,
        write_type: GattWriteRequestType,
        data: AttAttributeDataView<'_>,
    ) -> Result<(), AttErrorCode> {
        match write_type {
            GattWriteRequestType::Prepare { .. } => {
                warn!("got prepare write attempt to characteristic {handle:?} not supporting write_without_response");
                Err(AttErrorCode::WRITE_REQUEST_REJECTED)
            }
            GattWriteRequestType::Request => self.write(conn_id, handle, attr_type, data).await,
        }
    }

    fn write_no_response(
        &self,
        conn_id: ConnectionId,
        handle: AttHandle,
        _: AttributeBackingType,
        _: AttAttributeDataView<'_>,
    ) {
        // silently drop, since there's no way to return an error
        warn!("got write command on {conn_id:?} to characteristic {handle:?} not supporting write_without_response");
    }

    /// Execute or cancel any prepared writes
    async fn execute(&self, _: ConnectionId, _: TransactionDecision) -> Result<(), AttErrorCode> {
        // we never do prepared writes, so who cares
        return Ok(());
    }
}

#[cfg(test)]
mod test {
    use tokio::{sync::mpsc::error::TryRecvError, task::spawn_local};

    use crate::{
        gatt::mocks::mock_datastore::{MockDatastore, MockDatastoreEvents},
        packets::OwnedAttAttributeDataView,
        utils::{
            packet::{build_att_data, build_view_or_crash},
            task::block_on_locally,
        },
    };

    use super::*;

    const CONN_ID: ConnectionId = ConnectionId(1);
    const HANDLE: AttHandle = AttHandle(1);
    const DATA: [u8; 4] = [1, 2, 3, 4];

    #[test]
    fn test_regular_read_invoke() {
        block_on_locally(async {
            // arrange
            let (datastore, mut rx) = MockDatastore::new();

            // act: send read request
            spawn_local(async move {
                RawGattDatastore::read(
                    &datastore,
                    CONN_ID,
                    HANDLE,
                    0,
                    AttributeBackingType::Characteristic,
                )
                .await
            });
            let resp = rx.recv().await.unwrap();

            // assert: got read event
            assert!(matches!(
                resp,
                MockDatastoreEvents::Read(CONN_ID, HANDLE, AttributeBackingType::Characteristic, _)
            ));
        });
    }

    #[test]
    fn test_regular_read_response() {
        block_on_locally(async {
            // arrange
            let (datastore, mut rx) = MockDatastore::new();

            // act: send read request
            let pending = spawn_local(async move {
                RawGattDatastore::read(
                    &datastore,
                    CONN_ID,
                    HANDLE,
                    0,
                    AttributeBackingType::Characteristic,
                )
                .await
            });
            let resp = rx.recv().await.unwrap();
            let MockDatastoreEvents::Read(_, _, _, resp) = resp else {
                unreachable!();
            };
            resp.send(Err(AttErrorCode::APPLICATION_ERROR)).unwrap();

            // assert: got the supplied response
            assert_eq!(pending.await.unwrap(), Err(AttErrorCode::APPLICATION_ERROR));
        });
    }

    #[test]
    fn test_rejected_read_blob() {
        // arrange
        let (datastore, mut rx) = MockDatastore::new();

        // act: send read blob request
        let resp = block_on_locally(RawGattDatastore::read(
            &datastore,
            CONN_ID,
            HANDLE,
            1,
            AttributeBackingType::Characteristic,
        ));

        // assert: got the correct error code
        assert_eq!(resp, Err(AttErrorCode::ATTRIBUTE_NOT_LONG));
        // assert: no pending events
        assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
    }

    fn make_data() -> OwnedAttAttributeDataView {
        build_view_or_crash(build_att_data(AttAttributeDataChild::RawData(DATA.into())))
    }

    #[test]
    fn test_write_request_invoke() {
        block_on_locally(async {
            // arrange
            let (datastore, mut rx) = MockDatastore::new();

            // act: send write request
            spawn_local(async move {
                RawGattDatastore::write(
                    &datastore,
                    CONN_ID,
                    HANDLE,
                    AttributeBackingType::Characteristic,
                    GattWriteRequestType::Request,
                    make_data().view(),
                )
                .await
            });
            let resp = rx.recv().await.unwrap();

            // assert: got write event
            assert!(matches!(
                resp,
                MockDatastoreEvents::Write(
                    CONN_ID,
                    HANDLE,
                    AttributeBackingType::Characteristic,
                    _,
                    _
                )
            ));
        });
    }

    #[test]
    fn test_write_request_response() {
        block_on_locally(async {
            // arrange
            let (datastore, mut rx) = MockDatastore::new();

            // act: send write request
            let pending = spawn_local(async move {
                RawGattDatastore::write(
                    &datastore,
                    CONN_ID,
                    HANDLE,
                    AttributeBackingType::Characteristic,
                    GattWriteRequestType::Request,
                    make_data().view(),
                )
                .await
            });
            let resp = rx.recv().await.unwrap();
            let MockDatastoreEvents::Write(_, _, _, _, resp) = resp else {
                unreachable!();
            };
            resp.send(Err(AttErrorCode::APPLICATION_ERROR)).unwrap();

            // assert: got the supplied response
            assert_eq!(pending.await.unwrap(), Err(AttErrorCode::APPLICATION_ERROR));
        });
    }

    #[test]
    fn test_rejected_prepared_write() {
        // arrange
        let (datastore, mut rx) = MockDatastore::new();

        // act: send prepare write request
        let resp = block_on_locally(RawGattDatastore::write(
            &datastore,
            CONN_ID,
            HANDLE,
            AttributeBackingType::Characteristic,
            GattWriteRequestType::Prepare { offset: 1 },
            make_data().view(),
        ));

        // assert: got the correct error code
        assert_eq!(resp, Err(AttErrorCode::WRITE_REQUEST_REJECTED));
        // assert: no event sent up
        assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
    }

    #[test]
    fn test_dropped_write_command() {
        // arrange
        let (datastore, mut rx) = MockDatastore::new();

        // act: send write command
        RawGattDatastore::write_no_response(
            &datastore,
            CONN_ID,
            HANDLE,
            AttributeBackingType::Characteristic,
            make_data().view(),
        );

        // assert: no event sent up
        assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
    }

    #[test]
    fn test_execute_noop() {
        // arrange
        let (datastore, mut rx) = MockDatastore::new();

        // act: send execute request
        let resp = block_on_locally(RawGattDatastore::execute(
            &datastore,
            CONN_ID,
            TransactionDecision::Execute,
        ));

        // assert: succeeds trivially
        assert!(resp.is_ok());
        // assert: no event sent up
        assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
    }
}
+79 −56
Original line number Diff line number Diff line
use std::{cell::RefCell, collections::HashMap, rc::Rc, time::Duration};

use async_trait::async_trait;
use log::{error, trace, warn};
use log::{trace, warn};
use tokio::{sync::oneshot, time::timeout};

use crate::{
@@ -12,7 +12,10 @@ use crate::{
    packets::{AttAttributeDataChild, AttAttributeDataView, AttErrorCode},
};

use super::{AttributeBackingType, GattDatastore};
use super::{
    AttributeBackingType, GattWriteRequestType, GattWriteType, RawGattDatastore,
    TransactionDecision,
};

struct PendingTransaction {
    response: oneshot::Sender<Result<AttAttributeDataChild, AttErrorCode>>,
@@ -25,33 +28,6 @@ struct PendingTransactionWatcher {
    rx: oneshot::Receiver<Result<AttAttributeDataChild, AttErrorCode>>,
}

enum PendingTransactionError {
    SenderDropped,
    Timeout,
}

impl PendingTransactionWatcher {
    /// Wait for the transaction to resolve, or to hit the timeout. If the
    /// timeout is reached, clean up state related to transaction watching.
    async fn wait(
        self,
        manager: &CallbackTransactionManager,
    ) -> Result<Result<AttAttributeDataChild, AttErrorCode>, PendingTransactionError> {
        match timeout(TIMEOUT, self.rx).await {
            Ok(Ok(result)) => Ok(result),
            Ok(Err(_)) => Err(PendingTransactionError::SenderDropped),
            Err(_) => {
                manager
                    .pending_transactions
                    .borrow_mut()
                    .pending_transactions
                    .remove(&(self.conn_id, self.trans_id));
                Err(PendingTransactionError::Timeout)
            }
        }
    }
}

/// This struct converts the asynchronus read/write operations of GattDatastore
/// into the callback-based interface expected by JNI
pub struct CallbackTransactionManager {
@@ -113,41 +89,57 @@ impl CallbackTransactionManager {
}

impl PendingTransactionsState {
    fn start_new_transaction(&mut self, conn_id: ConnectionId) -> PendingTransactionWatcher {
    fn alloc_transaction_id(&mut self) -> TransactionId {
        let trans_id = TransactionId(self.next_transaction_id);
        self.next_transaction_id = self.next_transaction_id.wrapping_add(1);
        trans_id
    }

    fn start_new_transaction(&mut self, conn_id: ConnectionId) -> PendingTransactionWatcher {
        let trans_id = self.alloc_transaction_id();
        let (tx, rx) = oneshot::channel();
        self.pending_transactions.insert((conn_id, trans_id), PendingTransaction { response: tx });
        PendingTransactionWatcher { conn_id, trans_id, rx }
    }
}

impl PendingTransactionWatcher {
    /// Wait for the transaction to resolve, or to hit the timeout. If the
    /// timeout is reached, clean up state related to transaction watching.
    async fn wait(
        self,
        manager: &CallbackTransactionManager,
    ) -> Result<AttAttributeDataChild, AttErrorCode> {
        if let Ok(Ok(result)) = timeout(TIMEOUT, self.rx).await {
            result
        } else {
            manager
                .pending_transactions
                .borrow_mut()
                .pending_transactions
                .remove(&(self.conn_id, self.trans_id));
            warn!("no response received from Java after timeout - returning UNLIKELY_ERROR");
            Err(AttErrorCode::UNLIKELY_ERROR)
        }
    }
}

#[async_trait(?Send)]
impl GattDatastore for CallbackTransactionManager {
impl RawGattDatastore for CallbackTransactionManager {
    async fn read(
        &self,
        conn_id: ConnectionId,
        handle: AttHandle,
        offset: u32,
        attr_type: AttributeBackingType,
    ) -> Result<AttAttributeDataChild, AttErrorCode> {
        let pending_transaction =
            self.pending_transactions.borrow_mut().start_new_transaction(conn_id);
        let trans_id = pending_transaction.trans_id;

        self.callbacks.on_server_read(conn_id, trans_id, handle, attr_type, 0, false);
        self.callbacks.on_server_read(conn_id, trans_id, handle, attr_type, offset);

        match pending_transaction.wait(self).await {
            Ok(value) => value,
            Err(PendingTransactionError::SenderDropped) => {
                warn!("sender side of {trans_id:?} dropped / timed out while handling request - most likely this response will not be sent over the air");
                Err(AttErrorCode::UNLIKELY_ERROR)
            }
            Err(PendingTransactionError::Timeout) => {
                warn!("no response received from Java after timeout - returning UNLIKELY_ERROR");
                Err(AttErrorCode::UNLIKELY_ERROR)
            }
        }
        pending_transaction.wait(self).await
    }

    async fn write(
@@ -155,25 +147,56 @@ impl GattDatastore for CallbackTransactionManager {
        conn_id: ConnectionId,
        handle: AttHandle,
        attr_type: AttributeBackingType,
        write_type: GattWriteRequestType,
        data: AttAttributeDataView<'_>,
    ) -> Result<(), AttErrorCode> {
        let pending_transaction =
            self.pending_transactions.borrow_mut().start_new_transaction(conn_id);
        let trans_id = pending_transaction.trans_id;

        self.callbacks.on_server_write(conn_id, trans_id, handle, attr_type, 0, true, false, data);
        self.callbacks.on_server_write(
            conn_id,
            trans_id,
            handle,
            attr_type,
            GattWriteType::Request(write_type),
            data,
        );

        match pending_transaction.wait(self).await {
            Ok(value) => value.map(|_| ()), // the data passed back is irrelevant for write
            // requests
            Err(PendingTransactionError::SenderDropped) => {
                error!("the CallbackTransactionManager dropped the sender TX without sending it");
                Err(AttErrorCode::UNLIKELY_ERROR)
            }
            Err(PendingTransactionError::Timeout) => {
                warn!("no response received from Java after timeout - returning UNLIKELY_ERROR");
                Err(AttErrorCode::UNLIKELY_ERROR)
            }
        // the data passed back is irrelevant for write requests
        pending_transaction.wait(self).await.map(|_| ())
    }

    fn write_no_response(
        &self,
        conn_id: ConnectionId,
        handle: AttHandle,
        attr_type: AttributeBackingType,
        data: AttAttributeDataView<'_>,
    ) {
        let trans_id = self.pending_transactions.borrow_mut().alloc_transaction_id();
        self.callbacks.on_server_write(
            conn_id,
            trans_id,
            handle,
            attr_type,
            GattWriteType::Command,
            data,
        );
    }

    async fn execute(
        &self,
        conn_id: ConnectionId,
        decision: TransactionDecision,
    ) -> Result<(), AttErrorCode> {
        let pending_transaction =
            self.pending_transactions.borrow_mut().start_new_transaction(conn_id);
        let trans_id = pending_transaction.trans_id;

        self.callbacks.on_execute(conn_id, trans_id, decision);

        // the data passed back is irrelevant for execute requests
        pending_transaction.wait(self).await.map(|_| ())
    }
}
+36 −11
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ use crate::{

use super::{
    arbiter::{self, with_arbiter},
    callbacks::{GattWriteRequestType, GattWriteType, TransactionDecision},
    channel::AttTransport,
    ids::{AdvertiserId, AttHandle, ConnectionId, ServerId, TransactionId, TransportIndex},
    server::{
@@ -97,6 +98,10 @@ mod inner {
            value: &[u8],
        );

        /// This callback is invoked when executing / cancelling a write
        #[cxx_name = "OnExecute"]
        fn on_execute(self: &GattServerCallbacks, conn_id: u16, trans_id: u32, execute: bool);

        /// This callback is invoked when an indication has been sent and the
        /// peer device has confirmed it, or if some error occurred.
        #[cxx_name = "OnIndicationSentConfirmation"]
@@ -190,12 +195,15 @@ impl GattCallbacks for GattCallbacksImpl {
        handle: AttHandle,
        attr_type: AttributeBackingType,
        offset: u32,
        is_long: bool,
    ) {
        self.0
            .as_ref()
            .unwrap()
            .on_server_read(conn_id.0, trans_id.0, handle.0, attr_type, offset, is_long);
        self.0.as_ref().unwrap().on_server_read(
            conn_id.0,
            trans_id.0,
            handle.0,
            attr_type,
            offset,
            offset != 0,
        );
    }

    fn on_server_write(
@@ -204,9 +212,7 @@ impl GattCallbacks for GattCallbacksImpl {
        trans_id: TransactionId,
        handle: AttHandle,
        attr_type: AttributeBackingType,
        offset: u32,
        need_response: bool,
        is_prepare: bool,
        write_type: GattWriteType,
        value: AttAttributeDataView,
    ) {
        self.0.as_ref().unwrap().on_server_write(
@@ -214,9 +220,12 @@ impl GattCallbacks for GattCallbacksImpl {
            trans_id.0,
            handle.0,
            attr_type,
            offset,
            need_response,
            is_prepare,
            match write_type {
                GattWriteType::Request(GattWriteRequestType::Prepare { offset }) => offset,
                _ => 0,
            },
            matches!(write_type, GattWriteType::Request { .. }),
            matches!(write_type, GattWriteType::Request(GattWriteRequestType::Prepare { .. })),
            &value.get_raw_payload().collect::<Vec<_>>(),
        );
    }
@@ -234,6 +243,22 @@ impl GattCallbacks for GattCallbacksImpl {
            },
        )
    }

    fn on_execute(
        &self,
        conn_id: ConnectionId,
        trans_id: TransactionId,
        decision: TransactionDecision,
    ) {
        self.0.as_ref().unwrap().on_execute(
            conn_id.0,
            trans_id.0,
            match decision {
                TransactionDecision::Execute => true,
                TransactionDecision::Cancel => false,
            },
        )
    }
}

/// Implementation of AttTransport wrapping the corresponding C++ method
+14 −0
Original line number Diff line number Diff line
@@ -121,5 +121,19 @@ void GattServerCallbacks::OnIndicationSentConfirmation(uint16_t conn_id,
                   base::Bind(callbacks.indication_sent_cb, conn_id, status));
}

void GattServerCallbacks::OnExecute(uint16_t conn_id, uint32_t trans_id,
                                    bool execute) const {
  auto addr = AddressOfConnection(conn_id);
  if (!addr.has_value()) {
    LOG_WARN("Dropping server execute write since connection %d not found",
             conn_id);
    return;
  }

  do_in_jni_thread(
      FROM_HERE, base::Bind(callbacks.request_exec_write_cb, conn_id, trans_id,
                            addr.value(), execute));
}

}  // namespace gatt
}  // namespace bluetooth
+2 −0
Original line number Diff line number Diff line
@@ -50,6 +50,8 @@ class GattServerCallbacks {

  void OnIndicationSentConfirmation(uint16_t conn_id, int status) const;

  void OnExecute(uint16_t conn_id, uint32_t trans_id, bool execute) const;

 private:
  const btgatt_server_callbacks_t& callbacks;
};
Loading