Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 88c5808f authored by Martin Geisler's avatar Martin Geisler Committed by Gerrit Code Review
Browse files

Merge changes Ib193838a,I09ad078e,I7c93c6ef,Ibfe35033,Iffffdca4

* changes:
  pdl: Extract generate_data_struct helper
  pdl: Remove unused generate_packet_decl arguments
  pdl: Extract find_parents helper
  pdl: Extract constraint_to_value helper
  pdl: Add support for constraints in grand children
parents 1599c721 e441c330
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -93,6 +93,8 @@ rust_test_host {
        "tests/generated/packet_decl_fixed_enum_field_little_endian.rs",
        "tests/generated/packet_decl_fixed_scalar_field_big_endian.rs",
        "tests/generated/packet_decl_fixed_scalar_field_little_endian.rs",
        "tests/generated/packet_decl_grand_children_big_endian.rs",
        "tests/generated/packet_decl_grand_children_little_endian.rs",
        "tests/generated/packet_decl_mask_scalar_value_big_endian.rs",
        "tests/generated/packet_decl_mask_scalar_value_little_endian.rs",
        "tests/generated/packet_decl_mixed_scalars_enums_big_endian.rs",
+256 −88
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@
use crate::{ast, lint};
use heck::ToUpperCamelCase;
use quote::{format_ident, quote};
use std::collections::BTreeSet;
use std::path::Path;

use crate::parser::ast as parser_ast;
@@ -45,7 +46,7 @@ pub fn mask_bits(n: usize) -> syn::LitInt {

fn generate_packet_size_getter(
    scope: &lint::Scope<'_>,
    fields: &[parser_ast::Field],
    fields: &[&parser_ast::Field],
) -> (usize, proc_macro2::TokenStream) {
    let mut constant_width = 0;
    let mut dynamic_widths = Vec::new();
@@ -123,33 +124,75 @@ fn top_level_packet<'a>(scope: &lint::Scope<'a>, packet_name: &'a str) -> &'a pa
    decl
}

/// Generate code for `ast::Decl::Packet` and `ast::Decl::Struct`
/// values.
fn generate_packet_decl(
fn get_packet_children<'a>(scope: &'a lint::Scope<'_>, id: &str) -> &'a [&'a parser_ast::Decl] {
    scope.children.get(id).map(Vec::as_slice).unwrap_or_default()
}

/// Find all constrained fields in children of `id`.
fn find_constrained_fields<'a>(
    scope: &'a lint::Scope<'a>,
    id: &'a str,
) -> Vec<&'a parser_ast::Field> {
    let mut fields = Vec::new();
    let mut field_names = BTreeSet::new();
    let mut children = Vec::from(get_packet_children(scope, id));

    while let Some(child) = children.pop() {
        if let ast::DeclDesc::Packet { id, constraints, .. }
        | ast::DeclDesc::Struct { id, constraints, .. } = &child.desc
        {
            let packet_scope = &scope.scopes[&scope.typedef[id]];
            for constraint in constraints {
                if field_names.insert(&constraint.id) {
                    fields.push(packet_scope.all_fields[&constraint.id]);
                }
            }
            children.extend(get_packet_children(scope, id));
        }
    }

    fields
}

/// Find parent fields which are constrained in child packets.
///
/// These fields are the fields which need to be passed in when
/// parsing a `id` packet since their values are needed for one or
/// more child packets.
fn find_constrained_parent_fields<'a>(
    scope: &'a lint::Scope<'a>,
    id: &'a str,
) -> impl Iterator<Item = &'a parser_ast::Field> {
    let packet_scope = &scope.scopes[&scope.typedef[id]];
    find_constrained_fields(scope, id).into_iter().filter(|field| {
        let id = field.id().unwrap();
        packet_scope.all_fields.contains_key(id) && !packet_scope.named.contains_key(id)
    })
}

/// Generate the declaration and implementation for a data struct.
///
/// This struct will hold the data for a packet or a struct. It knows
/// how to parse and serialize its own fields.
fn generate_data_struct(
    scope: &lint::Scope<'_>,
    //  File:
    endianness: ast::EndiannessValue,
    // Packet:
    id: &str,
    _constraints: &[ast::Constraint],
    fields: &[parser_ast::Field],
) -> proc_macro2::TokenStream {
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
    let packet_scope = &scope.scopes[&scope.typedef[id]];
    let id_data = format_ident!("{id}Data");

    let top_level = top_level_packet(scope, id);
    let top_level_id = top_level.id().unwrap();
    let top_level_packet = format_ident!("{top_level_id}");
    let top_level_data = format_ident!("{top_level_id}Data");
    let top_level_id_lower = format_ident!("{}", top_level_id.to_lowercase());
    let fields_with_ids =
        packet_scope.fields.iter().filter(|f| f.id().is_some()).collect::<Vec<_>>();
    let field_names =
        fields_with_ids.iter().map(|f| format_ident!("{}", f.id().unwrap())).collect::<Vec<_>>();

    // TODO(mgeisler): use the convert_case crate to convert between
    // `FooBar` and `foo_bar` in the code below.
    let span = format_ident!("bytes");
    let serializer_span = format_ident!("buffer");
    let mut field_declarations = FieldDeclarations::new(scope, id);
    let mut field_parser = FieldParser::new(scope, endianness, id, &span);
    let mut field_serializer = FieldSerializer::new(scope, endianness, id, &serializer_span);
    for field in fields {
    for field in &packet_scope.fields {
        field_declarations.add(field);
        field_parser.add(field);
        field_serializer.add(field);
@@ -157,17 +200,67 @@ fn generate_packet_decl(
    field_declarations.done();
    field_parser.done();

    let id_lower = format_ident!("{}", id.to_lowercase());
    let id_packet = format_ident!("{id}");
    let id_child = format_ident!("{id}Child");
    let id_data = format_ident!("{id}Data");
    let id_data_child = format_ident!("{id}DataChild");
    let id_builder = format_ident!("{id}Builder");
    let parse_fields = find_constrained_parent_fields(scope, id).collect::<Vec<_>>();
    let parse_field_names =
        parse_fields.iter().map(|f| format_ident!("{}", f.id().unwrap())).collect::<Vec<_>>();
    let parse_field_types = parse_fields.iter().map(|f| types::rust_type(f));

    let fields_with_ids = fields.iter().filter(|f| f.id().is_some()).collect::<Vec<_>>();
    let field_names =
        fields_with_ids.iter().map(|f| format_ident!("{}", f.id().unwrap())).collect::<Vec<_>>();
    let (constant_width, packet_size) = generate_packet_size_getter(scope, &packet_scope.fields);
    let conforms = if constant_width == 0 {
        quote! { true }
    } else {
        let constant_width = syn::Index::from(constant_width / 8);
        quote! { #span.len() >= #constant_width }
    };

    let has_payload = packet_scope.payload.is_some();
    let children = get_packet_children(scope, id);
    let has_children_or_payload = !children.is_empty() || has_payload;
    let child_field = has_children_or_payload.then(|| quote!(child));

    let data_struct_decl = quote! {
        #[derive(Debug)]
        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
        pub struct #id_data {
            #field_declarations
        }
    };

    let data_struct_impl = quote! {
        impl #id_data {
            fn conforms(#span: &[u8]) -> bool {
                #conforms
            }

            fn parse(mut #span: &mut Cell<&[u8]> #(, #parse_field_names: #parse_field_types)*) -> Result<Self> {
                #field_parser
                Ok(Self {
                    #(#field_names,)*
                    #child_field
                })
            }

            fn write_to(&self, buffer: &mut BytesMut) {
                #field_serializer
            }

            fn get_total_size(&self) -> usize {
                self.get_size()
            }

            fn get_size(&self) -> usize {
                #packet_size
            }
        }
    };

    (data_struct_decl, data_struct_impl)
}

/// Find all parents from `id`.
///
/// This includes the `Decl` for `id` itself.
fn find_parents<'a>(scope: &lint::Scope<'a>, id: &str) -> Vec<&'a parser_ast::Decl> {
    let mut decl = scope.typedef[id];
    let mut parents = vec![decl];
    while let ast::DeclDesc::Packet { parent_id: Some(parent_id), .. }
@@ -177,7 +270,59 @@ fn generate_packet_decl(
        parents.push(decl);
    }
    parents.reverse();
    parents
}

/// Turn the constraint into a value (such as `10` or
/// `SomeEnum::Foo`).
pub fn constraint_to_value(
    packet_scope: &lint::PacketScope<'_>,
    constraint: &ast::Constraint,
) -> proc_macro2::TokenStream {
    match constraint {
        ast::Constraint { value: Some(value), .. } => {
            let value = proc_macro2::Literal::usize_unsuffixed(*value);
            quote!(#value)
        }
        // TODO(mgeisler): include type_id in `ast::Constraint` and
        // drop the packet_scope argument.
        ast::Constraint { tag_id: Some(tag_id), .. } => {
            let type_id = match &packet_scope.all_fields[&constraint.id].desc {
                ast::FieldDesc::Typedef { type_id, .. } => format_ident!("{type_id}"),
                _ => unreachable!("Invalid constraint: {constraint:?}"),
            };
            let tag_id = format_ident!("{}", tag_id.to_upper_camel_case());
            quote!(#type_id::#tag_id)
        }
        _ => unreachable!("Invalid constraint: {constraint:?}"),
    }
}

/// Generate code for `ast::Decl::Packet` and `ast::Decl::Struct`
/// values.
fn generate_packet_decl(
    scope: &lint::Scope<'_>,
    endianness: ast::EndiannessValue,
    id: &str,
) -> proc_macro2::TokenStream {
    let packet_scope = &scope.scopes[&scope.typedef[id]];

    let top_level = top_level_packet(scope, id);
    let top_level_id = top_level.id().unwrap();
    let top_level_packet = format_ident!("{top_level_id}");
    let top_level_data = format_ident!("{top_level_id}Data");
    let top_level_id_lower = format_ident!("{}", top_level_id.to_lowercase());

    // TODO(mgeisler): use the convert_case crate to convert between
    // `FooBar` and `foo_bar` in the code below.
    let span = format_ident!("bytes");
    let id_lower = format_ident!("{}", id.to_lowercase());
    let id_packet = format_ident!("{id}");
    let id_child = format_ident!("{id}Child");
    let id_data_child = format_ident!("{id}DataChild");
    let id_builder = format_ident!("{id}Builder");

    let parents = find_parents(scope, id);
    let parent_ids = parents.iter().map(|p| p.id().unwrap()).collect::<Vec<_>>();
    let parent_shifted_ids = parent_ids.iter().skip(1).map(|id| format_ident!("{id}"));
    let parent_lower_ids =
@@ -235,26 +380,7 @@ fn generate_packet_decl(
        let mut value = named_fields
            .iter()
            .map(|&id| match packet_scope.all_constraints.get(id) {
                Some(constraint) => {
                    let value = match constraint {
                        ast::Constraint { value: Some(value), .. } => {
                            let value = proc_macro2::Literal::usize_unsuffixed(*value);
                            quote!(#value)
                        }
                        ast::Constraint { tag_id: Some(tag_id), .. } => {
                            let type_id = match packet_scope.all_fields.get(id).map(|f| &f.desc) {
                                Some(ast::FieldDesc::Typedef { type_id, .. }) => {
                                    format_ident!("{type_id}")
                                }
                                _ => unreachable!("Invalid constraint: {constraint:?}"),
                            };
                            let tag_id = format_ident!("{}", tag_id.to_upper_camel_case());
                            quote!(#type_id::#tag_id)
                        }
                        _ => unreachable!("Invalid constraint: {constraint:?}"),
                    };
                    quote!(#value)
                }
                Some(constraint) => constraint_to_value(packet_scope, constraint),
                None => {
                    let id = format_ident!("{id}");
                    quote!(self.#id)
@@ -291,7 +417,7 @@ fn generate_packet_decl(
        }
    });

    let children = scope.children.get(id).map(Vec::as_slice).unwrap_or_default();
    let children = get_packet_children(scope, id);
    let has_payload = packet_scope.payload.is_some();
    let has_children_or_payload = !children.is_empty() || has_payload;
    let child =
@@ -336,7 +462,6 @@ fn generate_packet_decl(
            }
        }
    });
    let child_field = has_children_or_payload.then(|| quote!(child));
    let builder_payload_field = has_children_or_payload.then(|| {
        quote! {
            pub payload: Option<Bytes>
@@ -364,22 +489,12 @@ fn generate_packet_decl(
        }
    });

    let (constant_width, packet_size) = generate_packet_size_getter(scope, fields);
    let conforms = if constant_width == 0 {
        quote! { true }
    } else {
        let constant_width = syn::Index::from(constant_width / 8);
        quote! { #span.len() >= #constant_width }
    };
    let (data_struct_decl, data_struct_impl) = generate_data_struct(scope, endianness, id);

    quote! {
        #child_declaration

        #[derive(Debug)]
        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
        pub struct #id_data {
            #field_declarations
        }
        #data_struct_decl

        #[derive(Debug, Clone)]
        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
@@ -397,31 +512,7 @@ fn generate_packet_decl(
            #builder_payload_field
        }

        impl #id_data {
            fn conforms(#span: &[u8]) -> bool {
                #conforms
            }

            fn parse(mut #span: &mut Cell<&[u8]>) -> Result<Self> {
                #field_parser
                Ok(Self {
                    #(#field_names,)*
                    #child_field
                })
            }

            fn write_to(&self, buffer: &mut BytesMut) {
                #field_serializer
            }

            fn get_total_size(&self) -> usize {
                self.get_size()
            }

            fn get_size(&self) -> usize {
                #packet_size
            }
        }
        #data_struct_impl

        impl Packet for #id_packet {
            fn to_bytes(self) -> Bytes {
@@ -573,9 +664,8 @@ fn generate_decl(
    decl: &parser_ast::Decl,
) -> String {
    match &decl.desc {
        ast::DeclDesc::Packet { id, constraints, fields, .. }
        | ast::DeclDesc::Struct { id, constraints, fields, .. } => {
            generate_packet_decl(scope, file.endianness.value, id, constraints, fields).to_string()
        ast::DeclDesc::Packet { id, .. } | ast::DeclDesc::Struct { id, .. } => {
            generate_packet_decl(scope, file.endianness.value, id).to_string()
        }
        ast::DeclDesc::Enum { id, tags, .. } => generate_enum_decl(id, tags).to_string(),
        _ => todo!("unsupported Decl::{:?}", decl),
@@ -609,6 +699,53 @@ mod tests {
    use crate::test_utils::{assert_snapshot_eq, rustfmt};
    use paste::paste;

    /// Parse a string fragment as a PDL file.
    ///
    /// # Panics
    ///
    /// Panics on parse errors.
    pub fn parse_str(text: &str) -> parser_ast::File {
        let mut db = ast::SourceDatabase::new();
        parse_inline(&mut db, String::from("stdin"), String::from(text)).expect("parse error")
    }

    #[track_caller]
    fn assert_iter_eq<T: std::cmp::PartialEq + std::fmt::Debug>(
        left: impl IntoIterator<Item = T>,
        right: impl IntoIterator<Item = T>,
    ) {
        assert_eq!(left.into_iter().collect::<Vec<T>>(), right.into_iter().collect::<Vec<T>>());
    }

    #[test]
    fn test_find_constrained_parent_fields() {
        let code = "
              little_endian_packets
              packet Parent {
                a: 8,
                b: 8,
                c: 8,
              }
              packet Child: Parent(a = 10) {
                x: 8,
              }
              packet GrandChild: Child(b = 20) {
                y: 8,
              }
              packet GrandGrandChild: GrandChild(c = 30) {
                z: 8,
              }
            ";
        let file = parse_str(code);
        let scope = lint::Scope::new(&file).unwrap();
        let find_fields =
            |id| find_constrained_parent_fields(&scope, id).map(|field| field.id().unwrap());
        assert_iter_eq(find_fields("Parent"), vec![]);
        assert_iter_eq(find_fields("Child"), vec!["b", "c"]);
        assert_iter_eq(find_fields("GrandChild"), vec!["c"]);
        assert_iter_eq(find_fields("GrandGrandChild"), vec![]);
    }

    /// Create a unit test for the given PDL `code`.
    ///
    /// The unit test will compare the generated Rust code for all
@@ -898,4 +1035,35 @@ mod tests {
          }
        "
    );

    test_pdl!(
        packet_decl_grand_children,
        "
          enum Enum16 : 16 {
            A = 1,
            B = 2,
          }

          packet Parent {
              foo: Enum16,
              bar: Enum16,
              baz: Enum16,
              _size_(_payload_): 8,
              _payload_
          }

          packet Child : Parent (foo = A) {
              quux: Enum16,
              _payload_,
          }

          packet GrandChild : Child (bar = A, quux = A) {
              _body_,
          }

          packet GrandGrandChild : GrandChild (baz = A) {
              _body_,
          }
        "
    );
}
+13 −21
Original line number Diff line number Diff line
use crate::backends::rust::{mask_bits, types};
use crate::backends::rust::{
    constraint_to_value, find_constrained_parent_fields, mask_bits, types,
};
use crate::parser::ast as parser_ast;
use crate::{ast, lint};
use heck::ToUpperCamelCase;
@@ -571,26 +573,11 @@ impl<'a> FieldParser<'a> {
                ast::DeclDesc::Packet { id, constraints, .. }
                | ast::DeclDesc::Struct { id, constraints, .. } => {
                    for constraint in constraints.iter() {
                        let value = match constraint {
                            ast::Constraint { value: Some(value), .. } => {
                                let value = proc_macro2::Literal::usize_unsuffixed(*value);
                                quote!(#value)
                            }
                            ast::Constraint { id, tag_id: Some(tag_id), .. } => {
                                // TODO: add `type_id` to `Constraint`.
                                let type_id = match &packet_scope.named[id].desc {
                                    ast::FieldDesc::Typedef { type_id, .. } => {
                                        format_ident!("{type_id}")
                                    }
                                    _ => unreachable!("Invalid constraint: {constraint:?}"),
                                };
                                let tag_id = format_ident!("{}", tag_id.to_upper_camel_case());
                                quote!(#type_id::#tag_id)
                            }
                            _ => unreachable!("Invalid constraint: {constraint:?}"),
                        };
                        constrained_fields.insert(&constraint.id);
                        constraint_values.insert((id.as_str(), &constraint.id), value);
                        constraint_values.insert(
                            (id.as_str(), &constraint.id),
                            constraint_to_value(packet_scope, constraint),
                        );
                    }
                }
                _ => unreachable!("Invalid child: {child:?}"),
@@ -609,12 +596,17 @@ impl<'a> FieldParser<'a> {
        });
        let constrained_field_idents =
            constrained_fields.iter().map(|field| format_ident!("{field}"));
        let child_parse_args = children.iter().map(|child| {
            let fields = find_constrained_parent_fields(self.scope, child.id().unwrap())
                .map(|field| format_ident!("{}", field.id().unwrap()));
            quote!(#(, #fields)*)
        });
        let packet_data_child = format_ident!("{}DataChild", self.packet_name);
        self.code.push(quote! {
            let child = match (#(#constrained_field_idents),*) {
                #(#match_values => {
                    let mut cell = Cell::new(payload);
                    let child_data = #child_ids_data::parse(&mut cell)?;
                    let child_data = #child_ids_data::parse(&mut cell #child_parse_args)?;
                    if !cell.get().is_empty() {
                        return Err(Error::InvalidPacketError);
                    }
+1 −1
Original line number Diff line number Diff line
@@ -91,7 +91,7 @@ pub struct PacketScope<'d> {
    pub fields: Vec<&'d parser::ast::Field>,

    // Constraint declarations gathered from Group inlining.
    constraints: HashMap<String, &'d Constraint>,
    pub constraints: HashMap<String, &'d Constraint>,

    // Local and inherited field declarations. Only named fields are preserved.
    // Saved here for reference for parent constraint resolving.
+877 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading