From ef3d343254405cc360b4df843c6e4a843c335012 Mon Sep 17 00:00:00 2001 From: Josh Gao Date: Tue, 2 May 2017 15:01:09 -0700 Subject: [PATCH] adb: use the actual wMaxPacketSize for usb endpoints. Previously, adb was assuming a fixed maximum packet size of 1024 bytes (the value for an endpoint connected via USB 3.0). When connected to an endpoint that has an actual maximum packet size of 512 bytes (i.e. every single device over USB 2.0), the following could occur: device sends amessage with 512 byte payload client reads amessage client tries to read payload with a length of 1024 In this scenario, the kernel will block, waiting for an additional packet which won't arrive until something else gets sent across the wire, which will result in the previous read failing, and the new packet being dropped. Bug: http://b/37783561 Test: python test_device.py on linux/darwin, with native/libusb Change-Id: I556f5344945e22dd1533b076f662a97eea24628e --- adb/client/usb_dispatch.cpp | 6 ++++ adb/client/usb_libusb.cpp | 30 ++++++++++++---- adb/client/usb_linux.cpp | 29 ++++++++++------ adb/client/usb_osx.cpp | 18 ++++++++-- adb/client/usb_windows.cpp | 8 +++++ adb/test_device.py | 20 +++++++++++ adb/transport_usb.cpp | 68 +++++++++++++++---------------------- adb/usb.h | 5 ++- 8 files changed, 122 insertions(+), 62 deletions(-) diff --git a/adb/client/usb_dispatch.cpp b/adb/client/usb_dispatch.cpp index bfc8e164b..710a3ce85 100644 --- a/adb/client/usb_dispatch.cpp +++ b/adb/client/usb_dispatch.cpp @@ -48,3 +48,9 @@ void usb_kick(usb_handle* h) { should_use_libusb() ? libusb::usb_kick(reinterpret_cast(h)) : native::usb_kick(reinterpret_cast(h)); } + +size_t usb_get_max_packet_size(usb_handle* h) { + return should_use_libusb() + ? libusb::usb_get_max_packet_size(reinterpret_cast(h)) + : native::usb_get_max_packet_size(reinterpret_cast(h)); +} diff --git a/adb/client/usb_libusb.cpp b/adb/client/usb_libusb.cpp index fec4742b2..d39884ac7 100644 --- a/adb/client/usb_libusb.cpp +++ b/adb/client/usb_libusb.cpp @@ -91,7 +91,7 @@ namespace libusb { struct usb_handle : public ::usb_handle { usb_handle(const std::string& device_address, const std::string& serial, unique_device_handle&& device_handle, uint8_t interface, uint8_t bulk_in, - uint8_t bulk_out, size_t zero_mask) + uint8_t bulk_out, size_t zero_mask, size_t max_packet_size) : device_address(device_address), serial(serial), closing(false), @@ -100,7 +100,8 @@ struct usb_handle : public ::usb_handle { write("write", zero_mask, true), interface(interface), bulk_in(bulk_in), - bulk_out(bulk_out) {} + bulk_out(bulk_out), + max_packet_size(max_packet_size) {} ~usb_handle() { Close(); @@ -143,6 +144,8 @@ struct usb_handle : public ::usb_handle { uint8_t interface; uint8_t bulk_in; uint8_t bulk_out; + + size_t max_packet_size; }; static auto& usb_handles = *new std::unordered_map>(); @@ -206,6 +209,7 @@ static void poll_for_devices() { size_t interface_num; uint16_t zero_mask; uint8_t bulk_in = 0, bulk_out = 0; + size_t packet_size = 0; bool found_adb = false; for (interface_num = 0; interface_num < config->bNumInterfaces; ++interface_num) { @@ -252,6 +256,14 @@ static void poll_for_devices() { found_in = true; bulk_in = endpoint_addr; } + + size_t endpoint_packet_size = endpoint_desc.wMaxPacketSize; + CHECK(endpoint_packet_size != 0); + if (packet_size == 0) { + packet_size = endpoint_packet_size; + } else { + CHECK(packet_size == endpoint_packet_size); + } } if (found_in && found_out) { @@ -280,7 +292,7 @@ static void poll_for_devices() { } libusb_device_handle* handle_raw; - rc = libusb_open(list[i], &handle_raw); + rc = libusb_open(device, &handle_raw); if (rc != 0) { LOG(WARNING) << "failed to open usb device at " << device_address << ": " << libusb_error_name(rc); @@ -324,9 +336,9 @@ static void poll_for_devices() { } } - auto result = - std::make_unique(device_address, device_serial, std::move(handle), - interface_num, bulk_in, bulk_out, zero_mask); + auto result = std::make_unique(device_address, device_serial, + std::move(handle), interface_num, bulk_in, + bulk_out, zero_mask, packet_size); usb_handle* usb_handle_raw = result.get(); { @@ -507,4 +519,10 @@ int usb_close(usb_handle* h) { void usb_kick(usb_handle* h) { h->Close(); } + +size_t usb_get_max_packet_size(usb_handle* h) { + CHECK(h->max_packet_size != 0); + return h->max_packet_size; +} + } // namespace libusb diff --git a/adb/client/usb_linux.cpp b/adb/client/usb_linux.cpp index 6efed274b..f9ba7cbc2 100644 --- a/adb/client/usb_linux.cpp +++ b/adb/client/usb_linux.cpp @@ -65,6 +65,7 @@ struct usb_handle : public ::usb_handle { unsigned char ep_in; unsigned char ep_out; + size_t max_packet_size; unsigned zero_mask; unsigned writeable = 1; @@ -120,9 +121,9 @@ static inline bool contains_non_digit(const char* name) { } static void find_usb_device(const std::string& base, - void (*register_device_callback) - (const char*, const char*, unsigned char, unsigned char, int, int, unsigned)) -{ + void (*register_device_callback)(const char*, const char*, + unsigned char, unsigned char, int, int, + unsigned, size_t)) { std::unique_ptr bus_dir(opendir(base.c_str()), closedir); if (!bus_dir) return; @@ -144,6 +145,7 @@ static void find_usb_device(const std::string& base, struct usb_interface_descriptor* interface; struct usb_endpoint_descriptor *ep1, *ep2; unsigned zero_mask = 0; + size_t max_packet_size = 0; unsigned vid, pid; if (contains_non_digit(de->d_name)) continue; @@ -251,7 +253,8 @@ static void find_usb_device(const std::string& base, continue; } /* aproto 01 needs 0 termination */ - if(interface->bInterfaceProtocol == 0x01) { + if (interface->bInterfaceProtocol == 0x01) { + max_packet_size = ep1->wMaxPacketSize; zero_mask = ep1->wMaxPacketSize - 1; } @@ -281,9 +284,9 @@ static void find_usb_device(const std::string& base, } } - register_device_callback(dev_name.c_str(), devpath, - local_ep_in, local_ep_out, - interface->bInterfaceNumber, device->iSerialNumber, zero_mask); + register_device_callback(dev_name.c_str(), devpath, local_ep_in, + local_ep_out, interface->bInterfaceNumber, + device->iSerialNumber, zero_mask, max_packet_size); break; } } else { @@ -497,10 +500,13 @@ int usb_close(usb_handle* h) { return 0; } -static void register_device(const char* dev_name, const char* dev_path, - unsigned char ep_in, unsigned char ep_out, - int interface, int serial_index, - unsigned zero_mask) { +size_t usb_get_max_packet_size(usb_handle* h) { + return h->max_packet_size; +} + +static void register_device(const char* dev_name, const char* dev_path, unsigned char ep_in, + unsigned char ep_out, int interface, int serial_index, + unsigned zero_mask, size_t max_packet_size) { // Since Linux will not reassign the device ID (and dev_name) as long as the // device is open, we can add to the list here once we open it and remove // from the list when we're finally closed and everything will work out @@ -523,6 +529,7 @@ static void register_device(const char* dev_name, const char* dev_path, usb->ep_in = ep_in; usb->ep_out = ep_out; usb->zero_mask = zero_mask; + usb->max_packet_size = max_packet_size; // Initialize mark so we don't get garbage collected after the device scan. usb->mark = true; diff --git a/adb/client/usb_osx.cpp b/adb/client/usb_osx.cpp index fcd0bc044..e4a543bba 100644 --- a/adb/client/usb_osx.cpp +++ b/adb/client/usb_osx.cpp @@ -51,15 +51,21 @@ struct usb_handle UInt8 bulkOut; IOUSBInterfaceInterface190** interface; unsigned int zero_mask; + size_t max_packet_size; // For garbage collecting disconnected devices. bool mark; std::string devpath; std::atomic dead; - usb_handle() : bulkIn(0), bulkOut(0), interface(nullptr), - zero_mask(0), mark(false), dead(false) { - } + usb_handle() + : bulkIn(0), + bulkOut(0), + interface(nullptr), + zero_mask(0), + max_packet_size(0), + mark(false), + dead(false) {} }; static std::atomic usb_inited_flag; @@ -390,6 +396,7 @@ CheckInterface(IOUSBInterfaceInterface190 **interface, UInt16 vendor, UInt16 pro } handle->zero_mask = maxPacketSize - 1; + handle->max_packet_size = maxPacketSize; } handle->interface = interface; @@ -558,4 +565,9 @@ void usb_kick(usb_handle *handle) { std::lock_guard lock_guard(g_usb_handles_mutex); usb_kick_locked(handle); } + +size_t usb_get_max_packet_size(usb_handle* handle) { + return handle->max_packet_size; +} + } // namespace native diff --git a/adb/client/usb_windows.cpp b/adb/client/usb_windows.cpp index ee7f8024f..ec55b0e2a 100644 --- a/adb/client/usb_windows.cpp +++ b/adb/client/usb_windows.cpp @@ -65,6 +65,9 @@ struct usb_handle { /// Interface name wchar_t* interface_name; + /// Maximum packet size. + unsigned max_packet_size; + /// Mask for determining when to use zero length packets unsigned zero_mask; }; @@ -522,6 +525,10 @@ int usb_close(usb_handle* handle) { return 0; } +size_t usb_get_max_packet_size(usb_handle* handle) { + return handle->max_packet_size; +} + int recognized_device(usb_handle* handle) { if (NULL == handle) return 0; @@ -557,6 +564,7 @@ int recognized_device(usb_handle* handle) { AdbEndpointInformation endpoint_info; // assuming zero is a valid bulk endpoint ID if (AdbGetEndpointInformation(handle->adb_interface, 0, &endpoint_info)) { + handle->max_packet_size = endpoint_info.max_packet_size; handle->zero_mask = endpoint_info.max_packet_size - 1; D("device zero_mask: 0x%x", handle->zero_mask); } else { diff --git a/adb/test_device.py b/adb/test_device.py index a30972e54..e44cc83f0 100644 --- a/adb/test_device.py +++ b/adb/test_device.py @@ -1259,6 +1259,26 @@ class DeviceOfflineTest(DeviceTest): self.assertEqual(self._get_device_state(serialno), 'device') + def test_packet_size_regression(self): + """Test for http://b/37783561 + + Receiving packets of a length divisible by 512 but not 1024 resulted in + the adb client waiting indefinitely for more input. + """ + # The values that trigger things are 507 (512 - 5 bytes from shell protocol) + 1024*n + # Probe some surrounding values as well, for the hell of it. + for length in [506, 507, 508, 1018, 1019, 1020, 1530, 1531, 1532]: + cmd = ['dd', 'if=/dev/zero', 'bs={}'.format(length), 'count=1', '2>/dev/null;' + 'echo', 'foo'] + rc, stdout, _ = self.device.shell_nocheck(cmd) + + self.assertEqual(0, rc) + + # Output should be '\0' * length, followed by "foo\n" + self.assertEqual(length, len(stdout) - 4) + self.assertEqual(stdout, "\0" * length + "foo\n") + + def main(): random.seed(0) if len(adb.get_devices()) > 0: diff --git a/adb/transport_usb.cpp b/adb/transport_usb.cpp index ce419b88d..885d7230e 100644 --- a/adb/transport_usb.cpp +++ b/adb/transport_usb.cpp @@ -27,57 +27,43 @@ #if ADB_HOST -static constexpr size_t MAX_USB_BULK_PACKET_SIZE = 1024u; - -// Call usb_read using a buffer having a multiple of MAX_USB_BULK_PACKET_SIZE bytes +// Call usb_read using a buffer having a multiple of usb_get_max_packet_size() bytes // to avoid overflow. See http://libusb.sourceforge.net/api-1.0/packetoverflow.html. static int UsbReadMessage(usb_handle* h, amessage* msg) { D("UsbReadMessage"); - char buffer[MAX_USB_BULK_PACKET_SIZE]; - int n = usb_read(h, buffer, sizeof(buffer)); - if (n == sizeof(*msg)) { - memcpy(msg, buffer, sizeof(*msg)); + + size_t usb_packet_size = usb_get_max_packet_size(h); + CHECK(usb_packet_size >= sizeof(*msg)); + CHECK(usb_packet_size < 4096); + + char buffer[4096]; + int n = usb_read(h, buffer, usb_packet_size); + if (n != sizeof(*msg)) { + D("usb_read returned unexpected length %d (expected %zu)", n, sizeof(*msg)); + return -1; } + memcpy(msg, buffer, sizeof(*msg)); return n; } -// Call usb_read using a buffer having a multiple of MAX_USB_BULK_PACKET_SIZE bytes +// Call usb_read using a buffer having a multiple of usb_get_max_packet_size() bytes // to avoid overflow. See http://libusb.sourceforge.net/api-1.0/packetoverflow.html. static int UsbReadPayload(usb_handle* h, apacket* p) { - D("UsbReadPayload"); - size_t need_size = p->msg.data_length; - size_t data_pos = 0u; - while (need_size > 0u) { - int n = 0; - if (data_pos + MAX_USB_BULK_PACKET_SIZE <= sizeof(p->data)) { - // Read directly to p->data. - size_t rem_size = need_size % MAX_USB_BULK_PACKET_SIZE; - size_t direct_read_size = need_size - rem_size; - if (rem_size && - data_pos + direct_read_size + MAX_USB_BULK_PACKET_SIZE <= sizeof(p->data)) { - direct_read_size += MAX_USB_BULK_PACKET_SIZE; - } - n = usb_read(h, &p->data[data_pos], direct_read_size); - if (n < 0) { - D("usb_read(size %zu) failed", direct_read_size); - return n; - } - } else { - // Read indirectly using a buffer. - char buffer[MAX_USB_BULK_PACKET_SIZE]; - n = usb_read(h, buffer, sizeof(buffer)); - if (n < 0) { - D("usb_read(size %zu) failed", sizeof(buffer)); - return -1; - } - size_t copy_size = std::min(static_cast(n), need_size); - D("usb read %d bytes, need %zu bytes, copy %zu bytes", n, need_size, copy_size); - memcpy(&p->data[data_pos], buffer, copy_size); - } - data_pos += n; - need_size -= std::min(static_cast(n), need_size); + D("UsbReadPayload(%d)", p->msg.data_length); + + size_t usb_packet_size = usb_get_max_packet_size(h); + CHECK(sizeof(p->data) % usb_packet_size == 0); + + // Round the data length up to the nearest packet size boundary. + // The device won't send a zero packet for packet size aligned payloads, + // so don't read any more packets than needed. + size_t len = p->msg.data_length; + size_t rem_size = len % usb_packet_size; + if (rem_size) { + len += usb_packet_size - rem_size; } - return static_cast(data_pos); + CHECK(len <= sizeof(p->data)); + return usb_read(h, &p->data, len); } static int remote_read(apacket* p, atransport* t) { diff --git a/adb/usb.h b/adb/usb.h index ba70de43e..e867ec8a3 100644 --- a/adb/usb.h +++ b/adb/usb.h @@ -16,6 +16,8 @@ #pragma once +#include + // USB host/client interface. #define ADB_USB_INTERFACE(handle_ref_type) \ @@ -23,7 +25,8 @@ int usb_write(handle_ref_type h, const void* data, int len); \ int usb_read(handle_ref_type h, void* data, int len); \ int usb_close(handle_ref_type h); \ - void usb_kick(handle_ref_type h) + void usb_kick(handle_ref_type h); \ + size_t usb_get_max_packet_size(handle_ref_type) #if defined(_WIN32) || !ADB_HOST // Windows and the daemon have a single implementation.