Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit cf25dea9 authored by Rahul Arya's avatar Rahul Arya
Browse files

[Private GATT] Refactor to remove Rc<>

This CL introduces the abstraction of SharedBox<>, which
wraps Rc<> to avoid accidental circular references / leaks.

Without this, we could leak AttServerBearers by e.g. storing
an Rc<> to them in a spawned future, or by actually having a
circular reference.

Bug: 255880936
Test: unit

Change-Id: If60bd51d4ecda799de465d2e54a2976cf4e6793a
parent efac4fed
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
//! Shared data-types and utility methods go here.

mod ffi;
pub mod shared_box;
pub mod uuid;

use std::{rc::Rc, thread};
+90 −0
Original line number Diff line number Diff line
//! Wrapper around Rc<> to make ownership clearer
//!
//! The idea is to have ownership represented by a SharedBox<T>.
//! Temporary ownership can be held using a WeakBox<T>, which should
//! not be held across async points. This reduces the risk of accidental
//! lifetime extension.

use std::{
    ops::Deref,
    rc::{Rc, Weak},
};

/// A Box<> where static "weak" references to the contents can be taken,
/// and fallibly upgraded at a later point. Unlike Rc<>, weak references
/// cannot be upgraded back to owning references, so ownership remains clear
/// and reference cycles avoided.
#[derive(Debug)]
pub struct SharedBox<T: ?Sized>(Rc<T>);

impl<T> SharedBox<T> {
    /// Constructor
    pub fn new(t: T) -> Self {
        Self(t.into())
    }

    /// Produce a weak reference to the contents
    pub fn downgrade(&self) -> WeakBox<T> {
        WeakBox(Rc::downgrade(&self.0))
    }

    /// Produce an upgraded weak reference to the contents
    pub fn as_ref(&self) -> WeakBoxRef<T> {
        WeakBoxRef(self.0.deref(), Rc::downgrade(&self.0))
    }
}

impl<T> From<T> for SharedBox<T> {
    fn from(value: T) -> Self {
        Self(value.into())
    }
}

impl<T> Deref for SharedBox<T> {
    type Target = T;

    fn deref(&self) -> &Self::Target {
        self.0.deref()
    }
}

/// A weak reference to the contents within a SharedBox<>
pub struct WeakBox<T>(Weak<T>);

impl<T> WeakBox<T> {
    /// Fallibly upgrade to a strong reference, passed into the supplied closure.
    /// The strong reference is not passed into the closure to avoid accidental
    /// lifetime extension.
    ///
    /// Note: reference-counting is used so that, if the passed-in closure drops
    /// the SharedBox<>, the strong reference remains safe. But please don't
    /// do that!
    pub fn with<U>(&self, f: impl FnOnce(Option<WeakBoxRef<T>>) -> U) -> U {
        f(self.0.upgrade().as_deref().map(|x| WeakBoxRef(x, self.0.clone())))
    }
}

impl<T> Clone for WeakBox<T> {
    fn clone(&self) -> Self {
        Self(self.0.clone())
    }
}

/// A strong reference to the contents within a SharedBox<>.
pub struct WeakBoxRef<'a, T>(&'a T, Weak<T>);

impl<'a, T> WeakBoxRef<'a, T> {
    /// Downgrade to a weak reference (with static lifetime) to the contents
    /// within the underlying SharedBox<>
    pub fn downgrade(&self) -> WeakBox<T> {
        WeakBox(self.1.clone())
    }
}

impl<'a, T> Deref for WeakBoxRef<'a, T> {
    type Target = T;

    fn deref(&self) -> &Self::Target {
        self.0
    }
}
+7 −5
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@

use std::{collections::HashMap, sync::Mutex};

use log::{error, info};
use log::{error, info, trace};

use crate::{
    do_in_rust_thread,
@@ -81,7 +81,7 @@ impl Arbiter {
    }

    /// Test to see if a buffer contains a valid ATT packet with an opcode we
    /// are interested in intercepting
    /// are interested in intercepting (those intended for servers)
    pub fn try_parse_att_server_packet(
        &self,
        tcb_idx: TransportIndex,
@@ -158,9 +158,11 @@ fn intercept_packet(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction {
        arbiter.try_parse_att_server_packet(TransportIndex(tcb_idx), packet.into_boxed_slice())
    }) {
        do_in_rust_thread(move |modules| {
            info!("pushing packet to GATT");
            if let Err(err) = modules.gatt_module.handle_packet(conn_id, att.view()) {
                error!("{:?}", err.context("failed to push packet to GATT"))
            trace!("pushing packet to GATT");
            if let Some(bearer) = modules.gatt_module.get_bearer(conn_id) {
                bearer.handle_packet(att.view())
            } else {
                error!("{conn_id:?} closed, bearer does not exist");
            }
        });
        InterceptAction::Drop
+14 −16
Original line number Diff line number Diff line
@@ -13,8 +13,8 @@ mod test;
use std::{collections::HashMap, rc::Rc};

use crate::{
    core::shared_box::{SharedBox, WeakBoxRef},
    gatt::{ids::ConnectionId, server::gatt_database::GattDatabase},
    packets::AttView,
};

use self::{
@@ -30,8 +30,8 @@ use log::info;
#[allow(missing_docs)]
pub struct GattModule {
    connection_bearers:
        HashMap<ConnectionId, Rc<AttServerBearer<AttDatabaseImpl<dyn GattDatastore>>>>,
    databases: HashMap<ServerId, Rc<GattDatabase<dyn GattDatastore>>>,
        HashMap<ConnectionId, SharedBox<AttServerBearer<AttDatabaseImpl<dyn GattDatastore>>>>,
    databases: HashMap<ServerId, SharedBox<GattDatabase<dyn GattDatastore>>>,
    datastore: Rc<dyn GattDatastore>,
    transport: Rc<dyn AttTransport>,
}
@@ -58,7 +58,8 @@ impl GattModule {
            conn_id,
            AttServerBearer::new(database.get_att_database(conn_id), move |packet| {
                transport.send_packet(conn_id.get_tcb_idx(), packet)
            }),
            })
            .into(),
        );
        Ok(())
    }
@@ -70,15 +71,6 @@ impl GattModule {
        self.datastore.remove_connection(conn_id);
    }

    /// Handle an incoming ATT packet
    pub fn handle_packet(&mut self, conn_id: ConnectionId, packet: AttView<'_>) -> Result<()> {
        self.connection_bearers
            .get(&conn_id)
            .ok_or_else(|| anyhow!("dropping ATT packet for unregistered connection"))?
            .handle_packet(packet);
        Ok(())
    }

    /// Register a new GATT service on a given server
    pub fn register_gatt_service(
        &mut self,
@@ -116,12 +108,18 @@ impl GattModule {
    /// Close a GATT server
    pub fn close_gatt_server(&mut self, server_id: ServerId) -> Result<()> {
        let old = self.databases.remove(&server_id);
        let Some(old) = old else {
        if old.is_none() {
            bail!("GATT server {server_id:?} did not exist")
        };

        old.clear_all_services();

        Ok(())
    }

    /// Get an ATT bearer for a particular connection
    pub fn get_bearer(
        &self,
        conn_id: ConnectionId,
    ) -> Option<WeakBoxRef<AttServerBearer<AttDatabaseImpl<dyn GattDatastore>>>> {
        self.connection_bearers.get(&conn_id).map(|x| x.as_ref())
    }
}
+33 −27
Original line number Diff line number Diff line
@@ -2,15 +2,13 @@
//! It handles ATT transactions and unacknowledged operations, backed by an
//! AttDatabase (that may in turn be backed by an upper-layer protocol)

use std::{
    cell::Cell,
    rc::{Rc, Weak},
};
use std::cell::Cell;

use log::{error, trace, warn};
use tokio::task::spawn_local;

use crate::{
    core::shared_box::WeakBoxRef,
    gatt::ids::AttHandle,
    packets::{
        AttBuilder, AttChild, AttErrorCode, AttErrorResponseBuilder, AttView, Packet,
@@ -29,8 +27,8 @@ enum AttTransaction<T: AttDatabase> {
const DEFAULT_ATT_MTU: usize = 23;

/// This represents a single ATT bearer (currently, always the unenhanced fixed
/// channel on LE) The AttTransaction ensures that only one transaction can take
/// place at a time
/// channel on LE) The AttTransactionHandler ensures that only one transaction
/// can take place at a time
pub struct AttServerBearer<T: AttDatabase> {
    curr_operation: Cell<AttTransaction<T>>,
    send_packet: Box<dyn Fn(AttBuilder) -> Result<(), SerializeError>>,
@@ -43,37 +41,40 @@ impl<T: AttDatabase + 'static> AttServerBearer<T> {
    pub fn new(
        db: T,
        send_packet: impl Fn(AttBuilder) -> Result<(), SerializeError> + 'static,
    ) -> Rc<Self> {
    ) -> Self {
        Self {
            curr_operation: AttTransaction::Idle(AttTransactionHandler::new(db)).into(),
            send_packet: Box::new(send_packet),
            mtu: Cell::new(DEFAULT_ATT_MTU),
        }
        .into()
    }
}

impl<T: AttDatabase + 'static> WeakBoxRef<'_, AttServerBearer<T>> {
    /// Handle an incoming packet, and send outgoing packets as appropriate
    /// using the owned ATT channel.
    pub fn handle_packet(self: &Rc<Self>, packet: AttView<'_>) {
    pub fn handle_packet(&self, packet: AttView<'_>) {
        let curr_operation = self.curr_operation.replace(AttTransaction::Pending(None));
        self.clone().curr_operation.replace(match curr_operation {
        self.curr_operation.replace(match curr_operation {
            AttTransaction::Idle(mut request_handler) => {
                // even if the MTU is updated afterwards, 5.3 3F 3.4.2.2 states that the request-time MTU should be used
                // even if the MTU is updated afterwards, 5.3 3F 3.4.2.2 states that the
                // request-time MTU should be used
                let mtu = self.mtu.get();
                let this = Rc::downgrade(self);
                let packet = packet.to_owned_packet();
                let this = self.downgrade();
                let task = spawn_local(async move {
                    trace!("starting ATT transaction");
                    let reply = request_handler.process_packet(packet.view(), mtu).await;
                    match Weak::upgrade(&this) {
                    this.with(|this| {
                    match this {
                        None => {
                            warn!("callback returned after disconnect");
                        }
                        Some(this) => {
                            trace!("sending reply packet");
                            if let Err(err) = this.send_response(reply) {
                            if let Err(err) = this.send_packet(reply) {
                                error!("serializer failure {err:?}, dropping packet and sending failed reply");
                                this.send_response(AttErrorResponseBuilder {
                                this.send_packet(AttErrorResponseBuilder {
                                    opcode_in_error: packet.view().get_opcode(),
                                    handle_in_error: AttHandle(0).into(),
                                    error_code: AttErrorCode::UNLIKELY_ERROR,
@@ -83,6 +84,7 @@ impl<T: AttDatabase + 'static> AttServerBearer<T> {
                            this.curr_operation.replace(AttTransaction::Idle(request_handler));
                        }
                    }
                    })
                });
                AttTransaction::Pending(Some(task.into()))
            }
@@ -94,7 +96,7 @@ impl<T: AttDatabase + 'static> AttServerBearer<T> {
        });
    }

    fn send_response(&self, packet: impl Into<AttChild>) -> Result<(), SerializeError> {
    fn send_packet(&self, packet: impl Into<AttChild>) -> Result<(), SerializeError> {
        let child = packet.into();
        let packet = AttBuilder { opcode: HACK_child_to_opcode(&child), _child_: child };
        (self.send_packet)(packet)
@@ -103,12 +105,14 @@ impl<T: AttDatabase + 'static> AttServerBearer<T> {

#[cfg(test)]
mod test {
    use std::rc::Rc;

    use tokio::sync::mpsc::{error::TryRecvError, unbounded_channel, UnboundedReceiver};

    use super::*;

    use crate::{
        core::uuid::Uuid,
        core::{shared_box::SharedBox, uuid::Uuid},
        gatt::{
            callbacks::GattDatastore,
            ids::ConnectionId,
@@ -136,7 +140,8 @@ mod test {

    const CONN_ID: ConnectionId = ConnectionId(1);

    fn open_connection() -> (Rc<AttServerBearer<TestAttDatabase>>, UnboundedReceiver<AttBuilder>) {
    fn open_connection(
    ) -> (SharedBox<AttServerBearer<TestAttDatabase>>, UnboundedReceiver<AttBuilder>) {
        let db = TestAttDatabase::new(vec![(
            AttAttribute {
                handle: VALID_HANDLE,
@@ -149,7 +154,8 @@ mod test {
        let conn = AttServerBearer::new(db, move |packet| {
            tx.send(packet).unwrap();
            Ok(())
        });
        })
        .into();
        (conn, rx)
    }

@@ -157,7 +163,7 @@ mod test {
    fn test_single_transaction() {
        block_on_locally(async {
            let (conn, mut rx) = open_connection();
            conn.handle_packet(
            conn.as_ref().handle_packet(
                build_att_view_or_crash(AttReadRequestBuilder {
                    attribute_handle: VALID_HANDLE.into(),
                })
@@ -172,7 +178,7 @@ mod test {
    fn test_sequential_transactions() {
        block_on_locally(async {
            let (conn, mut rx) = open_connection();
            conn.handle_packet(
            conn.as_ref().handle_packet(
                build_att_view_or_crash(AttReadRequestBuilder {
                    attribute_handle: INVALID_HANDLE.into(),
                })
@@ -181,7 +187,7 @@ mod test {
            assert_eq!(rx.recv().await.unwrap().opcode, AttOpcode::ERROR_RESPONSE);
            assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));

            conn.handle_packet(
            conn.as_ref().handle_packet(
                build_att_view_or_crash(AttReadRequestBuilder {
                    attribute_handle: VALID_HANDLE.into(),
                })
@@ -200,7 +206,7 @@ mod test {
        let datastore = Rc::new(datastore);
        datastore.add_connection(CONN_ID);
        data_rx.blocking_recv().unwrap(); // ignore AddConnection() event
        let db = Rc::new(GattDatabase::new(datastore));
        let db = SharedBox::new(GattDatabase::new(datastore));
        db.add_service_with_handles(GattServiceWithHandle {
            handle: AttHandle(1),
            type_: Uuid::new(1),
@@ -219,11 +225,11 @@ mod test {
        })
        .unwrap();
        let (tx, mut rx) = unbounded_channel();
        let send_response = move |packet| {
        let send_packet = move |packet| {
            tx.send(packet).unwrap();
            Ok(())
        };
        let conn = AttServerBearer::new(db.get_att_database(CONN_ID), send_response);
        let conn = SharedBox::new(AttServerBearer::new(db.get_att_database(CONN_ID), send_packet));
        let data = AttAttributeDataChild::RawData([1, 2].into());

        // act: send two read requests before replying to either read
@@ -232,12 +238,12 @@ mod test {
            let req1 = build_att_view_or_crash(AttReadRequestBuilder {
                attribute_handle: VALID_HANDLE.into(),
            });
            conn.handle_packet(req1.view());
            conn.as_ref().handle_packet(req1.view());
            // second request
            let req2 = build_att_view_or_crash(AttReadRequestBuilder {
                attribute_handle: ANOTHER_VALID_HANDLE.into(),
            });
            conn.handle_packet(req2.view());
            conn.as_ref().handle_packet(req2.view());
            // handle first reply
            let MockDatastoreEvents::ReadCharacteristic(CONN_ID, VALID_HANDLE, data_resp) =
                data_rx.recv().await.unwrap() else {
Loading