From 077ac11106ed8e48eaa629083157c1d6766ae3dc Mon Sep 17 00:00:00 2001 From: Joshua Duong Date: Tue, 7 Apr 2020 15:16:42 -0700 Subject: [PATCH] [adb client] Fix mdns discovery service registry. We were getting stale service ip addresses because we weren't destroying the sdref correctly. Also, we were leaking the ResolvedServices when removing it from the ServiceRegistry. Converted them to smart pointers to fix that. Bug: 153343580 Test: test_adb.py Change-Id: Ib7c1dbf54937d4ac6d9885cb5f7289bef616d12e --- adb/client/transport_mdns.cpp | 89 +++++++++++++++++++++++------------ adb/test_adb.py | 4 +- 2 files changed, 61 insertions(+), 32 deletions(-) diff --git a/adb/client/transport_mdns.cpp b/adb/client/transport_mdns.cpp index 2b6aa7cbb..c9993b787 100644 --- a/adb/client/transport_mdns.cpp +++ b/adb/client/transport_mdns.cpp @@ -144,7 +144,7 @@ class AsyncServiceRef { return initialized_; } - virtual ~AsyncServiceRef() { + void DestroyServiceRef() { if (!initialized_) { return; } @@ -152,9 +152,13 @@ class AsyncServiceRef { // Order matters here! Must destroy the fdevent first since it has a // reference to |sdRef_|. fdevent_destroy(fde_); + D("DNSServiceRefDeallocate(sdRef=%p)", sdRef_); DNSServiceRefDeallocate(sdRef_); + initialized_ = false; } + virtual ~AsyncServiceRef() { DestroyServiceRef(); } + protected: DNSServiceRef sdRef_; @@ -203,6 +207,7 @@ class ResolvedService : public AsyncServiceRef { if (ret != kDNSServiceErr_NoError) { D("Got %d from DNSServiceGetAddrInfo.", ret); } else { + D("DNSServiceGetAddrInfo(sdRef=%p, hosttarget=%s)", sdRef_, hosttarget); Initialize(); } @@ -223,7 +228,7 @@ class ResolvedService : public AsyncServiceRef { return true; } - void Connect(const sockaddr* address) { + bool AddToServiceRegistry(const sockaddr* address) { sa_family_ = address->sa_family; if (sa_family_ == AF_INET) { @@ -234,13 +239,13 @@ class ResolvedService : public AsyncServiceRef { addr_format_ = "[%s]:%hu"; } else { // Should be impossible D("mDNS resolved non-IP address."); - return; + return false; } // Winsock version requires the const cast Because Microsoft. if (!inet_ntop(sa_family_, const_cast(ip_addr_data_), ip_addr_, sizeof(ip_addr_))) { D("Could not convert IP address to string."); - return; + return false; } // adb secure service needs to do something different from just @@ -264,19 +269,32 @@ class ResolvedService : public AsyncServiceRef { } int adbSecureServiceType = serviceIndex(); + ServiceRegistry* services = nullptr; switch (adbSecureServiceType) { case kADBTransportServiceRefIndex: - sAdbTransportServices->push_back(this); + services = sAdbTransportServices; break; case kADBSecurePairingServiceRefIndex: - sAdbSecurePairingServices->push_back(this); + services = sAdbSecurePairingServices; break; case kADBSecureConnectServiceRefIndex: - sAdbSecureConnectServices->push_back(this); + services = sAdbSecureConnectServices; break; default: - break; + LOG(WARNING) << "No registry available for reg_type=[" << regType_ << "]"; + return false; } + + if (!services->empty()) { + // Remove the previous resolved service, if any. + services->erase(std::remove_if(services->begin(), services->end(), + [&](std::unique_ptr& service) { + return (serviceName_ == service->serviceName()); + })); + } + services->push_back(std::unique_ptr(this)); + + return true; } int serviceIndex() const { return adb_DNSServiceIndexByName(regType_.c_str()); } @@ -291,7 +309,7 @@ class ResolvedService : public AsyncServiceRef { uint16_t port() const { return port_; } - using ServiceRegistry = std::vector; + using ServiceRegistry = std::vector>; // unencrypted tcp connections static ServiceRegistry* sAdbTransportServices; @@ -321,13 +339,13 @@ class ResolvedService : public AsyncServiceRef { }; // static -std::vector* ResolvedService::sAdbTransportServices = NULL; +ResolvedService::ServiceRegistry* ResolvedService::sAdbTransportServices = NULL; // static -std::vector* ResolvedService::sAdbSecurePairingServices = NULL; +ResolvedService::ServiceRegistry* ResolvedService::sAdbSecurePairingServices = NULL; // static -std::vector* ResolvedService::sAdbSecureConnectServices = NULL; +ResolvedService::ServiceRegistry* ResolvedService::sAdbSecureConnectServices = NULL; // static void ResolvedService::initAdbServiceRegistries() { @@ -348,7 +366,7 @@ void ResolvedService::forEachService(const ServiceRegistry& services, adb_secure_foreach_service_callback cb) { initAdbServiceRegistries(); - for (auto service : services) { + for (const auto& service : services) { auto service_name = service->serviceName(); auto reg_type = service->regType(); auto ip = service->ipAddress(); @@ -366,7 +384,7 @@ void ResolvedService::forEachService(const ServiceRegistry& services, bool ResolvedService::connectByServiceName(const ServiceRegistry& services, const std::string& service_name) { initAdbServiceRegistries(); - for (auto service : services) { + for (const auto& service : services) { if (service_name == service->serviceName()) { D("Got service_name match [%s]", service->serviceName().c_str()); return service->ConnectSecureWifiDevice(); @@ -393,23 +411,28 @@ bool adb_secure_connect_by_service_name(const char* service_name) { service_name); } -static void DNSSD_API register_service_ip(DNSServiceRef /*sdRef*/, - DNSServiceFlags /*flags*/, +static void DNSSD_API register_service_ip(DNSServiceRef sdRef, DNSServiceFlags flags, uint32_t /*interfaceIndex*/, - DNSServiceErrorType /*errorCode*/, - const char* /*hostname*/, - const sockaddr* address, - uint32_t /*ttl*/, - void* context) { - D("Got IP for service."); + DNSServiceErrorType errorCode, const char* hostname, + const sockaddr* address, uint32_t ttl, void* context) { + D("%s: sdRef=%p flags=0x%08x errorCode=%u ttl=%u", __func__, sdRef, flags, errorCode, ttl); std::unique_ptr data( reinterpret_cast(context)); - data->Connect(address); + // Only resolve the address once. If the address or port changes, we'll just get another + // registration. + data->DestroyServiceRef(); - // For ADB Secure services, keep those ResolvedService's around - // for later processing with secure connection establishment. - if (data->serviceIndex() != kADBTransportServiceRefIndex) { - data.release(); + if (errorCode != kDNSServiceErr_NoError) { + D("Got error while looking up ipaddr [%u]", errorCode); + return; + } + + if (flags & kDNSServiceFlagsAdd) { + D("Resolved IP address for [%s]. Adding to service registry.", hostname); + auto* ptr = data.release(); + if (!ptr->AddToServiceRegistry(address)) { + data.reset(ptr); + } } } @@ -459,6 +482,7 @@ class DiscoveredService : public AsyncServiceRef { }; static void adb_RemoveDNSService(const char* regType, const char* serviceName) { + D("%s: regType=[%s] serviceName=[%s]", __func__, regType, serviceName); int index = adb_DNSServiceIndexByName(regType); ResolvedService::ServiceRegistry* services; switch (index) { @@ -475,10 +499,15 @@ static void adb_RemoveDNSService(const char* regType, const char* serviceName) { return; } + if (services->empty()) { + return; + } + std::string sName(serviceName); - services->erase(std::remove_if( - services->begin(), services->end(), - [&sName](ResolvedService* service) { return (sName == service->serviceName()); })); + services->erase(std::remove_if(services->begin(), services->end(), + [&sName](std::unique_ptr& service) { + return (sName == service->serviceName()); + })); } // Returns the version the device wanted to advertise, diff --git a/adb/test_adb.py b/adb/test_adb.py index 03bdcbd8d..9912f11e8 100755 --- a/adb/test_adb.py +++ b/adb/test_adb.py @@ -659,14 +659,14 @@ class MdnsTest(unittest.TestCase): print(f"Registering {serv_instance}.{serv_type} ...") with zeroconf_register_service(zc, service_info) as info: """Give adb some time to register the service""" - time.sleep(0.25) + time.sleep(1) print(f"services={_mdns_services(server_port)}") self.assertTrue(any((serv_instance in line and serv_type in line) for line in _mdns_services(server_port))) """Give adb some time to unregister the service""" print("Unregistering mdns service...") - time.sleep(0.25) + time.sleep(1) print(f"services={_mdns_services(server_port)}") self.assertFalse(any((serv_instance in line and serv_type in line) for line in _mdns_services(server_port)))