Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 6f1c3fe8 authored by Dennis Shen's avatar Dennis Shen Committed by Automerger Merge Worker
Browse files

Merge "cpp codegen redesign, unit test support" am: 99d4a49d

parents f3207430 99d4a49d
Loading
Loading
Loading
Loading
+337 −20
Original line number Diff line number Diff line
@@ -16,13 +16,18 @@

use anyhow::{ensure, Result};
use serde::Serialize;
use std::path::PathBuf;
use tinytemplate::TinyTemplate;

use crate::codegen;
use crate::commands::OutputFile;
use crate::commands::{CodegenMode, OutputFile};
use crate::protos::{ProtoFlagPermission, ProtoFlagState, ProtoParsedFlag};

pub fn generate_cpp_code<'a, I>(package: &str, parsed_flags_iter: I) -> Result<OutputFile>
pub fn generate_cpp_code<'a, I>(
    package: &str,
    parsed_flags_iter: I,
    codegen_mode: CodegenMode,
) -> Result<Vec<OutputFile>>
where
    I: Iterator<Item = &'a ProtoParsedFlag>,
{
@@ -37,29 +42,66 @@ where
        cpp_namespace,
        package: package.to_string(),
        readwrite,
        for_prod: codegen_mode == CodegenMode::Production,
        class_elements,
    };

    let files = [
        FileSpec {
            name: &format!("{}.h", header),
            template: include_str!("../templates/cpp_exported_header.template"),
            dir: "include",
        },
        FileSpec {
            name: &format!("{}.cc", header),
            template: include_str!("../templates/cpp_source_file.template"),
            dir: "",
        },
        FileSpec {
            name: &format!("{}_flag_provider.h", header),
            template: match codegen_mode {
                CodegenMode::Production => {
                    include_str!("../templates/cpp_prod_flag_provider.template")
                }
                CodegenMode::Test => include_str!("../templates/cpp_test_flag_provider.template"),
            },
            dir: "",
        },
    ];
    files.iter().map(|file| generate_file(file, &context)).collect()
}

pub fn generate_file(file: &FileSpec, context: &Context) -> Result<OutputFile> {
    let mut template = TinyTemplate::new();
    template.add_template("cpp_code_gen", include_str!("../templates/cpp.template"))?;
    let contents = template.render("cpp_code_gen", &context)?;
    let path = ["aconfig", &(header + ".h")].iter().collect();
    template.add_template(file.name, file.template)?;
    let contents = template.render(file.name, &context)?;
    let path: PathBuf = [&file.dir, &file.name].iter().collect();
    Ok(OutputFile { contents: contents.into(), path })
}

#[derive(Serialize)]
struct Context {
pub struct FileSpec<'a> {
    pub name: &'a str,
    pub template: &'a str,
    pub dir: &'a str,
}

#[derive(Serialize)]
pub struct Context {
    pub header: String,
    pub cpp_namespace: String,
    pub package: String,
    pub readwrite: bool,
    pub for_prod: bool,
    pub class_elements: Vec<ClassElement>,
}

#[derive(Serialize)]
struct ClassElement {
pub struct ClassElement {
    pub readwrite: bool,
    pub default_value: String,
    pub flag_name: String,
    pub uppercase_flag_name: String,
    pub device_config_namespace: String,
    pub device_config_flag: String,
}
@@ -73,6 +115,7 @@ fn create_class_element(package: &str, pf: &ProtoParsedFlag) -> ClassElement {
            "false".to_string()
        },
        flag_name: pf.name().to_string(),
        uppercase_flag_name: pf.name().to_string().to_ascii_uppercase(),
        device_config_namespace: pf.namespace().to_string(),
        device_config_flag: codegen::create_device_config_ident(package, pf.name())
            .expect("values checked at flag parse time"),
@@ -82,51 +125,325 @@ fn create_class_element(package: &str, pf: &ProtoParsedFlag) -> ClassElement {
#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashMap;

    #[test]
    fn test_generate_cpp_code() {
        let parsed_flags = crate::test::parse_test_flags();
        let generated =
            generate_cpp_code(crate::test::TEST_PACKAGE, parsed_flags.parsed_flag.iter()).unwrap();
        assert_eq!("aconfig/com_android_aconfig_test.h", format!("{}", generated.path.display()));
        let expected = r#"
    const EXPORTED_PROD_HEADER_EXPECTED: &str = r#"
#ifndef com_android_aconfig_test_HEADER_H
#define com_android_aconfig_test_HEADER_H

#include <string>
#include <memory>
#include <server_configurable_flags/get_flags.h>
using namespace server_configurable_flags;

namespace com::android::aconfig::test {
class flag_provider_interface {
public:

    virtual ~flag_provider_interface() = default;

    virtual bool disabled_ro() = 0;

    virtual bool disabled_rw() = 0;

    virtual bool enabled_ro() = 0;

    virtual bool enabled_rw() = 0;

    virtual void override_flag(std::string const&, bool) {}

    virtual void reset_overrides() {}
};

extern std::unique_ptr<flag_provider_interface> provider_;

extern std::string const DISABLED_RO;
extern std::string const DISABLED_RW;
extern std::string const ENABLED_RO;
extern std::string const ENABLED_RW;

inline bool disabled_ro() {
    return false;
}

inline bool disabled_rw() {
    return provider_->disabled_rw();
}

inline bool enabled_ro() {
    return true;
}

inline bool enabled_rw() {
    return provider_->enabled_rw();
}

inline void override_flag(std::string const& name, bool val) {
    return provider_->override_flag(name, val);
}

inline void reset_overrides() {
    return provider_->reset_overrides();
}

}
#endif
"#;

    const EXPORTED_TEST_HEADER_EXPECTED: &str = r#"
#ifndef com_android_aconfig_test_HEADER_H
#define com_android_aconfig_test_HEADER_H

#include <string>
#include <memory>
#include <server_configurable_flags/get_flags.h>
using namespace server_configurable_flags;

namespace com::android::aconfig::test {
    static const bool disabled_ro() {
class flag_provider_interface {
public:

    virtual ~flag_provider_interface() = default;

    virtual bool disabled_ro() = 0;

    virtual bool disabled_rw() = 0;

    virtual bool enabled_ro() = 0;

    virtual bool enabled_rw() = 0;

    virtual void override_flag(std::string const&, bool) {}

    virtual void reset_overrides() {}
};

extern std::unique_ptr<flag_provider_interface> provider_;

extern std::string const DISABLED_RO;
extern std::string const DISABLED_RW;
extern std::string const ENABLED_RO;
extern std::string const ENABLED_RW;

inline bool disabled_ro() {
    return provider_->disabled_ro();
}

inline bool disabled_rw() {
    return provider_->disabled_rw();
}

inline bool enabled_ro() {
    return provider_->enabled_ro();
}

inline bool enabled_rw() {
    return provider_->enabled_rw();
}

inline void override_flag(std::string const& name, bool val) {
    return provider_->override_flag(name, val);
}

inline void reset_overrides() {
    return provider_->reset_overrides();
}

}
#endif
"#;

    const PROD_FLAG_PROVIDER_HEADER_EXPECTED: &str = r#"
#ifndef com_android_aconfig_test_flag_provider_HEADER_H
#define com_android_aconfig_test_flag_provider_HEADER_H

#include "com_android_aconfig_test.h"

namespace com::android::aconfig::test {
class flag_provider : public flag_provider_interface {
public:

    virtual bool disabled_ro() override {
        return false;
    }

    static const bool disabled_rw() {
    virtual bool disabled_rw() override {
        return GetServerConfigurableFlag(
            "aconfig_test",
            "com.android.aconfig.test.disabled_rw",
            "false") == "true";
    }

    static const bool enabled_ro() {
    virtual bool enabled_ro() override {
        return true;
    }

    static const bool enabled_rw() {
    virtual bool enabled_rw() override {
        return GetServerConfigurableFlag(
            "aconfig_test",
            "com.android.aconfig.test.enabled_rw",
            "true") == "true";
    }
};
}
#endif
"#;

    const TEST_FLAG_PROVIDER_HEADER_EXPECTED: &str = r#"
#ifndef com_android_aconfig_test_flag_provider_HEADER_H
#define com_android_aconfig_test_flag_provider_HEADER_H

#include "com_android_aconfig_test.h"

#include <unordered_map>
#include <unordered_set>
#include <cassert>

namespace com::android::aconfig::test {
class flag_provider : public flag_provider_interface {
private:
    std::unordered_map<std::string, bool> overrides_;
    std::unordered_set<std::string> flag_names_;

public:

    flag_provider()
        : overrides_(),
          flag_names_() {
        flag_names_.insert(DISABLED_RO);
        flag_names_.insert(DISABLED_RW);
        flag_names_.insert(ENABLED_RO);
        flag_names_.insert(ENABLED_RW);
    }

    virtual bool disabled_ro() override {
        auto it = overrides_.find(DISABLED_RO);
        if (it != overrides_.end()) {
            return it->second;
        } else {
            return false;
        }
    }

    virtual bool disabled_rw() override {
        auto it = overrides_.find(DISABLED_RW);
        if (it != overrides_.end()) {
            return it->second;
        } else {
            return GetServerConfigurableFlag(
                "aconfig_test",
                "com.android.aconfig.test.disabled_rw",
                "false") == "true";
        }
    }

    virtual bool enabled_ro() override {
        auto it = overrides_.find(ENABLED_RO);
        if (it != overrides_.end()) {
            return it->second;
        } else {
            return true;
        }
    }

    virtual bool enabled_rw() override {
        auto it = overrides_.find(ENABLED_RW);
        if (it != overrides_.end()) {
            return it->second;
        } else {
            return GetServerConfigurableFlag(
                "aconfig_test",
                "com.android.aconfig.test.enabled_rw",
                "true") == "true";
        }
    }

    virtual void override_flag(std::string const& flag, bool val) override {
        assert(flag_names_.count(flag));
        overrides_[flag] = val;
    }

    virtual void reset_overrides() override {
        overrides_.clear();
    }
};
}
#endif
"#;

    const SOURCE_FILE_EXPECTED: &str = r#"
#include "com_android_aconfig_test.h"
#include "com_android_aconfig_test_flag_provider.h"

namespace com::android::aconfig::test {

    std::string const DISABLED_RO = "com.android.aconfig.test.disabled_ro";
    std::string const DISABLED_RW = "com.android.aconfig.test.disabled_rw";
    std::string const ENABLED_RO = "com.android.aconfig.test.enabled_ro";
    std::string const ENABLED_RW = "com.android.aconfig.test.enabled_rw";

    std::unique_ptr<flag_provider_interface> provider_ =
        std::make_unique<flag_provider>();
}
"#;

    fn test_generate_cpp_code(mode: CodegenMode) {
        let parsed_flags = crate::test::parse_test_flags();
        let generated =
            generate_cpp_code(crate::test::TEST_PACKAGE, parsed_flags.parsed_flag.iter(), mode)
                .unwrap();
        let mut generated_files_map = HashMap::new();
        for file in generated {
            generated_files_map.insert(
                String::from(file.path.to_str().unwrap()),
                String::from_utf8(file.contents.clone()).unwrap(),
            );
        }

        let mut target_file_path = String::from("include/com_android_aconfig_test.h");
        assert!(generated_files_map.contains_key(&target_file_path));
        assert_eq!(
            None,
            crate::test::first_significant_code_diff(
                expected,
                &String::from_utf8(generated.contents).unwrap()
                match mode {
                    CodegenMode::Production => EXPORTED_PROD_HEADER_EXPECTED,
                    CodegenMode::Test => EXPORTED_TEST_HEADER_EXPECTED,
                },
                generated_files_map.get(&target_file_path).unwrap()
            )
        );

        target_file_path = String::from("com_android_aconfig_test_flag_provider.h");
        assert!(generated_files_map.contains_key(&target_file_path));
        assert_eq!(
            None,
            crate::test::first_significant_code_diff(
                match mode {
                    CodegenMode::Production => PROD_FLAG_PROVIDER_HEADER_EXPECTED,
                    CodegenMode::Test => TEST_FLAG_PROVIDER_HEADER_EXPECTED,
                },
                generated_files_map.get(&target_file_path).unwrap()
            )
        );

        target_file_path = String::from("com_android_aconfig_test.cc");
        assert!(generated_files_map.contains_key(&target_file_path));
        assert_eq!(
            None,
            crate::test::first_significant_code_diff(
                SOURCE_FILE_EXPECTED,
                generated_files_map.get(&target_file_path).unwrap()
            )
        );
    }

    #[test]
    fn test_generate_cpp_code_for_prod() {
        test_generate_cpp_code(CodegenMode::Production);
    }

    #[test]
    fn test_generate_cpp_code_for_test() {
        test_generate_cpp_code(CodegenMode::Test);
    }
}
+2 −2
Original line number Diff line number Diff line
@@ -143,12 +143,12 @@ pub fn create_java_lib(mut input: Input, codegen_mode: CodegenMode) -> Result<Ve
    generate_java_code(package, parsed_flags.parsed_flag.iter(), codegen_mode)
}

pub fn create_cpp_lib(mut input: Input) -> Result<OutputFile> {
pub fn create_cpp_lib(mut input: Input, codegen_mode: CodegenMode) -> Result<Vec<OutputFile>> {
    let parsed_flags = input.try_parse_flags()?;
    let Some(package) = find_unique_package(&parsed_flags) else {
        bail!("no parsed flags, or the parsed flags use different packages");
    };
    generate_cpp_code(package, parsed_flags.parsed_flag.iter())
    generate_cpp_code(package, parsed_flags.parsed_flag.iter(), codegen_mode)
}

pub fn create_rust_lib(mut input: Input) -> Result<OutputFile> {
+12 −3
Original line number Diff line number Diff line
@@ -60,7 +60,13 @@ fn cli() -> Command {
        .subcommand(
            Command::new("create-cpp-lib")
                .arg(Arg::new("cache").long("cache").required(true))
                .arg(Arg::new("out").long("out").required(true)),
                .arg(Arg::new("out").long("out").required(true))
                .arg(
                    Arg::new("mode")
                        .long("mode")
                        .value_parser(EnumValueParser::<commands::CodegenMode>::new())
                        .default_value("production"),
                ),
        )
        .subcommand(
            Command::new("create-rust-lib")
@@ -163,9 +169,12 @@ fn main() -> Result<()> {
        }
        Some(("create-cpp-lib", sub_matches)) => {
            let cache = open_single_file(sub_matches, "cache")?;
            let generated_file = commands::create_cpp_lib(cache)?;
            let mode = get_required_arg::<CodegenMode>(sub_matches, "mode")?;
            let generated_files = commands::create_cpp_lib(cache, *mode)?;
            let dir = PathBuf::from(get_required_arg::<String>(sub_matches, "out")?);
            write_output_file_realtive_to_dir(&dir, &generated_file)?;
            generated_files
                .iter()
                .try_for_each(|file| write_output_file_realtive_to_dir(&dir, file))?;
        }
        Some(("create-rust-lib", sub_matches)) => {
            let cache = open_single_file(sub_matches, "cache")?;
+48 −0
Original line number Diff line number Diff line
#ifndef {header}_HEADER_H
#define {header}_HEADER_H

#include <string>
#include <memory>
{{ if readwrite }}
#include <server_configurable_flags/get_flags.h>
using namespace server_configurable_flags;
{{ endif }}
namespace {cpp_namespace} \{

class flag_provider_interface \{
public:
    virtual ~flag_provider_interface() = default;
    {{ for item in class_elements}}
    virtual bool {item.flag_name}() = 0;
    {{ endfor }}
    virtual void override_flag(std::string const&, bool) \{}

    virtual void reset_overrides() \{}
};

extern std::unique_ptr<flag_provider_interface> provider_;
{{ for item in class_elements}}
extern std::string const {item.uppercase_flag_name};{{ endfor }}
{{ for item in class_elements}}
inline bool {item.flag_name}() \{
    {{ if for_prod }}
    {{ if not item.readwrite- }}
    return {item.default_value};
    {{ -else- }}
    return provider_->{item.flag_name}();
    {{ -endif }}
    {{ -else- }}
    return provider_->{item.flag_name}();
    {{ -endif }}
}
{{ endfor }}
inline void override_flag(std::string const& name, bool val) \{
    return provider_->override_flag(name, val);
}

inline void reset_overrides() \{
    return provider_->reset_overrides();
}

}
#endif
+8 −7
Original line number Diff line number Diff line
#ifndef {header}_HEADER_H
#define {header}_HEADER_H
{{ if readwrite }}
#include <server_configurable_flags/get_flags.h>
using namespace server_configurable_flags;
{{ endif }}
#ifndef {header}_flag_provider_HEADER_H
#define {header}_flag_provider_HEADER_H
#include "{header}.h"

namespace {cpp_namespace} \{
class flag_provider : public flag_provider_interface \{
public:
    {{ for item in class_elements}}
    static const bool {item.flag_name}() \{
    virtual bool {item.flag_name}() override \{
        {{ if item.readwrite- }}
        return GetServerConfigurableFlag(
            "{item.device_config_namespace}",
@@ -17,5 +17,6 @@ namespace {cpp_namespace} \{
        {{ -endif }}
    }
    {{ endfor }}
};
}
#endif
Loading