Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 0b6a2fa6 authored by Martin Geisler's avatar Martin Geisler Committed by Automerger Merge Worker
Browse files

Merge "pdl: correctly apply mask in ‘write_to’ as well" am: 8c96266d

parents f3f6cfd4 8c96266d
Loading
Loading
Loading
Loading
+39 −12
Original line number Diff line number Diff line
@@ -122,6 +122,26 @@ fn generate_field_getter(packet_name: &syn::Ident, field: &ast::Field) -> proc_m
    }
}

/// Mask and rebind the field value (if necessary).
fn mask_field_value(field: &ast::Field) -> Option<proc_macro2::TokenStream> {
    match field {
        ast::Field::Scalar { id, width, .. } => {
            let field_name = format_ident!("{id}");
            let type_width = round_bit_width(*width);
            if *width != type_width {
                let mask =
                    syn::parse_str::<syn::LitInt>(&format!("{:#x}", (1u64 << *width) - 1)).unwrap();
                Some(quote! {
                    let #field_name = #field_name & #mask;
                })
            } else {
                None
            }
        }
        _ => todo!("unsupported field: {:?}", field),
    }
}

fn generate_field_parser(
    endianness_value: &ast::EndiannessValue,
    packet_name: &str,
@@ -150,13 +170,7 @@ fn generate_field_parser(

            let wanted_len = syn::Index::from(offset + width / 8);
            let indices = (offset..offset + width / 8).map(syn::Index::from);
            let mask = if *width != type_width {
                Some(quote! {
                    let #field_name = #field_name & 0xfff;
                })
            } else {
                None
            };
            let mask = mask_field_value(field);

            quote! {
                // TODO(mgeisler): call a function instead to avoid
@@ -187,16 +201,17 @@ fn generate_field_writer(
    match field {
        ast::Field::Scalar { id, width, .. } => {
            let field_name = format_ident!("{id}");
            let bit_width = round_bit_width(*width);
            let start = syn::Index::from(offset);
            let end = syn::Index::from(offset + bit_width / 8);
            let byte_width = syn::Index::from(bit_width / 8);
            let end = syn::Index::from(offset + width / 8);
            let byte_width = syn::Index::from(width / 8);
            let mask = mask_field_value(field);
            let writer = match file.endianness.value {
                ast::EndiannessValue::BigEndian => format_ident!("to_be_bytes"),
                ast::EndiannessValue::LittleEndian => format_ident!("to_le_bytes"),
            };
            quote! {
                let #field_name = self.#field_name;
                #mask
                buffer[#start..#end].copy_from_slice(&#field_name.#writer()[0..#byte_width]);
            }
        }
@@ -535,6 +550,7 @@ mod tests {
              packet Foo {
                x: 8,
                y: 16,
                z: 24,
              }
            "#,
        );
@@ -557,6 +573,7 @@ mod tests {
              packet Foo {
                x: 8,
                y: 16,
                z: 24,
              }
            "#,
        );
@@ -591,6 +608,16 @@ mod tests {
        );
    }

    #[test]
    fn test_mask_field_value() {
        let loc = ast::SourceRange::default();
        let field = ast::Field::Scalar { loc: loc.clone(), id: String::from("a"), width: 8 };
        assert_eq!(mask_field_value(&field).map(|m| m.to_string()), None);

        let field = ast::Field::Scalar { loc, id: String::from("a"), width: 24 };
        assert_expr_eq(mask_field_value(&field).unwrap(), quote! { let a = a & 0xffffff; });
    }

    #[test]
    fn test_generate_field_parser_no_padding() {
        let loc = ast::SourceRange::default();
@@ -629,7 +656,7 @@ mod tests {
                    });
                }
                let a = u32::from_le_bytes([bytes[10], bytes[11], bytes[12], 0]);
                let a = a & 0xfff;
                let a = a & 0xffffff;
            },
        );
    }
@@ -651,7 +678,7 @@ mod tests {
                    });
                }
                let a = u32::from_be_bytes([0, bytes[10], bytes[11], bytes[12]]);
                let a = a & 0xfff;
                let a = a & 0xffffff;
            },
        );
    }
+22 −4
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
struct FooData {
    x: u8,
    y: u16,
    z: u32,
}

#[derive(Debug, Clone)]
@@ -13,11 +14,12 @@ pub struct FooPacket {
pub struct FooBuilder {
    pub x: u8,
    pub y: u16,
    pub z: u32,
}

impl FooData {
    fn conforms(bytes: &[u8]) -> bool {
        if bytes.len() < 3 {
        if bytes.len() < 6 {
            return false;
        }
        true
@@ -41,20 +43,33 @@ impl FooData {
            });
        }
        let y = u16::from_be_bytes([bytes[1], bytes[2]]);
        Ok(Self { x, y })
        if bytes.len() < 6 {
            return Err(Error::InvalidLengthError {
                obj: "Foo".to_string(),
                field: "z".to_string(),
                wanted: 6,
                got: bytes.len(),
            });
        }
        let z = u32::from_be_bytes([0, bytes[3], bytes[4], bytes[5]]);
        let z = z & 0xffffff;
        Ok(Self { x, y, z })
    }
    fn write_to(&self, buffer: &mut BytesMut) {
        let x = self.x;
        buffer[0..1].copy_from_slice(&x.to_be_bytes()[0..1]);
        let y = self.y;
        buffer[1..3].copy_from_slice(&y.to_be_bytes()[0..2]);
        let z = self.z;
        let z = z & 0xffffff;
        buffer[3..6].copy_from_slice(&z.to_be_bytes()[0..3]);
    }
    fn get_total_size(&self) -> usize {
        self.get_size()
    }
    fn get_size(&self) -> usize {
        let ret = 0;
        let ret = ret + 3;
        let ret = ret + 6;
        ret
    }
}
@@ -95,11 +110,14 @@ impl FooPacket {
    pub fn get_y(&self) -> u16 {
        self.foo.as_ref().y
    }
    pub fn get_z(&self) -> u32 {
        self.foo.as_ref().z
    }
}

impl FooBuilder {
    pub fn build(self) -> FooPacket {
        let foo = Arc::new(FooData { x: self.x, y: self.y });
        let foo = Arc::new(FooData { x: self.x, y: self.y, z: self.z });
        FooPacket::new(foo).unwrap()
    }
}
+22 −4
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
struct FooData {
    x: u8,
    y: u16,
    z: u32,
}

#[derive(Debug, Clone)]
@@ -13,11 +14,12 @@ pub struct FooPacket {
pub struct FooBuilder {
    pub x: u8,
    pub y: u16,
    pub z: u32,
}

impl FooData {
    fn conforms(bytes: &[u8]) -> bool {
        if bytes.len() < 3 {
        if bytes.len() < 6 {
            return false;
        }
        true
@@ -41,20 +43,33 @@ impl FooData {
            });
        }
        let y = u16::from_le_bytes([bytes[1], bytes[2]]);
        Ok(Self { x, y })
        if bytes.len() < 6 {
            return Err(Error::InvalidLengthError {
                obj: "Foo".to_string(),
                field: "z".to_string(),
                wanted: 6,
                got: bytes.len(),
            });
        }
        let z = u32::from_le_bytes([bytes[3], bytes[4], bytes[5], 0]);
        let z = z & 0xffffff;
        Ok(Self { x, y, z })
    }
    fn write_to(&self, buffer: &mut BytesMut) {
        let x = self.x;
        buffer[0..1].copy_from_slice(&x.to_le_bytes()[0..1]);
        let y = self.y;
        buffer[1..3].copy_from_slice(&y.to_le_bytes()[0..2]);
        let z = self.z;
        let z = z & 0xffffff;
        buffer[3..6].copy_from_slice(&z.to_le_bytes()[0..3]);
    }
    fn get_total_size(&self) -> usize {
        self.get_size()
    }
    fn get_size(&self) -> usize {
        let ret = 0;
        let ret = ret + 3;
        let ret = ret + 6;
        ret
    }
}
@@ -95,11 +110,14 @@ impl FooPacket {
    pub fn get_y(&self) -> u16 {
        self.foo.as_ref().y
    }
    pub fn get_z(&self) -> u32 {
        self.foo.as_ref().z
    }
}

impl FooBuilder {
    pub fn build(self) -> FooPacket {
        let foo = Arc::new(FooData { x: self.x, y: self.y });
        let foo = Arc::new(FooData { x: self.x, y: self.y, z: self.z });
        FooPacket::new(foo).unwrap()
    }
}