Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e3055aa7 authored by Henri Chataing's avatar Henri Chataing Committed by Automerger Merge Worker
Browse files

Merge "pdl: Implement c++ backend generator" am: b315d83f am: b43bfffc am:...

Merge "pdl: Implement c++ backend generator" am: b315d83f am: b43bfffc am: 6b8c168f am: 0725fb5b am: 23de6983

Original change: https://android-review.googlesource.com/c/platform/packages/modules/Bluetooth/+/2496096



Change-Id: Id2d7f7e0a79aaa3ae029906588b86b19d4a608a7
Signed-off-by: default avatarAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
parents 0264c446 23de6983
Loading
Loading
Loading
Loading
+123 −0
Original line number Diff line number Diff line
@@ -318,6 +318,15 @@ genrule_defaults {
    ],
}

// Defaults for PDL python backend generation.
genrule_defaults {
    name: "pdl_cxx_generator_defaults",
    tools: [
        ":pdl",
        ":pdl_cxx_generator",
    ],
}

// Generate the python parser+serializer backend for the
// little endian test file located at tests/canonical/le_test_file.pdl.
genrule {
@@ -440,3 +449,117 @@ rust_test_host {
    ],
    test_suites: ["general-tests"],
}

// Generate the C++ parser+serializer backend for the
// little endian test file located at tests/canonical/le_test_file.pdl.
genrule {
    name: "pdl_cxx_canonical_le_src_gen",
    defaults: ["pdl_cxx_generator_defaults"],
    cmd: "set -o pipefail;" +
        " $(location :pdl) $(in) |" +
        " $(location :pdl_cxx_generator)" +
        " --namespace le_test" +
        " --output $(out)",
    srcs: [
        "tests/canonical/le_test_file.pdl",
    ],
    out: [
        "canonical_le_test_file.h",
    ],
}

// Generate the C++ parser+serializer backend tests for the
// little endian test file located at tests/canonical/le_test_file.pdl.
genrule {
    name: "pdl_cxx_canonical_le_test_gen",
    cmd: "set -o pipefail;" +
        " inputs=( $(in) ) &&" +
        " $(location :pdl) $${inputs[0]} |" +
        " $(location :pdl_cxx_unittest_generator)" +
        " --output $(out)" +
        " --test-vectors $${inputs[1]}" +
        " --include-header $$(basename $${inputs[2]})" +
        " --using-namespace le_test" +
        " --namespace le_test" +
        " --parser-test-suite LeParserTest" +
        " --serializer-test-suite LeSerializerTest",
    tools: [
        ":pdl",
        ":pdl_cxx_unittest_generator",
    ],
    srcs: [
        "tests/canonical/le_test_file.pdl",
        "tests/canonical/le_test_vectors.json",
        ":pdl_cxx_canonical_le_src_gen",
    ],
    out: [
        "canonical_le_test.cc",
    ],
}

// Generate the C++ parser+serializer backend for the
// big endian test file.
genrule {
    name: "pdl_cxx_canonical_be_src_gen",
    defaults: ["pdl_cxx_generator_defaults"],
    cmd: "set -o pipefail;" +
        " $(location :pdl) $(in) |" +
        " $(location :pdl_cxx_generator)" +
        " --namespace be_test" +
        " --output $(out)",
    srcs: [
        ":pdl_be_test_file",
    ],
    out: [
        "canonical_be_test_file.h",
    ],
}

// Generate the C++ parser+serializer backend tests for the
// big endian test file.
genrule {
    name: "pdl_cxx_canonical_be_test_gen",
    cmd: "set -o pipefail;" +
        " inputs=( $(in) ) &&" +
        " $(location :pdl) $${inputs[0]} |" +
        " $(location :pdl_cxx_unittest_generator)" +
        " --output $(out)" +
        " --test-vectors $${inputs[1]}" +
        " --include-header $$(basename $${inputs[2]})" +
        " --using-namespace be_test" +
        " --namespace be_test" +
        " --parser-test-suite BeParserTest" +
        " --serializer-test-suite BeSerializerTest",
    tools: [
        ":pdl",
        ":pdl_cxx_unittest_generator",
    ],
    srcs: [
        ":pdl_be_test_file",
        "tests/canonical/be_test_vectors.json",
        ":pdl_cxx_canonical_be_src_gen",
    ],
    out: [
        "canonical_be_test.cc",
    ],
}

// Test the generated C++ parser+serializer against
// pre-generated binary inputs.
cc_test_host {
    name: "pdl_cxx_generator_test",
    local_include_dirs: [
        "scripts",
    ],
    generated_headers: [
        "pdl_cxx_canonical_le_src_gen",
        "pdl_cxx_canonical_be_src_gen",
    ],
    generated_sources: [
        "pdl_cxx_canonical_le_test_gen",
        "pdl_cxx_canonical_be_test_gen",
    ],
    static_libs: [
        "libgtest",
    ],
}
+25 −0
Original line number Diff line number Diff line
@@ -21,5 +21,30 @@ python_binary_host {
        "generate_python_backend.py",
        "pdl/ast.py",
        "pdl/core.py",
        "pdl/utils.py",
    ],
}

// C++ generator.
python_binary_host {
    name: "pdl_cxx_generator",
    main: "generate_cxx_backend.py",
    srcs: [
        "generate_cxx_backend.py",
        "pdl/ast.py",
        "pdl/core.py",
        "pdl/utils.py",
    ],
}

// C++ test generator.
python_binary_host {
    name: "pdl_cxx_unittest_generator",
    main: "generate_cxx_backend_tests.py",
    srcs: [
        "generate_cxx_backend_tests.py",
        "pdl/ast.py",
        "pdl/core.py",
        "pdl/utils.py",
    ],
}
+1378 −0

File added.

Preview size limit exceeded, changes collapsed.

+305 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3

import argparse
from dataclasses import dataclass, field
import json
from pathlib import Path
import sys
from textwrap import dedent
from typing import List, Tuple, Union, Optional

from pdl import ast, core
from pdl.utils import indent, to_pascal_case


def get_cxx_scalar_type(width: int) -> str:
    """Return the cxx scalar type to be used to back a PDL type."""
    for n in [8, 16, 32, 64]:
        if width <= n:
            return f'uint{n}_t'
    # PDL type does not fit on non-extended scalar types.
    assert False


def generate_packet_parser_test(parser_test_suite: str, packet: ast.PacketDeclaration, tests: List[object]) -> str:
    """Generate the implementation of unit tests for the selected packet."""

    def parse_packet(packet: ast.PacketDeclaration) -> str:
        parent = parse_packet(packet.parent) if packet.parent else "input"
        return f"{packet.id}View::Create({parent})"

    def input_bytes(input: str) -> List[str]:
        input = bytes.fromhex(input)
        input_bytes = []
        for i in range(0, len(input), 16):
            input_bytes.append(' '.join(f'0x{b:x},' for b in input[i:i + 16]))
        return input_bytes

    def get_field(decl: ast.Declaration, var: str, id: str) -> str:
        if isinstance(decl, ast.StructDeclaration):
            return f"{var}.{id}_"
        else:
            return f"{var}.Get{to_pascal_case(id)}()"

    def check_members(decl: ast.Declaration, var: str, expected: object) -> List[str]:
        checks = []
        for (id, value) in expected.items():
            field = core.get_packet_field(decl, id)
            sanitized_var = var.replace('[', '_').replace(']', '')
            field_var = f'{sanitized_var}_{id}'

            if isinstance(field, ast.ScalarField):
                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});")

            elif (isinstance(field, ast.TypedefField) and
                  isinstance(field.type, (ast.EnumDeclaration, ast.CustomFieldDeclaration, ast.ChecksumDeclaration))):
                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {field.type_id}({value}));")

            elif isinstance(field, ast.TypedefField):
                checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)};")
                checks.extend(check_members(field.type, field_var, value))

            elif isinstance(field, (ast.PayloadField, ast.BodyField)):
                checks.append(f"std::vector<uint8_t> expected_{field_var} {{")
                for i in range(0, len(value), 16):
                    checks.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]]))
                checks.append("};")
                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")

            elif isinstance(field, ast.ArrayField) and field.width:
                checks.append(f"std::vector<{get_cxx_scalar_type(field.width)}> expected_{field_var} {{")
                step = int(16 * 8 / field.width)
                for i in range(0, len(value), step):
                    checks.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
                checks.append("};")
                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")

            elif (isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration)):
                checks.append(f"std::vector<{field.type_id}> expected_{field_var} {{")
                for v in value:
                    checks.append(f"    {field.type_id}({v}),")
                checks.append("};")
                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")

            elif isinstance(field, ast.ArrayField):
                checks.append(f"std::vector<{field.type_id}> {field_var} = {get_field(decl, var, id)};")
                checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});")
                for (n, value) in enumerate(value):
                    checks.extend(check_members(field.type, f"{field_var}[{n}]", value))

            else:
                pass

        return checks

    generated_tests = []
    for (test_nr, test) in enumerate(tests):
        child_packet_id = test.get('packet', packet.id)
        child_packet = packet.file.packet_scope[child_packet_id]

        generated_tests.append(
            dedent("""\

            TEST_F({parser_test_suite}, {packet_id}_Case{test_nr}) {{
                pdl::packet::slice input(std::shared_ptr<std::vector<uint8_t>>(new std::vector<uint8_t> {{
                    {input_bytes}
                }}));
                {child_packet_id}View packet = {parse_packet};
                ASSERT_TRUE(packet.IsValid());
                {checks}
            }}
            """).format(parser_test_suite=parser_test_suite,
                        packet_id=packet.id,
                        child_packet_id=child_packet_id,
                        test_nr=test_nr,
                        input_bytes=indent(input_bytes(test['packed']), 2),
                        parse_packet=parse_packet(child_packet),
                        checks=indent(check_members(packet, 'packet', test['unpacked']), 1)))

    return ''.join(generated_tests)


def generate_packet_serializer_test(serializer_test_suite: str, packet: ast.PacketDeclaration,
                                    tests: List[object]) -> str:
    """Generate the implementation of unit tests for the selected packet."""

    def build_packet(decl: ast.Declaration, var: str, initializer: object) -> (str, List[str]):
        fields = core.get_unconstrained_parent_fields(decl) + decl.fields
        declarations = []
        parameters = []
        for field in fields:
            sanitized_var = var.replace('[', '_').replace(']', '')
            field_id = getattr(field, 'id', None)
            field_var = f'{sanitized_var}_{field_id}'
            value = initializer['payload'] if isinstance(field, (ast.PayloadField,
                                                                 ast.BodyField)) else initializer.get(field_id, None)

            if isinstance(field, ast.ScalarField):
                parameters.append(f"{value}")

            elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration):
                parameters.append(f"{field.type_id}({value})")

            elif isinstance(field, ast.TypedefField):
                (element, intermediate_declarations) = build_packet(field.type, field_var, value)
                declarations.extend(intermediate_declarations)
                parameters.append(element)

            elif isinstance(field, (ast.PayloadField, ast.BodyField)):
                declarations.append(f"std::vector<uint8_t> {field_var} {{")
                for i in range(0, len(value), 16):
                    declarations.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]]))
                declarations.append("};")
                parameters.append(f"std::move({field_var})")

            elif isinstance(field, ast.ArrayField) and field.width:
                declarations.append(f"std::vector<{get_cxx_scalar_type(field.width)}> {field_var} {{")
                step = int(16 * 8 / field.width)
                for i in range(0, len(value), step):
                    declarations.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
                declarations.append("};")
                parameters.append(f"std::move({field_var})")

            elif isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration):
                declarations.append(f"std::vector<{field.type_id}> {field_var} {{")
                for v in value:
                    declarations.append(f"    {field.type_id}({v}),")
                declarations.append("};")
                parameters.append(f"std::move({field_var})")

            elif isinstance(field, ast.ArrayField):
                elements = []
                for (n, value) in enumerate(value):
                    (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value)
                    elements.append(element)
                    declarations.extend(intermediate_declarations)
                declarations.append(f"std::vector<{field.type_id}> {field_var} {{")
                for element in elements:
                    declarations.append(f"    {element},")
                declarations.append("};")
                parameters.append(f"std::move({field_var})")

            else:
                pass

        constructor_name = f'{decl.id}Builder' if isinstance(decl, ast.PacketDeclaration) else decl.id
        return (f"{constructor_name}({', '.join(parameters)})", declarations)

    def output_bytes(output: str) -> List[str]:
        output = bytes.fromhex(output)
        output_bytes = []
        for i in range(0, len(output), 16):
            output_bytes.append(' '.join(f'0x{b:x},' for b in output[i:i + 16]))
        return output_bytes

    generated_tests = []
    for (test_nr, test) in enumerate(tests):
        child_packet_id = test.get('packet', packet.id)
        child_packet = packet.file.packet_scope[child_packet_id]

        (built_packet, intermediate_declarations) = build_packet(child_packet, 'packet', test['unpacked'])
        generated_tests.append(
            dedent("""\

            TEST_F({serializer_test_suite}, {packet_id}_Case{test_nr}) {{
                std::vector<uint8_t> expected_output {{
                    {output_bytes}
                }};
                {intermediate_declarations}
                {child_packet_id}Builder packet = {built_packet};
                ASSERT_EQ(packet.pdl::packet::Builder::Serialize(), expected_output);
            }}
            """).format(serializer_test_suite=serializer_test_suite,
                        packet_id=packet.id,
                        child_packet_id=child_packet_id,
                        test_nr=test_nr,
                        output_bytes=indent(output_bytes(test['packed']), 2),
                        built_packet=built_packet,
                        intermediate_declarations=indent(intermediate_declarations, 1)))

    return ''.join(generated_tests)


def run(input: argparse.FileType, output: argparse.FileType, test_vectors: argparse.FileType, include_header: List[str],
        using_namespace: List[str], namespace: str, parser_test_suite: str, serializer_test_suite: str):

    file = ast.File.from_json(json.load(input))
    tests = json.load(test_vectors)
    core.desugar(file)

    include_header = '\n'.join([f'#include <{header}>' for header in include_header])
    using_namespace = '\n'.join([f'using namespace {namespace};' for namespace in using_namespace])

    skipped_tests = [
        'Packet_Checksum_Field_FromStart',
        'Packet_Checksum_Field_FromEnd',
        'Struct_Checksum_Field_FromStart',
        'Struct_Checksum_Field_FromEnd',
        'PartialParent5',
        'PartialParent12',
    ]

    output.write(
        dedent("""\
        // File generated from {input_name} and {test_vectors_name}, with the command:
        //  {input_command}
        // /!\\ Do not edit by hand

        #include <cstdint>
        #include <string>
        #include <gtest/gtest.h>
        #include <packet_runtime.h>

        {include_header}
        {using_namespace}

        namespace {namespace} {{

        class {parser_test_suite} : public testing::Test {{}};
        class {serializer_test_suite} : public testing::Test {{}};
        """).format(parser_test_suite=parser_test_suite,
                    serializer_test_suite=serializer_test_suite,
                    input_name=input.name,
                    input_command=' '.join(sys.argv),
                    test_vectors_name=test_vectors.name,
                    include_header=include_header,
                    using_namespace=using_namespace,
                    namespace=namespace))

    for decl in file.declarations:
        if decl.id in skipped_tests:
            continue

        if isinstance(decl, ast.PacketDeclaration):
            matching_tests = [test['tests'] for test in tests if test['packet'] == decl.id]
            matching_tests = [test for test_list in matching_tests for test in test_list]
            if matching_tests:
                output.write(generate_packet_parser_test(parser_test_suite, decl, matching_tests))
                output.write(generate_packet_serializer_test(serializer_test_suite, decl, matching_tests))

    output.write(f"}}  // namespace {namespace}\n")


def main() -> int:
    """Generate cxx PDL backend."""
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source')
    parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output C++ file')
    parser.add_argument('--test-vectors', type=argparse.FileType('r'), required=True, help='Input PDL test file')
    parser.add_argument('--namespace', type=str, default='pdl', help='Namespace of the generated file')
    parser.add_argument('--parser-test-suite', type=str, default='ParserTest', help='Name of the parser test suite')
    parser.add_argument('--serializer-test-suite',
                        type=str,
                        default='SerializerTest',
                        help='Name of the serializer test suite')
    parser.add_argument('--include-header', type=str, default=[], action='append', help='Added include directives')
    parser.add_argument('--using-namespace',
                        type=str,
                        default=[],
                        action='append',
                        help='Added using namespace statements')
    return run(**vars(parser.parse_args()))


if __name__ == '__main__':
    sys.exit(main())
+1 −14
Original line number Diff line number Diff line
@@ -9,20 +9,7 @@ from textwrap import dedent
from typing import List, Tuple, Union, Optional

from pdl import ast, core


def indent(lines: List[str], depth: int) -> str:
    """Indent a code block to the selected depth.
    The first line is intentionally not indented so that
    the caller may use it as:

    '''
    def generated():
        {codeblock}
    '''
    """
    sep = '\n' + (' ' * (depth * 4))
    return sep.join(lines)
from pdl.utils import indent


def mask(width: int) -> str:
Loading