Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ffc11d3c authored by Josh Gao's avatar Josh Gao
Browse files

adb: don't close sockets before hitting EOF.

The standard (RFC 1122 - 4.2.2.13) says that if we call close on a
socket while we have pending data, a TCP RST should be sent to the
other end to notify it that we didn't read all of its data. However,
this can result in data that we've succesfully written out to be dropped
on the other end. To avoid this, instead of immediately closing a
socket, call shutdown on it instead, and then read from the file
descriptor until we hit EOF or an error before closing.

Bug: http://b/74616284
Test: ./test_adb.py
Test: ./test_device.py
Change-Id: I36f72bd14965821dc23de82774b0806b2db24f13
parent a3303fd2
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ cc_defaults {
        "-Wno-missing-field-initializers",
        "-Wvla",
    ],
    cpp_std: "gnu++17",
    rtti: true,

    use_version_lib: true,
+151 −4
Original line number Diff line number Diff line
@@ -26,10 +26,14 @@
#include <unistd.h>

#include <algorithm>
#include <map>
#include <mutex>
#include <string>
#include <thread>
#include <vector>

#include <android-base/thread_annotations.h>

#if !ADB_HOST
#include <android-base/properties.h>
#include <log/log_properties.h>
@@ -37,9 +41,150 @@

#include "adb.h"
#include "adb_io.h"
#include "adb_utils.h"
#include "sysdeps/chrono.h"
#include "transport.h"
#include "types.h"

// The standard (RFC 1122 - 4.2.2.13) says that if we call close on a
// socket while we have pending data, a TCP RST should be sent to the
// other end to notify it that we didn't read all of its data. However,
// this can result in data that we've successfully written out to be dropped
// on the other end. To avoid this, instead of immediately closing a
// socket, call shutdown on it instead, and then read from the file
// descriptor until we hit EOF or an error before closing.
struct LingeringSocketCloser {
    LingeringSocketCloser() = default;
    ~LingeringSocketCloser() = delete;

    // Defer thread creation until it's needed, because we need for there to
    // only be one thread when dropping privileges in adbd.
    void Start() {
        CHECK(!thread_.joinable());

        int fds[2];
        if (adb_socketpair(fds) != 0) {
            PLOG(FATAL) << "adb_socketpair failed";
        }

        set_file_block_mode(fds[0], false);
        set_file_block_mode(fds[1], false);

        notify_fd_read_.reset(fds[0]);
        notify_fd_write_.reset(fds[1]);

        thread_ = std::thread([this]() { Run(); });
    }

    void EnqueueSocket(unique_fd socket) {
        // Shutdown the socket in the outgoing direction only, so that
        // we don't have the same problem on the opposite end.
        adb_shutdown(socket.get(), SHUT_WR);
        set_file_block_mode(socket.get(), false);

        std::lock_guard<std::mutex> lock(mutex_);
        int fd = socket.get();
        SocketInfo info = {
                .fd = std::move(socket),
                .deadline = std::chrono::steady_clock::now() + 1s,
        };

        D("LingeringSocketCloser received fd %d", fd);

        fds_.emplace(fd, std::move(info));
        if (adb_write(notify_fd_write_, "", 1) == -1 && errno != EAGAIN) {
            PLOG(FATAL) << "failed to write to LingeringSocketCloser notify fd";
        }
    }

  private:
    std::vector<adb_pollfd> GeneratePollFds() {
        std::lock_guard<std::mutex> lock(mutex_);
        std::vector<adb_pollfd> result;
        result.push_back(adb_pollfd{.fd = notify_fd_read_, .events = POLLIN});
        for (auto& [fd, _] : fds_) {
            result.push_back(adb_pollfd{.fd = fd, .events = POLLIN});
        }
        return result;
    }

    void Run() {
        while (true) {
            std::vector<adb_pollfd> pfds = GeneratePollFds();
            int rc = adb_poll(pfds.data(), pfds.size(), 1000);
            if (rc == -1) {
                PLOG(FATAL) << "poll failed in LingeringSocketCloser";
            }

            std::lock_guard<std::mutex> lock(mutex_);
            if (rc == 0) {
                // Check deadlines.
                auto now = std::chrono::steady_clock::now();
                for (auto it = fds_.begin(); it != fds_.end();) {
                    if (now > it->second.deadline) {
                        D("LingeringSocketCloser closing fd %d due to deadline", it->first);
                        it = fds_.erase(it);
                    } else {
                        D("deadline still not expired for fd %d", it->first);
                        ++it;
                    }
                }
                continue;
            }

            for (auto& pfd : pfds) {
                if ((pfd.revents & POLLIN) == 0) {
                    continue;
                }

                // Empty the fd.
                ssize_t rc;
                char buf[32768];
                while ((rc = adb_read(pfd.fd, buf, sizeof(buf))) > 0) {
                    continue;
                }

                if (pfd.fd == notify_fd_read_) {
                    continue;
                }

                auto it = fds_.find(pfd.fd);
                if (it == fds_.end()) {
                    LOG(FATAL) << "fd is missing";
                }

                if (rc == -1 && errno == EAGAIN) {
                    if (std::chrono::steady_clock::now() > it->second.deadline) {
                        D("LingeringSocketCloser closing fd %d due to deadline", pfd.fd);
                    } else {
                        continue;
                    }
                } else if (rc == -1) {
                    D("LingeringSocketCloser closing fd %d due to error %d", pfd.fd, errno);
                } else {
                    D("LingeringSocketCloser closing fd %d due to EOF", pfd.fd);
                }

                fds_.erase(it);
            }
        }
    }

    std::thread thread_;
    unique_fd notify_fd_read_;
    unique_fd notify_fd_write_;

    struct SocketInfo {
        unique_fd fd;
        std::chrono::steady_clock::time_point deadline;
    };

    std::mutex mutex_;
    std::map<int, SocketInfo> fds_ GUARDED_BY(mutex_);
};

static auto& socket_closer = *new LingeringSocketCloser();

static std::recursive_mutex& local_socket_list_lock = *new std::recursive_mutex();
static unsigned local_socket_next_id = 1;

@@ -243,10 +388,12 @@ static void local_socket_destroy(asocket* s) {

    D("LS(%d): destroying fde.fd=%d", s->id, s->fd);

    /* IMPORTANT: the remove closes the fd
    ** that belongs to this socket
    */
    fdevent_destroy(s->fde);
    // Defer thread creation until it's needed, because we need for there to
    // only be one thread when dropping privileges in adbd.
    static std::once_flag once;
    std::call_once(once, []() { socket_closer.Start(); });

    socket_closer.EnqueueSocket(fdevent_release(s->fde));

    remove_socket(s);
    delete s;
+59 −0
Original line number Diff line number Diff line
@@ -35,6 +35,8 @@ import threading
import time
import unittest

from datetime import datetime

import adb

def requires_root(func):
@@ -1335,6 +1337,63 @@ class DeviceOfflineTest(DeviceTest):
            self.device.forward_remove("tcp:{}".format(local_port))


class SocketTest(DeviceTest):
    def test_socket_flush(self):
        """Test that we handle socket closure properly.

        If we're done writing to a socket, closing before the other end has
        closed will send a TCP_RST if we have incoming data queued up, which
        may result in data that we've written being discarded.

        Bug: http://b/74616284
        """
        s = socket.create_connection(("localhost", 5037))

        def adb_length_prefixed(string):
            encoded = string.encode("utf8")
            result = b"%04x%s" % (len(encoded), encoded)
            return result

        if "ANDROID_SERIAL" in os.environ:
            transport_string = "host:transport:" + os.environ["ANDROID_SERIAL"]
        else:
            transport_string = "host:transport-any"

        s.sendall(adb_length_prefixed(transport_string))
        response = s.recv(4)
        self.assertEquals(b"OKAY", response)

        shell_string = "shell:sleep 0.5; dd if=/dev/zero bs=1m count=1 status=none; echo foo"
        s.sendall(adb_length_prefixed(shell_string))

        response = s.recv(4)
        self.assertEquals(b"OKAY", response)

        # Spawn a thread that dumps garbage into the socket until failure.
        def spam():
            buf = b"\0" * 16384
            try:
                while True:
                    s.sendall(buf)
            except Exception as ex:
                print(ex)

        thread = threading.Thread(target=spam)
        thread.start()

        time.sleep(1)

        received = b""
        while True:
            read = s.recv(512)
            if len(read) == 0:
                break
            received += read

        self.assertEquals(1024 * 1024 + len("foo\n"), len(received))
        thread.join()


if sys.platform == "win32":
    # From https://stackoverflow.com/a/38749458
    import os