Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 1178f20d authored by Zach Johnson's avatar Zach Johnson
Browse files

rusty-gd: add builder structs

by making them structs, you must name the parameters explicitly (which
is great for readability - also it doesn't matter which order you
specify them as long as you specify them all)

since builders can be constructed for middle structs in the lineage, add
None option for child enums

Bug: 171749953
Tag: #gd-refactor
Test: gd/cert/run --rhost SimpleHalTest
Change-Id: Ic2cfde113ccd6f8475c7514c513ac7b62b3689e3
parent 22c4f0da
Loading
Loading
Loading
Loading
+89 −14
Original line number Diff line number Diff line
@@ -746,11 +746,13 @@ void PacketDef::GenRustChildEnums(std::ostream& s) const {
    for (const auto& child : children_) {
      s << child->name_ << "(Rc<" << child->name_ << "Data>),";
    }
    s << "None,";
    s << "}\n";
    s << "pub enum " << name_ << "Child {";
    for (const auto& child : children_) {
      s << child->name_ << "(" << child->name_ << "Packet),";
    }
    s << "None,";
    s << "}\n";
  }
}
@@ -763,19 +765,9 @@ void PacketDef::GenRustStructDeclarations(std::ostream& s) const {
  if (!children_.empty()) {
    s << "child: " << name_ << "DataChild,";
  }

  // Generate size field
  auto fields = fields_.GetFieldsWithoutTypes({
      BodyField::kFieldType,
      CountField::kFieldType,
      PaddingField::kFieldType,
      SizeField::kFieldType,
  });
  if (fields.size() > 0) {
    s << " size: usize";
  }
  s << "}\n";

  // Generate accessor struct
  s << "pub struct " << name_ << "Packet {";
  auto lineage = GetAncestors();
  lineage.push_back(this);
@@ -784,6 +776,19 @@ void PacketDef::GenRustStructDeclarations(std::ostream& s) const {
    s << util::CamelCaseToUnderScore(def->name_) << ": Rc<" << def->name_ << "Data>,";
  }
  s << "}\n";

  // Generate builder struct
  s << "pub struct " << name_ << "Builder {";
  auto params = GetParamList().GetFieldsWithoutTypes({
      PayloadField::kFieldType,
      BodyField::kFieldType,
  });
  for (auto param : params) {
    s << "pub ";
    param->GenRustNameAndType(s);
    s << ", ";
  }
  s << "}\n";
}

bool PacketDef::GenRustStructFieldNameAndType(std::ostream& s) const {
@@ -794,6 +799,7 @@ bool PacketDef::GenRustStructFieldNameAndType(std::ostream& s) const {
      ReservedField::kFieldType,
      SizeField::kFieldType,
      PayloadField::kFieldType,
      FixedScalarField::kFieldType,
  });
  if (fields.size() == 0) {
    return false;
@@ -813,6 +819,7 @@ void PacketDef::GenRustStructFieldNames(std::ostream& s) const {
      ReservedField::kFieldType,
      SizeField::kFieldType,
      PayloadField::kFieldType,
      FixedScalarField::kFieldType,
  });
  for (int i = 0; i < fields.size(); i++) {
    s << fields[i]->GetName();
@@ -831,7 +838,7 @@ void PacketDef::GenRustStructSizeField(std::ostream& s) const {
    size += fields[i]->GetSize().bytes();
  }
  if (fields.size() > 0) {
    s << " size: " << size;
    s << size;
  }
}

@@ -889,7 +896,7 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const {
  // write_to function
  s << "fn write_to(&self, buffer: &mut BytesMut) {";
  if (fields_exist) {
    s << " buffer.resize(buffer.len() + self.size, 0);";
    s << " buffer.resize(buffer.len() + self.get_size(), 0);";
  }

  fields = fields_.GetFieldsWithoutTypes({
@@ -898,6 +905,7 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const {
      PaddingField::kFieldType,
      ReservedField::kFieldType,
      SizeField::kFieldType,
      FixedScalarField::kFieldType,
  });

  for (auto const& field : fields) {
@@ -917,13 +925,16 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const {
    for (const auto& child : children_) {
      s << name_ << "DataChild::" << child->name_ << "(value) => value.write_to(buffer),";
    }
    s << name_ << "DataChild::None => {}";
    s << "}";
  }

  s << "}\n";

  if (fields_exist) {
    s << "pub fn get_size(&self) -> usize { self.size }";
    s << "pub fn get_size(&self) -> usize {";
    GenRustStructSizeField(s);
    s << "}";
  }
  s << "}\n";
}
@@ -951,6 +962,7 @@ void PacketDef::GenRustAccessStructImpls(std::ostream& s) const {
      s << name_ << "DataChild::" << child->name_ << "(_) => " << name_ << "Child::" << child->name_ << "("
        << child->name_ << "Packet::new(self." << root_accessor << ")),";
    }
    s << name_ << "DataChild::None => " << name_ << "Child::None,";
    s << "}}";
  }
  auto lineage = GetAncestors();
@@ -987,6 +999,7 @@ void PacketDef::GenRustAccessStructImpls(std::ostream& s) const {
        ReservedField::kFieldType,
        SizeField::kFieldType,
        PayloadField::kFieldType,
        FixedScalarField::kFieldType,
    });

    for (auto const& field : fields) {
@@ -1009,9 +1022,71 @@ void PacketDef::GenRustAccessStructImpls(std::ostream& s) const {
  }
}

void PacketDef::GenRustBuilderStructImpls(std::ostream& s) const {
  s << "impl " << name_ << "Builder {";
  s << "pub fn build(self) -> " << name_ << "Packet {";
  auto lineage = GetAncestors();
  lineage.push_back(this);
  std::reverse(lineage.begin(), lineage.end());

  auto all_constraints = GetAllConstraints();

  const ParentDef* prev = nullptr;
  for (auto ancestor : lineage) {
    auto fields = ancestor->fields_.GetFieldsWithoutTypes({
        BodyField::kFieldType,
        CountField::kFieldType,
        PaddingField::kFieldType,
        ReservedField::kFieldType,
        SizeField::kFieldType,
        PayloadField::kFieldType,
        FixedScalarField::kFieldType,
    });

    auto accessor_name = util::CamelCaseToUnderScore(ancestor->name_);
    s << "let " << accessor_name << "= Rc::new(" << ancestor->name_ << "Data {";
    for (auto field : fields) {
      auto constraint = all_constraints.find(field->GetName());
      s << field->GetName() << ": ";
      if (constraint != all_constraints.end()) {
        if (field->GetFieldType() == ScalarField::kFieldType) {
          s << std::get<int64_t>(constraint->second);
        } else if (field->GetFieldType() == EnumField::kFieldType) {
          auto value = std::get<std::string>(constraint->second);
          auto constant = value.substr(value.find("::") + 2, std::string::npos);
          s << field->GetDataType() << "::" << util::ConstantCaseToCamelCase(constant) << " as "
            << field->GetRustDataType();
          ;
        } else {
          ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
        }
      } else {
        s << "self." << field->GetName();
      }
      s << ", ";
    }
    if (!ancestor->children_.empty()) {
      if (prev == nullptr) {
        s << "child: " << name_ << "DataChild::None,";
      } else {
        s << "child: " << ancestor->name_ << "DataChild::" << prev->name_ << "("
          << util::CamelCaseToUnderScore(prev->name_) << "),";
      }
    }
    s << "});";
    prev = ancestor;
  }

  s << name_ << "Packet::new(" << util::CamelCaseToUnderScore(prev->name_) << ")";
  s << "}\n";

  s << "}\n";
}

void PacketDef::GenRustDef(std::ostream& s) const {
  GenRustChildEnums(s);
  GenRustStructDeclarations(s);
  GenRustStructImpls(s);
  GenRustAccessStructImpls(s);
  GenRustBuilderStructImpls(s);
}
+2 −0
Original line number Diff line number Diff line
@@ -79,5 +79,7 @@ class PacketDef : public ParentDef {

  void GenRustAccessStructImpls(std::ostream& s) const;

  void GenRustBuilderStructImpls(std::ostream& s) const;

  void GenRustDef(std::ostream& s) const;
};
+9 −0
Original line number Diff line number Diff line
@@ -501,3 +501,12 @@ std::vector<const ParentDef*> ParentDef::GetAncestors() const {
  std::reverse(res.begin(), res.end());
  return res;
}

std::map<std::string, std::variant<int64_t, std::string>> ParentDef::GetAllConstraints() const {
  std::map<std::string, std::variant<int64_t, std::string>> res;
  res.insert(parent_constraints_.begin(), parent_constraints_.end());
  for (auto parent : GetAncestors()) {
    res.insert(parent->parent_constraints_.begin(), parent->parent_constraints_.end());
  }
  return res;
}
+2 −0
Original line number Diff line number Diff line
@@ -63,6 +63,8 @@ class ParentDef : public TypeDef {

  const ParentDef* GetRootDef() const;

  std::map<std::string, std::variant<int64_t, std::string>> GetAllConstraints() const;

  std::vector<const ParentDef*> GetAncestors() const;

  FieldList fields_;