Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 1f3cd3bc authored by Martin Geisler's avatar Martin Geisler Committed by Cherrypicker Worker
Browse files

pdl: Add support for payloads and constraints

This CL adds support for payload fields and child packets.

The data struct for each packet is now marked `pub`: this is because
it can be mentioned in the enums we build to hold child data.

Tag: #feature
Bug: 233340327
Bug: 233340326
Test: atest pdl_tests pdl_rust_generator_tests_{le,be}
(cherry picked from https://android-review.googlesource.com/q/commit:abc636fe8c921923e33498a2bde61545e208b653)
Merged-In: I8284f3a7f5b4a7e5c6ae12077a8466492d74cea0
Change-Id: I8284f3a7f5b4a7e5c6ae12077a8466492d74cea0
parent e5822130
Loading
Loading
Loading
Loading
+9 −1
Original line number Diff line number Diff line
@@ -82,6 +82,8 @@ rust_test_host {
        "tests/generated/packet_decl_array_unknown_element_width_dynamic_count_little_endian.rs",
        "tests/generated/packet_decl_array_unknown_element_width_dynamic_size_big_endian.rs",
        "tests/generated/packet_decl_array_unknown_element_width_dynamic_size_little_endian.rs",
        "tests/generated/packet_decl_child_packets_big_endian.rs",
        "tests/generated/packet_decl_child_packets_little_endian.rs",
        "tests/generated/packet_decl_complex_scalars_big_endian.rs",
        "tests/generated/packet_decl_complex_scalars_little_endian.rs",
        "tests/generated/packet_decl_empty_big_endian.rs",
@@ -94,6 +96,12 @@ rust_test_host {
        "tests/generated/packet_decl_mask_scalar_value_little_endian.rs",
        "tests/generated/packet_decl_mixed_scalars_enums_big_endian.rs",
        "tests/generated/packet_decl_mixed_scalars_enums_little_endian.rs",
        "tests/generated/packet_decl_payload_field_unknown_size_big_endian.rs",
        "tests/generated/packet_decl_payload_field_unknown_size_little_endian.rs",
        "tests/generated/packet_decl_payload_field_unknown_size_terminal_big_endian.rs",
        "tests/generated/packet_decl_payload_field_unknown_size_terminal_little_endian.rs",
        "tests/generated/packet_decl_payload_field_variable_size_big_endian.rs",
        "tests/generated/packet_decl_payload_field_variable_size_little_endian.rs",
        "tests/generated/packet_decl_reserved_field_big_endian.rs",
        "tests/generated/packet_decl_reserved_field_little_endian.rs",
        "tests/generated/packet_decl_simple_scalars_big_endian.rs",
+1 −1
Original line number Diff line number Diff line
@@ -26,7 +26,7 @@ optional = true

[dev-dependencies]
tempfile = "3.3.0"
bytes = "1.2.1"
bytes = { version = "1.2.1", features = ["serde"] }
num-derive = "0.3.3"
num-traits = "0.2.15"
thiserror = "1.0.37"
+10 −4
Original line number Diff line number Diff line
@@ -23,7 +23,7 @@ pub struct SourceLocation {
    pub column: usize,
}

#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Serialize)]
#[derive(Default, Copy, Clone, PartialEq, Eq, Serialize)]
pub struct SourceRange {
    pub file: FileId,
    pub start: SourceLocation,
@@ -64,7 +64,7 @@ pub struct Tag {
    pub value: usize,
}

#[derive(Debug, Serialize, Clone)]
#[derive(Debug, Serialize, Clone, PartialEq, Eq)]
#[serde(tag = "kind", rename = "constraint")]
pub struct Constraint {
    pub id: String,
@@ -73,7 +73,7 @@ pub struct Constraint {
    pub tag_id: Option<String>,
}

#[derive(Debug, Serialize, Clone)]
#[derive(Debug, Serialize, Clone, PartialEq, Eq)]
#[serde(tag = "kind")]
pub enum FieldDesc {
    #[serde(rename = "checksum_field")]
@@ -112,7 +112,7 @@ pub enum FieldDesc {
    Group { group_id: String, constraints: Vec<Constraint> },
}

#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, PartialEq, Eq)]
pub struct Field<A: Annotation> {
    pub loc: SourceRange,
    #[serde(skip_serializing)]
@@ -216,6 +216,12 @@ impl fmt::Display for SourceRange {
    }
}

impl fmt::Debug for SourceRange {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("SourceRange").finish_non_exhaustive()
    }
}

impl ops::Add<SourceRange> for SourceRange {
    type Output = SourceRange;

+324 −37
Original line number Diff line number Diff line
@@ -57,9 +57,9 @@ fn generate_packet_size_getter(

        let decl = field.declaration(scope);
        dynamic_widths.push(match &field.desc {
            ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. } => {
                quote!(self.payload.len())
            }
            ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. } => quote! {
                self.child.get_total_size()
            },
            ast::FieldDesc::Typedef { id, .. } => {
                let id = format_ident!("{id}");
                quote!(self.#id.get_size())
@@ -112,6 +112,16 @@ fn generate_packet_size_getter(
    )
}

fn top_level_packet<'a>(scope: &lint::Scope<'a>, packet_name: &'a str) -> &'a parser_ast::Decl {
    let mut decl = scope.typedef[packet_name];
    while let ast::DeclDesc::Packet { parent_id: Some(parent_id), .. }
    | ast::DeclDesc::Struct { parent_id: Some(parent_id), .. } = &decl.desc
    {
        decl = scope.typedef[parent_id];
    }
    decl
}

/// Generate code for `ast::Decl::Packet` and `ast::Decl::Struct`
/// values.
fn generate_packet_decl(
@@ -122,13 +132,20 @@ fn generate_packet_decl(
    id: &str,
    _constraints: &[ast::Constraint],
    fields: &[parser_ast::Field],
    _parent_id: Option<&str>,
) -> proc_macro2::TokenStream {
    let packet_scope = &scope.scopes[&scope.typedef[id]];

    let top_level = top_level_packet(scope, id);
    let top_level_id = top_level.id().unwrap();
    let top_level_packet = format_ident!("{top_level_id}");
    let top_level_data = format_ident!("{top_level_id}Data");
    let top_level_id_lower = format_ident!("{}", top_level_id.to_lowercase());

    // TODO(mgeisler): use the convert_case crate to convert between
    // `FooBar` and `foo_bar` in the code below.
    let span = format_ident!("bytes");
    let serializer_span = format_ident!("buffer");
    let mut field_declarations = FieldDeclarations::new();
    let mut field_declarations = FieldDeclarations::new(scope, id);
    let mut field_parser = FieldParser::new(scope, endianness, id, &span);
    let mut field_serializer = FieldSerializer::new(scope, endianness, id, &serializer_span);
    for field in fields {
@@ -136,19 +153,214 @@ fn generate_packet_decl(
        field_parser.add(field);
        field_serializer.add(field);
    }
    field_declarations.done();
    field_parser.done();

    let id_lower = format_ident!("{}", id.to_lowercase());
    let id_packet = format_ident!("{id}");
    let id_child = format_ident!("{id}Child");
    let id_data = format_ident!("{id}Data");
    let id_data_child = format_ident!("{id}DataChild");
    let id_builder = format_ident!("{id}Builder");

    let fields_with_ids = fields.iter().filter(|f| f.id().is_some()).collect::<Vec<_>>();
    let field_names =
        fields_with_ids.iter().map(|f| format_ident!("{}", f.id().unwrap())).collect::<Vec<_>>();
    let field_types = fields_with_ids.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
    let field_borrows = fields_with_ids.iter().map(|f| types::rust_borrow(f)).collect::<Vec<_>>();
    let getter_names = field_names.iter().map(|id| format_ident!("get_{id}"));

    let mut decl = scope.typedef[id];
    let mut parents = vec![decl];
    while let ast::DeclDesc::Packet { parent_id: Some(parent_id), .. }
    | ast::DeclDesc::Struct { parent_id: Some(parent_id), .. } = &decl.desc
    {
        decl = scope.typedef[parent_id];
        parents.push(decl);
    }
    parents.reverse();

    let parent_ids = parents.iter().map(|p| p.id().unwrap()).collect::<Vec<_>>();
    let parent_shifted_ids = parent_ids.iter().skip(1).map(|id| format_ident!("{id}"));
    let parent_lower_ids =
        parent_ids.iter().map(|id| format_ident!("{}", id.to_lowercase())).collect::<Vec<_>>();
    let parent_shifted_lower_ids = parent_lower_ids.iter().skip(1).collect::<Vec<_>>();
    let parent_packet = parent_ids.iter().map(|id| format_ident!("{id}"));
    let parent_data = parent_ids.iter().map(|id| format_ident!("{id}Data"));
    let parent_data_child = parent_ids.iter().map(|id| format_ident!("{id}DataChild"));

    let all_fields = {
        let mut fields = packet_scope.all_fields.values().collect::<Vec<_>>();
        fields.sort_by_key(|f| f.id());
        fields
    };
    let all_field_names =
        all_fields.iter().map(|f| format_ident!("{}", f.id().unwrap())).collect::<Vec<_>>();
    let all_field_types = all_fields.iter().map(|f| types::rust_type(f)).collect::<Vec<_>>();
    let all_field_borrows = all_fields.iter().map(|f| types::rust_borrow(f)).collect::<Vec<_>>();
    let all_field_getter_names = all_field_names.iter().map(|id| format_ident!("get_{id}"));
    let all_field_self_field = all_fields.iter().map(|f| {
        for (parent, parent_id) in parents.iter().zip(parent_lower_ids.iter()) {
            if scope.scopes[parent].fields.contains(f) {
                return quote!(self.#parent_id);
            }
        }
        unreachable!("Could not find {f:?} in parent chain");
    });

    let unconstrained_fields = all_fields
        .iter()
        .filter(|f| !packet_scope.all_constraints.contains_key(f.id().unwrap()))
        .collect::<Vec<_>>();
    let unconstrained_field_names = unconstrained_fields
        .iter()
        .map(|f| format_ident!("{}", f.id().unwrap()))
        .collect::<Vec<_>>();
    let unconstrained_field_types = unconstrained_fields.iter().map(|f| types::rust_type(f));

    let rev_parents = parents.iter().rev().collect::<Vec<_>>();
    let builder_assignments = rev_parents.iter().enumerate().map(|(idx, parent)| {
        let parent_id = parent.id().unwrap();
        let parent_id_lower = format_ident!("{}", parent_id.to_lowercase());
        let parent_data = format_ident!("{parent_id}Data");
        let parent_data_child = format_ident!("{parent_id}DataChild");
        let parent_packet_scope = &scope.scopes[&scope.typedef[parent_id]];

        let named_fields = {
            let mut names = parent_packet_scope.named.keys().collect::<Vec<_>>();
            names.sort();
            names
        };

        let mut field = named_fields.iter().map(|id| format_ident!("{id}")).collect::<Vec<_>>();
        let mut value = named_fields
            .iter()
            .map(|&id| match packet_scope.all_constraints.get(id) {
                Some(constraint) => {
                    let value = match constraint {
                        ast::Constraint { value: Some(value), .. } => {
                            let value = proc_macro2::Literal::usize_unsuffixed(*value);
                            quote!(#value)
                        }
                        ast::Constraint { tag_id: Some(tag_id), .. } => {
                            let type_id = match packet_scope.all_fields.get(id).map(|f| &f.desc) {
                                Some(ast::FieldDesc::Typedef { type_id, .. }) => {
                                    format_ident!("{type_id}")
                                }
                                _ => unreachable!("Invalid constraint: {constraint:?}"),
                            };
                            let tag_id = format_ident!("{tag_id}");
                            quote!(#type_id::#tag_id)
                        }
                        _ => unreachable!("Invalid constraint: {constraint:?}"),
                    };
                    quote!(#value)
                }
                None => {
                    let id = format_ident!("{id}");
                    quote!(self.#id)
                }
            })
            .collect::<Vec<_>>();

        if parent_packet_scope.payload.is_some() {
            field.push(format_ident!("child"));
            if idx == 0 {
                // Top-most parent, the child is simply created from
                // our payload.
                value.push(quote! {
                    match self.payload {
                        None => #parent_data_child::None,
                        Some(bytes) => #parent_data_child::Payload(bytes),
                    }
                });
            } else {
                // Child is created from the previous parent.
                let prev_parent_id = rev_parents[idx - 1].id().unwrap();
                let prev_parent_id_lower = format_ident!("{}", prev_parent_id.to_lowercase());
                let prev_parent_id = format_ident!("{prev_parent_id}");
                value.push(quote! {
                    #parent_data_child::#prev_parent_id(#prev_parent_id_lower)
                });
            }
        }

        quote! {
            let #parent_id_lower = Arc::new(#parent_data {
                #(#field: #value,)*
            });
        }
    });

    let children = scope.children.get(id).map(Vec::as_slice).unwrap_or_default();
    let has_payload = packet_scope.payload.is_some();
    let has_children_or_payload = !children.is_empty() || has_payload;
    let child =
        children.iter().map(|child| format_ident!("{}", child.id().unwrap())).collect::<Vec<_>>();
    let child_data = child.iter().map(|child| format_ident!("{child}Data")).collect::<Vec<_>>();
    let get_payload = (children.is_empty() && has_payload).then(|| {
        quote! {
            pub fn get_payload(&self) -> &[u8] {
                match &self.#id_lower.child {
                    #id_data_child::Payload(bytes) => &bytes,
                    #id_data_child::None => &[],
                }
            }
        }
    });
    let child_declaration = has_children_or_payload.then(|| {
        quote! {
            #[derive(Debug)]
            #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
            pub enum #id_data_child {
                #(#child(Arc<#child_data>),)*
                Payload(Bytes),
                None,
            }

            impl #id_data_child {
                fn get_total_size(&self) -> usize {
                    match self {
                        #(#id_data_child::#child(value) => value.get_total_size(),)*
                        #id_data_child::Payload(bytes) => bytes.len(),
                        #id_data_child::None => 0,
                    }
                }
            }

            #[derive(Debug)]
            #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
            pub enum #id_child {
                #(#child(#child),)*
                Payload(Bytes),
                None,
            }
        }
    });
    let child_field = has_children_or_payload.then(|| quote!(child));
    let builder_payload_field = has_children_or_payload.then(|| {
        quote! {
            pub payload: Option<Bytes>
        }
    });

    let ancestor_packets =
        parent_ids[..parent_ids.len() - 1].iter().map(|id| format_ident!("{id}"));
    let impl_from_and_try_from = (top_level_id != id).then(|| {
        quote! {
            #(
                impl From<#id_packet> for #ancestor_packets {
                    fn from(packet: #id_packet) -> #ancestor_packets {
                        #ancestor_packets::new(packet.#top_level_id_lower).unwrap()
                    }
                }
            )*

            impl TryFrom<#top_level_packet> for #id_packet {
                type Error = TryFromError;
                fn try_from(packet: #top_level_packet) -> std::result::Result<#id_packet, TryFromError> {
                    #id_packet::new(packet.#top_level_id_lower).map_err(TryFromError)
                }
            }
        }
    });

    let (constant_width, packet_size) = generate_packet_size_getter(scope, fields);
    let conforms = if constant_width == 0 {
@@ -159,23 +371,28 @@ fn generate_packet_decl(
    };

    quote! {
        #child_declaration

        #[derive(Debug)]
        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
        struct #id_data {
        pub struct #id_data {
            #field_declarations
        }

        #[derive(Debug, Clone)]
        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
        pub struct #id_packet {
            #(
                #[cfg_attr(feature = "serde", serde(flatten))]
            #id_lower: Arc<#id_data>,
                #parent_lower_ids: Arc<#parent_data>,
            )*
        }

        #[derive(Debug)]
        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
        pub struct #id_builder {
            #(pub #field_names: #field_types),*
            #(pub #unconstrained_field_names: #unconstrained_field_types,)*
            #builder_payload_field
        }

        impl #id_data {
@@ -185,7 +402,10 @@ fn generate_packet_decl(

            fn parse(mut #span: &mut Cell<&[u8]>) -> Result<Self> {
                #field_parser
                Ok(Self { #(#field_names),* })
                Ok(Self {
                    #(#field_names,)*
                    #child_field
                })
            }

            fn write_to(&self, buffer: &mut BytesMut) {
@@ -203,8 +423,8 @@ fn generate_packet_decl(

        impl Packet for #id_packet {
            fn to_bytes(self) -> Bytes {
                let mut buffer = BytesMut::with_capacity(self.#id_lower.get_total_size());
                self.#id_lower.write_to(&mut buffer);
                let mut buffer = BytesMut::with_capacity(self.#top_level_id_lower.get_size());
                self.#top_level_id_lower.write_to(&mut buffer);
                buffer.freeze()
            }

@@ -225,6 +445,8 @@ fn generate_packet_decl(
            }
        }

        #impl_from_and_try_from

        impl #id_packet {
            pub fn parse(#span: &[u8]) -> Result<Self> {
                let mut cell = Cell::new(#span);
@@ -236,35 +458,49 @@ fn generate_packet_decl(
            }

            fn parse_inner(mut bytes: &mut Cell<&[u8]>) -> Result<Self> {
                let packet = #id_data::parse(&mut bytes)?;
                Ok(Self::new(Arc::new(packet)).unwrap())
            }
            fn new(root: Arc<#id_data>) -> std::result::Result<Self, &'static str> {
                let #id_lower = root;
                Ok(Self { #id_lower })
                let data = #top_level_data::parse(&mut bytes)?;
                Ok(Self::new(Arc::new(data)).unwrap())
            }
            fn new(#top_level_id_lower: Arc<#top_level_data>)
                   -> std::result::Result<Self, &'static str> {
                #(
                    let #parent_shifted_lower_ids = match &#parent_lower_ids.child {
                        #parent_data_child::#parent_shifted_ids(value) => value.clone(),
                        _ => return Err("Could not parse data, wrong child type"),
                    };
                )*
                Ok(Self { #(#parent_lower_ids),* })
            }

            #(pub fn #getter_names(&self) -> #field_borrows #field_types {
                #field_borrows self.#id_lower.as_ref().#field_names
            #(pub fn #all_field_getter_names(&self) -> #all_field_borrows #all_field_types {
                #all_field_borrows #all_field_self_field.as_ref().#all_field_names
            })*

            #get_payload

            fn write_to(&self, buffer: &mut BytesMut) {
                self.#id_lower.write_to(buffer)
            }

            pub fn get_size(&self) -> usize {
                self.#id_lower.get_size()
                self.#top_level_id_lower.get_size()
            }
        }

        impl #id_builder {
            pub fn build(self) -> #id_packet {
                let #id_lower = Arc::new(#id_data {
                    #(#field_names: self.#field_names),*
                });
                #id_packet::new(#id_lower).unwrap()
                #(#builder_assignments;)*
                #id_packet::new(#top_level_id_lower).unwrap()
            }
        }

        #(
            impl From<#id_builder> for #parent_packet {
                fn from(builder: #id_builder) -> #parent_packet {
                    builder.build().into()
                }
            }
        )*
    }
}

@@ -334,16 +570,10 @@ fn generate_decl(
    decl: &parser_ast::Decl,
) -> String {
    match &decl.desc {
        ast::DeclDesc::Packet { id, constraints, fields, parent_id, .. }
        | ast::DeclDesc::Struct { id, constraints, fields, parent_id, .. } => generate_packet_decl(
            scope,
            file.endianness.value,
            id,
            constraints,
            fields,
            parent_id.as_deref(),
        )
        .to_string(),
        ast::DeclDesc::Packet { id, constraints, fields, .. }
        | ast::DeclDesc::Struct { id, constraints, fields, .. } => {
            generate_packet_decl(scope, file.endianness.value, id, constraints, fields).to_string()
        }
        ast::DeclDesc::Enum { id, tags, .. } => generate_enum_decl(id, tags).to_string(),
        _ => todo!("unsupported Decl::{:?}", decl),
    }
@@ -608,4 +838,61 @@ mod tests {
          }
        "
    );

    test_pdl!(
        packet_decl_payload_field_variable_size,
        "
          packet Foo {
              a: 8,
              _size_(_payload_): 8,
              _payload_,
              b: 16,
          }
        "
    );

    test_pdl!(
        packet_decl_payload_field_unknown_size,
        "
          packet Foo {
              a: 24,
              _payload_,
          }
        "
    );

    test_pdl!(
        packet_decl_payload_field_unknown_size_terminal,
        "
          packet Foo {
              _payload_,
              a: 24,
          }
        "
    );

    test_pdl!(
        packet_decl_child_packets,
        "
          enum Enum16 : 16 {
            A = 1,
            B = 2,
          }

          packet Foo {
              a: 8,
              b: Enum16,
              _size_(_payload_): 8,
              _payload_
          }

          packet Bar : Foo (a = 100) {
              x: 8,
          }

          packet Baz : Foo (b = B) {
              y: 16,
          }
        "
    );
}
+18 −5
Original line number Diff line number Diff line
use crate::backends::rust::types;
use crate::lint;
use crate::parser::ast as parser_ast;
use quote::{format_ident, quote};

pub struct FieldDeclarations {
pub struct FieldDeclarations<'a> {
    scope: &'a lint::Scope<'a>,
    packet_name: &'a str,
    code: Vec<proc_macro2::TokenStream>,
}

impl FieldDeclarations {
    pub fn new() -> FieldDeclarations {
        FieldDeclarations { code: Vec::new() }
impl<'a> FieldDeclarations<'a> {
    pub fn new(scope: &'a lint::Scope<'a>, packet_name: &'a str) -> FieldDeclarations<'a> {
        FieldDeclarations { scope, packet_name, code: Vec::new() }
    }

    pub fn add(&mut self, field: &parser_ast::Field) {
@@ -22,9 +25,19 @@ impl FieldDeclarations {
            #id: #field_type,
        })
    }

    pub fn done(&mut self) {
        let packet_data_child = format_ident!("{}DataChild", self.packet_name);
        let packet_scope = &self.scope.scopes[&self.scope.typedef[self.packet_name]];
        if self.scope.children.contains_key(self.packet_name) || packet_scope.payload.is_some() {
            self.code.push(quote! {
                child: #packet_data_child,
            });
        }
    }
}

impl quote::ToTokens for FieldDeclarations {
impl quote::ToTokens for FieldDeclarations<'_> {
    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
        let code = &self.code;
        tokens.extend(quote! {
Loading