diff --git a/adb/adb.cpp b/adb/adb.cpp index 90ee7b4c9..65fa2e795 100644 --- a/adb/adb.cpp +++ b/adb/adb.cpp @@ -136,8 +136,16 @@ void handle_online(atransport *t) void handle_offline(atransport *t) { - D("adb: offline"); - //Close the associated usb + if (t->GetConnectionState() == kCsOffline) { + LOG(INFO) << t->serial_name() << ": already offline"; + return; + } + + LOG(INFO) << t->serial_name() << ": offline"; + + t->SetConnectionState(kCsOffline); + + // Close the associated usb t->online = 0; // This is necessary to avoid a race condition that occurred when a transport closes @@ -318,10 +326,7 @@ void parse_banner(const std::string& banner, atransport* t) { } static void handle_new_connection(atransport* t, apacket* p) { - if (t->GetConnectionState() != kCsOffline) { - t->SetConnectionState(kCsOffline); - handle_offline(t); - } + handle_offline(t); t->update_version(p->msg.arg0, p->msg.arg1); parse_banner(p->payload, t); @@ -350,19 +355,6 @@ void handle_packet(apacket *p, atransport *t) CHECK_EQ(p->payload.size(), p->msg.data_length); switch(p->msg.command){ - case A_SYNC: - if (p->msg.arg0){ - send_packet(p, t); -#if ADB_HOST - send_connect(t); -#endif - } else { - t->SetConnectionState(kCsOffline); - handle_offline(t); - send_packet(p, t); - } - return; - case A_CNXN: // CONNECT(version, maxdata, "system-id-string") handle_new_connection(t, p); break; diff --git a/adb/protocol.txt b/adb/protocol.txt index 55ea87f0e..f4523c4be 100644 --- a/adb/protocol.txt +++ b/adb/protocol.txt @@ -183,9 +183,11 @@ requirement, since they will be ignored. Command constant: A_SYNC -The SYNC message is used by the io pump to make sure that stale +*** obsolete, no longer used *** + +The SYNC message was used by the io pump to make sure that stale outbound messages are discarded when the connection to the remote side -is broken. It is only used internally to the bridge and never valid +is broken. It was only used internally to the bridge and never valid to send across the wire. * when the connection to the remote side goes offline, the io pump diff --git a/adb/transport.cpp b/adb/transport.cpp index e7a94d517..1ccff921c 100644 --- a/adb/transport.cpp +++ b/adb/transport.cpp @@ -28,6 +28,7 @@ #include #include +#include #include #include #include @@ -66,6 +67,82 @@ TransportId NextTransportId() { return next++; } +BlockingConnectionAdapter::BlockingConnectionAdapter(std::unique_ptr connection) + : underlying_(std::move(connection)) {} + +BlockingConnectionAdapter::~BlockingConnectionAdapter() { + LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): destructing"; + Stop(); +} + +void BlockingConnectionAdapter::Start() { + read_thread_ = std::thread([this]() { + LOG(INFO) << this->transport_name_ << ": read thread spawning"; + while (true) { + std::unique_ptr packet(new apacket()); + if (!underlying_->Read(packet.get())) { + PLOG(INFO) << this->transport_name_ << ": read failed"; + break; + } + read_callback_(this, std::move(packet)); + } + std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "read failed"); }); + }); + + write_thread_ = std::thread([this]() { + LOG(INFO) << this->transport_name_ << ": write thread spawning"; + while (true) { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this]() { return this->stopped_ || !this->write_queue_.empty(); }); + + if (this->stopped_) { + return; + } + + std::unique_ptr packet = std::move(this->write_queue_.front()); + this->write_queue_.pop_front(); + lock.unlock(); + + if (!this->underlying_->Write(packet.get())) { + break; + } + } + std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "write failed"); }); + }); +} + +void BlockingConnectionAdapter::Stop() { + std::unique_lock lock(mutex_); + if (stopped_) { + LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): already stopped"; + return; + } + + stopped_ = true; + lock.unlock(); + + LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): stopping"; + + this->underlying_->Close(); + + this->cv_.notify_one(); + read_thread_.join(); + write_thread_.join(); + + LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): stopped"; + std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "requested stop"); }); +} + +bool BlockingConnectionAdapter::Write(std::unique_ptr packet) { + { + std::unique_lock lock(this->mutex_); + write_queue_.emplace_back(std::move(packet)); + } + + cv_.notify_one(); + return true; +} + bool FdConnection::Read(apacket* packet) { if (!ReadFdExactly(fd_.get(), &packet->msg, sizeof(amessage))) { D("remote local: read terminated (message)"); @@ -144,67 +221,6 @@ static std::string dump_packet(const char* name, const char* func, apacket* p) { return result; } -static int read_packet(int fd, const char* name, apacket** ppacket) { - ATRACE_NAME("read_packet"); - char buff[8]; - if (!name) { - snprintf(buff, sizeof buff, "fd=%d", fd); - name = buff; - } - char* p = reinterpret_cast(ppacket); /* really read a packet address */ - int len = sizeof(apacket*); - while (len > 0) { - int r = adb_read(fd, p, len); - if (r > 0) { - len -= r; - p += r; - } else { - D("%s: read_packet (fd=%d), error ret=%d: %s", name, fd, r, strerror(errno)); - return -1; - } - } - - VLOG(TRANSPORT) << dump_packet(name, "from remote", *ppacket); - return 0; -} - -static int write_packet(int fd, const char* name, apacket** ppacket) { - ATRACE_NAME("write_packet"); - char buff[8]; - if (!name) { - snprintf(buff, sizeof buff, "fd=%d", fd); - name = buff; - } - VLOG(TRANSPORT) << dump_packet(name, "to remote", *ppacket); - char* p = reinterpret_cast(ppacket); /* we really write the packet address */ - int len = sizeof(apacket*); - while (len > 0) { - int r = adb_write(fd, p, len); - if (r > 0) { - len -= r; - p += r; - } else { - D("%s: write_packet (fd=%d) error ret=%d: %s", name, fd, r, strerror(errno)); - return -1; - } - } - return 0; -} - -static void transport_socket_events(int fd, unsigned events, void* _t) { - atransport* t = reinterpret_cast(_t); - D("transport_socket_events(fd=%d, events=%04x,...)", fd, events); - if (events & FDE_READ) { - apacket* p = 0; - if (read_packet(fd, t->serial, &p)) { - D("%s: failed to read packet from transport socket on fd %d", t->serial, fd); - return; - } - - handle_packet(p, (atransport*)_t); - } -} - void send_packet(apacket* p, atransport* t) { p->msg.magic = p->msg.command ^ 0xffffffff; // compute a checksum for connection/auth packets for compatibility reasons @@ -214,162 +230,18 @@ void send_packet(apacket* p, atransport* t) { p->msg.data_check = calculate_apacket_checksum(p); } - print_packet("send", p); + VLOG(TRANSPORT) << dump_packet(t->serial, "to remote", p); if (t == NULL) { fatal("Transport is null"); } - if (write_packet(t->transport_socket, t->serial, &p)) { - fatal_errno("cannot enqueue packet on transport socket"); + if (t->Write(p) != 0) { + D("%s: failed to enqueue packet, closing transport", t->serial); + t->Kick(); } } -// The transport is opened by transport_register_func before -// the read_transport and write_transport threads are started. -// -// The read_transport thread issues a SYNC(1, token) message to let -// the write_transport thread know to start things up. In the event -// of transport IO failure, the read_transport thread will post a -// SYNC(0,0) message to ensure shutdown. -// -// The transport will not actually be closed until both threads exit, but the threads -// will kick the transport on their way out to disconnect the underlying device. -// -// read_transport thread reads data from a transport (representing a usb/tcp connection), -// and makes the main thread call handle_packet(). -static void read_transport_thread(void* _t) { - atransport* t = reinterpret_cast(_t); - apacket* p; - - adb_thread_setname( - android::base::StringPrintf("<-%s", (t->serial != nullptr ? t->serial : "transport"))); - D("%s: starting read_transport thread on fd %d, SYNC online (%d)", t->serial, t->fd, - t->sync_token + 1); - p = get_apacket(); - p->msg.command = A_SYNC; - p->msg.arg0 = 1; - p->msg.arg1 = ++(t->sync_token); - p->msg.magic = A_SYNC ^ 0xffffffff; - D("sending SYNC packet (len = %u, payload.size() = %zu)", p->msg.data_length, p->payload.size()); - if (write_packet(t->fd, t->serial, &p)) { - put_apacket(p); - D("%s: failed to write SYNC packet", t->serial); - goto oops; - } - - D("%s: data pump started", t->serial); - for (;;) { - ATRACE_NAME("read_transport loop"); - p = get_apacket(); - - { - ATRACE_NAME("read_transport read_remote"); - if (!t->connection->Read(p)) { - D("%s: remote read failed for transport", t->serial); - put_apacket(p); - break; - } - - if (!check_header(p, t)) { - D("%s: remote read: bad header", t->serial); - put_apacket(p); - break; - } - -#if ADB_HOST - if (p->msg.command == 0) { - put_apacket(p); - continue; - } -#endif - } - - D("%s: received remote packet, sending to transport", t->serial); - if (write_packet(t->fd, t->serial, &p)) { - put_apacket(p); - D("%s: failed to write apacket to transport", t->serial); - goto oops; - } - } - - D("%s: SYNC offline for transport", t->serial); - p = get_apacket(); - p->msg.command = A_SYNC; - p->msg.arg0 = 0; - p->msg.arg1 = 0; - p->msg.magic = A_SYNC ^ 0xffffffff; - if (write_packet(t->fd, t->serial, &p)) { - put_apacket(p); - D("%s: failed to write SYNC apacket to transport", t->serial); - } - -oops: - D("%s: read_transport thread is exiting", t->serial); - kick_transport(t); - transport_unref(t); -} - -// write_transport thread gets packets sent by the main thread (through send_packet()), -// and writes to a transport (representing a usb/tcp connection). -static void write_transport_thread(void* _t) { - atransport* t = reinterpret_cast(_t); - apacket* p; - int active = 0; - - adb_thread_setname( - android::base::StringPrintf("->%s", (t->serial != nullptr ? t->serial : "transport"))); - D("%s: starting write_transport thread, reading from fd %d", t->serial, t->fd); - - for (;;) { - ATRACE_NAME("write_transport loop"); - if (read_packet(t->fd, t->serial, &p)) { - D("%s: failed to read apacket from transport on fd %d", t->serial, t->fd); - break; - } - - if (p->msg.command == A_SYNC) { - if (p->msg.arg0 == 0) { - D("%s: transport SYNC offline", t->serial); - put_apacket(p); - break; - } else { - if (p->msg.arg1 == t->sync_token) { - D("%s: transport SYNC online", t->serial); - active = 1; - } else { - D("%s: transport ignoring SYNC %d != %d", t->serial, p->msg.arg1, t->sync_token); - } - } - } else { - if (active) { - D("%s: transport got packet, sending to remote", t->serial); - ATRACE_NAME("write_transport write_remote"); - - // Allow sending the payload's implicit null terminator. - if (p->msg.data_length != p->payload.size()) { - LOG(FATAL) << "packet data length doesn't match payload: msg.data_length = " - << p->msg.data_length << ", payload.size() = " << p->payload.size(); - } - - if (t->Write(p) != 0) { - D("%s: remote write failed for transport", t->serial); - put_apacket(p); - break; - } - } else { - D("%s: transport ignoring packet while offline", t->serial); - } - } - - put_apacket(p); - } - - D("%s: write_transport thread is exiting, fd %d", t->serial, t->fd); - kick_transport(t); - transport_unref(t); -} - void kick_transport(atransport* t) { std::lock_guard lock(transport_lock); // As kick_transport() can be called from threads without guarantee that t is valid, @@ -560,9 +432,10 @@ static int transport_write_action(int fd, struct tmsg* m) { return 0; } -static void transport_registration_func(int _fd, unsigned ev, void* data) { +static void remove_transport(atransport*); + +static void transport_registration_func(int _fd, unsigned ev, void*) { tmsg m; - int s[2]; atransport* t; if (!(ev & FDE_READ)) { @@ -576,13 +449,7 @@ static void transport_registration_func(int _fd, unsigned ev, void* data) { t = m.transport; if (m.action == 0) { - D("transport: %s removing and free'ing %d", t->serial, t->transport_socket); - - /* IMPORTANT: the remove closes one half of the - ** socket pair. The close closes the other half. - */ - fdevent_remove(&(t->transport_fde)); - adb_close(t->fd); + D("transport: %s deleting", t->serial); { std::lock_guard lock(transport_lock); @@ -604,23 +471,33 @@ static void transport_registration_func(int _fd, unsigned ev, void* data) { /* don't create transport threads for inaccessible devices */ if (t->GetConnectionState() != kCsNoPerm) { /* initial references are the two threads */ - t->ref_count = 2; + t->ref_count = 1; + t->connection->SetTransportName(t->serial_name()); + t->connection->SetReadCallback([t](Connection*, std::unique_ptr p) { + if (!check_header(p.get(), t)) { + D("%s: remote read: bad header", t->serial); + return false; + } - if (adb_socketpair(s)) { - fatal_errno("cannot open transport socketpair"); - } + VLOG(TRANSPORT) << dump_packet(t->serial, "from remote", p.get()); + apacket* packet = p.release(); - D("transport: %s socketpair: (%d,%d) starting", t->serial, s[0], s[1]); + // TODO: Does this need to run on the main thread? + fdevent_run_on_main_thread([packet, t]() { handle_packet(packet, t); }); + return true; + }); + t->connection->SetErrorCallback([t](Connection*, const std::string& error) { + D("%s: connection terminated: %s", t->serial, error.c_str()); + fdevent_run_on_main_thread([t]() { + handle_offline(t); + transport_unref(t); + }); + }); - t->transport_socket = s[0]; - t->fd = s[1]; - - fdevent_install(&(t->transport_fde), t->transport_socket, transport_socket_events, t); - - fdevent_set(&(t->transport_fde), FDE_READ); - - std::thread(write_transport_thread, t).detach(); - std::thread(read_transport_thread, t).detach(); + t->connection->Start(); +#if ADB_HOST + send_connect(t); +#endif } { @@ -686,7 +563,7 @@ static void transport_unref(atransport* t) { t->ref_count--; if (t->ref_count == 0) { D("transport: %s unref (kicking and closing)", t->serial); - t->connection->Close(); + t->connection->Stop(); remove_transport(t); } else { D("transport: %s unref (count=%zu)", t->serial, t->ref_count); @@ -812,14 +689,14 @@ atransport* acquire_one_transport(TransportType type, const char* serial, Transp } int atransport::Write(apacket* p) { - return this->connection->Write(p) ? 0 : -1; + return this->connection->Write(std::unique_ptr(p)) ? 0 : -1; } void atransport::Kick() { if (!kicked_) { D("kicking transport %s", this->serial); kicked_ = true; - this->connection->Close(); + this->connection->Stop(); } } diff --git a/adb/transport.h b/adb/transport.h index 9700f445b..a492f008d 100644 --- a/adb/transport.h +++ b/adb/transport.h @@ -20,12 +20,14 @@ #include #include +#include #include #include #include #include #include #include +#include #include #include @@ -57,15 +59,47 @@ extern const char* const kFeaturePushSync; TransportId NextTransportId(); -// Abstraction for a blocking packet transport. +// Abstraction for a non-blocking packet transport. struct Connection { Connection() = default; - Connection(const Connection& copy) = delete; - Connection(Connection&& move) = delete; - - // Destroy a Connection. Formerly known as 'Close' in atransport. virtual ~Connection() = default; + void SetTransportName(std::string transport_name) { + transport_name_ = std::move(transport_name); + } + + using ReadCallback = std::function)>; + void SetReadCallback(ReadCallback callback) { + CHECK(!read_callback_); + read_callback_ = callback; + } + + // Called after the Connection has terminated, either by an error or because Stop was called. + using ErrorCallback = std::function; + void SetErrorCallback(ErrorCallback callback) { + CHECK(!error_callback_); + error_callback_ = callback; + } + + virtual bool Write(std::unique_ptr packet) = 0; + + virtual void Start() = 0; + virtual void Stop() = 0; + + std::string transport_name_; + ReadCallback read_callback_; + ErrorCallback error_callback_; +}; + +// Abstraction for a blocking packet transport. +struct BlockingConnection { + BlockingConnection() = default; + BlockingConnection(const BlockingConnection& copy) = delete; + BlockingConnection(BlockingConnection&& move) = delete; + + // Destroy a BlockingConnection. Formerly known as 'Close' in atransport. + virtual ~BlockingConnection() = default; + // Read/Write a packet. These functions are concurrently called from a transport's reader/writer // threads. virtual bool Read(apacket* packet) = 0; @@ -77,7 +111,30 @@ struct Connection { virtual void Close() = 0; }; -struct FdConnection : public Connection { +struct BlockingConnectionAdapter : public Connection { + explicit BlockingConnectionAdapter(std::unique_ptr connection); + + virtual ~BlockingConnectionAdapter(); + + virtual bool Write(std::unique_ptr packet) override final; + + virtual void Start() override final; + virtual void Stop() override final; + + bool stopped_ = false; + + std::unique_ptr underlying_; + std::thread read_thread_; + std::thread write_thread_; + + std::deque> write_queue_; + std::mutex mutex_; + std::condition_variable cv_; + + std::once_flag error_flag_; +}; + +struct FdConnection : public BlockingConnection { explicit FdConnection(unique_fd fd) : fd_(std::move(fd)) {} bool Read(apacket* packet) override final; @@ -89,7 +146,7 @@ struct FdConnection : public Connection { unique_fd fd_; }; -struct UsbConnection : public Connection { +struct UsbConnection : public BlockingConnection { explicit UsbConnection(usb_handle* handle) : handle_(handle) {} ~UsbConnection(); @@ -110,7 +167,6 @@ class atransport { atransport(ConnectionState state = kCsOffline) : id(NextTransportId()), connection_state_(state) { - transport_fde = {}; // Initialize protocol to min version for compatibility with older versions. // Version will be updated post-connect. protocol_version = A_VERSION_MIN; @@ -126,11 +182,7 @@ class atransport { void SetConnectionState(ConnectionState state); const TransportId id; - int fd = -1; - int transport_socket = -1; - fdevent transport_fde; size_t ref_count = 0; - uint32_t sync_token = 0; bool online = false; TransportType type = kTransportAny; diff --git a/adb/transport_local.cpp b/adb/transport_local.cpp index 560a0312b..ff395dc7e 100644 --- a/adb/transport_local.cpp +++ b/adb/transport_local.cpp @@ -445,13 +445,14 @@ int init_socket_transport(atransport* t, int s, int adb_port, int local) { int fail = 0; unique_fd fd(s); - t->sync_token = 1; t->type = kTransportLocal; #if ADB_HOST // Emulator connection. if (local) { - t->connection.reset(new EmulatorConnection(std::move(fd), adb_port)); + std::unique_ptr emulator_connection( + new EmulatorConnection(std::move(fd), adb_port)); + t->connection.reset(new BlockingConnectionAdapter(std::move(emulator_connection))); std::lock_guard lock(local_transports_lock); atransport* existing_transport = find_emulator_transport_by_adb_port_locked(adb_port); if (existing_transport != NULL) { @@ -470,6 +471,7 @@ int init_socket_transport(atransport* t, int s, int adb_port, int local) { #endif // Regular tcp connection. - t->connection.reset(new FdConnection(std::move(fd))); + std::unique_ptr fd_connection(new FdConnection(std::move(fd))); + t->connection.reset(new BlockingConnectionAdapter(std::move(fd_connection))); return fail; } diff --git a/adb/transport_usb.cpp b/adb/transport_usb.cpp index d7565f63d..33e00a1f4 100644 --- a/adb/transport_usb.cpp +++ b/adb/transport_usb.cpp @@ -174,8 +174,8 @@ void UsbConnection::Close() { void init_usb_transport(atransport* t, usb_handle* h) { D("transport: usb"); - t->connection.reset(new UsbConnection(h)); - t->sync_token = 1; + std::unique_ptr connection(new UsbConnection(h)); + t->connection.reset(new BlockingConnectionAdapter(std::move(connection))); t->type = kTransportUsb; }