diff --git a/adb/Android.bp b/adb/Android.bp index bccc71a4e..7f82ca6c2 100644 --- a/adb/Android.bp +++ b/adb/Android.bp @@ -24,6 +24,7 @@ cc_defaults { "-Wno-missing-field-initializers", "-Wvla", ], + cpp_std: "gnu++17", rtti: true, use_version_lib: true, diff --git a/adb/sockets.cpp b/adb/sockets.cpp index 15347929a..dfd9a0afa 100644 --- a/adb/sockets.cpp +++ b/adb/sockets.cpp @@ -26,10 +26,14 @@ #include #include +#include #include #include +#include #include +#include + #if !ADB_HOST #include #include @@ -37,9 +41,150 @@ #include "adb.h" #include "adb_io.h" +#include "adb_utils.h" +#include "sysdeps/chrono.h" #include "transport.h" #include "types.h" +// The standard (RFC 1122 - 4.2.2.13) says that if we call close on a +// socket while we have pending data, a TCP RST should be sent to the +// other end to notify it that we didn't read all of its data. However, +// this can result in data that we've successfully written out to be dropped +// on the other end. To avoid this, instead of immediately closing a +// socket, call shutdown on it instead, and then read from the file +// descriptor until we hit EOF or an error before closing. +struct LingeringSocketCloser { + LingeringSocketCloser() = default; + ~LingeringSocketCloser() = delete; + + // Defer thread creation until it's needed, because we need for there to + // only be one thread when dropping privileges in adbd. + void Start() { + CHECK(!thread_.joinable()); + + int fds[2]; + if (adb_socketpair(fds) != 0) { + PLOG(FATAL) << "adb_socketpair failed"; + } + + set_file_block_mode(fds[0], false); + set_file_block_mode(fds[1], false); + + notify_fd_read_.reset(fds[0]); + notify_fd_write_.reset(fds[1]); + + thread_ = std::thread([this]() { Run(); }); + } + + void EnqueueSocket(unique_fd socket) { + // Shutdown the socket in the outgoing direction only, so that + // we don't have the same problem on the opposite end. + adb_shutdown(socket.get(), SHUT_WR); + set_file_block_mode(socket.get(), false); + + std::lock_guard lock(mutex_); + int fd = socket.get(); + SocketInfo info = { + .fd = std::move(socket), + .deadline = std::chrono::steady_clock::now() + 1s, + }; + + D("LingeringSocketCloser received fd %d", fd); + + fds_.emplace(fd, std::move(info)); + if (adb_write(notify_fd_write_, "", 1) == -1 && errno != EAGAIN) { + PLOG(FATAL) << "failed to write to LingeringSocketCloser notify fd"; + } + } + + private: + std::vector GeneratePollFds() { + std::lock_guard lock(mutex_); + std::vector result; + result.push_back(adb_pollfd{.fd = notify_fd_read_, .events = POLLIN}); + for (auto& [fd, _] : fds_) { + result.push_back(adb_pollfd{.fd = fd, .events = POLLIN}); + } + return result; + } + + void Run() { + while (true) { + std::vector pfds = GeneratePollFds(); + int rc = adb_poll(pfds.data(), pfds.size(), 1000); + if (rc == -1) { + PLOG(FATAL) << "poll failed in LingeringSocketCloser"; + } + + std::lock_guard lock(mutex_); + if (rc == 0) { + // Check deadlines. + auto now = std::chrono::steady_clock::now(); + for (auto it = fds_.begin(); it != fds_.end();) { + if (now > it->second.deadline) { + D("LingeringSocketCloser closing fd %d due to deadline", it->first); + it = fds_.erase(it); + } else { + D("deadline still not expired for fd %d", it->first); + ++it; + } + } + continue; + } + + for (auto& pfd : pfds) { + if ((pfd.revents & POLLIN) == 0) { + continue; + } + + // Empty the fd. + ssize_t rc; + char buf[32768]; + while ((rc = adb_read(pfd.fd, buf, sizeof(buf))) > 0) { + continue; + } + + if (pfd.fd == notify_fd_read_) { + continue; + } + + auto it = fds_.find(pfd.fd); + if (it == fds_.end()) { + LOG(FATAL) << "fd is missing"; + } + + if (rc == -1 && errno == EAGAIN) { + if (std::chrono::steady_clock::now() > it->second.deadline) { + D("LingeringSocketCloser closing fd %d due to deadline", pfd.fd); + } else { + continue; + } + } else if (rc == -1) { + D("LingeringSocketCloser closing fd %d due to error %d", pfd.fd, errno); + } else { + D("LingeringSocketCloser closing fd %d due to EOF", pfd.fd); + } + + fds_.erase(it); + } + } + } + + std::thread thread_; + unique_fd notify_fd_read_; + unique_fd notify_fd_write_; + + struct SocketInfo { + unique_fd fd; + std::chrono::steady_clock::time_point deadline; + }; + + std::mutex mutex_; + std::map fds_ GUARDED_BY(mutex_); +}; + +static auto& socket_closer = *new LingeringSocketCloser(); + static std::recursive_mutex& local_socket_list_lock = *new std::recursive_mutex(); static unsigned local_socket_next_id = 1; @@ -243,10 +388,12 @@ static void local_socket_destroy(asocket* s) { D("LS(%d): destroying fde.fd=%d", s->id, s->fd); - /* IMPORTANT: the remove closes the fd - ** that belongs to this socket - */ - fdevent_destroy(s->fde); + // Defer thread creation until it's needed, because we need for there to + // only be one thread when dropping privileges in adbd. + static std::once_flag once; + std::call_once(once, []() { socket_closer.Start(); }); + + socket_closer.EnqueueSocket(fdevent_release(s->fde)); remove_socket(s); delete s; diff --git a/adb/test_device.py b/adb/test_device.py index c3166ffe1..4c45a7378 100755 --- a/adb/test_device.py +++ b/adb/test_device.py @@ -35,6 +35,8 @@ import threading import time import unittest +from datetime import datetime + import adb def requires_root(func): @@ -1335,6 +1337,63 @@ class DeviceOfflineTest(DeviceTest): self.device.forward_remove("tcp:{}".format(local_port)) +class SocketTest(DeviceTest): + def test_socket_flush(self): + """Test that we handle socket closure properly. + + If we're done writing to a socket, closing before the other end has + closed will send a TCP_RST if we have incoming data queued up, which + may result in data that we've written being discarded. + + Bug: http://b/74616284 + """ + s = socket.create_connection(("localhost", 5037)) + + def adb_length_prefixed(string): + encoded = string.encode("utf8") + result = b"%04x%s" % (len(encoded), encoded) + return result + + if "ANDROID_SERIAL" in os.environ: + transport_string = "host:transport:" + os.environ["ANDROID_SERIAL"] + else: + transport_string = "host:transport-any" + + s.sendall(adb_length_prefixed(transport_string)) + response = s.recv(4) + self.assertEquals(b"OKAY", response) + + shell_string = "shell:sleep 0.5; dd if=/dev/zero bs=1m count=1 status=none; echo foo" + s.sendall(adb_length_prefixed(shell_string)) + + response = s.recv(4) + self.assertEquals(b"OKAY", response) + + # Spawn a thread that dumps garbage into the socket until failure. + def spam(): + buf = b"\0" * 16384 + try: + while True: + s.sendall(buf) + except Exception as ex: + print(ex) + + thread = threading.Thread(target=spam) + thread.start() + + time.sleep(1) + + received = b"" + while True: + read = s.recv(512) + if len(read) == 0: + break + received += read + + self.assertEquals(1024 * 1024 + len("foo\n"), len(received)) + thread.join() + + if sys.platform == "win32": # From https://stackoverflow.com/a/38749458 import os