Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 569da72a authored by Zach Johnson's avatar Zach Johnson
Browse files

rusty-gd: give special treament to unconstrained only-children

usually a bad thing with actual human kids :p

if a packet only has one child and that child is running around
unconstrained and causing mischief, assume we may be able to parse
remaining bytes as the child because it's the only possibility

Bug: 171749953
Tag: #gd-refactor
Test: gd/cert/run --rhost DirectHciTest
Change-Id: I759d041c602e1d96eb640a2ec911bef2b80957bb
parent 325cf938
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -159,9 +159,6 @@ bool generate_rust_source_one_file(
  }

  for (const auto& packet_def : decls.packet_defs_queue_) {
    if (packet_def.second->name_.rfind("LeGetVendorCapabilitiesComplete", 0) == 0) {
      continue;
    }
    packet_def.second->GenRustDef(out_file);
    out_file << "\n\n";
  }
+17 −27
Original line number Diff line number Diff line
@@ -746,9 +746,6 @@ void PacketDef::GenRustChildEnums(std::ostream& s) const {
    s << "#[derive(Debug)] ";
    s << "enum " << name_ << "DataChild {";
    for (const auto& child : children_) {
      if (child->name_.rfind("LeGetVendorCapabilitiesComplete", 0) == 0) {
        continue;
      }
      s << child->name_ << "(Arc<" << child->name_ << "Data>),";
    }
    if (payload) {
@@ -761,9 +758,6 @@ void PacketDef::GenRustChildEnums(std::ostream& s) const {
    s << "fn get_total_size(&self) -> usize {";
    s << "match self {";
    for (const auto& child : children_) {
      if (child->name_.rfind("LeGetVendorCapabilitiesComplete", 0) == 0) {
        continue;
      }
      s << name_ << "DataChild::" << child->name_ << "(value) => value.get_total_size(),";
    }
    if (payload) {
@@ -777,9 +771,6 @@ void PacketDef::GenRustChildEnums(std::ostream& s) const {
    s << "#[derive(Debug)] ";
    s << "pub enum " << name_ << "Child {";
    for (const auto& child : children_) {
      if (child->name_.rfind("LeGetVendorCapabilitiesComplete", 0) == 0) {
        continue;
      }
      s << child->name_ << "(" << child->name_ << "Packet),";
    }
    if (payload) {
@@ -889,12 +880,11 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const {
  s << "}";

  // parse function
  if (parent_constraints_.empty() && !children_.empty() && parent_ != nullptr) {
  if (parent_constraints_.empty() && children_.size() > 1 && parent_ != nullptr) {
    auto constraint = FindConstraintField();
    auto constraint_field = GetParamList().GetField(constraint);
    auto constraint_type = constraint_field->GetRustDataType();
      s << "fn parse(bytes: &[u8], " << constraint << ": " << constraint_type
          << ") -> Result<Self> {";
    s << "fn parse(bytes: &[u8], " << constraint << ": " << constraint_type << ") -> Result<Self> {";
  } else {
    s << "fn parse(bytes: &[u8]) -> Result<Self> {";
  }
@@ -928,13 +918,10 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const {
  auto constraint_name = FindConstraintField();
  auto constrained_descendants = FindDescendantsWithConstraint(constraint_name);

  if (!children_.empty()) {
  if (children_.size() > 1) {
    s << "let child = match " << constraint_name << " {";

    for (const auto& desc : constrained_descendants) {
      if (desc.first->name_.rfind("LeGetVendorCapabilitiesComplete", 0) == 0) {
        continue;
      }
      auto desc_path = FindPathToDescendant(desc.first->name_);
      std::reverse(desc_path.begin(), desc_path.end());
      auto constraint_field = GetParamList().GetField(constraint_name);
@@ -970,6 +957,15 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const {
    }

    s << "};\n";
  } else if (children_.size() == 1) {
    auto child = children_.at(0);
    s << "let child = match " << child->name_ << "Data::parse(&bytes[..]) {";
    s << " Ok(c) if " << child->name_ << "Data::conforms(&bytes[..]) => {";
    s << name_ << "DataChild::" << child->name_ << "(Arc::new(c))";
    s << " },";
    s << " Err(Error::InvalidLengthError { .. }) => " << name_ << "DataChild::None,";
    s << " _ => return Err(Error::InvalidPacketError),";
    s << "};";
  } else if (fields_.HasPayload()) {
    s << "let child = if payload.len() > 0 {";
    s << name_ << "DataChild::Payload(Bytes::from(payload))";
@@ -1010,9 +1006,6 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const {
  if (HasChildEnums()) {
    s << "match &self.child {";
    for (const auto& child : children_) {
      if (child->name_.rfind("LeGetVendorCapabilitiesComplete", 0) == 0) {
        continue;
      }
      s << name_ << "DataChild::" << child->name_ << "(value) => value.write_to(buffer),";
    }
    if (fields_.HasPayload()) {
@@ -1040,7 +1033,7 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const {
}

void PacketDef::GenRustAccessStructImpls(std::ostream& s) const {
  if (complement_ != nullptr && complement_->name_.rfind("LeGetVendorCapabilitiesComplete", 0) != 0) {
  if (complement_ != nullptr) {
    auto complement_root = complement_->GetRootDef();
    auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
    s << "impl CommandExpectations for " << name_ << "Packet {";
@@ -1076,9 +1069,6 @@ void PacketDef::GenRustAccessStructImpls(std::ostream& s) const {
    s << " pub fn specialize(&self) -> " << name_ << "Child {";
    s << " match &self." << util::CamelCaseToUnderScore(name_) << ".child {";
    for (const auto& child : children_) {
      if (child->name_.rfind("LeGetVendorCapabilitiesComplete", 0) == 0) {
        continue;
      }
      s << name_ << "DataChild::" << child->name_ << "(_) => " << name_ << "Child::" << child->name_ << "("
        << child->name_ << "Packet::new(self." << root_accessor << ".clone())),";
    }
@@ -1152,7 +1142,7 @@ void PacketDef::GenRustAccessStructImpls(std::ostream& s) const {
}

void PacketDef::GenRustBuilderStructImpls(std::ostream& s) const {
  if (complement_ != nullptr && complement_->name_.rfind("LeGetVendorCapabilitiesComplete", 0) != 0) {
  if (complement_ != nullptr) {
    auto complement_root = complement_->GetRootDef();
    auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
    s << "impl CommandExpectations for " << name_ << "Builder {";