Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d6a7286b authored by Martin Geisler's avatar Martin Geisler Committed by Automerger Merge Worker
Browse files

Merge changes I46bf0c83,I5f639a94,Ib4ff14eb,I59c1d4c4,Ib9147d48 am: 6c770e9d

parents d125fded 6c770e9d
Loading
Loading
Loading
Loading
+55 −119
Original line number Diff line number Diff line
@@ -14,9 +14,14 @@ use std::collections::HashMap;
use std::path::Path;
use syn::parse_quote;

mod chunk;
mod field;
mod preamble;
mod types;

use chunk::Chunk;
use field::Field;

/// Generate a block of code.
///
/// Like `quote!`, but the code block will be followed by an empty
@@ -28,38 +33,8 @@ macro_rules! quote_block {
    }
}

fn generate_field(field: &ast::Field, visibility: syn::Visibility) -> proc_macro2::TokenStream {
    match field {
        ast::Field::Scalar { id, width, .. } => {
            let field_name = format_ident!("{id}");
            let field_type = types::Integer::new(*width);
            quote! {
                #visibility #field_name: #field_type
            }
        }
        _ => todo!("unsupported field: {:?}", field),
    }
}

fn generate_field_getter(packet_name: &syn::Ident, field: &ast::Field) -> proc_macro2::TokenStream {
    match field {
        ast::Field::Scalar { id, width, .. } => {
            // TODO(mgeisler): refactor with generate_field above.
            let getter_name = format_ident!("get_{id}");
            let field_name = format_ident!("{id}");
            let field_type = types::Integer::new(*width);
            quote! {
                pub fn #getter_name(&self) -> #field_type {
                    self.#packet_name.as_ref().#field_name
                }
            }
        }
        _ => todo!("unsupported field: {:?}", field),
    }
}

/// Find byte indices covering `offset..offset+width` bits.
fn get_field_range(offset: usize, width: usize) -> std::ops::Range<usize> {
pub fn get_field_range(offset: usize, width: usize) -> std::ops::Range<usize> {
    let start = offset / 8;
    let mut end = (offset + width) / 8;
    if (offset + width) % 8 != 0 {
@@ -68,16 +43,12 @@ fn get_field_range(offset: usize, width: usize) -> std::ops::Range<usize> {
    start..end
}

fn get_chunk_width(fields: &[ast::Field]) -> usize {
    fields.iter().map(get_field_width).sum()
}

/// Read data for a byte-aligned chunk.
fn generate_chunk_read(
    packet_name: &str,
    endianness_value: ast::EndiannessValue,
    offset: usize,
    chunk: &[ast::Field],
    chunk: &Chunk,
) -> proc_macro2::TokenStream {
    assert!(offset % 8 == 0, "Chunks must be byte-aligned, got offset: {offset}");
    let getter = match endianness_value {
@@ -87,44 +58,17 @@ fn generate_chunk_read(

    // Work directly with the field name if we are reading a single
    // field. This generates simpler code.
    let chunk_name = match chunk {
        [ast::Field::Scalar { id: field_name, .. }] => format_ident!("{}", field_name),
        _ => format_ident!("chunk"),
    };
    let chunk_width = get_chunk_width(chunk);
    let chunk_name = chunk.get_name();
    let chunk_width = chunk.get_width();
    let chunk_type = types::Integer::new(chunk_width);
    assert!(chunk_width % 8 == 0, "Chunks must have a byte size, got width: {chunk_width}");

    let range = get_field_range(offset, chunk_width);
    let indices = range.map(syn::Index::from).collect::<Vec<_>>();

    let mut field_offset = offset;
    let mut last_field_range_end = 0;
    // TODO(mgeisler): emit just a single length check per chunk. We
    // could even emit a single length check per packet.
    let length_checks = chunk.iter().map(|field| match field {
        ast::Field::Scalar { id, width, .. } => {
            let field_range = get_field_range(field_offset, *width);
            field_offset += *width;
            if field_range.end == last_field_range_end {
                None // Suppress redundant length check.
            } else {
                last_field_range_end = field_range.end;
                let range_end = syn::Index::from(field_range.end);
                Some(quote! {
                    if bytes.len() < #range_end {
                        return Err(Error::InvalidLengthError {
                            obj: #packet_name.to_string(),
                            field: #id.to_string(),
                            wanted: #range_end,
                            got: bytes.len(),
                        });
                    }
                })
            }
        }
        _ => todo!("unsupported field: {:?}", field),
    });
    let length_checks = chunk.generate_length_checks(packet_name, offset);

    // When the chunk_type.width is larger than chunk_width (e.g.
    // chunk_width is 24 but chunk_type.width is 32), then we need
@@ -153,15 +97,16 @@ fn generate_chunk_read_field_adjustments(fields: &[ast::Field]) -> proc_macro2::
        return quote! {};
    }

    let chunk_width = get_chunk_width(fields);
    let chunk_fields = fields.iter().map(Field::from).collect::<Vec<_>>();
    let chunk_width = Chunk::new(&chunk_fields).get_width();
    let chunk_type = types::Integer::new(chunk_width);

    let mut field_parsers = Vec::new();
    let mut field_offset = 0;
    for field in fields {
        let field_name = Field::from(field).get_ident();
        match field {
            ast::Field::Scalar { id, width, .. } => {
                let field_name = format_ident!("{id}");
            ast::Field::Scalar { width, .. } => {
                let field_type = types::Integer::new(*width);

                let mut field = quote! {
@@ -214,7 +159,8 @@ fn generate_chunk_write_field_adjustments(chunk: &[ast::Field]) -> proc_macro2::
        };
    }

    let chunk_width = get_chunk_width(chunk);
    let chunk_fields = chunk.iter().map(Field::from).collect::<Vec<_>>();
    let chunk_width = Chunk::new(&chunk_fields).get_width();
    let chunk_type = types::Integer::new(chunk_width);

    let mut field_parsers = Vec::new();
@@ -280,13 +226,9 @@ fn generate_chunk_write(
        ast::EndiannessValue::LittleEndian => format_ident!("to_le_bytes"),
    };

    // Work directly with the field name if we are writing a single
    // field. This generates simpler code.
    let chunk_name = match chunk {
        [ast::Field::Scalar { id, .. }] => format_ident!("{id}"),
        _ => format_ident!("chunk"),
    };
    let chunk_width = get_chunk_width(chunk);
    let chunk_fields = chunk.iter().map(Field::from).collect::<Vec<_>>();
    let chunk_width = Chunk::new(&chunk_fields).get_width();
    let chunk_name = Chunk::new(&chunk_fields).get_name();
    assert!(chunk_width % 8 == 0, "Chunks must have a byte size, got width: {chunk_width}");

    let range = get_field_range(offset, chunk_width);
@@ -299,14 +241,6 @@ fn generate_chunk_write(
    }
}

/// Field size in bits.
fn get_field_width(field: &ast::Field) -> usize {
    match field {
        ast::Field::Scalar { width, .. } => *width,
        _ => todo!("unsupported field: {:?}", field),
    }
}

/// Generate code for an `ast::Decl::Packet` enum value.
fn generate_packet_decl(
    file: &ast::File,
@@ -361,7 +295,7 @@ fn generate_packet_decl(
            child: #data_child_ident,
        }
    });
    let plain_fields = fields.iter().map(|field| generate_field(field, parse_quote!()));
    let plain_fields = fields.iter().map(|field| Field::from(field).generate_decl(parse_quote!()));
    code.push_str(&quote_block! {
        #[derive(Debug)]
        struct #data_name {
@@ -391,7 +325,7 @@ fn generate_packet_decl(
    });

    let builder_name = format_ident!("{id}Builder");
    let pub_fields = fields.iter().map(|field| generate_field(field, parse_quote!(pub)));
    let pub_fields = fields.iter().map(|field| Field::from(field).generate_decl(parse_quote!(pub)));
    code.push_str(&quote_block! {
        #[derive(Debug)]
        pub struct #builder_name {
@@ -401,31 +335,32 @@ fn generate_packet_decl(

    let mut chunk_width = 0;
    let chunks = fields.split_inclusive(|field| {
        chunk_width += get_field_width(field);
        chunk_width += Field::from(field).get_width();
        chunk_width % 8 == 0
    });
    let mut field_parsers = Vec::new();
    let mut field_writers = Vec::new();
    let mut offset = 0;
    for chunk in chunks {
        field_parsers.push(generate_chunk_read(id, file.endianness.value, offset, chunk));
        let chunk_fields = chunk.iter().map(Field::from).collect::<Vec<_>>();
        field_parsers.push(generate_chunk_read(
            id,
            file.endianness.value,
            offset,
            &Chunk::new(&chunk_fields),
        ));
        field_parsers.push(generate_chunk_read_field_adjustments(chunk));

        field_writers.push(generate_chunk_write_field_adjustments(chunk));
        field_writers.push(generate_chunk_write(file.endianness.value, offset, chunk));

        offset += get_chunk_width(chunk);
        offset += Chunk::new(&chunk_fields).get_width();
    }

    let field_names = fields
        .iter()
        .map(|field| match field {
            ast::Field::Scalar { id, .. } => format_ident!("{id}"),
            _ => todo!("unsupported field: {:?}", field),
        })
        .collect::<Vec<_>>();
    let field_names = fields.iter().map(|field| Field::from(field).get_ident()).collect::<Vec<_>>();

    let packet_size_bits = get_chunk_width(fields);
    let chunk_fields = fields.iter().map(Field::from).collect::<Vec<_>>();
    let packet_size_bits = Chunk::new(&chunk_fields).get_width();
    if packet_size_bits % 8 != 0 {
        panic!("packet {id} does not end on a byte boundary, size: {packet_size_bits} bits",);
    }
@@ -504,7 +439,7 @@ fn generate_packet_decl(
            }
        }
    });
    let field_getters = fields.iter().map(|field| generate_field_getter(&ident, field));
    let field_getters = fields.iter().map(|field| Field::from(field).generate_getter(&ident));
    code.push_str(&quote_block! {
        impl #packet_name {
            pub fn parse(bytes: &[u8]) -> Result<Self> {
@@ -596,6 +531,7 @@ pub fn generate(sources: &ast::SourceDatabase, file: &ast::File) -> String {
mod tests {
    use super::*;
    use crate::ast;
    use crate::backends::rust::field::ScalarField;
    use crate::parser::parse_inline;
    use crate::test_utils::{assert_eq_with_diff, assert_snapshot_eq, rustfmt};

@@ -767,10 +703,10 @@ mod tests {

    #[test]
    fn test_generate_chunk_read_8bit() {
        let loc = ast::SourceRange::default();
        let fields = &[ast::Field::Scalar { loc, id: String::from("a"), width: 8 }];
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 8 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, fields),
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, &chunk),
            quote! {
                if bytes.len() < 11 {
                    return Err(Error::InvalidLengthError {
@@ -787,10 +723,10 @@ mod tests {

    #[test]
    fn test_generate_chunk_read_16bit_le() {
        let loc = ast::SourceRange::default();
        let fields = &[ast::Field::Scalar { loc, id: String::from("a"), width: 16 }];
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 16 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::LittleEndian, 80, fields),
            generate_chunk_read("Foo", ast::EndiannessValue::LittleEndian, 80, &chunk),
            quote! {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
@@ -807,10 +743,10 @@ mod tests {

    #[test]
    fn test_generate_chunk_read_16bit_be() {
        let loc = ast::SourceRange::default();
        let fields = &[ast::Field::Scalar { loc, id: String::from("a"), width: 16 }];
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 16 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, fields),
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, &chunk),
            quote! {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
@@ -827,10 +763,10 @@ mod tests {

    #[test]
    fn test_generate_chunk_read_24bit_le() {
        let loc = ast::SourceRange::default();
        let fields = &[ast::Field::Scalar { loc, id: String::from("a"), width: 24 }];
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 24 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::LittleEndian, 80, fields),
            generate_chunk_read("Foo", ast::EndiannessValue::LittleEndian, 80, &chunk),
            quote! {
                if bytes.len() < 13 {
                    return Err(Error::InvalidLengthError {
@@ -847,10 +783,10 @@ mod tests {

    #[test]
    fn test_generate_chunk_read_24bit_be() {
        let loc = ast::SourceRange::default();
        let fields = &[ast::Field::Scalar { loc, id: String::from("a"), width: 24 }];
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 24 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, fields),
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, &chunk),
            quote! {
                if bytes.len() < 13 {
                    return Err(Error::InvalidLengthError {
@@ -867,13 +803,13 @@ mod tests {

    #[test]
    fn test_generate_chunk_read_multiple_fields() {
        let loc = ast::SourceRange::default();
        let fields = &[
            ast::Field::Scalar { loc, id: String::from("a"), width: 16 },
            ast::Field::Scalar { loc, id: String::from("b"), width: 24 },
        let fields = [
            Field::Scalar(ScalarField { id: String::from("a"), width: 16 }),
            Field::Scalar(ScalarField { id: String::from("b"), width: 24 }),
        ];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, fields),
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, &chunk),
            quote! {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
+68 −0
Original line number Diff line number Diff line
use crate::backends::rust::field::Field;
use crate::backends::rust::get_field_range;
use quote::{format_ident, quote};

/// A chunk of field.
///
/// While fields can have arbitrary widths, a chunk is always an
/// integer number of bytes wide.
pub struct Chunk<'a> {
    pub fields: &'a [Field],
}

impl Chunk<'_> {
    /// Construct a new `Chunk` from the fields.
    pub fn new(fields: &[Field]) -> Chunk {
        // TODO(mgeisler): check that the width % 8 == 0?
        Chunk { fields }
    }

    /// Generate a name for this chunk.
    ///
    /// The name is `"chunk"` if there is more than one field.
    pub fn get_name(&self) -> proc_macro2::Ident {
        match self.fields {
            [field] => field.get_ident(),
            _ => format_ident!("chunk"),
        }
    }

    /// Return the width in bits.
    pub fn get_width(&self) -> usize {
        self.fields.iter().map(|field| field.get_width()).sum()
    }

    /// Generate length checks for this chunk.
    pub fn generate_length_checks(
        &self,
        packet_name: &str,
        offset: usize,
    ) -> Vec<proc_macro2::TokenStream> {
        let mut field_offset = offset;
        let mut last_field_range_end = 0;
        let mut length_checks = Vec::new();
        for field in self.fields {
            let id = field.get_id();
            let width = field.get_width();
            let field_range = get_field_range(field_offset, width);
            field_offset += width;
            if field_range.end == last_field_range_end {
                continue;
            }

            last_field_range_end = field_range.end;
            let range_end = syn::Index::from(field_range.end);
            length_checks.push(quote! {
                if bytes.len() < #range_end {
                    return Err(Error::InvalidLengthError {
                        obj: #packet_name.to_string(),
                        field: #id.to_string(),
                        wanted: #range_end,
                        got: bytes.len(),
                    });
                }
            });
        }
        length_checks
    }
}
+96 −0
Original line number Diff line number Diff line
use quote::{format_ident, quote};

use crate::ast;
use crate::backends::rust::types;

/// Like [`ast::Field::Scalar`].
#[derive(Debug, Clone)]
pub struct ScalarField {
    pub id: String,
    pub width: usize,
}

impl ScalarField {
    fn new(id: &str, width: usize) -> ScalarField {
        ScalarField { id: String::from(id), width }
    }

    fn get_width(&self) -> usize {
        self.width
    }

    fn get_id(&self) -> &str {
        &self.id
    }

    fn get_ident(&self) -> proc_macro2::Ident {
        format_ident!("{}", self.id)
    }

    fn generate_decl(&self, visibility: syn::Visibility) -> proc_macro2::TokenStream {
        let field_name = self.get_ident();
        let field_type = types::Integer::new(self.width);
        quote! {
            #visibility #field_name: #field_type
        }
    }

    fn generate_getter(&self, packet_name: &syn::Ident) -> proc_macro2::TokenStream {
        let field_name = self.get_ident();
        let getter_name = format_ident!("get_{}", self.id);
        let field_type = types::Integer::new(self.width);
        quote! {
            pub fn #getter_name(&self) -> #field_type {
                self.#packet_name.as_ref().#field_name
            }
        }
    }
}

/// Projection of [`ast::Field`] with the bits needed for the Rust
/// backend.
#[derive(Debug, Clone)]
pub enum Field {
    Scalar(ScalarField),
}

impl From<&ast::Field> for Field {
    fn from(field: &ast::Field) -> Field {
        match field {
            ast::Field::Scalar { id, width, .. } => Field::Scalar(ScalarField::new(id, *width)),
            _ => todo!("Unsupported field: {:?}", field),
        }
    }
}

impl Field {
    pub fn get_width(&self) -> usize {
        match self {
            Field::Scalar(field) => field.get_width(),
        }
    }

    pub fn get_id(&self) -> &str {
        match self {
            Field::Scalar(field) => field.get_id(),
        }
    }

    pub fn get_ident(&self) -> proc_macro2::Ident {
        match self {
            Field::Scalar(field) => field.get_ident(),
        }
    }

    pub fn generate_decl(&self, visibility: syn::Visibility) -> proc_macro2::TokenStream {
        match self {
            Field::Scalar(field) => field.generate_decl(visibility),
        }
    }

    pub fn generate_getter(&self, packet_name: &syn::Ident) -> proc_macro2::TokenStream {
        match self {
            Field::Scalar(field) => field.generate_getter(packet_name),
        }
    }
}