Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ac989337 authored by Henri Chataing's avatar Henri Chataing
Browse files

pdl: Implement support for tag ranges in the rust generator

Bug: 267339120
Test: cargo test
Ignore-AOSP-First: API breaking change in PDL w/ UWB dependant
Change-Id: Ic79a818a082cedac38c0d103a7b1e1abd60914fb
parent 2d2c1a7c
Loading
Loading
Loading
Loading
+231 −41
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ use crate::{ast, lint};
use quote::{format_ident, quote};
use std::collections::BTreeSet;
use std::path::Path;
use syn::LitInt;

use crate::analyzer::ast as analyzer_ast;

@@ -680,65 +681,204 @@ fn generate_struct_decl(
    }
}

fn generate_enum_decl(id: &str, tags: &[ast::Tag]) -> proc_macro2::TokenStream {
    let name = format_ident!("{id}");
    let variants =
        tags.iter().map(|t| format_ident!("{}", t.id().to_upper_camel_case())).collect::<Vec<_>>();
    let values = tags
/// Generate an enum declaration.
///
/// # Arguments
/// * `id` - Enum identifier.
/// * `tags` - List of enum tags.
/// * `width` - Width of the backing type of the enum, in bits.
/// * `open` - Whether to generate an open or closed enum. Open enums have
///            an additional Unknown case for unmatched valued. Complete
///            enums (where the full range of values is covered) are
///            automatically closed.
fn generate_enum_decl(
    id: &str,
    tags: &[ast::Tag],
    width: usize,
    open: bool,
) -> proc_macro2::TokenStream {
    // Determine if the enum is complete, i.e. all values in the backing
    // integer range have a matching tag in the original declaration.
    fn enum_is_complete(tags: &[ast::Tag], max: usize) -> bool {
        let mut ranges = tags
            .iter()
        .map(|t| syn::parse_str::<syn::LitInt>(&format!("{:#x}", t.value().unwrap())).unwrap())
            .map(|tag| match tag {
                ast::Tag::Value(tag) => (tag.value, tag.value),
                ast::Tag::Range(tag) => tag.range.clone().into_inner(),
            })
            .collect::<Vec<_>>();
    let visitor_name = format_ident!("{id}Visitor");
        ranges.sort_unstable();
        ranges.first().unwrap().0 == 0
            && ranges.last().unwrap().1 == max
            && ranges.windows(2).all(|window| {
                if let [left, right] = window {
                    left.1 == right.0 - 1
                } else {
                    false
                }
            })
    }

    quote! {
        #[derive(FromPrimitive, ToPrimitive, Debug, Hash, Eq, PartialEq, Clone, Copy)]
        #[repr(u64)]
        pub enum #name {
            #(#variants = #values,)*
    // Determine if the enum is primitive, i.e. does not contain any
    // tag range.
    fn enum_is_primitive(tags: &[ast::Tag]) -> bool {
        tags.iter().all(|tag| matches!(tag, ast::Tag::Value(_)))
    }

        #[cfg(feature = "serde")]
        impl serde::Serialize for #name {
            fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
            where
                S: serde::Serializer,
            {
                serializer.serialize_u64(*self as u64)
    // Return the maximum value for the scalar type.
    fn scalar_max(width: usize) -> usize {
        if width >= usize::BITS as usize {
            usize::MAX
        } else {
            (1 << width) - 1
        }
    }

    // Format an enum tag identifier to rust upper caml case.
    fn format_tag_ident(id: &str) -> proc_macro2::TokenStream {
        let id = format_ident!("{}", id.to_upper_camel_case());
        quote! { #id }
    }

    // Format a constant value as hexadecimal constant.
    fn format_value(value: usize) -> LitInt {
        syn::parse_str::<syn::LitInt>(&format!("{:#x}", value)).unwrap()
    }

        #[cfg(feature = "serde")]
        struct #visitor_name;
    // Backing type for the enum.
    let backing_type = types::Integer::new(width);
    let backing_type_str = proc_macro2::Literal::string(&format!("u{}", backing_type.width));
    let range_max = scalar_max(width);
    let is_complete = enum_is_complete(tags, scalar_max(width));
    let is_primitive = enum_is_primitive(tags);
    let name = format_ident!("{id}");

        #[cfg(feature = "serde")]
        impl<'de> serde::de::Visitor<'de> for #visitor_name {
            type Value = #name;
    // Generate the variant cases for the enum declaration.
    // Tags declared in ranges are flattened in the same declaration.
    let use_variant_values = is_primitive && (is_complete || !open);
    let mut variants = vec![];
    for tag in tags.iter() {
        match tag {
            ast::Tag::Value(tag) if use_variant_values => {
                let id = format_tag_ident(&tag.id);
                let value = format_value(tag.value);
                variants.push(quote! { #id = #value })
            }
            ast::Tag::Value(tag) => variants.push(format_tag_ident(&tag.id)),
            ast::Tag::Range(tag) => {
                variants.extend(tag.tags.iter().map(|tag| format_tag_ident(&tag.id)));
                let id = format_tag_ident(&tag.id);
                variants.push(quote! { #id(Private<#backing_type>) })
            }
        }
    }

    // Generate the cases for parsing the enum value from an integer.
    let mut from_cases = vec![];
    for tag in tags.iter() {
        match tag {
            ast::Tag::Value(tag) => {
                let id = format_tag_ident(&tag.id);
                let value = format_value(tag.value);
                from_cases.push(quote! { #value => Ok(#name::#id) })
            }
            ast::Tag::Range(tag) => {
                from_cases.extend(tag.tags.iter().map(|tag| {
                    let id = format_tag_ident(&tag.id);
                    let value = format_value(tag.value);
                    quote! { #value => Ok(#name::#id) }
                }));
                let id = format_tag_ident(&tag.id);
                let start = format_value(*tag.range.start());
                let end = format_value(*tag.range.end());
                from_cases.push(quote! { #start ..= #end => Ok(#name::#id(Private(value))) })
            }
        }
    }

    // Generate the cases for serializing the enum value to an integer.
    let mut into_cases = vec![];
    for tag in tags.iter() {
        match tag {
            ast::Tag::Value(tag) => {
                let id = format_tag_ident(&tag.id);
                let value = format_value(tag.value);
                into_cases.push(quote! { #name::#id => #value })
            }
            ast::Tag::Range(tag) => {
                into_cases.extend(tag.tags.iter().map(|tag| {
                    let id = format_tag_ident(&tag.id);
                    let value = format_value(tag.value);
                    quote! { #name::#id => #value }
                }));
                let id = format_tag_ident(&tag.id);
                into_cases.push(quote! { #name::#id(Private(value)) => *value })
            }
        }
    }

    // Generate a default case if the enum is open and incomplete.
    if !is_complete && open {
        variants.push(quote! { Unknown(Private<#backing_type>) });
        from_cases.push(quote! { 0..#range_max => Ok(#name::Unknown(Private(value))) });
        into_cases.push(quote! { #name::Unknown(Private(value)) => *value });
    }

    // Generate an error case if the enum size is lower than the backing
    // type size, or if the enum is closed or incomplete.
    if backing_type.width != width || (!is_complete && !open) {
        from_cases.push(quote! { _ => Err(value) });
    }

    // Derive other Into<uN> and Into<iN> implementations from the explicit
    // implementation, where the type is larger than the backing type.
    let derived_signed_into_types = [8, 16, 32, 64]
        .into_iter()
        .filter(|w| *w > width)
        .map(|w| syn::parse_str::<syn::Type>(&format!("i{}", w)).unwrap());
    let derived_unsigned_into_types = [8, 16, 32, 64]
        .into_iter()
        .filter(|w| *w >= width && *w != backing_type.width)
        .map(|w| syn::parse_str::<syn::Type>(&format!("u{}", w)).unwrap());
    let derived_into_types = derived_signed_into_types.chain(derived_unsigned_into_types);

            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
                formatter.write_str("a valid discriminant")
    quote! {
        #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
        #[cfg_attr(feature = "serde", serde(try_from = #backing_type_str, into = #backing_type_str))]
        pub enum #name {
            #(#variants,)*
        }

            fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
            where
                E: serde::de::Error,
            {
        impl TryFrom<#backing_type> for #name {
            type Error = #backing_type;
            fn try_from(value: #backing_type) -> std::result::Result<Self, Self::Error> {
                match value {
                    #(#values => Ok(#name::#variants),)*
                    _ => Err(E::custom(format!("invalid discriminant: {value}"))),
                    #(#from_cases,)*
                }
            }
        }

        #[cfg(feature = "serde")]
        impl<'de> serde::Deserialize<'de> for #name {
            fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
            where
                D: serde::Deserializer<'de>,
            {
                deserializer.deserialize_u64(#visitor_name)
        impl From<&#name> for #backing_type {
            fn from(value: &#name) -> Self {
                match value {
                    #(#into_cases,)*
                }
            }
        }

        impl From<#name> for #backing_type {
            fn from(value: #name) -> Self {
                (&value).into()
            }
        }

        #(impl From<#name> for #derived_into_types {
            fn from(value: #name) -> Self {
                #backing_type::from(value) as Self
            }
        })*
    }
}

fn generate_decl(
@@ -758,7 +898,9 @@ fn generate_decl(
            // implement the recursive (de)serialization.
            generate_struct_decl(scope, file.endianness.value, id).to_string()
        }
        ast::DeclDesc::Enum { id, tags, .. } => generate_enum_decl(id, tags).to_string(),
        ast::DeclDesc::Enum { id, tags, width } => {
            generate_enum_decl(id, tags, *width, false).to_string()
        }
        _ => todo!("unsupported Decl::{:?}", decl),
    }
}
@@ -896,6 +1038,54 @@ mod tests {
    test_pdl!(packet_decl_24bit_scalar, "packet Foo { x: 24 }");
    test_pdl!(packet_decl_64bit_scalar, "packet Foo { x: 64 }");

    test_pdl!(
        enum_declaration,
        r#"
        // Should generate unknown case.
        enum IncompleteTruncated : 3 {
            A = 0,
            B = 1,
        }

        // Should generate unknown case.
        enum IncompleteTruncatedWithRange : 3 {
            A = 0,
            B = 1..6 {
                X = 1,
                Y = 2,
            }
        }

        // Should generate unreachable case.
        enum CompleteTruncated : 3 {
            A = 0,
            B = 1,
            C = 2,
            D = 3,
            E = 4,
            F = 5,
            G = 6,
            H = 7,
        }

        // Should generate unreachable case.
        enum CompleteTruncatedWithRange : 3 {
            A = 0,
            B = 1..7 {
                X = 1,
                Y = 2,
            }
        }

        // Should generate no unknown or unreachable case.
        enum CompleteWithRange : 8 {
            A = 0,
            B = 1,
            C = 2..255,
        }
        "#
    );

    test_pdl!(
        packet_decl_simple_scalars,
        r#"
+6 −11
Original line number Diff line number Diff line
@@ -133,9 +133,9 @@ impl<'a> FieldParser<'a> {
                    let enum_id = format_ident!("{enum_id}");
                    let tag_id = format_ident!("{}", tag_id.to_upper_camel_case());
                    quote! {
                        if #v != #enum_id::#tag_id as #value_type {
                        if #v != #enum_id::#tag_id.into()  {
                            return Err(Error::InvalidFixedValue {
                                expected: #enum_id::#tag_id as u64,
                                expected: #value_type::from(#enum_id::#tag_id) as u64,
                                actual: #v as u64,
                            });
                        }
@@ -153,15 +153,12 @@ impl<'a> FieldParser<'a> {
                    }
                }
                ast::FieldDesc::Typedef { id, type_id } => {
                    let id = format_ident!("{id}");
                    let type_id = format_ident!("{type_id}");
                    let from_u = format_ident!("from_u{}", value_type.width);
                    // TODO(mgeisler): Remove the `unwrap` from the
                    // generated code and return the error to the
                    // caller.
                    quote! {
                        let #id = #type_id::#from_u(#v).unwrap();
                    }
                    let id = format_ident!("{id}");
                    let type_id = format_ident!("{type_id}");
                    quote! { let #id = #type_id::try_from(#v).unwrap(); }
                }
                ast::FieldDesc::Reserved { .. } => {
                    if single_value {
@@ -545,13 +542,11 @@ impl<'a> FieldParser<'a> {
        }

        if let Some(ast::DeclDesc::Enum { id, width, .. }) = decl.map(|decl| &decl.desc) {
            let element_type = types::Integer::new(*width);
            let get_uint = types::get_uint(self.endianness, *width, span);
            let type_id = format_ident!("{id}");
            let from_u = format_ident!("from_u{}", element_type.width);
            let packet_name = &self.packet_name;
            return quote! {
                #type_id::#from_u(#get_uint).ok_or_else(|| Error::InvalidEnumValueError {
                #type_id::try_from(#get_uint).map_err(|_| Error::InvalidEnumValueError {
                    obj: #packet_name.to_string(),
                    field: String::new(), // TODO(mgeisler): fill out or remove
                    value: 0,
+16 −2
Original line number Diff line number Diff line
@@ -24,8 +24,6 @@ pub fn generate(path: &Path) -> String {

    code.push_str(&quote_block! {
        use bytes::{Buf, BufMut, Bytes, BytesMut};
        use num_derive::{FromPrimitive, ToPrimitive};
        use num_traits::{FromPrimitive, ToPrimitive};
        use std::convert::{TryFrom, TryInto};
        use std::cell::Cell;
        use std::fmt;
@@ -37,6 +35,22 @@ pub fn generate(path: &Path) -> String {
        type Result<T> = std::result::Result<T, Error>;
    });

    code.push_str(&quote_block! {
        /// Private prevents users from creating arbitrary scalar values
        /// in situations where the value needs to be validated.
        /// Users can freely deref the value, but only the backend
        /// may create it.
        #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
        pub struct Private<T>(T);

        impl<T> std::ops::Deref for Private<T> {
            type Target = T;
            fn deref(&self) -> &Self::Target {
                &self.0
            }
        }
    });

    code.push_str(&quote_block! {
        #[derive(Debug, Error)]
        pub enum Error {
+8 −8
Original line number Diff line number Diff line
@@ -80,7 +80,11 @@ impl<'a> FieldSerializer<'a> {
                let field_type = types::Integer::new(width);
                let enum_id = format_ident!("{enum_id}");
                let tag_id = format_ident!("{}", tag_id.to_upper_camel_case());
                self.chunk.push(BitField { value: quote!(#enum_id::#tag_id), field_type, shift });
                self.chunk.push(BitField {
                    value: quote!(#field_type::from(#enum_id::#tag_id)),
                    field_type,
                    shift,
                });
            }
            ast::FieldDesc::FixedScalar { value, .. } => {
                let field_type = types::Integer::new(width);
@@ -90,11 +94,8 @@ impl<'a> FieldSerializer<'a> {
            ast::FieldDesc::Typedef { id, .. } => {
                let field_name = format_ident!("{id}");
                let field_type = types::Integer::new(width);
                let to_u = format_ident!("to_u{}", field_type.width);
                // TODO(mgeisler): remove `unwrap` and return error to
                // caller in generated code.
                self.chunk.push(BitField {
                    value: quote!(self.#field_name.#to_u().unwrap()),
                    value: quote!(#field_type::from(self.#field_name)),
                    field_type,
                    shift,
                });
@@ -254,11 +255,10 @@ impl<'a> FieldSerializer<'a> {
            }
            None => {
                if let Some(ast::DeclDesc::Enum { width, .. }) = decl.map(|decl| &decl.desc) {
                    let field_type = types::Integer::new(*width);
                    let to_u = format_ident!("to_u{}", field_type.width);
                    let element_type = types::Integer::new(*width);
                    types::put_uint(
                        self.endianness,
                        &quote!(elem.#to_u().unwrap()),
                        &quote!(#element_type::from(elem)),
                        *width,
                        self.span,
                    )
+0 −1
Original line number Diff line number Diff line
@@ -120,7 +120,6 @@ fn generate_unit_tests(input: &str, packet_names: &[&str], module_name: &str) {
        "{}",
        &quote! {
            use #module::Packet;
            use num_traits::{FromPrimitive, ToPrimitive};
            use serde_json::json;

            #(#tests)*
Loading