Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d1f1e5f2 authored by Martin Geisler's avatar Martin Geisler
Browse files

pdl: let lint phase do all error handling

All errors in the code generator are internal bugs: the lint phase
should catch all illegal PDL constructs. This means that we can safely
crash when generating code (and then add more lints as needed to
prevent these crashes).

Bug: 230475552
Test: atest 'pdl_*'
Change-Id: If686acbb0f5edce0bae17611c4b1aa39dc27eac5
parent ce0be425
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -12,7 +12,6 @@ rust_defaults {
    srcs: ["src/main.rs"],
    // LINT.IfChange
    rustlibs: [
        "libanyhow",
        "libcodespan_reporting",
        "libpest",
        "libproc_macro2",
+0 −1
Original line number Diff line number Diff line
@@ -6,7 +6,6 @@ edition = "2021"
[workspace]

[dependencies]
anyhow = "*"
codespan-reporting = "*"
pest = "*"
pest_derive = "*"
+54 −82
Original line number Diff line number Diff line
use crate::ast;
use anyhow::{anyhow, bail, Context, Result};
use quote::{format_ident, quote};
use std::collections::HashMap;
use std::fmt::Write;
@@ -17,12 +16,9 @@ macro_rules! quote_block {
}

/// Generate the file preamble.
fn generate_preamble(path: &Path) -> Result<String> {
fn generate_preamble(path: &Path) -> String {
    let mut code = String::new();
    let filename = path
        .file_name()
        .and_then(|path| path.to_str())
        .ok_or_else(|| anyhow!("could not find filename in {:?}", path))?;
    let filename = path.file_name().unwrap().to_str().expect("non UTF-8 filename");
    let _ = write!(code, "// @generated rust packets from {filename}\n\n");

    let _ = write!(
@@ -88,58 +84,52 @@ fn generate_preamble(path: &Path) -> Result<String> {
        }
    );

    Ok(code)
    code
}

/// Round up the bit width to a Rust integer size.
fn round_bit_width(width: usize) -> Result<usize> {
fn round_bit_width(width: usize) -> usize {
    match width {
        8 => Ok(8),
        16 => Ok(16),
        24 | 32 => Ok(32),
        40 | 48 | 56 | 64 => Ok(64),
        _ => bail!("unsupported field width: {width}"),
        8 => 8,
        16 => 16,
        24 | 32 => 32,
        40 | 48 | 56 | 64 => 64,
        _ => todo!("unsupported field width: {width}"),
    }
}

/// Generate a Rust unsigned integer type large enough to hold
/// integers of the given bit width.
fn type_for_width(width: usize) -> Result<syn::Type> {
    let rounded_width = round_bit_width(width)?;
    syn::parse_str(&format!("u{rounded_width}")).map_err(anyhow::Error::from)
fn type_for_width(width: usize) -> syn::Type {
    let rounded_width = round_bit_width(width);
    syn::parse_str(&format!("u{rounded_width}")).unwrap()
}

fn generate_field(
    field: &ast::Field,
    visibility: syn::Visibility,
) -> Result<proc_macro2::TokenStream> {
fn generate_field(field: &ast::Field, visibility: syn::Visibility) -> proc_macro2::TokenStream {
    match field {
        ast::Field::Scalar { id, width, .. } => {
            let field_name = format_ident!("{id}");
            let field_type = type_for_width(*width)?;
            Ok(quote! {
            let field_type = type_for_width(*width);
            quote! {
                #visibility #field_name: #field_type
            })
            }
        }
        _ => todo!("unsupported field: {:?}", field),
    }
}

fn generate_field_getter(
    packet_name: &syn::Ident,
    field: &ast::Field,
) -> Result<proc_macro2::TokenStream> {
fn generate_field_getter(packet_name: &syn::Ident, field: &ast::Field) -> proc_macro2::TokenStream {
    match field {
        ast::Field::Scalar { id, width, .. } => {
            // TODO(mgeisler): refactor with generate_field above.
            let getter_name = format_ident!("get_{id}");
            let field_name = format_ident!("{id}");
            let field_type = type_for_width(*width)?;
            Ok(quote! {
            let field_type = type_for_width(*width);
            quote! {
                pub fn #getter_name(&self) -> #field_type {
                    self.#packet_name.as_ref().#field_name
                }
            })
            }
        }
        _ => todo!("unsupported field: {:?}", field),
    }
@@ -150,12 +140,12 @@ fn generate_field_parser(
    packet_name: &str,
    field: &ast::Field,
    offset: usize,
) -> Result<proc_macro2::TokenStream> {
) -> proc_macro2::TokenStream {
    match field {
        ast::Field::Scalar { id, width, .. } => {
            let field_name = format_ident!("{id}");
            let type_width = round_bit_width(*width)?;
            let field_type = type_for_width(*width)?;
            let type_width = round_bit_width(*width);
            let field_type = type_for_width(*width);

            let getter = match endianness_value {
                ast::EndiannessValue::BigEndian => format_ident!("from_be_bytes"),
@@ -173,7 +163,7 @@ fn generate_field_parser(
                None
            };

            Ok(quote! {
            quote! {
                // TODO(mgeisler): call a function instead to avoid
                // generating so much code for this.
                if bytes.len() < #wanted_len {
@@ -186,7 +176,7 @@ fn generate_field_parser(
                }
                let #field_name = #field_type::#getter([#(bytes[#indices]),* #(, #padding)*]);
                #mask
            })
            }
        }
        _ => todo!("unsupported field: {:?}", field),
    }
@@ -196,11 +186,11 @@ fn generate_field_writer(
    grammar: &ast::Grammar,
    field: &ast::Field,
    offset: usize,
) -> Result<proc_macro2::TokenStream> {
) -> proc_macro2::TokenStream {
    match field {
        ast::Field::Scalar { id, width, .. } => {
            let field_name = format_ident!("{id}");
            let bit_width = round_bit_width(*width)?;
            let bit_width = round_bit_width(*width);
            let start = syn::Index::from(offset);
            let end = syn::Index::from(offset + bit_width / 8);
            let byte_width = syn::Index::from(bit_width / 8);
@@ -208,10 +198,10 @@ fn generate_field_writer(
                ast::EndiannessValue::BigEndian => format_ident!("to_be_bytes"),
                ast::EndiannessValue::LittleEndian => format_ident!("to_le_bytes"),
            };
            Ok(quote! {
            quote! {
                let #field_name = self.#field_name;
                buffer[#start..#end].copy_from_slice(&#field_name.#writer()[0..#byte_width]);
            })
            }
        }
        _ => todo!("unsupported field: {:?}", field),
    }
@@ -232,7 +222,7 @@ fn generate_packet_decl(
    id: &str,
    fields: &[ast::Field],
    parent_id: &Option<String>,
) -> Result<String> {
) -> String {
    // TODO(mgeisler): use the convert_case crate to convert between
    // `FooBar` and `foo_bar` in the code below.
    let mut code = String::new();
@@ -282,10 +272,7 @@ fn generate_packet_decl(
            child: #data_child_ident,
        }
    });
    let plain_fields = fields
        .iter()
        .map(|field| generate_field(field, parse_quote!()))
        .collect::<Result<Vec<_>>>()?;
    let plain_fields = fields.iter().map(|field| generate_field(field, parse_quote!()));
    let _ = write!(
        code,
        "{}",
@@ -323,10 +310,7 @@ fn generate_packet_decl(
    );

    let builder_name = format_ident!("{id}Builder");
    let pub_fields = fields
        .iter()
        .map(|field| generate_field(field, parse_quote!(pub)))
        .collect::<Result<Vec<_>>>()?;
    let pub_fields = fields.iter().map(|field| generate_field(field, parse_quote!(pub)));
    let _ = write!(
        code,
        "{}",
@@ -341,14 +325,11 @@ fn generate_packet_decl(
    // TODO(mgeisler): use the `Buf` trait instead of tracking
    // the offset manually.
    let mut offset = 0;
    let field_parsers = fields
        .iter()
        .map(|field| {
    let field_parsers = fields.iter().map(|field| {
        let parser = generate_field_parser(&grammar.endianness.value, id, field, offset);
        offset += get_field_size(field);
        parser
        })
        .collect::<Result<Vec<_>>>()?;
    });
    let field_names = fields
        .iter()
        .map(|field| match field {
@@ -357,14 +338,11 @@ fn generate_packet_decl(
        })
        .collect::<Vec<_>>();
    let mut offset = 0;
    let field_writers = fields
        .iter()
        .map(|field| {
    let field_writers = fields.iter().map(|field| {
        let writer = generate_field_writer(grammar, field, offset);
        offset += get_field_size(field);
        writer
        })
        .collect::<Result<Vec<_>>>()?;
    });

    let total_field_size = syn::Index::from(fields.iter().map(get_field_size).sum::<usize>());
    let get_size_adjustment = (total_field_size.index > 0).then(|| {
@@ -449,10 +427,7 @@ fn generate_packet_decl(
            }
        }
    });
    let field_getters = fields
        .iter()
        .map(|field| generate_field_getter(&ident, field))
        .collect::<Result<Vec<_>>>()?;
    let field_getters = fields.iter().map(|field| generate_field_getter(&ident, field));
    let _ = write!(
        code,
        "{}",
@@ -495,7 +470,7 @@ fn generate_packet_decl(
        }
    );

    Ok(code)
    code
}

fn generate_decl(
@@ -503,7 +478,7 @@ fn generate_decl(
    packets: &HashMap<&str, &ast::Decl>,
    children: &HashMap<&str, Vec<&str>>,
    decl: &ast::Decl,
) -> Result<String> {
) -> String {
    let empty: Vec<&str> = vec![];
    match decl {
        ast::Decl::Packet { id, fields, parent_id, .. } => generate_packet_decl(
@@ -522,9 +497,8 @@ fn generate_decl(
///
/// The code is not formatted, pipe it through `rustfmt` to get
/// readable source code.
pub fn generate_rust(sources: &ast::SourceDatabase, grammar: &ast::Grammar) -> Result<String> {
    let source =
        sources.get(grammar.file).with_context(|| format!("could not read {}", grammar.file))?;
pub fn generate_rust(sources: &ast::SourceDatabase, grammar: &ast::Grammar) -> String {
    let source = sources.get(grammar.file).expect("could not read source");

    let mut children = HashMap::new();
    let mut packets = HashMap::new();
@@ -539,16 +513,14 @@ pub fn generate_rust(sources: &ast::SourceDatabase, grammar: &ast::Grammar) -> R

    let mut code = String::new();

    code.push_str(&generate_preamble(Path::new(source.name()))?);
    code.push_str(&generate_preamble(Path::new(source.name())));

    for decl in &grammar.declarations {
        let decl_code = generate_decl(grammar, &packets, &children, decl)
            .with_context(|| format!("failed to generating code for {:?}", decl))?;
        code.push_str(&decl_code);
        code.push_str(&generate_decl(grammar, &packets, &children, decl));
        code.push_str("\n\n");
    }

    Ok(code)
    code
}

#[cfg(test)]
@@ -570,7 +542,7 @@ mod tests {

    #[test]
    fn test_generate_preamble() {
        let actual_code = generate_preamble(Path::new("some/path/foo.pdl")).unwrap();
        let actual_code = generate_preamble(Path::new("some/path/foo.pdl"));
        assert_snapshot_eq("tests/generated/preamble.rs", &rustfmt(&actual_code));
    }

@@ -585,7 +557,7 @@ mod tests {
        let packets = HashMap::new();
        let children = HashMap::new();
        let decl = &grammar.declarations[0];
        let actual_code = generate_decl(&grammar, &packets, &children, decl).unwrap();
        let actual_code = generate_decl(&grammar, &packets, &children, decl);
        assert_snapshot_eq("tests/generated/packet_decl_empty.rs", &rustfmt(&actual_code));
    }

@@ -604,7 +576,7 @@ mod tests {
        let packets = HashMap::new();
        let children = HashMap::new();
        let decl = &grammar.declarations[0];
        let actual_code = generate_decl(&grammar, &packets, &children, decl).unwrap();
        let actual_code = generate_decl(&grammar, &packets, &children, decl);
        assert_snapshot_eq(
            "tests/generated/packet_decl_simple_little_endian.rs",
            &rustfmt(&actual_code),
@@ -626,7 +598,7 @@ mod tests {
        let packets = HashMap::new();
        let children = HashMap::new();
        let decl = &grammar.declarations[0];
        let actual_code = generate_decl(&grammar, &packets, &children, decl).unwrap();
        let actual_code = generate_decl(&grammar, &packets, &children, decl);
        assert_snapshot_eq(
            "tests/generated/packet_decl_simple_big_endian.rs",
            &rustfmt(&actual_code),
+14 −8
Original line number Diff line number Diff line
@@ -47,33 +47,39 @@ struct Opt {
    input_file: String,
}

fn main() {
fn main() -> std::process::ExitCode {
    let opt = Opt::from_args();

    if opt.version {
        println!("Packet Description Language parser version 1.0");
        return;
        return std::process::ExitCode::SUCCESS;
    }

    let mut sources = ast::SourceDatabase::new();
    match parser::parse_file(&mut sources, opt.input_file) {
        Ok(grammar) => {
            let _ = grammar.lint().print(&sources, termcolor::ColorChoice::Always);
            let lint = grammar.lint();
            if !lint.diagnostics.is_empty() {
                lint.print(&sources, termcolor::ColorChoice::Always)
                    .expect("Could not print lint diagnostics");
                return std::process::ExitCode::FAILURE;
            }

            match opt.output_format {
                OutputFormat::JSON => {
                    println!("{}", serde_json::to_string_pretty(&grammar).unwrap())
                }
                OutputFormat::Rust => match generator::generate_rust(&sources, &grammar) {
                    Ok(code) => println!("{}", &code),
                    Err(err) => println!("failed to generate code: {}", err),
                },
                OutputFormat::Rust => {
                    println!("{}", generator::generate_rust(&sources, &grammar))
                }
            }
            std::process::ExitCode::SUCCESS
        }
        Err(err) => {
            let writer = termcolor::StandardStream::stderr(termcolor::ColorChoice::Always);
            let config = term::Config::default();
            _ = term::emit(&mut writer.lock(), &config, &sources, &err);
            term::emit(&mut writer.lock(), &config, &sources, &err).expect("Could not print error");
            std::process::ExitCode::FAILURE
        }
    }
}