diff --git a/adb/adb.cpp b/adb/adb.cpp index 32fbb6555..3c0788254 100644 --- a/adb/adb.cpp +++ b/adb/adb.cpp @@ -1018,8 +1018,9 @@ static int SendOkay(int fd, const std::string& s) { return 0; } -bool handle_host_request(std::string_view service, TransportType type, const char* serial, - TransportId transport_id, int reply_fd, asocket* s) { +HostRequestResult handle_host_request(std::string_view service, TransportType type, + const char* serial, TransportId transport_id, int reply_fd, + asocket* s) { if (service == "kill") { fprintf(stderr, "adb server killed by remote request\n"); fflush(stdout); @@ -1032,29 +1033,49 @@ bool handle_host_request(std::string_view service, TransportType type, const cha exit(0); } - // "transport:" is used for switching transport with a specified serial number - // "transport-usb:" is used for switching transport to the only USB transport - // "transport-local:" is used for switching transport to the only local transport - // "transport-any:" is used for switching transport to the only transport - if (service.starts_with("transport")) { + LOG(DEBUG) << "handle_host_request(" << service << ")"; + + // Transport selection: + if (service.starts_with("transport") || service.starts_with("tport:")) { TransportType type = kTransportAny; std::string serial_storage; + bool legacy = true; - if (ConsumePrefix(&service, "transport-id:")) { - if (!ParseUint(&transport_id, service)) { - SendFail(reply_fd, "invalid transport id"); - return true; + // New transport selection protocol: + // This is essentially identical to the previous version, except it returns the selected + // transport id to the caller as well. + if (ConsumePrefix(&service, "tport:")) { + legacy = false; + if (ConsumePrefix(&service, "serial:")) { + serial_storage = service; + serial = serial_storage.c_str(); + } else if (service == "usb") { + type = kTransportUsb; + } else if (service == "local") { + type = kTransportLocal; + } else if (service == "any") { + type = kTransportAny; + } + + // Selection by id is unimplemented, since you obviously already know the transport id + // you're connecting to. + } else { + if (ConsumePrefix(&service, "transport-id:")) { + if (!ParseUint(&transport_id, service)) { + SendFail(reply_fd, "invalid transport id"); + return HostRequestResult::Handled; + } + } else if (service == "transport-usb") { + type = kTransportUsb; + } else if (service == "transport-local") { + type = kTransportLocal; + } else if (service == "transport-any") { + type = kTransportAny; + } else if (ConsumePrefix(&service, "transport:")) { + serial_storage = service; + serial = serial_storage.c_str(); } - } else if (service == "transport-usb") { - type = kTransportUsb; - } else if (service == "transport-local") { - type = kTransportLocal; - } else if (service == "transport-any") { - type = kTransportAny; - } else if (ConsumePrefix(&service, "transport:")) { - serial_storage = service; - serial = serial_storage.c_str(); } std::string error; @@ -1063,11 +1084,15 @@ bool handle_host_request(std::string_view service, TransportType type, const cha s->transport = t; SendOkay(reply_fd); - // We succesfully handled the device selection, but there's another request coming. - return false; + if (!legacy) { + // Nothing we can do if this fails. + WriteFdExactly(reply_fd, &t->id, sizeof(t->id)); + } + + return HostRequestResult::SwitchedTransport; } else { SendFail(reply_fd, error); - return true; + return HostRequestResult::Handled; } } @@ -1078,7 +1103,7 @@ bool handle_host_request(std::string_view service, TransportType type, const cha std::string device_list = list_transports(long_listing); D("Sending device list..."); SendOkay(reply_fd, device_list); - return true; + return HostRequestResult::Handled; } if (service == "reconnect-offline") { @@ -1094,7 +1119,7 @@ bool handle_host_request(std::string_view service, TransportType type, const cha response.resize(response.size() - 1); } SendOkay(reply_fd, response); - return true; + return HostRequestResult::Handled; } if (service == "features") { @@ -1105,7 +1130,7 @@ bool handle_host_request(std::string_view service, TransportType type, const cha } else { SendFail(reply_fd, error); } - return true; + return HostRequestResult::Handled; } if (service == "host-features") { @@ -1116,7 +1141,7 @@ bool handle_host_request(std::string_view service, TransportType type, const cha } features.insert(kFeaturePushSync); SendOkay(reply_fd, FeatureSetToString(features)); - return true; + return HostRequestResult::Handled; } // remove TCP transport @@ -1125,7 +1150,7 @@ bool handle_host_request(std::string_view service, TransportType type, const cha if (address.empty()) { kick_all_tcp_devices(); SendOkay(reply_fd, "disconnected everything"); - return true; + return HostRequestResult::Handled; } std::string serial; @@ -1137,22 +1162,22 @@ bool handle_host_request(std::string_view service, TransportType type, const cha } else if (!android::base::ParseNetAddress(address, &host, &port, &serial, &error)) { SendFail(reply_fd, android::base::StringPrintf("couldn't parse '%s': %s", address.c_str(), error.c_str())); - return true; + return HostRequestResult::Handled; } atransport* t = find_transport(serial.c_str()); if (t == nullptr) { SendFail(reply_fd, android::base::StringPrintf("no such device '%s'", serial.c_str())); - return true; + return HostRequestResult::Handled; } kick_transport(t); SendOkay(reply_fd, android::base::StringPrintf("disconnected %s", address.c_str())); - return true; + return HostRequestResult::Handled; } // Returns our value for ADB_SERVER_VERSION. if (service == "version") { SendOkay(reply_fd, android::base::StringPrintf("%04x", ADB_SERVER_VERSION)); - return true; + return HostRequestResult::Handled; } // These always report "unknown" rather than the actual error, for scripts. @@ -1164,7 +1189,7 @@ bool handle_host_request(std::string_view service, TransportType type, const cha } else { SendFail(reply_fd, error); } - return true; + return HostRequestResult::Handled; } if (service == "get-devpath") { std::string error; @@ -1174,7 +1199,7 @@ bool handle_host_request(std::string_view service, TransportType type, const cha } else { SendFail(reply_fd, error); } - return true; + return HostRequestResult::Handled; } if (service == "get-state") { std::string error; @@ -1184,7 +1209,7 @@ bool handle_host_request(std::string_view service, TransportType type, const cha } else { SendFail(reply_fd, error); } - return true; + return HostRequestResult::Handled; } // Indicates a new emulator instance has started. @@ -1197,7 +1222,7 @@ bool handle_host_request(std::string_view service, TransportType type, const cha } /* we don't even need to send a reply */ - return true; + return HostRequestResult::Handled; } if (service == "reconnect") { @@ -1209,7 +1234,7 @@ bool handle_host_request(std::string_view service, TransportType type, const cha "reconnecting " + t->serial_name() + " [" + t->connection_state_name() + "]\n"; } SendOkay(reply_fd, response); - return true; + return HostRequestResult::Handled; } // TODO: Switch handle_forward_request to string_view. @@ -1220,10 +1245,10 @@ bool handle_host_request(std::string_view service, TransportType type, const cha return acquire_one_transport(type, serial, transport_id, nullptr, error); }, reply_fd)) { - return true; + return HostRequestResult::Handled; } - return false; + return HostRequestResult::Unhandled; } static auto& init_mutex = *new std::mutex(); diff --git a/adb/adb.h b/adb/adb.h index f575adb47..5eea8bea7 100644 --- a/adb/adb.h +++ b/adb/adb.h @@ -219,8 +219,15 @@ extern const char* adb_device_banner; #define USB_FFS_ADB_IN USB_FFS_ADB_EP(ep2) #endif -bool handle_host_request(std::string_view service, TransportType type, const char* serial, - TransportId transport_id, int reply_fd, asocket* s); +enum class HostRequestResult { + Handled, + SwitchedTransport, + Unhandled, +}; + +HostRequestResult handle_host_request(std::string_view service, TransportType type, + const char* serial, TransportId transport_id, int reply_fd, + asocket* s); void handle_online(atransport* t); void handle_offline(atransport* t); diff --git a/adb/client/adb_client.cpp b/adb/client/adb_client.cpp index 0a09d1ee5..4cf3a743a 100644 --- a/adb/client/adb_client.cpp +++ b/adb/client/adb_client.cpp @@ -70,46 +70,60 @@ void adb_set_socket_spec(const char* socket_spec) { __adb_server_socket_spec = socket_spec; } -static int switch_socket_transport(int fd, std::string* error) { +static std::optional switch_socket_transport(int fd, std::string* error) { + TransportId result; + bool read_transport = true; + std::string service; if (__adb_transport_id) { + read_transport = false; service += "host:transport-id:"; service += std::to_string(__adb_transport_id); + result = __adb_transport_id; } else if (__adb_serial) { - service += "host:transport:"; + service += "host:tport:serial:"; service += __adb_serial; } else { const char* transport_type = "???"; switch (__adb_transport) { case kTransportUsb: - transport_type = "transport-usb"; - break; + transport_type = "usb"; + break; case kTransportLocal: - transport_type = "transport-local"; - break; + transport_type = "local"; + break; case kTransportAny: - transport_type = "transport-any"; - break; + transport_type = "any"; + break; case kTransportHost: // no switch necessary return 0; } - service += "host:"; + service += "host:tport:"; service += transport_type; } if (!SendProtocolString(fd, service)) { *error = perror_str("write failure during connection"); - return -1; + return std::nullopt; } - D("Switch transport in progress"); + + LOG(DEBUG) << "Switch transport in progress: " << service; if (!adb_status(fd, error)) { D("Switch transport failed: %s", error->c_str()); - return -1; + return std::nullopt; } + + if (read_transport) { + if (!ReadFdExactly(fd, &result, sizeof(result))) { + *error = "failed to read transport id from server"; + return std::nullopt; + } + } + D("Switch transport success"); - return 0; + return result; } bool adb_status(int fd, std::string* error) { @@ -133,11 +147,10 @@ bool adb_status(int fd, std::string* error) { return false; } -static int _adb_connect(const std::string& service, std::string* error) { - D("_adb_connect: %s", service.c_str()); +static int _adb_connect(std::string_view service, TransportId* transport, std::string* error) { + LOG(DEBUG) << "_adb_connect: " << service; if (service.empty() || service.size() > MAX_PAYLOAD) { - *error = android::base::StringPrintf("bad service name length (%zd)", - service.size()); + *error = android::base::StringPrintf("bad service name length (%zd)", service.size()); return -1; } @@ -149,8 +162,15 @@ static int _adb_connect(const std::string& service, std::string* error) { return -2; } - if (memcmp(&service[0], "host", 4) != 0 && switch_socket_transport(fd.get(), error)) { - return -1; + if (!service.starts_with("host")) { + std::optional transport_result = switch_socket_transport(fd.get(), error); + if (!transport_result) { + return -1; + } + + if (transport) { + *transport = *transport_result; + } } if (!SendProtocolString(fd.get(), service)) { @@ -190,11 +210,15 @@ bool adb_kill_server() { return true; } -int adb_connect(const std::string& service, std::string* error) { - // first query the adb server's version - unique_fd fd(_adb_connect("host:version", error)); +int adb_connect(std::string_view service, std::string* error) { + return adb_connect(nullptr, service, error); +} - D("adb_connect: service %s", service.c_str()); +int adb_connect(TransportId* transport, std::string_view service, std::string* error) { + // first query the adb server's version + unique_fd fd(_adb_connect("host:version", nullptr, error)); + + LOG(DEBUG) << "adb_connect: service: " << service; if (fd == -2 && !is_local_socket_spec(__adb_server_socket_spec)) { fprintf(stderr, "* cannot start server on remote host\n"); // error is the original network connection error @@ -216,7 +240,7 @@ int adb_connect(const std::string& service, std::string* error) { // Fall through to _adb_connect. } else { // If a server is already running, check its version matches. - int version = ADB_SERVER_VERSION - 1; + int version = 0; // If we have a file descriptor, then parse version result. if (fd >= 0) { @@ -254,7 +278,7 @@ int adb_connect(const std::string& service, std::string* error) { return 0; } - fd.reset(_adb_connect(service, error)); + fd.reset(_adb_connect(service, transport, error)); if (fd == -1) { D("_adb_connect error: %s", error->c_str()); } else if(fd == -2) { @@ -265,7 +289,6 @@ int adb_connect(const std::string& service, std::string* error) { return fd.release(); } - bool adb_command(const std::string& service) { std::string error; unique_fd fd(adb_connect(service, &error)); diff --git a/adb/client/adb_client.h b/adb/client/adb_client.h index d4675396f..0a7378770 100644 --- a/adb/client/adb_client.h +++ b/adb/client/adb_client.h @@ -24,7 +24,10 @@ // Connect to adb, connect to the named service, and return a valid fd for // interacting with that service upon success or a negative number on failure. -int adb_connect(const std::string& service, std::string* _Nonnull error); +int adb_connect(std::string_view service, std::string* _Nonnull error); + +// Same as above, except returning the TransportId for the service that we've connected to. +int adb_connect(TransportId* _Nullable id, std::string_view service, std::string* _Nonnull error); // Kill the currently running adb server, if it exists. bool adb_kill_server(); diff --git a/adb/sockets.cpp b/adb/sockets.cpp index 04d92db60..dc4402612 100644 --- a/adb/sockets.cpp +++ b/adb/sockets.cpp @@ -792,16 +792,22 @@ static int smart_socket_enqueue(asocket* s, apacket::payload_type data) { // Some requests are handled immediately -- in that case the handle_host_request() routine // has sent the OKAY or FAIL message and all we have to do is clean up. - if (handle_host_request(service, type, - serial.empty() ? nullptr : std::string(serial).c_str(), - transport_id, s->peer->fd, s)) { - LOG(VERBOSE) << "SS(" << s->id << "): handled host service '" << service << "'"; - goto fail; - } - if (service.starts_with("transport")) { - D("SS(%d): okay transport", s->id); - s->smart_socket_data.clear(); - return 0; + auto host_request_result = handle_host_request( + service, type, serial.empty() ? nullptr : std::string(serial).c_str(), transport_id, + s->peer->fd, s); + + switch (host_request_result) { + case HostRequestResult::Handled: + LOG(VERBOSE) << "SS(" << s->id << "): handled host service '" << service << "'"; + goto fail; + + case HostRequestResult::SwitchedTransport: + D("SS(%d): okay transport", s->id); + s->smart_socket_data.clear(); + return 0; + + case HostRequestResult::Unhandled: + break; } /* try to find a local service with this name.