Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit fa98fe99 authored by Martin Geisler's avatar Martin Geisler Committed by Gerrit Code Review
Browse files

Merge "pdl: Generate simpler code for PDL structs"

parents d163ac53 0c8189de
Loading
Loading
Loading
Loading
+142 −30
Original line number Diff line number Diff line
@@ -16,13 +16,11 @@ use std::path::Path;

use crate::parser::ast as parser_ast;

mod declarations;
mod parser;
mod preamble;
mod serializer;
mod types;

use declarations::FieldDeclarations;
use parser::FieldParser;
use serializer::FieldSerializer;

@@ -59,6 +57,7 @@ pub fn mask_bits(n: usize, suffix: &str) -> syn::LitInt {
fn generate_packet_size_getter(
    scope: &lint::Scope<'_>,
    fields: &[&parser_ast::Field],
    is_packet: bool,
) -> (usize, proc_macro2::TokenStream) {
    let mut constant_width = 0;
    let mut dynamic_widths = Vec::new();
@@ -71,9 +70,17 @@ fn generate_packet_size_getter(

        let decl = field.declaration(scope);
        dynamic_widths.push(match &field.desc {
            ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. } => quote! {
            ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. } => {
                if is_packet {
                    quote! {
                        self.child.get_total_size()
            },
                    }
                } else {
                    quote! {
                        self.payload.len()
                    }
                }
            }
            ast::FieldDesc::Typedef { id, .. } => {
                let id = format_ident!("{id}");
                quote!(self.#id.get_size())
@@ -191,33 +198,31 @@ fn generate_data_struct(
    endianness: ast::EndiannessValue,
    id: &str,
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
    let packet_scope = &scope.scopes[&scope.typedef[id]];
    let id_data = format_ident!("{id}Data");

    let fields_with_ids =
        packet_scope.fields.iter().filter(|f| f.id().is_some()).collect::<Vec<_>>();
    let field_names =
        fields_with_ids.iter().map(|f| format_ident!("{}", f.id().unwrap())).collect::<Vec<_>>();
    let decl = scope.typedef[id];
    let packet_scope = &scope.scopes[&decl];
    let is_packet = matches!(&decl.desc, ast::DeclDesc::Packet { .. });

    let span = format_ident!("bytes");
    let serializer_span = format_ident!("buffer");
    let mut field_declarations = FieldDeclarations::new(scope, id);
    let mut field_parser = FieldParser::new(scope, endianness, id, &span);
    let mut field_serializer = FieldSerializer::new(scope, endianness, id, &serializer_span);
    for field in &packet_scope.fields {
        field_declarations.add(field);
        field_parser.add(field);
        field_serializer.add(field);
    }
    field_declarations.done();
    field_parser.done();

    let parse_fields = find_constrained_parent_fields(scope, id).collect::<Vec<_>>();
    let parse_field_names =
        parse_fields.iter().map(|f| format_ident!("{}", f.id().unwrap())).collect::<Vec<_>>();
    let parse_field_types = parse_fields.iter().map(|f| types::rust_type(f));
    let (parse_arg_names, parse_arg_types) = if is_packet {
        let fields = find_constrained_parent_fields(scope, id).collect::<Vec<_>>();
        let names = fields.iter().map(|f| format_ident!("{}", f.id().unwrap())).collect::<Vec<_>>();
        let types = fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
        (names, types)
    } else {
        (Vec::new(), Vec::new()) // No extra arguments to parse in structs.
    };

    let (constant_width, packet_size) = generate_packet_size_getter(scope, &packet_scope.fields);
    let (constant_width, packet_size) =
        generate_packet_size_getter(scope, &packet_scope.fields, is_packet);
    let conforms = if constant_width == 0 {
        quote! { true }
    } else {
@@ -225,30 +230,58 @@ fn generate_data_struct(
        quote! { #span.len() >= #constant_width }
    };

    let visibility = if is_packet { quote!() } else { quote!(pub) };
    let has_payload = packet_scope.payload.is_some();
    let children = get_packet_children(scope, id);
    let has_children_or_payload = !children.is_empty() || has_payload;
    let child_field = has_children_or_payload.then(|| quote!(child));
    let struct_name = if is_packet { format_ident!("{id}Data") } else { format_ident!("{id}") };
    let fields_with_ids =
        packet_scope.fields.iter().filter(|f| f.id().is_some()).collect::<Vec<_>>();
    let mut field_names =
        fields_with_ids.iter().map(|f| format_ident!("{}", f.id().unwrap())).collect::<Vec<_>>();
    let mut field_types = fields_with_ids.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
    if has_children_or_payload {
        if is_packet {
            field_names.push(format_ident!("child"));
            let field_type = format_ident!("{id}DataChild");
            field_types.push(quote!(#field_type));
        } else {
            field_names.push(format_ident!("payload"));
            field_types.push(quote!(Vec<u8>));
        }
    }

    let data_struct_decl = quote! {
        #[derive(Debug, Clone, PartialEq, Eq)]
        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
        pub struct #id_data {
            #field_declarations
        pub struct #struct_name {
            #(#visibility #field_names: #field_types,)*
        }
    };

    let data_struct_impl = quote! {
        impl #id_data {
        impl #struct_name {
            fn conforms(#span: &[u8]) -> bool {
                #conforms
            }

            fn parse(mut #span: &mut Cell<&[u8]> #(, #parse_field_names: #parse_field_types)*) -> Result<Self> {
            #visibility fn parse(
                #span: &[u8] #(, #parse_arg_names: #parse_arg_types)*
            ) -> Result<Self> {
                let mut cell = Cell::new(#span);
                let packet = Self::parse_inner(&mut cell #(, #parse_arg_names)*)?;
                if !cell.get().is_empty() {
                    return Err(Error::InvalidPacketError);
                }
                Ok(packet)
            }

            fn parse_inner(
                mut #span: &mut Cell<&[u8]> #(, #parse_arg_names: #parse_arg_types)*
            ) -> Result<Self> {
                #field_parser
                Ok(Self {
                    #(#field_names,)*
                    #child_field
                })
            }

@@ -310,8 +343,7 @@ pub fn constraint_to_value(
    }
}

/// Generate code for `ast::Decl::Packet` and `ast::Decl::Struct`
/// values.
/// Generate code for a `ast::Decl::Packet`.
fn generate_packet_decl(
    scope: &lint::Scope<'_>,
    endianness: ast::EndiannessValue,
@@ -578,7 +610,7 @@ fn generate_packet_decl(
            }

            fn parse_inner(mut bytes: &mut Cell<&[u8]>) -> Result<Self> {
                let data = #top_level_data::parse(&mut bytes)?;
                let data = #top_level_data::parse_inner(&mut bytes)?;
                Ok(Self::new(Arc::new(data)).unwrap())
            }

@@ -627,6 +659,19 @@ fn generate_packet_decl(
    }
}

/// Generate code for a `ast::Decl::Struct`.
fn generate_struct_decl(
    scope: &lint::Scope<'_>,
    endianness: ast::EndiannessValue,
    id: &str,
) -> proc_macro2::TokenStream {
    let (struct_decl, struct_impl) = generate_data_struct(scope, endianness, id);
    quote! {
        #struct_decl
        #struct_impl
    }
}

fn generate_enum_decl(id: &str, tags: &[ast::Tag]) -> proc_macro2::TokenStream {
    let name = format_ident!("{id}");
    let variants =
@@ -694,9 +739,17 @@ fn generate_decl(
    decl: &parser_ast::Decl,
) -> String {
    match &decl.desc {
        ast::DeclDesc::Packet { id, .. } | ast::DeclDesc::Struct { id, .. } => {
        ast::DeclDesc::Packet { id, .. } => {
            generate_packet_decl(scope, file.endianness.value, id).to_string()
        }
        ast::DeclDesc::Struct { id, parent_id: None, .. } => {
            // TODO(mgeisler): handle structs with parents. We could
            // generate code for them, but the code is not useful
            // since it would require the caller to unpack everything
            // manually. We either need to change the API, or
            // implement the recursive (de)serialization.
            generate_struct_decl(scope, file.endianness.value, id).to_string()
        }
        ast::DeclDesc::Enum { id, tags, .. } => generate_enum_decl(id, tags).to_string(),
        _ => todo!("unsupported Decl::{:?}", decl),
    }
@@ -1096,4 +1149,63 @@ mod tests {
          }
        "
    );

    // TODO(mgeisler): enable this test when we have an approach to
    // struct fields with parents.
    //
    // test_pdl!(
    //     struct_decl_child_structs,
    //     "
    //       enum Enum16 : 16 {
    //         A = 1,
    //         B = 2,
    //       }
    //
    //       struct Foo {
    //           a: 8,
    //           b: Enum16,
    //           _size_(_payload_): 8,
    //           _payload_
    //       }
    //
    //       struct Bar : Foo (a = 100) {
    //           x: 8,
    //       }
    //
    //       struct Baz : Foo (b = B) {
    //           y: 16,
    //       }
    //     "
    // );
    //
    // test_pdl!(
    //     struct_decl_grand_children,
    //     "
    //       enum Enum16 : 16 {
    //         A = 1,
    //         B = 2,
    //       }
    //
    //       struct Parent {
    //           foo: Enum16,
    //           bar: Enum16,
    //           baz: Enum16,
    //           _size_(_payload_): 8,
    //           _payload_
    //       }
    //
    //       struct Child : Parent (foo = A) {
    //           quux: Enum16,
    //           _payload_,
    //       }
    //
    //       struct GrandChild : Child (bar = A, quux = A) {
    //           _body_,
    //       }
    //
    //       struct GrandGrandChild : GrandChild (baz = A) {
    //           _body_,
    //       }
    //     "
    // );
}
+0 −47
Original line number Diff line number Diff line
use crate::backends::rust::types;
use crate::lint;
use crate::parser::ast as parser_ast;
use quote::{format_ident, quote};

pub struct FieldDeclarations<'a> {
    scope: &'a lint::Scope<'a>,
    packet_name: &'a str,
    code: Vec<proc_macro2::TokenStream>,
}

impl<'a> FieldDeclarations<'a> {
    pub fn new(scope: &'a lint::Scope<'a>, packet_name: &'a str) -> FieldDeclarations<'a> {
        FieldDeclarations { scope, packet_name, code: Vec::new() }
    }

    pub fn add(&mut self, field: &parser_ast::Field) {
        let id = match field.id() {
            Some(id) => format_ident!("{id}"),
            None => return, // No id => field not stored.
        };

        let field_type = types::rust_type(field);
        self.code.push(quote! {
            #id: #field_type,
        })
    }

    pub fn done(&mut self) {
        let packet_data_child = format_ident!("{}DataChild", self.packet_name);
        let packet_scope = &self.scope.scopes[&self.scope.typedef[self.packet_name]];
        if self.scope.children.contains_key(self.packet_name) || packet_scope.payload.is_some() {
            self.code.push(quote! {
                child: #packet_data_child,
            });
        }
    }
}

impl quote::ToTokens for FieldDeclarations<'_> {
    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
        let code = &self.code;
        tokens.extend(quote! {
            #(#code)*
        });
    }
}
+14 −2
Original line number Diff line number Diff line
@@ -520,6 +520,13 @@ impl<'a> FieldParser<'a> {
                #span.get_mut().advance(payload.len());
            });
        }

        let decl = self.scope.typedef[self.packet_name];
        if let ast::DeclDesc::Struct { .. } = &decl.desc {
            self.code.push(quote! {
                let payload = Vec::from(payload);
            });
        }
    }

    /// Parse a single array field element from `span`.
@@ -560,7 +567,12 @@ impl<'a> FieldParser<'a> {
    }

    pub fn done(&mut self) {
        let packet_scope = &self.scope.scopes[&self.scope.typedef[self.packet_name]];
        let decl = self.scope.typedef[self.packet_name];
        if let parser_ast::DeclDesc::Struct { .. } = &decl.desc {
            return; // Structs don't parse the child structs recursively.
        }

        let packet_scope = &self.scope.scopes[&decl];
        let children =
            self.scope.children.get(self.packet_name).map(Vec::as_slice).unwrap_or_default();
        if children.is_empty() && packet_scope.payload.is_none() {
@@ -616,7 +628,7 @@ impl<'a> FieldParser<'a> {
            let child = match (#(#constrained_field_idents),*) {
                #(#match_values => {
                    let mut cell = Cell::new(payload);
                    let child_data = #child_ids_data::parse(&mut cell #child_parse_args)?;
                    let child_data = #child_ids_data::parse_inner(&mut cell #child_parse_args)?;
                    if !cell.get().is_empty() {
                        return Err(Error::InvalidPacketError);
                    }
+23 −11
Original line number Diff line number Diff line
@@ -121,8 +121,11 @@ impl<'a> FieldSerializer<'a> {
                let array_size = match (&value_field.desc, value_field_decl.map(|decl| &decl.desc))
                {
                    (ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. }, _) => {
                        //let span = format_ident!("{}", self.span);
                        if let ast::DeclDesc::Packet { .. } = &decl.desc {
                            quote! { self.child.get_total_size() }
                        } else {
                            quote! { self.payload.len() }
                        }
                    }
                    (ast::FieldDesc::Array { width: Some(width), .. }, _)
                    | (ast::FieldDesc::Array { .. }, Some(ast::DeclDesc::Enum { width, .. })) => {
@@ -291,6 +294,9 @@ impl<'a> FieldSerializer<'a> {
            panic!("Payload field does not start on an octet boundary");
        }

        let decl = self.scope.typedef[self.packet_name];
        let is_packet = matches!(&decl.desc, ast::DeclDesc::Packet { .. });

        let children =
            self.scope.children.get(self.packet_name).map(Vec::as_slice).unwrap_or_default();
        let child_ids = children
@@ -298,8 +304,9 @@ impl<'a> FieldSerializer<'a> {
            .map(|child| format_ident!("{}", child.id().unwrap()))
            .collect::<Vec<_>>();

        if self.shift == 0 {
        let span = format_ident!("{}", self.span);
        if self.shift == 0 {
            if is_packet {
                let packet_data_child = format_ident!("{}DataChild", self.packet_name);
                self.code.push(quote! {
                    match &self.child {
@@ -308,6 +315,11 @@ impl<'a> FieldSerializer<'a> {
                        #packet_data_child::None => {},
                    }
                })
            } else {
                self.code.push(quote! {
                    #span.put_slice(&self.payload);
                });
            }
        } else {
            todo!("Shifted payloads");
        }
+1 −0
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ pub mod ast {

    pub type Field = crate::ast::Field<Annotation>;
    pub type Decl = crate::ast::Decl<Annotation>;
    pub type DeclDesc = crate::ast::DeclDesc<Annotation>;
    pub type File = crate::ast::File<Annotation>;
}

Loading