Merge "adb: fix two device offline problems."

am: 2e821078e4

Change-Id: Iffba674a90dae88610541fe93c8df751e9ad63d2
This commit is contained in:
Yabin Cui 2017-04-20 19:49:08 +00:00 committed by android-build-merger
commit ae6a3605d2
20 changed files with 356 additions and 108 deletions

View file

@ -253,6 +253,19 @@ void send_connect(atransport* t) {
send_packet(cp, t); send_packet(cp, t);
} }
#if ADB_HOST
void SendConnectOnHost(atransport* t) {
// Send an empty message before A_CNXN message. This is because the data toggle of the ep_out on
// host and ep_in on device may not be the same.
apacket* p = get_apacket();
CHECK(p);
send_packet(p, t);
send_connect(t);
}
#endif
// qual_overwrite is used to overwrite a qualifier string. dst is a // qual_overwrite is used to overwrite a qualifier string. dst is a
// pointer to a char pointer. It is assumed that if *dst is non-NULL, it // pointer to a char pointer. It is assumed that if *dst is non-NULL, it
// was malloc'ed and needs to freed. *dst will be set to a dup of src. // was malloc'ed and needs to freed. *dst will be set to a dup of src.
@ -299,29 +312,29 @@ void parse_banner(const std::string& banner, atransport* t) {
const std::string& type = pieces[0]; const std::string& type = pieces[0];
if (type == "bootloader") { if (type == "bootloader") {
D("setting connection_state to kCsBootloader"); D("setting connection_state to kCsBootloader");
t->connection_state = kCsBootloader; t->SetConnectionState(kCsBootloader);
update_transports(); update_transports();
} else if (type == "device") { } else if (type == "device") {
D("setting connection_state to kCsDevice"); D("setting connection_state to kCsDevice");
t->connection_state = kCsDevice; t->SetConnectionState(kCsDevice);
update_transports(); update_transports();
} else if (type == "recovery") { } else if (type == "recovery") {
D("setting connection_state to kCsRecovery"); D("setting connection_state to kCsRecovery");
t->connection_state = kCsRecovery; t->SetConnectionState(kCsRecovery);
update_transports(); update_transports();
} else if (type == "sideload") { } else if (type == "sideload") {
D("setting connection_state to kCsSideload"); D("setting connection_state to kCsSideload");
t->connection_state = kCsSideload; t->SetConnectionState(kCsSideload);
update_transports(); update_transports();
} else { } else {
D("setting connection_state to kCsHost"); D("setting connection_state to kCsHost");
t->connection_state = kCsHost; t->SetConnectionState(kCsHost);
} }
} }
static void handle_new_connection(atransport* t, apacket* p) { static void handle_new_connection(atransport* t, apacket* p) {
if (t->connection_state != kCsOffline) { if (t->GetConnectionState() != kCsOffline) {
t->connection_state = kCsOffline; t->SetConnectionState(kCsOffline);
handle_offline(t); handle_offline(t);
} }
@ -355,10 +368,10 @@ void handle_packet(apacket *p, atransport *t)
if (p->msg.arg0){ if (p->msg.arg0){
send_packet(p, t); send_packet(p, t);
#if ADB_HOST #if ADB_HOST
send_connect(t); SendConnectOnHost(t);
#endif #endif
} else { } else {
t->connection_state = kCsOffline; t->SetConnectionState(kCsOffline);
handle_offline(t); handle_offline(t);
send_packet(p, t); send_packet(p, t);
} }
@ -372,7 +385,9 @@ void handle_packet(apacket *p, atransport *t)
switch (p->msg.arg0) { switch (p->msg.arg0) {
#if ADB_HOST #if ADB_HOST
case ADB_AUTH_TOKEN: case ADB_AUTH_TOKEN:
t->connection_state = kCsUnauthorized; if (t->GetConnectionState() == kCsOffline) {
t->SetConnectionState(kCsUnauthorized);
}
send_auth_response(p->data, p->msg.data_length, t); send_auth_response(p->data, p->msg.data_length, t);
break; break;
#else #else
@ -391,7 +406,7 @@ void handle_packet(apacket *p, atransport *t)
break; break;
#endif #endif
default: default:
t->connection_state = kCsOffline; t->SetConnectionState(kCsOffline);
handle_offline(t); handle_offline(t);
break; break;
} }
@ -1032,7 +1047,6 @@ static int SendOkay(int fd, const std::string& s) {
SendProtocolString(fd, s); SendProtocolString(fd, s);
return 0; return 0;
} }
#endif
int handle_host_request(const char* service, TransportType type, int handle_host_request(const char* service, TransportType type,
const char* serial, int reply_fd, asocket* s) { const char* serial, int reply_fd, asocket* s) {
@ -1051,7 +1065,6 @@ int handle_host_request(const char* service, TransportType type,
android::base::quick_exit(0); android::base::quick_exit(0);
} }
#if ADB_HOST
// "transport:" is used for switching transport with a specified serial number // "transport:" is used for switching transport with a specified serial number
// "transport-usb:" is used for switching transport to the only USB transport // "transport-usb:" is used for switching transport to the only USB transport
// "transport-local:" is used for switching transport to the only local transport // "transport-local:" is used for switching transport to the only local transport
@ -1096,16 +1109,10 @@ int handle_host_request(const char* service, TransportType type,
if (!strcmp(service, "reconnect-offline")) { if (!strcmp(service, "reconnect-offline")) {
std::string response; std::string response;
close_usb_devices([&response](const atransport* transport) { close_usb_devices([&response](const atransport* transport) {
switch (transport->connection_state) { switch (transport->GetConnectionState()) {
case kCsOffline: case kCsOffline:
case kCsUnauthorized: case kCsUnauthorized:
response += "reconnecting "; response += "reconnecting " + transport->serial_name() + "\n";
if (transport->serial) {
response += transport->serial;
} else {
response += "<unknown>";
}
response += "\n";
return true; return true;
default: default:
return false; return false;
@ -1129,7 +1136,6 @@ int handle_host_request(const char* service, TransportType type,
return 0; return 0;
} }
#if ADB_HOST
if (!strcmp(service, "host-features")) { if (!strcmp(service, "host-features")) {
FeatureSet features = supported_features(); FeatureSet features = supported_features();
// Abuse features to report libusb status. // Abuse features to report libusb status.
@ -1139,7 +1145,6 @@ int handle_host_request(const char* service, TransportType type,
SendOkay(reply_fd, FeatureSetToString(features)); SendOkay(reply_fd, FeatureSetToString(features));
return 0; return 0;
} }
#endif
// remove TCP transport // remove TCP transport
if (!strncmp(service, "disconnect:", 11)) { if (!strncmp(service, "disconnect:", 11)) {
@ -1209,15 +1214,19 @@ int handle_host_request(const char* service, TransportType type,
} }
if (!strcmp(service, "reconnect")) { if (!strcmp(service, "reconnect")) {
if (s->transport != nullptr) { std::string response;
kick_transport(s->transport); atransport* t = acquire_one_transport(type, serial, nullptr, &response, true);
if (t != nullptr) {
kick_transport(t);
response =
"reconnecting " + t->serial_name() + " [" + t->connection_state_name() + "]\n";
} }
return SendOkay(reply_fd, "done"); return SendOkay(reply_fd, response);
} }
#endif // ADB_HOST
int ret = handle_forward_request(service, type, serial, reply_fd); int ret = handle_forward_request(service, type, serial, reply_fd);
if (ret >= 0) if (ret >= 0)
return ret - 1; return ret - 1;
return -1; return -1;
} }
#endif // ADB_HOST

View file

@ -139,7 +139,7 @@ int adb_server_main(int is_daemon, const std::string& socket_spec, int ack_reply
int get_available_local_transport_index(); int get_available_local_transport_index();
#endif #endif
int init_socket_transport(atransport *t, int s, int port, int local); int init_socket_transport(atransport *t, int s, int port, int local);
void init_usb_transport(atransport *t, usb_handle *usb, ConnectionState state); void init_usb_transport(atransport* t, usb_handle* usb);
std::string getEmulatorSerialString(int console_port); std::string getEmulatorSerialString(int console_port);
#if ADB_HOST #if ADB_HOST
@ -222,6 +222,9 @@ void handle_online(atransport *t);
void handle_offline(atransport *t); void handle_offline(atransport *t);
void send_connect(atransport *t); void send_connect(atransport *t);
#if ADB_HOST
void SendConnectOnHost(atransport* t);
#endif
void parse_banner(const std::string&, atransport* t); void parse_banner(const std::string&, atransport* t);

View file

@ -136,8 +136,7 @@ int _adb_connect(const std::string& service, std::string* error) {
return -2; return -2;
} }
if ((memcmp(&service[0],"host",4) != 0 || service == "host:reconnect") && if (memcmp(&service[0], "host", 4) != 0 && switch_socket_transport(fd, error)) {
switch_socket_transport(fd, error)) {
return -1; return -1;
} }
@ -147,11 +146,9 @@ int _adb_connect(const std::string& service, std::string* error) {
return -1; return -1;
} }
if (service != "reconnect") { if (!adb_status(fd, error)) {
if (!adb_status(fd, error)) { adb_close(fd);
adb_close(fd); return -1;
return -1;
}
} }
D("_adb_connect: return fd %d", fd); D("_adb_connect: return fd %d", fd);

View file

@ -155,7 +155,7 @@ void adb_trace_init(char** argv) {
} }
#endif #endif
#if !defined(_WIN32) #if ADB_HOST && !defined(_WIN32)
// adb historically ignored $ANDROID_LOG_TAGS but passed it through to logcat. // adb historically ignored $ANDROID_LOG_TAGS but passed it through to logcat.
// If set, move it out of the way so that libbase logging doesn't try to parse it. // If set, move it out of the way so that libbase logging doesn't try to parse it.
std::string log_tags; std::string log_tags;
@ -168,7 +168,7 @@ void adb_trace_init(char** argv) {
android::base::InitLogging(argv, &AdbLogger); android::base::InitLogging(argv, &AdbLogger);
#if !defined(_WIN32) #if ADB_HOST && !defined(_WIN32)
// Put $ANDROID_LOG_TAGS back so we can pass it to logcat. // Put $ANDROID_LOG_TAGS back so we can pass it to logcat.
if (!log_tags.empty()) setenv("ANDROID_LOG_TAGS", log_tags.c_str(), 1); if (!log_tags.empty()) setenv("ANDROID_LOG_TAGS", log_tags.c_str(), 1);
#endif #endif

View file

@ -58,6 +58,9 @@ extern int adb_trace_mask;
void adb_trace_init(char**); void adb_trace_init(char**);
void adb_trace_enable(AdbTrace trace_tag); void adb_trace_enable(AdbTrace trace_tag);
// Include <atomic> before stdatomic.h (introduced in cutils/trace.h) to avoid compile error.
#include <atomic>
#define ATRACE_TAG ATRACE_TAG_ADB #define ATRACE_TAG ATRACE_TAG_ADB
#include <cutils/trace.h> #include <cutils/trace.h>
#include <utils/Trace.h> #include <utils/Trace.h>

View file

@ -62,12 +62,11 @@ struct DeviceHandleDeleter {
using unique_device_handle = std::unique_ptr<libusb_device_handle, DeviceHandleDeleter>; using unique_device_handle = std::unique_ptr<libusb_device_handle, DeviceHandleDeleter>;
struct transfer_info { struct transfer_info {
transfer_info(const char* name, uint16_t zero_mask) : transfer_info(const char* name, uint16_t zero_mask, bool is_bulk_out)
name(name), : name(name),
transfer(libusb_alloc_transfer(0)), transfer(libusb_alloc_transfer(0)),
zero_mask(zero_mask) is_bulk_out(is_bulk_out),
{ zero_mask(zero_mask) {}
}
~transfer_info() { ~transfer_info() {
libusb_free_transfer(transfer); libusb_free_transfer(transfer);
@ -75,6 +74,7 @@ struct transfer_info {
const char* name; const char* name;
libusb_transfer* transfer; libusb_transfer* transfer;
bool is_bulk_out;
bool transfer_complete; bool transfer_complete;
std::condition_variable cv; std::condition_variable cv;
std::mutex mutex; std::mutex mutex;
@ -96,12 +96,11 @@ struct usb_handle : public ::usb_handle {
serial(serial), serial(serial),
closing(false), closing(false),
device_handle(device_handle.release()), device_handle(device_handle.release()),
read("read", zero_mask), read("read", zero_mask, false),
write("write", zero_mask), write("write", zero_mask, true),
interface(interface), interface(interface),
bulk_in(bulk_in), bulk_in(bulk_in),
bulk_out(bulk_out) { bulk_out(bulk_out) {}
}
~usb_handle() { ~usb_handle() {
Close(); Close();
@ -365,11 +364,6 @@ void usb_init() {
device_poll_thread = new std::thread(poll_for_devices); device_poll_thread = new std::thread(poll_for_devices);
android::base::at_quick_exit([]() { android::base::at_quick_exit([]() {
terminate_device_poll_thread = true; terminate_device_poll_thread = true;
std::unique_lock<std::mutex> lock(usb_handles_mutex);
for (auto& it : usb_handles) {
it.second->Close();
}
lock.unlock();
device_poll_thread->join(); device_poll_thread->join();
}); });
} }
@ -397,7 +391,8 @@ static int perform_usb_transfer(usb_handle* h, transfer_info* info,
return; return;
} }
if (transfer->actual_length != transfer->length) { // usb_read() can return when receiving some data.
if (info->is_bulk_out && transfer->actual_length != transfer->length) {
LOG(DEBUG) << info->name << " transfer incomplete, resubmitting"; LOG(DEBUG) << info->name << " transfer incomplete, resubmitting";
transfer->length -= transfer->actual_length; transfer->length -= transfer->actual_length;
transfer->buffer += transfer->actual_length; transfer->buffer += transfer->actual_length;
@ -491,8 +486,12 @@ int usb_read(usb_handle* h, void* d, int len) {
info->transfer->num_iso_packets = 0; info->transfer->num_iso_packets = 0;
int rc = perform_usb_transfer(h, info, std::move(lock)); int rc = perform_usb_transfer(h, info, std::move(lock));
LOG(DEBUG) << "usb_read(" << len << ") = " << rc; LOG(DEBUG) << "usb_read(" << len << ") = " << rc << ", actual_length "
return rc; << info->transfer->actual_length;
if (rc < 0) {
return rc;
}
return info->transfer->actual_length;
} }
int usb_close(usb_handle* h) { int usb_close(usb_handle* h) {

View file

@ -401,7 +401,6 @@ static int usb_bulk_read(usb_handle* h, void* data, int len) {
} }
} }
int usb_write(usb_handle *h, const void *_data, int len) int usb_write(usb_handle *h, const void *_data, int len)
{ {
D("++ usb_write ++"); D("++ usb_write ++");
@ -429,19 +428,16 @@ int usb_read(usb_handle *h, void *_data, int len)
int n; int n;
D("++ usb_read ++"); D("++ usb_read ++");
while(len > 0) { int orig_len = len;
while (len == orig_len) {
int xfer = len; int xfer = len;
D("[ usb read %d fd = %d], path=%s", xfer, h->fd, h->path.c_str()); D("[ usb read %d fd = %d], path=%s", xfer, h->fd, h->path.c_str());
n = usb_bulk_read(h, data, xfer); n = usb_bulk_read(h, data, xfer);
D("[ usb read %d ] = %d, path=%s", xfer, n, h->path.c_str()); D("[ usb read %d ] = %d, path=%s", xfer, n, h->path.c_str());
if(n != xfer) { if (n <= 0) {
if((errno == ETIMEDOUT) && (h->fd != -1)) { if((errno == ETIMEDOUT) && (h->fd != -1)) {
D("[ timeout ]"); D("[ timeout ]");
if(n > 0){
data += n;
len -= n;
}
continue; continue;
} }
D("ERROR: n = %d, errno = %d (%s)", D("ERROR: n = %d, errno = %d (%s)",
@ -449,12 +445,12 @@ int usb_read(usb_handle *h, void *_data, int len)
return -1; return -1;
} }
len -= xfer; len -= n;
data += xfer; data += n;
} }
D("-- usb_read --"); D("-- usb_read --");
return 0; return orig_len - len;
} }
void usb_kick(usb_handle* h) { void usb_kick(usb_handle* h) {

View file

@ -518,7 +518,7 @@ int usb_read(usb_handle *handle, void *buf, int len)
} }
if (kIOReturnSuccess == result) if (kIOReturnSuccess == result)
return 0; return numBytes;
else { else {
LOG(ERROR) << "usb_read failed with status: " << std::hex << result; LOG(ERROR) << "usb_read failed with status: " << std::hex << result;
} }

View file

@ -415,6 +415,7 @@ int usb_read(usb_handle *handle, void* data, int len) {
unsigned long time_out = 0; unsigned long time_out = 0;
unsigned long read = 0; unsigned long read = 0;
int err = 0; int err = 0;
int orig_len = len;
D("usb_read %d", len); D("usb_read %d", len);
if (NULL == handle) { if (NULL == handle) {
@ -423,9 +424,8 @@ int usb_read(usb_handle *handle, void* data, int len) {
goto fail; goto fail;
} }
while (len > 0) { while (len == orig_len) {
if (!AdbReadEndpointSync(handle->adb_read_pipe, data, len, &read, if (!AdbReadEndpointSync(handle->adb_read_pipe, data, len, &read, time_out)) {
time_out)) {
D("AdbReadEndpointSync failed: %s", D("AdbReadEndpointSync failed: %s",
android::base::SystemErrorCodeToString(GetLastError()).c_str()); android::base::SystemErrorCodeToString(GetLastError()).c_str());
err = EIO; err = EIO;
@ -433,11 +433,11 @@ int usb_read(usb_handle *handle, void* data, int len) {
} }
D("usb_read got: %ld, expected: %d", read, len); D("usb_read got: %ld, expected: %d", read, len);
data = (char *)data + read; data = (char*)data + read;
len -= read; len -= read;
} }
return 0; return orig_len - len;
fail: fail:
// Any failure should cause us to kick the device instead of leaving it a // Any failure should cause us to kick the device instead of leaving it a

View file

@ -212,6 +212,7 @@ static void help() {
" kill-server kill the server if it is running\n" " kill-server kill the server if it is running\n"
" reconnect kick connection from host side to force reconnect\n" " reconnect kick connection from host side to force reconnect\n"
" reconnect device kick connection from device side to force reconnect\n" " reconnect device kick connection from device side to force reconnect\n"
" reconnect offline reset offline/unauthorized devices to force reconnect\n"
"\n" "\n"
"environment variables:\n" "environment variables:\n"
" $ADB_TRACE\n" " $ADB_TRACE\n"
@ -1929,7 +1930,7 @@ int adb_commandline(int argc, const char** argv) {
return adb_query_command("host:host-features"); return adb_query_command("host:host-features");
} else if (!strcmp(argv[0], "reconnect")) { } else if (!strcmp(argv[0], "reconnect")) {
if (argc == 1) { if (argc == 1) {
return adb_query_command("host:reconnect"); return adb_query_command(format_host_command(argv[0], transport_type, serial));
} else if (argc == 2) { } else if (argc == 2) {
if (!strcmp(argv[1], "device")) { if (!strcmp(argv[1], "device")) {
std::string err; std::string err;

View file

@ -75,13 +75,13 @@ static std::atomic<bool> terminate_loop(false);
static bool main_thread_valid; static bool main_thread_valid;
static unsigned long main_thread_id; static unsigned long main_thread_id;
static void check_main_thread() { void check_main_thread() {
if (main_thread_valid) { if (main_thread_valid) {
CHECK_EQ(main_thread_id, adb_thread_id()); CHECK_EQ(main_thread_id, adb_thread_id());
} }
} }
static void set_main_thread() { void set_main_thread() {
main_thread_valid = true; main_thread_valid = true;
main_thread_id = adb_thread_id(); main_thread_id = adb_thread_id();
} }

View file

@ -76,9 +76,12 @@ void fdevent_set_timeout(fdevent *fde, int64_t timeout_ms);
*/ */
void fdevent_loop(); void fdevent_loop();
void check_main_thread();
// The following functions are used only for tests. // The following functions are used only for tests.
void fdevent_terminate_loop(); void fdevent_terminate_loop();
size_t fdevent_installed_count(); size_t fdevent_installed_count();
void fdevent_reset(); void fdevent_reset();
void set_main_thread();
#endif #endif

View file

@ -347,7 +347,7 @@ static void wait_for_state(int fd, void* data) {
std::string error = "unknown error"; std::string error = "unknown error";
const char* serial = sinfo->serial.length() ? sinfo->serial.c_str() : NULL; const char* serial = sinfo->serial.length() ? sinfo->serial.c_str() : NULL;
atransport* t = acquire_one_transport(sinfo->transport_type, serial, &is_ambiguous, &error); atransport* t = acquire_one_transport(sinfo->transport_type, serial, &is_ambiguous, &error);
if (t != nullptr && (sinfo->state == kCsAny || sinfo->state == t->connection_state)) { if (t != nullptr && (sinfo->state == kCsAny || sinfo->state == t->GetConnectionState())) {
SendOkay(fd); SendOkay(fd);
break; break;
} else if (!is_ambiguous) { } else if (!is_ambiguous) {

View file

@ -794,7 +794,7 @@ static int smart_socket_enqueue(asocket* s, apacket* p) {
if (!s->transport) { if (!s->transport) {
SendFail(s->peer->fd, "device offline (no transport)"); SendFail(s->peer->fd, "device offline (no transport)");
goto fail; goto fail;
} else if (s->transport->connection_state == kCsOffline) { } else if (s->transport->GetConnectionState() == kCsOffline) {
/* if there's no remote we fail the connection /* if there's no remote we fail the connection
** right here and terminate it ** right here and terminate it
*/ */

View file

@ -1188,6 +1188,77 @@ class FileOperationsTest(DeviceTest):
self.device.shell(['rm', '-f', '/data/local/tmp/adb-test-*']) self.device.shell(['rm', '-f', '/data/local/tmp/adb-test-*'])
class DeviceOfflineTest(DeviceTest):
def _get_device_state(self, serialno):
output = subprocess.check_output(self.device.adb_cmd + ['devices'])
for line in output.split('\n'):
m = re.match('(\S+)\s+(\S+)', line)
if m and m.group(1) == serialno:
return m.group(2)
return None
def test_killed_when_pushing_a_large_file(self):
"""
While running adb push with a large file, kill adb server.
Occasionally the device becomes offline. Because the device is still
reading data without realizing that the adb server has been restarted.
Test if we can bring the device online automatically now.
http://b/32952319
"""
serialno = subprocess.check_output(self.device.adb_cmd + ['get-serialno']).strip()
# 1. Push a large file
file_path = 'tmp_large_file'
try:
fh = open(file_path, 'w')
fh.write('\0' * (100 * 1024 * 1024))
fh.close()
subproc = subprocess.Popen(self.device.adb_cmd + ['push', file_path, '/data/local/tmp'])
time.sleep(0.1)
# 2. Kill the adb server
subprocess.check_call(self.device.adb_cmd + ['kill-server'])
subproc.terminate()
finally:
try:
os.unlink(file_path)
except:
pass
# 3. See if the device still exist.
# Sleep to wait for the adb server exit.
time.sleep(0.5)
# 4. The device should be online
self.assertEqual(self._get_device_state(serialno), 'device')
def test_killed_when_pulling_a_large_file(self):
"""
While running adb pull with a large file, kill adb server.
Occasionally the device can't be connected. Because the device is trying to
send a message larger than what is expected by the adb server.
Test if we can bring the device online automatically now.
"""
serialno = subprocess.check_output(self.device.adb_cmd + ['get-serialno']).strip()
file_path = 'tmp_large_file'
try:
# 1. Create a large file on device.
self.device.shell(['dd', 'if=/dev/zero', 'of=/data/local/tmp/tmp_large_file',
'bs=1000000', 'count=100'])
# 2. Pull the large file on host.
subproc = subprocess.Popen(self.device.adb_cmd +
['pull','/data/local/tmp/tmp_large_file', file_path])
time.sleep(0.1)
# 3. Kill the adb server
subprocess.check_call(self.device.adb_cmd + ['kill-server'])
subproc.terminate()
finally:
try:
os.unlink(file_path)
except:
pass
# 4. See if the device still exist.
# Sleep to wait for the adb server exit.
time.sleep(0.5)
self.assertEqual(self._get_device_state(serialno), 'device')
def main(): def main():
random.seed(0) random.seed(0)
if len(adb.get_devices()) > 0: if len(adb.get_devices()) > 0:

View file

@ -33,6 +33,7 @@
#include <android-base/logging.h> #include <android-base/logging.h>
#include <android-base/parsenetaddress.h> #include <android-base/parsenetaddress.h>
#include <android-base/quick_exit.h>
#include <android-base/stringprintf.h> #include <android-base/stringprintf.h>
#include <android-base/strings.h> #include <android-base/strings.h>
@ -41,6 +42,7 @@
#include "adb_trace.h" #include "adb_trace.h"
#include "adb_utils.h" #include "adb_utils.h"
#include "diagnose_usb.h" #include "diagnose_usb.h"
#include "fdevent.h"
static void transport_unref(atransport *t); static void transport_unref(atransport *t);
@ -209,6 +211,11 @@ static void read_transport_thread(void* _t) {
put_apacket(p); put_apacket(p);
break; break;
} }
#if ADB_HOST
if (p->msg.command == 0) {
continue;
}
#endif
} }
D("%s: received remote packet, sending to transport", t->serial); D("%s: received remote packet, sending to transport", t->serial);
@ -271,7 +278,11 @@ static void write_transport_thread(void* _t) {
if (active) { if (active) {
D("%s: transport got packet, sending to remote", t->serial); D("%s: transport got packet, sending to remote", t->serial);
ATRACE_NAME("write_transport write_remote"); ATRACE_NAME("write_transport write_remote");
t->write_to_remote(p, t); if (t->Write(p) != 0) {
D("%s: remote write failed for transport", t->serial);
put_apacket(p);
break;
}
} else { } else {
D("%s: transport ignoring packet while offline", t->serial); D("%s: transport ignoring packet while offline", t->serial);
} }
@ -493,7 +504,7 @@ static void transport_registration_func(int _fd, unsigned ev, void* data) {
} }
/* don't create transport threads for inaccessible devices */ /* don't create transport threads for inaccessible devices */
if (t->connection_state != kCsNoPerm) { if (t->GetConnectionState() != kCsNoPerm) {
/* initial references are the two threads */ /* initial references are the two threads */
t->ref_count = 2; t->ref_count = 2;
@ -538,6 +549,15 @@ void init_transport_registration(void) {
transport_registration_func, 0); transport_registration_func, 0);
fdevent_set(&transport_registration_fde, FDE_READ); fdevent_set(&transport_registration_fde, FDE_READ);
#if ADB_HOST
android::base::at_quick_exit([]() {
// To avoid only writing part of a packet to a transport after exit, kick all transports.
std::lock_guard<std::mutex> lock(transport_lock);
for (auto t : transport_list) {
t->Kick();
}
});
#endif
} }
/* the fdevent select pump is single threaded */ /* the fdevent select pump is single threaded */
@ -600,7 +620,7 @@ static int qual_match(const char* to_test, const char* prefix, const char* qual,
} }
atransport* acquire_one_transport(TransportType type, const char* serial, bool* is_ambiguous, atransport* acquire_one_transport(TransportType type, const char* serial, bool* is_ambiguous,
std::string* error_out) { std::string* error_out, bool accept_any_state) {
atransport* result = nullptr; atransport* result = nullptr;
if (serial) { if (serial) {
@ -615,7 +635,7 @@ atransport* acquire_one_transport(TransportType type, const char* serial, bool*
std::unique_lock<std::mutex> lock(transport_lock); std::unique_lock<std::mutex> lock(transport_lock);
for (const auto& t : transport_list) { for (const auto& t : transport_list) {
if (t->connection_state == kCsNoPerm) { if (t->GetConnectionState() == kCsNoPerm) {
#if ADB_HOST #if ADB_HOST
*error_out = UsbNoPermissionsLongHelpText(); *error_out = UsbNoPermissionsLongHelpText();
#endif #endif
@ -664,7 +684,7 @@ atransport* acquire_one_transport(TransportType type, const char* serial, bool*
lock.unlock(); lock.unlock();
// Don't return unauthorized devices; the caller can't do anything with them. // Don't return unauthorized devices; the caller can't do anything with them.
if (result && result->connection_state == kCsUnauthorized) { if (result && result->GetConnectionState() == kCsUnauthorized && !accept_any_state) {
*error_out = "device unauthorized.\n"; *error_out = "device unauthorized.\n";
char* ADB_VENDOR_KEYS = getenv("ADB_VENDOR_KEYS"); char* ADB_VENDOR_KEYS = getenv("ADB_VENDOR_KEYS");
*error_out += "This adb server's $ADB_VENDOR_KEYS is "; *error_out += "This adb server's $ADB_VENDOR_KEYS is ";
@ -676,7 +696,7 @@ atransport* acquire_one_transport(TransportType type, const char* serial, bool*
} }
// Don't return offline devices; the caller can't do anything with them. // Don't return offline devices; the caller can't do anything with them.
if (result && result->connection_state == kCsOffline) { if (result && result->GetConnectionState() == kCsOffline && !accept_any_state) {
*error_out = "device offline"; *error_out = "device offline";
result = nullptr; result = nullptr;
} }
@ -688,16 +708,38 @@ atransport* acquire_one_transport(TransportType type, const char* serial, bool*
return result; return result;
} }
int atransport::Write(apacket* p) {
#if ADB_HOST
std::lock_guard<std::mutex> lock(write_msg_lock_);
#endif
return write_func_(p, this);
}
void atransport::Kick() { void atransport::Kick() {
if (!kicked_) { if (!kicked_) {
kicked_ = true; kicked_ = true;
CHECK(kick_func_ != nullptr); CHECK(kick_func_ != nullptr);
#if ADB_HOST
// On host, adb server should avoid writing part of a packet, so don't
// kick a transport whiling writing a packet.
std::lock_guard<std::mutex> lock(write_msg_lock_);
#endif
kick_func_(this); kick_func_(this);
} }
} }
ConnectionState atransport::GetConnectionState() const {
return connection_state_;
}
void atransport::SetConnectionState(ConnectionState state) {
check_main_thread();
connection_state_ = state;
}
const std::string atransport::connection_state_name() const { const std::string atransport::connection_state_name() const {
switch (connection_state) { ConnectionState state = GetConnectionState();
switch (state) {
case kCsOffline: case kCsOffline:
return "offline"; return "offline";
case kCsBootloader: case kCsBootloader:
@ -963,10 +1005,10 @@ void kick_all_tcp_devices() {
void register_usb_transport(usb_handle* usb, const char* serial, const char* devpath, void register_usb_transport(usb_handle* usb, const char* serial, const char* devpath,
unsigned writeable) { unsigned writeable) {
atransport* t = new atransport(); atransport* t = new atransport((writeable ? kCsOffline : kCsNoPerm));
D("transport: %p init'ing for usb_handle %p (sn='%s')", t, usb, serial ? serial : ""); D("transport: %p init'ing for usb_handle %p (sn='%s')", t, usb, serial ? serial : "");
init_usb_transport(t, usb, (writeable ? kCsOffline : kCsNoPerm)); init_usb_transport(t, usb);
if (serial) { if (serial) {
t->serial = strdup(serial); t->serial = strdup(serial);
} }
@ -987,12 +1029,13 @@ void register_usb_transport(usb_handle* usb, const char* serial, const char* dev
void unregister_usb_transport(usb_handle* usb) { void unregister_usb_transport(usb_handle* usb) {
std::lock_guard<std::mutex> lock(transport_lock); std::lock_guard<std::mutex> lock(transport_lock);
transport_list.remove_if( transport_list.remove_if(
[usb](atransport* t) { return t->usb == usb && t->connection_state == kCsNoPerm; }); [usb](atransport* t) { return t->usb == usb && t->GetConnectionState() == kCsNoPerm; });
} }
int check_header(apacket* p, atransport* t) { int check_header(apacket* p, atransport* t) {
if (p->msg.magic != (p->msg.command ^ 0xffffffff)) { if (p->msg.magic != (p->msg.command ^ 0xffffffff)) {
VLOG(RWX) << "check_header(): invalid magic"; VLOG(RWX) << "check_header(): invalid magic command = " << std::hex << p->msg.command
<< ", magic = " << p->msg.magic;
return -1; return -1;
} }
@ -1020,4 +1063,11 @@ std::shared_ptr<RSA> atransport::NextKey() {
keys_.pop_front(); keys_.pop_front();
return result; return result;
} }
bool atransport::SetSendConnectOnError() {
if (has_send_connect_on_error_) {
return false;
}
has_send_connect_on_error_ = true;
return true;
}
#endif #endif

View file

@ -19,10 +19,12 @@
#include <sys/types.h> #include <sys/types.h>
#include <atomic>
#include <deque> #include <deque>
#include <functional> #include <functional>
#include <list> #include <list>
#include <memory> #include <memory>
#include <mutex>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
@ -57,31 +59,35 @@ public:
// class in one go is a very large change. Given how bad our testing is, // class in one go is a very large change. Given how bad our testing is,
// it's better to do this piece by piece. // it's better to do this piece by piece.
atransport() { atransport(ConnectionState state = kCsOffline) : connection_state_(state) {
transport_fde = {}; transport_fde = {};
protocol_version = A_VERSION; protocol_version = A_VERSION;
max_payload = MAX_PAYLOAD; max_payload = MAX_PAYLOAD;
} }
virtual ~atransport() {} virtual ~atransport() {}
int (*read_from_remote)(apacket* p, atransport* t) = nullptr; int (*read_from_remote)(apacket* p, atransport* t) = nullptr;
int (*write_to_remote)(apacket* p, atransport* t) = nullptr;
void (*close)(atransport* t) = nullptr; void (*close)(atransport* t) = nullptr;
void SetWriteFunction(int (*write_func)(apacket*, atransport*)) { write_func_ = write_func; }
void SetKickFunction(void (*kick_func)(atransport*)) { void SetKickFunction(void (*kick_func)(atransport*)) {
kick_func_ = kick_func; kick_func_ = kick_func;
} }
bool IsKicked() { bool IsKicked() {
return kicked_; return kicked_;
} }
int Write(apacket* p);
void Kick(); void Kick();
// ConnectionState can be read by all threads, but can only be written in the main thread.
ConnectionState GetConnectionState() const;
void SetConnectionState(ConnectionState state);
int fd = -1; int fd = -1;
int transport_socket = -1; int transport_socket = -1;
fdevent transport_fde; fdevent transport_fde;
size_t ref_count = 0; size_t ref_count = 0;
uint32_t sync_token = 0; uint32_t sync_token = 0;
ConnectionState connection_state = kCsOffline;
bool online = false; bool online = false;
TransportType type = kTransportAny; TransportType type = kTransportAny;
@ -114,11 +120,13 @@ public:
#if ADB_HOST #if ADB_HOST
std::shared_ptr<RSA> NextKey(); std::shared_ptr<RSA> NextKey();
bool SetSendConnectOnError();
#endif #endif
char token[TOKEN_SIZE] = {}; char token[TOKEN_SIZE] = {};
size_t failed_auth_attempts = 0; size_t failed_auth_attempts = 0;
const std::string serial_name() const { return serial ? serial : "<unknown>"; }
const std::string connection_state_name() const; const std::string connection_state_name() const;
void update_version(int version, size_t payload); void update_version(int version, size_t payload);
@ -157,6 +165,7 @@ private:
int local_port_for_emulator_ = -1; int local_port_for_emulator_ = -1;
bool kicked_ = false; bool kicked_ = false;
void (*kick_func_)(atransport*) = nullptr; void (*kick_func_)(atransport*) = nullptr;
int (*write_func_)(apacket*, atransport*) = nullptr;
// A set of features transmitted in the banner with the initial connection. // A set of features transmitted in the banner with the initial connection.
// This is stored in the banner as 'features=feature0,feature1,etc'. // This is stored in the banner as 'features=feature0,feature1,etc'.
@ -167,8 +176,11 @@ private:
// A list of adisconnect callbacks called when the transport is kicked. // A list of adisconnect callbacks called when the transport is kicked.
std::list<adisconnect*> disconnects_; std::list<adisconnect*> disconnects_;
std::atomic<ConnectionState> connection_state_;
#if ADB_HOST #if ADB_HOST
std::deque<std::shared_ptr<RSA>> keys_; std::deque<std::shared_ptr<RSA>> keys_;
std::mutex write_msg_lock_;
bool has_send_connect_on_error_ = false;
#endif #endif
DISALLOW_COPY_AND_ASSIGN(atransport); DISALLOW_COPY_AND_ASSIGN(atransport);
@ -181,8 +193,8 @@ private:
* is set to true and nullptr returned. * is set to true and nullptr returned.
* If no suitable transport is found, error is set and nullptr returned. * If no suitable transport is found, error is set and nullptr returned.
*/ */
atransport* acquire_one_transport(TransportType type, const char* serial, atransport* acquire_one_transport(TransportType type, const char* serial, bool* is_ambiguous,
bool* is_ambiguous, std::string* error_out); std::string* error_out, bool accept_any_state = false);
void kick_transport(atransport* t); void kick_transport(atransport* t);
void update_transports(void); void update_transports(void);

View file

@ -515,12 +515,11 @@ int init_socket_transport(atransport *t, int s, int adb_port, int local)
int fail = 0; int fail = 0;
t->SetKickFunction(remote_kick); t->SetKickFunction(remote_kick);
t->SetWriteFunction(remote_write);
t->close = remote_close; t->close = remote_close;
t->read_from_remote = remote_read; t->read_from_remote = remote_read;
t->write_to_remote = remote_write;
t->sfd = s; t->sfd = s;
t->sync_token = 1; t->sync_token = 1;
t->connection_state = kCsOffline;
t->type = kTransportLocal; t->type = kTransportLocal;
#if ADB_HOST #if ADB_HOST

View file

@ -94,12 +94,13 @@ TEST(transport, SetFeatures) {
} }
TEST(transport, parse_banner_no_features) { TEST(transport, parse_banner_no_features) {
set_main_thread();
atransport t; atransport t;
parse_banner("host::", &t); parse_banner("host::", &t);
ASSERT_EQ(0U, t.features().size()); ASSERT_EQ(0U, t.features().size());
ASSERT_EQ(kCsHost, t.connection_state); ASSERT_EQ(kCsHost, t.GetConnectionState());
ASSERT_EQ(nullptr, t.product); ASSERT_EQ(nullptr, t.product);
ASSERT_EQ(nullptr, t.model); ASSERT_EQ(nullptr, t.model);
@ -113,7 +114,7 @@ TEST(transport, parse_banner_product_features) {
"host::ro.product.name=foo;ro.product.model=bar;ro.product.device=baz;"; "host::ro.product.name=foo;ro.product.model=bar;ro.product.device=baz;";
parse_banner(banner, &t); parse_banner(banner, &t);
ASSERT_EQ(kCsHost, t.connection_state); ASSERT_EQ(kCsHost, t.GetConnectionState());
ASSERT_EQ(0U, t.features().size()); ASSERT_EQ(0U, t.features().size());
@ -130,7 +131,7 @@ TEST(transport, parse_banner_features) {
"features=woodly,doodly"; "features=woodly,doodly";
parse_banner(banner, &t); parse_banner(banner, &t);
ASSERT_EQ(kCsHost, t.connection_state); ASSERT_EQ(kCsHost, t.GetConnectionState());
ASSERT_EQ(2U, t.features().size()); ASSERT_EQ(2U, t.features().size());
ASSERT_TRUE(t.has_feature("woodly")); ASSERT_TRUE(t.has_feature("woodly"));

View file

@ -25,9 +25,115 @@
#include "adb.h" #include "adb.h"
#if ADB_HOST
static constexpr size_t MAX_USB_BULK_PACKET_SIZE = 1024u;
// Call usb_read using a buffer having a multiple of MAX_USB_BULK_PACKET_SIZE bytes
// to avoid overflow. See http://libusb.sourceforge.net/api-1.0/packetoverflow.html.
static int UsbReadMessage(usb_handle* h, amessage* msg) {
D("UsbReadMessage");
char buffer[MAX_USB_BULK_PACKET_SIZE];
int n = usb_read(h, buffer, sizeof(buffer));
if (n == sizeof(*msg)) {
memcpy(msg, buffer, sizeof(*msg));
}
return n;
}
// Call usb_read using a buffer having a multiple of MAX_USB_BULK_PACKET_SIZE bytes
// to avoid overflow. See http://libusb.sourceforge.net/api-1.0/packetoverflow.html.
static int UsbReadPayload(usb_handle* h, apacket* p) {
D("UsbReadPayload");
size_t need_size = p->msg.data_length;
size_t data_pos = 0u;
while (need_size > 0u) {
int n = 0;
if (data_pos + MAX_USB_BULK_PACKET_SIZE <= sizeof(p->data)) {
// Read directly to p->data.
size_t rem_size = need_size % MAX_USB_BULK_PACKET_SIZE;
size_t direct_read_size = need_size - rem_size;
if (rem_size &&
data_pos + direct_read_size + MAX_USB_BULK_PACKET_SIZE <= sizeof(p->data)) {
direct_read_size += MAX_USB_BULK_PACKET_SIZE;
}
n = usb_read(h, &p->data[data_pos], direct_read_size);
if (n < 0) {
D("usb_read(size %zu) failed", direct_read_size);
return n;
}
} else {
// Read indirectly using a buffer.
char buffer[MAX_USB_BULK_PACKET_SIZE];
n = usb_read(h, buffer, sizeof(buffer));
if (n < 0) {
D("usb_read(size %zu) failed", sizeof(buffer));
return -1;
}
size_t copy_size = std::min(static_cast<size_t>(n), need_size);
D("usb read %d bytes, need %zu bytes, copy %zu bytes", n, need_size, copy_size);
memcpy(&p->data[data_pos], buffer, copy_size);
}
data_pos += n;
need_size -= std::min(static_cast<size_t>(n), need_size);
}
return static_cast<int>(data_pos);
}
static int remote_read(apacket* p, atransport* t) {
int n = UsbReadMessage(t->usb, &p->msg);
if (n < 0) {
D("remote usb: read terminated (message)");
return -1;
}
if (static_cast<size_t>(n) != sizeof(p->msg) || check_header(p, t)) {
D("remote usb: check_header failed, skip it");
goto err_msg;
}
if (t->GetConnectionState() == kCsOffline) {
// If we read a wrong msg header declaring a large message payload, don't read its payload.
// Otherwise we may miss true messages from the device.
if (p->msg.command != A_CNXN && p->msg.command != A_AUTH) {
goto err_msg;
}
}
if (p->msg.data_length) {
n = UsbReadPayload(t->usb, p);
if (n < 0) {
D("remote usb: terminated (data)");
return -1;
}
if (static_cast<uint32_t>(n) != p->msg.data_length) {
D("remote usb: read payload failed (need %u bytes, give %d bytes), skip it",
p->msg.data_length, n);
goto err_msg;
}
}
if (check_data(p)) {
D("remote usb: check_data failed, skip it");
goto err_msg;
}
return 0;
err_msg:
p->msg.command = 0;
if (t->GetConnectionState() == kCsOffline) {
// If the data toggle of ep_out on device and ep_in on host are not the same, we may receive
// an error message. In this case, resend one A_CNXN message to connect the device.
if (t->SetSendConnectOnError()) {
SendConnectOnHost(t);
}
}
return 0;
}
#else
// On Android devices, we rely on the kernel to provide buffered read.
// So we can recover automatically from EOVERFLOW.
static int remote_read(apacket *p, atransport *t) static int remote_read(apacket *p, atransport *t)
{ {
if(usb_read(t->usb, &p->msg, sizeof(amessage))){ if (usb_read(t->usb, &p->msg, sizeof(amessage))) {
D("remote usb: read terminated (message)"); D("remote usb: read terminated (message)");
return -1; return -1;
} }
@ -38,7 +144,7 @@ static int remote_read(apacket *p, atransport *t)
} }
if(p->msg.data_length) { if(p->msg.data_length) {
if(usb_read(t->usb, p->data, p->msg.data_length)){ if (usb_read(t->usb, p->data, p->msg.data_length)) {
D("remote usb: terminated (data)"); D("remote usb: terminated (data)");
return -1; return -1;
} }
@ -51,17 +157,18 @@ static int remote_read(apacket *p, atransport *t)
return 0; return 0;
} }
#endif
static int remote_write(apacket *p, atransport *t) static int remote_write(apacket *p, atransport *t)
{ {
unsigned size = p->msg.data_length; unsigned size = p->msg.data_length;
if(usb_write(t->usb, &p->msg, sizeof(amessage))) { if (usb_write(t->usb, &p->msg, sizeof(amessage))) {
D("remote usb: 1 - write terminated"); D("remote usb: 1 - write terminated");
return -1; return -1;
} }
if(p->msg.data_length == 0) return 0; if(p->msg.data_length == 0) return 0;
if(usb_write(t->usb, &p->data, size)) { if (usb_write(t->usb, &p->data, size)) {
D("remote usb: 2 - write terminated"); D("remote usb: 2 - write terminated");
return -1; return -1;
} }
@ -75,20 +182,17 @@ static void remote_close(atransport *t)
t->usb = 0; t->usb = 0;
} }
static void remote_kick(atransport *t) static void remote_kick(atransport* t) {
{
usb_kick(t->usb); usb_kick(t->usb);
} }
void init_usb_transport(atransport *t, usb_handle *h, ConnectionState state) void init_usb_transport(atransport* t, usb_handle* h) {
{
D("transport: usb"); D("transport: usb");
t->close = remote_close; t->close = remote_close;
t->SetKickFunction(remote_kick); t->SetKickFunction(remote_kick);
t->SetWriteFunction(remote_write);
t->read_from_remote = remote_read; t->read_from_remote = remote_read;
t->write_to_remote = remote_write;
t->sync_token = 1; t->sync_token = 1;
t->connection_state = state;
t->type = kTransportUsb; t->type = kTransportUsb;
t->usb = h; t->usb = h;
} }