Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 78cc20f0 authored by Josh Gao's avatar Josh Gao
Browse files

libcutils: try all addresses in socket_network_client_timeout.

If a connection fails to an address that resolves to multiple
sockaddrs, attempt connecting to subsequent addresses if the initial
connection fails to a reason other than timeout. This is primarily
useful for localhost, which can resolve to both an IPv4 and and IPv6
address.

Also, add an adb test to verify that this behavior.

Bug: http://b/30313466
Change-Id: Ib2df706a66cf6ef8c1097fdfd7aedb69b8df2d6e
Test: python test_adb.py (+ the test fails before this patch)
parent 8e7ae1e3
Loading
Loading
Loading
Loading
+22 −0
Original line number Diff line number Diff line
@@ -207,6 +207,28 @@ class NonApiTest(unittest.TestCase):
            # reading the response from the adb emu kill command (on Windows).
            self.assertEqual(0, p.returncode)

    def test_connect_ipv4_ipv6(self):
        """Ensure that `adb connect localhost:1234` will try both IPv4 and IPv6.

        Bug: http://b/30313466
        """
        ipv4 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        ipv4.bind(('127.0.0.1', 0))
        ipv4.listen(1)

        ipv6 = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
        ipv6.bind(('::1', ipv4.getsockname()[1] + 1))
        ipv6.listen(1)

        for s in (ipv4, ipv6):
            port = s.getsockname()[1]
            output = subprocess.check_output(
                ['adb', 'connect', 'localhost:{}'.format(port)])

            self.assertEqual(
                output.strip(), 'connected to localhost:{}'.format(port))
            s.close()


def main():
    random.seed(0)
+48 −49
Original line number Diff line number Diff line
@@ -59,25 +59,19 @@ int socket_network_client_timeout(const char* host, int port, int type, int time
        return -1;
    }

    // TODO: try all the addresses if there's more than one?
    int family = addrs[0].ai_family;
    int protocol = addrs[0].ai_protocol;
    socklen_t addr_len = addrs[0].ai_addrlen;
    struct sockaddr_storage addr;
    memcpy(&addr, addrs[0].ai_addr, addr_len);

    freeaddrinfo(addrs);

    int result = -1;
    for (struct addrinfo* addr = addrs; addr != NULL; addr = addr->ai_next) {
        // The Mac doesn't have SOCK_NONBLOCK.
    int s = socket(family, type, protocol);
        int s = socket(addr->ai_family, type, addr->ai_protocol);
        if (s == -1 || toggle_O_NONBLOCK(s) == -1) return -1;

    int rc = connect(s, (const struct sockaddr*) &addr, addr_len);
        int rc = connect(s, addr->ai_addr, addr->ai_addrlen);
        if (rc == 0) {
        return toggle_O_NONBLOCK(s);
            result = toggle_O_NONBLOCK(s);
            break;
        } else if (rc == -1 && errno != EINPROGRESS) {
            close(s);
        return -1;
            continue;
        }

        fd_set r_set;
@@ -90,12 +84,12 @@ int socket_network_client_timeout(const char* host, int port, int type, int time
        ts.tv_usec = 0;
        if ((rc = select(s + 1, &r_set, &w_set, NULL, (timeout != 0) ? &ts : NULL)) == -1) {
            close(s);
        return -1;
            break;
        }
        if (rc == 0) {  // we had a timeout
            errno = ETIMEDOUT;
            close(s);
        return -1;
            break;
        }

        int error = 0;
@@ -103,20 +97,25 @@ int socket_network_client_timeout(const char* host, int port, int type, int time
        if (FD_ISSET(s, &r_set) || FD_ISSET(s, &w_set)) {
            if (getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len) < 0) {
                close(s);
            return -1;
                break;
            }
        } else {
            close(s);
        return -1;
            break;
        }

        if (error) {  // check if we had a socket error
            // TODO: Update the timeout.
            errno = error;
            close(s);
        return -1;
            continue;
        }

    return toggle_O_NONBLOCK(s);
        result = toggle_O_NONBLOCK(s);
    }

    freeaddrinfo(addrs);
    return result;
}

int socket_network_client(const char* host, int port, int type) {