diff --git a/adb/socket_test.cpp b/adb/socket_test.cpp index 7908f82e9..5e28f7601 100644 --- a/adb/socket_test.cpp +++ b/adb/socket_test.cpp @@ -221,6 +221,8 @@ TEST_F(LocalSocketTest, write_error_when_having_packets) { EXPECT_EQ(2u + GetAdditionalLocalSocketCount(), fdevent_installed_count()); ASSERT_EQ(0, adb_close(socket_fd[0])); + std::this_thread::sleep_for(2s); + WaitForFdeventLoop(); ASSERT_EQ(GetAdditionalLocalSocketCount(), fdevent_installed_count()); TerminateThread(); diff --git a/adb/sockets.cpp b/adb/sockets.cpp index f7c39f0a1..420a6d5a3 100644 --- a/adb/sockets.cpp +++ b/adb/sockets.cpp @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -41,6 +42,8 @@ #include "transport.h" #include "types.h" +using namespace std::chrono_literals; + static std::recursive_mutex& local_socket_list_lock = *new std::recursive_mutex(); static unsigned local_socket_next_id = 1; @@ -238,16 +241,64 @@ static void local_socket_ready(asocket* s) { fdevent_add(s->fde, FDE_READ); } +struct ClosingSocket { + std::chrono::steady_clock::time_point begin; +}; + +// The standard (RFC 1122 - 4.2.2.13) says that if we call close on a +// socket while we have pending data, a TCP RST should be sent to the +// other end to notify it that we didn't read all of its data. However, +// this can result in data that we've successfully written out to be dropped +// on the other end. To avoid this, instead of immediately closing a +// socket, call shutdown on it instead, and then read from the file +// descriptor until we hit EOF or an error before closing. +static void deferred_close(unique_fd fd) { + // Shutdown the socket in the outgoing direction only, so that + // we don't have the same problem on the opposite end. + adb_shutdown(fd.get(), SHUT_WR); + auto callback = [](fdevent* fde, unsigned event, void* arg) { + auto socket_info = static_cast(arg); + if (event & FDE_READ) { + ssize_t rc; + char buf[BUFSIZ]; + while ((rc = adb_read(fde->fd.get(), buf, sizeof(buf))) > 0) { + continue; + } + + if (rc == -1 && errno == EAGAIN) { + // There's potentially more data to read. + auto duration = std::chrono::steady_clock::now() - socket_info->begin; + if (duration > 1s) { + LOG(WARNING) << "timeout expired while flushing socket, closing"; + } else { + return; + } + } + } else if (event & FDE_TIMEOUT) { + LOG(WARNING) << "timeout expired while flushing socket, closing"; + } + + // Either there was an error, we hit the end of the socket, or our timeout expired. + fdevent_destroy(fde); + delete socket_info; + }; + + ClosingSocket* socket_info = new ClosingSocket{ + .begin = std::chrono::steady_clock::now(), + }; + + fdevent* fde = fdevent_create(fd.release(), callback, socket_info); + fdevent_add(fde, FDE_READ); + fdevent_set_timeout(fde, 1s); +} + // be sure to hold the socket list lock when calling this static void local_socket_destroy(asocket* s) { int exit_on_close = s->exit_on_close; D("LS(%d): destroying fde.fd=%d", s->id, s->fd); - /* IMPORTANT: the remove closes the fd - ** that belongs to this socket - */ - fdevent_destroy(s->fde); + deferred_close(fdevent_release(s->fde)); remove_socket(s); delete s; diff --git a/adb/test_device.py b/adb/test_device.py index 34f8fd9fa..f95a5b3cd 100755 --- a/adb/test_device.py +++ b/adb/test_device.py @@ -35,6 +35,8 @@ import threading import time import unittest +from datetime import datetime + import adb def requires_root(func): @@ -1335,6 +1337,63 @@ class DeviceOfflineTest(DeviceTest): self.device.forward_remove("tcp:{}".format(local_port)) +class SocketTest(DeviceTest): + def test_socket_flush(self): + """Test that we handle socket closure properly. + + If we're done writing to a socket, closing before the other end has + closed will send a TCP_RST if we have incoming data queued up, which + may result in data that we've written being discarded. + + Bug: http://b/74616284 + """ + s = socket.create_connection(("localhost", 5037)) + + def adb_length_prefixed(string): + encoded = string.encode("utf8") + result = b"%04x%s" % (len(encoded), encoded) + return result + + if "ANDROID_SERIAL" in os.environ: + transport_string = "host:transport:" + os.environ["ANDROID_SERIAL"] + else: + transport_string = "host:transport-any" + + s.sendall(adb_length_prefixed(transport_string)) + response = s.recv(4) + self.assertEquals(b"OKAY", response) + + shell_string = "shell:sleep 0.5; dd if=/dev/zero bs=1m count=1 status=none; echo foo" + s.sendall(adb_length_prefixed(shell_string)) + + response = s.recv(4) + self.assertEquals(b"OKAY", response) + + # Spawn a thread that dumps garbage into the socket until failure. + def spam(): + buf = b"\0" * 16384 + try: + while True: + s.sendall(buf) + except Exception as ex: + print(ex) + + thread = threading.Thread(target=spam) + thread.start() + + time.sleep(1) + + received = b"" + while True: + read = s.recv(512) + if len(read) == 0: + break + received += read + + self.assertEquals(1024 * 1024 + len("foo\n"), len(received)) + thread.join() + + if sys.platform == "win32": # From https://stackoverflow.com/a/38749458 import os