Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d2bd2edf authored by Josh Gao's avatar Josh Gao Committed by Gerrit Code Review
Browse files

Merge changes Icce121a4,I0f95d348

* changes:
  adb: switch connect_to_remote to string_view.
  adb: switch skip_host_serial to string_view.
parents d5db4e1b d0fa13e5
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -105,7 +105,7 @@ static void listener_event_func(int _fd, unsigned ev, void* _l)
        s = create_local_socket(fd);
        if (s) {
            s->transport = listener->transport;
            connect_to_remote(s, listener->connect_to.c_str());
            connect_to_remote(s, listener->connect_to);
            return;
        }

+3 −2
Original line number Diff line number Diff line
@@ -106,14 +106,15 @@ asocket *create_local_socket(int fd);
asocket* create_local_service_socket(std::string_view destination, atransport* transport);

asocket *create_remote_socket(unsigned id, atransport *t);
void connect_to_remote(asocket *s, const char *destination);
void connect_to_remote(asocket* s, std::string_view destination);
void connect_to_smartsocket(asocket *s);

// Internal functions that are only made available here for testing purposes.
namespace internal {

#if ADB_HOST
char* skip_host_serial(char* service);
bool parse_host_service(std::string_view* out_serial, std::string_view* out_command,
                        std::string_view service);
#endif

}  // namespace internal
+58 −33
Original line number Diff line number Diff line
@@ -34,6 +34,9 @@
#include "sysdeps.h"
#include "sysdeps/chrono.h"

using namespace std::string_literals;
using namespace std::string_view_literals;

struct ThreadArg {
    int first_read_fd;
    int last_write_fd;
@@ -303,56 +306,78 @@ TEST_F(LocalSocketTest, close_socket_in_CLOSE_WAIT_state) {

#if ADB_HOST

// Checks that skip_host_serial(serial) returns a pointer to the part of |serial| which matches
// |expected|, otherwise logs the failure to gtest.
void VerifySkipHostSerial(std::string serial, const char* expected) {
    char* result = internal::skip_host_serial(&serial[0]);
    if (expected == nullptr) {
        EXPECT_EQ(nullptr, result);
    } else {
        EXPECT_STREQ(expected, result);
    }
}
#define VerifyParseHostServiceFailed(s)                                         \
    do {                                                                        \
        std::string service(s);                                                 \
        std::string_view serial, command;                                       \
        bool result = internal::parse_host_service(&serial, &command, service); \
        EXPECT_FALSE(result);                                                   \
    } while (0)

#define VerifyParseHostService(s, expected_serial, expected_command)            \
    do {                                                                        \
        std::string service(s);                                                 \
        std::string_view serial, command;                                       \
        bool result = internal::parse_host_service(&serial, &command, service); \
        EXPECT_TRUE(result);                                                    \
        EXPECT_EQ(std::string(expected_serial), std::string(serial));           \
        EXPECT_EQ(std::string(expected_command), std::string(command));         \
    } while (0);

// Check [tcp:|udp:]<serial>[:<port>]:<command> format.
TEST(socket_test, test_skip_host_serial) {
TEST(socket_test, test_parse_host_service) {
    for (const std::string& protocol : {"", "tcp:", "udp:"}) {
        VerifySkipHostSerial(protocol, nullptr);
        VerifySkipHostSerial(protocol + "foo", nullptr);
        VerifyParseHostServiceFailed(protocol);
        VerifyParseHostServiceFailed(protocol + "foo");

        VerifySkipHostSerial(protocol + "foo:bar", ":bar");
        VerifySkipHostSerial(protocol + "foo:bar:baz", ":bar:baz");
        {
            std::string serial = protocol + "foo";
            VerifyParseHostService(serial + ":bar", serial, "bar");
            VerifyParseHostService(serial + " :bar:baz", serial, "bar:baz");
        }

        VerifySkipHostSerial(protocol + "foo:123:bar", ":bar");
        VerifySkipHostSerial(protocol + "foo:123:456", ":456");
        VerifySkipHostSerial(protocol + "foo:123:bar:baz", ":bar:baz");
        {
            // With port.
            std::string serial = protocol + "foo:123";
            VerifyParseHostService(serial + ":bar", serial, "bar");
            VerifyParseHostService(serial + ":456", serial, "456");
            VerifyParseHostService(serial + ":bar:baz", serial, "bar:baz");
        }

        // Don't register a port unless it's all numbers and ends with ':'.
        VerifySkipHostSerial(protocol + "foo:123", ":123");
        VerifySkipHostSerial(protocol + "foo:123bar:baz", ":123bar:baz");

        VerifySkipHostSerial(protocol + "100.100.100.100:5555:foo", ":foo");
        VerifySkipHostSerial(protocol + "[0123:4567:89ab:CDEF:0:9:a:f]:5555:foo", ":foo");
        VerifySkipHostSerial(protocol + "[::1]:5555:foo", ":foo");
        VerifyParseHostService(protocol + "foo:123", protocol + "foo", "123");
        VerifyParseHostService(protocol + "foo:123bar:baz", protocol + "foo", "123bar:baz");

        std::string addresses[] = {"100.100.100.100", "[0123:4567:89ab:CDEF:0:9:a:f]", "[::1]"};
        for (const std::string& address : addresses) {
            std::string serial = protocol + address;
            std::string serial_with_port = protocol + address + ":5555";
            VerifyParseHostService(serial + ":foo", serial, "foo");
            VerifyParseHostService(serial_with_port + ":foo", serial_with_port, "foo");
        }

        // If we can't find both [] then treat it as a normal serial with [ in it.
        VerifySkipHostSerial(protocol + "[0123:foo", ":foo");
        VerifyParseHostService(protocol + "[0123:foo", protocol + "[0123", "foo");

        // Don't be fooled by random IPv6 addresses in the command string.
        VerifySkipHostSerial(protocol + "foo:ping [0123:4567:89ab:CDEF:0:9:a:f]:5555",
                             ":ping [0123:4567:89ab:CDEF:0:9:a:f]:5555");
        VerifyParseHostService(protocol + "foo:ping [0123:4567:89ab:CDEF:0:9:a:f]:5555",
                               protocol + "foo", "ping [0123:4567:89ab:CDEF:0:9:a:f]:5555");

        // Handle embedded NULs properly.
        VerifyParseHostService(protocol + "foo:echo foo\0bar"s, protocol + "foo",
                               "echo foo\0bar"sv);
    }
}

// Check <prefix>:<serial>:<command> format.
TEST(socket_test, test_skip_host_serial_prefix) {
TEST(socket_test, test_parse_host_service_prefix) {
    for (const std::string& prefix : {"usb:", "product:", "model:", "device:"}) {
        VerifySkipHostSerial(prefix, nullptr);
        VerifySkipHostSerial(prefix + "foo", nullptr);
        VerifyParseHostServiceFailed(prefix);
        VerifyParseHostServiceFailed(prefix + "foo");

        VerifySkipHostSerial(prefix + "foo:bar", ":bar");
        VerifySkipHostSerial(prefix + "foo:bar:baz", ":bar:baz");
        VerifySkipHostSerial(prefix + "foo:123:bar", ":123:bar");
        VerifyParseHostService(prefix + "foo:bar", prefix + "foo", "bar");
        VerifyParseHostService(prefix + "foo:bar:baz", prefix + "foo", "bar:baz");
        VerifyParseHostService(prefix + "foo:123:bar", prefix + "foo", "123:bar");
    }
}

+141 −70
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@

#include "adb.h"
#include "adb_io.h"
#include "adb_utils.h"
#include "transport.h"
#include "types.h"

@@ -461,16 +462,19 @@ asocket* create_remote_socket(unsigned id, atransport* t) {
    return s;
}

void connect_to_remote(asocket* s, const char* destination) {
void connect_to_remote(asocket* s, std::string_view destination) {
    D("Connect_to_remote call RS(%d) fd=%d", s->id, s->fd);
    apacket* p = get_apacket();

    D("LS(%d): connect('%s')", s->id, destination);
    LOG(VERBOSE) << "LS(" << s->id << ": connect(" << destination << ")";
    p->msg.command = A_OPEN;
    p->msg.arg0 = s->id;

    // adbd expects a null-terminated string.
    p->payload.assign(destination, destination + strlen(destination) + 1);
    // adbd used to expect a null-terminated string.
    // Keep doing so to maintain backward compatibility.
    p->payload.resize(destination.size() + 1);
    memcpy(p->payload.data(), destination.data(), destination.size());
    p->payload[destination.size()] = '\0';
    p->msg.data_length = p->payload.size();

    CHECK_LE(p->msg.data_length, s->get_max_payload());
@@ -546,57 +550,119 @@ static unsigned unhex(const char* s, int len) {

namespace internal {

// Returns the position in |service| following the target serial parameter. Serial format can be
// any of:
// Parses a host service string of the following format:
//   * [tcp:|udp:]<serial>[:<port>]:<command>
//   * <prefix>:<serial>:<command>
// Where <port> must be a base-10 number and <prefix> may be any of {usb,product,model,device}.
//
// The returned pointer will point to the ':' just before <command>, or nullptr if not found.
char* skip_host_serial(char* service) {
    static const std::vector<std::string>& prefixes =
        *(new std::vector<std::string>{"usb:", "product:", "model:", "device:"});
bool parse_host_service(std::string_view* out_serial, std::string_view* out_command,
                        std::string_view full_service) {
    if (full_service.empty()) {
        return false;
    }

    std::string_view serial;
    std::string_view command = full_service;
    // Remove |count| bytes from the beginning of command and add them to |serial|.
    auto consume = [&full_service, &serial, &command](size_t count) {
        CHECK_LE(count, command.size());
        if (!serial.empty()) {
            CHECK_EQ(serial.data() + serial.size(), command.data());
        }

        serial = full_service.substr(0, serial.size() + count);
        command.remove_prefix(count);
    };

    // Remove the trailing : from serial, and assign the values to the output parameters.
    auto finish = [out_serial, out_command, &serial, &command] {
        if (serial.empty() || command.empty()) {
            return false;
        }

        CHECK_EQ(':', serial.back());
        serial.remove_suffix(1);

        *out_serial = serial;
        *out_command = command;
        return true;
    };

    for (const std::string& prefix : prefixes) {
        if (!strncmp(service, prefix.c_str(), prefix.length())) {
            return strchr(service + prefix.length(), ':');
    static constexpr std::string_view prefixes[] = {"usb:", "product:", "model:", "device:"};
    for (std::string_view prefix : prefixes) {
        if (command.starts_with(prefix)) {
            consume(prefix.size());

            size_t offset = command.find_first_of(':');
            if (offset == std::string::npos) {
                return false;
            }
            consume(offset + 1);
            return finish();
        }
    }

    // For fastboot compatibility, ignore protocol prefixes.
    if (!strncmp(service, "tcp:", 4) || !strncmp(service, "udp:", 4)) {
        service += 4;
    if (command.starts_with("tcp:") || command.starts_with("udp:")) {
        consume(4);
        if (command.empty()) {
            return false;
        }
    }

    // Check for an IPv6 address. `adb connect` creates the serial number from the canonical
    bool found_address = false;
    if (command[0] == '[') {
        // Read an IPv6 address. `adb connect` creates the serial number from the canonical
        // network address so it will always have the [] delimiters.
    if (service[0] == '[') {
        char* ipv6_end = strchr(service, ']');
        if (ipv6_end != nullptr) {
            service = ipv6_end;
        size_t ipv6_end = command.find_first_of(']');
        if (ipv6_end != std::string::npos) {
            consume(ipv6_end + 1);
            if (command.empty()) {
                // Nothing after the IPv6 address.
                return false;
            } else if (command[0] != ':') {
                // Garbage after the IPv6 address.
                return false;
            }
            consume(1);
            found_address = true;
        }
    }

    // The next colon we find must either begin the port field or the command field.
    char* colon_ptr = strchr(service, ':');
    if (!colon_ptr) {
        // No colon in service string.
        return nullptr;
    if (!found_address) {
        // Scan ahead to the next colon.
        size_t offset = command.find_first_of(':');
        if (offset == std::string::npos) {
            return false;
        }
        consume(offset + 1);
    }

    // We're either at the beginning of a port, or the command itself.
    // Look for a port in between colons.
    size_t next_colon = command.find_first_of(':');
    if (next_colon == std::string::npos) {
        // No colon, we must be at the command.
        return finish();
    }

    // If the next field is only decimal digits and ends with another colon, it's a port.
    char* serial_end = colon_ptr;
    if (isdigit(serial_end[1])) {
        serial_end++;
        while (*serial_end && isdigit(*serial_end)) {
            serial_end++;
    bool port_valid = true;
    if (command.size() <= next_colon) {
        return false;
    }
        if (*serial_end != ':') {
            // Something other than "<port>:" was found, this must be the command field instead.
            serial_end = colon_ptr;

    std::string_view port = command.substr(0, next_colon);
    for (auto digit : port) {
        if (!isdigit(digit)) {
            // Port isn't a number.
            port_valid = false;
            break;
        }
    }
    return serial_end;

    if (port_valid) {
        consume(next_colon + 1);
    }
    return finish();
}

}  // namespace internal
@@ -605,8 +671,8 @@ char* skip_host_serial(char* service) {

static int smart_socket_enqueue(asocket* s, apacket::payload_type data) {
#if ADB_HOST
    char* service = nullptr;
    char* serial = nullptr;
    std::string_view service;
    std::string_view serial;
    TransportId transport_id = 0;
    TransportType type = kTransportAny;
#endif
@@ -643,49 +709,52 @@ static int smart_socket_enqueue(asocket* s, apacket::payload_type data) {
    D("SS(%d): '%s'", s->id, (char*)(s->smart_socket_data.data() + 4));

#if ADB_HOST
    service = &s->smart_socket_data[4];
    if (!strncmp(service, "host-serial:", strlen("host-serial:"))) {
        char* serial_end;
        service += strlen("host-serial:");
    service = std::string_view(s->smart_socket_data).substr(4);
    if (service.starts_with("host-serial:")) {
        service.remove_prefix(strlen("host-serial:"));

        // serial number should follow "host:" and could be a host:port string.
        serial_end = internal::skip_host_serial(service);
        if (serial_end) {
            *serial_end = 0;  // terminate string
            serial = service;
            service = serial_end + 1;
        if (!internal::parse_host_service(&serial, &service, service)) {
            LOG(ERROR) << "SS(" << s->id << "): failed to parse host service: " << service;
            goto fail;
        }
    } else if (!strncmp(service, "host-transport-id:", strlen("host-transport-id:"))) {
        service += strlen("host-transport-id:");
        transport_id = strtoll(service, &service, 10);

        if (*service != ':') {
    } else if (service.starts_with("host-transport-id:")) {
        service.remove_prefix(strlen("host-transport-id:"));
        if (!ParseUint(&transport_id, service, &service)) {
            LOG(ERROR) << "SS(" << s->id << "): failed to parse host transport id: " << service;
            return -1;
        }
        if (!service.starts_with(":")) {
            LOG(ERROR) << "SS(" << s->id << "): host-transport-id without command";
            return -1;
        }
        service++;
    } else if (!strncmp(service, "host-usb:", strlen("host-usb:"))) {
        service.remove_prefix(1);
    } else if (service.starts_with("host-usb:")) {
        type = kTransportUsb;
        service += strlen("host-usb:");
    } else if (!strncmp(service, "host-local:", strlen("host-local:"))) {
        service.remove_prefix(strlen("host-usb:"));
    } else if (service.starts_with("host-local:")) {
        type = kTransportLocal;
        service += strlen("host-local:");
    } else if (!strncmp(service, "host:", strlen("host:"))) {
        service.remove_prefix(strlen("host-local:"));
    } else if (service.starts_with("host:")) {
        type = kTransportAny;
        service += strlen("host:");
        service.remove_prefix(strlen("host:"));
    } else {
        service = nullptr;
        service = std::string_view{};
    }

    if (service) {
    if (!service.empty()) {
        asocket* s2;

        // Some requests are handled immediately -- in that case the handle_host_request() routine
        // has sent the OKAY or FAIL message and all we have to do is clean up.
        if (handle_host_request(service, type, serial, transport_id, s->peer->fd, s)) {
            D("SS(%d): handled host service '%s'", s->id, service);
        // TODO: Convert to string_view.
        if (handle_host_request(std::string(service).c_str(), type,
                                serial.empty() ? nullptr : std::string(serial).c_str(),
                                transport_id, s->peer->fd, s)) {
            LOG(VERBOSE) << "SS(" << s->id << "): handled host service '" << service << "'";
            goto fail;
        }
        if (!strncmp(service, "transport", strlen("transport"))) {
        if (service.starts_with("transport")) {
            D("SS(%d): okay transport", s->id);
            s->smart_socket_data.clear();
            return 0;
@@ -695,9 +764,11 @@ static int smart_socket_enqueue(asocket* s, apacket::payload_type data) {
        ** if no such service exists, we'll fail out
        ** and tear down here.
        */
        s2 = create_host_service_socket(service, serial, transport_id);
        // TODO: Convert to string_view.
        s2 = create_host_service_socket(std::string(service).c_str(), std::string(serial).c_str(),
                                        transport_id);
        if (s2 == nullptr) {
            D("SS(%d): couldn't create host service '%s'", s->id, service);
            LOG(VERBOSE) << "SS(" << s->id << "): couldn't create host service '" << service << "'";
            SendFail(s->peer->fd, "unknown host service");
            goto fail;
        }
@@ -758,7 +829,7 @@ static int smart_socket_enqueue(asocket* s, apacket::payload_type data) {
    /* give him our transport and upref it */
    s->peer->transport = s->transport;

    connect_to_remote(s->peer, s->smart_socket_data.data() + 4);
    connect_to_remote(s->peer, std::string_view(s->smart_socket_data).substr(4));
    s->peer = nullptr;
    s->close(s);
    return 1;