diff --git a/adb/daemon/auth.cpp b/adb/daemon/auth.cpp index 2e84ce6b9..ec4ab4ad2 100644 --- a/adb/daemon/auth.cpp +++ b/adb/daemon/auth.cpp @@ -16,36 +16,72 @@ #define TRACE_TAG AUTH -#include "adb.h" -#include "adb_auth.h" -#include "adb_io.h" -#include "fdevent/fdevent.h" #include "sysdeps.h" -#include "transport.h" #include #include #include -#include #include +#include +#include #include #include #include +#include #include #include #include #include #include +#include "adb.h" +#include "adb_auth.h" +#include "adb_io.h" +#include "fdevent/fdevent.h" +#include "transport.h" +#include "types.h" + static AdbdAuthContext* auth_ctx; static void adb_disconnected(void* unused, atransport* t); static struct adisconnect adb_disconnect = {adb_disconnected, nullptr}; +static android::base::NoDestructor>> transports; +static uint32_t transport_auth_id = 0; + bool auth_required = true; +static void* transport_to_callback_arg(atransport* transport) { + uint32_t id = transport_auth_id++; + (*transports)[id] = transport->weak(); + return reinterpret_cast(id); +} + +static atransport* transport_from_callback_arg(void* id) { + uint64_t id_u64 = reinterpret_cast(id); + if (id_u64 > std::numeric_limits::max()) { + LOG(FATAL) << "transport_from_callback_arg called on out of range value: " << id_u64; + } + + uint32_t id_u32 = static_cast(id_u64); + auto it = transports->find(id_u32); + if (it == transports->end()) { + LOG(ERROR) << "transport_from_callback_arg failed to find transport for id " << id_u32; + return nullptr; + } + + atransport* t = it->second.get(); + if (!t) { + LOG(WARNING) << "transport_from_callback_arg found already destructed transport"; + return nullptr; + } + + transports->erase(it); + return t; +} + static void IteratePublicKeys(std::function f) { adbd_auth_get_public_keys( auth_ctx, @@ -111,9 +147,16 @@ void adbd_cloexec_auth_socket() { static void adbd_auth_key_authorized(void* arg, uint64_t id) { LOG(INFO) << "adb client authorized"; - auto* transport = static_cast(arg); - transport->auth_id = id; - adbd_auth_verified(transport); + fdevent_run_on_main_thread([=]() { + LOG(INFO) << "arg = " << reinterpret_cast(arg); + auto* transport = transport_from_callback_arg(arg); + if (!transport) { + LOG(ERROR) << "authorization received for deleted transport, ignoring"; + return; + } + transport->auth_id = id; + adbd_auth_verified(transport); + }); } void adbd_auth_init(void) { @@ -158,7 +201,8 @@ static void adb_disconnected(void* unused, atransport* t) { void adbd_auth_confirm_key(atransport* t) { LOG(INFO) << "prompting user to authorize key"; t->AddDisconnect(&adb_disconnect); - adbd_auth_prompt_user(auth_ctx, t->auth_key.data(), t->auth_key.size(), t); + adbd_auth_prompt_user(auth_ctx, t->auth_key.data(), t->auth_key.size(), + transport_to_callback_arg(t)); } void adbd_notify_framework_connected_key(atransport* t) { diff --git a/adb/transport.h b/adb/transport.h index 569e8bbdf..5a750eea1 100644 --- a/adb/transport.h +++ b/adb/transport.h @@ -38,6 +38,7 @@ #include "adb.h" #include "adb_unique_fd.h" +#include "types.h" #include "usb.h" typedef std::unordered_set FeatureSet; @@ -223,7 +224,7 @@ enum class ReconnectResult { Abort, }; -class atransport { +class atransport : public enable_weak_from_this { public: // TODO(danalbert): We expose waaaaaaay too much stuff because this was // historically just a struct, but making the whole thing a more idiomatic @@ -246,7 +247,7 @@ class atransport { } atransport(ConnectionState state = kCsOffline) : atransport([](atransport*) { return ReconnectResult::Abort; }, state) {} - virtual ~atransport(); + ~atransport(); int Write(apacket* p); void Reset(); diff --git a/adb/types.h b/adb/types.h index 6b0022472..c619fffcf 100644 --- a/adb/types.h +++ b/adb/types.h @@ -25,6 +25,7 @@ #include +#include "fdevent/fdevent.h" #include "sysdeps/uio.h" // Essentially std::vector, except without zero initialization or reallocation. @@ -245,3 +246,97 @@ struct IOVector { size_t start_index_ = 0; std::vector chain_; }; + +// An implementation of weak pointers tied to the fdevent run loop. +// +// This allows for code to submit a request for an object, and upon receiving +// a response, know whether the object is still alive, or has been destroyed +// because of other reasons. We keep a list of living weak_ptrs in each object, +// and clear the weak_ptrs when the object is destroyed. This is safe, because +// we require that both the destructor of the referent and the get method on +// the weak_ptr are executed on the main thread. +template +struct enable_weak_from_this; + +template +struct weak_ptr { + weak_ptr() = default; + explicit weak_ptr(T* ptr) { reset(ptr); } + weak_ptr(const weak_ptr& copy) { reset(copy.get()); } + + weak_ptr(weak_ptr&& move) { + reset(move.get()); + move.reset(); + } + + ~weak_ptr() { reset(); } + + weak_ptr& operator=(const weak_ptr& copy) { + if (© == this) { + return *this; + } + + reset(copy.get()); + return *this; + } + + weak_ptr& operator=(weak_ptr&& move) { + if (&move == this) { + return *this; + } + + reset(move.get()); + move.reset(); + return *this; + } + + T* get() { + check_main_thread(); + return ptr_; + } + + void reset(T* ptr = nullptr) { + check_main_thread(); + + if (ptr == ptr_) { + return; + } + + if (ptr_) { + ptr_->weak_ptrs_.erase( + std::remove(ptr_->weak_ptrs_.begin(), ptr_->weak_ptrs_.end(), this)); + } + + ptr_ = ptr; + if (ptr_) { + ptr_->weak_ptrs_.push_back(this); + } + } + + private: + friend struct enable_weak_from_this; + T* ptr_ = nullptr; +}; + +template +struct enable_weak_from_this { + ~enable_weak_from_this() { + if (!weak_ptrs_.empty()) { + check_main_thread(); + for (auto& weak : weak_ptrs_) { + weak->ptr_ = nullptr; + } + weak_ptrs_.clear(); + } + } + + weak_ptr weak() { return weak_ptr(static_cast(this)); } + + void schedule_deletion() { + fdevent_run_on_main_thread([this]() { delete this; }); + } + + private: + friend struct weak_ptr; + std::vector*> weak_ptrs_; +};