Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit cd8373a6 authored by Martin Geisler's avatar Martin Geisler Committed by Cherrypicker Worker
Browse files

pdl: Add u64 or usize suffix to mask_bit output

We already handled the case where the generated integer literal is so
large that it doesn’t fit into the default i32 type. This was handled
by adding an explicit u64 suffix.

However, this produces invalid code if we use the u64 integer in an
expression where we actually expect an usize. We now take that into
account with an explicit suffix argument.

This was found during the integration with UWB.

Tag: #feature
Bug: 228306436
Test: atest pdl_tests pdl_rust_generator_tests_{le,be}
(cherry picked from https://android-review.googlesource.com/q/commit:e7849d25cfe1322f441a5bf72c818c8bfe096e92)
Merged-In: I756a70ad776edfd9732c9bc4c8a0498f4a82defc
Change-Id: I756a70ad776edfd9732c9bc4c8a0498f4a82defc
parent 030c5a43
Loading
Loading
Loading
Loading
+20 −8
Original line number Diff line number Diff line
@@ -38,10 +38,22 @@ macro_rules! quote_block {
}

/// Generate a bit-mask which masks out `n` least significant bits.
pub fn mask_bits(n: usize) -> syn::LitInt {
    // The literal needs a suffix if it's larger than an i32.
    let suffix = if n > 31 { "u64" } else { "" };
    syn::parse_str::<syn::LitInt>(&format!("{:#x}{suffix}", (1u64 << n) - 1)).unwrap()
///
/// Literal integers in Rust default to the `i32` type. For this
/// reason, if `n` is larger than 31, a suffix is added to the
/// `LitInt` returned. This should either be `u64` or `usize`
/// depending on where the result is used.
pub fn mask_bits(n: usize, suffix: &str) -> syn::LitInt {
    let suffix = if n > 31 { format!("_{suffix}") } else { String::new() };
    // Format the hex digits as 0x1111_2222_3333_usize.
    let hex_digits = format!("{:x}", (1u64 << n) - 1)
        .as_bytes()
        .rchunks(4)
        .rev()
        .map(|chunk| std::str::from_utf8(chunk).unwrap())
        .collect::<Vec<&str>>()
        .join("_");
    syn::parse_str::<syn::LitInt>(&format!("0x{hex_digits}{suffix}")).unwrap()
}

fn generate_packet_size_getter(
@@ -919,12 +931,12 @@ mod tests {
        packet_decl_array_unknown_element_width_dynamic_size,
        "
          struct Foo {
            _count_(a): 8,
            _count_(a): 40,
            a: 16[],
          }

          packet Bar {
            _size_(x): 8,
            _size_(x): 40,
            x: Foo[],
          }
        "
@@ -934,12 +946,12 @@ mod tests {
        packet_decl_array_unknown_element_width_dynamic_count,
        "
          struct Foo {
            _count_(a): 8,
            _count_(a): 40,
            a: 16[],
          }

          packet Bar {
            _count_(x): 8,
            _count_(x): 40,
            x: Foo[],
          }
        "
+1 −1
Original line number Diff line number Diff line
@@ -115,7 +115,7 @@ impl<'a> FieldParser<'a> {
            if !single_value && width < value_type.width {
                // Mask value if we grabbed more than `width` and if
                // `as #value_type` doesn't already do the masking.
                let mask = mask_bits(width);
                let mask = mask_bits(width, "u64");
                v = quote! { (#v & #mask) };
            }

+3 −3
Original line number Diff line number Diff line
@@ -65,7 +65,7 @@ impl<'a> FieldSerializer<'a> {
                let field_type = types::Integer::new(*width);
                if field_type.width > *width {
                    let packet_name = &self.packet_name;
                    let max_value = mask_bits(*width);
                    let max_value = mask_bits(*width, "u64");
                    self.code.push(quote! {
                        if self.#field_name > #max_value {
                            panic!(
@@ -105,7 +105,7 @@ impl<'a> FieldSerializer<'a> {
            }
            ast::FieldDesc::Size { field_id, width, .. } => {
                let packet_name = &self.packet_name;
                let max_value = mask_bits(*width);
                let max_value = mask_bits(*width, "usize");

                let decl = self.scope.typedef.get(self.packet_name).unwrap();
                let scope = self.scope.scopes.get(decl).unwrap();
@@ -165,7 +165,7 @@ impl<'a> FieldSerializer<'a> {
                let field_type = types::Integer::new(*width);
                if field_type.width > *width {
                    let packet_name = &self.packet_name;
                    let max_value = mask_bits(*width);
                    let max_value = mask_bits(*width, "usize");
                    self.code.push(quote! {
                        if self.#field_name.len() > #max_value {
                            panic!(
+2 −2
Original line number Diff line number Diff line
@@ -70,8 +70,8 @@ impl FooData {
        Ok(Self { x })
    }
    fn write_to(&self, buffer: &mut BytesMut) {
        if self.x > 0xffffff {
            panic!("Invalid value for {}::{}: {} > {}", "Foo", "x", self.x, 0xffffff);
        if self.x > 0xff_ffff {
            panic!("Invalid value for {}::{}: {} > {}", "Foo", "x", self.x, 0xff_ffff);
        }
        buffer.put_uint(self.x as u64, 3);
    }
+2 −2
Original line number Diff line number Diff line
@@ -70,8 +70,8 @@ impl FooData {
        Ok(Self { x })
    }
    fn write_to(&self, buffer: &mut BytesMut) {
        if self.x > 0xffffff {
            panic!("Invalid value for {}::{}: {} > {}", "Foo", "x", self.x, 0xffffff);
        if self.x > 0xff_ffff {
            panic!("Invalid value for {}::{}: {} > {}", "Foo", "x", self.x, 0xff_ffff);
        }
        buffer.put_uint_le(self.x as u64, 3);
    }
Loading