Merge changes Icce121a4,I0f95d348

am: d2bd2edf25

Change-Id: I023bf73c1c364affb29285701b76826240c7fba1
This commit is contained in:
Josh Gao 2018-12-26 18:19:25 -08:00 committed by android-build-merger
commit 6707ae0678
4 changed files with 202 additions and 105 deletions

View file

@ -105,7 +105,7 @@ static void listener_event_func(int _fd, unsigned ev, void* _l)
s = create_local_socket(fd); s = create_local_socket(fd);
if (s) { if (s) {
s->transport = listener->transport; s->transport = listener->transport;
connect_to_remote(s, listener->connect_to.c_str()); connect_to_remote(s, listener->connect_to);
return; return;
} }

View file

@ -106,14 +106,15 @@ asocket *create_local_socket(int fd);
asocket* create_local_service_socket(std::string_view destination, atransport* transport); asocket* create_local_service_socket(std::string_view destination, atransport* transport);
asocket *create_remote_socket(unsigned id, atransport *t); asocket *create_remote_socket(unsigned id, atransport *t);
void connect_to_remote(asocket *s, const char *destination); void connect_to_remote(asocket* s, std::string_view destination);
void connect_to_smartsocket(asocket *s); void connect_to_smartsocket(asocket *s);
// Internal functions that are only made available here for testing purposes. // Internal functions that are only made available here for testing purposes.
namespace internal { namespace internal {
#if ADB_HOST #if ADB_HOST
char* skip_host_serial(char* service); bool parse_host_service(std::string_view* out_serial, std::string_view* out_command,
std::string_view service);
#endif #endif
} // namespace internal } // namespace internal

View file

@ -34,6 +34,9 @@
#include "sysdeps.h" #include "sysdeps.h"
#include "sysdeps/chrono.h" #include "sysdeps/chrono.h"
using namespace std::string_literals;
using namespace std::string_view_literals;
struct ThreadArg { struct ThreadArg {
int first_read_fd; int first_read_fd;
int last_write_fd; int last_write_fd;
@ -303,56 +306,78 @@ TEST_F(LocalSocketTest, close_socket_in_CLOSE_WAIT_state) {
#if ADB_HOST #if ADB_HOST
// Checks that skip_host_serial(serial) returns a pointer to the part of |serial| which matches #define VerifyParseHostServiceFailed(s) \
// |expected|, otherwise logs the failure to gtest. do { \
void VerifySkipHostSerial(std::string serial, const char* expected) { std::string service(s); \
char* result = internal::skip_host_serial(&serial[0]); std::string_view serial, command; \
if (expected == nullptr) { bool result = internal::parse_host_service(&serial, &command, service); \
EXPECT_EQ(nullptr, result); EXPECT_FALSE(result); \
} else { } while (0)
EXPECT_STREQ(expected, result);
} #define VerifyParseHostService(s, expected_serial, expected_command) \
} do { \
std::string service(s); \
std::string_view serial, command; \
bool result = internal::parse_host_service(&serial, &command, service); \
EXPECT_TRUE(result); \
EXPECT_EQ(std::string(expected_serial), std::string(serial)); \
EXPECT_EQ(std::string(expected_command), std::string(command)); \
} while (0);
// Check [tcp:|udp:]<serial>[:<port>]:<command> format. // Check [tcp:|udp:]<serial>[:<port>]:<command> format.
TEST(socket_test, test_skip_host_serial) { TEST(socket_test, test_parse_host_service) {
for (const std::string& protocol : {"", "tcp:", "udp:"}) { for (const std::string& protocol : {"", "tcp:", "udp:"}) {
VerifySkipHostSerial(protocol, nullptr); VerifyParseHostServiceFailed(protocol);
VerifySkipHostSerial(protocol + "foo", nullptr); VerifyParseHostServiceFailed(protocol + "foo");
VerifySkipHostSerial(protocol + "foo:bar", ":bar"); {
VerifySkipHostSerial(protocol + "foo:bar:baz", ":bar:baz"); std::string serial = protocol + "foo";
VerifyParseHostService(serial + ":bar", serial, "bar");
VerifyParseHostService(serial + " :bar:baz", serial, "bar:baz");
}
VerifySkipHostSerial(protocol + "foo:123:bar", ":bar"); {
VerifySkipHostSerial(protocol + "foo:123:456", ":456"); // With port.
VerifySkipHostSerial(protocol + "foo:123:bar:baz", ":bar:baz"); std::string serial = protocol + "foo:123";
VerifyParseHostService(serial + ":bar", serial, "bar");
VerifyParseHostService(serial + ":456", serial, "456");
VerifyParseHostService(serial + ":bar:baz", serial, "bar:baz");
}
// Don't register a port unless it's all numbers and ends with ':'. // Don't register a port unless it's all numbers and ends with ':'.
VerifySkipHostSerial(protocol + "foo:123", ":123"); VerifyParseHostService(protocol + "foo:123", protocol + "foo", "123");
VerifySkipHostSerial(protocol + "foo:123bar:baz", ":123bar:baz"); VerifyParseHostService(protocol + "foo:123bar:baz", protocol + "foo", "123bar:baz");
VerifySkipHostSerial(protocol + "100.100.100.100:5555:foo", ":foo"); std::string addresses[] = {"100.100.100.100", "[0123:4567:89ab:CDEF:0:9:a:f]", "[::1]"};
VerifySkipHostSerial(protocol + "[0123:4567:89ab:CDEF:0:9:a:f]:5555:foo", ":foo"); for (const std::string& address : addresses) {
VerifySkipHostSerial(protocol + "[::1]:5555:foo", ":foo"); std::string serial = protocol + address;
std::string serial_with_port = protocol + address + ":5555";
VerifyParseHostService(serial + ":foo", serial, "foo");
VerifyParseHostService(serial_with_port + ":foo", serial_with_port, "foo");
}
// If we can't find both [] then treat it as a normal serial with [ in it. // If we can't find both [] then treat it as a normal serial with [ in it.
VerifySkipHostSerial(protocol + "[0123:foo", ":foo"); VerifyParseHostService(protocol + "[0123:foo", protocol + "[0123", "foo");
// Don't be fooled by random IPv6 addresses in the command string. // Don't be fooled by random IPv6 addresses in the command string.
VerifySkipHostSerial(protocol + "foo:ping [0123:4567:89ab:CDEF:0:9:a:f]:5555", VerifyParseHostService(protocol + "foo:ping [0123:4567:89ab:CDEF:0:9:a:f]:5555",
":ping [0123:4567:89ab:CDEF:0:9:a:f]:5555"); protocol + "foo", "ping [0123:4567:89ab:CDEF:0:9:a:f]:5555");
// Handle embedded NULs properly.
VerifyParseHostService(protocol + "foo:echo foo\0bar"s, protocol + "foo",
"echo foo\0bar"sv);
} }
} }
// Check <prefix>:<serial>:<command> format. // Check <prefix>:<serial>:<command> format.
TEST(socket_test, test_skip_host_serial_prefix) { TEST(socket_test, test_parse_host_service_prefix) {
for (const std::string& prefix : {"usb:", "product:", "model:", "device:"}) { for (const std::string& prefix : {"usb:", "product:", "model:", "device:"}) {
VerifySkipHostSerial(prefix, nullptr); VerifyParseHostServiceFailed(prefix);
VerifySkipHostSerial(prefix + "foo", nullptr); VerifyParseHostServiceFailed(prefix + "foo");
VerifySkipHostSerial(prefix + "foo:bar", ":bar"); VerifyParseHostService(prefix + "foo:bar", prefix + "foo", "bar");
VerifySkipHostSerial(prefix + "foo:bar:baz", ":bar:baz"); VerifyParseHostService(prefix + "foo:bar:baz", prefix + "foo", "bar:baz");
VerifySkipHostSerial(prefix + "foo:123:bar", ":123:bar"); VerifyParseHostService(prefix + "foo:123:bar", prefix + "foo", "123:bar");
} }
} }

View file

@ -37,6 +37,7 @@
#include "adb.h" #include "adb.h"
#include "adb_io.h" #include "adb_io.h"
#include "adb_utils.h"
#include "transport.h" #include "transport.h"
#include "types.h" #include "types.h"
@ -461,16 +462,19 @@ asocket* create_remote_socket(unsigned id, atransport* t) {
return s; return s;
} }
void connect_to_remote(asocket* s, const char* destination) { void connect_to_remote(asocket* s, std::string_view destination) {
D("Connect_to_remote call RS(%d) fd=%d", s->id, s->fd); D("Connect_to_remote call RS(%d) fd=%d", s->id, s->fd);
apacket* p = get_apacket(); apacket* p = get_apacket();
D("LS(%d): connect('%s')", s->id, destination); LOG(VERBOSE) << "LS(" << s->id << ": connect(" << destination << ")";
p->msg.command = A_OPEN; p->msg.command = A_OPEN;
p->msg.arg0 = s->id; p->msg.arg0 = s->id;
// adbd expects a null-terminated string. // adbd used to expect a null-terminated string.
p->payload.assign(destination, destination + strlen(destination) + 1); // Keep doing so to maintain backward compatibility.
p->payload.resize(destination.size() + 1);
memcpy(p->payload.data(), destination.data(), destination.size());
p->payload[destination.size()] = '\0';
p->msg.data_length = p->payload.size(); p->msg.data_length = p->payload.size();
CHECK_LE(p->msg.data_length, s->get_max_payload()); CHECK_LE(p->msg.data_length, s->get_max_payload());
@ -546,57 +550,119 @@ static unsigned unhex(const char* s, int len) {
namespace internal { namespace internal {
// Returns the position in |service| following the target serial parameter. Serial format can be // Parses a host service string of the following format:
// any of:
// * [tcp:|udp:]<serial>[:<port>]:<command> // * [tcp:|udp:]<serial>[:<port>]:<command>
// * <prefix>:<serial>:<command> // * <prefix>:<serial>:<command>
// Where <port> must be a base-10 number and <prefix> may be any of {usb,product,model,device}. // Where <port> must be a base-10 number and <prefix> may be any of {usb,product,model,device}.
// bool parse_host_service(std::string_view* out_serial, std::string_view* out_command,
// The returned pointer will point to the ':' just before <command>, or nullptr if not found. std::string_view full_service) {
char* skip_host_serial(char* service) { if (full_service.empty()) {
static const std::vector<std::string>& prefixes = return false;
*(new std::vector<std::string>{"usb:", "product:", "model:", "device:"}); }
for (const std::string& prefix : prefixes) { std::string_view serial;
if (!strncmp(service, prefix.c_str(), prefix.length())) { std::string_view command = full_service;
return strchr(service + prefix.length(), ':'); // Remove |count| bytes from the beginning of command and add them to |serial|.
auto consume = [&full_service, &serial, &command](size_t count) {
CHECK_LE(count, command.size());
if (!serial.empty()) {
CHECK_EQ(serial.data() + serial.size(), command.data());
}
serial = full_service.substr(0, serial.size() + count);
command.remove_prefix(count);
};
// Remove the trailing : from serial, and assign the values to the output parameters.
auto finish = [out_serial, out_command, &serial, &command] {
if (serial.empty() || command.empty()) {
return false;
}
CHECK_EQ(':', serial.back());
serial.remove_suffix(1);
*out_serial = serial;
*out_command = command;
return true;
};
static constexpr std::string_view prefixes[] = {"usb:", "product:", "model:", "device:"};
for (std::string_view prefix : prefixes) {
if (command.starts_with(prefix)) {
consume(prefix.size());
size_t offset = command.find_first_of(':');
if (offset == std::string::npos) {
return false;
}
consume(offset + 1);
return finish();
} }
} }
// For fastboot compatibility, ignore protocol prefixes. // For fastboot compatibility, ignore protocol prefixes.
if (!strncmp(service, "tcp:", 4) || !strncmp(service, "udp:", 4)) { if (command.starts_with("tcp:") || command.starts_with("udp:")) {
service += 4; consume(4);
} if (command.empty()) {
return false;
// Check for an IPv6 address. `adb connect` creates the serial number from the canonical
// network address so it will always have the [] delimiters.
if (service[0] == '[') {
char* ipv6_end = strchr(service, ']');
if (ipv6_end != nullptr) {
service = ipv6_end;
} }
} }
// The next colon we find must either begin the port field or the command field. bool found_address = false;
char* colon_ptr = strchr(service, ':'); if (command[0] == '[') {
if (!colon_ptr) { // Read an IPv6 address. `adb connect` creates the serial number from the canonical
// No colon in service string. // network address so it will always have the [] delimiters.
return nullptr; size_t ipv6_end = command.find_first_of(']');
if (ipv6_end != std::string::npos) {
consume(ipv6_end + 1);
if (command.empty()) {
// Nothing after the IPv6 address.
return false;
} else if (command[0] != ':') {
// Garbage after the IPv6 address.
return false;
}
consume(1);
found_address = true;
}
} }
// If the next field is only decimal digits and ends with another colon, it's a port. if (!found_address) {
char* serial_end = colon_ptr; // Scan ahead to the next colon.
if (isdigit(serial_end[1])) { size_t offset = command.find_first_of(':');
serial_end++; if (offset == std::string::npos) {
while (*serial_end && isdigit(*serial_end)) { return false;
serial_end++;
} }
if (*serial_end != ':') { consume(offset + 1);
// Something other than "<port>:" was found, this must be the command field instead. }
serial_end = colon_ptr;
// We're either at the beginning of a port, or the command itself.
// Look for a port in between colons.
size_t next_colon = command.find_first_of(':');
if (next_colon == std::string::npos) {
// No colon, we must be at the command.
return finish();
}
bool port_valid = true;
if (command.size() <= next_colon) {
return false;
}
std::string_view port = command.substr(0, next_colon);
for (auto digit : port) {
if (!isdigit(digit)) {
// Port isn't a number.
port_valid = false;
break;
} }
} }
return serial_end;
if (port_valid) {
consume(next_colon + 1);
}
return finish();
} }
} // namespace internal } // namespace internal
@ -605,8 +671,8 @@ char* skip_host_serial(char* service) {
static int smart_socket_enqueue(asocket* s, apacket::payload_type data) { static int smart_socket_enqueue(asocket* s, apacket::payload_type data) {
#if ADB_HOST #if ADB_HOST
char* service = nullptr; std::string_view service;
char* serial = nullptr; std::string_view serial;
TransportId transport_id = 0; TransportId transport_id = 0;
TransportType type = kTransportAny; TransportType type = kTransportAny;
#endif #endif
@ -643,49 +709,52 @@ static int smart_socket_enqueue(asocket* s, apacket::payload_type data) {
D("SS(%d): '%s'", s->id, (char*)(s->smart_socket_data.data() + 4)); D("SS(%d): '%s'", s->id, (char*)(s->smart_socket_data.data() + 4));
#if ADB_HOST #if ADB_HOST
service = &s->smart_socket_data[4]; service = std::string_view(s->smart_socket_data).substr(4);
if (!strncmp(service, "host-serial:", strlen("host-serial:"))) { if (service.starts_with("host-serial:")) {
char* serial_end; service.remove_prefix(strlen("host-serial:"));
service += strlen("host-serial:");
// serial number should follow "host:" and could be a host:port string. // serial number should follow "host:" and could be a host:port string.
serial_end = internal::skip_host_serial(service); if (!internal::parse_host_service(&serial, &service, service)) {
if (serial_end) { LOG(ERROR) << "SS(" << s->id << "): failed to parse host service: " << service;
*serial_end = 0; // terminate string goto fail;
serial = service;
service = serial_end + 1;
} }
} else if (!strncmp(service, "host-transport-id:", strlen("host-transport-id:"))) { } else if (service.starts_with("host-transport-id:")) {
service += strlen("host-transport-id:"); service.remove_prefix(strlen("host-transport-id:"));
transport_id = strtoll(service, &service, 10); if (!ParseUint(&transport_id, service, &service)) {
LOG(ERROR) << "SS(" << s->id << "): failed to parse host transport id: " << service;
if (*service != ':') {
return -1; return -1;
} }
service++; if (!service.starts_with(":")) {
} else if (!strncmp(service, "host-usb:", strlen("host-usb:"))) { LOG(ERROR) << "SS(" << s->id << "): host-transport-id without command";
return -1;
}
service.remove_prefix(1);
} else if (service.starts_with("host-usb:")) {
type = kTransportUsb; type = kTransportUsb;
service += strlen("host-usb:"); service.remove_prefix(strlen("host-usb:"));
} else if (!strncmp(service, "host-local:", strlen("host-local:"))) { } else if (service.starts_with("host-local:")) {
type = kTransportLocal; type = kTransportLocal;
service += strlen("host-local:"); service.remove_prefix(strlen("host-local:"));
} else if (!strncmp(service, "host:", strlen("host:"))) { } else if (service.starts_with("host:")) {
type = kTransportAny; type = kTransportAny;
service += strlen("host:"); service.remove_prefix(strlen("host:"));
} else { } else {
service = nullptr; service = std::string_view{};
} }
if (service) { if (!service.empty()) {
asocket* s2; asocket* s2;
// Some requests are handled immediately -- in that case the handle_host_request() routine // Some requests are handled immediately -- in that case the handle_host_request() routine
// has sent the OKAY or FAIL message and all we have to do is clean up. // has sent the OKAY or FAIL message and all we have to do is clean up.
if (handle_host_request(service, type, serial, transport_id, s->peer->fd, s)) { // TODO: Convert to string_view.
D("SS(%d): handled host service '%s'", s->id, service); if (handle_host_request(std::string(service).c_str(), type,
serial.empty() ? nullptr : std::string(serial).c_str(),
transport_id, s->peer->fd, s)) {
LOG(VERBOSE) << "SS(" << s->id << "): handled host service '" << service << "'";
goto fail; goto fail;
} }
if (!strncmp(service, "transport", strlen("transport"))) { if (service.starts_with("transport")) {
D("SS(%d): okay transport", s->id); D("SS(%d): okay transport", s->id);
s->smart_socket_data.clear(); s->smart_socket_data.clear();
return 0; return 0;
@ -695,9 +764,11 @@ static int smart_socket_enqueue(asocket* s, apacket::payload_type data) {
** if no such service exists, we'll fail out ** if no such service exists, we'll fail out
** and tear down here. ** and tear down here.
*/ */
s2 = create_host_service_socket(service, serial, transport_id); // TODO: Convert to string_view.
s2 = create_host_service_socket(std::string(service).c_str(), std::string(serial).c_str(),
transport_id);
if (s2 == nullptr) { if (s2 == nullptr) {
D("SS(%d): couldn't create host service '%s'", s->id, service); LOG(VERBOSE) << "SS(" << s->id << "): couldn't create host service '" << service << "'";
SendFail(s->peer->fd, "unknown host service"); SendFail(s->peer->fd, "unknown host service");
goto fail; goto fail;
} }
@ -758,7 +829,7 @@ static int smart_socket_enqueue(asocket* s, apacket::payload_type data) {
/* give him our transport and upref it */ /* give him our transport and upref it */
s->peer->transport = s->transport; s->peer->transport = s->transport;
connect_to_remote(s->peer, s->smart_socket_data.data() + 4); connect_to_remote(s->peer, std::string_view(s->smart_socket_data).substr(4));
s->peer = nullptr; s->peer = nullptr;
s->close(s); s->close(s);
return 1; return 1;