Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 9ac83174 authored by Henri Chataing's avatar Henri Chataing Committed by Automerger Merge Worker
Browse files

Merge "pdl: Complete the python backend generator" am: 8128c60f

parents 8a9d4424 8128c60f
Loading
Loading
Loading
Loading
+27 −10
Original line number Diff line number Diff line
@@ -82,18 +82,35 @@ rust_test_host {
    ],
}

// Python generator.
python_binary_host {
    name: "pdl_python_generator",
    main: "scripts/generate_python_backend.py",
    srcs: [
        "scripts/generate_python_backend.py",
        "scripts/pdl/ast.py",
        "scripts/pdl/core.py",
    ]
}

// Defaults for PDL python backend generation.
genrule_defaults {
    name: "pdl_python_generator_defaults",
    tools: [
        ":pdl",
        ":pdl_python_generator",
    ],
}

// Generate the python parser+serializer backend for the
// little endian test file located at tests/canonical/le_test_file.pdl.
genrule {
    name: "pdl_python_generator_le_test_gen",
    defaults: [ "pdl_python_generator_defaults" ],
    cmd: "$(location :pdl) $(in) |" +
        " $(location scripts/generate_python_backend.py)" +
        " $(location :pdl_python_generator)" +
        " --output $(out) --custom-type-location tests.custom_types",
    tools: [ ":pdl" ],
    tool_files: [
        "scripts/generate_python_backend.py",
        "scripts/pdl/core.py",
        "scripts/pdl/ast.py",
        "tests/custom_types.py",
    ],
    srcs: [
@@ -108,14 +125,11 @@ genrule {
// big endian test file located at tests/canonical/be_test_file.pdl.
genrule {
    name: "pdl_python_generator_be_test_gen",
    defaults: [ "pdl_python_generator_defaults" ],
    cmd: "$(location :pdl) $(in) |" +
        " $(location scripts/generate_python_backend.py)" +
        " $(location :pdl_python_generator)" +
        " --output $(out) --custom-type-location tests.custom_types",
    tools: [ ":pdl" ],
    tool_files: [
        "scripts/generate_python_backend.py",
        "scripts/pdl/core.py",
        "scripts/pdl/ast.py",
        "tests/custom_types.py",
    ],
    srcs: [
@@ -141,6 +155,9 @@ python_test_host {
        "tests/canonical/le_test_vectors.json",
        "tests/canonical/be_test_vectors.json",
    ],
    libs: [
        "typing_extensions",
    ],
    test_options: {
        unit_test: true,
    },
+322 −29
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ def generate_prelude() -> str:

        @dataclass
        class Packet:
            payload: Optional[bytes] = field(repr=False)
            payload: Optional[bytes] = field(repr=False, default_factory=bytes)

            @classmethod
            def parse_all(cls, span: bytes) -> 'Packet':
@@ -48,6 +48,10 @@ def generate_prelude() -> str:
                    raise Exception('Unexpected parsing remainder')
                return packet

            @property
            def size(self) -> int:
                pass

            def show(self, prefix: str = ''):
                print(f'{self.__class__.__name__}')

@@ -141,11 +145,11 @@ class FieldParser:
            self.check_size_(str(self.offset))
            self.code.extend(unchecked_code)

    def consume_span_(self) -> str:
    def consume_span_(self, keep: int = 0) -> str:
        """Skip consumed span bytes."""
        if self.offset > 0:
            self.check_code_()
            self.append_(f'span = span[{self.offset}:]')
            self.append_(f'span = span[{self.offset - keep}:]')
            self.offset = 0

    def parse_array_element_dynamic_(self, field: ast.ArrayField, span: str):
@@ -196,9 +200,7 @@ class FieldParser:

        # Apply the size modifier.
        if field.size_modifier and size:
            self.append_(f"{size} = {size} {field.size_modifier}")
        if field.size_modifier and count:
            self.append_(f"{count} = {count} {field.size_modifier}")
            self.append_(f"{size} = {size} - {field.size_modifier}")

        # The element width is not known, but the array full octet size
        # is known by size field. Parse elements item by item as a vector.
@@ -288,7 +290,7 @@ class FieldParser:
            value = "value_"

        for shift, width, field in self.chunk:
            v = (value if len(self.chunk) == 1 else f"({value} >> {shift}) & {mask(width)}")
            v = (value if len(self.chunk) == 1 and shift == 0 else f"({value} >> {shift}) & {mask(width)}")

            if isinstance(field, ast.ScalarField):
                self.unchecked_append_(f"fields['{field.id}'] = {v}")
@@ -349,14 +351,29 @@ class FieldParser:
    def parse_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField]):
        """Parse body and payload fields."""

        size = core.get_payload_field_size(field)
        payload_size = core.get_payload_field_size(field)
        offset_from_end = core.get_field_offset_from_end(field)

        # If the payload is not byte aligned, do parse the bit fields
        # that can be extracted, but do not consume the input bytes as
        # they will also be included in the payload span.
        if self.shift != 0:
            if payload_size:
                raise Exception("Unexpected payload size for non byte aligned payload")

            rounded_size = int((self.shift + 7) / 8)
            padding_bits = 8 * rounded_size - self.shift
            self.parse_bit_field_(core.make_reserved_field(padding_bits))
            self.consume_span_(rounded_size)
        else:
            self.consume_span_()

        # The payload or body has a known size.
        # Consume the payload and update the span in case
        # fields are placed after the payload.
        if size:
        if payload_size:
            if getattr(field, 'size_modifier', None):
                self.append_(f"{field.id}_size -= {field.size_modifier}")
            self.check_size_(f'{field.id}_size')
            self.append_(f"payload = span[:{field.id}_size]")
            self.append_(f"span = span[{field.id}_size:]")
@@ -429,7 +446,7 @@ class FieldParser:
            checksum_span = f'span[:-{offset_from_end}]'
            if value_size > 1:
                start = offset_from_end
                end = offset_from_start - value_size
                end = offset_from_end - value_size
                value = f"int.from_bytes(span[-{start}:-{end}], byteorder='{self.byteorder}')"
            else:
                value = f'span[-{offset_from_end}]'
@@ -479,23 +496,262 @@ class FieldParser:
        self.consume_span_()


@dataclass
class FieldSerializer:
    byteorder: str
    shift: int = 0
    value: List[str] = field(default_factory=lambda: [])
    code: List[str] = field(default_factory=lambda: [])
    indent: int = 0

    def indent_(self):
        self.indent += 1

    def unindent_(self):
        self.indent -= 1

    def append_(self, line: str):
        """Append field serializing code."""
        lines = line.split('\n')
        self.code.extend(['    ' * self.indent + line for line in lines])

    def extend_(self, value: str, length: int):
        """Append data to the span being constructed."""
        if length == 1:
            self.append_(f"_span.append({value})")
        else:
            self.append_(f"_span.extend(int.to_bytes({value}, length={length}, byteorder='{self.byteorder}'))")

    def serialize_array_element_(self, field: ast.ArrayField):
        """Serialize a single array field element."""
        if field.width is not None:
            length = int(field.width / 8)
            self.extend_('_elt', length)
        elif isinstance(field.type, ast.EnumDeclaration):
            length = int(field.type.width / 8)
            self.extend_('_elt', length)
        else:
            self.append_("_span.extend(_elt.serialize())")

    def serialize_array_field_(self, field: ast.ArrayField):
        """Serialize the selected array field."""
        if field.width == 8:
            self.append_(f"_span.extend(self.{field.id})")
        else:
            self.append_(f"for _elt in self.{field.id}:")
            self.indent_()
            self.serialize_array_element_(field)
            self.unindent_()

    def serialize_bit_field_(self, field: ast.Field):
        """Serialize the selected field as a bit field.
        The field is added to the current chunk. When a byte boundary
        is reached all saved fields are serialized together."""

        # Add to current chunk.
        width = core.get_field_size(field)
        shift = self.shift

        if isinstance(field, str):
            self.value.append(f"({field} << {shift})")
        elif isinstance(field, ast.ScalarField):
            max_value = (1 << field.width) - 1
            self.append_(f"if self.{field.id} > {max_value}:")
            self.append_(f"    print(f\"Invalid value for field {field.parent.id}::{field.id}:" +
                         f" {{self.{field.id}}} > {max_value}; the value will be truncated\")")
            self.append_(f"    self.{field.id} &= {max_value}")
            self.value.append(f"(self.{field.id} << {shift})")
        elif isinstance(field, ast.FixedField) and field.enum_id:
            self.value.append(f"({field.enum_id}.{field.tag_id} << {shift})")
        elif isinstance(field, ast.FixedField):
            self.value.append(f"({field.value} << {shift})")
        elif isinstance(field, ast.TypedefField):
            self.value.append(f"(self.{field.id} << {shift})")

        elif isinstance(field, ast.SizeField):
            max_size = (1 << field.width) - 1
            value_field = core.get_packet_field(field.parent, field.field_id)
            size_modifier = ''

            if getattr(value_field, 'size_modifier', None):
                size_modifier = f' + {value_field.size_modifier}'

            if isinstance(value_field, (ast.PayloadField, ast.BodyField)):
                self.append_(f"_size = len(payload or self.payload or []){size_modifier}")
                self.append_(f"if _size > {max_size}:")
                self.append_(f"    print(f\"Invalid length for payload field:" +
                             f"  {{_size}} > {max_size}; the packet cannot be generated\")")
                self.append_(f"    raise Exception(\"Invalid payload length\")")
                array_size = "_size"
            elif isinstance(value_field, ast.ArrayField) and value_field.width:
                array_size = f"(len(self.{value_field.id}) * {int(value_field.width / 8)})"
            elif isinstance(value_field, ast.ArrayField):
                self.append_(f"_size = sum([elt.size for elt in self.{value_field.id}]){size_modifier}")
                array_size = "_size"
            else:
                raise Exception("Unsupported field type")
            self.value.append(f"({array_size} << {shift})")

        elif isinstance(field, ast.CountField):
            max_count = (1 << field.width) - 1
            self.append_(f"if len(self.{field.field_id}) > {max_count}:")
            self.append_(f"    print(f\"Invalid length for field {field.parent.id}::{field.field_id}:" +
                         f"  {{len(self.{field.field_id})}} > {max_count}; the array will be truncated\")")
            self.append_(f"    del self.{field.field_id}[{max_count}:]")
            self.value.append(f"(len(self.{field.field_id}) << {shift})")
        elif isinstance(field, ast.ReservedField):
            pass
        else:
            raise Exception(f'Unsupported bit field type {field.kind}')

        # Check if a byte boundary is reached.
        self.shift += width
        if (self.shift % 8) == 0:
            self.pack_bit_fields_()

    def pack_bit_fields_(self):
        """Pack serialized bit fields."""

        # Should have an integral number of bytes now.
        assert (self.shift % 8) == 0

        # Generate the backing integer, and serialize it
        # using the configured endiannes,
        size = int(self.shift / 8)

        if len(self.value) == 0:
            self.append_(f"_span.extend([0] * {size})")
        elif len(self.value) == 1:
            self.extend_(self.value[0], size)
        else:
            self.append_(f"_value = (")
            self.append_("    " + " |\n    ".join(self.value))
            self.append_(")")
            self.extend_('_value', size)

        # Reset state.
        self.shift = 0
        self.value = []

    def serialize_typedef_field_(self, field: ast.TypedefField):
        """Serialize a typedef field, to the exclusion of Enum fields."""

        if self.shift != 0:
            raise Exception('Typedef field does not start on an octet boundary')
        if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None):
            raise Exception('Derived struct used in typedef field')

        if isinstance(field.type, ast.ChecksumDeclaration):
            size = int(field.type.width / 8)
            self.append_(f"_checksum = {field.type.function}(_span[_checksum_start:])")
            self.extend_('_checksum', size)
        else:
            self.append_(f"_span.extend(self.{field.id}.serialize())")

    def serialize_padding_field_(self, field: ast.PaddingField):
        """Serialize a padding field. The value is zero."""

        if self.shift != 0:
            raise Exception('Padding field does not start on an octet boundary')
        self.append_(f"_span.extend([0] * {field.width})")

    def serialize_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField]):
        """Serialize body and payload fields."""

        if self.shift != 0 and self.byteorder == 'big':
            raise Exception('Payload field does not start on an octet boundary')

        if self.shift == 0:
            self.append_(f"_span.extend(payload or self.payload or [])")
        else:
            # Supported case of packet inheritance;
            # the incomplete fields are serialized into
            # the payload, rather than separately.
            # First extract the padding bits from the payload,
            # then recombine them with the bit fields to be serialized.
            rounded_size = int((self.shift + 7) / 8)
            padding_bits = 8 * rounded_size - self.shift
            self.append_(f"_payload = payload or self.payload or bytes()")
            self.append_(f"if len(_payload) < {rounded_size}:")
            self.append_(f"    raise Exception(f\"Invalid length for payload field:" +
                         f"  {{len(_payload)}} < {rounded_size}\")")
            self.append_(
                f"_padding = int.from_bytes(_payload[:{rounded_size}], byteorder='{self.byteorder}') >> {self.shift}")
            self.value.append(f"(_padding << {self.shift})")
            self.shift += padding_bits
            self.pack_bit_fields_()
            self.append_(f"_span.extend(_payload[{rounded_size}:])")

    def serialize_checksum_field_(self, field: ast.ChecksumField):
        """Generate a checksum check."""

        self.append_("_checksum_start = len(_span)")

    def serialize(self, field: ast.Field):
        # Field has bit granularity.
        # Append the field to the current chunk,
        # check if a byte boundary was reached.
        if core.is_bit_field(field):
            self.serialize_bit_field_(field)

        # Padding fields.
        elif isinstance(field, ast.PaddingField):
            self.serialize_padding_field_(field)

        # Array fields.
        elif isinstance(field, ast.ArrayField):
            self.serialize_array_field_(field)

        # Other typedef fields.
        elif isinstance(field, ast.TypedefField):
            self.serialize_typedef_field_(field)

        # Payload and body fields.
        elif isinstance(field, (ast.PayloadField, ast.BodyField)):
            self.serialize_payload_field_(field)

        # Checksum fields.
        elif isinstance(field, ast.ChecksumField):
            self.serialize_checksum_field_(field)

        else:
            raise Exception(f'Unimplemented field type {field.kind}')


def generate_toplevel_packet_serializer(packet: ast.Declaration) -> List[str]:
    """Generate the serialize() function for a toplevel Packet or Struct
       declaration."""
    return ["pass"]

    serializer = FieldSerializer(byteorder=packet.file.byteorder)
    for f in packet.fields:
        serializer.serialize(f)
    return ['_span = bytearray()'] + serializer.code + ['return bytes(_span)']


def generate_derived_packet_serializer(packet: ast.Declaration) -> List[str]:
    """Generate the serialize() function for a derived Packet or Struct
       declaration."""
    return ["pass"]

    packet_shift = core.get_packet_shift(packet)
    if packet_shift and packet.file.byteorder == 'big':
        raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift")

    serializer = FieldSerializer(byteorder=packet.file.byteorder, shift=packet_shift)
    for f in packet.fields:
        serializer.serialize(f)
    return ['_span = bytearray()'
           ] + serializer.code + [f'return {packet.parent.id}.serialize(self, payload = bytes(_span))']


def generate_packet_parser(packet: ast.Declaration) -> List[str]:
    """Generate the parse() function for a toplevel Packet or Struct
       declaration."""

    parser = FieldParser(byteorder=packet.file.byteorder)
    packet_shift = core.get_packet_shift(packet)
    if packet_shift and packet.file.byteorder == 'big':
        raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift")

    parser = FieldParser(byteorder=packet.file.byteorder, shift=packet_shift)
    for f in packet.fields:
        parser.parse(f)
    parser.done()
@@ -527,15 +783,39 @@ def generate_packet_parser(packet: ast.Declaration) -> List[str]:
        return decl + parser.code + [f"return {packet.id}(**fields), span"]


def generate_derived_packet_parser(packet: ast.Declaration) -> List[str]:
    """Generate the parse() function for a derived Packet or Struct
       declaration."""
    print(f"Parsing packet {packet.id}", file=sys.stderr)
    parser = FieldParser(byteorder=packet.file.byteorder)
def generate_packet_size_getter(packet: ast.Declaration) -> List[str]:
    constant_width = 0
    variable_width = []
    for f in packet.fields:
        parser.parse(f)
    parser.done()
    return parser.code + [f"return {packet.id}(**fields)"]
        field_size = core.get_field_size(f)
        if field_size is not None:
            constant_width += field_size
        elif isinstance(f, (ast.PayloadField, ast.BodyField)):
            variable_width.append("len(self.payload)")
        elif isinstance(f, ast.TypedefField):
            variable_width.append(f"self.{f.id}.size")
        elif isinstance(f, ast.ArrayField) and isinstance(f.type, (ast.StructDeclaration, ast.CustomFieldDeclaration)):
            variable_width.append(f"sum([elt.size for elt in self.{f.id}])")
        elif isinstance(f, ast.ArrayField) and isinstance(f.type, ast.EnumDeclaration):
            variable_width.append(f"len(self.{f.id}) * {f.type.width}")
        elif isinstance(f, ast.ArrayField):
            variable_width.append(f"len(self.{f.id}) * {int(f.width / 8)}")
        else:
            raise Exception("Unsupported field type")

    constant_width = int(constant_width / 8)
    if len(variable_width) == 0:
        return [f"return {constant_width}"]
    elif len(variable_width) == 1 and constant_width:
        return [f"return {variable_width[0]} + {constant_width}"]
    elif len(variable_width) == 1:
        return [f"return {variable_width[0]}"]
    elif len(variable_width) > 1 and constant_width:
        return ([f"return {constant_width} + ("] + " +\n    ".join(variable_width).split("\n") + [")"])
    elif len(variable_width) > 1:
        return (["return ("] + " +\n    ".join(variable_width).split("\n") + [")"])
    else:
        assert False


def generate_enum_declaration(decl: ast.EnumDeclaration) -> str:
@@ -562,15 +842,22 @@ def generate_packet_declaration(packet: ast.Declaration) -> str:
    field_decls = []
    for f in packet.fields:
        if isinstance(f, ast.ScalarField):
            field_decls.append(f"{f.id}: int")
            field_decls.append(f"{f.id}: int = 0")
        elif isinstance(f, ast.TypedefField):
            field_decls.append(f"{f.id}: {f.type_id}")
            if isinstance(f.type, ast.EnumDeclaration):
                field_decls.append(f"{f.id}: {f.type_id} = {f.type_id}.{f.type.tags[0].id}")
            elif isinstance(f.type, ast.ChecksumDeclaration):
                field_decls.append(f"{f.id}: int = 0")
            elif isinstance(f.type, (ast.StructDeclaration, ast.CustomFieldDeclaration)):
                field_decls.append(f"{f.id}: {f.type_id} = field(default_factory={f.type_id})")
            else:
                raise Exception("Unsupported typedef field type")
        elif isinstance(f, ast.ArrayField) and f.width == 8:
            field_decls.append(f"{f.id}: bytes")
            field_decls.append(f"{f.id}: bytearray = field(default_factory=bytearray)")
        elif isinstance(f, ast.ArrayField) and f.width:
            field_decls.append(f"{f.id}: List[int]")
            field_decls.append(f"{f.id}: List[int] = field(default_factory=list)")
        elif isinstance(f, ast.ArrayField) and f.type_id:
            field_decls.append(f"{f.id}: List[{f.type_id}]")
            field_decls.append(f"{f.id}: List[{f.type_id}] = field(default_factory=list)")

    if packet.parent_id:
        parent_name = packet.parent_id
@@ -582,6 +869,7 @@ def generate_packet_declaration(packet: ast.Declaration) -> str:
        serializer = generate_toplevel_packet_serializer(packet)

    parser = generate_packet_parser(packet)
    size = generate_packet_size_getter(packet)

    return dedent("""\

@@ -593,15 +881,20 @@ def generate_packet_declaration(packet: ast.Declaration) -> str:
            def parse({parent_fields}span: bytes) -> Tuple['{packet_name}', bytes]:
                {parser}

            def serialize(self) -> bytes:
            def serialize(self, payload: bytes = None) -> bytes:
                {serializer}

            @property
            def size(self) -> int:
                {size}
        """).format(
        packet_name=packet_name,
        parent_name=parent_name,
        parent_fields=parent_fields,
        field_decls=indent(field_decls, 1),
        parser=indent(parser, 2),
        serializer=indent(serializer, 2))
        serializer=indent(serializer, 2),
        size=indent(size, 2))


def generate_custom_field_declaration_check(decl: ast.CustomFieldDeclaration) -> str:
+52 −0
Original line number Diff line number Diff line
@@ -54,6 +54,11 @@ def desugar(file: File):
    file.group_scope = {}


def make_reserved_field(width: int) -> ReservedField:
    """Create a reserved field of specified width."""
    return ReservedField(kind='reserved_field', loc=None, width=width)


def get_packet_field(packet: Union[PacketDeclaration, StructDeclaration], id: str) -> Optional[Field]:
    """Return the field with selected identifier declared in the provided
    packet or its ancestors."""
@@ -70,6 +75,53 @@ def get_packet_field(packet: Union[PacketDeclaration, StructDeclaration], id: st
        return None


def get_packet_shift(packet: Union[PacketDeclaration, StructDeclaration]) -> int:
    """Return the bit shift of the payload or body field in the parent packet.

    When using packet derivation on bit fields, the body may be shifted.
    The shift is handled statically in the implementation of child packets,
    and the incomplete field is included in the body.
    ```
    packet Basic {
        type: 1,
        _body_
    }
    ```
    """

    # Traverse empty parents.
    parent = packet.parent
    while parent and len(parent.fields) == 1:
        parent = parent.parent

    if not parent:
        return 0

    shift = 0
    for f in packet.parent.fields:
        if isinstance(f, (BodyField, PayloadField)):
            return 0 if (shift % 8) == 0 else shift
        else:
            # Fields that do not have a constant size are assumed to start
            # on a byte boundary, and measure an integral number of bytes.
            # Start the count over.
            size = get_field_size(f)
            shift = 0 if size is None else shift + size

    # No payload or body in parent packet.
    # Not raising an error, the generation will fail somewhere else.
    return 0


def get_packet_ancestor(
        decl: Union[PacketDeclaration, StructDeclaration]) -> Union[PacketDeclaration, StructDeclaration]:
    """Return the root ancestor of the selected packet or struct."""
    if decl.parent_id is None:
        return decl
    else:
        return get_packet_ancestor(decl.grammar.packet_scope[decl.parent_id])


def get_derived_packets(decl: Union[PacketDeclaration, StructDeclaration]
                       ) -> List[Tuple[List[Constraint], Union[PacketDeclaration, StructDeclaration]]]:
    """Return the list of packets or structs that immediately derive from the
+29 −0
Original line number Diff line number Diff line
@@ -143,6 +143,29 @@
      }
    ]
  },
  {
    "packet": "Packet_Payload_Field_SizeModifier",
    "tests": [
      {
        "packed": "02",
        "unpacked": {
          "payload": []
        }
      },
      {
        "packed": "070001020304",
        "unpacked": {
          "payload": [
            0,
            1,
            2,
            3,
            4
          ]
        }
      }
    ]
  },
  {
    "packet": "Packet_Payload_Field_UnknownSize",
    "tests": [
@@ -593,6 +616,7 @@
    "packet": "ScalarParent",
    "tests": [
      {
        "packet": "ScalarChild_A",
        "packed": "0001da",
        "unpacked": {
          "a": 0,
@@ -600,6 +624,7 @@
        }
      },
      {
        "packet": "ScalarChild_B",
        "packed": "0102dedc",
        "unpacked": {
          "a": 1,
@@ -607,6 +632,7 @@
        }
      },
      {
        "packet": "AliasedChild_A",
        "packed": "0201d8",
        "unpacked": {
          "a": 2,
@@ -614,6 +640,7 @@
        }
      },
      {
        "packet": "AliasedChild_B",
        "packed": "0302e70a",
        "unpacked": {
          "a": 3,
@@ -626,6 +653,7 @@
    "packet": "EnumParent",
    "tests": [
      {
        "packet": "EnumChild_A",
        "packed": "aabb01dd",
        "unpacked": {
          "a": 43707,
@@ -633,6 +661,7 @@
        }
      },
      {
        "packet": "EnumChild_B",
        "packed": "ccdd02def7",
        "unpacked": {
          "a": 52445,
Loading