Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 40b64d7f authored by Martin Geisler's avatar Martin Geisler
Browse files

pdl: Rework reading of scalar fields

This CL reverts the code reading scalar fields to match the logic in
the Python backend. This approach will make it simpler to update the
two backends going forward.

Test: atest pdl_tests pdl_rust_generator_tests_{le,be}
Change-Id: I49ddb3dae8eab1bb00cf730eda50a051d326ce2c
parent 47ae88c8
Loading
Loading
Loading
Loading
+12 −10
Original line number Diff line number Diff line
@@ -50,18 +50,20 @@ rust_test_host {
    data: [
        ":rustfmt",
        ":rustfmt.toml",
        "tests/generated/generate_chunk_read_8bit.rs",
        "tests/generated/generate_chunk_read_16bit_be.rs",
        "tests/generated/generate_chunk_read_16bit_le.rs",
        "tests/generated/generate_chunk_read_24bit_be.rs",
        "tests/generated/generate_chunk_read_24bit_le.rs",
        "tests/generated/generate_chunk_read_multiple_fields.rs",
        "tests/generated/packet_decl_complex_big_endian.rs",
        "tests/generated/packet_decl_complex_little_endian.rs",
        "tests/generated/packet_decl_8bit_scalar_big_endian.rs",
        "tests/generated/packet_decl_8bit_scalar_little_endian.rs",
        "tests/generated/packet_decl_24bit_scalar_big_endian.rs",
        "tests/generated/packet_decl_24bit_scalar_little_endian.rs",
        "tests/generated/packet_decl_64bit_scalar_big_endian.rs",
        "tests/generated/packet_decl_64bit_scalar_little_endian.rs",
        "tests/generated/packet_decl_complex_scalars_big_endian.rs",
        "tests/generated/packet_decl_complex_scalars_little_endian.rs",
        "tests/generated/packet_decl_empty_big_endian.rs",
        "tests/generated/packet_decl_empty_little_endian.rs",
        "tests/generated/packet_decl_simple_big_endian.rs",
        "tests/generated/packet_decl_simple_little_endian.rs",
        "tests/generated/packet_decl_mask_scalar_value_big_endian.rs",
        "tests/generated/packet_decl_mask_scalar_value_little_endian.rs",
        "tests/generated/packet_decl_simple_scalars_big_endian.rs",
        "tests/generated/packet_decl_simple_scalars_little_endian.rs",
        "tests/generated/preamble.rs",
    ],
}
+29 −2
Original line number Diff line number Diff line
use crate::lint;
use codespan_reporting::diagnostic;
use codespan_reporting::files;
use serde::Serialize;
@@ -58,7 +59,7 @@ pub struct Tag {
    pub value: usize,
}

#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
#[serde(tag = "kind", rename = "constraint")]
pub struct Constraint {
    pub id: String,
@@ -67,7 +68,7 @@ pub struct Constraint {
    pub tag_id: Option<String>,
}

#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
#[serde(tag = "kind")]
pub enum Field {
    #[serde(rename = "checksum_field")]
@@ -287,6 +288,32 @@ impl Field {
            }
        }
    }

    pub fn is_bitfield(&self, scope: &lint::Scope<'_>) -> bool {
        match self {
            Field::Size { .. }
            | Field::Count { .. }
            | Field::Fixed { .. }
            | Field::Reserved { .. }
            | Field::Scalar { .. } => true,
            Field::Typedef { type_id, .. } => {
                let field = scope.typedef.get(type_id.as_str());
                matches!(field, Some(Decl::Enum { .. }))
            }
            _ => false,
        }
    }

    pub fn width(&self) -> Option<usize> {
        match self {
            Field::Scalar { width, .. }
            | Field::Size { width, .. }
            | Field::Count { width, .. }
            | Field::Reserved { width, .. } => Some(*width),
            // TODO(mgeisler): padding, arrays, etc.
            _ => None,
        }
    }
}

#[cfg(test)]
+113 −176
Original line number Diff line number Diff line
@@ -11,15 +11,16 @@
use crate::{ast, lint};
use quote::{format_ident, quote};
use std::path::Path;
use syn::parse_quote;

mod chunk;
mod field;
mod declarations;
mod parser;
mod preamble;
mod serializer;
mod types;

use chunk::Chunk;
use field::Field;
use declarations::FieldDeclarations;
use parser::FieldParser;
use serializer::FieldSerializer;

/// Generate a block of code.
///
@@ -34,147 +35,86 @@ macro_rules! quote_block {

/// Generate a bit-mask which masks out `n` least significant bits.
pub fn mask_bits(n: usize) -> syn::LitInt {
    syn::parse_str::<syn::LitInt>(&format!("{:#x}", (1u64 << n) - 1)).unwrap()
    // The literal needs a suffix if it's larger than an i32.
    let suffix = if n > 31 { "u64" } else { "" };
    syn::parse_str::<syn::LitInt>(&format!("{:#x}{suffix}", (1u64 << n) - 1)).unwrap()
}

/// Generate code for an `ast::Decl::Packet` enum value.
fn generate_packet_decl(
    scope: &lint::Scope<'_>,
    file: &ast::File,
    //  File:
    endianness: ast::EndiannessValue,
    // Packet:
    id: &str,
    fields: &[Field],
    parent_id: &Option<String>,
) -> String {
    _constraints: &[ast::Constraint],
    fields: &[ast::Field],
    _parent_id: Option<&str>,
) -> proc_macro2::TokenStream {
    // TODO(mgeisler): use the convert_case crate to convert between
    // `FooBar` and `foo_bar` in the code below.
    let mut code = String::new();
    let child_ids = scope
        .typedef
        .values()
        .filter_map(|p| match p {
            ast::Decl::Packet { id, parent_id, .. } if parent_id.as_deref() == Some(id) => Some(id),
            _ => None,
        })
    let span = format_ident!("bytes");
    let serializer_span = format_ident!("buffer");
    let mut field_declarations = FieldDeclarations::new();
    let mut field_parser = FieldParser::new(scope, endianness, id, &span);
    let mut field_serializer = FieldSerializer::new(scope, endianness, id, &serializer_span);
    for field in fields {
        field_declarations.add(field);
        field_parser.add(field);
        field_serializer.add(field);
    }
    field_parser.done();

    let id_lower = format_ident!("{}", id.to_lowercase());
    let id_data = format_ident!("{id}Data");
    let id_packet = format_ident!("{id}Packet");
    let id_builder = format_ident!("{id}Builder");

    let field_names =
        fields.iter().filter_map(|f| f.id()).map(|id| format_ident!("{id}")).collect::<Vec<_>>();
    let field_types = fields
        .iter()
        .filter_map(|f| f.width())
        .map(|w| format_ident!("u{}", types::Integer::new(w).width))
        .collect::<Vec<_>>();
    let has_children = !child_ids.is_empty();
    let child_idents = child_ids.iter().map(|id| format_ident!("{id}")).collect::<Vec<_>>();

    let ident = format_ident!("{}", id.to_lowercase());
    let data_child_ident = format_ident!("{id}DataChild");
    let child_decl_packet_name =
        child_idents.iter().map(|ident| format_ident!("{ident}Packet")).collect::<Vec<_>>();
    let child_name = format_ident!("{id}Child");
    if has_children {
        let child_data_idents = child_idents.iter().map(|ident| format_ident!("{ident}Data"));
        code.push_str(&quote_block! {
            #[derive(Debug)]
            enum #data_child_ident {
                #(#child_idents(Arc<#child_data_idents>),)*
                None,
            }

            impl #data_child_ident {
                fn get_total_size(&self) -> usize {
                    // TODO(mgeisler): use Self instad of #data_child_ident.
                    match self {
                        #(#data_child_ident::#child_idents(value) => value.get_total_size(),)*
                        #data_child_ident::None => 0,
                    }
                }
            }
    let getter_names = field_names.iter().map(|id| format_ident!("get_{id}"));

            #[derive(Debug)]
            pub enum #child_name {
                #(#child_idents(#child_decl_packet_name),)*
                None,
            }
        });
    }
    let packet_size = syn::Index::from(fields.iter().filter_map(|f| f.width()).sum::<usize>() / 8);
    let conforms = if packet_size.index == 0 {
        quote! { true }
    } else {
        quote! { #span.len() >= #packet_size }
    };

    let data_name = format_ident!("{id}Data");
    let child_field = has_children.then(|| {
    quote! {
            child: #data_child_ident,
        }
    });
    let plain_fields = fields.iter().map(|field| field.generate_decl(parse_quote!()));
    code.push_str(&quote_block! {
    #[derive(Debug)]
        struct #data_name {
            #(#plain_fields,)*
            #child_field
        struct #id_data {
            #field_declarations
        }
    });

    let parent = parent_id.as_ref().map(|parent_id| match scope.typedef.get(parent_id.as_str()) {
        Some(ast::Decl::Packet { id, .. }) => {
            let parent_ident = format_ident!("{}", id.to_lowercase());
            let parent_data = format_ident!("{id}Data");
            quote! {
                #parent_ident: Arc<#parent_data>,
            }
        }
        _ => panic!("Could not find {parent_id}"),
    });

    let packet_name = format_ident!("{id}Packet");
    code.push_str(&quote_block! {
        #[derive(Debug, Clone)]
        pub struct #packet_name {
            #parent
            #ident: Arc<#data_name>,
        pub struct #id_packet {
            #id_lower: Arc<#id_data>,
        }
    });

    let builder_name = format_ident!("{id}Builder");
    let pub_fields = fields.iter().map(|field| field.generate_decl(parse_quote!(pub)));
    code.push_str(&quote_block! {
        #[derive(Debug)]
        pub struct #builder_name {
            #(#pub_fields,)*
        }
    });

    let mut chunk_width = 0;
    let chunks = fields.split_inclusive(|field| {
        chunk_width += field.width();
        chunk_width % 8 == 0
    });
    let mut field_parsers = Vec::new();
    let mut field_writers = Vec::new();
    for fields in chunks {
        let chunk = Chunk::new(fields);
        field_parsers.push(chunk.generate_read(id, file.endianness.value));
        field_writers.push(chunk.generate_write(file.endianness.value));
    }

    let field_names = fields.iter().map(Field::ident).collect::<Vec<_>>();

    let packet_size_bits = Chunk::new(fields).width();
    if packet_size_bits % 8 != 0 {
        panic!("packet {id} does not end on a byte boundary, size: {packet_size_bits} bits",);
        pub struct #id_builder {
            #(pub #field_names: #field_types),*
        }
    let packet_size_bytes = syn::Index::from(packet_size_bits / 8);

    let conforms = if packet_size_bytes.index == 0 {
        quote! { true }
    } else {
        quote! { bytes.len() >= #packet_size_bytes }
    };

    code.push_str(&quote_block! {
        impl #data_name {
            fn conforms(bytes: &[u8]) -> bool {
        impl #id_data {
            fn conforms(#span: &[u8]) -> bool {
                #conforms
            }

            fn parse(mut bytes: &[u8]) -> Result<Self> {
                #(#field_parsers)*
            fn parse(mut #span: &[u8]) -> Result<Self> {
                #field_parser
                Ok(Self { #(#field_names),* })
            }

            fn write_to(&self, buffer: &mut BytesMut) {
                #(#field_writers)*
                #field_serializer
            }

            fn get_total_size(&self) -> usize {
@@ -182,90 +122,70 @@ fn generate_packet_decl(
            }

            fn get_size(&self) -> usize {
                #packet_size_bytes
                #packet_size
            }
        }
    });

    code.push_str(&quote_block! {
        impl Packet for #packet_name {
        impl Packet for #id_packet {
            fn to_bytes(self) -> Bytes {
                let mut buffer = BytesMut::with_capacity(self.#ident.get_total_size());
                self.#ident.write_to(&mut buffer);
                let mut buffer = BytesMut::with_capacity(self.#id_lower.get_total_size());
                self.#id_lower.write_to(&mut buffer);
                buffer.freeze()
            }

            fn to_vec(self) -> Vec<u8> {
                self.to_bytes().to_vec()
            }
        }
        impl From<#packet_name> for Bytes {
            fn from(packet: #packet_name) -> Self {

        impl From<#id_packet> for Bytes {
            fn from(packet: #id_packet) -> Self {
                packet.to_bytes()
            }
        }
        impl From<#packet_name> for Vec<u8> {
            fn from(packet: #packet_name) -> Self {

        impl From<#id_packet> for Vec<u8> {
            fn from(packet: #id_packet) -> Self {
                packet.to_vec()
            }
        }
    });

    let specialize = has_children.then(|| {
        quote! {
            pub fn specialize(&self) -> #child_name {
                match &self.#ident.child {
                    #(#data_child_ident::#child_idents(_) =>
                      #child_name::#child_idents(
                          #child_decl_packet_name::new(self.#ident.clone()).unwrap()),)*
                    #data_child_ident::None => #child_name::None,
                }
            }
        }
    });
    let field_getters = fields.iter().map(|field| field.generate_getter(&ident));
    code.push_str(&quote_block! {
        impl #packet_name {
        impl #id_packet {
            pub fn parse(mut bytes: &[u8]) -> Result<Self> {
                Ok(Self::new(Arc::new(#data_name::parse(bytes)?)).unwrap())
                Ok(Self::new(Arc::new(#id_data::parse(bytes)?)).unwrap())
            }

            #specialize

            fn new(root: Arc<#data_name>) -> std::result::Result<Self, &'static str> {
                let #ident = root;
                Ok(Self { #ident })
            fn new(root: Arc<#id_data>) -> std::result::Result<Self, &'static str> {
                let #id_lower = root;
                Ok(Self { #id_lower })
            }

            #(#field_getters)*
            #(pub fn #getter_names(&self) -> #field_types {
                self.#id_lower.as_ref().#field_names
            })*
        }
    });

    let child = has_children.then(|| {
        quote! {
            child: #data_child_ident::None,
        }
    });
    code.push_str(&quote_block! {
        impl #builder_name {
            pub fn build(self) -> #packet_name {
                let #ident = Arc::new(#data_name {
                    #(#field_names: self.#field_names,)*
                    #child
        impl #id_builder {
            pub fn build(self) -> #id_packet {
                let #id_lower = Arc::new(#id_data {
                    #(#field_names: self.#field_names),*
                });
                #packet_name::new(#ident).unwrap()
                #id_packet::new(#id_lower).unwrap()
            }
        }
    }
    });

    code
}

fn generate_decl(scope: &lint::Scope<'_>, file: &ast::File, decl: &ast::Decl) -> String {
    match decl {
        ast::Decl::Packet { id, fields, parent_id, .. } => {
            let fields = fields.iter().map(Field::from).collect::<Vec<_>>();
            generate_packet_decl(scope, file, id, &fields, parent_id)
        }
        ast::Decl::Packet { id, constraints, fields, parent_id, .. } => generate_packet_decl(
            scope,
            file.endianness.value,
            id,
            constraints,
            fields,
            parent_id.as_deref(),
        )
        .to_string(),
        _ => todo!("unsupported Decl::{:?}", decl),
    }
}
@@ -345,8 +265,12 @@ mod tests {

    test_pdl!(packet_decl_empty, "packet Foo {}");

    test_pdl!(packet_decl_8bit_scalar, " packet Foo { x:  8 }");
    test_pdl!(packet_decl_24bit_scalar, "packet Foo { x: 24 }");
    test_pdl!(packet_decl_64bit_scalar, "packet Foo { x: 64 }");

    test_pdl!(
        packet_decl_simple,
        packet_decl_simple_scalars,
        r#"
          packet Foo {
            x: 8,
@@ -357,7 +281,7 @@ mod tests {
    );

    test_pdl!(
        packet_decl_complex,
        packet_decl_complex_scalars,
        r#"
          packet Foo {
            a: 3,
@@ -369,4 +293,17 @@ mod tests {
          }
        "#,
    );

    // Test that we correctly mask a byte-sized value in the middle of
    // a chunk.
    test_pdl!(
        packet_decl_mask_scalar_value,
        r#"
          packet Foo {
            a: 2,
            b: 24,
            c: 6,
          }
        "#,
    );
}
+0 −417

File deleted.

Preview size limit exceeded, changes collapsed.

+35 −0
Original line number Diff line number Diff line
use crate::ast;
use crate::backends::rust::types;
use quote::{format_ident, quote};

pub struct FieldDeclarations {
    code: Vec<proc_macro2::TokenStream>,
}

impl FieldDeclarations {
    pub fn new() -> FieldDeclarations {
        FieldDeclarations { code: Vec::new() }
    }

    pub fn add(&mut self, field: &ast::Field) {
        self.code.push(match field {
            ast::Field::Scalar { id, width, .. } => {
                let id = format_ident!("{id}");
                let field_type = types::Integer::new(*width);
                quote! {
                    #id: #field_type,
                }
            }
            _ => todo!(),
        });
    }
}

impl quote::ToTokens for FieldDeclarations {
    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
        let code = &self.code;
        tokens.extend(quote! {
            #(#code)*
        });
    }
}
Loading