Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 403d1162 authored by Henri Chataing's avatar Henri Chataing Committed by Cherrypicker Worker
Browse files

pdl: Implement c++ backend generator

The generated parsing API follows closely with the legacy
generator to minimize the porting work, but differs on
notable points:

 - Generated packet view classes *do not* inherit from one another.
 - Packet fields are now backed in the class. Scalar and typedef fields
   are backed by value, payload and array fields by slice.
 - PacketView::Create() now does the parsing work, accessors just return
   the internal value in case of scalar and typedef fields, or parse the
   slice in case of payload and array fields.
 - The generated code eschews the legacy packet runtime in favor of a
   trimmed down version packet_runtime.h

Some features of PDL are not supported:

 - Payload with a bit shift
 - Checksums

Test: atest pdl_cxx_generator_tests
(cherry picked from https://android-review.googlesource.com/q/commit:075d0d3da75e1d4c55781881145851e0a33116e3)
Merged-In: Ifcb5b21c7a6e796ed43985132ad8bc4b116f4c48
Change-Id: Ifcb5b21c7a6e796ed43985132ad8bc4b116f4c48
parent 96c9a3ff
Loading
Loading
Loading
Loading
+123 −0
Original line number Diff line number Diff line
@@ -319,6 +319,15 @@ genrule_defaults {
    ],
}

// Defaults for PDL python backend generation.
genrule_defaults {
    name: "pdl_cxx_generator_defaults",
    tools: [
        ":pdl",
        ":pdl_cxx_generator",
    ],
}

// Generate the python parser+serializer backend for the
// little endian test file located at tests/canonical/le_test_file.pdl.
genrule {
@@ -441,3 +450,117 @@ rust_test_host {
    ],
    test_suites: ["general-tests"],
}

// Generate the C++ parser+serializer backend for the
// little endian test file located at tests/canonical/le_test_file.pdl.
genrule {
    name: "pdl_cxx_canonical_le_src_gen",
    defaults: ["pdl_cxx_generator_defaults"],
    cmd: "set -o pipefail;" +
        " $(location :pdl) $(in) |" +
        " $(location :pdl_cxx_generator)" +
        " --namespace le_test" +
        " --output $(out)",
    srcs: [
        "tests/canonical/le_test_file.pdl",
    ],
    out: [
        "canonical_le_test_file.h",
    ],
}

// Generate the C++ parser+serializer backend tests for the
// little endian test file located at tests/canonical/le_test_file.pdl.
genrule {
    name: "pdl_cxx_canonical_le_test_gen",
    cmd: "set -o pipefail;" +
        " inputs=( $(in) ) &&" +
        " $(location :pdl) $${inputs[0]} |" +
        " $(location :pdl_cxx_unittest_generator)" +
        " --output $(out)" +
        " --test-vectors $${inputs[1]}" +
        " --include-header $$(basename $${inputs[2]})" +
        " --using-namespace le_test" +
        " --namespace le_test" +
        " --parser-test-suite LeParserTest" +
        " --serializer-test-suite LeSerializerTest",
    tools: [
        ":pdl",
        ":pdl_cxx_unittest_generator",
    ],
    srcs: [
        "tests/canonical/le_test_file.pdl",
        "tests/canonical/le_test_vectors.json",
        ":pdl_cxx_canonical_le_src_gen",
    ],
    out: [
        "canonical_le_test.cc",
    ],
}

// Generate the C++ parser+serializer backend for the
// big endian test file.
genrule {
    name: "pdl_cxx_canonical_be_src_gen",
    defaults: ["pdl_cxx_generator_defaults"],
    cmd: "set -o pipefail;" +
        " $(location :pdl) $(in) |" +
        " $(location :pdl_cxx_generator)" +
        " --namespace be_test" +
        " --output $(out)",
    srcs: [
        ":pdl_be_test_file",
    ],
    out: [
        "canonical_be_test_file.h",
    ],
}

// Generate the C++ parser+serializer backend tests for the
// big endian test file.
genrule {
    name: "pdl_cxx_canonical_be_test_gen",
    cmd: "set -o pipefail;" +
        " inputs=( $(in) ) &&" +
        " $(location :pdl) $${inputs[0]} |" +
        " $(location :pdl_cxx_unittest_generator)" +
        " --output $(out)" +
        " --test-vectors $${inputs[1]}" +
        " --include-header $$(basename $${inputs[2]})" +
        " --using-namespace be_test" +
        " --namespace be_test" +
        " --parser-test-suite BeParserTest" +
        " --serializer-test-suite BeSerializerTest",
    tools: [
        ":pdl",
        ":pdl_cxx_unittest_generator",
    ],
    srcs: [
        ":pdl_be_test_file",
        "tests/canonical/be_test_vectors.json",
        ":pdl_cxx_canonical_be_src_gen",
    ],
    out: [
        "canonical_be_test.cc",
    ],
}

// Test the generated C++ parser+serializer against
// pre-generated binary inputs.
cc_test_host {
    name: "pdl_cxx_generator_test",
    local_include_dirs: [
        "scripts",
    ],
    generated_headers: [
        "pdl_cxx_canonical_le_src_gen",
        "pdl_cxx_canonical_be_src_gen",
    ],
    generated_sources: [
        "pdl_cxx_canonical_le_test_gen",
        "pdl_cxx_canonical_be_test_gen",
    ],
    static_libs: [
        "libgtest",
    ],
}
+25 −0
Original line number Diff line number Diff line
@@ -21,5 +21,30 @@ python_binary_host {
        "generate_python_backend.py",
        "pdl/ast.py",
        "pdl/core.py",
        "pdl/utils.py",
    ],
}

// C++ generator.
python_binary_host {
    name: "pdl_cxx_generator",
    main: "generate_cxx_backend.py",
    srcs: [
        "generate_cxx_backend.py",
        "pdl/ast.py",
        "pdl/core.py",
        "pdl/utils.py",
    ],
}

// C++ test generator.
python_binary_host {
    name: "pdl_cxx_unittest_generator",
    main: "generate_cxx_backend_tests.py",
    srcs: [
        "generate_cxx_backend_tests.py",
        "pdl/ast.py",
        "pdl/core.py",
        "pdl/utils.py",
    ],
}
+1378 −0

File added.

Preview size limit exceeded, changes collapsed.

+305 −0
Original line number Diff line number Diff line
#!/usr/bin/env python3

import argparse
from dataclasses import dataclass, field
import json
from pathlib import Path
import sys
from textwrap import dedent
from typing import List, Tuple, Union, Optional

from pdl import ast, core
from pdl.utils import indent, to_pascal_case


def get_cxx_scalar_type(width: int) -> str:
    """Return the cxx scalar type to be used to back a PDL type."""
    for n in [8, 16, 32, 64]:
        if width <= n:
            return f'uint{n}_t'
    # PDL type does not fit on non-extended scalar types.
    assert False


def generate_packet_parser_test(parser_test_suite: str, packet: ast.PacketDeclaration, tests: List[object]) -> str:
    """Generate the implementation of unit tests for the selected packet."""

    def parse_packet(packet: ast.PacketDeclaration) -> str:
        parent = parse_packet(packet.parent) if packet.parent else "input"
        return f"{packet.id}View::Create({parent})"

    def input_bytes(input: str) -> List[str]:
        input = bytes.fromhex(input)
        input_bytes = []
        for i in range(0, len(input), 16):
            input_bytes.append(' '.join(f'0x{b:x},' for b in input[i:i + 16]))
        return input_bytes

    def get_field(decl: ast.Declaration, var: str, id: str) -> str:
        if isinstance(decl, ast.StructDeclaration):
            return f"{var}.{id}_"
        else:
            return f"{var}.Get{to_pascal_case(id)}()"

    def check_members(decl: ast.Declaration, var: str, expected: object) -> List[str]:
        checks = []
        for (id, value) in expected.items():
            field = core.get_packet_field(decl, id)
            sanitized_var = var.replace('[', '_').replace(']', '')
            field_var = f'{sanitized_var}_{id}'

            if isinstance(field, ast.ScalarField):
                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});")

            elif (isinstance(field, ast.TypedefField) and
                  isinstance(field.type, (ast.EnumDeclaration, ast.CustomFieldDeclaration, ast.ChecksumDeclaration))):
                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {field.type_id}({value}));")

            elif isinstance(field, ast.TypedefField):
                checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)};")
                checks.extend(check_members(field.type, field_var, value))

            elif isinstance(field, (ast.PayloadField, ast.BodyField)):
                checks.append(f"std::vector<uint8_t> expected_{field_var} {{")
                for i in range(0, len(value), 16):
                    checks.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]]))
                checks.append("};")
                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")

            elif isinstance(field, ast.ArrayField) and field.width:
                checks.append(f"std::vector<{get_cxx_scalar_type(field.width)}> expected_{field_var} {{")
                step = int(16 * 8 / field.width)
                for i in range(0, len(value), step):
                    checks.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
                checks.append("};")
                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")

            elif (isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration)):
                checks.append(f"std::vector<{field.type_id}> expected_{field_var} {{")
                for v in value:
                    checks.append(f"    {field.type_id}({v}),")
                checks.append("};")
                checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")

            elif isinstance(field, ast.ArrayField):
                checks.append(f"std::vector<{field.type_id}> {field_var} = {get_field(decl, var, id)};")
                checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});")
                for (n, value) in enumerate(value):
                    checks.extend(check_members(field.type, f"{field_var}[{n}]", value))

            else:
                pass

        return checks

    generated_tests = []
    for (test_nr, test) in enumerate(tests):
        child_packet_id = test.get('packet', packet.id)
        child_packet = packet.file.packet_scope[child_packet_id]

        generated_tests.append(
            dedent("""\

            TEST_F({parser_test_suite}, {packet_id}_Case{test_nr}) {{
                pdl::packet::slice input(std::shared_ptr<std::vector<uint8_t>>(new std::vector<uint8_t> {{
                    {input_bytes}
                }}));
                {child_packet_id}View packet = {parse_packet};
                ASSERT_TRUE(packet.IsValid());
                {checks}
            }}
            """).format(parser_test_suite=parser_test_suite,
                        packet_id=packet.id,
                        child_packet_id=child_packet_id,
                        test_nr=test_nr,
                        input_bytes=indent(input_bytes(test['packed']), 2),
                        parse_packet=parse_packet(child_packet),
                        checks=indent(check_members(packet, 'packet', test['unpacked']), 1)))

    return ''.join(generated_tests)


def generate_packet_serializer_test(serializer_test_suite: str, packet: ast.PacketDeclaration,
                                    tests: List[object]) -> str:
    """Generate the implementation of unit tests for the selected packet."""

    def build_packet(decl: ast.Declaration, var: str, initializer: object) -> (str, List[str]):
        fields = core.get_unconstrained_parent_fields(decl) + decl.fields
        declarations = []
        parameters = []
        for field in fields:
            sanitized_var = var.replace('[', '_').replace(']', '')
            field_id = getattr(field, 'id', None)
            field_var = f'{sanitized_var}_{field_id}'
            value = initializer['payload'] if isinstance(field, (ast.PayloadField,
                                                                 ast.BodyField)) else initializer.get(field_id, None)

            if isinstance(field, ast.ScalarField):
                parameters.append(f"{value}")

            elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration):
                parameters.append(f"{field.type_id}({value})")

            elif isinstance(field, ast.TypedefField):
                (element, intermediate_declarations) = build_packet(field.type, field_var, value)
                declarations.extend(intermediate_declarations)
                parameters.append(element)

            elif isinstance(field, (ast.PayloadField, ast.BodyField)):
                declarations.append(f"std::vector<uint8_t> {field_var} {{")
                for i in range(0, len(value), 16):
                    declarations.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]]))
                declarations.append("};")
                parameters.append(f"std::move({field_var})")

            elif isinstance(field, ast.ArrayField) and field.width:
                declarations.append(f"std::vector<{get_cxx_scalar_type(field.width)}> {field_var} {{")
                step = int(16 * 8 / field.width)
                for i in range(0, len(value), step):
                    declarations.append('    ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
                declarations.append("};")
                parameters.append(f"std::move({field_var})")

            elif isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration):
                declarations.append(f"std::vector<{field.type_id}> {field_var} {{")
                for v in value:
                    declarations.append(f"    {field.type_id}({v}),")
                declarations.append("};")
                parameters.append(f"std::move({field_var})")

            elif isinstance(field, ast.ArrayField):
                elements = []
                for (n, value) in enumerate(value):
                    (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value)
                    elements.append(element)
                    declarations.extend(intermediate_declarations)
                declarations.append(f"std::vector<{field.type_id}> {field_var} {{")
                for element in elements:
                    declarations.append(f"    {element},")
                declarations.append("};")
                parameters.append(f"std::move({field_var})")

            else:
                pass

        constructor_name = f'{decl.id}Builder' if isinstance(decl, ast.PacketDeclaration) else decl.id
        return (f"{constructor_name}({', '.join(parameters)})", declarations)

    def output_bytes(output: str) -> List[str]:
        output = bytes.fromhex(output)
        output_bytes = []
        for i in range(0, len(output), 16):
            output_bytes.append(' '.join(f'0x{b:x},' for b in output[i:i + 16]))
        return output_bytes

    generated_tests = []
    for (test_nr, test) in enumerate(tests):
        child_packet_id = test.get('packet', packet.id)
        child_packet = packet.file.packet_scope[child_packet_id]

        (built_packet, intermediate_declarations) = build_packet(child_packet, 'packet', test['unpacked'])
        generated_tests.append(
            dedent("""\

            TEST_F({serializer_test_suite}, {packet_id}_Case{test_nr}) {{
                std::vector<uint8_t> expected_output {{
                    {output_bytes}
                }};
                {intermediate_declarations}
                {child_packet_id}Builder packet = {built_packet};
                ASSERT_EQ(packet.pdl::packet::Builder::Serialize(), expected_output);
            }}
            """).format(serializer_test_suite=serializer_test_suite,
                        packet_id=packet.id,
                        child_packet_id=child_packet_id,
                        test_nr=test_nr,
                        output_bytes=indent(output_bytes(test['packed']), 2),
                        built_packet=built_packet,
                        intermediate_declarations=indent(intermediate_declarations, 1)))

    return ''.join(generated_tests)


def run(input: argparse.FileType, output: argparse.FileType, test_vectors: argparse.FileType, include_header: List[str],
        using_namespace: List[str], namespace: str, parser_test_suite: str, serializer_test_suite: str):

    file = ast.File.from_json(json.load(input))
    tests = json.load(test_vectors)
    core.desugar(file)

    include_header = '\n'.join([f'#include <{header}>' for header in include_header])
    using_namespace = '\n'.join([f'using namespace {namespace};' for namespace in using_namespace])

    skipped_tests = [
        'Packet_Checksum_Field_FromStart',
        'Packet_Checksum_Field_FromEnd',
        'Struct_Checksum_Field_FromStart',
        'Struct_Checksum_Field_FromEnd',
        'PartialParent5',
        'PartialParent12',
    ]

    output.write(
        dedent("""\
        // File generated from {input_name} and {test_vectors_name}, with the command:
        //  {input_command}
        // /!\\ Do not edit by hand

        #include <cstdint>
        #include <string>
        #include <gtest/gtest.h>
        #include <packet_runtime.h>

        {include_header}
        {using_namespace}

        namespace {namespace} {{

        class {parser_test_suite} : public testing::Test {{}};
        class {serializer_test_suite} : public testing::Test {{}};
        """).format(parser_test_suite=parser_test_suite,
                    serializer_test_suite=serializer_test_suite,
                    input_name=input.name,
                    input_command=' '.join(sys.argv),
                    test_vectors_name=test_vectors.name,
                    include_header=include_header,
                    using_namespace=using_namespace,
                    namespace=namespace))

    for decl in file.declarations:
        if decl.id in skipped_tests:
            continue

        if isinstance(decl, ast.PacketDeclaration):
            matching_tests = [test['tests'] for test in tests if test['packet'] == decl.id]
            matching_tests = [test for test_list in matching_tests for test in test_list]
            if matching_tests:
                output.write(generate_packet_parser_test(parser_test_suite, decl, matching_tests))
                output.write(generate_packet_serializer_test(serializer_test_suite, decl, matching_tests))

    output.write(f"}}  // namespace {namespace}\n")


def main() -> int:
    """Generate cxx PDL backend."""
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source')
    parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output C++ file')
    parser.add_argument('--test-vectors', type=argparse.FileType('r'), required=True, help='Input PDL test file')
    parser.add_argument('--namespace', type=str, default='pdl', help='Namespace of the generated file')
    parser.add_argument('--parser-test-suite', type=str, default='ParserTest', help='Name of the parser test suite')
    parser.add_argument('--serializer-test-suite',
                        type=str,
                        default='SerializerTest',
                        help='Name of the serializer test suite')
    parser.add_argument('--include-header', type=str, default=[], action='append', help='Added include directives')
    parser.add_argument('--using-namespace',
                        type=str,
                        default=[],
                        action='append',
                        help='Added using namespace statements')
    return run(**vars(parser.parse_args()))


if __name__ == '__main__':
    sys.exit(main())
+1 −14
Original line number Diff line number Diff line
@@ -9,20 +9,7 @@ from textwrap import dedent
from typing import List, Tuple, Union, Optional

from pdl import ast, core


def indent(lines: List[str], depth: int) -> str:
    """Indent a code block to the selected depth.
    The first line is intentionally not indented so that
    the caller may use it as:

    '''
    def generated():
        {codeblock}
    '''
    """
    sep = '\n' + (' ' * (depth * 4))
    return sep.join(lines)
from pdl.utils import indent


def mask(width: int) -> str:
Loading