Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit cae12168 authored by Martin Geisler's avatar Martin Geisler
Browse files

pdl: Test canonical vectors using Serde

Being able to serialize data with Serde allows us to test enums using
the canonical test vectors: we can take the output of PDL and
serialize it to JSON, which we then compare with the canonical test
vectors.

While this was motivated by the need for testing, Serde support can be
useful in general so I think this is an overall nice change. The
support is behind an off-by-default feature flag, just in case people
don’t want the generated code.

Test: atest pdl_tests pdl_rust_generator_tests_{le,be}
Change-Id: I6a19a939df7b449d350dbcda8a7daa2cb96733a2
parent d9892bbd
Loading
Loading
Loading
Loading
+8 −1
Original line number Diff line number Diff line
@@ -104,13 +104,18 @@ genrule {

rust_defaults {
    name: "pdl_backend_defaults",
    features: ["serde"],
    rustlibs: [
        "libbytes",
        "libnum_traits",
        "libserde",
        "libtempfile",
        "libthiserror",
    ],
    proc_macros: ["libnum_derive"],
    proc_macros: [
        "libnum_derive",
        "libserde_derive",
    ],
}

rust_library_host {
@@ -175,6 +180,7 @@ rust_test_host {
    rustlibs: [
        "libnum_traits",
        "libpdl_le_backend",
        "libserde_json",
    ],
}

@@ -185,6 +191,7 @@ rust_test_host {
    rustlibs: [
        "libnum_traits",
        "libpdl_be_backend",
        "libserde_json",
    ],
}

+8 −1
Original line number Diff line number Diff line
@@ -6,17 +6,24 @@ default-run = "pdl"

[workspace]

[features]
default = ["serde"]

[dependencies]
codespan-reporting = "0.11.1"
pest = "2.4.0"
pest_derive = "2.4.0"
proc-macro2 = "1.0.46"
quote = "1.0.21"
serde = { version = "1.0.145", features = ["default", "derive", "serde_derive", "std"] }
serde_json = "1.0.86"
clap = { version = "3.2.22", features = ["derive"] }
syn = "1.0.102"

[dependencies.serde]
version = "1.0.145"
features = ["default", "derive", "serde_derive", "std", "rc"]
optional = true

[dev-dependencies]
tempfile = "3.3.0"
bytes = "1.2.1"
+55 −10
Original line number Diff line number Diff line
@@ -86,16 +86,20 @@ fn generate_packet_decl(

    quote! {
        #[derive(Debug)]
        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
        struct #id_data {
            #field_declarations
        }

        #[derive(Debug, Clone)]
        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
        pub struct #id_packet {
            #[cfg_attr(feature = "serde", serde(flatten))]
            #id_lower: Arc<#id_data>,
        }

        #[derive(Debug)]
        #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
        pub struct #id_builder {
            #(pub #field_names: #field_types),*
        }
@@ -173,20 +177,61 @@ fn generate_packet_decl(
}

fn generate_enum_decl(id: &str, tags: &[ast::Tag]) -> proc_macro2::TokenStream {
    let variants = tags.iter().map(|t| {
        let variant = format_ident!("{}", t.id);
        let value = syn::parse_str::<syn::LitInt>(&format!("{:#x}", t.value)).unwrap();
        quote! {
            #variant = #value
        }
    });
    let name = format_ident!("{id}");
    let variants = tags.iter().map(|t| format_ident!("{}", t.id)).collect::<Vec<_>>();
    let values = tags
        .iter()
        .map(|t| syn::parse_str::<syn::LitInt>(&format!("{:#x}", t.value)).unwrap())
        .collect::<Vec<_>>();
    let visitor_name = format_ident!("{id}Visitor");

    let name = format_ident!("{}", id);
    quote! {
        #[derive(FromPrimitive, ToPrimitive, Debug, Hash, Eq, PartialEq, Clone, Copy)]
        #[repr(u64)]
        pub enum #name {
            #(#variants),*
            #(#variants = #values,)*
        }

        #[cfg(feature = "serde")]
        impl serde::Serialize for #name {
            fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
            where
                S: serde::Serializer,
            {
                serializer.serialize_u64(*self as u64)
            }
        }

        #[cfg(feature = "serde")]
        struct #visitor_name;

        #[cfg(feature = "serde")]
        impl<'de> serde::de::Visitor<'de> for #visitor_name {
            type Value = #name;

            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
                formatter.write_str("a valid discriminant")
            }

            fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
            where
                E: serde::de::Error,
            {
                match value {
                    #(#values => Ok(#name::#variants),)*
                    _ => Err(E::custom(format!("invalid discriminant: {value}"))),
                }
            }
        }

        #[cfg(feature = "serde")]
        impl<'de> serde::Deserialize<'de> for #name {
            fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
            where
                D: serde::Deserializer<'de>,
            {
                deserializer.deserialize_u64(#visitor_name)
            }
        }
    }
}
+22 −28
Original line number Diff line number Diff line
//! Generate Rust unit tests for canonical test vectors.

use quote::{format_ident, quote};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use serde_json::Value;

#[derive(Debug, Deserialize)]
@@ -18,10 +18,10 @@ struct TestVector {
    packet: Option<String>,
}

// Convert a string of hexadecimal characters into a Rust vector of
// bytes.
//
// The string `"80038302"` becomes `vec![0x80, 0x03, 0x83, 0x02]`.
/// Convert a string of hexadecimal characters into a Rust vector of
/// bytes.
///
/// The string `"80038302"` becomes `vec![0x80, 0x03, 0x83, 0x02]`.
fn hexadecimal_to_vec(hex: &str) -> proc_macro2::TokenStream {
    assert!(hex.len() % 2 == 0, "Expects an even number of hex digits");
    let bytes = hex.as_bytes().chunks_exact(2).map(|chunk| {
@@ -34,6 +34,16 @@ fn hexadecimal_to_vec(hex: &str) -> proc_macro2::TokenStream {
    }
}

/// Convert `value` to a JSON string literal.
///
/// The string literal is a raw literal to avoid escaping
/// double-quotes.
fn to_json<T: Serialize>(value: &T) -> syn::LitStr {
    let json = serde_json::to_string(value).unwrap();
    assert!(!json.contains("\"#"), "Please increase number of # for {json:?}");
    syn::parse_str::<syn::LitStr>(&format!("r#\" {json} \"#")).unwrap()
}

fn generate_unit_tests(input: &str, packet_names: &[&str], module_name: &str) {
    eprintln!("Reading test vectors from {input}, will use {} packets", packet_names.len());

@@ -71,30 +81,15 @@ fn generate_unit_tests(input: &str, packet_names: &[&str], module_name: &str) {
            });
            let assertions = object.iter().map(|(key, value)| {
                let getter = format_ident!("get_{key}");
                let value_u64 = value
                    .as_u64()
                    .unwrap_or_else(|| panic!("Expected u64 for {key:?} key, got {value}"));
                let value = proc_macro2::Literal::u64_unsuffixed(value_u64);
                // We lack type information, but ToPrimitive allows us
                // to convert both integers and enums to u64.
                quote! {
                    assert_eq!(actual.#getter().to_u64().unwrap(), #value);
                }
            });

            let builder_fields = object.iter().map(|(key, value)| {
                let field = format_ident!("{key}");
                let value_u64 = value
                    .as_u64()
                    .unwrap_or_else(|| panic!("Expected u64 for {key:?} key, got {value}"));
                let value = proc_macro2::Literal::u64_unsuffixed(value_u64);
                // We lack type information, but FromPrimitive allows
                // us to convert both integers and enums to u64.
                let expected = format_ident!("expected_{key}");
                let json = to_json(&value);
                quote! {
                    #field: FromPrimitive::from_u64(#value).unwrap()
                    let #expected: serde_json::Value = serde_json::from_str(#json).unwrap();
                    assert_eq!(json!(actual.#getter()), #expected);
                }
            });

            let json = to_json(&object);
            tests.push(quote! {
                #[test]
                fn #parse_test_name() {
@@ -105,9 +100,7 @@ fn generate_unit_tests(input: &str, packet_names: &[&str], module_name: &str) {

                #[test]
                fn #serialize_test_name() {
                    let builder =  #module::#builder_name {
                        #(#builder_fields,)*
                    };
                    let builder: #module::#builder_name = serde_json::from_str(#json).unwrap();
                    let packet = builder.build();
                    let packed = #packed;
                    assert_eq!(packet.to_vec(), packed);
@@ -124,6 +117,7 @@ fn generate_unit_tests(input: &str, packet_names: &[&str], module_name: &str) {
        &quote! {
            use #module::Packet;
            use num_traits::{FromPrimitive, ToPrimitive};
            use serde_json::json;

            #(#tests)*
        }
+41 −0
Original line number Diff line number Diff line
@@ -41,16 +41,57 @@ pub enum Foo {
    A = 0x1,
    B = 0x2,
}
#[cfg(feature = "serde")]
impl serde::Serialize for Foo {
    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        serializer.serialize_u64(*self as u64)
    }
}
#[cfg(feature = "serde")]
struct FooVisitor;
#[cfg(feature = "serde")]
impl<'de> serde::de::Visitor<'de> for FooVisitor {
    type Value = Foo;
    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
        formatter.write_str("a valid discriminant")
    }
    fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
    where
        E: serde::de::Error,
    {
        match value {
            0x1 => Ok(Foo::A),
            0x2 => Ok(Foo::B),
            _ => Err(E::custom(format!("invalid discriminant: {value}"))),
        }
    }
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for Foo {
    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        deserializer.deserialize_u64(FooVisitor)
    }
}

#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct BarData {
    x: Foo,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct BarPacket {
    #[cfg_attr(feature = "serde", serde(flatten))]
    bar: Arc<BarData>,
}
#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct BarBuilder {
    pub x: Foo,
}
Loading