Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 77a4de0d authored by Henri Chataing's avatar Henri Chataing Committed by Gerrit Code Review
Browse files

Merge "PDL: Improve the Python backend generator"

parents 9e69e602 614f094c
Loading
Loading
Loading
Loading
+76 −36
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ def generate_prelude() -> str:

        @dataclass
        class Packet:
            payload: Optional[bytes] = field(repr=False, default_factory=bytes)
            payload: Optional[bytes] = field(repr=False, default_factory=bytes, compare=False)

            @classmethod
            def parse_all(cls, span: bytes) -> 'Packet':
@@ -86,7 +86,7 @@ def generate_prelude() -> str:
                        val.show(prefix=pp)

                    # Array fields.
                    elif typ.__origin__ == list:
                    elif getattr(typ, '__origin__', None) == list:
                        print(f'{p}{name:{align}}')
                        last = len(val) - 1
                        align = 5
@@ -95,6 +95,10 @@ def generate_prelude() -> str:
                            n_pp = pp + ('' if idx != last else '    ')
                            print_val(n_p, n_pp, f'[{idx}]', align, typ.__args__[0], val[idx])

                    # Custom fields.
                    elif inspect.isclass(typ):
                        print(f'{p}{name:{align}} = {repr(val)}')

                    else:
                        print(f'{p}{name:{align}} = ##{typ}##')

@@ -755,36 +759,43 @@ def generate_packet_parser(packet: ast.Declaration) -> List[str]:
    if packet_shift and packet.file.byteorder == 'big':
        raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift")

    # Convert the packet constraints to a boolean expression.
    validation = []
    if packet.constraints:
        cond = []
        for c in packet.constraints:
            if c.value is not None:
                cond.append(f"fields['{c.id}'] != {hex(c.value)}")
            else:
                field = core.get_packet_field(packet, c.id)
                cond.append(f"fields['{c.id}'] != {field.type_id}.{c.tag_id}")

        validation = [f"if {' or '.join(cond)}:", "    raise Exception(\"Invalid constraint field values\")"]

    # Parse fields iteratively.
    parser = FieldParser(byteorder=packet.file.byteorder, shift=packet_shift)
    for f in packet.fields:
        parser.parse(f)
    parser.done()

    # Specialize to child packets.
    children = core.get_derived_packets(packet)
    decl = [] if packet.parent_id else ['fields = {\'payload\': None}']
    specialization = []

    if len(children) != 0:
        # Generate dissector on constrained fields, continue parsing the
        # child packets.
        code = decl + parser.code
        op = 'if'
        for constraints, child in children:
            cond = []
            for c in constraints:
                if c.value is not None:
                    cond.append(f"fields['{c.id}'] == {hex(c.value)}")
                else:
                    field = core.get_packet_field(packet, c.id)
                    cond.append(f"fields['{c.id}'] == {field.type_id}.{c.tag_id}")
            cond = ' and '.join(cond)
            code.append(f"{op} {cond}:")
            code.append(f"    return {child.id}.parse(fields, payload)")
            op = 'elif'

        code.append("else:")
        code.append(f"    return {packet.id}(**fields), span")
        return code
    else:
        return decl + parser.code + [f"return {packet.id}(**fields), span"]
        # Try parsing every child packet successively until one is
        # successfully parsed. Return a parsing error if none is valid.
        # Return parent packet if no child packet matches.
        # TODO: order child packets by decreasing size in case no constraint
        # is given for specialization.
        for _, child in children:
            specialization.append("try:")
            specialization.append(f"    return {child.id}.parse(fields, payload)")
            specialization.append("except Exception as exn:")
            specialization.append("    pass")

    return decl + validation + parser.code + specialization + [f"return {packet.id}(**fields), span"]


def generate_packet_size_getter(packet: ast.Declaration) -> List[str]:
@@ -822,6 +833,30 @@ def generate_packet_size_getter(packet: ast.Declaration) -> List[str]:
        assert False


def generate_packet_post_init(decl: ast.Declaration) -> List[str]:
    """Generate __post_init__ function to set constraint field values."""

    # Gather all constraints from parent packets.
    constraints = []
    current = decl
    while current.parent_id:
        constraints.extend(current.constraints)
        current = current.parent

    if constraints:
        code = []
        for c in constraints:
            if c.value is not None:
                code.append(f"self.{c.id} = {c.value}")
            else:
                field = core.get_packet_field(decl, c.id)
                code.append(f"self.{c.id} = {field.type_id}.{c.tag_id}")
        return code

    else:
        return ["pass"]


def generate_enum_declaration(decl: ast.EnumDeclaration) -> str:
    """Generate the implementation of an enum type."""

@@ -834,8 +869,7 @@ def generate_enum_declaration(decl: ast.EnumDeclaration) -> str:

        class {enum_name}(enum.IntEnum):
            {tag_decls}
        """).format(
        enum_name=enum_name, tag_decls=indent(tag_decls, 1))
        """).format(enum_name=enum_name, tag_decls=indent(tag_decls, 1))


def generate_packet_declaration(packet: ast.Declaration) -> str:
@@ -874,6 +908,7 @@ def generate_packet_declaration(packet: ast.Declaration) -> str:

    parser = generate_packet_parser(packet)
    size = generate_packet_size_getter(packet)
    post_init = generate_packet_post_init(packet)

    return dedent("""\

@@ -881,6 +916,9 @@ def generate_packet_declaration(packet: ast.Declaration) -> str:
        class {packet_name}({parent_name}):
            {field_decls}

            def __post_init__(self):
                {post_init}

            @staticmethod
            def parse({parent_fields}span: bytes) -> Tuple['{packet_name}', bytes]:
                {parser}
@@ -891,11 +929,11 @@ def generate_packet_declaration(packet: ast.Declaration) -> str:
            @property
            def size(self) -> int:
                {size}
        """).format(
        packet_name=packet_name,
        """).format(packet_name=packet_name,
                    parent_name=parent_name,
                    parent_fields=parent_fields,
                    field_decls=indent(field_decls, 1),
                    post_init=indent(post_init, 2),
                    parser=indent(parser, 2),
                    serializer=indent(serializer, 2),
                    size=indent(size, 2))
@@ -964,8 +1002,10 @@ def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source')
    parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output Python file')
    parser.add_argument(
        '--custom-type-location', type=str, required=False, help='Module of declaration of custom types')
    parser.add_argument('--custom-type-location',
                        type=str,
                        required=False,
                        help='Module of declaration of custom types')
    return run(**vars(parser.parse_args()))