Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit a4f58798 authored by Myles Watson's avatar Myles Watson Committed by Jakub Pawlowski
Browse files

PDL: Move GenSerialize and GenSize to ParentDef

Test: bluetooth_packet_parser_test
Change-Id: Ifa748a874ec5a84dea19eb40093b33ac5b6fd9cb
parent 26dfa76c
Loading
Loading
Loading
Loading
+2 −160
Original line number Diff line number Diff line
@@ -98,164 +98,6 @@ void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field)
  field->GenGetter(s, start_field_offset, end_field_offset);
}

void PacketDef::GenSerialize(std::ostream& s) const {
  auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
  auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();

  s << "protected:";
  s << "void SerializeHeader(BitInserter&";
  if (parent_ != nullptr || header_fields.size() != 0) {
    s << " i ";
  }
  s << ") const {";

  if (parent_ != nullptr) {
    s << parent_->name_ << "Builder::SerializeHeader(i);";
  }

  for (const auto& field : header_fields) {
    if (field->GetFieldType() == PacketField::Type::SIZE) {
      const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
      const auto& sized_field = fields_.GetField(field_name);
      if (sized_field == nullptr) {
        ERROR(field) << __func__ << ": Can't find sized field named " << field_name;
      }
      if (sized_field->GetFieldType() == PacketField::Type::PAYLOAD) {
        s << "size_t payload_bytes = GetPayloadSize();";
        std::string modifier = ((PayloadField*)sized_field)->size_modifier_;
        if (modifier != "") {
          s << "static_assert((" << modifier << ")%8 == 0, \"Modifiers must be byte-aligned\");";
          s << "payload_bytes = payload_bytes + (" << modifier << ") / 8;";
        }
        s << "ASSERT(payload_bytes < (static_cast<size_t>(1) << " << field->GetSize().bits() << "));";
        s << "insert(static_cast<" << field->GetType() << ">(payload_bytes), i," << field->GetSize().bits() << ");";
      } else {
        if (sized_field->GetFieldType() != PacketField::Type::ARRAY) {
          ERROR(field) << __func__ << ": Unhandled sized field type for " << field_name;
        }
        const auto& array_name = field_name + "_";
        const ArrayField* array = (ArrayField*)sized_field;
        s << "size_t " << array_name + "bytes =  0;";
        if (array->element_size_ == -1) {
          s << "for (auto elem : " << array_name << ") {";
          s << array_name + "bytes += elem.size(); }";
        } else {
          s << array_name + "bytes = ";
          s << array_name << ".size() * (" << array->element_size_ << " / 8);";
        }
        s << "ASSERT(" << array_name + "bytes < (1 << " << field->GetSize().bits() << "));";
        s << "insert(" << array_name << "bytes";
        s << array->GetSizeModifier() << ", i, ";
        s << field->GetSize().bits() << ");";
      }
    } else if (field->GetFieldType() == PacketField::Type::CHECKSUM_START) {
      const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
      const auto& started_field = fields_.GetField(field_name);
      if (started_field == nullptr) {
        ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
                     << ")";
      }
      s << "auto shared_checksum_ptr = std::make_shared<" << started_field->GetType() << ">();";
      s << started_field->GetType() << "::Initialize(*shared_checksum_ptr);";
      s << "i.RegisterObserver(packet::ByteObserver(";
      s << "[shared_checksum_ptr](uint8_t byte){" << started_field->GetType()
        << "::AddByte(*shared_checksum_ptr, byte);},";
      s << "[shared_checksum_ptr](){ return static_cast<uint64_t>(" << started_field->GetType()
        << "::GetChecksum(*shared_checksum_ptr));}));";
    } else if (field->GetFieldType() == PacketField::Type::COUNT) {
      const auto& array_name = ((SizeField*)field)->GetSizedFieldName() + "_";
      s << "insert(" << array_name << ".size(), i, " << field->GetSize().bits() << ");";
    } else {
      field->GenInserter(s);
    }
  }
  s << "}\n\n";

  s << "void SerializeFooter(BitInserter&";
  if (parent_ != nullptr || footer_fields.size() != 0) {
    s << " i ";
  }
  s << ") const {";

  for (const auto& field : footer_fields) {
    field->GenInserter(s);
  }
  if (parent_ != nullptr) {
    s << parent_->name_ << "Builder::SerializeFooter(i);";
  }
  s << "}\n\n";

  s << "public:";
  s << "virtual void Serialize(BitInserter& i) const override {";
  s << "SerializeHeader(i);";
  if (fields_.HasPayload()) {
    s << "payload_->Serialize(i);";
  }
  s << "SerializeFooter(i);";

  s << "}\n";
}

void PacketDef::GenBuilderSize(std::ostream& s) const {
  auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
  auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();

  s << "protected:";
  s << "size_t BitsOfHeader() const {";
  s << "return 0";

  if (parent_ != nullptr) {
    s << " + " << parent_->name_ << "Builder::BitsOfHeader() ";
  }

  for (const auto& field : header_fields) {
    Size field_size = field->GetBuilderSize();
    if (field_size.has_bits()) {
      s << " + " << field_size.bits();
    }
    if (field_size.has_dynamic()) {
      s << " + " << field_size.dynamic_string();
    }
  }
  s << ";";

  s << "}\n\n";

  s << "size_t BitsOfFooter() const {";
  s << "return 0";
  for (const auto& field : footer_fields) {
    Size field_size = field->GetBuilderSize();
    if (field_size.has_bits()) {
      s << " + " << field_size.bits();
    }
    if (field_size.has_dynamic()) {
      s << " + " << field_size.dynamic_string();
    }
  }

  if (parent_ != nullptr) {
    s << " + " << parent_->name_ << "Builder::BitsOfFooter()";
  }
  s << ";";
  s << "}\n\n";

  if (fields_.HasPayload()) {
    s << "size_t GetPayloadSize() const {";
    s << "if (payload_ != nullptr) {return payload_->size();}";
    s << "else { return size() - (BitsOfHeader() + BitsOfFooter()) / 8;}";
    s << ";}\n\n";
  }

  s << "public:";
  s << "virtual size_t size() const override {";
  s << "return (BitsOfHeader() / 8)";
  if (fields_.HasPayload()) {
    s << "+ payload_->size()";
  }
  s << " + (BitsOfFooter() / 8);";
  s << "}\n";
}

TypeDef::Type PacketDef::GetDefinitionType() const {
  return TypeDef::Type::PACKET;
}
@@ -279,7 +121,7 @@ void PacketDef::GenValidator(std::ostream& s) const {
  // Offset by the parents known size. We know that any dynamic fields can
  // already be called since the parent must have already been validated by
  // this point.
  auto parent_size = Size();
  auto parent_size = Size(0);
  if (parent_ != nullptr) {
    parent_size = parent_->GetSize(true);
  }
@@ -429,7 +271,7 @@ void PacketDef::GenBuilderDefinition(std::ostream& s) const {
  GenSerialize(s);
  s << "\n";

  GenBuilderSize(s);
  GenSize(s);
  s << "\n";

  s << " protected:\n";
+0 −4
Original line number Diff line number Diff line
@@ -35,10 +35,6 @@ class PacketDef : public ParentDef {

  void GenParserFieldGetter(std::ostream& s, const PacketField* field) const;

  void GenSerialize(std::ostream& s) const;

  void GenBuilderSize(std::ostream& s) const;

  void GenValidator(std::ostream& s) const;

  TypeDef::Type GetDefinitionType() const;
+174 −0
Original line number Diff line number Diff line
@@ -241,3 +241,177 @@ void ParentDef::GenMembers(std::ostream& s) const {
    }
  }
}

void ParentDef::GenSize(std::ostream& s) const {
  auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
  auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();

  s << "protected:";
  s << "size_t BitsOfHeader() const {";
  s << "return 0";

  if (parent_ != nullptr) {
    if (parent_->GetDefinitionType() == Type::PACKET) {
      s << " + " << parent_->name_ << "Builder::BitsOfHeader() ";
    } else {
      s << " + " << parent_->name_ << "::BitsOfHeader() ";
    }
  }

  for (const auto& field : header_fields) {
    Size field_size = field->GetBuilderSize();
    if (field_size.has_bits()) {
      s << " + " << field_size.bits();
    }
    if (field_size.has_dynamic()) {
      s << " + " << field_size.dynamic_string();
    }
  }
  s << ";";

  s << "}\n\n";

  s << "size_t BitsOfFooter() const {";
  s << "return 0";
  for (const auto& field : footer_fields) {
    Size field_size = field->GetBuilderSize();
    if (field_size.has_bits()) {
      s << " + " << field_size.bits();
    }
    if (field_size.has_dynamic()) {
      s << " + " << field_size.dynamic_string();
    }
  }

  if (parent_ != nullptr) {
    if (parent_->GetDefinitionType() == Type::PACKET) {
      s << " + " << parent_->name_ << "Builder::BitsOfFooter() ";
    } else {
      s << " + " << parent_->name_ << "::BitsOfFooter() ";
    }
  }
  s << ";";
  s << "}\n\n";

  if (fields_.HasPayload()) {
    s << "size_t GetPayloadSize() const {";
    s << "if (payload_ != nullptr) {return payload_->size();}";
    s << "else { return size() - (BitsOfHeader() + BitsOfFooter()) / 8;}";
    s << ";}\n\n";
  }

  s << "public:";
  s << "virtual size_t size() const override {";
  s << "return (BitsOfHeader() / 8)";
  if (fields_.HasPayload()) {
    s << "+ payload_->size()";
  }
  s << " + (BitsOfFooter() / 8);";
  s << "}\n";
}

void ParentDef::GenSerialize(std::ostream& s) const {
  auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
  auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();

  s << "protected:";
  s << "void SerializeHeader(BitInserter&";
  if (parent_ != nullptr || header_fields.size() != 0) {
    s << " i ";
  }
  s << ") const {";

  if (parent_ != nullptr) {
    if (parent_->GetDefinitionType() == Type::PACKET) {
      s << parent_->name_ << "Builder::SerializeHeader(i);";
    } else {
      s << parent_->name_ << "::SerializeHeader(i);";
    }
  }

  for (const auto& field : header_fields) {
    if (field->GetFieldType() == PacketField::Type::SIZE) {
      const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
      const auto& sized_field = fields_.GetField(field_name);
      if (sized_field == nullptr) {
        ERROR(field) << __func__ << ": Can't find sized field named " << field_name;
      }
      if (sized_field->GetFieldType() == PacketField::Type::PAYLOAD) {
        s << "size_t payload_bytes = GetPayloadSize();";
        std::string modifier = ((PayloadField*)sized_field)->size_modifier_;
        if (modifier != "") {
          s << "static_assert((" << modifier << ")%8 == 0, \"Modifiers must be byte-aligned\");";
          s << "payload_bytes = payload_bytes + (" << modifier << ") / 8;";
        }
        s << "ASSERT(payload_bytes < (static_cast<size_t>(1) << " << field->GetSize().bits() << "));";
        s << "insert(static_cast<" << field->GetType() << ">(payload_bytes), i," << field->GetSize().bits() << ");";
      } else {
        if (sized_field->GetFieldType() != PacketField::Type::ARRAY) {
          ERROR(field) << __func__ << ": Unhandled sized field type for " << field_name;
        }
        const auto& array_name = field_name + "_";
        const ArrayField* array = (ArrayField*)sized_field;
        s << "size_t " << array_name + "bytes =  0;";
        if (array->element_size_ == -1) {
          s << "for (auto elem : " << array_name << ") {";
          s << array_name + "bytes += elem.size(); }";
        } else {
          s << array_name + "bytes = ";
          s << array_name << ".size() * (" << array->element_size_ << " / 8);";
        }
        s << "ASSERT(" << array_name + "bytes < (1 << " << field->GetSize().bits() << "));";
        s << "insert(" << array_name << "bytes";
        s << array->GetSizeModifier() << ", i, ";
        s << field->GetSize().bits() << ");";
      }
    } else if (field->GetFieldType() == PacketField::Type::CHECKSUM_START) {
      const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
      const auto& started_field = fields_.GetField(field_name);
      if (started_field == nullptr) {
        ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
                     << ")";
      }
      s << "auto shared_checksum_ptr = std::make_shared<" << started_field->GetType() << ">();";
      s << started_field->GetType() << "::Initialize(*shared_checksum_ptr);";
      s << "i.RegisterObserver(packet::ByteObserver(";
      s << "[shared_checksum_ptr](uint8_t byte){" << started_field->GetType()
        << "::AddByte(*shared_checksum_ptr, byte);},";
      s << "[shared_checksum_ptr](){ return static_cast<uint64_t>(" << started_field->GetType()
        << "::GetChecksum(*shared_checksum_ptr));}));";
    } else if (field->GetFieldType() == PacketField::Type::COUNT) {
      const auto& array_name = ((SizeField*)field)->GetSizedFieldName() + "_";
      s << "insert(" << array_name << ".size(), i, " << field->GetSize().bits() << ");";
    } else {
      field->GenInserter(s);
    }
  }
  s << "}\n\n";

  s << "void SerializeFooter(BitInserter&";
  if (parent_ != nullptr || footer_fields.size() != 0) {
    s << " i ";
  }
  s << ") const {";

  for (const auto& field : footer_fields) {
    field->GenInserter(s);
  }
  if (parent_ != nullptr) {
    if (parent_->GetDefinitionType() == Type::PACKET) {
      s << parent_->name_ << "Builder::SerializeFooter(i);";
    } else {
      s << parent_->name_ << "::SerializeFooter(i);";
    }
  }
  s << "}\n\n";

  s << "public:";
  s << "virtual void Serialize(BitInserter& i) const override {";
  s << "SerializeHeader(i);";
  if (fields_.HasPayload()) {
    s << "payload_->Serialize(i);";
  }
  s << "SerializeFooter(i);";

  s << "}\n";
}
+4 −0
Original line number Diff line number Diff line
@@ -55,6 +55,10 @@ class ParentDef : public TypeDef {

  void GenMembers(std::ostream& s) const;

  void GenSize(std::ostream& s) const;

  void GenSerialize(std::ostream& s) const;

  FieldList fields_;

  ParentDef* parent_{nullptr};
+5 −5
Original line number Diff line number Diff line
@@ -62,23 +62,23 @@ class Size {
    return dynamic_;
  }

  bool empty() {
  bool empty() const {
    return !is_valid_;
  }

  bool has_bits() {
  bool has_bits() const {
    return bits_ != 0;
  }

  bool has_dynamic() {
  bool has_dynamic() const {
    return !dynamic_.empty();
  }

  int bits() {
  int bits() const {
    return bits_;
  }

  int bytes() {
  int bytes() const {
    return bits_ / 8;
  }