Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e735877d authored by Martin Geisler's avatar Martin Geisler Committed by Automerger Merge Worker
Browse files

pdl: let lint phase do all error handling am: d1f1e5f2 am: eff69d76

parents 2c7676d0 eff69d76
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -12,7 +12,6 @@ rust_defaults {
    srcs: ["src/main.rs"],
    // LINT.IfChange
    rustlibs: [
        "libanyhow",
        "libcodespan_reporting",
        "libpest",
        "libproc_macro2",
+0 −1
Original line number Diff line number Diff line
@@ -6,7 +6,6 @@ edition = "2021"
[workspace]

[dependencies]
anyhow = "*"
codespan-reporting = "*"
pest = "*"
pest_derive = "*"
+54 −82
Original line number Diff line number Diff line
use crate::ast;
use anyhow::{anyhow, bail, Context, Result};
use quote::{format_ident, quote};
use std::collections::HashMap;
use std::fmt::Write;
@@ -17,12 +16,9 @@ macro_rules! quote_block {
}

/// Generate the file preamble.
fn generate_preamble(path: &Path) -> Result<String> {
fn generate_preamble(path: &Path) -> String {
    let mut code = String::new();
    let filename = path
        .file_name()
        .and_then(|path| path.to_str())
        .ok_or_else(|| anyhow!("could not find filename in {:?}", path))?;
    let filename = path.file_name().unwrap().to_str().expect("non UTF-8 filename");
    let _ = write!(code, "// @generated rust packets from {filename}\n\n");

    let _ = write!(
@@ -88,58 +84,52 @@ fn generate_preamble(path: &Path) -> Result<String> {
        }
    );

    Ok(code)
    code
}

/// Round up the bit width to a Rust integer size.
fn round_bit_width(width: usize) -> Result<usize> {
fn round_bit_width(width: usize) -> usize {
    match width {
        8 => Ok(8),
        16 => Ok(16),
        24 | 32 => Ok(32),
        40 | 48 | 56 | 64 => Ok(64),
        _ => bail!("unsupported field width: {width}"),
        8 => 8,
        16 => 16,
        24 | 32 => 32,
        40 | 48 | 56 | 64 => 64,
        _ => todo!("unsupported field width: {width}"),
    }
}

/// Generate a Rust unsigned integer type large enough to hold
/// integers of the given bit width.
fn type_for_width(width: usize) -> Result<syn::Type> {
    let rounded_width = round_bit_width(width)?;
    syn::parse_str(&format!("u{rounded_width}")).map_err(anyhow::Error::from)
fn type_for_width(width: usize) -> syn::Type {
    let rounded_width = round_bit_width(width);
    syn::parse_str(&format!("u{rounded_width}")).unwrap()
}

fn generate_field(
    field: &ast::Field,
    visibility: syn::Visibility,
) -> Result<proc_macro2::TokenStream> {
fn generate_field(field: &ast::Field, visibility: syn::Visibility) -> proc_macro2::TokenStream {
    match field {
        ast::Field::Scalar { id, width, .. } => {
            let field_name = format_ident!("{id}");
            let field_type = type_for_width(*width)?;
            Ok(quote! {
            let field_type = type_for_width(*width);
            quote! {
                #visibility #field_name: #field_type
            })
            }
        }
        _ => todo!("unsupported field: {:?}", field),
    }
}

fn generate_field_getter(
    packet_name: &syn::Ident,
    field: &ast::Field,
) -> Result<proc_macro2::TokenStream> {
fn generate_field_getter(packet_name: &syn::Ident, field: &ast::Field) -> proc_macro2::TokenStream {
    match field {
        ast::Field::Scalar { id, width, .. } => {
            // TODO(mgeisler): refactor with generate_field above.
            let getter_name = format_ident!("get_{id}");
            let field_name = format_ident!("{id}");
            let field_type = type_for_width(*width)?;
            Ok(quote! {
            let field_type = type_for_width(*width);
            quote! {
                pub fn #getter_name(&self) -> #field_type {
                    self.#packet_name.as_ref().#field_name
                }
            })
            }
        }
        _ => todo!("unsupported field: {:?}", field),
    }
@@ -150,12 +140,12 @@ fn generate_field_parser(
    packet_name: &str,
    field: &ast::Field,
    offset: usize,
) -> Result<proc_macro2::TokenStream> {
) -> proc_macro2::TokenStream {
    match field {
        ast::Field::Scalar { id, width, .. } => {
            let field_name = format_ident!("{id}");
            let type_width = round_bit_width(*width)?;
            let field_type = type_for_width(*width)?;
            let type_width = round_bit_width(*width);
            let field_type = type_for_width(*width);

            let getter = match endianness_value {
                ast::EndiannessValue::BigEndian => format_ident!("from_be_bytes"),
@@ -173,7 +163,7 @@ fn generate_field_parser(
                None
            };

            Ok(quote! {
            quote! {
                // TODO(mgeisler): call a function instead to avoid
                // generating so much code for this.
                if bytes.len() < #wanted_len {
@@ -186,7 +176,7 @@ fn generate_field_parser(
                }
                let #field_name = #field_type::#getter([#(bytes[#indices]),* #(, #padding)*]);
                #mask
            })
            }
        }
        _ => todo!("unsupported field: {:?}", field),
    }
@@ -196,11 +186,11 @@ fn generate_field_writer(
    grammar: &ast::Grammar,
    field: &ast::Field,
    offset: usize,
) -> Result<proc_macro2::TokenStream> {
) -> proc_macro2::TokenStream {
    match field {
        ast::Field::Scalar { id, width, .. } => {
            let field_name = format_ident!("{id}");
            let bit_width = round_bit_width(*width)?;
            let bit_width = round_bit_width(*width);
            let start = syn::Index::from(offset);
            let end = syn::Index::from(offset + bit_width / 8);
            let byte_width = syn::Index::from(bit_width / 8);
@@ -208,10 +198,10 @@ fn generate_field_writer(
                ast::EndiannessValue::BigEndian => format_ident!("to_be_bytes"),
                ast::EndiannessValue::LittleEndian => format_ident!("to_le_bytes"),
            };
            Ok(quote! {
            quote! {
                let #field_name = self.#field_name;
                buffer[#start..#end].copy_from_slice(&#field_name.#writer()[0..#byte_width]);
            })
            }
        }
        _ => todo!("unsupported field: {:?}", field),
    }
@@ -232,7 +222,7 @@ fn generate_packet_decl(
    id: &str,
    fields: &[ast::Field],
    parent_id: &Option<String>,
) -> Result<String> {
) -> String {
    // TODO(mgeisler): use the convert_case crate to convert between
    // `FooBar` and `foo_bar` in the code below.
    let mut code = String::new();
@@ -282,10 +272,7 @@ fn generate_packet_decl(
            child: #data_child_ident,
        }
    });
    let plain_fields = fields
        .iter()
        .map(|field| generate_field(field, parse_quote!()))
        .collect::<Result<Vec<_>>>()?;
    let plain_fields = fields.iter().map(|field| generate_field(field, parse_quote!()));
    let _ = write!(
        code,
        "{}",
@@ -323,10 +310,7 @@ fn generate_packet_decl(
    );

    let builder_name = format_ident!("{id}Builder");
    let pub_fields = fields
        .iter()
        .map(|field| generate_field(field, parse_quote!(pub)))
        .collect::<Result<Vec<_>>>()?;
    let pub_fields = fields.iter().map(|field| generate_field(field, parse_quote!(pub)));
    let _ = write!(
        code,
        "{}",
@@ -341,14 +325,11 @@ fn generate_packet_decl(
    // TODO(mgeisler): use the `Buf` trait instead of tracking
    // the offset manually.
    let mut offset = 0;
    let field_parsers = fields
        .iter()
        .map(|field| {
    let field_parsers = fields.iter().map(|field| {
        let parser = generate_field_parser(&grammar.endianness.value, id, field, offset);
        offset += get_field_size(field);
        parser
        })
        .collect::<Result<Vec<_>>>()?;
    });
    let field_names = fields
        .iter()
        .map(|field| match field {
@@ -357,14 +338,11 @@ fn generate_packet_decl(
        })
        .collect::<Vec<_>>();
    let mut offset = 0;
    let field_writers = fields
        .iter()
        .map(|field| {
    let field_writers = fields.iter().map(|field| {
        let writer = generate_field_writer(grammar, field, offset);
        offset += get_field_size(field);
        writer
        })
        .collect::<Result<Vec<_>>>()?;
    });

    let total_field_size = syn::Index::from(fields.iter().map(get_field_size).sum::<usize>());
    let get_size_adjustment = (total_field_size.index > 0).then(|| {
@@ -449,10 +427,7 @@ fn generate_packet_decl(
            }
        }
    });
    let field_getters = fields
        .iter()
        .map(|field| generate_field_getter(&ident, field))
        .collect::<Result<Vec<_>>>()?;
    let field_getters = fields.iter().map(|field| generate_field_getter(&ident, field));
    let _ = write!(
        code,
        "{}",
@@ -495,7 +470,7 @@ fn generate_packet_decl(
        }
    );

    Ok(code)
    code
}

fn generate_decl(
@@ -503,7 +478,7 @@ fn generate_decl(
    packets: &HashMap<&str, &ast::Decl>,
    children: &HashMap<&str, Vec<&str>>,
    decl: &ast::Decl,
) -> Result<String> {
) -> String {
    let empty: Vec<&str> = vec![];
    match decl {
        ast::Decl::Packet { id, fields, parent_id, .. } => generate_packet_decl(
@@ -522,9 +497,8 @@ fn generate_decl(
///
/// The code is not formatted, pipe it through `rustfmt` to get
/// readable source code.
pub fn generate_rust(sources: &ast::SourceDatabase, grammar: &ast::Grammar) -> Result<String> {
    let source =
        sources.get(grammar.file).with_context(|| format!("could not read {}", grammar.file))?;
pub fn generate_rust(sources: &ast::SourceDatabase, grammar: &ast::Grammar) -> String {
    let source = sources.get(grammar.file).expect("could not read source");

    let mut children = HashMap::new();
    let mut packets = HashMap::new();
@@ -539,16 +513,14 @@ pub fn generate_rust(sources: &ast::SourceDatabase, grammar: &ast::Grammar) -> R

    let mut code = String::new();

    code.push_str(&generate_preamble(Path::new(source.name()))?);
    code.push_str(&generate_preamble(Path::new(source.name())));

    for decl in &grammar.declarations {
        let decl_code = generate_decl(grammar, &packets, &children, decl)
            .with_context(|| format!("failed to generating code for {:?}", decl))?;
        code.push_str(&decl_code);
        code.push_str(&generate_decl(grammar, &packets, &children, decl));
        code.push_str("\n\n");
    }

    Ok(code)
    code
}

#[cfg(test)]
@@ -570,7 +542,7 @@ mod tests {

    #[test]
    fn test_generate_preamble() {
        let actual_code = generate_preamble(Path::new("some/path/foo.pdl")).unwrap();
        let actual_code = generate_preamble(Path::new("some/path/foo.pdl"));
        assert_snapshot_eq("tests/generated/preamble.rs", &rustfmt(&actual_code));
    }

@@ -585,7 +557,7 @@ mod tests {
        let packets = HashMap::new();
        let children = HashMap::new();
        let decl = &grammar.declarations[0];
        let actual_code = generate_decl(&grammar, &packets, &children, decl).unwrap();
        let actual_code = generate_decl(&grammar, &packets, &children, decl);
        assert_snapshot_eq("tests/generated/packet_decl_empty.rs", &rustfmt(&actual_code));
    }

@@ -604,7 +576,7 @@ mod tests {
        let packets = HashMap::new();
        let children = HashMap::new();
        let decl = &grammar.declarations[0];
        let actual_code = generate_decl(&grammar, &packets, &children, decl).unwrap();
        let actual_code = generate_decl(&grammar, &packets, &children, decl);
        assert_snapshot_eq(
            "tests/generated/packet_decl_simple_little_endian.rs",
            &rustfmt(&actual_code),
@@ -626,7 +598,7 @@ mod tests {
        let packets = HashMap::new();
        let children = HashMap::new();
        let decl = &grammar.declarations[0];
        let actual_code = generate_decl(&grammar, &packets, &children, decl).unwrap();
        let actual_code = generate_decl(&grammar, &packets, &children, decl);
        assert_snapshot_eq(
            "tests/generated/packet_decl_simple_big_endian.rs",
            &rustfmt(&actual_code),
+14 −8
Original line number Diff line number Diff line
@@ -47,33 +47,39 @@ struct Opt {
    input_file: String,
}

fn main() {
fn main() -> std::process::ExitCode {
    let opt = Opt::from_args();

    if opt.version {
        println!("Packet Description Language parser version 1.0");
        return;
        return std::process::ExitCode::SUCCESS;
    }

    let mut sources = ast::SourceDatabase::new();
    match parser::parse_file(&mut sources, opt.input_file) {
        Ok(grammar) => {
            let _ = grammar.lint().print(&sources, termcolor::ColorChoice::Always);
            let lint = grammar.lint();
            if !lint.diagnostics.is_empty() {
                lint.print(&sources, termcolor::ColorChoice::Always)
                    .expect("Could not print lint diagnostics");
                return std::process::ExitCode::FAILURE;
            }

            match opt.output_format {
                OutputFormat::JSON => {
                    println!("{}", serde_json::to_string_pretty(&grammar).unwrap())
                }
                OutputFormat::Rust => match generator::generate_rust(&sources, &grammar) {
                    Ok(code) => println!("{}", &code),
                    Err(err) => println!("failed to generate code: {}", err),
                },
                OutputFormat::Rust => {
                    println!("{}", generator::generate_rust(&sources, &grammar))
                }
            }
            std::process::ExitCode::SUCCESS
        }
        Err(err) => {
            let writer = termcolor::StandardStream::stderr(termcolor::ColorChoice::Always);
            let config = term::Config::default();
            _ = term::emit(&mut writer.lock(), &config, &sources, &err);
            term::emit(&mut writer.lock(), &config, &sources, &err).expect("Could not print error");
            std::process::ExitCode::FAILURE
        }
    }
}