diff --git a/adb/socket.h b/adb/socket.h index f9ad4f881..03927315a 100644 --- a/adb/socket.h +++ b/adb/socket.h @@ -113,7 +113,8 @@ void connect_to_smartsocket(asocket *s); namespace internal { #if ADB_HOST -char* skip_host_serial(char* service); +bool parse_host_service(std::string_view* out_serial, std::string_view* out_command, + std::string_view service); #endif } // namespace internal diff --git a/adb/socket_test.cpp b/adb/socket_test.cpp index 04214a20b..80f9430b7 100644 --- a/adb/socket_test.cpp +++ b/adb/socket_test.cpp @@ -34,6 +34,9 @@ #include "sysdeps.h" #include "sysdeps/chrono.h" +using namespace std::string_literals; +using namespace std::string_view_literals; + struct ThreadArg { int first_read_fd; int last_write_fd; @@ -303,56 +306,78 @@ TEST_F(LocalSocketTest, close_socket_in_CLOSE_WAIT_state) { #if ADB_HOST -// Checks that skip_host_serial(serial) returns a pointer to the part of |serial| which matches -// |expected|, otherwise logs the failure to gtest. -void VerifySkipHostSerial(std::string serial, const char* expected) { - char* result = internal::skip_host_serial(&serial[0]); - if (expected == nullptr) { - EXPECT_EQ(nullptr, result); - } else { - EXPECT_STREQ(expected, result); - } -} +#define VerifyParseHostServiceFailed(s) \ + do { \ + std::string service(s); \ + std::string_view serial, command; \ + bool result = internal::parse_host_service(&serial, &command, service); \ + EXPECT_FALSE(result); \ + } while (0) + +#define VerifyParseHostService(s, expected_serial, expected_command) \ + do { \ + std::string service(s); \ + std::string_view serial, command; \ + bool result = internal::parse_host_service(&serial, &command, service); \ + EXPECT_TRUE(result); \ + EXPECT_EQ(std::string(expected_serial), std::string(serial)); \ + EXPECT_EQ(std::string(expected_command), std::string(command)); \ + } while (0); // Check [tcp:|udp:][:]: format. -TEST(socket_test, test_skip_host_serial) { +TEST(socket_test, test_parse_host_service) { for (const std::string& protocol : {"", "tcp:", "udp:"}) { - VerifySkipHostSerial(protocol, nullptr); - VerifySkipHostSerial(protocol + "foo", nullptr); + VerifyParseHostServiceFailed(protocol); + VerifyParseHostServiceFailed(protocol + "foo"); - VerifySkipHostSerial(protocol + "foo:bar", ":bar"); - VerifySkipHostSerial(protocol + "foo:bar:baz", ":bar:baz"); + { + std::string serial = protocol + "foo"; + VerifyParseHostService(serial + ":bar", serial, "bar"); + VerifyParseHostService(serial + " :bar:baz", serial, "bar:baz"); + } - VerifySkipHostSerial(protocol + "foo:123:bar", ":bar"); - VerifySkipHostSerial(protocol + "foo:123:456", ":456"); - VerifySkipHostSerial(protocol + "foo:123:bar:baz", ":bar:baz"); + { + // With port. + std::string serial = protocol + "foo:123"; + VerifyParseHostService(serial + ":bar", serial, "bar"); + VerifyParseHostService(serial + ":456", serial, "456"); + VerifyParseHostService(serial + ":bar:baz", serial, "bar:baz"); + } // Don't register a port unless it's all numbers and ends with ':'. - VerifySkipHostSerial(protocol + "foo:123", ":123"); - VerifySkipHostSerial(protocol + "foo:123bar:baz", ":123bar:baz"); + VerifyParseHostService(protocol + "foo:123", protocol + "foo", "123"); + VerifyParseHostService(protocol + "foo:123bar:baz", protocol + "foo", "123bar:baz"); - VerifySkipHostSerial(protocol + "100.100.100.100:5555:foo", ":foo"); - VerifySkipHostSerial(protocol + "[0123:4567:89ab:CDEF:0:9:a:f]:5555:foo", ":foo"); - VerifySkipHostSerial(protocol + "[::1]:5555:foo", ":foo"); + std::string addresses[] = {"100.100.100.100", "[0123:4567:89ab:CDEF:0:9:a:f]", "[::1]"}; + for (const std::string& address : addresses) { + std::string serial = protocol + address; + std::string serial_with_port = protocol + address + ":5555"; + VerifyParseHostService(serial + ":foo", serial, "foo"); + VerifyParseHostService(serial_with_port + ":foo", serial_with_port, "foo"); + } // If we can't find both [] then treat it as a normal serial with [ in it. - VerifySkipHostSerial(protocol + "[0123:foo", ":foo"); + VerifyParseHostService(protocol + "[0123:foo", protocol + "[0123", "foo"); // Don't be fooled by random IPv6 addresses in the command string. - VerifySkipHostSerial(protocol + "foo:ping [0123:4567:89ab:CDEF:0:9:a:f]:5555", - ":ping [0123:4567:89ab:CDEF:0:9:a:f]:5555"); + VerifyParseHostService(protocol + "foo:ping [0123:4567:89ab:CDEF:0:9:a:f]:5555", + protocol + "foo", "ping [0123:4567:89ab:CDEF:0:9:a:f]:5555"); + + // Handle embedded NULs properly. + VerifyParseHostService(protocol + "foo:echo foo\0bar"s, protocol + "foo", + "echo foo\0bar"sv); } } // Check :: format. -TEST(socket_test, test_skip_host_serial_prefix) { +TEST(socket_test, test_parse_host_service_prefix) { for (const std::string& prefix : {"usb:", "product:", "model:", "device:"}) { - VerifySkipHostSerial(prefix, nullptr); - VerifySkipHostSerial(prefix + "foo", nullptr); + VerifyParseHostServiceFailed(prefix); + VerifyParseHostServiceFailed(prefix + "foo"); - VerifySkipHostSerial(prefix + "foo:bar", ":bar"); - VerifySkipHostSerial(prefix + "foo:bar:baz", ":bar:baz"); - VerifySkipHostSerial(prefix + "foo:123:bar", ":123:bar"); + VerifyParseHostService(prefix + "foo:bar", prefix + "foo", "bar"); + VerifyParseHostService(prefix + "foo:bar:baz", prefix + "foo", "bar:baz"); + VerifyParseHostService(prefix + "foo:123:bar", prefix + "foo", "123:bar"); } } diff --git a/adb/sockets.cpp b/adb/sockets.cpp index cb8cd16ff..676ef4402 100644 --- a/adb/sockets.cpp +++ b/adb/sockets.cpp @@ -37,6 +37,7 @@ #include "adb.h" #include "adb_io.h" +#include "adb_utils.h" #include "transport.h" #include "types.h" @@ -546,57 +547,119 @@ static unsigned unhex(const char* s, int len) { namespace internal { -// Returns the position in |service| following the target serial parameter. Serial format can be -// any of: +// Parses a host service string of the following format: // * [tcp:|udp:][:]: // * :: // Where must be a base-10 number and may be any of {usb,product,model,device}. -// -// The returned pointer will point to the ':' just before , or nullptr if not found. -char* skip_host_serial(char* service) { - static const std::vector& prefixes = - *(new std::vector{"usb:", "product:", "model:", "device:"}); +bool parse_host_service(std::string_view* out_serial, std::string_view* out_command, + std::string_view full_service) { + if (full_service.empty()) { + return false; + } - for (const std::string& prefix : prefixes) { - if (!strncmp(service, prefix.c_str(), prefix.length())) { - return strchr(service + prefix.length(), ':'); + std::string_view serial; + std::string_view command = full_service; + // Remove |count| bytes from the beginning of command and add them to |serial|. + auto consume = [&full_service, &serial, &command](size_t count) { + CHECK_LE(count, command.size()); + if (!serial.empty()) { + CHECK_EQ(serial.data() + serial.size(), command.data()); + } + + serial = full_service.substr(0, serial.size() + count); + command.remove_prefix(count); + }; + + // Remove the trailing : from serial, and assign the values to the output parameters. + auto finish = [out_serial, out_command, &serial, &command] { + if (serial.empty() || command.empty()) { + return false; + } + + CHECK_EQ(':', serial.back()); + serial.remove_suffix(1); + + *out_serial = serial; + *out_command = command; + return true; + }; + + static constexpr std::string_view prefixes[] = {"usb:", "product:", "model:", "device:"}; + for (std::string_view prefix : prefixes) { + if (command.starts_with(prefix)) { + consume(prefix.size()); + + size_t offset = command.find_first_of(':'); + if (offset == std::string::npos) { + return false; + } + consume(offset + 1); + return finish(); } } // For fastboot compatibility, ignore protocol prefixes. - if (!strncmp(service, "tcp:", 4) || !strncmp(service, "udp:", 4)) { - service += 4; - } - - // Check for an IPv6 address. `adb connect` creates the serial number from the canonical - // network address so it will always have the [] delimiters. - if (service[0] == '[') { - char* ipv6_end = strchr(service, ']'); - if (ipv6_end != nullptr) { - service = ipv6_end; + if (command.starts_with("tcp:") || command.starts_with("udp:")) { + consume(4); + if (command.empty()) { + return false; } } - // The next colon we find must either begin the port field or the command field. - char* colon_ptr = strchr(service, ':'); - if (!colon_ptr) { - // No colon in service string. - return nullptr; + bool found_address = false; + if (command[0] == '[') { + // Read an IPv6 address. `adb connect` creates the serial number from the canonical + // network address so it will always have the [] delimiters. + size_t ipv6_end = command.find_first_of(']'); + if (ipv6_end != std::string::npos) { + consume(ipv6_end + 1); + if (command.empty()) { + // Nothing after the IPv6 address. + return false; + } else if (command[0] != ':') { + // Garbage after the IPv6 address. + return false; + } + consume(1); + found_address = true; + } } - // If the next field is only decimal digits and ends with another colon, it's a port. - char* serial_end = colon_ptr; - if (isdigit(serial_end[1])) { - serial_end++; - while (*serial_end && isdigit(*serial_end)) { - serial_end++; + if (!found_address) { + // Scan ahead to the next colon. + size_t offset = command.find_first_of(':'); + if (offset == std::string::npos) { + return false; } - if (*serial_end != ':') { - // Something other than ":" was found, this must be the command field instead. - serial_end = colon_ptr; + consume(offset + 1); + } + + // We're either at the beginning of a port, or the command itself. + // Look for a port in between colons. + size_t next_colon = command.find_first_of(':'); + if (next_colon == std::string::npos) { + // No colon, we must be at the command. + return finish(); + } + + bool port_valid = true; + if (command.size() <= next_colon) { + return false; + } + + std::string_view port = command.substr(0, next_colon); + for (auto digit : port) { + if (!isdigit(digit)) { + // Port isn't a number. + port_valid = false; + break; } } - return serial_end; + + if (port_valid) { + consume(next_colon + 1); + } + return finish(); } } // namespace internal @@ -605,8 +668,8 @@ char* skip_host_serial(char* service) { static int smart_socket_enqueue(asocket* s, apacket::payload_type data) { #if ADB_HOST - char* service = nullptr; - char* serial = nullptr; + std::string_view service; + std::string_view serial; TransportId transport_id = 0; TransportType type = kTransportAny; #endif @@ -643,49 +706,52 @@ static int smart_socket_enqueue(asocket* s, apacket::payload_type data) { D("SS(%d): '%s'", s->id, (char*)(s->smart_socket_data.data() + 4)); #if ADB_HOST - service = &s->smart_socket_data[4]; - if (!strncmp(service, "host-serial:", strlen("host-serial:"))) { - char* serial_end; - service += strlen("host-serial:"); + service = std::string_view(s->smart_socket_data).substr(4); + if (service.starts_with("host-serial:")) { + service.remove_prefix(strlen("host-serial:")); // serial number should follow "host:" and could be a host:port string. - serial_end = internal::skip_host_serial(service); - if (serial_end) { - *serial_end = 0; // terminate string - serial = service; - service = serial_end + 1; + if (!internal::parse_host_service(&serial, &service, service)) { + LOG(ERROR) << "SS(" << s->id << "): failed to parse host service: " << service; + goto fail; } - } else if (!strncmp(service, "host-transport-id:", strlen("host-transport-id:"))) { - service += strlen("host-transport-id:"); - transport_id = strtoll(service, &service, 10); - - if (*service != ':') { + } else if (service.starts_with("host-transport-id:")) { + service.remove_prefix(strlen("host-transport-id:")); + if (!ParseUint(&transport_id, service, &service)) { + LOG(ERROR) << "SS(" << s->id << "): failed to parse host transport id: " << service; return -1; } - service++; - } else if (!strncmp(service, "host-usb:", strlen("host-usb:"))) { + if (!service.starts_with(":")) { + LOG(ERROR) << "SS(" << s->id << "): host-transport-id without command"; + return -1; + } + service.remove_prefix(1); + } else if (service.starts_with("host-usb:")) { type = kTransportUsb; - service += strlen("host-usb:"); - } else if (!strncmp(service, "host-local:", strlen("host-local:"))) { + service.remove_prefix(strlen("host-usb:")); + } else if (service.starts_with("host-local:")) { type = kTransportLocal; - service += strlen("host-local:"); - } else if (!strncmp(service, "host:", strlen("host:"))) { + service.remove_prefix(strlen("host-local:")); + } else if (service.starts_with("host:")) { type = kTransportAny; - service += strlen("host:"); + service.remove_prefix(strlen("host:")); } else { - service = nullptr; + service = std::string_view{}; } - if (service) { + if (!service.empty()) { asocket* s2; // Some requests are handled immediately -- in that case the handle_host_request() routine // has sent the OKAY or FAIL message and all we have to do is clean up. - if (handle_host_request(service, type, serial, transport_id, s->peer->fd, s)) { - D("SS(%d): handled host service '%s'", s->id, service); + // TODO: Convert to string_view. + if (handle_host_request(std::string(service).c_str(), type, + serial.empty() ? nullptr : std::string(serial).c_str(), + transport_id, s->peer->fd, s)) { + LOG(VERBOSE) << "SS(" << s->id << "): handled host service '" << service << "'"; goto fail; } - if (!strncmp(service, "transport", strlen("transport"))) { + if (service.starts_with("transport")) { D("SS(%d): okay transport", s->id); s->smart_socket_data.clear(); return 0; @@ -695,9 +761,11 @@ static int smart_socket_enqueue(asocket* s, apacket::payload_type data) { ** if no such service exists, we'll fail out ** and tear down here. */ - s2 = create_host_service_socket(service, serial, transport_id); + // TODO: Convert to string_view. + s2 = create_host_service_socket(std::string(service).c_str(), std::string(serial).c_str(), + transport_id); if (s2 == nullptr) { - D("SS(%d): couldn't create host service '%s'", s->id, service); + LOG(VERBOSE) << "SS(" << s->id << "): couldn't create host service '" << service << "'"; SendFail(s->peer->fd, "unknown host service"); goto fail; }