Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 614f094c authored by Henri Chataing's avatar Henri Chataing
Browse files

PDL: Improve the Python backend generator

- Fix specialization of packets when no constraints
  is provided (aliased packets or packets with no
  direct identification)
- Exclude payload from packet comparison
- Add post_init to set constraint values in child
  packet builder.

Test: atest --host pdl_python_generator_test
Change-Id: Ia1a1a90dac66cf43f8ea1b59854308d16aeaad24
parent 8b607516
Loading
Loading
Loading
Loading
+76 −36
Original line number Original line Diff line number Diff line
@@ -39,7 +39,7 @@ def generate_prelude() -> str:


        @dataclass
        @dataclass
        class Packet:
        class Packet:
            payload: Optional[bytes] = field(repr=False, default_factory=bytes)
            payload: Optional[bytes] = field(repr=False, default_factory=bytes, compare=False)


            @classmethod
            @classmethod
            def parse_all(cls, span: bytes) -> 'Packet':
            def parse_all(cls, span: bytes) -> 'Packet':
@@ -86,7 +86,7 @@ def generate_prelude() -> str:
                        val.show(prefix=pp)
                        val.show(prefix=pp)


                    # Array fields.
                    # Array fields.
                    elif typ.__origin__ == list:
                    elif getattr(typ, '__origin__', None) == list:
                        print(f'{p}{name:{align}}')
                        print(f'{p}{name:{align}}')
                        last = len(val) - 1
                        last = len(val) - 1
                        align = 5
                        align = 5
@@ -95,6 +95,10 @@ def generate_prelude() -> str:
                            n_pp = pp + ('' if idx != last else '    ')
                            n_pp = pp + ('' if idx != last else '    ')
                            print_val(n_p, n_pp, f'[{idx}]', align, typ.__args__[0], val[idx])
                            print_val(n_p, n_pp, f'[{idx}]', align, typ.__args__[0], val[idx])


                    # Custom fields.
                    elif inspect.isclass(typ):
                        print(f'{p}{name:{align}} = {repr(val)}')

                    else:
                    else:
                        print(f'{p}{name:{align}} = ##{typ}##')
                        print(f'{p}{name:{align}} = ##{typ}##')


@@ -755,36 +759,43 @@ def generate_packet_parser(packet: ast.Declaration) -> List[str]:
    if packet_shift and packet.file.byteorder == 'big':
    if packet_shift and packet.file.byteorder == 'big':
        raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift")
        raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift")


    # Convert the packet constraints to a boolean expression.
    validation = []
    if packet.constraints:
        cond = []
        for c in packet.constraints:
            if c.value is not None:
                cond.append(f"fields['{c.id}'] != {hex(c.value)}")
            else:
                field = core.get_packet_field(packet, c.id)
                cond.append(f"fields['{c.id}'] != {field.type_id}.{c.tag_id}")

        validation = [f"if {' or '.join(cond)}:", "    raise Exception(\"Invalid constraint field values\")"]

    # Parse fields iteratively.
    parser = FieldParser(byteorder=packet.file.byteorder, shift=packet_shift)
    parser = FieldParser(byteorder=packet.file.byteorder, shift=packet_shift)
    for f in packet.fields:
    for f in packet.fields:
        parser.parse(f)
        parser.parse(f)
    parser.done()
    parser.done()

    # Specialize to child packets.
    children = core.get_derived_packets(packet)
    children = core.get_derived_packets(packet)
    decl = [] if packet.parent_id else ['fields = {\'payload\': None}']
    decl = [] if packet.parent_id else ['fields = {\'payload\': None}']
    specialization = []


    if len(children) != 0:
    if len(children) != 0:
        # Generate dissector on constrained fields, continue parsing the
        # Try parsing every child packet successively until one is
        # child packets.
        # successfully parsed. Return a parsing error if none is valid.
        code = decl + parser.code
        # Return parent packet if no child packet matches.
        op = 'if'
        # TODO: order child packets by decreasing size in case no constraint
        for constraints, child in children:
        # is given for specialization.
            cond = []
        for _, child in children:
            for c in constraints:
            specialization.append("try:")
                if c.value is not None:
            specialization.append(f"    return {child.id}.parse(fields, payload)")
                    cond.append(f"fields['{c.id}'] == {hex(c.value)}")
            specialization.append("except Exception as exn:")
                else:
            specialization.append("    pass")
                    field = core.get_packet_field(packet, c.id)

                    cond.append(f"fields['{c.id}'] == {field.type_id}.{c.tag_id}")
    return decl + validation + parser.code + specialization + [f"return {packet.id}(**fields), span"]
            cond = ' and '.join(cond)
            code.append(f"{op} {cond}:")
            code.append(f"    return {child.id}.parse(fields, payload)")
            op = 'elif'

        code.append("else:")
        code.append(f"    return {packet.id}(**fields), span")
        return code
    else:
        return decl + parser.code + [f"return {packet.id}(**fields), span"]




def generate_packet_size_getter(packet: ast.Declaration) -> List[str]:
def generate_packet_size_getter(packet: ast.Declaration) -> List[str]:
@@ -822,6 +833,30 @@ def generate_packet_size_getter(packet: ast.Declaration) -> List[str]:
        assert False
        assert False




def generate_packet_post_init(decl: ast.Declaration) -> List[str]:
    """Generate __post_init__ function to set constraint field values."""

    # Gather all constraints from parent packets.
    constraints = []
    current = decl
    while current.parent_id:
        constraints.extend(current.constraints)
        current = current.parent

    if constraints:
        code = []
        for c in constraints:
            if c.value is not None:
                code.append(f"self.{c.id} = {c.value}")
            else:
                field = core.get_packet_field(decl, c.id)
                code.append(f"self.{c.id} = {field.type_id}.{c.tag_id}")
        return code

    else:
        return ["pass"]


def generate_enum_declaration(decl: ast.EnumDeclaration) -> str:
def generate_enum_declaration(decl: ast.EnumDeclaration) -> str:
    """Generate the implementation of an enum type."""
    """Generate the implementation of an enum type."""


@@ -834,8 +869,7 @@ def generate_enum_declaration(decl: ast.EnumDeclaration) -> str:


        class {enum_name}(enum.IntEnum):
        class {enum_name}(enum.IntEnum):
            {tag_decls}
            {tag_decls}
        """).format(
        """).format(enum_name=enum_name, tag_decls=indent(tag_decls, 1))
        enum_name=enum_name, tag_decls=indent(tag_decls, 1))




def generate_packet_declaration(packet: ast.Declaration) -> str:
def generate_packet_declaration(packet: ast.Declaration) -> str:
@@ -874,6 +908,7 @@ def generate_packet_declaration(packet: ast.Declaration) -> str:


    parser = generate_packet_parser(packet)
    parser = generate_packet_parser(packet)
    size = generate_packet_size_getter(packet)
    size = generate_packet_size_getter(packet)
    post_init = generate_packet_post_init(packet)


    return dedent("""\
    return dedent("""\


@@ -881,6 +916,9 @@ def generate_packet_declaration(packet: ast.Declaration) -> str:
        class {packet_name}({parent_name}):
        class {packet_name}({parent_name}):
            {field_decls}
            {field_decls}


            def __post_init__(self):
                {post_init}

            @staticmethod
            @staticmethod
            def parse({parent_fields}span: bytes) -> Tuple['{packet_name}', bytes]:
            def parse({parent_fields}span: bytes) -> Tuple['{packet_name}', bytes]:
                {parser}
                {parser}
@@ -891,11 +929,11 @@ def generate_packet_declaration(packet: ast.Declaration) -> str:
            @property
            @property
            def size(self) -> int:
            def size(self) -> int:
                {size}
                {size}
        """).format(
        """).format(packet_name=packet_name,
        packet_name=packet_name,
                    parent_name=parent_name,
                    parent_name=parent_name,
                    parent_fields=parent_fields,
                    parent_fields=parent_fields,
                    field_decls=indent(field_decls, 1),
                    field_decls=indent(field_decls, 1),
                    post_init=indent(post_init, 2),
                    parser=indent(parser, 2),
                    parser=indent(parser, 2),
                    serializer=indent(serializer, 2),
                    serializer=indent(serializer, 2),
                    size=indent(size, 2))
                    size=indent(size, 2))
@@ -964,8 +1002,10 @@ def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source')
    parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source')
    parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output Python file')
    parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output Python file')
    parser.add_argument(
    parser.add_argument('--custom-type-location',
        '--custom-type-location', type=str, required=False, help='Module of declaration of custom types')
                        type=str,
                        required=False,
                        help='Module of declaration of custom types')
    return run(**vars(parser.parse_args()))
    return run(**vars(parser.parse_args()))