Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 9a388d53 authored by Luis Hector Chavez's avatar Luis Hector Chavez
Browse files

adb: Make the Connection object a std::shared_ptr

This change is in preparation to allow the TCP-based transports to be
able to reconnect. This is needed because multiple threads can access
the Connection object. It used to be safe to do because one instance of
atransport would have the same Connection instance throughout its
lifetime, but now it is possible to replace the Connection instance,
which could cause threads that were attempting to Write to an
atransport* to use-after-free the Connection instance.

Bug: 74411879
Test: system/core/adb/test_adb.py
Change-Id: I4f092be11b2095088a9a9de2c0386086814d37ce
parent 56fe7530
Loading
Loading
Loading
Loading
+15 −9
Original line number Original line Diff line number Diff line
@@ -517,8 +517,8 @@ static void transport_registration_func(int _fd, unsigned ev, void*) {
    if (t->GetConnectionState() != kCsNoPerm) {
    if (t->GetConnectionState() != kCsNoPerm) {
        /* initial references are the two threads */
        /* initial references are the two threads */
        t->ref_count = 1;
        t->ref_count = 1;
        t->connection->SetTransportName(t->serial_name());
        t->connection()->SetTransportName(t->serial_name());
        t->connection->SetReadCallback([t](Connection*, std::unique_ptr<apacket> p) {
        t->connection()->SetReadCallback([t](Connection*, std::unique_ptr<apacket> p) {
            if (!check_header(p.get(), t)) {
            if (!check_header(p.get(), t)) {
                D("%s: remote read: bad header", t->serial);
                D("%s: remote read: bad header", t->serial);
                return false;
                return false;
@@ -531,7 +531,7 @@ static void transport_registration_func(int _fd, unsigned ev, void*) {
            fdevent_run_on_main_thread([packet, t]() { handle_packet(packet, t); });
            fdevent_run_on_main_thread([packet, t]() { handle_packet(packet, t); });
            return true;
            return true;
        });
        });
        t->connection->SetErrorCallback([t](Connection*, const std::string& error) {
        t->connection()->SetErrorCallback([t](Connection*, const std::string& error) {
            D("%s: connection terminated: %s", t->serial, error.c_str());
            D("%s: connection terminated: %s", t->serial, error.c_str());
            fdevent_run_on_main_thread([t]() {
            fdevent_run_on_main_thread([t]() {
                handle_offline(t);
                handle_offline(t);
@@ -539,7 +539,7 @@ static void transport_registration_func(int _fd, unsigned ev, void*) {
            });
            });
        });
        });


        t->connection->Start();
        t->connection()->Start();
#if ADB_HOST
#if ADB_HOST
        send_connect(t);
        send_connect(t);
#endif
#endif
@@ -608,7 +608,7 @@ static void transport_unref(atransport* t) {
    t->ref_count--;
    t->ref_count--;
    if (t->ref_count == 0) {
    if (t->ref_count == 0) {
        D("transport: %s unref (kicking and closing)", t->serial);
        D("transport: %s unref (kicking and closing)", t->serial);
        t->connection->Stop();
        t->connection()->Stop();
        remove_transport(t);
        remove_transport(t);
    } else {
    } else {
        D("transport: %s unref (count=%zu)", t->serial, t->ref_count);
        D("transport: %s unref (count=%zu)", t->serial, t->ref_count);
@@ -758,14 +758,14 @@ atransport::~atransport() {
}
}


int atransport::Write(apacket* p) {
int atransport::Write(apacket* p) {
    return this->connection->Write(std::unique_ptr<apacket>(p)) ? 0 : -1;
    return this->connection()->Write(std::unique_ptr<apacket>(p)) ? 0 : -1;
}
}


void atransport::Kick() {
void atransport::Kick() {
    if (!kicked_) {
    if (!kicked_) {
        D("kicking transport %s", this->serial);
        D("kicking transport %s", this->serial);
        kicked_ = true;
        kicked_ = true;
        this->connection->Stop();
        this->connection()->Stop();
    }
    }
}
}


@@ -778,6 +778,11 @@ void atransport::SetConnectionState(ConnectionState state) {
    connection_state_ = state;
    connection_state_ = state;
}
}


void atransport::SetConnection(std::unique_ptr<Connection> connection) {
    std::lock_guard<std::mutex> lock(mutex_);
    connection_ = std::shared_ptr<Connection>(std::move(connection));
}

std::string atransport::connection_state_name() const {
std::string atransport::connection_state_name() const {
    ConnectionState state = GetConnectionState();
    ConnectionState state = GetConnectionState();
    switch (state) {
    switch (state) {
@@ -1094,8 +1099,9 @@ void register_usb_transport(usb_handle* usb, const char* serial, const char* dev
void unregister_usb_transport(usb_handle* usb) {
void unregister_usb_transport(usb_handle* usb) {
    std::lock_guard<std::recursive_mutex> lock(transport_lock);
    std::lock_guard<std::recursive_mutex> lock(transport_lock);
    transport_list.remove_if([usb](atransport* t) {
    transport_list.remove_if([usb](atransport* t) {
        if (auto connection = dynamic_cast<UsbConnection*>(t->connection.get())) {
        auto connection = t->connection();
            return connection->handle_ == usb && t->GetConnectionState() == kCsNoPerm;
        if (auto usb_connection = dynamic_cast<UsbConnection*>(connection.get())) {
            return usb_connection->handle_ == usb && t->GetConnectionState() == kCsNoPerm;
        }
        }
        return false;
        return false;
    });
    });
+13 −3
Original line number Original line Diff line number Diff line
@@ -201,7 +201,8 @@ class atransport {
    atransport(ConnectionState state = kCsOffline)
    atransport(ConnectionState state = kCsOffline)
        : id(NextTransportId()),
        : id(NextTransportId()),
          connection_state_(state),
          connection_state_(state),
          connection_waitable_(std::make_shared<ConnectionWaitable>()) {
          connection_waitable_(std::make_shared<ConnectionWaitable>()),
          connection_(nullptr) {
        // Initialize protocol to min version for compatibility with older versions.
        // Initialize protocol to min version for compatibility with older versions.
        // Version will be updated post-connect.
        // Version will be updated post-connect.
        protocol_version = A_VERSION_MIN;
        protocol_version = A_VERSION_MIN;
@@ -216,13 +217,17 @@ class atransport {
    ConnectionState GetConnectionState() const;
    ConnectionState GetConnectionState() const;
    void SetConnectionState(ConnectionState state);
    void SetConnectionState(ConnectionState state);


    void SetConnection(std::unique_ptr<Connection> connection);
    std::shared_ptr<Connection> connection() {
        std::lock_guard<std::mutex> lock(mutex_);
        return connection_;
    }

    const TransportId id;
    const TransportId id;
    size_t ref_count = 0;
    size_t ref_count = 0;
    bool online = false;
    bool online = false;
    TransportType type = kTransportAny;
    TransportType type = kTransportAny;


    std::unique_ptr<Connection> connection;

    // Used to identify transports for clients.
    // Used to identify transports for clients.
    char* serial = nullptr;
    char* serial = nullptr;
    char* product = nullptr;
    char* product = nullptr;
@@ -302,6 +307,11 @@ class atransport {
    // connection to be established.
    // connection to be established.
    std::shared_ptr<ConnectionWaitable> connection_waitable_;
    std::shared_ptr<ConnectionWaitable> connection_waitable_;


    // The underlying connection object.
    std::shared_ptr<Connection> connection_ GUARDED_BY(mutex_);

    std::mutex mutex_;

    DISALLOW_COPY_AND_ASSIGN(atransport);
    DISALLOW_COPY_AND_ASSIGN(atransport);
};
};


+3 −2
Original line number Original line Diff line number Diff line
@@ -456,7 +456,8 @@ int init_socket_transport(atransport* t, int s, int adb_port, int local) {
    // Emulator connection.
    // Emulator connection.
    if (local) {
    if (local) {
        auto emulator_connection = std::make_unique<EmulatorConnection>(std::move(fd), adb_port);
        auto emulator_connection = std::make_unique<EmulatorConnection>(std::move(fd), adb_port);
        t->connection = std::make_unique<BlockingConnectionAdapter>(std::move(emulator_connection));
        t->SetConnection(
            std::make_unique<BlockingConnectionAdapter>(std::move(emulator_connection)));
        std::lock_guard<std::mutex> lock(local_transports_lock);
        std::lock_guard<std::mutex> lock(local_transports_lock);
        atransport* existing_transport = find_emulator_transport_by_adb_port_locked(adb_port);
        atransport* existing_transport = find_emulator_transport_by_adb_port_locked(adb_port);
        if (existing_transport != NULL) {
        if (existing_transport != NULL) {
@@ -476,6 +477,6 @@ int init_socket_transport(atransport* t, int s, int adb_port, int local) {


    // Regular tcp connection.
    // Regular tcp connection.
    auto fd_connection = std::make_unique<FdConnection>(std::move(fd));
    auto fd_connection = std::make_unique<FdConnection>(std::move(fd));
    t->connection = std::make_unique<BlockingConnectionAdapter>(std::move(fd_connection));
    t->SetConnection(std::make_unique<BlockingConnectionAdapter>(std::move(fd_connection)));
    return fail;
    return fail;
}
}
+1 −1
Original line number Original line Diff line number Diff line
@@ -176,7 +176,7 @@ void UsbConnection::Close() {
void init_usb_transport(atransport* t, usb_handle* h) {
void init_usb_transport(atransport* t, usb_handle* h) {
    D("transport: usb");
    D("transport: usb");
    auto connection = std::make_unique<UsbConnection>(h);
    auto connection = std::make_unique<UsbConnection>(h);
    t->connection = std::make_unique<BlockingConnectionAdapter>(std::move(connection));
    t->SetConnection(std::make_unique<BlockingConnectionAdapter>(std::move(connection)));
    t->type = kTransportUsb;
    t->type = kTransportUsb;
}
}