Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 6dab87f1 authored by Martin Geisler's avatar Martin Geisler
Browse files

pdl: add ‘Chunk::generate_read’

Test: atest pdl_tests pdl_inline_tests
Change-Id: I11416277d43ef8ebb78608e7403f5e63636dd02c
parent d3afb2a3
Loading
Loading
Loading
Loading
+2 −204
Original line number Diff line number Diff line
@@ -43,53 +43,6 @@ pub fn get_field_range(offset: usize, width: usize) -> std::ops::Range<usize> {
    start..end
}

/// Read data for a byte-aligned chunk.
fn generate_chunk_read(
    packet_name: &str,
    endianness_value: ast::EndiannessValue,
    offset: usize,
    chunk: &Chunk,
) -> proc_macro2::TokenStream {
    assert!(offset % 8 == 0, "Chunks must be byte-aligned, got offset: {offset}");
    let getter = match endianness_value {
        ast::EndiannessValue::BigEndian => format_ident!("from_be_bytes"),
        ast::EndiannessValue::LittleEndian => format_ident!("from_le_bytes"),
    };

    // Work directly with the field name if we are reading a single
    // field. This generates simpler code.
    let chunk_name = chunk.get_name();
    let chunk_width = chunk.get_width();
    let chunk_type = types::Integer::new(chunk_width);
    assert!(chunk_width % 8 == 0, "Chunks must have a byte size, got width: {chunk_width}");

    let range = get_field_range(offset, chunk_width);
    let indices = range.map(syn::Index::from).collect::<Vec<_>>();

    // TODO(mgeisler): emit just a single length check per chunk. We
    // could even emit a single length check per packet.
    let length_checks = chunk.generate_length_checks(packet_name, offset);

    // When the chunk_type.width is larger than chunk_width (e.g.
    // chunk_width is 24 but chunk_type.width is 32), then we need
    // zero padding.
    let zero_padding_len = (chunk_type.width - chunk_width) / 8;
    // We need the padding on the MSB side of the payload, so for
    // big-endian, we need to padding on the left, for little-endian
    // we need it on the right.
    let (zero_padding_before, zero_padding_after) = match endianness_value {
        ast::EndiannessValue::BigEndian => (vec![syn::Index::from(0); zero_padding_len], vec![]),
        ast::EndiannessValue::LittleEndian => (vec![], vec![syn::Index::from(0); zero_padding_len]),
    };

    quote! {
        #(#length_checks)*
        let #chunk_name = #chunk_type::#getter([
            #(#zero_padding_before,)* #(bytes[#indices]),* #(, #zero_padding_after)*
        ]);
    }
}

fn generate_chunk_read_field_adjustments(fields: &[ast::Field]) -> proc_macro2::TokenStream {
    // If there is a single field in the chunk, then we don't have to
    // shift, mask, or cast.
@@ -343,11 +296,10 @@ fn generate_packet_decl(
    let mut offset = 0;
    for chunk in chunks {
        let chunk_fields = chunk.iter().map(Field::from).collect::<Vec<_>>();
        field_parsers.push(generate_chunk_read(
        field_parsers.push(Chunk::new(&chunk_fields).generate_read(
            id,
            file.endianness.value,
            offset,
            &Chunk::new(&chunk_fields),
        ));
        field_parsers.push(generate_chunk_read_field_adjustments(chunk));

@@ -531,9 +483,8 @@ pub fn generate(sources: &ast::SourceDatabase, file: &ast::File) -> String {
mod tests {
    use super::*;
    use crate::ast;
    use crate::backends::rust::field::ScalarField;
    use crate::parser::parse_inline;
    use crate::test_utils::{assert_eq_with_diff, assert_snapshot_eq, rustfmt};
    use crate::test_utils::{assert_expr_eq, assert_snapshot_eq, rustfmt};

    /// Parse a string fragment as a PDL file.
    ///
@@ -680,159 +631,6 @@ mod tests {
        assert_eq!(get_field_range(/*offset=*/ 5, /*width=*/ 20), (0..4));
    }

    // Assert that an expression equals the given expression.
    //
    // Both expressions are wrapped in a `main` function (so we can
    // format it with `rustfmt`) and a diff is be shown if they
    // differ.
    #[track_caller]
    fn assert_expr_eq(left: proc_macro2::TokenStream, right: proc_macro2::TokenStream) {
        let left = quote! {
            fn main() { #left }
        };
        let right = quote! {
            fn main() { #right }
        };
        assert_eq_with_diff(
            "left",
            &rustfmt(&left.to_string()),
            "right",
            &rustfmt(&right.to_string()),
        );
    }

    #[test]
    fn test_generate_chunk_read_8bit() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 8 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, &chunk),
            quote! {
                if bytes.len() < 11 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 11,
                        got: bytes.len(),
                    });
                }
                let a = u8::from_be_bytes([bytes[10]]);
            },
        );
    }

    #[test]
    fn test_generate_chunk_read_16bit_le() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 16 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::LittleEndian, 80, &chunk),
            quote! {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 12,
                        got: bytes.len(),
                    });
                }
                let a = u16::from_le_bytes([bytes[10], bytes[11]]);
            },
        );
    }

    #[test]
    fn test_generate_chunk_read_16bit_be() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 16 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, &chunk),
            quote! {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 12,
                        got: bytes.len(),
                    });
                }
                let a = u16::from_be_bytes([bytes[10], bytes[11]]);
            },
        );
    }

    #[test]
    fn test_generate_chunk_read_24bit_le() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 24 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::LittleEndian, 80, &chunk),
            quote! {
                if bytes.len() < 13 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 13,
                        got: bytes.len(),
                    });
                }
                let a = u32::from_le_bytes([bytes[10], bytes[11], bytes[12], 0]);
            },
        );
    }

    #[test]
    fn test_generate_chunk_read_24bit_be() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 24 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, &chunk),
            quote! {
                if bytes.len() < 13 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 13,
                        got: bytes.len(),
                    });
                }
                let a = u32::from_be_bytes([0, bytes[10], bytes[11], bytes[12]]);
            },
        );
    }

    #[test]
    fn test_generate_chunk_read_multiple_fields() {
        let fields = [
            Field::Scalar(ScalarField { id: String::from("a"), width: 16 }),
            Field::Scalar(ScalarField { id: String::from("b"), width: 24 }),
        ];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            generate_chunk_read("Foo", ast::EndiannessValue::BigEndian, 80, &chunk),
            quote! {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 12,
                        got: bytes.len(),
                    });
                }
                if bytes.len() < 15 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "b".to_string(),
                        wanted: 15,
                        got: bytes.len(),
                    });
                }
                let chunk =
                    u64::from_be_bytes([0, 0, 0, bytes[10], bytes[11], bytes[12], bytes[13], bytes[14]]);
            },
        );
    }

    #[test]
    fn test_generate_chunk_read_field_adjustments_8bit() {
        let loc = ast::SourceRange::default();
+190 −0
Original line number Diff line number Diff line
use crate::ast;
use crate::backends::rust::field::Field;
use crate::backends::rust::get_field_range;
use crate::backends::rust::types::Integer;
use quote::{format_ident, quote};

/// A chunk of field.
@@ -65,4 +67,192 @@ impl Chunk<'_> {
        }
        length_checks
    }

    /// Read data for a chunk.
    pub fn generate_read(
        &self,
        packet_name: &str,
        endianness_value: ast::EndiannessValue,
        offset: usize,
    ) -> proc_macro2::TokenStream {
        assert!(offset % 8 == 0, "Chunks must be byte-aligned, got offset: {offset}");
        let getter = match endianness_value {
            ast::EndiannessValue::BigEndian => format_ident!("from_be_bytes"),
            ast::EndiannessValue::LittleEndian => format_ident!("from_le_bytes"),
        };

        let chunk_name = self.get_name();
        let chunk_width = self.get_width();
        let chunk_type = Integer::new(chunk_width);
        assert!(chunk_width % 8 == 0, "Chunks must have a byte size, got width: {chunk_width}");

        let range = get_field_range(offset, chunk_width);
        let indices = range.map(syn::Index::from).collect::<Vec<_>>();

        // TODO(mgeisler): emit just a single length check per chunk. We
        // could even emit a single length check per packet.
        let length_checks = self.generate_length_checks(packet_name, offset);

        // When the chunk_type.width is larger than chunk_width (e.g.
        // chunk_width is 24 but chunk_type.width is 32), then we need
        // zero padding.
        let zero_padding_len = (chunk_type.width - chunk_width) / 8;
        // We need the padding on the MSB side of the payload, so for
        // big-endian, we need to padding on the left, for little-endian
        // we need it on the right.
        let (zero_padding_before, zero_padding_after) = match endianness_value {
            ast::EndiannessValue::BigEndian => {
                (vec![syn::Index::from(0); zero_padding_len], vec![])
            }
            ast::EndiannessValue::LittleEndian => {
                (vec![], vec![syn::Index::from(0); zero_padding_len])
            }
        };

        quote! {
            #(#length_checks)*
            let #chunk_name = #chunk_type::#getter([
                #(#zero_padding_before,)* #(bytes[#indices]),* #(, #zero_padding_after)*
            ]);
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::backends::rust::field::ScalarField;
    use crate::test_utils::assert_expr_eq;

    #[test]
    fn test_generate_read_8bit() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 8 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            chunk.generate_read("Foo", ast::EndiannessValue::BigEndian, 80),
            quote! {
                if bytes.len() < 11 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 11,
                        got: bytes.len(),
                    });
                }
                let a = u8::from_be_bytes([bytes[10]]);
            },
        );
    }

    #[test]
    fn test_generate_read_16bit_le() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 16 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            chunk.generate_read("Foo", ast::EndiannessValue::LittleEndian, 80),
            quote! {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 12,
                        got: bytes.len(),
                    });
                }
                let a = u16::from_le_bytes([bytes[10], bytes[11]]);
            },
        );
    }

    #[test]
    fn test_generate_read_16bit_be() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 16 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            chunk.generate_read("Foo", ast::EndiannessValue::BigEndian, 80),
            quote! {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 12,
                        got: bytes.len(),
                    });
                }
                let a = u16::from_be_bytes([bytes[10], bytes[11]]);
            },
        );
    }

    #[test]
    fn test_generate_read_24bit_le() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 24 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            chunk.generate_read("Foo", ast::EndiannessValue::LittleEndian, 80),
            quote! {
                if bytes.len() < 13 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 13,
                        got: bytes.len(),
                    });
                }
                let a = u32::from_le_bytes([bytes[10], bytes[11], bytes[12], 0]);
            },
        );
    }

    #[test]
    fn test_generate_read_24bit_be() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 24 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            chunk.generate_read("Foo", ast::EndiannessValue::BigEndian, 80),
            quote! {
                if bytes.len() < 13 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 13,
                        got: bytes.len(),
                    });
                }
                let a = u32::from_be_bytes([0, bytes[10], bytes[11], bytes[12]]);
            },
        );
    }

    #[test]
    fn test_generate_read_multiple_fields() {
        let fields = [
            Field::Scalar(ScalarField { id: String::from("a"), width: 16 }),
            Field::Scalar(ScalarField { id: String::from("b"), width: 24 }),
        ];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            chunk.generate_read("Foo", ast::EndiannessValue::BigEndian, 80),
            quote! {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 12,
                        got: bytes.len(),
                    });
                }
                if bytes.len() < 15 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "b".to_string(),
                        wanted: 15,
                        got: bytes.len(),
                    });
                }
                let chunk =
                    u64::from_be_bytes([0, 0, 0, bytes[10], bytes[11], bytes[12], bytes[13], bytes[14]]);
            },
        );
    }
}
+16 −0
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@
// rest of the `pdl` crate. To make this work, avoid `use crate::`
// statements below.

use quote::quote;
use std::fs;
use std::io::Write;
use std::path::Path;
@@ -102,6 +103,21 @@ pub fn assert_eq_with_diff(left_label: &str, left: &str, right_label: &str, righ
    );
}

// Assert that an expression equals the given expression.
//
// Both expressions are wrapped in a `main` function (so we can format
// it with `rustfmt`) and a diff is be shown if they differ.
#[track_caller]
pub fn assert_expr_eq(left: proc_macro2::TokenStream, right: proc_macro2::TokenStream) {
    let left = quote! {
        fn main() { #left }
    };
    let right = quote! {
        fn main() { #right }
    };
    assert_eq_with_diff("left", &rustfmt(&left.to_string()), "right", &rustfmt(&right.to_string()));
}

/// Check that `haystack` contains `needle`.
///
/// Panic with a nice message if not.