Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 3aee419e authored by Martin Geisler's avatar Martin Geisler
Browse files

pdl: Generate just a single length check per chunk

We can now deviate from what ‘bluetooth_packetgen’ generates and thus
simplify our code a little.

Test: atest pdl_tests
Change-Id: Ica2cfe4e7d97dab5edfc897d6a2581b36dfb7885
parent e0f9ad4e
Loading
Loading
Loading
Loading
+13 −45
Original line number Diff line number Diff line
@@ -35,37 +35,22 @@ impl Chunk<'_> {
    }

    /// Generate length checks for this chunk.
    pub fn generate_length_checks(
    pub fn generate_length_check(
        &self,
        packet_name: &str,
        offset: usize,
    ) -> Vec<proc_macro2::TokenStream> {
        let mut field_offset = offset;
        let mut last_field_range_end = 0;
        let mut length_checks = Vec::new();
        for field in self.fields {
            let id = field.get_id();
            let width = field.get_width();
            let field_range = get_field_range(field_offset, width);
            field_offset += width;
            if field_range.end == last_field_range_end {
                continue;
            }

            last_field_range_end = field_range.end;
            let range_end = syn::Index::from(field_range.end);
            length_checks.push(quote! {
                if bytes.len() < #range_end {
    ) -> proc_macro2::TokenStream {
        let range = get_field_range(offset, self.get_width());
        let wanted_length = syn::Index::from(range.end);
        quote! {
            if bytes.len() < #wanted_length {
                return Err(Error::InvalidLengthError {
                    obj: #packet_name.to_string(),
                        field: #id.to_string(),
                        wanted: #range_end,
                    wanted: #wanted_length,
                    got: bytes.len(),
                });
            }
            });
        }
        length_checks
    }

    /// Read data for a chunk.
@@ -88,10 +73,7 @@ impl Chunk<'_> {

        let range = get_field_range(offset, chunk_width);
        let indices = range.map(syn::Index::from).collect::<Vec<_>>();

        // TODO(mgeisler): emit just a single length check per chunk. We
        // could even emit a single length check per packet.
        let length_checks = self.generate_length_checks(packet_name, offset);
        let length_check = self.generate_length_check(packet_name, offset);

        // When the chunk_type.width is larger than chunk_width (e.g.
        // chunk_width is 24 but chunk_type.width is 32), then we need
@@ -112,7 +94,7 @@ impl Chunk<'_> {
        let read_adjustments = self.generate_read_adjustments();

        quote! {
            #(#length_checks)*
            #length_check
            let #chunk_name = #chunk_type::#getter([
                #(#zero_padding_before,)* #(bytes[#indices]),* #(, #zero_padding_after)*
            ]);
@@ -211,7 +193,6 @@ mod tests {
                if bytes.len() < 11 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 11,
                        got: bytes.len(),
                    });
@@ -231,7 +212,6 @@ mod tests {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 12,
                        got: bytes.len(),
                    });
@@ -251,7 +231,6 @@ mod tests {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 12,
                        got: bytes.len(),
                    });
@@ -271,7 +250,6 @@ mod tests {
                if bytes.len() < 13 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 13,
                        got: bytes.len(),
                    });
@@ -291,7 +269,6 @@ mod tests {
                if bytes.len() < 13 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 13,
                        got: bytes.len(),
                    });
@@ -311,18 +288,9 @@ mod tests {
        assert_expr_eq(
            chunk.generate_read("Foo", ast::EndiannessValue::BigEndian, 80),
            quote! {
                if bytes.len() < 12 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "a".to_string(),
                        wanted: 12,
                        got: bytes.len(),
                    });
                }
                if bytes.len() < 15 {
                    return Err(Error::InvalidLengthError {
                        obj: "Foo".to_string(),
                        field: "b".to_string(),
                        wanted: 15,
                        got: bytes.len(),
                    });
+0 −10
Original line number Diff line number Diff line
@@ -20,10 +20,6 @@ impl ScalarField {
        self.width
    }

    fn get_id(&self) -> &str {
        &self.id
    }

    fn get_ident(&self) -> proc_macro2::Ident {
        format_ident!("{}", self.id)
    }
@@ -149,12 +145,6 @@ impl Field {
        }
    }

    pub fn get_id(&self) -> &str {
        match self {
            Field::Scalar(field) => field.get_id(),
        }
    }

    pub fn get_ident(&self) -> proc_macro2::Ident {
        match self {
            Field::Scalar(field) => field.get_ident(),
+2 −2
Original line number Diff line number Diff line
@@ -29,8 +29,8 @@ pub fn generate(path: &Path) -> String {
            InvalidPacketError,
            #[error("{field} was {value:x}, which is not known")]
            ConstraintOutOfBounds { field: String, value: u64 },
            #[error("when parsing {obj}.{field} needed length of {wanted} but got {got}")]
            InvalidLengthError { obj: String, field: String, wanted: usize, got: usize },
            #[error("when parsing {obj} needed length of {wanted} but got {got}")]
            InvalidLengthError { obj: String, wanted: usize, got: usize },
            #[error("Due to size restrictions a struct could not be parsed.")]
            ImpossibleStructError,
            #[error("when parsing field {obj}.{field}, {value} is not a valid {type_} value")]
+0 −11
Original line number Diff line number Diff line
@@ -31,18 +31,9 @@ impl FooData {
        true
    }
    fn parse(bytes: &[u8]) -> Result<Self> {
        if bytes.len() < 1 {
            return Err(Error::InvalidLengthError {
                obj: "Foo".to_string(),
                field: "a".to_string(),
                wanted: 1,
                got: bytes.len(),
            });
        }
        if bytes.len() < 2 {
            return Err(Error::InvalidLengthError {
                obj: "Foo".to_string(),
                field: "b".to_string(),
                wanted: 2,
                got: bytes.len(),
            });
@@ -54,7 +45,6 @@ impl FooData {
        if bytes.len() < 5 {
            return Err(Error::InvalidLengthError {
                obj: "Foo".to_string(),
                field: "d".to_string(),
                wanted: 5,
                got: bytes.len(),
            });
@@ -63,7 +53,6 @@ impl FooData {
        if bytes.len() < 7 {
            return Err(Error::InvalidLengthError {
                obj: "Foo".to_string(),
                field: "e".to_string(),
                wanted: 7,
                got: bytes.len(),
            });
+0 −11
Original line number Diff line number Diff line
@@ -31,18 +31,9 @@ impl FooData {
        true
    }
    fn parse(bytes: &[u8]) -> Result<Self> {
        if bytes.len() < 1 {
            return Err(Error::InvalidLengthError {
                obj: "Foo".to_string(),
                field: "a".to_string(),
                wanted: 1,
                got: bytes.len(),
            });
        }
        if bytes.len() < 2 {
            return Err(Error::InvalidLengthError {
                obj: "Foo".to_string(),
                field: "b".to_string(),
                wanted: 2,
                got: bytes.len(),
            });
@@ -54,7 +45,6 @@ impl FooData {
        if bytes.len() < 5 {
            return Err(Error::InvalidLengthError {
                obj: "Foo".to_string(),
                field: "d".to_string(),
                wanted: 5,
                got: bytes.len(),
            });
@@ -63,7 +53,6 @@ impl FooData {
        if bytes.len() < 7 {
            return Err(Error::InvalidLengthError {
                obj: "Foo".to_string(),
                field: "e".to_string(),
                wanted: 7,
                got: bytes.len(),
            });
Loading