Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit a89a3680 authored by Martin Geisler's avatar Martin Geisler
Browse files

pdl: push length checks down to ‘Chunk’

Test: atest pdl_tests pdl_inline_tests
Change-Id: I46bf0c83bfaf51165cd2ee6c699bc6032b38de44
parent f6a23743
Loading
Loading
Loading
Loading
+34 −61
Original line number Diff line number Diff line
@@ -34,7 +34,7 @@ macro_rules! quote_block {
}

/// Find byte indices covering `offset..offset+width` bits.
fn get_field_range(offset: usize, width: usize) -> std::ops::Range<usize> {
pub fn get_field_range(offset: usize, width: usize) -> std::ops::Range<usize> {
    let start = offset / 8;
    let mut end = (offset + width) / 8;
    if (offset + width) % 8 != 0 {
@@ -48,7 +48,7 @@ fn generate_chunk_read(
    packet_name: &str,
    endianness_value: ast::EndiannessValue,
    offset: usize,
    chunk: &[ast::Field],
    chunk: &Chunk,
) -> proc_macro2::TokenStream {
    assert!(offset % 8 == 0, "Chunks must be byte-aligned, got offset: {offset}");
    let getter = match endianness_value {
@@ -58,45 +58,17 @@ fn generate_chunk_read(

    // Work directly with the field name if we are reading a single
    // field. This generates simpler code.
    let chunk_name = match chunk {
        [ast::Field::Scalar { id: field_name, .. }] => format_ident!("{}", field_name),
        _ => format_ident!("chunk"),
    };
    let chunk_fields = chunk.iter().map(Field::from).collect::<Vec<_>>();
    let chunk_width = Chunk::new(&chunk_fields).get_width();
    let chunk_name = chunk.get_name();
    let chunk_width = chunk.get_width();
    let chunk_type = types::Integer::new(chunk_width);
    assert!(chunk_width % 8 == 0, "Chunks must have a byte size, got width: {chunk_width}");

    let range = get_field_range(offset, chunk_width);
    let indices = range.map(syn::Index::from).collect::<Vec<_>>();

    let mut field_offset = offset;
    let mut last_field_range_end = 0;
    // TODO(mgeisler): emit just a single length check per chunk. We
    // could even emit a single length check per packet.
    let length_checks = chunk.iter().map(|field| match field {
        ast::Field::Scalar { id, width, .. } => {
            let field_range = get_field_range(field_offset, *width);
            field_offset += *width;
            if field_range.end == last_field_range_end {
                None // Suppress redundant length check.
            } else {
                last_field_range_end = field_range.end;
                let range_end = syn::Index::from(field_range.end);
                Some(quote! {
                    if bytes.len() < #range_end {
                        return Err(Error::InvalidLengthError {
                            obj: #packet_name.to_string(),
                            field: #id.to_string(),
                            wanted: #range_end,
                            got: bytes.len(),
                        });
                    }
                })
            }
        }
        _ => todo!("unsupported field: {:?}", field),
    });
    let length_checks = chunk.generate_length_checks(packet_name, offset);

    // When the chunk_type.width is larger than chunk_width (e.g.
    // chunk_width is 24 but chunk_type.width is 32), then we need
@@ -254,14 +226,9 @@ fn generate_chunk_write(
        ast::EndiannessValue::LittleEndian => format_ident!("to_le_bytes"),
    };

    // Work directly with the field name if we are writing a single
    // field. This generates simpler code.
    let chunk_name = match chunk {
        [ast::Field::Scalar { id, .. }] => format_ident!("{id}"),
        _ => format_ident!("chunk"),
    };
    let chunk_fields = chunk.iter().map(Field::from).collect::<Vec<_>>();
    let chunk_width = Chunk::new(&chunk_fields).get_width();
    let chunk_name = Chunk::new(&chunk_fields).get_name();
    assert!(chunk_width % 8 == 0, "Chunks must have a byte size, got width: {chunk_width}");

    let range = get_field_range(offset, chunk_width);
@@ -375,13 +342,18 @@ fn generate_packet_decl(
    let mut field_writers = Vec::new();
    let mut offset = 0;
    for chunk in chunks {
        field_parsers.push(generate_chunk_read(id, file.endianness.value, offset, chunk));
        let chunk_fields = chunk.iter().map(Field::from).collect::<Vec<_>>();
        field_parsers.push(generate_chunk_read(
            id,
            file.endianness.value,
            offset,
            &Chunk::new(&chunk_fields),
        ));
        field_parsers.push(generate_chunk_read_field_adjustments(chunk));

        field_writers.push(generate_chunk_write_field_adjustments(chunk));
        field_writers.push(generate_chunk_write(file.endianness.value, offset, chunk));

        let chunk_fields = chunk.iter().map(Field::from).collect::<Vec<_>>();
        offset += Chunk::new(&chunk_fields).get_width();
    }

@@ -559,6 +531,7 @@ pub fn generate(sources: &ast::SourceDatabase, file: &ast::File) -> String {
mod tests {
    use super::*;
    use crate::ast;
    use crate::backends::rust::field::ScalarField;
    use crate::parser::parse_inline;
    use crate::test_utils::{assert_eq_with_diff, assert_snapshot_eq, rustfmt};

@@ -730,10 +703,10 @@ mod tests {

    #[test]
    fn test_generate_chunk_read_8bit() {
        let loc = ast::SourceRange::default();
        let fields = &[ast::Field::Scalar { loc, id: String::from("a"), width: 8 }];
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 8 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, fields),
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, &chunk),
            quote! {
                if bytes.len() < 11 {
                    return Err(Error::InvalidLengthError {
@@ -750,10 +723,10 @@ mod tests {

    #[test]
    fn test_generate_chunk_read_16bit_le() {
        let loc = ast::SourceRange::default();
        let fields = &[ast::Field::Scalar { loc, id: String::from("a"), width: 16 }];
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 16 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::LittleEndian, 80, fields),
            generate_chunk_read("Foo", ast::EndiannessValue::LittleEndian, 80, &chunk),
            quote! {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
@@ -770,10 +743,10 @@ mod tests {

    #[test]
    fn test_generate_chunk_read_16bit_be() {
        let loc = ast::SourceRange::default();
        let fields = &[ast::Field::Scalar { loc, id: String::from("a"), width: 16 }];
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 16 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, fields),
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, &chunk),
            quote! {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
@@ -790,10 +763,10 @@ mod tests {

    #[test]
    fn test_generate_chunk_read_24bit_le() {
        let loc = ast::SourceRange::default();
        let fields = &[ast::Field::Scalar { loc, id: String::from("a"), width: 24 }];
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 24 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::LittleEndian, 80, fields),
            generate_chunk_read("Foo", ast::EndiannessValue::LittleEndian, 80, &chunk),
            quote! {
                if bytes.len() < 13 {
                    return Err(Error::InvalidLengthError {
@@ -810,10 +783,10 @@ mod tests {

    #[test]
    fn test_generate_chunk_read_24bit_be() {
        let loc = ast::SourceRange::default();
        let fields = &[ast::Field::Scalar { loc, id: String::from("a"), width: 24 }];
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 24 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, fields),
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, &chunk),
            quote! {
                if bytes.len() < 13 {
                    return Err(Error::InvalidLengthError {
@@ -830,13 +803,13 @@ mod tests {

    #[test]
    fn test_generate_chunk_read_multiple_fields() {
        let loc = ast::SourceRange::default();
        let fields = &[
            ast::Field::Scalar { loc, id: String::from("a"), width: 16 },
            ast::Field::Scalar { loc, id: String::from("b"), width: 24 },
        let fields = [
            Field::Scalar(ScalarField { id: String::from("a"), width: 16 }),
            Field::Scalar(ScalarField { id: String::from("b"), width: 24 }),
        ];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, fields),
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, &chunk),
            quote! {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
+48 −2
Original line number Diff line number Diff line
use crate::backends::rust::field::Field;
use crate::backends::rust::get_field_range;
use quote::{format_ident, quote};

/// A chunk of field.
///
/// While fields can have arbitrary widths, a chunk is always an
/// integer number of bytes wide.
pub struct Chunk<'a> {
    fields: &'a [Field],
    pub fields: &'a [Field],
}

impl Chunk<'_> {
@@ -15,8 +17,52 @@ impl Chunk<'_> {
        Chunk { fields }
    }

    /// Generate a name for this chunk.
    ///
    /// The name is `"chunk"` if there is more than one field.
    pub fn get_name(&self) -> proc_macro2::Ident {
        match self.fields {
            [field] => field.get_ident(),
            _ => format_ident!("chunk"),
        }
    }

    /// Return the width in bits.
    pub fn get_width(self) -> usize {
    pub fn get_width(&self) -> usize {
        self.fields.iter().map(|field| field.get_width()).sum()
    }

    /// Generate length checks for this chunk.
    pub fn generate_length_checks(
        &self,
        packet_name: &str,
        offset: usize,
    ) -> Vec<proc_macro2::TokenStream> {
        let mut field_offset = offset;
        let mut last_field_range_end = 0;
        let mut length_checks = Vec::new();
        for field in self.fields {
            let id = field.get_id();
            let width = field.get_width();
            let field_range = get_field_range(field_offset, width);
            field_offset += width;
            if field_range.end == last_field_range_end {
                continue;
            }

            last_field_range_end = field_range.end;
            let range_end = syn::Index::from(field_range.end);
            length_checks.push(quote! {
                if bytes.len() < #range_end {
                    return Err(Error::InvalidLengthError {
                        obj: #packet_name.to_string(),
                        field: #id.to_string(),
                        wanted: #range_end,
                        got: bytes.len(),
                    });
                }
            });
        }
        length_checks
    }
}
+12 −2
Original line number Diff line number Diff line
@@ -6,8 +6,8 @@ use crate::backends::rust::types;
/// Like [`ast::Field::Scalar`].
#[derive(Debug, Clone)]
pub struct ScalarField {
    id: String,
    width: usize,
    pub id: String,
    pub width: usize,
}

impl ScalarField {
@@ -19,6 +19,10 @@ impl ScalarField {
        self.width
    }

    fn get_id(&self) -> &str {
        &self.id
    }

    fn get_ident(&self) -> proc_macro2::Ident {
        format_ident!("{}", self.id)
    }
@@ -66,6 +70,12 @@ impl Field {
        }
    }

    pub fn get_id(&self) -> &str {
        match self {
            Field::Scalar(field) => field.get_id(),
        }
    }

    pub fn get_ident(&self) -> proc_macro2::Ident {
        match self {
            Field::Scalar(field) => field.get_ident(),