Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d77e5333 authored by Henri Chataing's avatar Henri Chataing
Browse files

RootCanal: Fix potential use-after-free issues

The tasks registered within LinkLayerController capture
the pointer to the Controller instance for member access.
These tasks can outlive the Controller instance if the
controller is deleted (e.g. through the test channel, or
because the TCP connection is lost). In this case
attempting to execute the task callback would cause a
use-after-free invalid pointer access.

This change bypasses the issue by capturing a weak pointer
to the instance to check if it is live before running the
task callback.

Test: m root-canal && launch_cvd
Change-Id: I5384fc056fc4fbc0095b99ef7113fd4f706803ba
parent 069a7c41
Loading
Loading
Loading
Loading
+8 −4
Original line number Diff line number Diff line
@@ -27,10 +27,14 @@ using ::bluetooth::hci::Address;
using ::bluetooth::hci::AddressType;
using ::bluetooth::hci::AddressWithType;

void AclConnectionHandler::RegisterTaskScheduler(
    std::function<AsyncTaskId(std::chrono::milliseconds, const TaskCallback&)>
        event_scheduler) {
  schedule_task_ = event_scheduler;
void AclConnectionHandler::Reset(std::function<void(AsyncTaskId)> stopStream) {
  // Leave no dangling periodic task.
  for (auto& [_, sco_connection] : sco_connections_) {
    sco_connection.StopStream(stopStream);
  }

  sco_connections_.clear();
  acl_connections_.clear();
}

bool AclConnectionHandler::HasHandle(uint16_t handle) const {
+3 −7
Original line number Diff line number Diff line
@@ -35,12 +35,11 @@ static constexpr uint16_t kReservedHandle = 0xF00;
class AclConnectionHandler {
 public:
  AclConnectionHandler() = default;

  virtual ~AclConnectionHandler() = default;

  void RegisterTaskScheduler(
      std::function<AsyncTaskId(std::chrono::milliseconds, const TaskCallback&)>
          event_scheduler);
  // Reset the connection manager state, stopping any pending
  // SCO connections.
  void Reset(std::function<void(AsyncTaskId)> stopStream);

  bool CreatePendingConnection(bluetooth::hci::Address addr,
                               bool authenticate_on_connect);
@@ -152,9 +151,6 @@ class AclConnectionHandler {
  std::unordered_map<uint16_t, AclConnection> acl_connections_;
  std::unordered_map<uint16_t, ScoConnection> sco_connections_;

  std::function<AsyncTaskId(std::chrono::milliseconds, const TaskCallback&)>
      schedule_task_;

  bool classic_connection_pending_{false};
  bluetooth::hci::Address pending_connection_address_{
      bluetooth::hci::Address::kEmpty};
+38 −4
Original line number Diff line number Diff line
@@ -331,16 +331,50 @@ void DualModeController::SniffSubrating(CommandView command) {
}

void DualModeController::RegisterTaskScheduler(
    std::function<AsyncTaskId(std::chrono::milliseconds, const TaskCallback&)>
    std::function<AsyncTaskId(std::chrono::milliseconds, TaskCallback)>
        task_scheduler) {
  link_layer_controller_.RegisterTaskScheduler(task_scheduler);
  link_layer_controller_.RegisterTaskScheduler(
      [this, schedule = std::move(task_scheduler)](
          std::chrono::milliseconds delay_ms, TaskCallback callback) {
        // weak_from_this is valid only if [this] is already protected
        // behind a shared_ptr; this is the case in TestModel.
        return schedule(delay_ms, [lifetime = weak_from_this(),
                                   callback = std::move(callback)] {
          // Capture a weak_ptr of the DualModeController object to protect
          // against the execution of callbacks capturing dead pointers.
          // This can occur if the device is deleted with scheduled events.
          if (lifetime.lock() != nullptr) {
            callback();
          }
        });
      });
}

void DualModeController::RegisterPeriodicTaskScheduler(
    std::function<AsyncTaskId(std::chrono::milliseconds,
                              std::chrono::milliseconds, const TaskCallback&)>
                              std::chrono::milliseconds, TaskCallback)>
        periodic_task_scheduler) {
  link_layer_controller_.RegisterPeriodicTaskScheduler(periodic_task_scheduler);
  link_layer_controller_.RegisterPeriodicTaskScheduler(
      [this, schedule = std::move(periodic_task_scheduler)](
          std::chrono::milliseconds delay_ms,
          std::chrono::milliseconds interval_ms, TaskCallback callback) {
        // weak_from_this is valid only if [this] is already protected
        // behind a shared_ptr; this is the case in TestModel.
        return schedule(
            delay_ms, interval_ms,
            [lifetime = weak_from_this(), callback = std::move(callback)] {
              // Capture a weak_ptr of the DualModeController object to protect
              // against the execution of callbacks capturing dead pointers.
              // This can occur if the device is deleted with scheduled events.
              //
              // Note: the task handle cannot be cancelled from this context;
              // we depend on the link layer to properly clean-up pending
              // periodic tasks when deleted.
              if (lifetime.lock() != nullptr) {
                callback();
              }
            });
      });
}

void DualModeController::RegisterTaskCancel(
+5 −3
Original line number Diff line number Diff line
@@ -51,7 +51,9 @@ using ::bluetooth::hci::CommandView;
// the controller's default constructor. Be sure to name your method after the
// corresponding Bluetooth command in the Core Specification with the prefix
// "Hci" to distinguish it as a controller command.
class DualModeController : public Device {
class DualModeController
    : public Device,
      public std::enable_shared_from_this<DualModeController> {
  static constexpr uint16_t kSecurityManagerNumKeys = 15;

 public:
@@ -78,12 +80,12 @@ class DualModeController : public Device {

  // Set the callbacks for scheduling tasks.
  void RegisterTaskScheduler(
      std::function<AsyncTaskId(std::chrono::milliseconds, const TaskCallback&)>
      std::function<AsyncTaskId(std::chrono::milliseconds, TaskCallback)>
          task_scheduler);

  void RegisterPeriodicTaskScheduler(
      std::function<AsyncTaskId(std::chrono::milliseconds,
                                std::chrono::milliseconds, const TaskCallback&)>
                                std::chrono::milliseconds, TaskCallback)>
          periodic_task_scheduler);

  void RegisterTaskCancel(std::function<void(AsyncTaskId)> cancel);
+14 −13
Original line number Diff line number Diff line
@@ -1421,6 +1421,12 @@ LinkLayerController::LinkLayerController(const Address& address,
    : address_(address), properties_(properties) {}
#endif

LinkLayerController::~LinkLayerController() {
  // Clear out periodic tasks for opened SCO connections in the
  // connection manager state.
  connections_.Reset(cancel_task_);
}

void LinkLayerController::SendLeLinkLayerPacket(
    std::unique_ptr<model::packets::LinkLayerPacketBuilder> packet) {
  std::shared_ptr<model::packets::LinkLayerPacketBuilder> shared_packet =
@@ -4814,36 +4820,31 @@ void LinkLayerController::RegisterRemoteChannel(
}

void LinkLayerController::RegisterTaskScheduler(
    std::function<AsyncTaskId(milliseconds, const TaskCallback&)>
        task_scheduler) {
    std::function<AsyncTaskId(milliseconds, TaskCallback)> task_scheduler) {
  schedule_task_ = task_scheduler;
}

AsyncTaskId LinkLayerController::ScheduleTask(
    milliseconds delay_ms, const TaskCallback& task_callback) {
AsyncTaskId LinkLayerController::ScheduleTask(milliseconds delay_ms,
                                              TaskCallback task_callback) {
  if (schedule_task_) {
    return schedule_task_(delay_ms, task_callback);
  }
  if (delay_ms == milliseconds::zero()) {
    task_callback();
    return 0;
    return schedule_task_(delay_ms, std::move(task_callback));
  }
  LOG_ERROR("Unable to schedule task on delay");
  return 0;
}

AsyncTaskId LinkLayerController::SchedulePeriodicTask(
    milliseconds delay_ms, milliseconds period_ms,
    const TaskCallback& task_callback) {
    milliseconds delay_ms, milliseconds period_ms, TaskCallback task_callback) {
  if (schedule_periodic_task_) {
    return schedule_periodic_task_(delay_ms, period_ms, task_callback);
    return schedule_periodic_task_(delay_ms, period_ms,
                                   std::move(task_callback));
  }
  LOG_ERROR("Unable to schedule task on delay");
  return 0;
}

void LinkLayerController::RegisterPeriodicTaskScheduler(
    std::function<AsyncTaskId(milliseconds, milliseconds, const TaskCallback&)>
    std::function<AsyncTaskId(milliseconds, milliseconds, TaskCallback)>
        periodic_task_scheduler) {
  schedule_periodic_task_ = periodic_task_scheduler;
}
Loading