Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 9a388d53 authored by Luis Hector Chavez's avatar Luis Hector Chavez
Browse files

adb: Make the Connection object a std::shared_ptr

This change is in preparation to allow the TCP-based transports to be
able to reconnect. This is needed because multiple threads can access
the Connection object. It used to be safe to do because one instance of
atransport would have the same Connection instance throughout its
lifetime, but now it is possible to replace the Connection instance,
which could cause threads that were attempting to Write to an
atransport* to use-after-free the Connection instance.

Bug: 74411879
Test: system/core/adb/test_adb.py
Change-Id: I4f092be11b2095088a9a9de2c0386086814d37ce
parent 56fe7530
Loading
Loading
Loading
Loading
+15 −9
Original line number Diff line number Diff line
@@ -517,8 +517,8 @@ static void transport_registration_func(int _fd, unsigned ev, void*) {
    if (t->GetConnectionState() != kCsNoPerm) {
        /* initial references are the two threads */
        t->ref_count = 1;
        t->connection->SetTransportName(t->serial_name());
        t->connection->SetReadCallback([t](Connection*, std::unique_ptr<apacket> p) {
        t->connection()->SetTransportName(t->serial_name());
        t->connection()->SetReadCallback([t](Connection*, std::unique_ptr<apacket> p) {
            if (!check_header(p.get(), t)) {
                D("%s: remote read: bad header", t->serial);
                return false;
@@ -531,7 +531,7 @@ static void transport_registration_func(int _fd, unsigned ev, void*) {
            fdevent_run_on_main_thread([packet, t]() { handle_packet(packet, t); });
            return true;
        });
        t->connection->SetErrorCallback([t](Connection*, const std::string& error) {
        t->connection()->SetErrorCallback([t](Connection*, const std::string& error) {
            D("%s: connection terminated: %s", t->serial, error.c_str());
            fdevent_run_on_main_thread([t]() {
                handle_offline(t);
@@ -539,7 +539,7 @@ static void transport_registration_func(int _fd, unsigned ev, void*) {
            });
        });

        t->connection->Start();
        t->connection()->Start();
#if ADB_HOST
        send_connect(t);
#endif
@@ -608,7 +608,7 @@ static void transport_unref(atransport* t) {
    t->ref_count--;
    if (t->ref_count == 0) {
        D("transport: %s unref (kicking and closing)", t->serial);
        t->connection->Stop();
        t->connection()->Stop();
        remove_transport(t);
    } else {
        D("transport: %s unref (count=%zu)", t->serial, t->ref_count);
@@ -758,14 +758,14 @@ atransport::~atransport() {
}

int atransport::Write(apacket* p) {
    return this->connection->Write(std::unique_ptr<apacket>(p)) ? 0 : -1;
    return this->connection()->Write(std::unique_ptr<apacket>(p)) ? 0 : -1;
}

void atransport::Kick() {
    if (!kicked_) {
        D("kicking transport %s", this->serial);
        kicked_ = true;
        this->connection->Stop();
        this->connection()->Stop();
    }
}

@@ -778,6 +778,11 @@ void atransport::SetConnectionState(ConnectionState state) {
    connection_state_ = state;
}

void atransport::SetConnection(std::unique_ptr<Connection> connection) {
    std::lock_guard<std::mutex> lock(mutex_);
    connection_ = std::shared_ptr<Connection>(std::move(connection));
}

std::string atransport::connection_state_name() const {
    ConnectionState state = GetConnectionState();
    switch (state) {
@@ -1094,8 +1099,9 @@ void register_usb_transport(usb_handle* usb, const char* serial, const char* dev
void unregister_usb_transport(usb_handle* usb) {
    std::lock_guard<std::recursive_mutex> lock(transport_lock);
    transport_list.remove_if([usb](atransport* t) {
        if (auto connection = dynamic_cast<UsbConnection*>(t->connection.get())) {
            return connection->handle_ == usb && t->GetConnectionState() == kCsNoPerm;
        auto connection = t->connection();
        if (auto usb_connection = dynamic_cast<UsbConnection*>(connection.get())) {
            return usb_connection->handle_ == usb && t->GetConnectionState() == kCsNoPerm;
        }
        return false;
    });
+13 −3
Original line number Diff line number Diff line
@@ -201,7 +201,8 @@ class atransport {
    atransport(ConnectionState state = kCsOffline)
        : id(NextTransportId()),
          connection_state_(state),
          connection_waitable_(std::make_shared<ConnectionWaitable>()) {
          connection_waitable_(std::make_shared<ConnectionWaitable>()),
          connection_(nullptr) {
        // Initialize protocol to min version for compatibility with older versions.
        // Version will be updated post-connect.
        protocol_version = A_VERSION_MIN;
@@ -216,13 +217,17 @@ class atransport {
    ConnectionState GetConnectionState() const;
    void SetConnectionState(ConnectionState state);

    void SetConnection(std::unique_ptr<Connection> connection);
    std::shared_ptr<Connection> connection() {
        std::lock_guard<std::mutex> lock(mutex_);
        return connection_;
    }

    const TransportId id;
    size_t ref_count = 0;
    bool online = false;
    TransportType type = kTransportAny;

    std::unique_ptr<Connection> connection;

    // Used to identify transports for clients.
    char* serial = nullptr;
    char* product = nullptr;
@@ -302,6 +307,11 @@ class atransport {
    // connection to be established.
    std::shared_ptr<ConnectionWaitable> connection_waitable_;

    // The underlying connection object.
    std::shared_ptr<Connection> connection_ GUARDED_BY(mutex_);

    std::mutex mutex_;

    DISALLOW_COPY_AND_ASSIGN(atransport);
};

+3 −2
Original line number Diff line number Diff line
@@ -456,7 +456,8 @@ int init_socket_transport(atransport* t, int s, int adb_port, int local) {
    // Emulator connection.
    if (local) {
        auto emulator_connection = std::make_unique<EmulatorConnection>(std::move(fd), adb_port);
        t->connection = std::make_unique<BlockingConnectionAdapter>(std::move(emulator_connection));
        t->SetConnection(
            std::make_unique<BlockingConnectionAdapter>(std::move(emulator_connection)));
        std::lock_guard<std::mutex> lock(local_transports_lock);
        atransport* existing_transport = find_emulator_transport_by_adb_port_locked(adb_port);
        if (existing_transport != NULL) {
@@ -476,6 +477,6 @@ int init_socket_transport(atransport* t, int s, int adb_port, int local) {

    // Regular tcp connection.
    auto fd_connection = std::make_unique<FdConnection>(std::move(fd));
    t->connection = std::make_unique<BlockingConnectionAdapter>(std::move(fd_connection));
    t->SetConnection(std::make_unique<BlockingConnectionAdapter>(std::move(fd_connection)));
    return fail;
}
+1 −1
Original line number Diff line number Diff line
@@ -176,7 +176,7 @@ void UsbConnection::Close() {
void init_usb_transport(atransport* t, usb_handle* h) {
    D("transport: usb");
    auto connection = std::make_unique<UsbConnection>(h);
    t->connection = std::make_unique<BlockingConnectionAdapter>(std::move(connection));
    t->SetConnection(std::make_unique<BlockingConnectionAdapter>(std::move(connection)));
    t->type = kTransportUsb;
}