Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e002a8ff authored by Martin Geisler's avatar Martin Geisler Committed by Gerrit Code Review
Browse files

Merge "pdl: Rework reading of scalar fields"

parents b2df79f1 40b64d7f
Loading
Loading
Loading
Loading
+12 −10
Original line number Diff line number Diff line
@@ -50,18 +50,20 @@ rust_test_host {
    data: [
        ":rustfmt",
        ":rustfmt.toml",
        "tests/generated/generate_chunk_read_8bit.rs",
        "tests/generated/generate_chunk_read_16bit_be.rs",
        "tests/generated/generate_chunk_read_16bit_le.rs",
        "tests/generated/generate_chunk_read_24bit_be.rs",
        "tests/generated/generate_chunk_read_24bit_le.rs",
        "tests/generated/generate_chunk_read_multiple_fields.rs",
        "tests/generated/packet_decl_complex_big_endian.rs",
        "tests/generated/packet_decl_complex_little_endian.rs",
        "tests/generated/packet_decl_8bit_scalar_big_endian.rs",
        "tests/generated/packet_decl_8bit_scalar_little_endian.rs",
        "tests/generated/packet_decl_24bit_scalar_big_endian.rs",
        "tests/generated/packet_decl_24bit_scalar_little_endian.rs",
        "tests/generated/packet_decl_64bit_scalar_big_endian.rs",
        "tests/generated/packet_decl_64bit_scalar_little_endian.rs",
        "tests/generated/packet_decl_complex_scalars_big_endian.rs",
        "tests/generated/packet_decl_complex_scalars_little_endian.rs",
        "tests/generated/packet_decl_empty_big_endian.rs",
        "tests/generated/packet_decl_empty_little_endian.rs",
        "tests/generated/packet_decl_simple_big_endian.rs",
        "tests/generated/packet_decl_simple_little_endian.rs",
        "tests/generated/packet_decl_mask_scalar_value_big_endian.rs",
        "tests/generated/packet_decl_mask_scalar_value_little_endian.rs",
        "tests/generated/packet_decl_simple_scalars_big_endian.rs",
        "tests/generated/packet_decl_simple_scalars_little_endian.rs",
        "tests/generated/preamble.rs",
    ],
}
+29 −2
Original line number Diff line number Diff line
use crate::lint;
use codespan_reporting::diagnostic;
use codespan_reporting::files;
use serde::Serialize;
@@ -58,7 +59,7 @@ pub struct Tag {
    pub value: usize,
}

#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
#[serde(tag = "kind", rename = "constraint")]
pub struct Constraint {
    pub id: String,
@@ -67,7 +68,7 @@ pub struct Constraint {
    pub tag_id: Option<String>,
}

#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Clone)]
#[serde(tag = "kind")]
pub enum Field {
    #[serde(rename = "checksum_field")]
@@ -287,6 +288,32 @@ impl Field {
            }
        }
    }

    pub fn is_bitfield(&self, scope: &lint::Scope<'_>) -> bool {
        match self {
            Field::Size { .. }
            | Field::Count { .. }
            | Field::Fixed { .. }
            | Field::Reserved { .. }
            | Field::Scalar { .. } => true,
            Field::Typedef { type_id, .. } => {
                let field = scope.typedef.get(type_id.as_str());
                matches!(field, Some(Decl::Enum { .. }))
            }
            _ => false,
        }
    }

    pub fn width(&self) -> Option<usize> {
        match self {
            Field::Scalar { width, .. }
            | Field::Size { width, .. }
            | Field::Count { width, .. }
            | Field::Reserved { width, .. } => Some(*width),
            // TODO(mgeisler): padding, arrays, etc.
            _ => None,
        }
    }
}

#[cfg(test)]
+113 −176
Original line number Diff line number Diff line
@@ -11,15 +11,16 @@
use crate::{ast, lint};
use quote::{format_ident, quote};
use std::path::Path;
use syn::parse_quote;

mod chunk;
mod field;
mod declarations;
mod parser;
mod preamble;
mod serializer;
mod types;

use chunk::Chunk;
use field::Field;
use declarations::FieldDeclarations;
use parser::FieldParser;
use serializer::FieldSerializer;

/// Generate a block of code.
///
@@ -34,147 +35,86 @@ macro_rules! quote_block {

/// Generate a bit-mask which masks out `n` least significant bits.
pub fn mask_bits(n: usize) -> syn::LitInt {
    syn::parse_str::<syn::LitInt>(&format!("{:#x}", (1u64 << n) - 1)).unwrap()
    // The literal needs a suffix if it's larger than an i32.
    let suffix = if n > 31 { "u64" } else { "" };
    syn::parse_str::<syn::LitInt>(&format!("{:#x}{suffix}", (1u64 << n) - 1)).unwrap()
}

/// Generate code for an `ast::Decl::Packet` enum value.
fn generate_packet_decl(
    scope: &lint::Scope<'_>,
    file: &ast::File,
    //  File:
    endianness: ast::EndiannessValue,
    // Packet:
    id: &str,
    fields: &[Field],
    parent_id: &Option<String>,
) -> String {
    _constraints: &[ast::Constraint],
    fields: &[ast::Field],
    _parent_id: Option<&str>,
) -> proc_macro2::TokenStream {
    // TODO(mgeisler): use the convert_case crate to convert between
    // `FooBar` and `foo_bar` in the code below.
    let mut code = String::new();
    let child_ids = scope
        .typedef
        .values()
        .filter_map(|p| match p {
            ast::Decl::Packet { id, parent_id, .. } if parent_id.as_deref() == Some(id) => Some(id),
            _ => None,
        })
    let span = format_ident!("bytes");
    let serializer_span = format_ident!("buffer");
    let mut field_declarations = FieldDeclarations::new();
    let mut field_parser = FieldParser::new(scope, endianness, id, &span);
    let mut field_serializer = FieldSerializer::new(scope, endianness, id, &serializer_span);
    for field in fields {
        field_declarations.add(field);
        field_parser.add(field);
        field_serializer.add(field);
    }
    field_parser.done();

    let id_lower = format_ident!("{}", id.to_lowercase());
    let id_data = format_ident!("{id}Data");
    let id_packet = format_ident!("{id}Packet");
    let id_builder = format_ident!("{id}Builder");

    let field_names =
        fields.iter().filter_map(|f| f.id()).map(|id| format_ident!("{id}")).collect::<Vec<_>>();
    let field_types = fields
        .iter()
        .filter_map(|f| f.width())
        .map(|w| format_ident!("u{}", types::Integer::new(w).width))
        .collect::<Vec<_>>();
    let has_children = !child_ids.is_empty();
    let child_idents = child_ids.iter().map(|id| format_ident!("{id}")).collect::<Vec<_>>();

    let ident = format_ident!("{}", id.to_lowercase());
    let data_child_ident = format_ident!("{id}DataChild");
    let child_decl_packet_name =
        child_idents.iter().map(|ident| format_ident!("{ident}Packet")).collect::<Vec<_>>();
    let child_name = format_ident!("{id}Child");
    if has_children {
        let child_data_idents = child_idents.iter().map(|ident| format_ident!("{ident}Data"));
        code.push_str(&quote_block! {
            #[derive(Debug)]
            enum #data_child_ident {
                #(#child_idents(Arc<#child_data_idents>),)*
                None,
            }

            impl #data_child_ident {
                fn get_total_size(&self) -> usize {
                    // TODO(mgeisler): use Self instad of #data_child_ident.
                    match self {
                        #(#data_child_ident::#child_idents(value) => value.get_total_size(),)*
                        #data_child_ident::None => 0,
                    }
                }
            }
    let getter_names = field_names.iter().map(|id| format_ident!("get_{id}"));

            #[derive(Debug)]
            pub enum #child_name {
                #(#child_idents(#child_decl_packet_name),)*
                None,
            }
        });
    }
    let packet_size = syn::Index::from(fields.iter().filter_map(|f| f.width()).sum::<usize>() / 8);
    let conforms = if packet_size.index == 0 {
        quote! { true }
    } else {
        quote! { #span.len() >= #packet_size }
    };

    let data_name = format_ident!("{id}Data");
    let child_field = has_children.then(|| {
    quote! {
            child: #data_child_ident,
        }
    });
    let plain_fields = fields.iter().map(|field| field.generate_decl(parse_quote!()));
    code.push_str(&quote_block! {
    #[derive(Debug)]
        struct #data_name {
            #(#plain_fields,)*
            #child_field
        struct #id_data {
            #field_declarations
        }
    });

    let parent = parent_id.as_ref().map(|parent_id| match scope.typedef.get(parent_id.as_str()) {
        Some(ast::Decl::Packet { id, .. }) => {
            let parent_ident = format_ident!("{}", id.to_lowercase());
            let parent_data = format_ident!("{id}Data");
            quote! {
                #parent_ident: Arc<#parent_data>,
            }
        }
        _ => panic!("Could not find {parent_id}"),
    });

    let packet_name = format_ident!("{id}Packet");
    code.push_str(&quote_block! {
        #[derive(Debug, Clone)]
        pub struct #packet_name {
            #parent
            #ident: Arc<#data_name>,
        pub struct #id_packet {
            #id_lower: Arc<#id_data>,
        }
    });

    let builder_name = format_ident!("{id}Builder");
    let pub_fields = fields.iter().map(|field| field.generate_decl(parse_quote!(pub)));
    code.push_str(&quote_block! {
        #[derive(Debug)]
        pub struct #builder_name {
            #(#pub_fields,)*
        }
    });

    let mut chunk_width = 0;
    let chunks = fields.split_inclusive(|field| {
        chunk_width += field.width();
        chunk_width % 8 == 0
    });
    let mut field_parsers = Vec::new();
    let mut field_writers = Vec::new();
    for fields in chunks {
        let chunk = Chunk::new(fields);
        field_parsers.push(chunk.generate_read(id, file.endianness.value));
        field_writers.push(chunk.generate_write(file.endianness.value));
    }

    let field_names = fields.iter().map(Field::ident).collect::<Vec<_>>();

    let packet_size_bits = Chunk::new(fields).width();
    if packet_size_bits % 8 != 0 {
        panic!("packet {id} does not end on a byte boundary, size: {packet_size_bits} bits",);
        pub struct #id_builder {
            #(pub #field_names: #field_types),*
        }
    let packet_size_bytes = syn::Index::from(packet_size_bits / 8);

    let conforms = if packet_size_bytes.index == 0 {
        quote! { true }
    } else {
        quote! { bytes.len() >= #packet_size_bytes }
    };

    code.push_str(&quote_block! {
        impl #data_name {
            fn conforms(bytes: &[u8]) -> bool {
        impl #id_data {
            fn conforms(#span: &[u8]) -> bool {
                #conforms
            }

            fn parse(mut bytes: &[u8]) -> Result<Self> {
                #(#field_parsers)*
            fn parse(mut #span: &[u8]) -> Result<Self> {
                #field_parser
                Ok(Self { #(#field_names),* })
            }

            fn write_to(&self, buffer: &mut BytesMut) {
                #(#field_writers)*
                #field_serializer
            }

            fn get_total_size(&self) -> usize {
@@ -182,90 +122,70 @@ fn generate_packet_decl(
            }

            fn get_size(&self) -> usize {
                #packet_size_bytes
                #packet_size
            }
        }
    });

    code.push_str(&quote_block! {
        impl Packet for #packet_name {
        impl Packet for #id_packet {
            fn to_bytes(self) -> Bytes {
                let mut buffer = BytesMut::with_capacity(self.#ident.get_total_size());
                self.#ident.write_to(&mut buffer);
                let mut buffer = BytesMut::with_capacity(self.#id_lower.get_total_size());
                self.#id_lower.write_to(&mut buffer);
                buffer.freeze()
            }

            fn to_vec(self) -> Vec<u8> {
                self.to_bytes().to_vec()
            }
        }
        impl From<#packet_name> for Bytes {
            fn from(packet: #packet_name) -> Self {

        impl From<#id_packet> for Bytes {
            fn from(packet: #id_packet) -> Self {
                packet.to_bytes()
            }
        }
        impl From<#packet_name> for Vec<u8> {
            fn from(packet: #packet_name) -> Self {

        impl From<#id_packet> for Vec<u8> {
            fn from(packet: #id_packet) -> Self {
                packet.to_vec()
            }
        }
    });

    let specialize = has_children.then(|| {
        quote! {
            pub fn specialize(&self) -> #child_name {
                match &self.#ident.child {
                    #(#data_child_ident::#child_idents(_) =>
                      #child_name::#child_idents(
                          #child_decl_packet_name::new(self.#ident.clone()).unwrap()),)*
                    #data_child_ident::None => #child_name::None,
                }
            }
        }
    });
    let field_getters = fields.iter().map(|field| field.generate_getter(&ident));
    code.push_str(&quote_block! {
        impl #packet_name {
        impl #id_packet {
            pub fn parse(mut bytes: &[u8]) -> Result<Self> {
                Ok(Self::new(Arc::new(#data_name::parse(bytes)?)).unwrap())
                Ok(Self::new(Arc::new(#id_data::parse(bytes)?)).unwrap())
            }

            #specialize

            fn new(root: Arc<#data_name>) -> std::result::Result<Self, &'static str> {
                let #ident = root;
                Ok(Self { #ident })
            fn new(root: Arc<#id_data>) -> std::result::Result<Self, &'static str> {
                let #id_lower = root;
                Ok(Self { #id_lower })
            }

            #(#field_getters)*
            #(pub fn #getter_names(&self) -> #field_types {
                self.#id_lower.as_ref().#field_names
            })*
        }
    });

    let child = has_children.then(|| {
        quote! {
            child: #data_child_ident::None,
        }
    });
    code.push_str(&quote_block! {
        impl #builder_name {
            pub fn build(self) -> #packet_name {
                let #ident = Arc::new(#data_name {
                    #(#field_names: self.#field_names,)*
                    #child
        impl #id_builder {
            pub fn build(self) -> #id_packet {
                let #id_lower = Arc::new(#id_data {
                    #(#field_names: self.#field_names),*
                });
                #packet_name::new(#ident).unwrap()
                #id_packet::new(#id_lower).unwrap()
            }
        }
    }
    });

    code
}

fn generate_decl(scope: &lint::Scope<'_>, file: &ast::File, decl: &ast::Decl) -> String {
    match decl {
        ast::Decl::Packet { id, fields, parent_id, .. } => {
            let fields = fields.iter().map(Field::from).collect::<Vec<_>>();
            generate_packet_decl(scope, file, id, &fields, parent_id)
        }
        ast::Decl::Packet { id, constraints, fields, parent_id, .. } => generate_packet_decl(
            scope,
            file.endianness.value,
            id,
            constraints,
            fields,
            parent_id.as_deref(),
        )
        .to_string(),
        _ => todo!("unsupported Decl::{:?}", decl),
    }
}
@@ -345,8 +265,12 @@ mod tests {

    test_pdl!(packet_decl_empty, "packet Foo {}");

    test_pdl!(packet_decl_8bit_scalar, " packet Foo { x:  8 }");
    test_pdl!(packet_decl_24bit_scalar, "packet Foo { x: 24 }");
    test_pdl!(packet_decl_64bit_scalar, "packet Foo { x: 64 }");

    test_pdl!(
        packet_decl_simple,
        packet_decl_simple_scalars,
        r#"
          packet Foo {
            x: 8,
@@ -357,7 +281,7 @@ mod tests {
    );

    test_pdl!(
        packet_decl_complex,
        packet_decl_complex_scalars,
        r#"
          packet Foo {
            a: 3,
@@ -369,4 +293,17 @@ mod tests {
          }
        "#,
    );

    // Test that we correctly mask a byte-sized value in the middle of
    // a chunk.
    test_pdl!(
        packet_decl_mask_scalar_value,
        r#"
          packet Foo {
            a: 2,
            b: 24,
            c: 6,
          }
        "#,
    );
}
+0 −417

File deleted.

Preview size limit exceeded, changes collapsed.

+35 −0
Original line number Diff line number Diff line
use crate::ast;
use crate::backends::rust::types;
use quote::{format_ident, quote};

pub struct FieldDeclarations {
    code: Vec<proc_macro2::TokenStream>,
}

impl FieldDeclarations {
    pub fn new() -> FieldDeclarations {
        FieldDeclarations { code: Vec::new() }
    }

    pub fn add(&mut self, field: &ast::Field) {
        self.code.push(match field {
            ast::Field::Scalar { id, width, .. } => {
                let id = format_ident!("{id}");
                let field_type = types::Integer::new(*width);
                quote! {
                    #id: #field_type,
                }
            }
            _ => todo!(),
        });
    }
}

impl quote::ToTokens for FieldDeclarations {
    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
        let code = &self.code;
        tokens.extend(quote! {
            #(#code)*
        });
    }
}
Loading