Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit a4a73d75 authored by Martin Geisler's avatar Martin Geisler
Browse files

pdl: Add support for scalar and enum arrays

This implements support for integer and enum arrays with static and
dynamic sizes.

The parse function on the data structs changed its argument from
‘&[u8]’ to ‘Cell<&[u8]>’. This allows us to pass around a single slice
which we mutate in-place. The ‘Cell’ type is a no-overhead type which
cannot panic, so this simply allows us to mutate the slice in-place.

The in-place mutations allow us to eat away the slice as we parse. The
alternative would be to make all parse functions return two values:
the parsed result and the remaining slice. That would in turn
complicate the logic where we collect results.

Another alternative would be to make the parsing stateful by storing
the slice in a struct (and that’s actually what ‘Cell<&[u8]>’ does).

Test: atest pdl_tests pdl_rust_generator_tests_{le,be}
Change-Id: Idc13ab75c5d6c4704f7edb70306be761ee815cab
parent 66d44917
Loading
Loading
Loading
Loading
+20 −0
Original line number Diff line number Diff line
@@ -50,18 +50,38 @@ rust_test_host {
    data: [
        ":rustfmt",
        ":rustfmt.toml",
        "tests/generated/packet_decl_8bit_enum_array_big_endian.rs",
        "tests/generated/packet_decl_8bit_enum_array_little_endian.rs",
        "tests/generated/packet_decl_8bit_enum_big_endian.rs",
        "tests/generated/packet_decl_8bit_enum_little_endian.rs",
        "tests/generated/packet_decl_8bit_scalar_array_big_endian.rs",
        "tests/generated/packet_decl_8bit_scalar_array_little_endian.rs",
        "tests/generated/packet_decl_8bit_scalar_big_endian.rs",
        "tests/generated/packet_decl_8bit_scalar_little_endian.rs",
        "tests/generated/packet_decl_24bit_enum_array_big_endian.rs",
        "tests/generated/packet_decl_24bit_enum_array_little_endian.rs",
        "tests/generated/packet_decl_24bit_enum_big_endian.rs",
        "tests/generated/packet_decl_24bit_enum_little_endian.rs",
        "tests/generated/packet_decl_24bit_scalar_array_big_endian.rs",
        "tests/generated/packet_decl_24bit_scalar_array_little_endian.rs",
        "tests/generated/packet_decl_24bit_scalar_big_endian.rs",
        "tests/generated/packet_decl_24bit_scalar_little_endian.rs",
        "tests/generated/packet_decl_64bit_enum_array_big_endian.rs",
        "tests/generated/packet_decl_64bit_enum_array_little_endian.rs",
        "tests/generated/packet_decl_64bit_enum_big_endian.rs",
        "tests/generated/packet_decl_64bit_enum_little_endian.rs",
        "tests/generated/packet_decl_64bit_scalar_array_big_endian.rs",
        "tests/generated/packet_decl_64bit_scalar_array_little_endian.rs",
        "tests/generated/packet_decl_64bit_scalar_big_endian.rs",
        "tests/generated/packet_decl_64bit_scalar_little_endian.rs",
        "tests/generated/packet_decl_array_dynamic_count_big_endian.rs",
        "tests/generated/packet_decl_array_dynamic_count_little_endian.rs",
        "tests/generated/packet_decl_array_dynamic_size_big_endian.rs",
        "tests/generated/packet_decl_array_dynamic_size_little_endian.rs",
        "tests/generated/packet_decl_array_unknown_element_width_dynamic_count_big_endian.rs",
        "tests/generated/packet_decl_array_unknown_element_width_dynamic_count_little_endian.rs",
        "tests/generated/packet_decl_array_unknown_element_width_dynamic_size_big_endian.rs",
        "tests/generated/packet_decl_array_unknown_element_width_dynamic_size_little_endian.rs",
        "tests/generated/packet_decl_complex_scalars_big_endian.rs",
        "tests/generated/packet_decl_complex_scalars_little_endian.rs",
        "tests/generated/packet_decl_empty_big_endian.rs",
+51 −6
Original line number Diff line number Diff line
@@ -259,6 +259,31 @@ impl<A: Annotation> Decl<A> {
        }
    }

    /// Determine the size of a declaration type in bits, if possible.
    ///
    /// If the type is dynamically sized (e.g. contains an array or
    /// payload), `None` is returned. If `skip_payload` is set,
    /// payload and body fields are counted as having size `0` rather
    /// than a variable size.
    pub fn width(&self, scope: &lint::Scope<'_>, skip_payload: bool) -> Option<usize> {
        match &self.desc {
            DeclDesc::Enum { width, .. } | DeclDesc::Checksum { width, .. } => Some(*width),
            DeclDesc::CustomField { width, .. } => *width,
            DeclDesc::Packet { fields, parent_id, .. }
            | DeclDesc::Struct { fields, parent_id, .. } => {
                let mut packet_size = match parent_id {
                    None => 0,
                    Some(id) => scope.typedef.get(id.as_str())?.width(scope, true)?,
                };
                for field in fields.iter() {
                    packet_size += field.width(scope, skip_payload)?;
                }
                Some(packet_size)
            }
            DeclDesc::Group { .. } | DeclDesc::Test { .. } => None,
        }
    }

    pub fn new(loc: SourceRange, desc: DeclDesc<A>) -> Decl<A> {
        Decl { loc, annot: Default::default(), desc }
    }
@@ -301,18 +326,38 @@ impl<A: Annotation> Field<A> {
        }
    }

    pub fn width(&self, scope: &lint::Scope<'_>) -> Option<usize> {
    pub fn declaration<'a>(
        &self,
        scope: &'a lint::Scope<'a>,
    ) -> Option<&'a crate::parser::ast::Decl> {
        match &self.desc {
            FieldDesc::FixedEnum { enum_id, .. } => scope.typedef.get(enum_id).copied(),
            FieldDesc::Array { type_id: Some(type_id), .. } => scope.typedef.get(type_id).copied(),
            FieldDesc::Typedef { type_id, .. } => scope.typedef.get(type_id.as_str()).copied(),
            _ => None,
        }
    }

    /// Determine the size of a field in bits, if possible.
    ///
    /// If the field is dynamically sized (e.g. unsized array or
    /// payload field), `None` is returned. If `skip_payload` is set,
    /// payload and body fields are counted as having size `0` rather
    /// than a variable size.
    pub fn width(&self, scope: &lint::Scope<'_>, skip_payload: bool) -> Option<usize> {
        match &self.desc {
            FieldDesc::Scalar { width, .. }
            | FieldDesc::Size { width, .. }
            | FieldDesc::Count { width, .. }
            | FieldDesc::ElementSize { width, .. }
            | FieldDesc::Reserved { width, .. } => Some(*width),
            FieldDesc::Typedef { type_id, .. } => match scope.typedef.get(type_id.as_str()) {
                Some(Decl { desc: DeclDesc::Enum { width, .. }, .. }) => Some(*width),
                _ => None,
            },
            // TODO(mgeisler): padding, arrays, etc.
            FieldDesc::Array { size: Some(size), width, .. } => {
                let width = width.or_else(|| self.declaration(scope)?.width(scope, false))?;
                Some(width * size)
            }
            FieldDesc::Typedef { .. } => self.declaration(scope)?.width(scope, false),
            FieldDesc::Checksum { .. } => Some(0),
            FieldDesc::Payload { .. } | FieldDesc::Body { .. } if skip_payload => Some(0),
            _ => None,
        }
    }
+167 −9
Original line number Diff line number Diff line
@@ -42,6 +42,76 @@ pub fn mask_bits(n: usize) -> syn::LitInt {
    syn::parse_str::<syn::LitInt>(&format!("{:#x}{suffix}", (1u64 << n) - 1)).unwrap()
}

fn generate_packet_size_getter(
    scope: &lint::Scope<'_>,
    fields: &[parser_ast::Field],
) -> (usize, proc_macro2::TokenStream) {
    let mut constant_width = 0;
    let mut dynamic_widths = Vec::new();

    for field in fields {
        if let Some(width) = field.width(scope, false) {
            constant_width += width;
            continue;
        }

        let decl = field.declaration(scope);
        dynamic_widths.push(match &field.desc {
            ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. } => {
                quote!(self.payload.len())
            }
            ast::FieldDesc::Typedef { id, .. } => {
                let id = format_ident!("{id}");
                quote!(self.#id.get_size())
            }
            ast::FieldDesc::Array { id, width, .. } => {
                let id = format_ident!("{id}");
                match &decl {
                    Some(parser_ast::Decl {
                        desc: ast::DeclDesc::Struct { .. } | ast::DeclDesc::CustomField { .. },
                        ..
                    }) => {
                        quote! {
                            self.#id.iter().map(|elem| elem.get_size()).sum::<usize>()
                        }
                    }
                    Some(parser_ast::Decl { desc: ast::DeclDesc::Enum { .. }, .. }) => {
                        let width =
                            syn::Index::from(decl.unwrap().width(scope, false).unwrap() / 8);
                        let mul_width = (width.index > 1).then(|| quote!(* #width));
                        quote! {
                            self.#id.len() #mul_width
                        }
                    }
                    _ => {
                        let width = syn::Index::from(width.unwrap() / 8);
                        let mul_width = (width.index > 1).then(|| quote!(* #width));
                        quote! {
                            self.#id.len() #mul_width
                        }
                    }
                }
            }
            _ => panic!("Unsupported field type: {field:?}"),
        });
    }

    if constant_width > 0 {
        let width = syn::Index::from(constant_width / 8);
        dynamic_widths.insert(0, quote!(#width));
    }
    if dynamic_widths.is_empty() {
        dynamic_widths.push(quote!(0))
    }

    (
        constant_width,
        quote! {
            #(#dynamic_widths)+*
        },
    )
}

/// Generate code for `ast::Decl::Packet` and `ast::Decl::Struct`
/// values.
fn generate_packet_decl(
@@ -77,14 +147,15 @@ fn generate_packet_decl(
    let field_names =
        fields_with_ids.iter().map(|f| format_ident!("{}", f.id().unwrap())).collect::<Vec<_>>();
    let field_types = fields_with_ids.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
    let field_borrows = fields_with_ids.iter().map(|f| types::rust_borrow(f)).collect::<Vec<_>>();
    let getter_names = field_names.iter().map(|id| format_ident!("get_{id}"));

    let packet_size =
        syn::Index::from(fields.iter().filter_map(|f| f.width(scope)).sum::<usize>() / 8);
    let conforms = if packet_size.index == 0 {
    let (constant_width, packet_size) = generate_packet_size_getter(scope, fields);
    let conforms = if constant_width == 0 {
        quote! { true }
    } else {
        quote! { #span.len() >= #packet_size }
        let constant_width = syn::Index::from(constant_width / 8);
        quote! { #span.len() >= #constant_width }
    };

    quote! {
@@ -112,7 +183,7 @@ fn generate_packet_decl(
                #conforms
            }

            fn parse(mut #span: &[u8]) -> Result<Self> {
            fn parse(mut #span: &mut Cell<&[u8]>) -> Result<Self> {
                #field_parser
                Ok(Self { #(#field_names),* })
            }
@@ -155,17 +226,35 @@ fn generate_packet_decl(
        }

        impl #id_packet {
            pub fn parse(mut bytes: &[u8]) -> Result<Self> {
                Ok(Self::new(Arc::new(#id_data::parse(bytes)?)).unwrap())
            pub fn parse(#span: &[u8]) -> Result<Self> {
                let mut cell = Cell::new(#span);
                let packet = Self::parse_inner(&mut cell)?;
                if !cell.get().is_empty() {
                    return Err(Error::InvalidPacketError);
                }
                Ok(packet)
            }

            fn parse_inner(mut bytes: &mut Cell<&[u8]>) -> Result<Self> {
                let packet = #id_data::parse(&mut bytes)?;
                Ok(Self::new(Arc::new(packet)).unwrap())
            }
            fn new(root: Arc<#id_data>) -> std::result::Result<Self, &'static str> {
                let #id_lower = root;
                Ok(Self { #id_lower })
            }

            #(pub fn #getter_names(&self) -> #field_types {
                self.#id_lower.as_ref().#field_names
            #(pub fn #getter_names(&self) -> #field_borrows #field_types {
                #field_borrows self.#id_lower.as_ref().#field_names
            })*

            fn write_to(&self, buffer: &mut BytesMut) {
                self.#id_lower.write_to(buffer)
            }

            pub fn get_size(&self) -> usize {
                self.#id_lower.get_size()
            }
        }

        impl #id_builder {
@@ -417,6 +506,75 @@ mod tests {
        "
    );

    test_pdl!(packet_decl_8bit_scalar_array, " packet Foo { x:  8[3] }");
    test_pdl!(packet_decl_24bit_scalar_array, "packet Foo { x: 24[5] }");
    test_pdl!(packet_decl_64bit_scalar_array, "packet Foo { x: 64[7] }");

    test_pdl!(
        packet_decl_8bit_enum_array,
        "enum Foo :  8 { A = 1, B = 2 } packet Bar { x: Foo[3] }"
    );
    test_pdl!(
        packet_decl_24bit_enum_array,
        "enum Foo : 24 { A = 1, B = 2 } packet Bar { x: Foo[5] }"
    );
    test_pdl!(
        packet_decl_64bit_enum_array,
        "enum Foo : 64 { A = 1, B = 2 } packet Bar { x: Foo[7] }"
    );

    test_pdl!(
        packet_decl_array_dynamic_count,
        "
          packet Foo {
            _count_(x): 5,
            padding: 3,
            x: 24[]
          }
        "
    );

    test_pdl!(
        packet_decl_array_dynamic_size,
        "
          packet Foo {
            _size_(x): 5,
            padding: 3,
            x: 24[]
          }
        "
    );

    test_pdl!(
        packet_decl_array_unknown_element_width_dynamic_size,
        "
          struct Foo {
            _count_(a): 8,
            a: 16[],
          }

          packet Bar {
            _size_(x): 8,
            x: Foo[],
          }
        "
    );

    test_pdl!(
        packet_decl_array_unknown_element_width_dynamic_count,
        "
          struct Foo {
            _count_(a): 8,
            a: 16[],
          }

          packet Bar {
            _count_(x): 8,
            x: Foo[],
          }
        "
    );

    test_pdl!(
        packet_decl_reserved_field,
        "
+9 −22
Original line number Diff line number Diff line
use crate::ast;
use crate::backends::rust::types;
use crate::parser::ast as parser_ast;
use quote::{format_ident, quote};
@@ -13,27 +12,15 @@ impl FieldDeclarations {
    }

    pub fn add(&mut self, field: &parser_ast::Field) {
        self.code.push(match &field.desc {
            ast::FieldDesc::Scalar { id, width } => {
                let id = format_ident!("{id}");
                let field_type = types::Integer::new(*width);
                quote! {
                    #id: #field_type,
                }
            }
            ast::FieldDesc::Typedef { id, type_id } => {
                let id = format_ident!("{id}");
                let field_type = format_ident!("{type_id}");
                quote! {
        let id = match field.id() {
            Some(id) => format_ident!("{id}"),
            None => return, // No id => field not stored.
        };

        let field_type = types::rust_type(field);
        self.code.push(quote! {
            #id: #field_type,
                }
            }
            ast::FieldDesc::Reserved { .. } => {
                // Nothing to do here.
                quote! {}
            }
            _ => todo!(),
        });
        })
    }
}

+329 −17

File changed.

Preview size limit exceeded, changes collapsed.

Loading