Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 43bf1607 authored by Zach Johnson's avatar Zach Johnson
Browse files

rusty-gd: some facade cleanup

Bug: 171749953
Tag: #gd-refactor
Test: gd/cert/run --rhost SimpleHalTest
Change-Id: Ice6866028f9ea6e1b94d7aa6a240258731d7aa32
parent ab6c298f
Loading
Loading
Loading
Loading
+272 −0
Original line number Diff line number Diff line
//! Macros simplifying grpc service definitions

extern crate proc_macro;
use proc_macro::{TokenStream, TokenTree};
use quote::{format_ident, quote, quote_spanned};
use syn::parse::{Parse, ParseStream, Result};
use syn::{
    braced, parenthesized, parse_macro_input, parse_quote, Block, Expr, FnArg, Ident,
    ImplItemMethod, PatType, Path, Receiver, Token, Type,
};

/// provices shortcut syntax for defining proto-based rpc services
#[proc_macro]
pub fn grpc_service(item: TokenStream) -> TokenStream {
    let service = parse_macro_input!(item as ServiceDef);
    let grpc_trait = service.grpc_trait.clone();
    let struct_ = service.struct_.clone();
    let struct_for_facade = service.struct_.clone();
    let functions = service.items.iter().map(|i| match i {
        ServiceItem::Raw(inner) => Some(inner.clone()),
        ServiceItem::Rpc(inner) => Some(inner.clone().generate_fn()),
    });

    let mut grpc_path = grpc_trait.clone().segments;
    let grpc_create_fn = format_ident!(
        "create_{}",
        to_snake_case(&grpc_path.pop().unwrap().into_value().ident.to_string().as_str())
    );

    let emitted_code = quote! {
        impl #grpc_trait for #struct_ {
            #(#functions)*
        }

        impl bt_common::GrpcFacade for #struct_for_facade {
            fn into_grpc(self) -> grpcio::Service {
                #grpc_path#grpc_create_fn(self)
            }
        }
    };

    emitted_code.into()
}

struct ServiceDef {
    grpc_trait: Path,
    struct_: Type,
    items: Vec<ServiceItem>,
}

enum ServiceItem {
    Raw(ImplItemMethod),
    Rpc(RpcItem),
}

#[derive(Clone)]
struct RpcItem {
    name: Ident,
    input: PatType,
    output: RpcReturnType,
    unimplemented: bool,
    drain: Option<Box<Expr>>,
    code: Option<Block>,
}

impl RpcItem {
    fn generate_fn(self) -> ImplItemMethod {
        let name = self.name;
        let input = self.input;
        let output = self.output.type_;
        let tokens = match (self.drain, self.code) {
            (Some(drain), None) if self.output.stream => {
                quote_spanned! {
                    name.span()=>
                    fn #name(&mut self, _ctx: grpcio::RpcContext<'_>, #input, mut sink: grpcio::ServerStreamingSink<#output>) {
                        let stream = #drain.clone();
                        self.rt.spawn(async move {
                            while let Some(item) = stream.lock().await.recv().await {
                                sink.send((item.to_proto(), grpcio::WriteFlags::default())).await.unwrap();
                            }
                        });
                    }
                }
            }
            (Some(drain), None) if !self.output.stream => {
                let input_pat = input.clone().pat;
                quote_spanned! {
                    name.span()=>
                    fn #name(&mut self, _ctx: grpcio::RpcContext<'_>, #input, sink: grpcio::UnarySink<#output>) {
                        let channel = #drain.clone();
                        self.rt.block_on(async move {
                            channel.send(#input_pat.to_packet()).await.unwrap();
                        });
                        sink.success(Empty::default());
                    }
                }
            }
            (None, Some(code)) if !self.output.stream => {
                let tokens = quote! { #code };
                let rewritten_code = syn::parse::<Block>(replace_self(tokens.into())).unwrap();
                quote_spanned! {
                    name.span()=>
                    fn #name(&mut self, _ctx: grpcio::RpcContext<'_>, #input, sink: grpcio::UnarySink<#output>) {
                        let mut ___implicit_self___ = self.clone();
                        self.rt.block_on(async move {
                            #rewritten_code
                        });
                        sink.success(Empty::default());
                    }
                }
            }
            (None, None) if self.unimplemented => {
                let sink_type = format_ident!(
                    "{}",
                    if self.output.stream { "ServerStreamingSink" } else { "UnarySink" }
                );
                quote_spanned! {
                    name.span()=>
                    fn #name(&mut self, _ctx: grpcio::RpcContext<'_>, #input, _sink: grpcio::#sink_type<#output>) {
                        unimplemented!();
                    }
                }
            }
            (_, _) => {
                let sink_type = format_ident!(
                    "{}",
                    if self.output.stream { "ServerStreamingSink" } else { "UnarySink" }
                );
                quote_spanned! {
                    name.span()=>
                    fn #name(&mut self, _ctx: grpcio::RpcContext<'_>, #input, _sink: grpcio::#sink_type<#output>) {
                        compile_error!("support for this syntax is not supported yet");
                    }
                }
            }
        };

        syn::parse(tokens.into()).unwrap()
    }
}

#[derive(Clone)]
struct RpcReturnType {
    type_: Type,
    stream: bool,
}

impl Parse for ServiceDef {
    fn parse(input: ParseStream) -> Result<Self> {
        let _impl_token: Token![impl] = input.parse()?;
        let grpc_trait = input.parse()?;
        let _for: Token![for] = input.parse()?;
        let struct_ = input.parse()?;
        let content;
        braced!(content in input);

        let mut items = Vec::new();
        while !content.is_empty() {
            items.push(content.parse()?);
        }

        Ok(ServiceDef { grpc_trait, struct_, items })
    }
}

impl Parse for ServiceItem {
    fn parse(input: ParseStream) -> Result<Self> {
        if input.peek(Token![fn]) {
            Ok(ServiceItem::Raw(input.parse()?))
        } else {
            match input.parse::<Ident>()?.to_string().as_str() {
                "rpc" => Ok(ServiceItem::Rpc(input.parse()?)),
                keyword => panic!("unexpected keyword {}", keyword),
            }
        }
    }
}

impl Parse for RpcItem {
    fn parse(input: ParseStream) -> Result<Self> {
        let name: Ident = input.parse()?;
        let rpc_input;
        parenthesized!(rpc_input in input);
        let receiver: Receiver = rpc_input.parse()?;
        if receiver.mutability.is_none() {
            panic!("self should be mutable");
        }
        if receiver.reference.is_none() {
            panic!("self should be by reference");
        }
        let rpc_input: FnArg = if rpc_input.is_empty() {
            parse_quote! {
                _arg: Empty
            }
        } else {
            rpc_input.parse::<Token![,]>()?;
            rpc_input.parse()?
        };
        let rpc_input = match rpc_input {
            FnArg::Receiver(r) => panic!("did not expect {:?}", r),
            FnArg::Typed(t) => t,
        };

        let output = if input.peek(Token![->]) {
            let _arrow: Token![->] = input.parse()?;
            let stream = input.peek2(Ident);
            if stream && input.parse::<Ident>()?.to_string().as_str() != "stream" {
                panic!("expected \'stream\' keyword");
            }
            RpcReturnType { type_: input.parse()?, stream }
        } else {
            RpcReturnType {
                type_: parse_quote! {
                    Empty
                },
                stream: false,
            }
        };
        let (unimplemented, drain, code) = if input.peek(Token![=>]) {
            input.parse::<Token![=>]>()?;
            match input.parse::<Ident>()?.to_string().as_str() {
                "unimplemented" => {
                    input.parse::<Token![!]>()?;
                    let contents;
                    parenthesized!(contents in input);
                    if !contents.is_empty() {
                        panic!("expected empty unimplemented!()");
                    }
                    (true, None, None)
                }
                "drains" if output.stream => (false, Some(input.parse()?), None),
                "into" if !output.stream => (false, Some(input.parse()?), None),
                keyword => panic!("unexpected keyword {}", keyword),
            }
        } else {
            (false, None, Some(input.parse()?))
        };

        Ok(RpcItem { name, input: rpc_input, output, unimplemented, drain, code })
    }
}

fn to_snake_case(s: &str) -> String {
    let mut output = String::default();
    let mut first = true;
    for c in s.chars() {
        if c.is_uppercase() && !first {
            output.push('_');
        }
        output.push_str(&c.to_lowercase().to_string());
        first = false;
    }

    output
}

fn replace_self(stream: TokenStream) -> TokenStream {
    stream
        .into_iter()
        .map(|tt| match tt {
            TokenTree::Ident(i) if i.to_string() == "self" => {
                TokenTree::Ident(proc_macro::Ident::new("___implicit_self___", i.span()))
            }
            TokenTree::Group(g) => {
                let mut group = proc_macro::Group::new(g.delimiter(), replace_self(g.stream()));
                group.set_span(g.span());
                TokenTree::Group(group)
            }
            other => other,
        })
        .collect()
}
+9 −19
Original line number Diff line number Diff line
@@ -5,13 +5,12 @@ use bt_common::GrpcFacade;
use bt_facade_proto::common::Data;
use bt_facade_proto::empty::Empty;
use bt_facade_proto::hal_facade_grpc::{create_hci_hal_facade, HciHalFacade};
use bt_packets::hci;
use bt_packets::hci::{AclPacket, CommandPacket};
use futures::sink::SinkExt;
use gddi::{module, provides, Stoppable};
use grpcio::*;
use std::sync::Arc;
use tokio::runtime::Runtime;
use tokio::sync::{mpsc, Mutex};

module! {
    hal_facade_module,
@@ -22,23 +21,14 @@ module! {

#[provides]
async fn provide_facade(hal_exports: HalExports, rt: Arc<Runtime>) -> HciHalFacadeService {
    HciHalFacadeService {
        rt,
        cmd_tx: hal_exports.cmd_tx,
        evt_rx: hal_exports.evt_rx,
        acl_tx: hal_exports.acl_tx,
        acl_rx: hal_exports.acl_rx,
    }
    HciHalFacadeService { rt, hal_exports }
}

/// HCI HAL facade service
#[derive(Clone, Stoppable)]
pub struct HciHalFacadeService {
    rt: Arc<Runtime>,
    cmd_tx: mpsc::Sender<hci::CommandPacket>,
    evt_rx: Arc<Mutex<mpsc::Receiver<hci::EventPacket>>>,
    acl_tx: mpsc::Sender<hci::AclPacket>,
    acl_rx: Arc<Mutex<mpsc::Receiver<hci::AclPacket>>>,
    hal_exports: HalExports,
}

impl GrpcFacade for HciHalFacadeService {
@@ -49,17 +39,17 @@ impl GrpcFacade for HciHalFacadeService {

impl HciHalFacade for HciHalFacadeService {
    fn send_command(&mut self, _ctx: RpcContext<'_>, mut data: Data, sink: UnarySink<Empty>) {
        let cmd_tx = self.cmd_tx.clone();
        let cmd_tx = self.hal_exports.cmd_tx.clone();
        self.rt.block_on(async move {
            cmd_tx.send(hci::CommandPacket::parse(&data.take_payload()).unwrap()).await.unwrap();
            cmd_tx.send(CommandPacket::parse(&data.take_payload()).unwrap()).await.unwrap();
        });
        sink.success(Empty::default());
    }

    fn send_acl(&mut self, _ctx: RpcContext<'_>, mut data: Data, sink: UnarySink<Empty>) {
        let acl_tx = self.acl_tx.clone();
        let acl_tx = self.hal_exports.acl_tx.clone();
        self.rt.block_on(async move {
            acl_tx.send(hci::AclPacket::parse(&data.take_payload()).unwrap()).await.unwrap();
            acl_tx.send(AclPacket::parse(&data.take_payload()).unwrap()).await.unwrap();
        });
        sink.success(Empty::default());
    }
@@ -78,7 +68,7 @@ impl HciHalFacade for HciHalFacadeService {
        _: Empty,
        mut sink: ServerStreamingSink<Data>,
    ) {
        let evt_rx = self.evt_rx.clone();
        let evt_rx = self.hal_exports.evt_rx.clone();
        self.rt.spawn(async move {
            while let Some(event) = evt_rx.lock().await.recv().await {
                let mut output = Data::default();
@@ -89,7 +79,7 @@ impl HciHalFacade for HciHalFacadeService {
    }

    fn stream_acl(&mut self, _ctx: RpcContext<'_>, _: Empty, mut sink: ServerStreamingSink<Data>) {
        let acl_rx = self.acl_rx.clone();
        let acl_rx = self.hal_exports.acl_rx.clone();
        self.rt.spawn(async move {
            while let Some(acl) = acl_rx.lock().await.recv().await {
                let mut output = Data::default();
+17 −15
Original line number Diff line number Diff line
@@ -6,7 +6,9 @@ use bt_facade_proto::common::Data;
use bt_facade_proto::empty::Empty;
use bt_facade_proto::hci_facade::EventRequest;
use bt_facade_proto::hci_facade_grpc::{create_hci_layer_facade, HciLayerFacade};
use bt_packets::hci;
use bt_packets::hci::{
    AclPacket, CommandPacket, EventCode, EventPacket, LeMetaEventPacket, SubeventCode,
};
use futures::sink::SinkExt;
use gddi::{module, provides, Stoppable};
use grpcio::*;
@@ -25,8 +27,8 @@ module! {

#[provides]
async fn provide_facade(hci_exports: HciExports, rt: Arc<Runtime>) -> HciLayerFacadeService {
    let (from_hci_evt_tx, to_grpc_evt_rx) = channel::<hci::EventPacket>(10);
    let (from_hci_le_evt_tx, to_grpc_le_evt_rx) = channel::<hci::LeMetaEventPacket>(10);
    let (from_hci_evt_tx, to_grpc_evt_rx) = channel::<EventPacket>(10);
    let (from_hci_le_evt_tx, to_grpc_le_evt_rx) = channel::<LeMetaEventPacket>(10);
    HciLayerFacadeService {
        hci_exports,
        rt,
@@ -42,10 +44,10 @@ async fn provide_facade(hci_exports: HciExports, rt: Arc<Runtime>) -> HciLayerFa
pub struct HciLayerFacadeService {
    hci_exports: HciExports,
    rt: Arc<Runtime>,
    from_hci_evt_tx: Sender<hci::EventPacket>,
    to_grpc_evt_rx: Arc<Mutex<Receiver<hci::EventPacket>>>,
    from_hci_le_evt_tx: Sender<hci::LeMetaEventPacket>,
    to_grpc_le_evt_rx: Arc<Mutex<Receiver<hci::LeMetaEventPacket>>>,
    from_hci_evt_tx: Sender<EventPacket>,
    to_grpc_evt_rx: Arc<Mutex<Receiver<EventPacket>>>,
    from_hci_le_evt_tx: Sender<LeMetaEventPacket>,
    to_grpc_le_evt_rx: Arc<Mutex<Receiver<LeMetaEventPacket>>>,
}

impl GrpcFacade for HciLayerFacadeService {
@@ -63,7 +65,7 @@ impl HciLayerFacade for HciLayerFacadeService {
    ) {
        self.rt
            .block_on(
                self.hci_exports.send_raw(hci::CommandPacket::parse(&data.take_payload()).unwrap()),
                self.hci_exports.send_raw(CommandPacket::parse(&data.take_payload()).unwrap()),
            )
            .unwrap();
        sink.success(Empty::default());
@@ -77,15 +79,15 @@ impl HciLayerFacade for HciLayerFacadeService {
    ) {
        self.rt
            .block_on(
                self.hci_exports.send_raw(hci::CommandPacket::parse(&data.take_payload()).unwrap()),
                self.hci_exports.send_raw(CommandPacket::parse(&data.take_payload()).unwrap()),
            )
            .unwrap();
        sink.success(Empty::default());
    }

    fn request_event(&mut self, _ctx: RpcContext<'_>, code: EventRequest, sink: UnarySink<Empty>) {
    fn request_event(&mut self, _ctx: RpcContext<'_>, req: EventRequest, sink: UnarySink<Empty>) {
        self.rt.block_on(self.hci_exports.register_event_handler(
            hci::EventCode::from_u32(code.get_code()).unwrap(),
            EventCode::from_u32(req.get_code()).unwrap(),
            self.from_hci_evt_tx.clone(),
        ));
        sink.success(Empty::default());
@@ -94,11 +96,11 @@ impl HciLayerFacade for HciLayerFacadeService {
    fn request_le_subevent(
        &mut self,
        _ctx: RpcContext<'_>,
        code: EventRequest,
        req: EventRequest,
        sink: UnarySink<Empty>,
    ) {
        self.rt.block_on(self.hci_exports.register_le_event_handler(
            hci::SubeventCode::from_u32(code.get_code()).unwrap(),
            SubeventCode::from_u32(req.get_code()).unwrap(),
            self.from_hci_le_evt_tx.clone(),
        ));
        sink.success(Empty::default());
@@ -107,7 +109,7 @@ impl HciLayerFacade for HciLayerFacadeService {
    fn send_acl(&mut self, _ctx: RpcContext<'_>, mut packet: Data, sink: UnarySink<Empty>) {
        let acl_tx = self.hci_exports.acl_tx.clone();
        self.rt.block_on(async move {
            acl_tx.send(hci::AclPacket::parse(&packet.take_payload()).unwrap()).await.unwrap();
            acl_tx.send(AclPacket::parse(&packet.take_payload()).unwrap()).await.unwrap();
        });
        sink.success(Empty::default());
    }
@@ -139,7 +141,7 @@ impl HciLayerFacade for HciLayerFacadeService {

        self.rt.spawn(async move {
            while let Some(event) = evt_rx.lock().await.recv().await {
                let mut evt = LeSubevent::default();
                let mut evt = Data::default();
                evt.set_payload(event.to_vec());
                resp.send((evt, WriteFlags::default())).await.unwrap();
            }