Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 2185d8ce authored by Josh Gao's avatar Josh Gao Committed by Gerrit Code Review
Browse files

Merge "adb: don't close sockets before hitting EOF."

parents 042f2da1 ffc11d3c
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ cc_defaults {
        "-Wno-missing-field-initializers",
        "-Wvla",
    ],
    cpp_std: "gnu++17",
    rtti: true,

    use_version_lib: true,
+151 −4
Original line number Diff line number Diff line
@@ -26,10 +26,14 @@
#include <unistd.h>

#include <algorithm>
#include <map>
#include <mutex>
#include <string>
#include <thread>
#include <vector>

#include <android-base/thread_annotations.h>

#if !ADB_HOST
#include <android-base/properties.h>
#include <log/log_properties.h>
@@ -37,9 +41,150 @@

#include "adb.h"
#include "adb_io.h"
#include "adb_utils.h"
#include "sysdeps/chrono.h"
#include "transport.h"
#include "types.h"

// The standard (RFC 1122 - 4.2.2.13) says that if we call close on a
// socket while we have pending data, a TCP RST should be sent to the
// other end to notify it that we didn't read all of its data. However,
// this can result in data that we've successfully written out to be dropped
// on the other end. To avoid this, instead of immediately closing a
// socket, call shutdown on it instead, and then read from the file
// descriptor until we hit EOF or an error before closing.
struct LingeringSocketCloser {
    LingeringSocketCloser() = default;
    ~LingeringSocketCloser() = delete;

    // Defer thread creation until it's needed, because we need for there to
    // only be one thread when dropping privileges in adbd.
    void Start() {
        CHECK(!thread_.joinable());

        int fds[2];
        if (adb_socketpair(fds) != 0) {
            PLOG(FATAL) << "adb_socketpair failed";
        }

        set_file_block_mode(fds[0], false);
        set_file_block_mode(fds[1], false);

        notify_fd_read_.reset(fds[0]);
        notify_fd_write_.reset(fds[1]);

        thread_ = std::thread([this]() { Run(); });
    }

    void EnqueueSocket(unique_fd socket) {
        // Shutdown the socket in the outgoing direction only, so that
        // we don't have the same problem on the opposite end.
        adb_shutdown(socket.get(), SHUT_WR);
        set_file_block_mode(socket.get(), false);

        std::lock_guard<std::mutex> lock(mutex_);
        int fd = socket.get();
        SocketInfo info = {
                .fd = std::move(socket),
                .deadline = std::chrono::steady_clock::now() + 1s,
        };

        D("LingeringSocketCloser received fd %d", fd);

        fds_.emplace(fd, std::move(info));
        if (adb_write(notify_fd_write_, "", 1) == -1 && errno != EAGAIN) {
            PLOG(FATAL) << "failed to write to LingeringSocketCloser notify fd";
        }
    }

  private:
    std::vector<adb_pollfd> GeneratePollFds() {
        std::lock_guard<std::mutex> lock(mutex_);
        std::vector<adb_pollfd> result;
        result.push_back(adb_pollfd{.fd = notify_fd_read_, .events = POLLIN});
        for (auto& [fd, _] : fds_) {
            result.push_back(adb_pollfd{.fd = fd, .events = POLLIN});
        }
        return result;
    }

    void Run() {
        while (true) {
            std::vector<adb_pollfd> pfds = GeneratePollFds();
            int rc = adb_poll(pfds.data(), pfds.size(), 1000);
            if (rc == -1) {
                PLOG(FATAL) << "poll failed in LingeringSocketCloser";
            }

            std::lock_guard<std::mutex> lock(mutex_);
            if (rc == 0) {
                // Check deadlines.
                auto now = std::chrono::steady_clock::now();
                for (auto it = fds_.begin(); it != fds_.end();) {
                    if (now > it->second.deadline) {
                        D("LingeringSocketCloser closing fd %d due to deadline", it->first);
                        it = fds_.erase(it);
                    } else {
                        D("deadline still not expired for fd %d", it->first);
                        ++it;
                    }
                }
                continue;
            }

            for (auto& pfd : pfds) {
                if ((pfd.revents & POLLIN) == 0) {
                    continue;
                }

                // Empty the fd.
                ssize_t rc;
                char buf[32768];
                while ((rc = adb_read(pfd.fd, buf, sizeof(buf))) > 0) {
                    continue;
                }

                if (pfd.fd == notify_fd_read_) {
                    continue;
                }

                auto it = fds_.find(pfd.fd);
                if (it == fds_.end()) {
                    LOG(FATAL) << "fd is missing";
                }

                if (rc == -1 && errno == EAGAIN) {
                    if (std::chrono::steady_clock::now() > it->second.deadline) {
                        D("LingeringSocketCloser closing fd %d due to deadline", pfd.fd);
                    } else {
                        continue;
                    }
                } else if (rc == -1) {
                    D("LingeringSocketCloser closing fd %d due to error %d", pfd.fd, errno);
                } else {
                    D("LingeringSocketCloser closing fd %d due to EOF", pfd.fd);
                }

                fds_.erase(it);
            }
        }
    }

    std::thread thread_;
    unique_fd notify_fd_read_;
    unique_fd notify_fd_write_;

    struct SocketInfo {
        unique_fd fd;
        std::chrono::steady_clock::time_point deadline;
    };

    std::mutex mutex_;
    std::map<int, SocketInfo> fds_ GUARDED_BY(mutex_);
};

static auto& socket_closer = *new LingeringSocketCloser();

static std::recursive_mutex& local_socket_list_lock = *new std::recursive_mutex();
static unsigned local_socket_next_id = 1;

@@ -243,10 +388,12 @@ static void local_socket_destroy(asocket* s) {

    D("LS(%d): destroying fde.fd=%d", s->id, s->fd);

    /* IMPORTANT: the remove closes the fd
    ** that belongs to this socket
    */
    fdevent_destroy(s->fde);
    // Defer thread creation until it's needed, because we need for there to
    // only be one thread when dropping privileges in adbd.
    static std::once_flag once;
    std::call_once(once, []() { socket_closer.Start(); });

    socket_closer.EnqueueSocket(fdevent_release(s->fde));

    remove_socket(s);
    delete s;
+59 −0
Original line number Diff line number Diff line
@@ -35,6 +35,8 @@ import threading
import time
import unittest

from datetime import datetime

import adb

def requires_root(func):
@@ -1335,6 +1337,63 @@ class DeviceOfflineTest(DeviceTest):
            self.device.forward_remove("tcp:{}".format(local_port))


class SocketTest(DeviceTest):
    def test_socket_flush(self):
        """Test that we handle socket closure properly.

        If we're done writing to a socket, closing before the other end has
        closed will send a TCP_RST if we have incoming data queued up, which
        may result in data that we've written being discarded.

        Bug: http://b/74616284
        """
        s = socket.create_connection(("localhost", 5037))

        def adb_length_prefixed(string):
            encoded = string.encode("utf8")
            result = b"%04x%s" % (len(encoded), encoded)
            return result

        if "ANDROID_SERIAL" in os.environ:
            transport_string = "host:transport:" + os.environ["ANDROID_SERIAL"]
        else:
            transport_string = "host:transport-any"

        s.sendall(adb_length_prefixed(transport_string))
        response = s.recv(4)
        self.assertEquals(b"OKAY", response)

        shell_string = "shell:sleep 0.5; dd if=/dev/zero bs=1m count=1 status=none; echo foo"
        s.sendall(adb_length_prefixed(shell_string))

        response = s.recv(4)
        self.assertEquals(b"OKAY", response)

        # Spawn a thread that dumps garbage into the socket until failure.
        def spam():
            buf = b"\0" * 16384
            try:
                while True:
                    s.sendall(buf)
            except Exception as ex:
                print(ex)

        thread = threading.Thread(target=spam)
        thread.start()

        time.sleep(1)

        received = b""
        while True:
            read = s.recv(512)
            if len(read) == 0:
                break
            received += read

        self.assertEquals(1024 * 1024 + len("foo\n"), len(received))
        thread.join()


if sys.platform == "win32":
    # From https://stackoverflow.com/a/38749458
    import os