Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d1a2be60 authored by Myles Watson's avatar Myles Watson Committed by Gerrit Code Review
Browse files

Merge "PDL: Rust code handles multiple constraints"

parents 461b72c7 03a42a8a
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -33,10 +33,10 @@ cc_binary_host {
        "fields/variable_length_struct_field.cc",
        "checksum_def.cc",
        "custom_field_def.cc",
        "packet_dependency.cc",
        "enum_def.cc",
        "enum_gen.cc",
        "packet_def.cc",
        "packet_dependency.cc",
        "parent_def.cc",
        "struct_def.cc",
        "struct_parser_generator.cc",
+1 −0
Original line number Diff line number Diff line
@@ -62,6 +62,7 @@ executable("bluetooth_packetgen") {
    "gen_rust.cc",
    "main.cc",
    "packet_def.cc",
    "packet_dependency.cc",
    "parent_def.cc",
    "struct_def.cc",
    "struct_parser_generator.cc",
+110 −36
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@
#include <set>

#include "fields/all_fields.h"
#include "packet_dependency.h"
#include "util.h"

PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
@@ -883,8 +884,9 @@ void PacketDef::GenRustStructFieldNames(std::ostream& s) const {
}

void PacketDef::GenRustStructImpls(std::ostream& s) const {
  s << "impl " << name_ << "Data {";
  auto packet_dep = PacketDependency(GetRootDef());

  s << "impl " << name_ << "Data {";
  // conforms function
  s << "fn conforms(bytes: &[u8]) -> bool {";
  GenRustConformanceCheck(s);
@@ -904,15 +906,15 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const {
  s << " true";
  s << "}";

  // parse function
  if (parent_constraints_.empty() && children_.size() > 1 && parent_ != nullptr) {
    auto constraint = FindConstraintField();
    auto constraint_field = GetParamList().GetField(constraint);
  auto parse_params = packet_dep.GetDependencies(name_);
  s << "fn parse(bytes: &[u8]";
  for (auto field_name : parse_params) {
    auto constraint_field = GetParamList().GetField(field_name);
    auto constraint_type = constraint_field->GetRustDataType();
    s << "fn parse(bytes: &[u8], " << constraint << ": " << constraint_type << ") -> Result<Self> {";
  } else {
    s << "fn parse(bytes: &[u8]) -> Result<Self> {";
    s << ", " << field_name << ": " << constraint_type;
  }
  s << ") -> Result<Self> {";

  fields = fields_.GetFieldsWithoutTypes({
      BodyField::kFieldType,
  });
@@ -940,51 +942,123 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const {
    payload_offset = GetOffsetForField(payload_field[0]->GetName(), false);
  }

  auto constraint_name = FindConstraintField();
  auto constrained_descendants = FindDescendantsWithConstraint(constraint_name);

  if (children_.size() > 1) {
    s << "let child = match " << constraint_name << " {";
    auto match_on_variables = packet_dep.GetChildrenDependencies(name_);
    // If match_on_variables is empty, this means there are multiple abstract packets which will
    // specialize to a child down the packet tree.
    // In this case match variables will be the union of parent fields and parse params of children.
    if (match_on_variables.empty()) {
      for (auto& field : fields_) {
        if (std::any_of(children_.begin(), children_.end(), [&](auto child) {
              auto pass_me = packet_dep.GetDependencies(child->name_);
              return std::find(pass_me.begin(), pass_me.end(), field->GetName()) != pass_me.end();
            })) {
          match_on_variables.push_back(field->GetName());
        }
      }
    }

    for (const auto& desc : constrained_descendants) {
      auto desc_path = FindPathToDescendant(desc.first->name_);
      std::reverse(desc_path.begin(), desc_path.end());
      auto constraint_field = GetParamList().GetField(constraint_name);
    s << "let child = match (";

    for (auto var : match_on_variables) {
      if (var == match_on_variables[match_on_variables.size() - 1]) {
        s << var;
      } else {
        s << var << ", ";
      }
    }
    s << ") {";

    auto get_match_val = [&](
        std::string& match_var,
        std::variant<int64_t,
        std::string> constraint) -> std::string {
      auto constraint_field = GetParamList().GetField(match_var);
      auto constraint_type = constraint_field->GetFieldType();

      if (constraint_type == EnumField::kFieldType) {
        auto type = std::get<std::string>(desc.second);
        auto type = std::get<std::string>(constraint);
        auto variant_name = type.substr(type.find("::") + 2, type.length());
        auto enum_type = type.substr(0, type.find("::"));
        auto enum_variant = enum_type + "::"
            + util::UnderscoreToCamelCase(util::ToLowerCase(variant_name));
        s << enum_variant;
        s << " if " << desc_path[0]->name_ << "Data::conforms(&bytes[..])";
        return enum_type + "::" + util::UnderscoreToCamelCase(util::ToLowerCase(variant_name));
      }
      if (constraint_type == ScalarField::kFieldType) {
        return std::to_string(std::get<int64_t>(constraint));
      }
      return "_";
    };

    for (auto& child : children_) {
      s << "(";
      for (auto var : match_on_variables) {
        std::string match_val = "_";

        if (child->parent_constraints_.find(var) != child->parent_constraints_.end()) {
          match_val = get_match_val(var, child->parent_constraints_[var]);
        } else {
          auto dcs = child->FindDescendantsWithConstraint(var);
          std::vector<std::string> all_match_vals;
          for (auto& desc : dcs) {
            all_match_vals.push_back(get_match_val(var, desc.second));
          }
          match_val = "";
          for (std::size_t i = 0; i < all_match_vals.size(); ++i) {
            match_val += all_match_vals[i];
            if (i != all_match_vals.size() - 1) {
              match_val += " | ";
            }
          }
          match_val = (match_val == "") ? "_" : match_val;
        }

        if (var == match_on_variables[match_on_variables.size() - 1]) {
          s << match_val << ")";
        } else {
          s << match_val << ", ";
        }
      }
      s << " if " << child->name_ << "Data::conforms(&bytes[..])";
      s << " => {";
      s << name_ << "DataChild::";
        s << desc_path[0]->name_ << "(Arc::new(";
        if (desc_path[0]->parent_constraints_.empty()) {
          s << desc_path[0]->name_ << "Data::parse(&bytes[..]";
          s << ", " << enum_variant << ")?))";
      s << child->name_ << "(Arc::new(";

      auto child_parse_params = packet_dep.GetDependencies(child->name_);
      if (child_parse_params.size() == 0) {
        s << child->name_ << "Data::parse(&bytes[..]";
      } else {
          s << desc_path[0]->name_ << "Data::parse(&bytes[..])?))";
        s << child->name_ << "Data::parse(&bytes[..], ";
      }

      for (auto var : child_parse_params) {
        if (var == child_parse_params[child_parse_params.size() - 1]) {
          s << var;
        } else {
          s << var << ", ";
        }
      } else if (constraint_type == ScalarField::kFieldType) {
        s << std::get<int64_t>(desc.second) << " => {";
        s << "unimplemented!();";
      }
      s << ")?))";
      s << "}\n";
    }

    if (!constrained_descendants.empty()) {
      s << "v => return Err(Error::ConstraintOutOfBounds{field: \"" << constraint_name
        << "\".to_string(), value: v as u64}),";
    s << "(";
    for (int i = 1; i <= match_on_variables.size(); i++) {
      if (i == match_on_variables.size()) {
        s << "_";
      } else {
        s << "_, ";
      }

    }
    s << ")";
    s << " => return Err(Error::InvalidPacketError),";
    s << "};\n";
  } else if (children_.size() == 1) {
    auto child = children_.at(0);
    s << "let child = match " << child->name_ << "Data::parse(&bytes[..]) {";
    auto params = packet_dep.GetDependencies(child->name_);
    s << "let child = match " << child->name_ << "Data::parse(&bytes[..]";
    for (auto field_name : params) {
      s << ", " << field_name;
    }
    s << ") {";
    s << " Ok(c) if " << child->name_ << "Data::conforms(&bytes[..]) => {";
    s << name_ << "DataChild::" << child->name_ << "(Arc::new(c))";
    s << " },";