Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 0ac08468 authored by Zach Johnson's avatar Zach Johnson
Browse files

rusty-gd: controller supported commands & le buffer sizes

also, generate TryFrom for OpCodeIndex from OpCode, to remove 700+ lines
of repeated code in C++ gd controller

Bug: 171749953
Tag: #gd-refactor
Test: gd/cert/run --rhost SimpleHalTest
Change-Id: Ic6d95ccc87a0ccc64636f16b0ef379c1f0d2f069
parent cfd17817
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -39,4 +39,6 @@ class EnumDef : public TypeDef {
  // data
  std::map<uint32_t, std::string> constants_;
  std::set<std::string> entries_;

  EnumDef* try_from_enum_ = nullptr;
};
+21 −0
Original line number Diff line number Diff line
@@ -67,4 +67,25 @@ void EnumGen::GenRustDef(std::ostream& stream) {
    stream << util::ConstantCaseToCamelCase(pair.second) << " = 0x" << std::hex << pair.first << std::dec << ",";
  }
  stream << "}";

  if (e_.try_from_enum_ != nullptr) {
    std::vector<std::string> other_items;
    for (const auto& pair : e_.try_from_enum_->constants_) {
      other_items.push_back(pair.second);
    }
    stream << "impl TryFrom<" << e_.try_from_enum_->name_ << "> for " << e_.name_ << " {";
    stream << "type Error = &'static str;";
    stream << "fn try_from(value: " << e_.try_from_enum_->name_ << ") -> std::result::Result<Self, Self::Error> {";
    stream << "match value {";
    for (const auto& pair : e_.constants_) {
      if (std::find(other_items.begin(), other_items.end(), pair.second) == other_items.end()) {
        continue;
      }
      auto constant_name = util::ConstantCaseToCamelCase(pair.second);
      stream << e_.try_from_enum_->name_ << "::" << constant_name << " => Ok(" << e_.name_ << "::" << constant_name
             << "),";
    }
    stream << "_ => Err(\"No mapping for provided key\"),";
    stream << "}}}";
  }
}
+18 −0
Original line number Diff line number Diff line
@@ -109,6 +109,24 @@ bool generate_rust_source_one_file(
        }
      }
    }

    EnumDef* opcode = nullptr;
    EnumDef* opcode_index = nullptr;
    for (const auto& e : decls.type_defs_queue_) {
      if (e.second->GetDefinitionType() == TypeDef::Type::ENUM) {
        auto* enum_def = dynamic_cast<EnumDef*>(e.second);
        if (enum_def->name_ == "OpCode") {
          opcode = enum_def;
        } else if (enum_def->name_ == "OpCodeIndex") {
          opcode_index = enum_def;
        }
      }
    }

    if (opcode_index != nullptr && opcode != nullptr) {
      opcode_index->try_from_enum_ = opcode;
      out_file << "use std::convert::TryFrom;";
    }
  }

  for (const auto& e : decls.type_defs_queue_) {
+13 −2
Original line number Diff line number Diff line
@@ -362,10 +362,21 @@ void StructDef::GenRustDeclarations(std::ostream& s) const {
  s << "pub struct " << name_ << "{";

  // Generate struct fields
  GenRustFieldNameAndType(s, true);
  auto fields = fields_.GetFieldsWithoutTypes({
      BodyField::kFieldType,
      CountField::kFieldType,
      PaddingField::kFieldType,
      ReservedField::kFieldType,
      SizeField::kFieldType,
  });
  for (const auto& field : fields) {
    s << "pub ";
    field->GenRustNameAndType(s);
    s << ", ";
  }

  // Generate size field
  auto fields = fields_.GetFieldsWithoutTypes({
  fields = fields_.GetFieldsWithoutTypes({
      BodyField::kFieldType,
      CountField::kFieldType,
      PaddingField::kFieldType,
+72 −8
Original line number Diff line number Diff line
@@ -2,12 +2,15 @@

use crate::HciExports;
use bt_packets::hci::{
    Enable, ErrorCode, LeSetEventMaskBuilder, LocalVersionInformation, ReadBufferSizeBuilder,
    Enable, ErrorCode, LeReadBufferSizeV1Builder, LeReadBufferSizeV2Builder, LeSetEventMaskBuilder,
    LocalVersionInformation, OpCode, OpCodeIndex, ReadBufferSizeBuilder,
    ReadLocalExtendedFeaturesBuilder, ReadLocalNameBuilder, ReadLocalSupportedCommandsBuilder,
    ReadLocalVersionInformationBuilder, SetEventMaskBuilder, WriteLeHostSupportBuilder,
    WriteSimplePairingModeBuilder,
};
use gddi::{module, provides, Stoppable};
use num_traits::ToPrimitive;
use std::convert::TryFrom;

module! {
    controller_module,
@@ -48,23 +51,53 @@ async fn provide_controller(mut hci: HciExports) -> ControllerExports {
        .get_local_version_information()
        .clone();

    let supported_commands = assert_success!(hci.send(ReadLocalSupportedCommandsBuilder {}))
        .get_supported_commands()
        .clone();
    let commands = SupportedCommands {
        supported: *assert_success!(hci.send(ReadLocalSupportedCommandsBuilder {}))
            .get_supported_commands(),
    };

    let lmp_features = read_lmp_features(&mut hci).await;

    let buffer_size = assert_success!(hci.send(ReadBufferSizeBuilder {}));
    let acl_buffer_length = buffer_size.get_acl_data_packet_length();
    let mut acl_buffers = buffer_size.get_total_num_acl_data_packets();

    let mut le_buffer_length;
    let mut le_buffers;
    let mut iso_buffer_length = 0;
    let mut iso_buffers = 0;
    if commands.is_supported(OpCode::LeReadBufferSizeV2) {
        let response = assert_success!(hci.send(LeReadBufferSizeV2Builder {}));
        le_buffer_length = response.get_le_buffer_size().le_data_packet_length;
        le_buffers = response.get_le_buffer_size().total_num_le_packets;
        iso_buffer_length = response.get_iso_buffer_size().le_data_packet_length;
        iso_buffers = response.get_iso_buffer_size().total_num_le_packets;
    } else {
        let response = assert_success!(hci.send(LeReadBufferSizeV1Builder {}));
        le_buffer_length = response.get_le_buffer_size().le_data_packet_length;
        le_buffers = response.get_le_buffer_size().total_num_le_packets;
    }

    // If the controller reports zero LE buffers, the ACL buffers are shared between classic & LE
    if le_buffers == 0 {
        le_buffers = (acl_buffers / 2) as u8;
        acl_buffers -= le_buffers as u16;
        le_buffer_length = acl_buffer_length;
    }

    ControllerExports {
        name,
        version_info,
        supported_commands,
        commands,
        lmp_features,
        acl_buffer_length: buffer_size.get_acl_data_packet_length(),
        acl_buffers: buffer_size.get_total_num_acl_data_packets(),
        acl_buffer_length,
        acl_buffers,
        sco_buffer_length: buffer_size.get_synchronous_data_packet_length(),
        sco_buffers: buffer_size.get_total_num_synchronous_data_packets(),
        le_buffer_length,
        le_buffers,
        iso_buffer_length,
        iso_buffers,
    }
}

@@ -87,10 +120,41 @@ async fn read_lmp_features(hci: &mut HciExports) -> Vec<u64> {
pub struct ControllerExports {
    name: String,
    version_info: LocalVersionInformation,
    supported_commands: [u8; 64],
    commands: SupportedCommands,
    lmp_features: Vec<u64>,
    acl_buffer_length: u16,
    acl_buffers: u16,
    sco_buffer_length: u8,
    sco_buffers: u16,
    le_buffer_length: u16,
    le_buffers: u8,
    iso_buffer_length: u16,
    iso_buffers: u8,
}

/// Convenience struct for checking what commands are supported
#[derive(Clone)]
pub struct SupportedCommands {
    supported: [u8; 64],
}

impl SupportedCommands {
    /// Check whether a given opcode is supported by the controller
    pub fn is_supported(&self, opcode: OpCode) -> bool {
        match opcode {
            OpCode::ReadLocalSupportedCommands | OpCode::CreateNewUnitKey => true,
            _ => {
                let converted = OpCodeIndex::try_from(opcode);
                if converted.is_err() {
                    return false;
                }

                let index = converted.unwrap().to_usize().unwrap();

                // The 10 here looks sus, but hci_packets.pdl mentions the index value
                // is octet * 10 + bit
                self.supported[index / 10] & (1 << (index % 10)) == 1
            }
        }
    }
}