Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 0c8189de authored by Martin Geisler's avatar Martin Geisler
Browse files

pdl: Generate simpler code for PDL structs

Before, we treated packets and structs the same: for a packet or
struct Foo, we would generate three Rust structs:

  struct Foo
  struct FooData
  struct FooBuilder

This doesn’t match the old bluetooth_packetgen compiler: it would only
generate a single Rust struct with public fields in case of a PDL
struct.

We now do the same for the new compiler. We reuse most of the
generator code from the old data struct generator since this is the
one that most closely resemble the struct we want.

Tag: #feature
Bug: 230475725
Test: atest pdl_tests pdl_rust_generator_tests_{le,be}
Change-Id: Idf2ae9cc1347d2f6b1a5b9b7c29f14e225e7080c
parent 4546d4a6
Loading
Loading
Loading
Loading
+142 −30
Original line number Diff line number Diff line
@@ -16,13 +16,11 @@ use std::path::Path;

use crate::parser::ast as parser_ast;

mod declarations;
mod parser;
mod preamble;
mod serializer;
mod types;

use declarations::FieldDeclarations;
use parser::FieldParser;
use serializer::FieldSerializer;

@@ -59,6 +57,7 @@ pub fn mask_bits(n: usize, suffix: &str) -> syn::LitInt {
fn generate_packet_size_getter(
    scope: &lint::Scope<'_>,
    fields: &[&parser_ast::Field],
    is_packet: bool,
) -> (usize, proc_macro2::TokenStream) {
    let mut constant_width = 0;
    let mut dynamic_widths = Vec::new();
@@ -71,9 +70,17 @@ fn generate_packet_size_getter(

        let decl = field.declaration(scope);
        dynamic_widths.push(match &field.desc {
            ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. } => quote! {
            ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. } => {
                if is_packet {
                    quote! {
                        self.child.get_total_size()
            },
                    }
                } else {
                    quote! {
                        self.payload.len()
                    }
                }
            }
            ast::FieldDesc::Typedef { id, .. } => {
                let id = format_ident!("{id}");
                quote!(self.#id.get_size())
@@ -191,33 +198,31 @@ fn generate_data_struct(
    endianness: ast::EndiannessValue,
    id: &str,
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
    let packet_scope = &scope.scopes[&scope.typedef[id]];
    let id_data = format_ident!("{id}Data");

    let fields_with_ids =
        packet_scope.fields.iter().filter(|f| f.id().is_some()).collect::<Vec<_>>();
    let field_names =
        fields_with_ids.iter().map(|f| format_ident!("{}", f.id().unwrap())).collect::<Vec<_>>();
    let decl = scope.typedef[id];
    let packet_scope = &scope.scopes[&decl];
    let is_packet = matches!(&decl.desc, ast::DeclDesc::Packet { .. });

    let span = format_ident!("bytes");
    let serializer_span = format_ident!("buffer");
    let mut field_declarations = FieldDeclarations::new(scope, id);
    let mut field_parser = FieldParser::new(scope, endianness, id, &span);
    let mut field_serializer = FieldSerializer::new(scope, endianness, id, &serializer_span);
    for field in &packet_scope.fields {
        field_declarations.add(field);
        field_parser.add(field);
        field_serializer.add(field);
    }
    field_declarations.done();
    field_parser.done();

    let parse_fields = find_constrained_parent_fields(scope, id).collect::<Vec<_>>();
    let parse_field_names =
        parse_fields.iter().map(|f| format_ident!("{}", f.id().unwrap())).collect::<Vec<_>>();
    let parse_field_types = parse_fields.iter().map(|f| types::rust_type(f));
    let (parse_arg_names, parse_arg_types) = if is_packet {
        let fields = find_constrained_parent_fields(scope, id).collect::<Vec<_>>();
        let names = fields.iter().map(|f| format_ident!("{}", f.id().unwrap())).collect::<Vec<_>>();
        let types = fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
        (names, types)
    } else {
        (Vec::new(), Vec::new()) // No extra arguments to parse in structs.
    };

    let (constant_width, packet_size) = generate_packet_size_getter(scope, &packet_scope.fields);
    let (constant_width, packet_size) =
        generate_packet_size_getter(scope, &packet_scope.fields, is_packet);
    let conforms = if constant_width == 0 {
        quote! { true }
    } else {
@@ -225,30 +230,58 @@ fn generate_data_struct(
        quote! { #span.len() >= #constant_width }
    };

    let visibility = if is_packet { quote!() } else { quote!(pub) };
    let has_payload = packet_scope.payload.is_some();
    let children = get_packet_children(scope, id);
    let has_children_or_payload = !children.is_empty() || has_payload;
    let child_field = has_children_or_payload.then(|| quote!(child));
    let struct_name = if is_packet { format_ident!("{id}Data") } else { format_ident!("{id}") };
    let fields_with_ids =
        packet_scope.fields.iter().filter(|f| f.id().is_some()).collect::<Vec<_>>();
    let mut field_names =
        fields_with_ids.iter().map(|f| format_ident!("{}", f.id().unwrap())).collect::<Vec<_>>();
    let mut field_types = fields_with_ids.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
    if has_children_or_payload {
        if is_packet {
            field_names.push(format_ident!("child"));
            let field_type = format_ident!("{id}DataChild");
            field_types.push(quote!(#field_type));
        } else {
            field_names.push(format_ident!("payload"));
            field_types.push(quote!(Vec<u8>));
        }
    }

    let data_struct_decl = quote! {
        #[derive(Debug, Clone, PartialEq, Eq)]
        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
        pub struct #id_data {
            #field_declarations
        pub struct #struct_name {
            #(#visibility #field_names: #field_types,)*
        }
    };

    let data_struct_impl = quote! {
        impl #id_data {
        impl #struct_name {
            fn conforms(#span: &[u8]) -> bool {
                #conforms
            }

            fn parse(mut #span: &mut Cell<&[u8]> #(, #parse_field_names: #parse_field_types)*) -> Result<Self> {
            #visibility fn parse(
                #span: &[u8] #(, #parse_arg_names: #parse_arg_types)*
            ) -> Result<Self> {
                let mut cell = Cell::new(#span);
                let packet = Self::parse_inner(&mut cell #(, #parse_arg_names)*)?;
                if !cell.get().is_empty() {
                    return Err(Error::InvalidPacketError);
                }
                Ok(packet)
            }

            fn parse_inner(
                mut #span: &mut Cell<&[u8]> #(, #parse_arg_names: #parse_arg_types)*
            ) -> Result<Self> {
                #field_parser
                Ok(Self {
                    #(#field_names,)*
                    #child_field
                })
            }

@@ -310,8 +343,7 @@ pub fn constraint_to_value(
    }
}

/// Generate code for `ast::Decl::Packet` and `ast::Decl::Struct`
/// values.
/// Generate code for a `ast::Decl::Packet`.
fn generate_packet_decl(
    scope: &lint::Scope<'_>,
    endianness: ast::EndiannessValue,
@@ -578,7 +610,7 @@ fn generate_packet_decl(
            }

            fn parse_inner(mut bytes: &mut Cell<&[u8]>) -> Result<Self> {
                let data = #top_level_data::parse(&mut bytes)?;
                let data = #top_level_data::parse_inner(&mut bytes)?;
                Ok(Self::new(Arc::new(data)).unwrap())
            }

@@ -627,6 +659,19 @@ fn generate_packet_decl(
    }
}

/// Generate code for a `ast::Decl::Struct`.
fn generate_struct_decl(
    scope: &lint::Scope<'_>,
    endianness: ast::EndiannessValue,
    id: &str,
) -> proc_macro2::TokenStream {
    let (struct_decl, struct_impl) = generate_data_struct(scope, endianness, id);
    quote! {
        #struct_decl
        #struct_impl
    }
}

fn generate_enum_decl(id: &str, tags: &[ast::Tag]) -> proc_macro2::TokenStream {
    let name = format_ident!("{id}");
    let variants =
@@ -694,9 +739,17 @@ fn generate_decl(
    decl: &parser_ast::Decl,
) -> String {
    match &decl.desc {
        ast::DeclDesc::Packet { id, .. } | ast::DeclDesc::Struct { id, .. } => {
        ast::DeclDesc::Packet { id, .. } => {
            generate_packet_decl(scope, file.endianness.value, id).to_string()
        }
        ast::DeclDesc::Struct { id, parent_id: None, .. } => {
            // TODO(mgeisler): handle structs with parents. We could
            // generate code for them, but the code is not useful
            // since it would require the caller to unpack everything
            // manually. We either need to change the API, or
            // implement the recursive (de)serialization.
            generate_struct_decl(scope, file.endianness.value, id).to_string()
        }
        ast::DeclDesc::Enum { id, tags, .. } => generate_enum_decl(id, tags).to_string(),
        _ => todo!("unsupported Decl::{:?}", decl),
    }
@@ -1096,4 +1149,63 @@ mod tests {
          }
        "
    );

    // TODO(mgeisler): enable this test when we have an approach to
    // struct fields with parents.
    //
    // test_pdl!(
    //     struct_decl_child_structs,
    //     "
    //       enum Enum16 : 16 {
    //         A = 1,
    //         B = 2,
    //       }
    //
    //       struct Foo {
    //           a: 8,
    //           b: Enum16,
    //           _size_(_payload_): 8,
    //           _payload_
    //       }
    //
    //       struct Bar : Foo (a = 100) {
    //           x: 8,
    //       }
    //
    //       struct Baz : Foo (b = B) {
    //           y: 16,
    //       }
    //     "
    // );
    //
    // test_pdl!(
    //     struct_decl_grand_children,
    //     "
    //       enum Enum16 : 16 {
    //         A = 1,
    //         B = 2,
    //       }
    //
    //       struct Parent {
    //           foo: Enum16,
    //           bar: Enum16,
    //           baz: Enum16,
    //           _size_(_payload_): 8,
    //           _payload_
    //       }
    //
    //       struct Child : Parent (foo = A) {
    //           quux: Enum16,
    //           _payload_,
    //       }
    //
    //       struct GrandChild : Child (bar = A, quux = A) {
    //           _body_,
    //       }
    //
    //       struct GrandGrandChild : GrandChild (baz = A) {
    //           _body_,
    //       }
    //     "
    // );
}
+0 −47
Original line number Diff line number Diff line
use crate::backends::rust::types;
use crate::lint;
use crate::parser::ast as parser_ast;
use quote::{format_ident, quote};

pub struct FieldDeclarations<'a> {
    scope: &'a lint::Scope<'a>,
    packet_name: &'a str,
    code: Vec<proc_macro2::TokenStream>,
}

impl<'a> FieldDeclarations<'a> {
    pub fn new(scope: &'a lint::Scope<'a>, packet_name: &'a str) -> FieldDeclarations<'a> {
        FieldDeclarations { scope, packet_name, code: Vec::new() }
    }

    pub fn add(&mut self, field: &parser_ast::Field) {
        let id = match field.id() {
            Some(id) => format_ident!("{id}"),
            None => return, // No id => field not stored.
        };

        let field_type = types::rust_type(field);
        self.code.push(quote! {
            #id: #field_type,
        })
    }

    pub fn done(&mut self) {
        let packet_data_child = format_ident!("{}DataChild", self.packet_name);
        let packet_scope = &self.scope.scopes[&self.scope.typedef[self.packet_name]];
        if self.scope.children.contains_key(self.packet_name) || packet_scope.payload.is_some() {
            self.code.push(quote! {
                child: #packet_data_child,
            });
        }
    }
}

impl quote::ToTokens for FieldDeclarations<'_> {
    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
        let code = &self.code;
        tokens.extend(quote! {
            #(#code)*
        });
    }
}
+14 −2
Original line number Diff line number Diff line
@@ -520,6 +520,13 @@ impl<'a> FieldParser<'a> {
                #span.get_mut().advance(payload.len());
            });
        }

        let decl = self.scope.typedef[self.packet_name];
        if let ast::DeclDesc::Struct { .. } = &decl.desc {
            self.code.push(quote! {
                let payload = Vec::from(payload);
            });
        }
    }

    /// Parse a single array field element from `span`.
@@ -560,7 +567,12 @@ impl<'a> FieldParser<'a> {
    }

    pub fn done(&mut self) {
        let packet_scope = &self.scope.scopes[&self.scope.typedef[self.packet_name]];
        let decl = self.scope.typedef[self.packet_name];
        if let parser_ast::DeclDesc::Struct { .. } = &decl.desc {
            return; // Structs don't parse the child structs recursively.
        }

        let packet_scope = &self.scope.scopes[&decl];
        let children =
            self.scope.children.get(self.packet_name).map(Vec::as_slice).unwrap_or_default();
        if children.is_empty() && packet_scope.payload.is_none() {
@@ -616,7 +628,7 @@ impl<'a> FieldParser<'a> {
            let child = match (#(#constrained_field_idents),*) {
                #(#match_values => {
                    let mut cell = Cell::new(payload);
                    let child_data = #child_ids_data::parse(&mut cell #child_parse_args)?;
                    let child_data = #child_ids_data::parse_inner(&mut cell #child_parse_args)?;
                    if !cell.get().is_empty() {
                        return Err(Error::InvalidPacketError);
                    }
+23 −11
Original line number Diff line number Diff line
@@ -121,8 +121,11 @@ impl<'a> FieldSerializer<'a> {
                let array_size = match (&value_field.desc, value_field_decl.map(|decl| &decl.desc))
                {
                    (ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. }, _) => {
                        //let span = format_ident!("{}", self.span);
                        if let ast::DeclDesc::Packet { .. } = &decl.desc {
                            quote! { self.child.get_total_size() }
                        } else {
                            quote! { self.payload.len() }
                        }
                    }
                    (ast::FieldDesc::Array { width: Some(width), .. }, _)
                    | (ast::FieldDesc::Array { .. }, Some(ast::DeclDesc::Enum { width, .. })) => {
@@ -291,6 +294,9 @@ impl<'a> FieldSerializer<'a> {
            panic!("Payload field does not start on an octet boundary");
        }

        let decl = self.scope.typedef[self.packet_name];
        let is_packet = matches!(&decl.desc, ast::DeclDesc::Packet { .. });

        let children =
            self.scope.children.get(self.packet_name).map(Vec::as_slice).unwrap_or_default();
        let child_ids = children
@@ -298,8 +304,9 @@ impl<'a> FieldSerializer<'a> {
            .map(|child| format_ident!("{}", child.id().unwrap()))
            .collect::<Vec<_>>();

        if self.shift == 0 {
        let span = format_ident!("{}", self.span);
        if self.shift == 0 {
            if is_packet {
                let packet_data_child = format_ident!("{}DataChild", self.packet_name);
                self.code.push(quote! {
                    match &self.child {
@@ -308,6 +315,11 @@ impl<'a> FieldSerializer<'a> {
                        #packet_data_child::None => {},
                    }
                })
            } else {
                self.code.push(quote! {
                    #span.put_slice(&self.payload);
                });
            }
        } else {
            todo!("Shifted payloads");
        }
+1 −0
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ pub mod ast {

    pub type Field = crate::ast::Field<Annotation>;
    pub type Decl = crate::ast::Decl<Annotation>;
    pub type DeclDesc = crate::ast::DeclDesc<Annotation>;
    pub type File = crate::ast::File<Annotation>;
}

Loading