Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ada3d9a0 authored by Martin Geisler's avatar Martin Geisler
Browse files

pdl: Use ‘bytes’ crate to read scalars

Before we parsed the scalars by hand by tracking an offset and picking
out individual byte indices. Now we delegate this to the bytes crate.
This simplifies the generated code significantly (the bytes crate was
already a dependency of the generated code).

Test: atest pdl_tests pdl_rust_generator_tests_{le,be}
Change-Id: Ifa6f822df6aba06bcf13b8694ea182d7f1cfac1d
parent 7b828af4
Loading
Loading
Loading
Loading
+5 −40
Original line number Diff line number Diff line
@@ -32,16 +32,6 @@ macro_rules! quote_block {
    }
}

/// Find byte indices covering `offset..offset+width` bits.
pub fn get_field_range(offset: usize, width: usize) -> std::ops::Range<usize> {
    let start = offset / 8;
    let mut end = (offset + width) / 8;
    if (offset + width) % 8 != 0 {
        end += 1;
    }
    start..end
}

/// Generate a bit-mask which masks out `n` least significant bits.
pub fn mask_bits(n: usize) -> syn::LitInt {
    syn::parse_str::<syn::LitInt>(&format!("{:#x}", (1u64 << n) - 1)).unwrap()
@@ -152,12 +142,10 @@ fn generate_packet_decl(
    });
    let mut field_parsers = Vec::new();
    let mut field_writers = Vec::new();
    let mut offset = 0;
    for fields in chunks {
        let chunk = Chunk::new(fields);
        field_parsers.push(chunk.generate_read(id, file.endianness.value, offset));
        field_writers.push(chunk.generate_write(file.endianness.value, offset));
        offset += chunk.get_width();
        field_parsers.push(chunk.generate_read(id, file.endianness.value));
        field_writers.push(chunk.generate_write(file.endianness.value));
    }

    let field_names = fields.iter().map(Field::get_ident).collect::<Vec<_>>();
@@ -180,7 +168,7 @@ fn generate_packet_decl(
                #conforms
            }

            fn parse(bytes: &[u8]) -> Result<Self> {
            fn parse(mut bytes: &[u8]) -> Result<Self> {
                #(#field_parsers)*
                Ok(Self { #(#field_names),* })
            }
@@ -202,8 +190,7 @@ fn generate_packet_decl(
    code.push_str(&quote_block! {
        impl Packet for #packet_name {
            fn to_bytes(self) -> Bytes {
                let mut buffer = BytesMut::new();
                buffer.resize(self.#ident.get_total_size(), 0);
                let mut buffer = BytesMut::with_capacity(self.#ident.get_total_size());
                self.#ident.write_to(&mut buffer);
                buffer.freeze()
            }
@@ -238,7 +225,7 @@ fn generate_packet_decl(
    let field_getters = fields.iter().map(|field| field.generate_getter(&ident));
    code.push_str(&quote_block! {
        impl #packet_name {
            pub fn parse(bytes: &[u8]) -> Result<Self> {
            pub fn parse(mut bytes: &[u8]) -> Result<Self> {
                Ok(Self::new(Arc::new(#data_name::parse(bytes)?)).unwrap())
            }

@@ -382,26 +369,4 @@ mod tests {
          }
        "#,
    );

    #[test]
    fn test_get_field_range() {
        // Zero widths will give you an empty slice iff the offset is
        // byte aligned. In both cases, the slice covers the empty
        // width. In practice, PDL doesn't allow zero-width fields.
        assert_eq!(get_field_range(/*offset=*/ 0, /*width=*/ 0), (0..0));
        assert_eq!(get_field_range(/*offset=*/ 5, /*width=*/ 0), (0..1));
        assert_eq!(get_field_range(/*offset=*/ 8, /*width=*/ 0), (1..1));
        assert_eq!(get_field_range(/*offset=*/ 9, /*width=*/ 0), (1..2));

        // Non-zero widths work as expected.
        assert_eq!(get_field_range(/*offset=*/ 0, /*width=*/ 1), (0..1));
        assert_eq!(get_field_range(/*offset=*/ 0, /*width=*/ 5), (0..1));
        assert_eq!(get_field_range(/*offset=*/ 0, /*width=*/ 8), (0..1));
        assert_eq!(get_field_range(/*offset=*/ 0, /*width=*/ 20), (0..3));

        assert_eq!(get_field_range(/*offset=*/ 5, /*width=*/ 1), (0..1));
        assert_eq!(get_field_range(/*offset=*/ 5, /*width=*/ 3), (0..1));
        assert_eq!(get_field_range(/*offset=*/ 5, /*width=*/ 4), (0..2));
        assert_eq!(get_field_range(/*offset=*/ 5, /*width=*/ 20), (0..4));
    }
}
+88 −66
Original line number Diff line number Diff line
use crate::ast;
use crate::backends::rust::field::Field;
use crate::backends::rust::get_field_range;
use crate::backends::rust::types::Integer;
use quote::{format_ident, quote};

fn endianness_suffix(width: usize, endianness_value: ast::EndiannessValue) -> &'static str {
    if width > 8 && endianness_value == ast::EndiannessValue::LittleEndian {
        "_le"
    } else {
        ""
    }
}

/// Parse an unsigned integer from `buffer`.
///
/// The generated code requires that `buffer` is a mutable
/// `bytes::Buf` value.
fn get_uint(
    endianness: ast::EndiannessValue,
    buffer: proc_macro2::Ident,
    width: usize,
) -> proc_macro2::TokenStream {
    let suffix = endianness_suffix(width, endianness);
    let rust_integer_widths = [8, 16, 32, 64];
    if rust_integer_widths.contains(&width) {
        // We can use Buf::get_uNN.
        let get_u = format_ident!("get_u{}{}", width, suffix);
        quote! {
            #buffer.#get_u()
        }
    } else {
        // We fall back to Buf::get_uint.
        let get_uint = format_ident!("get_uint{}", suffix);
        let value_type = Integer::new(width);
        let value_nbytes = proc_macro2::Literal::usize_unsuffixed(width / 8);
        quote! {
            #buffer.#get_uint(#value_nbytes) as #value_type
        }
    }
}

/// Write an unsigned integer `value` to `buffer`.
///
/// The generated code requires that `buffer` is a mutable
/// `bytes::BufMut` value.
fn put_uint(
    endianness: ast::EndiannessValue,
    buffer: proc_macro2::Ident,
    value: proc_macro2::TokenStream,
    width: usize,
) -> proc_macro2::TokenStream {
    let suffix = endianness_suffix(width, endianness);
    let rust_integer_widths = [8, 16, 32, 64];
    if rust_integer_widths.contains(&width) {
        // We can use BufMut::put_uNN.
        let put_u = format_ident!("put_u{}{}", width, suffix);
        quote! {
            #buffer.#put_u(#value)
        }
    } else {
        // We fall back to BufMut::put_uint.
        let put_uint = format_ident!("put_uint{}", suffix);
        let value_nbytes = proc_macro2::Literal::usize_unsuffixed(width / 8);
        quote! {
            #buffer.#put_uint(#value as u64, #value_nbytes)
        }
    }
}

/// A chunk of field.
///
/// While fields can have arbitrary widths, a chunk is always an
@@ -35,19 +98,14 @@ impl Chunk<'_> {
    }

    /// Generate length checks for this chunk.
    pub fn generate_length_check(
        &self,
        packet_name: &str,
        offset: usize,
    ) -> proc_macro2::TokenStream {
        let range = get_field_range(offset, self.get_width());
        let wanted_length = syn::Index::from(range.end);
    pub fn generate_length_check(&self, packet_name: &str) -> proc_macro2::TokenStream {
        let wanted_length = proc_macro2::Literal::usize_unsuffixed(self.get_width() / 8);
        quote! {
            if bytes.len() < #wanted_length {
            if bytes.remaining() < #wanted_length {
                return Err(Error::InvalidLengthError {
                    obj: #packet_name.to_string(),
                    wanted: #wanted_length,
                    got: bytes.len(),
                    got: bytes.remaining(),
                });
            }
        }
@@ -58,46 +116,18 @@ impl Chunk<'_> {
        &self,
        packet_name: &str,
        endianness_value: ast::EndiannessValue,
        offset: usize,
    ) -> proc_macro2::TokenStream {
        assert!(offset % 8 == 0, "Chunks must be byte-aligned, got offset: {offset}");
        let getter = match endianness_value {
            ast::EndiannessValue::BigEndian => format_ident!("from_be_bytes"),
            ast::EndiannessValue::LittleEndian => format_ident!("from_le_bytes"),
        };

        let chunk_name = self.get_name();
        let chunk_width = self.get_width();
        let chunk_type = Integer::new(chunk_width);
        assert!(chunk_width % 8 == 0, "Chunks must have a byte size, got width: {chunk_width}");

        let range = get_field_range(offset, chunk_width);
        let indices = range.map(syn::Index::from).collect::<Vec<_>>();
        let length_check = self.generate_length_check(packet_name, offset);

        // When the chunk_type.width is larger than chunk_width (e.g.
        // chunk_width is 24 but chunk_type.width is 32), then we need
        // zero padding.
        let zero_padding_len = (chunk_type.width - chunk_width) / 8;
        // We need the padding on the MSB side of the payload, so for
        // big-endian, we need to padding on the left, for little-endian
        // we need it on the right.
        let (zero_padding_before, zero_padding_after) = match endianness_value {
            ast::EndiannessValue::BigEndian => {
                (vec![syn::Index::from(0); zero_padding_len], vec![])
            }
            ast::EndiannessValue::LittleEndian => {
                (vec![], vec![syn::Index::from(0); zero_padding_len])
            }
        };

        let length_check = self.generate_length_check(packet_name);
        let read = get_uint(endianness_value, format_ident!("bytes"), chunk_width);
        let read_adjustments = self.generate_read_adjustments();

        quote! {
            #length_check
            let #chunk_name = #chunk_type::#getter([
                #(#zero_padding_before,)* #(bytes[#indices]),* #(, #zero_padding_after)*
            ]);
            let #chunk_name = #read;
            #read_adjustments
        }
    }
@@ -127,26 +157,18 @@ impl Chunk<'_> {
    pub fn generate_write(
        &self,
        endianness_value: ast::EndiannessValue,
        offset: usize,
    ) -> proc_macro2::TokenStream {
        let writer = match endianness_value {
            ast::EndiannessValue::BigEndian => format_ident!("to_be_bytes"),
            ast::EndiannessValue::LittleEndian => format_ident!("to_le_bytes"),
        };

        let chunk_width = self.get_width();
        let chunk_name = self.get_name();
        assert!(chunk_width % 8 == 0, "Chunks must have a byte size, got width: {chunk_width}");

        let range = get_field_range(offset, chunk_width);
        let start = syn::Index::from(range.start);
        let end = syn::Index::from(range.end);
        // TODO(mgeisler): let slice = (chunk_type_width > chunk_width).then( ... )
        let chunk_byte_width = syn::Index::from(chunk_width / 8);
        let write_adjustments = self.generate_write_adjustments();
        let write =
            put_uint(endianness_value, format_ident!("buffer"), quote!(#chunk_name), chunk_width);
        quote! {
            #write_adjustments
            buffer[#start..#end].copy_from_slice(&#chunk_name.#writer()[0..#chunk_byte_width]);
            #write;
        }
    }

@@ -187,7 +209,7 @@ mod tests {
    fn test_generate_read_8bit() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 8 })];
        let chunk = Chunk::new(&fields);
        let chunk_read = chunk.generate_read("Foo", ast::EndiannessValue::BigEndian, 80);
        let chunk_read = chunk.generate_read("Foo", ast::EndiannessValue::BigEndian);
        let code = quote! { fn main() { #chunk_read } };
        assert_snapshot_eq(
            "tests/generated/generate_chunk_read_8bit.rs",
@@ -199,7 +221,7 @@ mod tests {
    fn test_generate_read_16bit_le() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 16 })];
        let chunk = Chunk::new(&fields);
        let chunk_read = chunk.generate_read("Foo", ast::EndiannessValue::LittleEndian, 80);
        let chunk_read = chunk.generate_read("Foo", ast::EndiannessValue::LittleEndian);
        let code = quote! { fn main() { #chunk_read } };
        assert_snapshot_eq(
            "tests/generated/generate_chunk_read_16bit_le.rs",
@@ -211,7 +233,7 @@ mod tests {
    fn test_generate_read_16bit_be() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 16 })];
        let chunk = Chunk::new(&fields);
        let chunk_read = chunk.generate_read("Foo", ast::EndiannessValue::BigEndian, 80);
        let chunk_read = chunk.generate_read("Foo", ast::EndiannessValue::BigEndian);
        let code = quote! { fn main() { #chunk_read } };
        assert_snapshot_eq(
            "tests/generated/generate_chunk_read_16bit_be.rs",
@@ -223,7 +245,7 @@ mod tests {
    fn test_generate_read_24bit_le() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 24 })];
        let chunk = Chunk::new(&fields);
        let chunk_read = chunk.generate_read("Foo", ast::EndiannessValue::LittleEndian, 80);
        let chunk_read = chunk.generate_read("Foo", ast::EndiannessValue::LittleEndian);
        let code = quote! { fn main() { #chunk_read } };
        assert_snapshot_eq(
            "tests/generated/generate_chunk_read_24bit_le.rs",
@@ -235,7 +257,7 @@ mod tests {
    fn test_generate_read_24bit_be() {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 24 })];
        let chunk = Chunk::new(&fields);
        let chunk_read = chunk.generate_read("Foo", ast::EndiannessValue::BigEndian, 80);
        let chunk_read = chunk.generate_read("Foo", ast::EndiannessValue::BigEndian);
        let code = quote! { fn main() { #chunk_read } };
        assert_snapshot_eq(
            "tests/generated/generate_chunk_read_24bit_be.rs",
@@ -250,7 +272,7 @@ mod tests {
            Field::Scalar(ScalarField { id: String::from("b"), width: 24 }),
        ];
        let chunk = Chunk::new(&fields);
        let chunk_read = chunk.generate_read("Foo", ast::EndiannessValue::BigEndian, 80);
        let chunk_read = chunk.generate_read("Foo", ast::EndiannessValue::BigEndian);
        let code = quote! { fn main() { #chunk_read } };
        assert_snapshot_eq(
            "tests/generated/generate_chunk_read_multiple_fields.rs",
@@ -301,10 +323,10 @@ mod tests {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 8 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            chunk.generate_write(ast::EndiannessValue::BigEndian, 80),
            chunk.generate_write(ast::EndiannessValue::BigEndian),
            quote! {
                let a = self.a;
                buffer[10..11].copy_from_slice(&a.to_be_bytes()[0..1]);
                buffer.put_u8(a);
            },
        );
    }
@@ -314,10 +336,10 @@ mod tests {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 16 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            chunk.generate_write(ast::EndiannessValue::BigEndian, 80),
            chunk.generate_write(ast::EndiannessValue::BigEndian),
            quote! {
                let a = self.a;
                buffer[10..12].copy_from_slice(&a.to_be_bytes()[0..2]);
                buffer.put_u16(a);
            },
        );
    }
@@ -327,10 +349,10 @@ mod tests {
        let fields = [Field::Scalar(ScalarField { id: String::from("a"), width: 24 })];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            chunk.generate_write(ast::EndiannessValue::BigEndian, 80),
            chunk.generate_write(ast::EndiannessValue::BigEndian),
            quote! {
                let a = self.a;
                buffer[10..13].copy_from_slice(&a.to_be_bytes()[0..3]);
                buffer.put_uint(a as u64, 3);
            },
        );
    }
@@ -343,12 +365,12 @@ mod tests {
        ];
        let chunk = Chunk::new(&fields);
        assert_expr_eq(
            chunk.generate_write(ast::EndiannessValue::BigEndian, 80),
            chunk.generate_write(ast::EndiannessValue::BigEndian),
            quote! {
                let chunk = 0;
                let chunk = chunk | (self.a as u64);
                let chunk = chunk | (((self.b as u64) & 0xffffff) << 16);
                buffer[10..15].copy_from_slice(&chunk.to_be_bytes()[0..5]);
                buffer.put_uint(chunk as u64, 5);
            },
        );
    }
+1 −1
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ pub fn generate(path: &Path) -> String {
    code.push_str("#![allow(warnings, missing_docs)]\n\n");

    code.push_str(&quote_block! {
        use bytes::{BufMut, Bytes, BytesMut};
        use bytes::{Buf, BufMut, Bytes, BytesMut};
        use num_derive::{FromPrimitive, ToPrimitive};
        use num_traits::{FromPrimitive, ToPrimitive};
        use std::convert::{TryFrom, TryInto};
+4 −4
Original line number Diff line number Diff line
fn main() {
    if bytes.len() < 12 {
    if bytes.remaining() < 2 {
        return Err(Error::InvalidLengthError {
            obj: "Foo".to_string(),
            wanted: 12,
            got: bytes.len(),
            wanted: 2,
            got: bytes.remaining(),
        });
    }
    let a = u16::from_be_bytes([bytes[10], bytes[11]]);
    let a = bytes.get_u16();
}
+4 −4
Original line number Diff line number Diff line
fn main() {
    if bytes.len() < 12 {
    if bytes.remaining() < 2 {
        return Err(Error::InvalidLengthError {
            obj: "Foo".to_string(),
            wanted: 12,
            got: bytes.len(),
            wanted: 2,
            got: bytes.remaining(),
        });
    }
    let a = u16::from_le_bytes([bytes[10], bytes[11]]);
    let a = bytes.get_u16_le();
}
Loading