Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 35b390d0 authored by Zach Johnson's avatar Zach Johnson
Browse files

rusty-gd: move size field to packet gen

this way we have full context on the targeted field

Bug: 171749953
Tag: #gd-refactor
Test: gd/cert/run --rhost
Change-Id: I2e4658cac68aaadacf35c31cc58c87ae997ae0f1
parent fdc69bfc
Loading
Loading
Loading
Loading
+5 −2
Original line number Original line Diff line number Diff line
@@ -16,6 +16,7 @@


#include "fields/scalar_field.h"
#include "fields/scalar_field.h"


#include "fields/size_field.h"
#include "util.h"
#include "util.h"


const std::string ScalarField::kFieldType = "ScalarField";
const std::string ScalarField::kFieldType = "ScalarField";
@@ -194,8 +195,10 @@ void ScalarField::GenRustWriter(std::ostream& s, Size start_offset, Size end_off
  Size size = GetSize();
  Size size = GetSize();
  int num_leading_bits = GetRustBitOffset(s, start_offset, end_offset, GetSize());
  int num_leading_bits = GetRustBitOffset(s, start_offset, end_offset, GetSize());


  if (GetFieldType() == SizeField::kFieldType) {
    // Do nothing, the field access has already happened in packet_def
  } else if (GetRustParseDataType() != GetRustDataType()) {
    // needs casting to primitive
    // needs casting to primitive
  if (GetRustParseDataType() != GetRustDataType()) {
    s << "let " << GetName() << " = self." << GetName() << ".to_" << GetRustParseDataType() << "().unwrap();";
    s << "let " << GetName() << " = self." << GetName() << ".to_" << GetRustParseDataType() << "().unwrap();";
  } else {
  } else {
    s << "let " << GetName() << " = self." << GetName() << ";";
    s << "let " << GetName() << " = self." << GetName() << ";";
+0 −35
Original line number Original line Diff line number Diff line
@@ -65,38 +65,3 @@ std::string SizeField::GetSizedFieldName() const {
void SizeField::GenStringRepresentation(std::ostream& s, std::string accessor) const {
void SizeField::GenStringRepresentation(std::ostream& s, std::string accessor) const {
  s << accessor;
  s << accessor;
}
}

void SizeField::GenRustWriter(std::ostream& s, Size start_offset, Size end_offset) const {
  Size size = GetSize();
  int num_leading_bits = GetRustBitOffset(s, start_offset, end_offset, GetSize());

  s << "let mut " << GetName() << ": " << GetRustDataType() << " = self.get_total_size() as ";
  s << GetRustDataType() << ";";
  s << GetName() << " -= self.get_size() as " << GetRustDataType() << ";";
  if (util::RoundSizeUp(size.bits()) != size.bits()) {
    uint64_t mask = 0;
    for (int i = 0; i < size.bits(); i++) {
      mask <<= 1;
      mask |= 1;
    }
    s << "let " << GetName() << " = ";
    s << GetName() << " & 0x" << std::hex << mask << std::dec << ";";
  }

  int access_offset = 0;
  if (num_leading_bits != 0) {
    access_offset = -1;
    uint64_t mask = 0;
    for (int i = 0; i < num_leading_bits; i++) {
      mask <<= 1;
      mask |= 1;
    }
    s << "let " << GetName() << " = (" << GetName() << " << " << num_leading_bits << ") | ("
      << "(buffer[" << start_offset.bytes() << "] as " << GetRustParseDataType() << ") & 0x" << std::hex << mask
      << std::dec << ");";
  }

  s << "buffer[" << start_offset.bytes() + access_offset << ".."
    << start_offset.bytes() + GetSize().bytes() + access_offset << "].copy_from_slice(&" << GetName()
    << ".to_le_bytes());";
}
+0 −2
Original line number Original line Diff line number Diff line
@@ -48,8 +48,6 @@ class SizeField : public ScalarField {


  virtual void GenStringRepresentation(std::ostream& s, std::string accessor) const override;
  virtual void GenStringRepresentation(std::ostream& s, std::string accessor) const override;


  virtual void GenRustWriter(std::ostream& s, Size start_offset, Size end_offset) const override;

 private:
 private:
  int size_;
  int size_;
  std::string sized_field_name_;
  std::string sized_field_name_;
+37 −0
Original line number Original line Diff line number Diff line
@@ -1017,6 +1017,43 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const {
                   << "no method exists to determine field location from begin() or end().\n";
                   << "no method exists to determine field location from begin() or end().\n";
    }
    }


    if (field->GetFieldType() == SizeField::kFieldType) {
      const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
      const auto& sized_field = fields_.GetField(field_name);
      if (sized_field == nullptr) {
        ERROR(field) << __func__ << ": Can't find sized field named " << field_name;
      }
      if (sized_field->GetFieldType() == PayloadField::kFieldType) {
        std::string modifier = ((PayloadField*)sized_field)->size_modifier_;
        if (modifier != "") {
          ERROR(field) << __func__ << ": size modifiers not implemented yet for " << field_name;
        }

        s << "let " << field->GetName() << " = " << field->GetRustDataType()
          << "::try_from(self.child.get_total_size()).expect(\"payload size did not fit\");";
      } else if (sized_field->GetFieldType() == BodyField::kFieldType) {
        s << "let " << field->GetName() << " = " << field->GetRustDataType()
          << "::try_from(self.get_total_size() - self.get_size()).expect(\"payload size did not fit\");";
      } else if (sized_field->GetFieldType() == VectorField::kFieldType) {
        const auto& vector_name = field_name + "_bytes";
        const VectorField* vector = (VectorField*)sized_field;
        if (vector->element_size_.empty() || vector->element_size_.has_dynamic()) {
          s << "let " << vector_name + " = self." << field_name << ".iter().fold(0, |acc, x| acc + x.get_size());";
        } else {
          s << "let " << vector_name + " = self." << field_name << ".len() * ((" << vector->element_size_ << ") / 8);";
        }
        std::string modifier = vector->GetSizeModifier();
        if (modifier != "") {
          ERROR(field) << __func__ << ": size modifiers not implemented yet for " << field_name;
        }

        s << "let " << field->GetName() << " = " << field->GetRustDataType() << "::try_from(" << vector_name
          << ").expect(\"payload size did not fit\");";
      } else {
        ERROR(field) << __func__ << ": Unhandled sized field type for " << field_name;
      }
    }

    field->GenRustWriter(s, start_field_offset, end_field_offset);
    field->GenRustWriter(s, start_field_offset, end_field_offset);
  }
  }