Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 9eb03a94 authored by Zach Johnson's avatar Zach Johnson
Browse files

rusty-gd: allow parting out injected values

this way we can inject only what we actually care about

Bug: 171749953
Tag: #gd-refactor
Test: gd/cert/run --rhost SimpleHalTest
Change-Id: I3798084eabb7a8ffcc9a48b982715c792d64ff8d
parent d442edb7
Loading
Loading
Loading
Loading
+69 −22
Original line number Diff line number Diff line
@@ -5,7 +5,10 @@ use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::parse::{Parse, ParseStream, Result};
use syn::punctuated::Punctuated;
use syn::{braced, parse, parse_macro_input, FnArg, Ident, ItemFn, Token, Type, DeriveInput, Path};
use syn::{
    braced, parse, parse_macro_input, DeriveInput, Fields, FnArg, Ident, ItemFn, ItemStruct, Path,
    Token, Type,
};

/// Defines a provider function, with generated helper that implicitly fetches argument instances from the registry
#[proc_macro_attribute]
@@ -57,6 +60,7 @@ enum ModuleEntry {
struct ProviderDef {
    ty: Type,
    ident: Ident,
    parts: bool,
}

impl Parse for ModuleDef {
@@ -75,30 +79,34 @@ impl Parse for ModuleDef {
                        panic!("providers specified more than once");
                    }
                    providers = value;
                },
                }
                ModuleEntry::Submodules(value) => {
                    if !submodules.is_empty() {
                        panic!("submodules specified more than once");
                    }
                    submodules = value;
                },
                }
            }
        Ok(ModuleDef {
            name,
            providers,
            submodules,
        })
        }
        Ok(ModuleDef { name, providers, submodules })
    }
}

impl Parse for ProviderDef {
    fn parse(input: ParseStream) -> Result<Self> {
        let parts = input.peek3(Token![=>]);
        if parts {
            match input.parse::<Ident>()?.to_string().as_str() {
                "parts" => {}
                keyword => panic!("expected 'parts', got '{}'", keyword),
            }
        }

        // A provider definition follows this format: <Type> -> <function name>
        let ty = input.parse()?;
        input.parse::<Token![=>]>()?;
        let ident = input.parse()?;
        Ok(ProviderDef { ty, ident })
        Ok(ProviderDef { ty, ident, parts })
    }
}

@@ -108,16 +116,12 @@ impl Parse for ModuleEntry {
            "providers" => {
                let entries;
                braced!(entries in input);
                Ok(ModuleEntry::Providers(
                    entries.parse_terminated(ProviderDef::parse)?,
                ))
                Ok(ModuleEntry::Providers(entries.parse_terminated(ProviderDef::parse)?))
            }
            "submodules" => {
                let entries;
                braced!(entries in input);
                Ok(ModuleEntry::Submodules(
                    entries.parse_terminated(Path::parse)?,
                ))
                Ok(ModuleEntry::Submodules(entries.parse_terminated(Path::parse)?))
            }
            keyword => {
                panic!("unexpected keyword: {}", keyword);
@@ -131,20 +135,30 @@ impl Parse for ModuleEntry {
pub fn module(item: TokenStream) -> TokenStream {
    let module = parse_macro_input!(item as ModuleDef);
    let init_ident = module.name.clone();
    let types = module.providers.iter().map(|p| p.ty.clone());
    let provider_idents = module
        .providers
        .iter()
        .map(|p| format_ident!("__gddi_{}_injected", p.ident.clone()));
    let providers = module.providers.iter();
    let types = providers.clone().map(|p| p.ty.clone());
    let provider_idents =
        providers.clone().map(|p| format_ident!("__gddi_{}_injected", p.ident.clone()));
    let parting_functions = providers.filter_map(|p| match &p.ty {
        Type::Path(ty) if p.parts => Some(format_ident!(
            "__gddi_part_out_{}",
            ty.path.get_ident().unwrap().to_string().to_lowercase()
        )),
        _ => None,
    });
    let submodule_idents = module.submodules.iter();
    let emitted_code = quote! {
        #[doc(hidden)]
        #[allow(missing_docs)]
        pub fn #init_ident(builder: gddi::RegistryBuilder) -> gddi::RegistryBuilder {
            // Register all providers on this module
            builder#(.register_provider::<#types>(Box::new(#provider_idents)))*
            let ret = builder#(.register_provider::<#types>(Box::new(#provider_idents)))*
            // Register all submodules on this module
            #(.register_module(#submodule_idents))*
            #(.register_module(#submodule_idents))*;

            #(let ret = #parting_functions(ret);)*

            ret
        }
    };
    emitted_code.into()
@@ -160,3 +174,36 @@ pub fn derive_nop_stop(item: TokenStream) -> TokenStream {
    };
    emitted_code.into()
}

/// Generates the code necessary to split up a type into its components
#[proc_macro_attribute]
pub fn part_out(_attr: TokenStream, item: TokenStream) -> TokenStream {
    let struct_: ItemStruct = parse(item).expect("can only be applied to struct definitions");
    let struct_ident = struct_.ident.clone();
    let fields = match struct_.fields.clone() {
        Fields::Named(f) => f,
        _ => panic!("can only be applied to structs with named fields"),
    }
    .named;

    let field_names = fields.iter().map(|f| f.ident.clone().expect("field without a name"));
    let field_types = fields.iter().map(|f| f.ty.clone());

    let fn_ident = format_ident!("__gddi_part_out_{}", struct_ident.to_string().to_lowercase());

    let emitted_code = quote! {
        #struct_

        fn #fn_ident(builder: gddi::RegistryBuilder) -> gddi::RegistryBuilder {
            builder#(.register_provider::<#field_types>(Box::new(
                |registry: std::sync::Arc<gddi::Registry>| -> std::pin::Pin<gddi::ProviderFutureBox> {
                    Box::pin(async move {
                        Box::new(async move {
                            registry.get::<#struct_ident>().await.#field_names
                        }) as Box<dyn std::any::Any>
                    })
                })))*
        }
    };
    emitted_code.into()
}
+4 −10
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@ use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;

pub use gddi_macros::{module, provides, Stoppable};
pub use gddi_macros::{module, part_out, provides, Stoppable};

type InstanceBox = Box<dyn Any + Send + Sync>;
/// A box around a future for a provider that is safe to send between threads
@@ -46,9 +46,7 @@ impl Default for RegistryBuilder {
impl RegistryBuilder {
    /// Creates a new RegistryBuilder
    pub fn new() -> Self {
        RegistryBuilder {
            providers: HashMap::new(),
        }
        RegistryBuilder { providers: HashMap::new() }
    }

    /// Registers a module with this registry
@@ -61,8 +59,7 @@ impl RegistryBuilder {

    /// Registers a provider function with this registry
    pub fn register_provider<T: 'static>(mut self, f: ProviderFnBox) -> Self {
        self.providers
            .insert(TypeId::of::<T>(), Provider { f: Arc::new(f) });
        self.providers.insert(TypeId::of::<T>(), Provider { f: Arc::new(f) });

        self
    }
@@ -84,10 +81,7 @@ impl Registry {
        {
            let instances = self.instances.lock().await;
            if let Some(value) = instances.get(&typeid) {
                return value
                    .downcast_ref::<T>()
                    .expect("was not correct type")
                    .clone();
                return value.downcast_ref::<T>().expect("was not correct type").clone();
            }
        }

+3 −3
Original line number Diff line number Diff line
//! Loads info from the controller at startup

use crate::{Address, Hci};
use crate::{Address, CommandSender};
use bt_packets::hci::{
    Enable, ErrorCode, LeMaximumDataLength, LeReadBufferSizeV1Builder, LeReadBufferSizeV2Builder,
    LeReadConnectListSizeBuilder, LeReadLocalSupportedFeaturesBuilder,
@@ -34,7 +34,7 @@ macro_rules! assert_success {
}

#[provides]
async fn provide_controller(mut hci: Hci) -> Arc<ControllerExports> {
async fn provide_controller(mut hci: CommandSender) -> Arc<ControllerExports> {
    assert_success!(hci.send(LeSetEventMaskBuilder { le_event_mask: 0x0000000000021e7f }));
    assert_success!(hci.send(SetEventMaskBuilder { event_mask: 0x3dbfffffffffffff }));
    assert_success!(
@@ -167,7 +167,7 @@ async fn provide_controller(mut hci: Hci) -> Arc<ControllerExports> {
    })
}

async fn read_features(hci: &mut Hci) -> SupportedFeatures {
async fn read_features(hci: &mut CommandSender) -> SupportedFeatures {
    let mut features = Vec::new();
    let mut page_number: u8 = 0;
    let mut max_page_number: u8 = 1;
+23 −12
Original line number Diff line number Diff line
//! HCI layer facade

use crate::Hci;
use crate::{EventRegistry, HciForAcl, RawCommandSender};
use bt_common::GrpcFacade;
use bt_facade_proto::common::Data;
use bt_facade_proto::empty::Empty;
@@ -26,11 +26,18 @@ module! {
}

#[provides]
async fn provide_facade(hci: Hci, rt: Arc<Runtime>) -> HciFacadeService {
async fn provide_facade(
    commands: RawCommandSender,
    events: EventRegistry,
    acl: HciForAcl,
    rt: Arc<Runtime>,
) -> HciFacadeService {
    let (from_hci_evt_tx, to_grpc_evt_rx) = channel::<EventPacket>(10);
    let (from_hci_le_evt_tx, to_grpc_le_evt_rx) = channel::<LeMetaEventPacket>(10);
    HciFacadeService {
        hci,
        commands,
        events,
        acl,
        rt,
        from_hci_evt_tx,
        to_grpc_evt_rx: Arc::new(Mutex::new(to_grpc_evt_rx)),
@@ -42,7 +49,9 @@ async fn provide_facade(hci: Hci, rt: Arc<Runtime>) -> HciFacadeService {
/// HCI layer facade service
#[derive(Clone, Stoppable)]
pub struct HciFacadeService {
    hci: Hci,
    commands: RawCommandSender,
    events: EventRegistry,
    acl: HciForAcl,
    rt: Arc<Runtime>,
    from_hci_evt_tx: Sender<EventPacket>,
    to_grpc_evt_rx: Arc<Mutex<Receiver<EventPacket>>>,
@@ -59,16 +68,18 @@ impl GrpcFacade for HciFacadeService {
impl HciFacade for HciFacadeService {
    fn send_command(&mut self, _ctx: RpcContext<'_>, mut data: Data, sink: UnarySink<Empty>) {
        self.rt
            .block_on(self.hci.send_raw(CommandPacket::parse(&data.take_payload()).unwrap()))
            .block_on(self.commands.send(CommandPacket::parse(&data.take_payload()).unwrap()))
            .unwrap();
        sink.success(Empty::default());
    }

    fn request_event(&mut self, _ctx: RpcContext<'_>, req: EventRequest, sink: UnarySink<Empty>) {
        self.rt.block_on(self.hci.register_event_handler(
        self.rt.block_on(
            self.events.register(
                EventCode::from_u32(req.get_code()).unwrap(),
                self.from_hci_evt_tx.clone(),
        ));
            ),
        );
        sink.success(Empty::default());
    }

@@ -78,7 +89,7 @@ impl HciFacade for HciFacadeService {
        req: EventRequest,
        sink: UnarySink<Empty>,
    ) {
        self.rt.block_on(self.hci.register_le_event_handler(
        self.rt.block_on(self.events.register_le(
            SubeventCode::from_u32(req.get_code()).unwrap(),
            self.from_hci_le_evt_tx.clone(),
        ));
@@ -86,7 +97,7 @@ impl HciFacade for HciFacadeService {
    }

    fn send_acl(&mut self, _ctx: RpcContext<'_>, mut packet: Data, sink: UnarySink<Empty>) {
        let acl_tx = self.hci.acl_tx.clone();
        let acl_tx = self.acl.tx.clone();
        self.rt.block_on(async move {
            acl_tx.send(AclPacket::parse(&packet.take_payload()).unwrap()).await.unwrap();
        });
@@ -133,7 +144,7 @@ impl HciFacade for HciFacadeService {
        _req: Empty,
        mut resp: ServerStreamingSink<Data>,
    ) {
        let acl_rx = self.hci.acl_rx.clone();
        let acl_rx = self.acl.rx.clone();

        self.rt.spawn(async move {
            while let Some(data) = acl_rx.lock().await.recv().await {
+55 −25
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ use bt_packets::hci::{
    LeMetaEventPacket, ResetBuilder, SubeventCode,
};
use error::Result;
use gddi::{module, provides, Stoppable};
use gddi::{module, part_out, provides, Stoppable};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
@@ -37,10 +37,19 @@ module! {
        controller::controller_module,
    },
    providers {
        Hci => provide_hci,
        parts Hci => provide_hci,
    },
}

#[part_out]
#[derive(Clone, Stoppable)]
struct Hci {
    raw_commands: RawCommandSender,
    commands: CommandSender,
    events: EventRegistry,
    acl: HciForAcl,
}

#[provides]
async fn provide_hci(hal: Hal, rt: Arc<Runtime>) -> Hci {
    let (cmd_tx, cmd_rx) = channel::<QueuedCommand>(10);
@@ -55,15 +64,20 @@ async fn provide_hci(hal: Hal, rt: Arc<Runtime>) -> Hci {
        cmd_rx,
    ));

    let mut hci =
        Hci { cmd_tx, evt_handlers, le_evt_handlers, acl_tx: hal.acl_tx, acl_rx: hal.acl_rx };
    let raw_commands = RawCommandSender { cmd_tx };
    let mut commands = CommandSender { raw: raw_commands.clone() };

    assert!(
        hci.send(ResetBuilder {}).await.get_status() == ErrorCode::Success,
        commands.send(ResetBuilder {}).await.get_status() == ErrorCode::Success,
        "reset did not complete successfully"
    );

    hci
    Hci {
        raw_commands,
        commands,
        events: EventRegistry { evt_handlers, le_evt_handlers },
        acl: HciForAcl { tx: hal.acl_tx, rx: hal.acl_rx },
    }
}

#[derive(Debug)]
@@ -72,39 +86,59 @@ struct QueuedCommand {
    fut: oneshot::Sender<EventPacket>,
}

/// HCI interface
/// Sends raw commands. Only useful for facades & shims, or wrapped as a CommandSender.
#[derive(Clone, Stoppable)]
pub struct Hci {
pub struct RawCommandSender {
    cmd_tx: Sender<QueuedCommand>,
    evt_handlers: Arc<Mutex<HashMap<EventCode, Sender<EventPacket>>>>,
    le_evt_handlers: Arc<Mutex<HashMap<SubeventCode, Sender<LeMetaEventPacket>>>>,
    /// Transmit end of a channel used to send ACL data
    pub acl_tx: Sender<AclPacket>,
    /// Receive end of a channel used to receive ACL data
    pub acl_rx: Arc<Mutex<Receiver<AclPacket>>>,
}

impl Hci {
impl RawCommandSender {
    /// Send a command, but does not automagically associate the expected returning event type.
    ///
    /// Only really useful for facades & shims.
    pub async fn send_raw(&mut self, cmd: CommandPacket) -> Result<EventPacket> {
    pub async fn send(&mut self, cmd: CommandPacket) -> Result<EventPacket> {
        let (tx, rx) = oneshot::channel::<EventPacket>();
        self.cmd_tx.send(QueuedCommand { cmd, fut: tx }).await?;
        let event = rx.await?;
        Ok(event)
    }
}

/// Sends commands to the controller
#[derive(Clone, Stoppable)]
pub struct CommandSender {
    raw: RawCommandSender,
}

impl CommandSender {
    /// Send a command to the controller, getting an expected response back
    pub async fn send<T: Into<CommandPacket> + CommandExpectations>(
        &mut self,
        cmd: T,
    ) -> T::ResponseType {
        T::_to_response_type(self.send_raw(cmd.into()).await.unwrap())
        T::_to_response_type(self.raw.send(cmd.into()).await.unwrap())
    }
}

/// Exposes the ACL send/receive interface
#[derive(Clone, Stoppable)]
pub struct HciForAcl {
    /// Transmit end
    pub tx: Sender<AclPacket>,
    /// Receive end
    pub rx: Arc<Mutex<Receiver<AclPacket>>>,
}

/// Provides ability to register and unregister for HCI events
#[derive(Clone, Stoppable)]
pub struct EventRegistry {
    evt_handlers: Arc<Mutex<HashMap<EventCode, Sender<EventPacket>>>>,
    le_evt_handlers: Arc<Mutex<HashMap<SubeventCode, Sender<LeMetaEventPacket>>>>,
}

impl EventRegistry {
    /// Indicate interest in specific HCI events
    pub async fn register_event_handler(&mut self, code: EventCode, sender: Sender<EventPacket>) {
    pub async fn register(&mut self, code: EventCode, sender: Sender<EventPacket>) {
        match code {
            EventCode::CommandStatus
            | EventCode::CommandComplete
@@ -123,16 +157,12 @@ impl Hci {
    }

    /// Remove interest in specific HCI events
    pub async fn unregister_event_handler(&mut self, code: EventCode) {
    pub async fn unregister(&mut self, code: EventCode) {
        self.evt_handlers.lock().await.remove(&code);
    }

    /// Indicate interest in specific LE events
    pub async fn register_le_event_handler(
        &mut self,
        code: SubeventCode,
        sender: Sender<LeMetaEventPacket>,
    ) {
    pub async fn register_le(&mut self, code: SubeventCode, sender: Sender<LeMetaEventPacket>) {
        assert!(
            self.le_evt_handlers.lock().await.insert(code, sender).is_none(),
            "A handler for {:?} is already registered",
@@ -141,7 +171,7 @@ impl Hci {
    }

    /// Remove interest in specific LE events
    pub async fn unregister_le_event_handler(&mut self, code: SubeventCode) {
    pub async fn unregister_le(&mut self, code: SubeventCode) {
        self.le_evt_handlers.lock().await.remove(&code);
    }
}
Loading