diff --git a/adb/Android.bp b/adb/Android.bp index 32581a255..1004483ac 100644 --- a/adb/Android.bp +++ b/adb/Android.bp @@ -225,9 +225,11 @@ cc_library_host_static { srcs: libadb_srcs + [ "client/auth.cpp", + "client/adb_wifi.cpp", "client/usb_libusb.cpp", "client/usb_dispatch.cpp", "client/transport_mdns.cpp", + "client/pairing/pairing_client.cpp", ], generated_headers: ["platform_tools_version"], @@ -257,6 +259,8 @@ cc_library_host_static { static_libs: [ "libadb_crypto", "libadb_protos", + "libadb_pairing_connection", + "libadb_tls_connection", "libbase", "libcrypto_utils", "libcrypto", @@ -266,6 +270,7 @@ cc_library_host_static { "libutils", "liblog", "libcutils", + "libprotobuf-cpp-lite", ], } @@ -274,8 +279,12 @@ cc_test_host { defaults: ["adb_defaults"], srcs: libadb_test_srcs, static_libs: [ - "libadb_crypto", + "libadb_crypto_static", "libadb_host", + "libadb_pairing_auth_static", + "libadb_pairing_connection_static", + "libadb_protos_static", + "libadb_tls_connection_static", "libbase", "libcutils", "libcrypto_utils", @@ -283,6 +292,8 @@ cc_test_host { "liblog", "libmdnssd", "libdiagnose_usb", + "libprotobuf-cpp-lite", + "libssl", "libusb", ], @@ -314,12 +325,16 @@ cc_benchmark { }, static_libs: [ + "libadb_crypto_static", + "libadb_tls_connection_static", + "libadbd_auth", "libbase", "libcutils", "libcrypto_utils", "libcrypto_static", "libdiagnose_usb", "liblog", + "libssl", "libusb", ], } @@ -354,6 +369,10 @@ cc_binary_host { static_libs: [ "libadb_crypto", "libadb_host", + "libadb_pairing_auth", + "libadb_pairing_connection", + "libadb_protos", + "libadb_tls_connection", "libandroidfw", "libbase", "libcutils", @@ -365,6 +384,7 @@ cc_binary_host { "liblz4", "libmdnssd", "libprotobuf-cpp-lite", + "libssl", "libusb", "libutils", "liblog", @@ -415,6 +435,7 @@ cc_library_static { srcs: libadb_srcs + libadb_linux_srcs + libadb_posix_srcs + [ "daemon/auth.cpp", "daemon/jdwp_service.cpp", + "daemon/adb_wifi.cpp", ], local_include_dirs: [ @@ -430,6 +451,9 @@ cc_library_static { shared_libs: [ "libadb_crypto", + "libadb_pairing_connection", + "libadb_protos", + "libadb_tls_connection", "libadbd_auth", "libasyncio", "libbase", @@ -484,6 +508,10 @@ cc_library { ], shared_libs: [ + "libadb_crypto", + "libadb_pairing_connection", + "libadb_protos", + "libadb_tls_connection", "libadbd_auth", "libasyncio", "libbase", @@ -532,6 +560,9 @@ cc_library { ], shared_libs: [ + "libadb_crypto", + "libadb_pairing_connection", + "libadb_tls_connection", "libadbd_auth", "libadbd_services", "libasyncio", @@ -580,9 +611,14 @@ cc_binary { "libmdnssd", "libminijail", "libselinux", + "libssl", ], shared_libs: [ + "libadb_crypto", + "libadb_pairing_connection", + "libadb_protos", + "libadb_tls_connection", "libadbd_auth", "libcrypto", ], @@ -659,6 +695,9 @@ cc_test { static_libs: [ "libadbd", "libadbd_auth", + "libadb_crypto_static", + "libadb_pairing_connection_static", + "libadb_tls_connection_static", "libbase", "libcutils", "libcrypto_utils", @@ -773,8 +812,12 @@ cc_test_host { "fastdeploy/deploypatchgenerator/patch_utils_test.cpp", ], static_libs: [ - "libadb_crypto", + "libadb_crypto_static", "libadb_host", + "libadb_pairing_auth_static", + "libadb_pairing_connection_static", + "libadb_protos_static", + "libadb_tls_connection_static", "libandroidfw", "libbase", "libcutils", @@ -785,6 +828,7 @@ cc_test_host { "liblog", "libmdnssd", "libprotobuf-cpp-lite", + "libssl", "libusb", "libutils", "libziparchive", diff --git a/adb/adb.cpp b/adb/adb.cpp index 460ddde55..554a754fd 100644 --- a/adb/adb.cpp +++ b/adb/adb.cpp @@ -52,6 +52,7 @@ #include "adb_listeners.h" #include "adb_unique_fd.h" #include "adb_utils.h" +#include "adb_wifi.h" #include "sysdeps/chrono.h" #include "transport.h" @@ -140,6 +141,9 @@ void print_packet(const char *label, apacket *p) case A_CLSE: tag = "CLSE"; break; case A_WRTE: tag = "WRTE"; break; case A_AUTH: tag = "AUTH"; break; + case A_STLS: + tag = "ATLS"; + break; default: tag = "????"; break; } @@ -209,6 +213,15 @@ std::string get_connection_string() { android::base::Join(connection_properties, ';').c_str()); } +void send_tls_request(atransport* t) { + D("Calling send_tls_request"); + apacket* p = get_apacket(); + p->msg.command = A_STLS; + p->msg.arg0 = A_STLS_VERSION; + p->msg.data_length = 0; + send_packet(p, t); +} + void send_connect(atransport* t) { D("Calling send_connect"); apacket* cp = get_apacket(); @@ -299,7 +312,12 @@ static void handle_new_connection(atransport* t, apacket* p) { #if ADB_HOST handle_online(t); #else - if (!auth_required) { + if (t->use_tls) { + // We still handshake in TLS mode. If auth_required is disabled, + // we'll just not verify the client's certificate. This should be the + // first packet the client receives to indicate the new protocol. + send_tls_request(t); + } else if (!auth_required) { LOG(INFO) << "authentication not required"; handle_online(t); send_connect(t); @@ -324,8 +342,21 @@ void handle_packet(apacket *p, atransport *t) case A_CNXN: // CONNECT(version, maxdata, "system-id-string") handle_new_connection(t, p); break; + case A_STLS: // TLS(version, "") + t->use_tls = true; +#if ADB_HOST + send_tls_request(t); + adb_auth_tls_handshake(t); +#else + adbd_auth_tls_handshake(t); +#endif + break; case A_AUTH: + // All AUTH commands are ignored in TLS mode + if (t->use_tls) { + break; + } switch (p->msg.arg0) { #if ADB_HOST case ADB_AUTH_TOKEN: diff --git a/adb/adb.h b/adb/adb.h index 7f7dd0d2e..86d205c98 100644 --- a/adb/adb.h +++ b/adb/adb.h @@ -44,6 +44,7 @@ constexpr size_t LINUX_MAX_SOCKET_SIZE = 4194304; #define A_CLSE 0x45534c43 #define A_WRTE 0x45545257 #define A_AUTH 0x48545541 +#define A_STLS 0x534C5453 // ADB protocol version. // Version revision: @@ -53,6 +54,10 @@ constexpr size_t LINUX_MAX_SOCKET_SIZE = 4194304; #define A_VERSION_SKIP_CHECKSUM 0x01000001 #define A_VERSION 0x01000001 +// Stream-based TLS protocol version +#define A_STLS_VERSION_MIN 0x01000000 +#define A_STLS_VERSION 0x01000000 + // Used for help/version information. #define ADB_VERSION_MAJOR 1 #define ADB_VERSION_MINOR 0 @@ -229,6 +234,7 @@ void handle_online(atransport* t); void handle_offline(atransport* t); void send_connect(atransport* t); +void send_tls_request(atransport* t); void parse_banner(const std::string&, atransport* t); diff --git a/adb/adb_auth.h b/adb/adb_auth.h index 2be9a7684..7e858dce4 100644 --- a/adb/adb_auth.h +++ b/adb/adb_auth.h @@ -38,10 +38,14 @@ void adb_auth_init(); int adb_auth_keygen(const char* filename); int adb_auth_pubkey(const char* filename); std::string adb_auth_get_userkey(); +bssl::UniquePtr adb_auth_get_user_privkey(); std::deque> adb_auth_get_private_keys(); void send_auth_response(const char* token, size_t token_size, atransport* t); +int adb_tls_set_certificate(SSL* ssl); +void adb_auth_tls_handshake(atransport* t); + #else // !ADB_HOST extern bool auth_required; @@ -57,6 +61,10 @@ void adbd_notify_framework_connected_key(atransport* t); void send_auth_request(atransport *t); +void adbd_auth_tls_handshake(atransport* t); +int adbd_tls_verify_cert(X509_STORE_CTX* ctx, std::string* auth_key); +bssl::UniquePtr adbd_tls_client_ca_list(); + #endif // ADB_HOST #endif // __ADB_AUTH_H diff --git a/adb/adb_wifi.h b/adb/adb_wifi.h new file mode 100644 index 000000000..585748c91 --- /dev/null +++ b/adb/adb_wifi.h @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2019 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "adb.h" + +#if ADB_HOST + +void adb_wifi_init(void); +void adb_wifi_pair_device(const std::string& host, const std::string& password, + std::string& response); +bool adb_wifi_is_known_host(const std::string& host); + +#else // !ADB_HOST + +struct AdbdAuthContext; + +void adbd_wifi_init(AdbdAuthContext* ctx); +void adbd_wifi_secure_connect(atransport* t); + +#endif diff --git a/adb/apex/Android.bp b/adb/apex/Android.bp index b62a8ff14..01894552a 100644 --- a/adb/apex/Android.bp +++ b/adb/apex/Android.bp @@ -5,7 +5,12 @@ apex_defaults { compile_multilib: "both", multilib: { both: { - native_shared_libs: ["libadbconnection_client"], + native_shared_libs: [ + "libadb_pairing_auth", + "libadb_pairing_connection", + "libadb_pairing_server", + "libadbconnection_client", + ], }, }, prebuilts: ["com.android.adbd.init.rc"], diff --git a/adb/client/adb_client.h b/adb/client/adb_client.h index 758fcab42..1c6cde77d 100644 --- a/adb/client/adb_client.h +++ b/adb/client/adb_client.h @@ -91,12 +91,15 @@ extern const char* _Nullable * _Nullable __adb_envp; // ADB Secure DNS service interface. Used to query what ADB Secure DNS services have been // resolved, and to run some kind of callback for each one. using adb_secure_foreach_service_callback = std::function; + const char* _Nonnull service_name, const char* _Nonnull ip_address, uint16_t port)>; // Queries pairing/connect services that have been discovered and resolved. // If |host_name| is not null, run |cb| only for services // matching |host_name|. Otherwise, run for all services. -void adb_secure_foreach_pairing_service(const char* _Nullable host_name, +void adb_secure_foreach_pairing_service(const char* _Nullable service_name, adb_secure_foreach_service_callback cb); -void adb_secure_foreach_connect_service(const char* _Nullable host_name, +void adb_secure_foreach_connect_service(const char* _Nullable service_name, adb_secure_foreach_service_callback cb); +// Tries to connect to a |service_name| if found. Returns true if found and +// connected, false otherwise. +bool adb_secure_connect_by_service_name(const char* _Nonnull service_name); diff --git a/adb/client/adb_wifi.cpp b/adb/client/adb_wifi.cpp new file mode 100644 index 000000000..fa7102811 --- /dev/null +++ b/adb/client/adb_wifi.cpp @@ -0,0 +1,246 @@ +/* + * Copyright (C) 2019 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "adb_wifi.h" + +#include +#include +#include + +#include +#include +#include +#include +#include "client/pairing/pairing_client.h" + +#include "adb_auth.h" +#include "adb_known_hosts.pb.h" +#include "adb_utils.h" +#include "client/adb_client.h" +#include "sysdeps.h" + +using adbwifi::pairing::PairingClient; +using namespace adb::crypto; + +struct PairingResultWaiter { + std::mutex mutex_; + std::condition_variable cv_; + std::optional is_valid_; + PeerInfo peer_info_; + + static void OnResult(const PeerInfo* peer_info, void* opaque) { + CHECK(opaque); + auto* p = reinterpret_cast(opaque); + { + std::lock_guard lock(p->mutex_); + if (peer_info) { + memcpy(&(p->peer_info_), peer_info, sizeof(PeerInfo)); + } + p->is_valid_ = (peer_info != nullptr); + } + p->cv_.notify_one(); + } +}; // PairingResultWaiter + +void adb_wifi_init() {} + +static std::vector stringToUint8(const std::string& str) { + auto* p8 = reinterpret_cast(str.data()); + return std::vector(p8, p8 + str.length()); +} + +// Tries to replace the |old_file| with |new_file|. +// On success, then |old_file| has been removed and replaced with the +// contents of |new_file|, |new_file| will be removed, and only |old_file| will +// remain. +// On failure, both files will be unchanged. +// |new_file| must exist, but |old_file| does not need to exist. +bool SafeReplaceFile(std::string_view old_file, std::string_view new_file) { + std::string to_be_deleted(old_file); + to_be_deleted += ".tbd"; + + bool old_renamed = true; + if (adb_rename(old_file.data(), to_be_deleted.c_str()) != 0) { + // Don't exit here. This is not necessarily an error, because |old_file| + // may not exist. + PLOG(INFO) << "Failed to rename " << old_file; + old_renamed = false; + } + + if (adb_rename(new_file.data(), old_file.data()) != 0) { + PLOG(ERROR) << "Unable to rename file (" << new_file << " => " << old_file << ")"; + if (old_renamed) { + // Rename the .tbd file back to it's original name + adb_rename(to_be_deleted.c_str(), old_file.data()); + } + return false; + } + + adb_unlink(to_be_deleted.c_str()); + return true; +} + +static std::string get_user_known_hosts_path() { + return adb_get_android_dir_path() + OS_PATH_SEPARATOR + "adb_known_hosts.pb"; +} + +bool load_known_hosts_from_file(const std::string& path, adb::proto::AdbKnownHosts& known_hosts) { + // Check for file existence. + struct stat buf; + if (stat(path.c_str(), &buf) == -1) { + LOG(INFO) << "Known hosts file [" << path << "] does not exist..."; + return false; + } + + std::ifstream file(path, std::ios::binary); + if (!file) { + PLOG(ERROR) << "Unable to open [" << path << "]."; + return false; + } + + if (!known_hosts.ParseFromIstream(&file)) { + PLOG(ERROR) << "Failed to parse [" << path << "]. Deleting it as it may be corrupted."; + adb_unlink(path.c_str()); + return false; + } + + return true; +} + +static bool write_known_host_to_file(std::string& known_host) { + std::string path = get_user_known_hosts_path(); + if (path.empty()) { + PLOG(ERROR) << "Error getting user known hosts filename"; + return false; + } + + adb::proto::AdbKnownHosts known_hosts; + load_known_hosts_from_file(path, known_hosts); + auto* host_info = known_hosts.add_host_infos(); + host_info->set_guid(known_host); + + std::unique_ptr temp_file(new TemporaryFile(adb_get_android_dir_path())); + if (temp_file->fd == -1) { + PLOG(ERROR) << "Failed to open [" << temp_file->path << "] for writing"; + return false; + } + + if (!known_hosts.SerializeToFileDescriptor(temp_file->fd)) { + LOG(ERROR) << "Unable to write out adb_knowns_hosts"; + return false; + } + temp_file->DoNotRemove(); + std::string temp_file_name(temp_file->path); + temp_file.reset(); + + // Replace the existing adb_known_hosts with the new one + if (!SafeReplaceFile(path, temp_file_name.c_str())) { + LOG(ERROR) << "Failed to replace old adb_known_hosts"; + adb_unlink(temp_file_name.c_str()); + return false; + } + chmod(path.c_str(), S_IRUSR | S_IWUSR | S_IRGRP); + + return true; +} + +bool adb_wifi_is_known_host(const std::string& host) { + std::string path = get_user_known_hosts_path(); + if (path.empty()) { + PLOG(ERROR) << "Error getting user known hosts filename"; + return false; + } + + adb::proto::AdbKnownHosts known_hosts; + if (!load_known_hosts_from_file(path, known_hosts)) { + return false; + } + + for (const auto& host_info : known_hosts.host_infos()) { + if (host == host_info.guid()) { + return true; + } + } + return false; +} + +void adb_wifi_pair_device(const std::string& host, const std::string& password, + std::string& response) { + // Check the address for a valid address and port. + std::string parsed_host; + std::string err; + int port = -1; + if (!android::base::ParseNetAddress(host, &parsed_host, &port, nullptr, &err)) { + response = "Failed to parse address for pairing: " + err; + return; + } + if (port <= 0 || port > 65535) { + response = "Invalid port while parsing address [" + host + "]"; + return; + } + + auto priv_key = adb_auth_get_user_privkey(); + auto x509_cert = GenerateX509Certificate(priv_key.get()); + if (!x509_cert) { + LOG(ERROR) << "Unable to create X509 certificate for pairing"; + return; + } + auto cert_str = X509ToPEMString(x509_cert.get()); + auto priv_str = Key::ToPEMString(priv_key.get()); + + // Send our public key on pairing success + PeerInfo system_info = {}; + system_info.type = ADB_RSA_PUB_KEY; + std::string public_key = adb_auth_get_userkey(); + CHECK_LE(public_key.size(), sizeof(system_info.data) - 1); // -1 for null byte + memcpy(system_info.data, public_key.data(), public_key.size()); + + auto pswd8 = stringToUint8(password); + auto cert8 = stringToUint8(cert_str); + auto priv8 = stringToUint8(priv_str); + + auto client = PairingClient::Create(pswd8, system_info, cert8, priv8); + if (client == nullptr) { + response = "Failed: unable to create pairing client."; + return; + } + + PairingResultWaiter waiter; + std::unique_lock lock(waiter.mutex_); + if (!client->Start(host, waiter.OnResult, &waiter)) { + response = "Failed: Unable to start pairing client."; + return; + } + waiter.cv_.wait(lock, [&]() { return waiter.is_valid_.has_value(); }); + if (!*(waiter.is_valid_)) { + response = "Failed: Wrong password or connection was dropped."; + return; + } + + if (waiter.peer_info_.type != ADB_DEVICE_GUID) { + response = "Failed: Successfully paired but server returned unknown response="; + response += waiter.peer_info_.type; + return; + } + + std::string device_guid = reinterpret_cast(waiter.peer_info_.data); + response = "Successfully paired to " + host + " [guid=" + device_guid + "]"; + + // Write to adb_known_hosts + write_known_host_to_file(device_guid); + // Try to auto-connect. + adb_secure_connect_by_service_name(device_guid.c_str()); +} diff --git a/adb/client/auth.cpp b/adb/client/auth.cpp index dcf4bc0ad..8738ce77a 100644 --- a/adb/client/auth.cpp +++ b/adb/client/auth.cpp @@ -30,6 +30,9 @@ #include #include +#include +#include +#include #include #include #include @@ -55,6 +58,7 @@ static std::map>& g_keys = static std::map& g_monitored_paths = *new std::map; using namespace adb::crypto; +using namespace adb::tls; static bool generate_key(const std::string& file) { LOG(INFO) << "generate_key(" << file << ")..."; @@ -144,6 +148,7 @@ static bool load_key(const std::string& file) { if (g_keys.find(fingerprint) != g_keys.end()) { LOG(INFO) << "ignoring already-loaded key: " << file; } else { + LOG(INFO) << "Loaded fingerprint=[" << SHA256BitsToHexString(fingerprint) << "]"; g_keys[fingerprint] = std::move(key); } return true; @@ -279,6 +284,28 @@ static bool pubkey_from_privkey(std::string* out, const std::string& path) { return CalculatePublicKey(out, privkey.get()); } +bssl::UniquePtr adb_auth_get_user_privkey() { + std::string path = get_user_key_path(); + if (path.empty()) { + PLOG(ERROR) << "Error getting user key filename"; + return nullptr; + } + + std::shared_ptr rsa_privkey = read_key_file(path); + if (!rsa_privkey) { + return nullptr; + } + + bssl::UniquePtr pkey(EVP_PKEY_new()); + if (!pkey) { + LOG(ERROR) << "Failed to allocate key"; + return nullptr; + } + + EVP_PKEY_set1_RSA(pkey.get(), rsa_privkey.get()); + return pkey; +} + std::string adb_auth_get_userkey() { std::string path = get_user_key_path(); if (path.empty()) { @@ -453,3 +480,72 @@ void send_auth_response(const char* token, size_t token_size, atransport* t) { p->msg.data_length = p->payload.size(); send_packet(p, t); } + +void adb_auth_tls_handshake(atransport* t) { + std::thread([t]() { + std::shared_ptr key = t->Key(); + if (key == nullptr) { + // Can happen if !auth_required + LOG(INFO) << "t->auth_key not set before handshake"; + key = t->NextKey(); + CHECK(key); + } + + LOG(INFO) << "Attempting to TLS handshake"; + bool success = t->connection()->DoTlsHandshake(key.get()); + if (success) { + LOG(INFO) << "Handshake succeeded. Waiting for CNXN packet..."; + } else { + LOG(INFO) << "Handshake failed. Kicking transport"; + t->Kick(); + } + }).detach(); +} + +int adb_tls_set_certificate(SSL* ssl) { + LOG(INFO) << __func__; + + const STACK_OF(X509_NAME)* ca_list = SSL_get_client_CA_list(ssl); + if (ca_list == nullptr) { + // Either the device doesn't know any keys, or !auth_required. + // So let's just try with the default certificate and see what happens. + LOG(INFO) << "No client CA list. Trying with default certificate."; + return 1; + } + + const size_t num_cas = sk_X509_NAME_num(ca_list); + for (size_t i = 0; i < num_cas; ++i) { + auto* x509_name = sk_X509_NAME_value(ca_list, i); + auto adbFingerprint = ParseEncodedKeyFromCAIssuer(x509_name); + if (!adbFingerprint.has_value()) { + // This could be a real CA issuer. Unfortunately, we don't support + // it ATM. + continue; + } + + LOG(INFO) << "Checking for fingerprint match [" << *adbFingerprint << "]"; + auto encoded_key = SHA256HexStringToBits(*adbFingerprint); + if (!encoded_key.has_value()) { + continue; + } + // Check against our list of encoded keys for a match + std::lock_guard lock(g_keys_mutex); + auto rsa_priv_key = g_keys.find(*encoded_key); + if (rsa_priv_key != g_keys.end()) { + LOG(INFO) << "Got SHA256 match on a key"; + bssl::UniquePtr evp_pkey(EVP_PKEY_new()); + CHECK(EVP_PKEY_set1_RSA(evp_pkey.get(), rsa_priv_key->second.get())); + auto x509 = GenerateX509Certificate(evp_pkey.get()); + auto x509_str = X509ToPEMString(x509.get()); + auto evp_str = Key::ToPEMString(evp_pkey.get()); + TlsConnection::SetCertAndKey(ssl, x509_str, evp_str); + return 1; + } else { + LOG(INFO) << "No match for [" << *adbFingerprint << "]"; + } + } + + // Let's just try with the default certificate anyways, because daemon might + // not require auth, even though it has a list of keys. + return 1; +} diff --git a/adb/client/commandline.cpp b/adb/client/commandline.cpp index 84c0e0134..081bac4df 100644 --- a/adb/client/commandline.cpp +++ b/adb/client/commandline.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -97,8 +98,10 @@ static void help() { " version show version num\n" "\n" "networking:\n" - " connect HOST[:PORT] connect to a device via TCP/IP\n" - " disconnect [[HOST]:PORT] disconnect from given TCP/IP device, or all\n" + " connect HOST[:PORT] connect to a device via TCP/IP [default port=5555]\n" + " disconnect [HOST[:PORT]]\n" + " disconnect from given TCP/IP device [default port=5555], or all\n" + " pair HOST[:PORT] pair with a device for secure TCP/IP communication\n" " forward --list list all forward socket connections\n" " forward [--no-rebind] LOCAL REMOTE\n" " forward socket connection using:\n" @@ -1638,6 +1641,19 @@ int adb_commandline(int argc, const char** argv) { return adb_query_command(query); } else if (!strcmp(argv[0], "abb")) { return adb_abb(argc, argv); + } else if (!strcmp(argv[0], "pair")) { + if (argc != 2) error_exit("usage: adb pair [:]"); + + std::string password; + printf("Enter pairing code: "); + fflush(stdout); + if (!std::getline(std::cin, password) || password.empty()) { + error_exit("No pairing code provided"); + } + std::string query = + android::base::StringPrintf("host:pair:%s:%s", password.c_str(), argv[1]); + + return adb_query_command(query); } else if (!strcmp(argv[0], "emu")) { return adb_send_emulator_command(argc, argv, serial); } else if (!strcmp(argv[0], "shell")) { diff --git a/adb/client/main.cpp b/adb/client/main.cpp index e5ffe4c09..a85a18c4e 100644 --- a/adb/client/main.cpp +++ b/adb/client/main.cpp @@ -35,6 +35,7 @@ #include "adb_client.h" #include "adb_listeners.h" #include "adb_utils.h" +#include "adb_wifi.h" #include "commandline.h" #include "sysdeps/chrono.h" #include "transport.h" @@ -118,6 +119,7 @@ int adb_server_main(int is_daemon, const std::string& socket_spec, int ack_reply init_transport_registration(); init_reconnect_handler(); + adb_wifi_init(); if (!getenv("ADB_MDNS") || strcmp(getenv("ADB_MDNS"), "0") != 0) { init_mdns_transport_discovery(); } diff --git a/adb/client/pairing/pairing_client.cpp b/adb/client/pairing/pairing_client.cpp new file mode 100644 index 000000000..2f878bf04 --- /dev/null +++ b/adb/client/pairing/pairing_client.cpp @@ -0,0 +1,172 @@ +/* + * Copyright (C) 2019 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "client/pairing/pairing_client.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include "sysdeps.h" + +namespace adbwifi { +namespace pairing { + +using android::base::unique_fd; + +namespace { + +struct ConnectionDeleter { + void operator()(PairingConnectionCtx* p) { pairing_connection_destroy(p); } +}; // ConnectionDeleter +using ConnectionPtr = std::unique_ptr; + +class PairingClientImpl : public PairingClient { + public: + virtual ~PairingClientImpl(); + + explicit PairingClientImpl(const Data& pswd, const PeerInfo& peer_info, const Data& cert, + const Data& priv_key); + + // Starts the pairing client. This call is non-blocking. Upon pairing + // completion, |cb| will be called with the PeerInfo on success, + // or an empty value on failure. + // + // Returns true if PairingClient was successfully started. Otherwise, + // return false. + virtual bool Start(std::string_view ip_addr, pairing_client_result_cb cb, + void* opaque) override; + + static void OnPairingResult(const PeerInfo* peer_info, int fd, void* opaque); + + private: + // Setup and start the PairingConnection + bool StartConnection(); + + enum class State { + Ready, + Running, + Stopped, + }; + + State state_ = State::Ready; + Data pswd_; + PeerInfo peer_info_; + Data cert_; + Data priv_key_; + std::string host_; + int port_; + + ConnectionPtr connection_; + pairing_client_result_cb cb_; + void* opaque_ = nullptr; +}; // PairingClientImpl + +PairingClientImpl::PairingClientImpl(const Data& pswd, const PeerInfo& peer_info, const Data& cert, + const Data& priv_key) + : pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key) { + CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty()); + + state_ = State::Ready; +} + +PairingClientImpl::~PairingClientImpl() { + // Make sure to kill the PairingConnection before terminating the fdevent + // looper. + if (connection_ != nullptr) { + connection_.reset(); + } +} + +bool PairingClientImpl::Start(std::string_view ip_addr, pairing_client_result_cb cb, void* opaque) { + CHECK(!ip_addr.empty()); + cb_ = cb; + opaque_ = opaque; + + if (state_ != State::Ready) { + LOG(ERROR) << "PairingClient already running or finished"; + return false; + } + + // Try to parse the host address + std::string err; + CHECK(android::base::ParseNetAddress(std::string(ip_addr), &host_, &port_, nullptr, &err)); + CHECK(port_ > 0 && port_ <= 65535); + + if (!StartConnection()) { + LOG(ERROR) << "Unable to start PairingClient connection"; + state_ = State::Stopped; + return false; + } + + state_ = State::Running; + return true; +} + +bool PairingClientImpl::StartConnection() { + std::string err; + const int timeout = 10; // seconds + unique_fd fd(network_connect(host_, port_, SOCK_STREAM, timeout, &err)); + if (fd.get() == -1) { + LOG(ERROR) << "Failed to start pairing connection client [" << err << "]"; + return false; + } + int off = 1; + adb_setsockopt(fd.get(), IPPROTO_TCP, TCP_NODELAY, &off, sizeof(off)); + + connection_ = ConnectionPtr( + pairing_connection_client_new(pswd_.data(), pswd_.size(), &peer_info_, cert_.data(), + cert_.size(), priv_key_.data(), priv_key_.size())); + CHECK(connection_); + + if (!pairing_connection_start(connection_.get(), fd.release(), OnPairingResult, this)) { + LOG(ERROR) << "PairingClient failed to start the PairingConnection"; + state_ = State::Stopped; + return false; + } + + return true; +} + +// static +void PairingClientImpl::OnPairingResult(const PeerInfo* peer_info, int /* fd */, void* opaque) { + auto* p = reinterpret_cast(opaque); + p->cb_(peer_info, p->opaque_); +} + +} // namespace + +// static +std::unique_ptr PairingClient::Create(const Data& pswd, const PeerInfo& peer_info, + const Data& cert, const Data& priv_key) { + CHECK(!pswd.empty()); + CHECK(!cert.empty()); + CHECK(!priv_key.empty()); + + return std::unique_ptr(new PairingClientImpl(pswd, peer_info, cert, priv_key)); +} + +} // namespace pairing +} // namespace adbwifi diff --git a/adb/client/pairing/pairing_client.h b/adb/client/pairing/pairing_client.h new file mode 100644 index 000000000..dbd72a549 --- /dev/null +++ b/adb/client/pairing/pairing_client.h @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2019 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "adb/pairing/pairing_connection.h" + +namespace adbwifi { +namespace pairing { + +typedef void (*pairing_client_result_cb)(const PeerInfo*, void*); + +// PairingClient is the client side of the PairingConnection protocol. It will +// attempt to connect to a PairingServer specified at |host| and |port|, and +// allocate a new PairingConnection for processing. +// +// See pairing_connection_test.cpp for example usage. +// +class PairingClient { + public: + using Data = std::vector; + + virtual ~PairingClient() = default; + + // Starts the pairing client. This call is non-blocking. Upon completion, + // if the pairing was successful, then |cb| will be called with the PeerInfo + // containing the info of the trusted peer. Otherwise, |cb| will be + // called with an empty value. Start can only be called once in the lifetime + // of this object. + // + // Returns true if PairingClient was successfully started. Otherwise, + // returns false. + virtual bool Start(std::string_view ip_addr, pairing_client_result_cb cb, void* opaque) = 0; + + // Creates a new PairingClient instance. May return null if unable + // to create an instance. |pswd|, |certificate|, |priv_key| and + // |ip_addr| cannot be empty. |peer_info| must contain non-empty strings for + // the guid and name fields. + static std::unique_ptr Create(const Data& pswd, const PeerInfo& peer_info, + const Data& certificate, const Data& priv_key); + + protected: + PairingClient() = default; +}; // class PairingClient + +} // namespace pairing +} // namespace adbwifi diff --git a/adb/client/pairing/tests/pairing_connection_test.cpp b/adb/client/pairing/tests/pairing_connection_test.cpp new file mode 100644 index 000000000..c69c1c2be --- /dev/null +++ b/adb/client/pairing/tests/pairing_connection_test.cpp @@ -0,0 +1,473 @@ +/* + * Copyright 2019 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define LOG_TAG "AdbWifiPairingConnectionTest" + +#include +#include +#include + +#include +#include +#include + +#include "adb/client/pairing/tests/pairing_client.h" + +namespace adbwifi { +namespace pairing { + +static const std::string kTestServerCert = + "-----BEGIN CERTIFICATE-----\n" + "MIIBljCCAT2gAwIBAgIBATAKBggqhkjOPQQDAjAzMQswCQYDVQQGEwJVUzEQMA4G\n" + "A1UECgwHQW5kcm9pZDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTE5MTEwNzAyMDkx\n" + "NVoXDTI5MTEwNDAyMDkxNVowMzELMAkGA1UEBhMCVVMxEDAOBgNVBAoMB0FuZHJv\n" + "aWQxEjAQBgNVBAMMCWxvY2FsaG9zdDBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IA\n" + "BCXRovy3RhtK0Khle48vUmkcuI0OF7K8o9sVPE4oVnp24l+cCYr3BtrgifoHPgj4\n" + "vq7n105qzK7ngBHH+LBmYIijQjBAMA8GA1UdEwEB/wQFMAMBAf8wDgYDVR0PAQH/\n" + "BAQDAgGGMB0GA1UdDgQWBBQi4eskzqVG3SCX2CwJF/aTZqUcuTAKBggqhkjOPQQD\n" + "AgNHADBEAiBPYvLOCIvPDtq3vMF7A2z7t7JfcCmbC7g8ftEVJucJBwIgepf+XjTb\n" + "L7RCE16p7iVkpHUrWAOl7zDxqD+jaji5MkQ=\n" + "-----END CERTIFICATE-----\n"; + +static const std::string kTestServerPrivKey = + "-----BEGIN PRIVATE KEY-----\n" + "MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgSCaskWPtutIgh8uQ\n" + "UBH6ZIea5Kxm7m6kkGNkd8FYPSOhRANCAAQl0aL8t0YbStCoZXuPL1JpHLiNDhey\n" + "vKPbFTxOKFZ6duJfnAmK9wba4In6Bz4I+L6u59dOasyu54ARx/iwZmCI\n" + "-----END PRIVATE KEY-----\n"; + +static const std::string kTestClientCert = + "-----BEGIN CERTIFICATE-----\n" + "MIIBlzCCAT2gAwIBAgIBATAKBggqhkjOPQQDAjAzMQswCQYDVQQGEwJVUzEQMA4G\n" + "A1UECgwHQW5kcm9pZDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTE5MTEwOTAxNTAy\n" + "OFoXDTI5MTEwNjAxNTAyOFowMzELMAkGA1UEBhMCVVMxEDAOBgNVBAoMB0FuZHJv\n" + "aWQxEjAQBgNVBAMMCWxvY2FsaG9zdDBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IA\n" + "BGW+RuoEIzbt42zAuZzbXaC0bvh8n4OLFDnqkkW6kWA43GYg/mUMVc9vg/nuxyuM\n" + "aT0KqbTaLhm+NjCXVRnxBrajQjBAMA8GA1UdEwEB/wQFMAMBAf8wDgYDVR0PAQH/\n" + "BAQDAgGGMB0GA1UdDgQWBBTjCaC8/NXgdBz9WlMVCNwhx7jn0jAKBggqhkjOPQQD\n" + "AgNIADBFAiB/xp2boj7b1KK2saS6BL59deo/TvfgZ+u8HPq4k4VP3gIhAMXswp9W\n" + "XdlziccQdj+0KpbUojDKeHOr4fIj/+LxsWPa\n" + "-----END CERTIFICATE-----\n"; + +static const std::string kTestClientPrivKey = + "-----BEGIN PRIVATE KEY-----\n" + "MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgFw/CWY1f6TSB70AF\n" + "yVe8n6QdYFu8HW5t/tij2SrXx42hRANCAARlvkbqBCM27eNswLmc212gtG74fJ+D\n" + "ixQ56pJFupFgONxmIP5lDFXPb4P57scrjGk9Cqm02i4ZvjYwl1UZ8Qa2\n" + "-----END PRIVATE KEY-----\n"; + +class AdbWifiPairingConnectionTest : public testing::Test { + protected: + virtual void SetUp() override {} + + virtual void TearDown() override {} + + void initPairing(const std::vector server_pswd, + const std::vector client_pswd) { + std::vector cert; + std::vector key; + // Include the null-byte as well. + cert.assign(reinterpret_cast(kTestServerCert.data()), + reinterpret_cast(kTestServerCert.data()) + + kTestServerCert.size() + 1); + key.assign(reinterpret_cast(kTestServerPrivKey.data()), + reinterpret_cast(kTestServerPrivKey.data()) + + kTestServerPrivKey.size() + 1); + server_ = PairingServer::create(server_pswd, server_info_, cert, key, kDefaultPairingPort); + cert.assign(reinterpret_cast(kTestClientCert.data()), + reinterpret_cast(kTestClientCert.data()) + + kTestClientCert.size() + 1); + key.assign(reinterpret_cast(kTestClientPrivKey.data()), + reinterpret_cast(kTestClientPrivKey.data()) + + kTestClientPrivKey.size() + 1); + client_ = PairingClient::create(client_pswd, client_info_, cert, key, "127.0.0.1"); + } + + std::unique_ptr createServer(const std::vector pswd) { + std::vector cert; + std::vector key; + // Include the null-byte as well. + cert.assign(reinterpret_cast(kTestServerCert.data()), + reinterpret_cast(kTestServerCert.data()) + + kTestServerCert.size() + 1); + key.assign(reinterpret_cast(kTestServerPrivKey.data()), + reinterpret_cast(kTestServerPrivKey.data()) + + kTestServerPrivKey.size() + 1); + return PairingServer::create(pswd, server_info_, cert, key, kDefaultPairingPort); + } + + std::unique_ptr createClient(const std::vector pswd) { + std::vector cert; + std::vector key; + // Include the null-byte as well. + cert.assign(reinterpret_cast(kTestClientCert.data()), + reinterpret_cast(kTestClientCert.data()) + + kTestClientCert.size() + 1); + key.assign(reinterpret_cast(kTestClientPrivKey.data()), + reinterpret_cast(kTestClientPrivKey.data()) + + kTestClientPrivKey.size() + 1); + return PairingClient::create(pswd, client_info_, cert, key, "127.0.0.1"); + } + + std::unique_ptr server_; + const PeerInfo server_info_ = { + .name = "my_server_name", + .guid = "my_server_guid", + }; + std::unique_ptr client_; + const PeerInfo client_info_ = { + .name = "my_client_name", + .guid = "my_client_guid", + }; +}; + +TEST_F(AdbWifiPairingConnectionTest, ServerCreation) { + // All parameters bad + auto server = PairingServer::create({}, {}, {}, {}, -1); + EXPECT_EQ(nullptr, server); + // Bad password + server = PairingServer::create({}, server_info_, {0x01}, {0x01}, -1); + EXPECT_EQ(nullptr, server); + // Bad peer_info + server = PairingServer::create({0x01}, {}, {0x01}, {0x01}, -1); + EXPECT_EQ(nullptr, server); + // Bad certificate + server = PairingServer::create({0x01}, server_info_, {}, {0x01}, -1); + EXPECT_EQ(nullptr, server); + // Bad private key + server = PairingServer::create({0x01}, server_info_, {0x01}, {}, -1); + EXPECT_EQ(nullptr, server); + // Bad port + server = PairingServer::create({0x01}, server_info_, {0x01}, {0x01}, -1); + EXPECT_EQ(nullptr, server); + // Valid params + server = PairingServer::create({0x01}, server_info_, {0x01}, {0x01}, 7776); + EXPECT_NE(nullptr, server); +} + +TEST_F(AdbWifiPairingConnectionTest, ClientCreation) { + // All parameters bad + auto client = PairingClient::create({}, client_info_, {}, {}, ""); + EXPECT_EQ(nullptr, client); + // Bad password + client = PairingClient::create({}, client_info_, {0x01}, {0x01}, "127.0.0.1"); + EXPECT_EQ(nullptr, client); + // Bad peer_info + client = PairingClient::create({0x01}, {}, {0x01}, {0x01}, "127.0.0.1"); + EXPECT_EQ(nullptr, client); + // Bad certificate + client = PairingClient::create({0x01}, client_info_, {}, {0x01}, "127.0.0.1"); + EXPECT_EQ(nullptr, client); + // Bad private key + client = PairingClient::create({0x01}, client_info_, {0x01}, {}, "127.0.0.1"); + EXPECT_EQ(nullptr, client); + // Bad ip address + client = PairingClient::create({0x01}, client_info_, {0x01}, {0x01}, ""); + EXPECT_EQ(nullptr, client); + // Valid params + client = PairingClient::create({0x01}, client_info_, {0x01}, {0x01}, "127.0.0.1"); + EXPECT_NE(nullptr, client); +} + +TEST_F(AdbWifiPairingConnectionTest, SmokeValidPairing) { + std::vector pswd{0x01, 0x03, 0x05, 0x07}; + initPairing(pswd, pswd); + + // Start the server first, to open the port for connections + std::mutex server_mutex; + std::condition_variable server_cv; + std::unique_lock server_lock(server_mutex); + + auto server_callback = [&](const PeerInfo* peer_info, const std::vector* cert, + void* opaque) { + ASSERT_NE(nullptr, peer_info); + ASSERT_NE(nullptr, cert); + EXPECT_FALSE(cert->empty()); + EXPECT_EQ(nullptr, opaque); + + // Verify the peer_info and cert + ASSERT_EQ(strlen(peer_info->name), strlen(client_info_.name)); + EXPECT_EQ(::memcmp(peer_info->name, client_info_.name, strlen(client_info_.name)), 0); + ASSERT_EQ(strlen(peer_info->guid), strlen(client_info_.guid)); + EXPECT_EQ(::memcmp(peer_info->guid, client_info_.guid, strlen(client_info_.guid)), 0); + ASSERT_EQ(cert->size(), kTestClientCert.size() + 1); + EXPECT_EQ(::memcmp(cert->data(), kTestClientCert.data(), kTestClientCert.size() + 1), 0); + + std::lock_guard lock(server_mutex); + server_cv.notify_one(); + }; + ASSERT_TRUE(server_->start(server_callback, nullptr)); + + // Start the client + bool got_valid_pairing = false; + std::mutex client_mutex; + std::condition_variable client_cv; + std::unique_lock client_lock(client_mutex); + auto client_callback = [&](const PeerInfo* peer_info, const std::vector* cert, + void* opaque) { + ASSERT_NE(nullptr, peer_info); + ASSERT_NE(nullptr, cert); + EXPECT_FALSE(cert->empty()); + EXPECT_EQ(nullptr, opaque); + + // Verify the peer_info and cert + ASSERT_EQ(strlen(peer_info->name), strlen(server_info_.name)); + EXPECT_EQ(::memcmp(peer_info->name, server_info_.name, strlen(server_info_.name)), 0); + ASSERT_EQ(strlen(peer_info->guid), strlen(server_info_.guid)); + EXPECT_EQ(::memcmp(peer_info->guid, server_info_.guid, strlen(server_info_.guid)), 0); + ASSERT_EQ(cert->size(), kTestServerCert.size() + 1); + EXPECT_EQ(::memcmp(cert->data(), kTestServerCert.data(), kTestServerCert.size() + 1), 0); + + got_valid_pairing = (peer_info != nullptr && cert != nullptr && !cert->empty()); + std::lock_guard lock(client_mutex); + client_cv.notify_one(); + }; + ASSERT_TRUE(client_->start(client_callback, nullptr)); + client_cv.wait(client_lock); + + // Kill server if the pairing failed, since server only shuts down when + // it gets a valid pairing. + if (!got_valid_pairing) { + server_lock.unlock(); + server_.reset(); + } else { + server_cv.wait(server_lock); + } +} + +TEST_F(AdbWifiPairingConnectionTest, CancelPairing) { + std::vector pswd{0x01, 0x03, 0x05, 0x07}; + std::vector pswd2{0x01, 0x03, 0x05, 0x06}; + initPairing(pswd, pswd2); + + // Start the server first, to open the port for connections + std::mutex server_mutex; + std::condition_variable server_cv; + std::unique_lock server_lock(server_mutex); + + bool server_got_valid_pairing = true; + auto server_callback = [&](const PeerInfo* peer_info, const std::vector* cert, + void* opaque) { + // Pairing will be cancelled, which should initiate this callback with + // empty values. + ASSERT_EQ(nullptr, peer_info); + ASSERT_EQ(nullptr, cert); + EXPECT_EQ(nullptr, opaque); + std::lock_guard lock(server_mutex); + server_cv.notify_one(); + server_got_valid_pairing = false; + }; + ASSERT_TRUE(server_->start(server_callback, nullptr)); + + // Start the client (should fail because of different passwords). + bool got_valid_pairing = false; + std::mutex client_mutex; + std::condition_variable client_cv; + std::unique_lock client_lock(client_mutex); + auto client_callback = [&](const PeerInfo* peer_info, const std::vector* cert, + void* opaque) { + ASSERT_EQ(nullptr, peer_info); + ASSERT_EQ(nullptr, cert); + EXPECT_EQ(nullptr, opaque); + + got_valid_pairing = (peer_info != nullptr && cert != nullptr && !cert->empty()); + std::lock_guard lock(client_mutex); + client_cv.notify_one(); + }; + ASSERT_TRUE(client_->start(client_callback, nullptr)); + client_cv.wait(client_lock); + + server_lock.unlock(); + // This should trigger the callback to be on the same thread. + server_.reset(); + EXPECT_FALSE(server_got_valid_pairing); +} + +TEST_F(AdbWifiPairingConnectionTest, MultipleClientsAllFail) { + std::vector pswd{0x01, 0x03, 0x05, 0x07}; + std::vector pswd2{0x01, 0x03, 0x05, 0x06}; + + auto server = createServer(pswd); + ASSERT_NE(nullptr, server); + // Start the server first, to open the port for connections + std::mutex server_mutex; + std::condition_variable server_cv; + std::unique_lock server_lock(server_mutex); + + bool server_got_valid_pairing = true; + auto server_callback = [&](const PeerInfo* peer_info, const std::vector* cert, + void* opaque) { + // Pairing will be cancelled, which should initiate this callback with + // empty values. + ASSERT_EQ(nullptr, peer_info); + ASSERT_EQ(nullptr, cert); + EXPECT_EQ(nullptr, opaque); + std::lock_guard lock(server_mutex); + server_cv.notify_one(); + server_got_valid_pairing = false; + }; + ASSERT_TRUE(server->start(server_callback, nullptr)); + + // Start multiple clients, all with bad passwords + std::vector> clients; + int num_clients_done = 0; + int test_num_clients = 5; + std::mutex client_mutex; + std::condition_variable client_cv; + std::unique_lock client_lock(client_mutex); + while (clients.size() < test_num_clients) { + auto client = createClient(pswd2); + ASSERT_NE(nullptr, client); + auto callback = [&](const PeerInfo* peer_info, const std::vector* cert, + void* opaque) { + ASSERT_EQ(nullptr, peer_info); + ASSERT_EQ(nullptr, cert); + EXPECT_EQ(nullptr, opaque); + + { + std::lock_guard lock(client_mutex); + num_clients_done++; + } + client_cv.notify_one(); + }; + ASSERT_TRUE(client->start(callback, nullptr)); + clients.push_back(std::move(client)); + } + + client_cv.wait(client_lock, [&]() { return (num_clients_done == test_num_clients); }); + EXPECT_EQ(num_clients_done, test_num_clients); + + server_lock.unlock(); + // This should trigger the callback to be on the same thread. + server.reset(); + EXPECT_FALSE(server_got_valid_pairing); +} + +TEST_F(AdbWifiPairingConnectionTest, MultipleClientsOnePass) { + // Send multiple clients with bad passwords, but send the last one with the + // correct password. + std::vector pswd{0x01, 0x03, 0x05, 0x07}; + std::vector pswd2{0x01, 0x03, 0x05, 0x06}; + + auto server = createServer(pswd); + ASSERT_NE(nullptr, server); + // Start the server first, to open the port for connections + std::mutex server_mutex; + std::condition_variable server_cv; + std::unique_lock server_lock(server_mutex); + + bool server_got_valid_pairing = false; + auto server_callback = [&](const PeerInfo* peer_info, const std::vector* cert, + void* opaque) { + // Pairing will be cancelled, which should initiate this callback with + // empty values. + + ASSERT_NE(nullptr, peer_info); + ASSERT_NE(nullptr, cert); + EXPECT_FALSE(cert->empty()); + EXPECT_EQ(nullptr, opaque); + + // Verify the peer_info and cert + ASSERT_EQ(strlen(peer_info->name), strlen(client_info_.name)); + EXPECT_EQ(::memcmp(peer_info->name, client_info_.name, strlen(client_info_.name)), 0); + ASSERT_EQ(strlen(peer_info->guid), strlen(client_info_.guid)); + EXPECT_EQ(::memcmp(peer_info->guid, client_info_.guid, strlen(client_info_.guid)), 0); + ASSERT_EQ(cert->size(), kTestClientCert.size() + 1); + EXPECT_EQ(::memcmp(cert->data(), kTestClientCert.data(), kTestClientCert.size() + 1), 0); + + std::lock_guard lock(server_mutex); + server_got_valid_pairing = true; + server_cv.notify_one(); + }; + ASSERT_TRUE(server->start(server_callback, nullptr)); + + // Start multiple clients, all with bad passwords (except for the last one) + std::vector> clients; + int num_clients_done = 0; + int test_num_clients = 5; + std::mutex client_mutex; + std::condition_variable client_cv; + std::unique_lock client_lock(client_mutex); + bool got_valid_pairing = false; + while (clients.size() < test_num_clients) { + std::unique_ptr client; + if (clients.size() == test_num_clients - 1) { + // Make this one have the valid password + client = createClient(pswd); + ASSERT_NE(nullptr, client); + auto callback = [&](const PeerInfo* peer_info, const std::vector* cert, + void* opaque) { + ASSERT_NE(nullptr, peer_info); + ASSERT_NE(nullptr, cert); + EXPECT_FALSE(cert->empty()); + EXPECT_EQ(nullptr, opaque); + + // Verify the peer_info and cert + ASSERT_EQ(strlen(peer_info->name), strlen(server_info_.name)); + EXPECT_EQ(::memcmp(peer_info->name, server_info_.name, strlen(server_info_.name)), + 0); + ASSERT_EQ(strlen(peer_info->guid), strlen(server_info_.guid)); + EXPECT_EQ(::memcmp(peer_info->guid, server_info_.guid, strlen(server_info_.guid)), + 0); + ASSERT_EQ(cert->size(), kTestServerCert.size() + 1); + EXPECT_EQ( + ::memcmp(cert->data(), kTestServerCert.data(), kTestServerCert.size() + 1), + 0); + got_valid_pairing = (peer_info != nullptr && cert != nullptr && !cert->empty()); + + { + std::lock_guard lock(client_mutex); + num_clients_done++; + } + client_cv.notify_one(); + }; + ASSERT_TRUE(client->start(callback, nullptr)); + } else { + client = createClient(pswd2); + ASSERT_NE(nullptr, client); + auto callback = [&](const PeerInfo* peer_info, const std::vector* cert, + void* opaque) { + ASSERT_EQ(nullptr, peer_info); + ASSERT_EQ(nullptr, cert); + EXPECT_EQ(nullptr, opaque); + + { + std::lock_guard lock(client_mutex); + num_clients_done++; + } + client_cv.notify_one(); + }; + ASSERT_TRUE(client->start(callback, nullptr)); + } + clients.push_back(std::move(client)); + } + + client_cv.wait(client_lock, [&]() { return (num_clients_done == test_num_clients); }); + EXPECT_EQ(num_clients_done, test_num_clients); + + // Kill server if the pairing failed, since server only shuts down when + // it gets a valid pairing. + if (!got_valid_pairing) { + server_lock.unlock(); + server_.reset(); + } else { + server_cv.wait(server_lock); + } + EXPECT_TRUE(server_got_valid_pairing); +} + +} // namespace pairing +} // namespace adbwifi diff --git a/adb/client/pairing/tests/pairing_server.cpp b/adb/client/pairing/tests/pairing_server.cpp new file mode 100644 index 000000000..9201e7a0e --- /dev/null +++ b/adb/client/pairing/tests/pairing_server.cpp @@ -0,0 +1,426 @@ +/* + * Copyright (C) 2019 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "adbwifi/pairing/pairing_server.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace adbwifi { +namespace pairing { + +using android::base::ScopedLockAssertion; +using android::base::unique_fd; + +namespace { + +// The implimentation has two background threads running: one to handle and +// accept any new pairing connection requests (socket accept), and the other to +// handle connection events (connection started, connection finished). +class PairingServerImpl : public PairingServer { + public: + virtual ~PairingServerImpl(); + + // All parameters must be non-empty. + explicit PairingServerImpl(const Data& pswd, const PeerInfo& peer_info, const Data& cert, + const Data& priv_key, int port); + + // Starts the pairing server. This call is non-blocking. Upon completion, + // if the pairing was successful, then |cb| will be called with the PublicKeyHeader + // containing the info of the trusted peer. Otherwise, |cb| will be + // called with an empty value. Start can only be called once in the lifetime + // of this object. + // + // Returns true if PairingServer was successfully started. Otherwise, + // returns false. + virtual bool start(PairingConnection::ResultCallback cb, void* opaque) override; + + private: + // Setup the server socket to accept incoming connections + bool setupServer(); + // Force stop the server thread. + void stopServer(); + + // handles a new pairing client connection + bool handleNewClientConnection(int fd) EXCLUDES(conn_mutex_); + + // ======== connection events thread ============= + std::mutex conn_mutex_; + std::condition_variable conn_cv_; + + using FdVal = int; + using ConnectionPtr = std::unique_ptr; + using NewConnectionEvent = std::tuple; + // + using ConnectionFinishedEvent = std::tuple, + std::optional, std::optional>; + using ConnectionEvent = std::variant; + // Queue for connections to write into. We have a separate queue to read + // from, in order to minimize the time the server thread is blocked. + std::deque conn_write_queue_ GUARDED_BY(conn_mutex_); + std::deque conn_read_queue_; + // Map of fds to their PairingConnections currently running. + std::unordered_map connections_; + + // Two threads launched when starting the pairing server: + // 1) A server thread that waits for incoming client connections, and + // 2) A connection events thread that synchonizes events from all of the + // clients, since each PairingConnection is running in it's own thread. + void startConnectionEventsThread(); + void startServerThread(); + + std::thread conn_events_thread_; + void connectionEventsWorker(); + std::thread server_thread_; + void serverWorker(); + bool is_terminate_ GUARDED_BY(conn_mutex_) = false; + + enum class State { + Ready, + Running, + Stopped, + }; + State state_ = State::Ready; + Data pswd_; + PeerInfo peer_info_; + Data cert_; + Data priv_key_; + int port_ = -1; + + PairingConnection::ResultCallback cb_; + void* opaque_ = nullptr; + bool got_valid_pairing_ = false; + + static const int kEpollConstSocket = 0; + // Used to break the server thread from epoll_wait + static const int kEpollConstEventFd = 1; + unique_fd epoll_fd_; + unique_fd server_fd_; + unique_fd event_fd_; +}; // PairingServerImpl + +PairingServerImpl::PairingServerImpl(const Data& pswd, const PeerInfo& peer_info, const Data& cert, + const Data& priv_key, int port) + : pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key), port_(port) { + CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty() && port_ > 0); + CHECK('\0' == peer_info.name[kPeerNameLength - 1] && + '\0' == peer_info.guid[kPeerGuidLength - 1] && strlen(peer_info.name) > 0 && + strlen(peer_info.guid) > 0); +} + +PairingServerImpl::~PairingServerImpl() { + // Since these connections have references to us, let's make sure they + // destruct before us. + if (server_thread_.joinable()) { + stopServer(); + server_thread_.join(); + } + + { + std::lock_guard lock(conn_mutex_); + is_terminate_ = true; + } + conn_cv_.notify_one(); + if (conn_events_thread_.joinable()) { + conn_events_thread_.join(); + } + + // Notify the cb_ if it hasn't already. + if (!got_valid_pairing_ && cb_ != nullptr) { + cb_(nullptr, nullptr, opaque_); + } +} + +bool PairingServerImpl::start(PairingConnection::ResultCallback cb, void* opaque) { + cb_ = cb; + opaque_ = opaque; + + if (state_ != State::Ready) { + LOG(ERROR) << "PairingServer already running or stopped"; + return false; + } + + if (!setupServer()) { + LOG(ERROR) << "Unable to start PairingServer"; + state_ = State::Stopped; + return false; + } + + state_ = State::Running; + return true; +} + +void PairingServerImpl::stopServer() { + if (event_fd_.get() == -1) { + return; + } + uint64_t value = 1; + ssize_t rc = write(event_fd_.get(), &value, sizeof(value)); + if (rc == -1) { + // This can happen if the server didn't start. + PLOG(ERROR) << "write to eventfd failed"; + } else if (rc != sizeof(value)) { + LOG(FATAL) << "write to event returned short (" << rc << ")"; + } +} + +bool PairingServerImpl::setupServer() { + epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC)); + if (epoll_fd_ == -1) { + PLOG(ERROR) << "failed to create epoll fd"; + return false; + } + + event_fd_.reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)); + if (event_fd_ == -1) { + PLOG(ERROR) << "failed to create eventfd"; + return false; + } + + server_fd_.reset(socket_inaddr_any_server(port_, SOCK_STREAM)); + if (server_fd_.get() == -1) { + PLOG(ERROR) << "Failed to start pairing connection server"; + return false; + } else if (fcntl(server_fd_.get(), F_SETFD, FD_CLOEXEC) != 0) { + PLOG(ERROR) << "Failed to make server socket cloexec"; + return false; + } else if (fcntl(server_fd_.get(), F_SETFD, O_NONBLOCK) != 0) { + PLOG(ERROR) << "Failed to make server socket nonblocking"; + return false; + } + + startConnectionEventsThread(); + startServerThread(); + return true; +} + +void PairingServerImpl::startServerThread() { + server_thread_ = std::thread([this]() { serverWorker(); }); +} + +void PairingServerImpl::startConnectionEventsThread() { + conn_events_thread_ = std::thread([this]() { connectionEventsWorker(); }); +} + +void PairingServerImpl::serverWorker() { + { + struct epoll_event event; + event.events = EPOLLIN; + event.data.u64 = kEpollConstSocket; + CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, server_fd_.get(), &event)); + } + + { + struct epoll_event event; + event.events = EPOLLIN; + event.data.u64 = kEpollConstEventFd; + CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, event_fd_.get(), &event)); + } + + while (true) { + struct epoll_event events[2]; + int rc = TEMP_FAILURE_RETRY(epoll_wait(epoll_fd_.get(), events, 2, -1)); + if (rc == -1) { + PLOG(ERROR) << "epoll_wait failed"; + return; + } else if (rc == 0) { + LOG(ERROR) << "epoll_wait returned 0"; + return; + } + + for (int i = 0; i < rc; ++i) { + struct epoll_event& event = events[i]; + switch (event.data.u64) { + case kEpollConstSocket: + handleNewClientConnection(server_fd_.get()); + break; + case kEpollConstEventFd: + uint64_t dummy; + int rc = TEMP_FAILURE_RETRY(read(event_fd_.get(), &dummy, sizeof(dummy))); + if (rc != sizeof(dummy)) { + PLOG(FATAL) << "failed to read from eventfd (rc=" << rc << ")"; + } + return; + } + } + } +} + +void PairingServerImpl::connectionEventsWorker() { + for (;;) { + // Transfer the write queue to the read queue. + { + std::unique_lock lock(conn_mutex_); + ScopedLockAssertion assume_locked(conn_mutex_); + + if (is_terminate_) { + // We check |is_terminate_| twice because condition_variable's + // notify() only wakes up a thread if it is in the wait state + // prior to notify(). Furthermore, we aren't holding the mutex + // when processing the events in |conn_read_queue_|. + return; + } + if (conn_write_queue_.empty()) { + // We need to wait for new events, or the termination signal. + conn_cv_.wait(lock, [this]() REQUIRES(conn_mutex_) { + return (is_terminate_ || !conn_write_queue_.empty()); + }); + } + if (is_terminate_) { + // We're done. + return; + } + // Move all events into the read queue. + conn_read_queue_ = std::move(conn_write_queue_); + conn_write_queue_.clear(); + } + + // Process all events in the read queue. + while (conn_read_queue_.size() > 0) { + auto& event = conn_read_queue_.front(); + if (auto* p = std::get_if(&event)) { + // Ignore if we are already at the max number of connections + if (connections_.size() >= internal::kMaxConnections) { + conn_read_queue_.pop_front(); + continue; + } + auto [ufd, connection] = std::move(*p); + int fd = ufd.release(); + bool started = connection->start( + fd, + [fd](const PeerInfo* peer_info, const Data* cert, void* opaque) { + auto* p = reinterpret_cast(opaque); + + ConnectionFinishedEvent event; + if (peer_info != nullptr && cert != nullptr) { + event = std::make_tuple(fd, std::string(peer_info->name), + std::string(peer_info->guid), Data(*cert)); + } else { + event = std::make_tuple(fd, std::nullopt, std::nullopt, + std::nullopt); + } + { + std::lock_guard lock(p->conn_mutex_); + p->conn_write_queue_.push_back(std::move(event)); + } + p->conn_cv_.notify_one(); + }, + this); + if (!started) { + LOG(ERROR) << "PairingServer unable to start a PairingConnection fd=" << fd; + ufd.reset(fd); + } else { + connections_[fd] = std::move(connection); + } + } else if (auto* p = std::get_if(&event)) { + auto [fd, name, guid, cert] = std::move(*p); + if (name.has_value() && guid.has_value() && cert.has_value() && !name->empty() && + !guid->empty() && !cert->empty()) { + // Valid pairing. Let's shutdown the server and close any + // pairing connections in progress. + stopServer(); + connections_.clear(); + + CHECK_LE(name->size(), kPeerNameLength); + CHECK_LE(guid->size(), kPeerGuidLength); + PeerInfo info = {}; + strncpy(info.name, name->data(), name->size()); + strncpy(info.guid, guid->data(), guid->size()); + + cb_(&info, &*cert, opaque_); + + got_valid_pairing_ = true; + return; + } + // Invalid pairing. Close the invalid connection. + if (connections_.find(fd) != connections_.end()) { + connections_.erase(fd); + } + } + conn_read_queue_.pop_front(); + } + } +} + +bool PairingServerImpl::handleNewClientConnection(int fd) { + unique_fd ufd(TEMP_FAILURE_RETRY(accept4(fd, nullptr, nullptr, SOCK_CLOEXEC))); + if (ufd == -1) { + PLOG(WARNING) << "adb_socket_accept failed fd=" << fd; + return false; + } + auto connection = PairingConnection::create(PairingConnection::Role::Server, pswd_, peer_info_, + cert_, priv_key_); + if (connection == nullptr) { + LOG(ERROR) << "PairingServer unable to create a PairingConnection fd=" << fd; + return false; + } + // send the new connection to the connection thread for further processing + NewConnectionEvent event = std::make_tuple(std::move(ufd), std::move(connection)); + { + std::lock_guard lock(conn_mutex_); + conn_write_queue_.push_back(std::move(event)); + } + conn_cv_.notify_one(); + + return true; +} + +} // namespace + +// static +std::unique_ptr PairingServer::create(const Data& pswd, const PeerInfo& peer_info, + const Data& cert, const Data& priv_key, + int port) { + if (pswd.empty() || cert.empty() || priv_key.empty() || port <= 0) { + return nullptr; + } + // Make sure peer_info has a non-empty, null-terminated string for guid and + // name. + if ('\0' != peer_info.name[kPeerNameLength - 1] || + '\0' != peer_info.guid[kPeerGuidLength - 1] || strlen(peer_info.name) == 0 || + strlen(peer_info.guid) == 0) { + LOG(ERROR) << "The GUID/short name fields are empty or not null-terminated"; + return nullptr; + } + + if (port != kDefaultPairingPort) { + LOG(WARNING) << "Starting server with non-default pairing port=" << port; + } + + return std::unique_ptr( + new PairingServerImpl(pswd, peer_info, cert, priv_key, port)); +} + +} // namespace pairing +} // namespace adbwifi diff --git a/adb/client/pairing/tests/pairing_server.h b/adb/client/pairing/tests/pairing_server.h new file mode 100644 index 000000000..6fb51ccf1 --- /dev/null +++ b/adb/client/pairing/tests/pairing_server.h @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2019 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include + +namespace adbwifi { +namespace pairing { + +// PairingServer is the server side of the PairingConnection protocol. It will +// listen for incoming PairingClient connections, and allocate a new +// PairingConnection per client for processing. PairingServer can handle multiple +// connections, but the first one to establish the pairing will be the only one +// to succeed. All others will be disconnected. +// +// See pairing_connection_test.cpp for example usage. +// +class PairingServer { + public: + using Data = std::vector; + + virtual ~PairingServer() = default; + + // Starts the pairing server. This call is non-blocking. Upon completion, + // if the pairing was successful, then |cb| will be called with the PeerInfo + // containing the info of the trusted peer. Otherwise, |cb| will be + // called with an empty value. Start can only be called once in the lifetime + // of this object. + // + // Returns true if PairingServer was successfully started. Otherwise, + // returns false. + virtual bool start(PairingConnection::ResultCallback cb, void* opaque) = 0; + + // Creates a new PairingServer instance. May return null if unable + // to create an instance. |pswd|, |certificate| and |priv_key| cannot + // be empty. |port| is the port PairingServer will listen to PairingClient + // connections on. |peer_info| must contain non-empty strings for the guid + // and name fields. + static std::unique_ptr create(const Data& pswd, const PeerInfo& peer_info, + const Data& certificate, const Data& priv_key, + int port); + + protected: + PairingServer() = default; +}; // class PairingServer + +} // namespace pairing +} // namespace adbwifi diff --git a/adb/client/transport_mdns.cpp b/adb/client/transport_mdns.cpp index f5811a4ca..ff1f7b4c6 100644 --- a/adb/client/transport_mdns.cpp +++ b/adb/client/transport_mdns.cpp @@ -24,15 +24,19 @@ #include #endif +#include #include #include #include +#include #include #include "adb_client.h" #include "adb_mdns.h" #include "adb_trace.h" +#include "adb_utils.h" +#include "adb_wifi.h" #include "fdevent/fdevent.h" #include "sysdeps.h" @@ -48,9 +52,17 @@ static int adb_DNSServiceIndexByName(const char* regType) { return -1; } -static bool adb_DNSServiceShouldConnect(const char* regType) { +static bool adb_DNSServiceShouldConnect(const char* regType, const char* serviceName) { int index = adb_DNSServiceIndexByName(regType); - return index == kADBTransportServiceRefIndex; + if (index == kADBTransportServiceRefIndex) { + // Ignore adb-EMULATOR* service names, as it interferes with the + // emulator ports that are already connected. + if (android::base::StartsWith(serviceName, "adb-EMULATOR")) { + LOG(INFO) << "Ignoring emulator transport service [" << serviceName << "]"; + return false; + } + } + return (index == kADBTransportServiceRefIndex || index == kADBSecureConnectServiceRefIndex); } // Use adb_DNSServiceRefSockFD() instead of calling DNSServiceRefSockFD() @@ -88,8 +100,10 @@ class AsyncServiceRef { return; } - DNSServiceRefDeallocate(sdRef_); + // Order matters here! Must destroy the fdevent first since it has a + // reference to |sdRef_|. fdevent_destroy(fde_); + DNSServiceRefDeallocate(sdRef_); } protected: @@ -97,6 +111,10 @@ class AsyncServiceRef { void Initialize() { fde_ = fdevent_create(adb_DNSServiceRefSockFD(sdRef_), pump_service_ref, &sdRef_); + if (fde_ == nullptr) { + D("Unable to create fdevent"); + return; + } fdevent_set(fde_, FDE_READ); initialized_ = true; } @@ -142,16 +160,29 @@ class ResolvedService : public AsyncServiceRef { D("Client version: %d Service version: %d\n", clientVersion_, serviceVersion_); } + bool ConnectSecureWifiDevice() { + if (!adb_wifi_is_known_host(serviceName_)) { + LOG(INFO) << "serviceName=" << serviceName_ << " not in keystore"; + return false; + } + + std::string response; + connect_device(android::base::StringPrintf(addr_format_.c_str(), ip_addr_, port_), + &response); + D("Secure connect to %s regtype %s (%s:%hu) : %s", serviceName_.c_str(), regType_.c_str(), + ip_addr_, port_, response.c_str()); + return true; + } + void Connect(const sockaddr* address) { sa_family_ = address->sa_family; - const char* addr_format; if (sa_family_ == AF_INET) { ip_addr_data_ = &reinterpret_cast(address)->sin_addr; - addr_format = "%s:%hu"; + addr_format_ = "%s:%hu"; } else if (sa_family_ == AF_INET6) { ip_addr_data_ = &reinterpret_cast(address)->sin6_addr; - addr_format = "[%s]:%hu"; + addr_format_ = "[%s]:%hu"; } else { // Should be impossible D("mDNS resolved non-IP address."); return; @@ -165,11 +196,19 @@ class ResolvedService : public AsyncServiceRef { // adb secure service needs to do something different from just // connecting here. - if (adb_DNSServiceShouldConnect(regType_.c_str())) { + if (adb_DNSServiceShouldConnect(regType_.c_str(), serviceName_.c_str())) { std::string response; - connect_device(android::base::StringPrintf(addr_format, ip_addr_, port_), &response); - D("Connect to %s regtype %s (%s:%hu) : %s", serviceName_.c_str(), regType_.c_str(), - ip_addr_, port_, response.c_str()); + D("Attempting to serviceName=[%s], regtype=[%s] ipaddr=(%s:%hu)", serviceName_.c_str(), + regType_.c_str(), ip_addr_, port_); + int index = adb_DNSServiceIndexByName(regType_.c_str()); + if (index == kADBSecureConnectServiceRefIndex) { + ConnectSecureWifiDevice(); + } else { + connect_device(android::base::StringPrintf(addr_format_.c_str(), ip_addr_, port_), + &response); + D("Connect to %s regtype %s (%s:%hu) : %s", serviceName_.c_str(), regType_.c_str(), + ip_addr_, port_, response.c_str()); + } } else { D("Not immediately connecting to serviceName=[%s], regtype=[%s] ipaddr=(%s:%hu)", serviceName_.c_str(), regType_.c_str(), ip_addr_, port_); @@ -192,6 +231,8 @@ class ResolvedService : public AsyncServiceRef { std::string hostTarget() const { return hosttarget_; } + std::string serviceName() const { return serviceName_; } + std::string ipAddress() const { return ip_addr_; } uint16_t port() const { return port_; } @@ -206,8 +247,12 @@ class ResolvedService : public AsyncServiceRef { static void forEachService(const ServiceRegistry& services, const std::string& hostname, adb_secure_foreach_service_callback cb); + static bool connectByServiceName(const ServiceRegistry& services, + const std::string& service_name); + private: int clientVersion_ = ADB_SECURE_CLIENT_VERSION; + std::string addr_format_; std::string serviceName_; std::string regType_; std::string hosttarget_; @@ -236,35 +281,52 @@ void ResolvedService::initAdbSecure() { // static void ResolvedService::forEachService(const ServiceRegistry& services, - const std::string& wanted_host, + const std::string& wanted_service_name, adb_secure_foreach_service_callback cb) { initAdbSecure(); for (auto service : services) { - auto hostname = service->hostTarget(); + auto service_name = service->serviceName(); auto ip = service->ipAddress(); auto port = service->port(); - if (wanted_host == "") { - cb(hostname.c_str(), ip.c_str(), port); - } else if (hostname == wanted_host) { - cb(hostname.c_str(), ip.c_str(), port); + if (wanted_service_name == "") { + cb(service_name.c_str(), ip.c_str(), port); + } else if (service_name == wanted_service_name) { + cb(service_name.c_str(), ip.c_str(), port); } } } // static -void adb_secure_foreach_pairing_service(const char* host_name, - adb_secure_foreach_service_callback cb) { - ResolvedService::forEachService(*ResolvedService::sAdbSecurePairingServices, - host_name ? host_name : "", cb); +bool ResolvedService::connectByServiceName(const ServiceRegistry& services, + const std::string& service_name) { + initAdbSecure(); + for (auto service : services) { + if (service_name == service->serviceName()) { + D("Got service_name match [%s]", service->serviceName().c_str()); + return service->ConnectSecureWifiDevice(); + } + } + D("No registered serviceNames matched [%s]", service_name.c_str()); + return false; } -// static -void adb_secure_foreach_connect_service(const char* host_name, +void adb_secure_foreach_pairing_service(const char* service_name, + adb_secure_foreach_service_callback cb) { + ResolvedService::forEachService(*ResolvedService::sAdbSecurePairingServices, + service_name ? service_name : "", cb); +} + +void adb_secure_foreach_connect_service(const char* service_name, adb_secure_foreach_service_callback cb) { ResolvedService::forEachService(*ResolvedService::sAdbSecureConnectServices, - host_name ? host_name : "", cb); + service_name ? service_name : "", cb); +} + +bool adb_secure_connect_by_service_name(const char* service_name) { + return ResolvedService::connectByServiceName(*ResolvedService::sAdbSecureConnectServices, + service_name); } static void DNSSD_API register_service_ip(DNSServiceRef /*sdRef*/, @@ -332,6 +394,26 @@ class DiscoveredService : public AsyncServiceRef { std::string regType_; }; +static void adb_RemoveDNSService(const char* regType, const char* serviceName) { + int index = adb_DNSServiceIndexByName(regType); + ResolvedService::ServiceRegistry* services; + switch (index) { + case kADBSecurePairingServiceRefIndex: + services = ResolvedService::sAdbSecurePairingServices; + break; + case kADBSecureConnectServiceRefIndex: + services = ResolvedService::sAdbSecureConnectServices; + break; + default: + return; + } + + std::string sName(serviceName); + std::remove_if(services->begin(), services->end(), [&sName](ResolvedService* service) { + return (sName == service->serviceName()); + }); +} + // Returns the version the device wanted to advertise, // or -1 if parsing fails. static int parse_version_from_txt_record(uint16_t txtLen, const unsigned char* txtRecord) { @@ -400,10 +482,12 @@ static void DNSSD_API register_resolved_mdns_service( interfaceIndex, hosttarget, ntohs(port), serviceVersion); if (! resolved->Initialized()) { + D("Unable to init resolved service"); delete resolved; } if (flags) { /* Only ever equals MoreComing or 0 */ + D("releasing discovered service"); discovered.release(); } } @@ -412,7 +496,6 @@ static void DNSSD_API on_service_browsed(DNSServiceRef sdRef, DNSServiceFlags fl uint32_t interfaceIndex, DNSServiceErrorType errorCode, const char* serviceName, const char* regtype, const char* domain, void* /*context*/) { - D("Registering a transport."); if (errorCode != kDNSServiceErr_NoError) { D("Got error %d during mDNS browse.", errorCode); DNSServiceRefDeallocate(sdRef); @@ -423,9 +506,17 @@ static void DNSSD_API on_service_browsed(DNSServiceRef sdRef, DNSServiceFlags fl return; } - auto discovered = new DiscoveredService(interfaceIndex, serviceName, regtype, domain); - if (!discovered->Initialized()) { - delete discovered; + if (flags & kDNSServiceFlagsAdd) { + D("%s: Discover found new serviceName=[%s] regtype=[%s] domain=[%s]", __func__, serviceName, + regtype, domain); + auto discovered = new DiscoveredService(interfaceIndex, serviceName, regtype, domain); + if (!discovered->Initialized()) { + delete discovered; + } + } else { + D("%s: Discover lost serviceName=[%s] regtype=[%s] domain=[%s]", __func__, serviceName, + regtype, domain); + adb_RemoveDNSService(regtype, serviceName); } } diff --git a/adb/crypto/Android.bp b/adb/crypto/Android.bp index da4869a9a..b7f75edd2 100644 --- a/adb/crypto/Android.bp +++ b/adb/crypto/Android.bp @@ -64,10 +64,6 @@ cc_library { "com.android.adbd", "test_com.android.adbd", ], - - static_libs: [ - "libadb_protos", - ], } // For running atest (b/147158681) diff --git a/adb/daemon/adb_wifi.cpp b/adb/daemon/adb_wifi.cpp new file mode 100644 index 000000000..bce303b2a --- /dev/null +++ b/adb/daemon/adb_wifi.cpp @@ -0,0 +1,228 @@ +/* + * Copyright (C) 2019 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if !ADB_HOST + +#define TRACE_TAG ADB_WIRELESS + +#include "adb_wifi.h" + +#include +#include + +#include +#include + +#include "adb.h" +#include "daemon/mdns.h" +#include "sysdeps.h" +#include "transport.h" + +using namespace android::base; + +namespace { + +static AdbdAuthContext* auth_ctx; + +static void adb_disconnected(void* unused, atransport* t); +static struct adisconnect adb_disconnect = {adb_disconnected, nullptr}; + +static void adb_disconnected(void* unused, atransport* t) { + LOG(INFO) << "ADB wifi device disconnected"; + adbd_auth_tls_device_disconnected(auth_ctx, kAdbTransportTypeWifi, t->auth_id); +} + +// TODO(b/31559095): need bionic host so that we can use 'prop_info' returned +// from WaitForProperty +#if defined(__ANDROID__) + +class TlsServer { + public: + explicit TlsServer(int port); + virtual ~TlsServer(); + bool Start(); + uint16_t port() { return port_; }; + + private: + void OnFdEvent(int fd, unsigned ev); + static void StaticOnFdEvent(int fd, unsigned ev, void* opaque); + + fdevent* fd_event_ = nullptr; + uint16_t port_; +}; // TlsServer + +TlsServer::TlsServer(int port) : port_(port) {} + +TlsServer::~TlsServer() { + fdevent* fde = fd_event_; + fdevent_run_on_main_thread([fde]() { + if (fde != nullptr) { + fdevent_destroy(fde); + } + }); +} + +bool TlsServer::Start() { + std::condition_variable cv; + std::mutex mutex; + std::optional success; + auto callback = [&](bool result) { + { + std::lock_guard lock(mutex); + success = result; + } + cv.notify_one(); + }; + + std::string err; + unique_fd fd(network_inaddr_any_server(port_, SOCK_STREAM, &err)); + if (fd.get() == -1) { + LOG(ERROR) << "Failed to start TLS server [" << err << "]"; + return false; + } + close_on_exec(fd.get()); + int port = socket_get_local_port(fd.get()); + if (port <= 0 || port > 65535) { + LOG(ERROR) << "Invalid port for tls server"; + return false; + } + port_ = static_cast(port); + LOG(INFO) << "adbwifi started on port " << port_; + + std::unique_lock lock(mutex); + fdevent_run_on_main_thread([&]() { + fd_event_ = fdevent_create(fd.release(), &TlsServer::StaticOnFdEvent, this); + if (fd_event_ == nullptr) { + LOG(ERROR) << "Failed to create fd event for TlsServer."; + callback(false); + return; + } + callback(true); + }); + + cv.wait(lock, [&]() { return success.has_value(); }); + if (!*success) { + LOG(INFO) << "TlsServer fdevent_create failed"; + return false; + } + fdevent_set(fd_event_, FDE_READ); + LOG(INFO) << "TlsServer running on port " << port_; + + return *success; +} + +// static +void TlsServer::StaticOnFdEvent(int fd, unsigned ev, void* opaque) { + auto server = reinterpret_cast(opaque); + server->OnFdEvent(fd, ev); +} + +void TlsServer::OnFdEvent(int fd, unsigned ev) { + if ((ev & FDE_READ) == 0 || fd != fd_event_->fd.get()) { + LOG(INFO) << __func__ << ": No read [ev=" << ev << " fd=" << fd << "]"; + return; + } + + unique_fd new_fd(adb_socket_accept(fd, nullptr, nullptr)); + if (new_fd >= 0) { + LOG(INFO) << "New TLS connection [fd=" << new_fd.get() << "]"; + close_on_exec(new_fd.get()); + disable_tcp_nagle(new_fd.get()); + std::string serial = android::base::StringPrintf("host-%d", new_fd.get()); + register_socket_transport( + std::move(new_fd), std::move(serial), port_, 1, + [](atransport*) { return ReconnectResult::Abort; }, true); + } +} + +TlsServer* sTlsServer = nullptr; +const char kWifiPortProp[] = "service.adb.tls.port"; + +const char kWifiEnabledProp[] = "persist.adb.tls_server.enable"; + +static void enable_wifi_debugging() { + start_mdnsd(); + + if (sTlsServer != nullptr) { + delete sTlsServer; + } + sTlsServer = new TlsServer(0); + if (!sTlsServer->Start()) { + LOG(ERROR) << "Failed to start TlsServer"; + delete sTlsServer; + sTlsServer = nullptr; + return; + } + + // Start mdns connect service for discovery + register_adb_secure_connect_service(sTlsServer->port()); + LOG(INFO) << "adb wifi started on port " << sTlsServer->port(); + SetProperty(kWifiPortProp, std::to_string(sTlsServer->port())); +} + +static void disable_wifi_debugging() { + if (sTlsServer != nullptr) { + delete sTlsServer; + sTlsServer = nullptr; + } + if (is_adb_secure_connect_service_registered()) { + unregister_adb_secure_connect_service(); + } + kick_all_tcp_tls_transports(); + LOG(INFO) << "adb wifi stopped"; + SetProperty(kWifiPortProp, ""); +} + +// Watches for the #kWifiEnabledProp property to toggle the TlsServer +static void start_wifi_enabled_observer() { + std::thread([]() { + bool wifi_enabled = false; + while (true) { + std::string toggled_val = wifi_enabled ? "0" : "1"; + LOG(INFO) << "Waiting for " << kWifiEnabledProp << "=" << toggled_val; + if (WaitForProperty(kWifiEnabledProp, toggled_val)) { + wifi_enabled = !wifi_enabled; + LOG(INFO) << kWifiEnabledProp << " changed to " << toggled_val; + if (wifi_enabled) { + enable_wifi_debugging(); + } else { + disable_wifi_debugging(); + } + } + } + }).detach(); +} +#endif //__ANDROID__ + +} // namespace + +void adbd_wifi_init(AdbdAuthContext* ctx) { + auth_ctx = ctx; +#if defined(__ANDROID__) + start_wifi_enabled_observer(); +#endif //__ANDROID__ +} + +void adbd_wifi_secure_connect(atransport* t) { + t->AddDisconnect(&adb_disconnect); + handle_online(t); + send_connect(t); + LOG(INFO) << __func__ << ": connected " << t->serial; + t->auth_id = adbd_auth_tls_device_connected(auth_ctx, kAdbTransportTypeWifi, t->auth_key.data(), + t->auth_key.size()); +} + +#endif /* !HOST */ diff --git a/adb/daemon/auth.cpp b/adb/daemon/auth.cpp index 1f6664e6e..2edf582d0 100644 --- a/adb/daemon/auth.cpp +++ b/adb/daemon/auth.cpp @@ -23,10 +23,14 @@ #include #include +#include #include #include #include +#include +#include +#include #include #include #include @@ -35,16 +39,24 @@ #include #include #include +#include #include "adb.h" #include "adb_auth.h" #include "adb_io.h" +#include "adb_wifi.h" #include "fdevent/fdevent.h" #include "transport.h" #include "types.h" +using namespace adb::crypto; +using namespace adb::tls; +using namespace std::chrono_literals; + static AdbdAuthContext* auth_ctx; +static RSA* rsa_pkey = nullptr; + static void adb_disconnected(void* unused, atransport* t); static struct adisconnect adb_disconnect = {adb_disconnected, nullptr}; @@ -91,6 +103,55 @@ static void IteratePublicKeys(std::function f &f); } +bssl::UniquePtr adbd_tls_client_ca_list() { + if (!auth_required) { + return nullptr; + } + + bssl::UniquePtr ca_list(sk_X509_NAME_new_null()); + + IteratePublicKeys([&](std::string_view public_key) { + // TODO: do we really have to support both ' ' and '\t'? + std::vector split = android::base::Split(std::string(public_key), " \t"); + uint8_t keybuf[ANDROID_PUBKEY_ENCODED_SIZE + 1]; + const std::string& pubkey = split[0]; + if (b64_pton(pubkey.c_str(), keybuf, sizeof(keybuf)) != ANDROID_PUBKEY_ENCODED_SIZE) { + LOG(ERROR) << "Invalid base64 key " << pubkey; + return true; + } + + RSA* key = nullptr; + if (!android_pubkey_decode(keybuf, ANDROID_PUBKEY_ENCODED_SIZE, &key)) { + LOG(ERROR) << "Failed to parse key " << pubkey; + return true; + } + bssl::UniquePtr rsa_key(key); + + unsigned char* dkey = nullptr; + int len = i2d_RSA_PUBKEY(rsa_key.get(), &dkey); + if (len <= 0 || dkey == nullptr) { + LOG(ERROR) << "Failed to encode RSA public key"; + return true; + } + + uint8_t digest[SHA256_DIGEST_LENGTH]; + // Put the encoded key in the commonName attribute of the issuer name. + // Note that the commonName has a max length of 64 bytes, which is less + // than the SHA256_DIGEST_LENGTH. + SHA256(dkey, len, digest); + OPENSSL_free(dkey); + + auto digest_str = SHA256BitsToHexString( + std::string_view(reinterpret_cast(&digest[0]), sizeof(digest))); + LOG(INFO) << "fingerprint=[" << digest_str << "]"; + auto issuer = CreateCAIssuerFromEncodedKey(digest_str); + CHECK(bssl::PushToStack(ca_list.get(), std::move(issuer))); + return true; + }); + + return ca_list; +} + bool adbd_auth_verify(const char* token, size_t token_size, const std::string& sig, std::string* auth_key) { bool authorized = false; @@ -159,11 +220,20 @@ static void adbd_auth_key_authorized(void* arg, uint64_t id) { }); } +static void adbd_key_removed(const char* public_key, size_t len) { + // The framework removed the key from its keystore. We need to disconnect all + // devices using that key. Search by t->auth_key + std::string_view auth_key(public_key, len); + kick_all_transports_by_auth_key(auth_key); +} + void adbd_auth_init(void) { AdbdAuthCallbacksV1 cb; cb.version = 1; cb.key_authorized = adbd_auth_key_authorized; + cb.key_removed = adbd_key_removed; auth_ctx = adbd_auth_new(&cb); + adbd_wifi_init(auth_ctx); std::thread([]() { adb_thread_setname("adbd auth"); adbd_auth_run(auth_ctx); @@ -206,5 +276,89 @@ void adbd_auth_confirm_key(atransport* t) { } void adbd_notify_framework_connected_key(atransport* t) { - adbd_auth_notify_auth(auth_ctx, t->auth_key.data(), t->auth_key.size()); + t->auth_id = adbd_auth_notify_auth(auth_ctx, t->auth_key.data(), t->auth_key.size()); +} + +int adbd_tls_verify_cert(X509_STORE_CTX* ctx, std::string* auth_key) { + if (!auth_required) { + // Any key will do. + LOG(INFO) << __func__ << ": auth not required"; + return 1; + } + + bool authorized = false; + X509* cert = X509_STORE_CTX_get0_cert(ctx); + if (cert == nullptr) { + LOG(INFO) << "got null x509 certificate"; + return 0; + } + bssl::UniquePtr evp_pkey(X509_get_pubkey(cert)); + if (evp_pkey == nullptr) { + LOG(INFO) << "got null evp_pkey from x509 certificate"; + return 0; + } + + IteratePublicKeys([&](std::string_view public_key) { + // TODO: do we really have to support both ' ' and '\t'? + std::vector split = android::base::Split(std::string(public_key), " \t"); + uint8_t keybuf[ANDROID_PUBKEY_ENCODED_SIZE + 1]; + const std::string& pubkey = split[0]; + if (b64_pton(pubkey.c_str(), keybuf, sizeof(keybuf)) != ANDROID_PUBKEY_ENCODED_SIZE) { + LOG(ERROR) << "Invalid base64 key " << pubkey; + return true; + } + + RSA* key = nullptr; + if (!android_pubkey_decode(keybuf, ANDROID_PUBKEY_ENCODED_SIZE, &key)) { + LOG(ERROR) << "Failed to parse key " << pubkey; + return true; + } + + bool verified = false; + bssl::UniquePtr known_evp(EVP_PKEY_new()); + EVP_PKEY_set1_RSA(known_evp.get(), key); + if (EVP_PKEY_cmp(known_evp.get(), evp_pkey.get())) { + LOG(INFO) << "Matched auth_key=" << public_key; + verified = true; + } else { + LOG(INFO) << "auth_key doesn't match [" << public_key << "]"; + } + RSA_free(key); + if (verified) { + *auth_key = public_key; + authorized = true; + return false; + } + + return true; + }); + + return authorized ? 1 : 0; +} + +void adbd_auth_tls_handshake(atransport* t) { + if (rsa_pkey == nullptr) { + // Generate a random RSA key to feed into the X509 certificate + auto rsa_2048 = CreateRSA2048Key(); + CHECK(rsa_2048.has_value()); + rsa_pkey = EVP_PKEY_get1_RSA(rsa_2048->GetEvpPkey()); + CHECK(rsa_pkey); + } + + std::thread([t]() { + std::string auth_key; + if (t->connection()->DoTlsHandshake(rsa_pkey, &auth_key)) { + LOG(INFO) << "auth_key=" << auth_key; + if (t->IsTcpDevice()) { + t->auth_key = auth_key; + adbd_wifi_secure_connect(t); + } else { + adbd_auth_verified(t); + adbd_notify_framework_connected_key(t); + } + } else { + // Only allow one attempt at the handshake. + t->Kick(); + } + }).detach(); } diff --git a/adb/daemon/main.cpp b/adb/daemon/main.cpp index 3322574ce..9e02e89ab 100644 --- a/adb/daemon/main.cpp +++ b/adb/daemon/main.cpp @@ -53,6 +53,7 @@ #include "adb_auth.h" #include "adb_listeners.h" #include "adb_utils.h" +#include "adb_wifi.h" #include "socket_spec.h" #include "transport.h" @@ -196,6 +197,7 @@ static void setup_adb(const std::vector& addrs) { if (port == -1) { port = DEFAULT_ADB_LOCAL_TRANSPORT_PORT; } + LOG(INFO) << "Setup mdns on port= " << port; setup_mdns(port); #endif for (const auto& addr : addrs) { @@ -317,9 +319,10 @@ int main(int argc, char** argv) { while (true) { static struct option opts[] = { - {"root_seclabel", required_argument, nullptr, 's'}, - {"device_banner", required_argument, nullptr, 'b'}, - {"version", no_argument, nullptr, 'v'}, + {"root_seclabel", required_argument, nullptr, 's'}, + {"device_banner", required_argument, nullptr, 'b'}, + {"version", no_argument, nullptr, 'v'}, + {"logpostfsdata", no_argument, nullptr, 'l'}, }; int option_index = 0; @@ -341,6 +344,9 @@ int main(int argc, char** argv) { printf("Android Debug Bridge Daemon version %d.%d.%d\n", ADB_VERSION_MAJOR, ADB_VERSION_MINOR, ADB_SERVER_VERSION); return 0; + case 'l': + LOG(ERROR) << "post-fs-data triggered"; + return 0; default: // getopt already prints "adbd: invalid option -- %c" for us. return 1; diff --git a/adb/daemon/mdns.cpp b/adb/daemon/mdns.cpp index fa98340b5..fa692c039 100644 --- a/adb/daemon/mdns.cpp +++ b/adb/daemon/mdns.cpp @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -36,7 +37,7 @@ static int port; static DNSServiceRef mdns_refs[kNumADBDNSServices]; static bool mdns_registered[kNumADBDNSServices]; -static void start_mdns() { +void start_mdnsd() { if (android::base::GetProperty("init.svc.mdnsd", "") == "running") { return; } @@ -61,11 +62,9 @@ static void mdns_callback(DNSServiceRef /*ref*/, } } -static void register_mdns_service(int index, int port) { +static void register_mdns_service(int index, int port, const std::string service_name) { std::lock_guard lock(mdns_lock); - std::string hostname = "adb-"; - hostname += android::base::GetProperty("ro.serialno", "unidentified"); // https://tools.ietf.org/html/rfc6763 // """ @@ -95,7 +94,7 @@ static void register_mdns_service(int index, int port) { } auto error = DNSServiceRegister( - &mdns_refs[index], 0, 0, hostname.c_str(), kADBDNSServices[index], nullptr, nullptr, + &mdns_refs[index], 0, 0, service_name.c_str(), kADBDNSServices[index], nullptr, nullptr, htobe16((uint16_t)port), (uint16_t)txtRecord.size(), txtRecord.empty() ? nullptr : txtRecord.data(), mdns_callback, nullptr); @@ -120,11 +119,13 @@ static void unregister_mdns_service(int index) { } static void register_base_mdns_transport() { - register_mdns_service(kADBTransportServiceRefIndex, port); + std::string hostname = "adb-"; + hostname += android::base::GetProperty("ro.serialno", "unidentified"); + register_mdns_service(kADBTransportServiceRefIndex, port, hostname); } static void setup_mdns_thread() { - start_mdns(); + start_mdnsd(); // We will now only set up the normal transport mDNS service // instead of registering all the adb secure mDNS services @@ -139,9 +140,57 @@ static void teardown_mdns() { } } +static std::string RandomAlphaNumString(size_t len) { + std::string ret; + std::random_device rd; + std::mt19937 mt(rd()); + // Generate values starting with zero and then up to enough to cover numeric + // digits, small letters and capital letters (26 each). + std::uniform_int_distribution dist(0, 61); + for (size_t i = 0; i < len; ++i) { + uint8_t val = dist(mt); + if (val < 10) { + ret += '0' + val; + } else if (val < 36) { + ret += 'A' + (val - 10); + } else { + ret += 'a' + (val - 36); + } + } + return ret; +} + +static std::string GenerateDeviceGuid() { + // The format is adb-- + std::string guid = "adb-"; + + std::string serial = android::base::GetProperty("ro.serialno", ""); + if (serial.empty()) { + // Generate 16-bytes of random alphanum string + serial = RandomAlphaNumString(16); + } + guid += serial + '-'; + // Random six-char suffix + guid += RandomAlphaNumString(6); + return guid; +} + +static std::string ReadDeviceGuid() { + std::string guid = android::base::GetProperty("persist.adb.wifi.guid", ""); + if (guid.empty()) { + guid = GenerateDeviceGuid(); + CHECK(!guid.empty()); + android::base::SetProperty("persist.adb.wifi.guid", guid); + } + return guid; +} + // Public interface///////////////////////////////////////////////////////////// void setup_mdns(int port_in) { + // Make sure the adb wifi guid is generated. + std::string guid = ReadDeviceGuid(); + CHECK(!guid.empty()); port = port_in; std::thread(setup_mdns_thread).detach(); @@ -149,24 +198,14 @@ void setup_mdns(int port_in) { atexit(teardown_mdns); } -void register_adb_secure_pairing_service(int port) { - std::thread([port]() { - register_mdns_service(kADBSecurePairingServiceRefIndex, port); - }).detach(); -} - -void unregister_adb_secure_pairing_service() { - std::thread([]() { unregister_mdns_service(kADBSecurePairingServiceRefIndex); }).detach(); -} - -bool is_adb_secure_pairing_service_registered() { - std::lock_guard lock(mdns_lock); - return mdns_registered[kADBSecurePairingServiceRefIndex]; -} - void register_adb_secure_connect_service(int port) { std::thread([port]() { - register_mdns_service(kADBSecureConnectServiceRefIndex, port); + auto service_name = ReadDeviceGuid(); + if (service_name.empty()) { + return; + } + LOG(INFO) << "Registering secure_connect service (" << service_name << ")"; + register_mdns_service(kADBSecureConnectServiceRefIndex, port, service_name); }).detach(); } diff --git a/adb/daemon/mdns.h b/adb/daemon/mdns.h index a18093b50..e7e7a6217 100644 --- a/adb/daemon/mdns.h +++ b/adb/daemon/mdns.h @@ -19,12 +19,9 @@ void setup_mdns(int port); -void register_adb_secure_pairing_service(int port); -void unregister_adb_secure_pairing_service(int port); -bool is_adb_secure_pairing_service_registered(); - void register_adb_secure_connect_service(int port); -void unregister_adb_secure_connect_service(int port); +void unregister_adb_secure_connect_service(); bool is_adb_secure_connect_service_registered(); +void start_mdnsd(); #endif // _DAEMON_MDNS_H_ diff --git a/adb/daemon/transport_qemu.cpp b/adb/daemon/transport_qemu.cpp index 901efeea2..e458cea7e 100644 --- a/adb/daemon/transport_qemu.cpp +++ b/adb/daemon/transport_qemu.cpp @@ -105,8 +105,9 @@ void qemu_socket_thread(std::string_view addr) { * exchange. */ std::string serial = android::base::StringPrintf("host-%d", fd.get()); WriteFdExactly(fd.get(), _start_req, strlen(_start_req)); - register_socket_transport(std::move(fd), std::move(serial), port, 1, - [](atransport*) { return ReconnectResult::Abort; }); + register_socket_transport( + std::move(fd), std::move(serial), port, 1, + [](atransport*) { return ReconnectResult::Abort; }, false); } /* Prepare for accepting of the next ADB host connection. */ diff --git a/adb/daemon/usb.cpp b/adb/daemon/usb.cpp index a9ad805d8..c7f8895b8 100644 --- a/adb/daemon/usb.cpp +++ b/adb/daemon/usb.cpp @@ -260,6 +260,12 @@ struct UsbFfsConnection : public Connection { CHECK_EQ(static_cast(rc), sizeof(notify)); } + virtual bool DoTlsHandshake(RSA* key, std::string* auth_key) override final { + // TODO: support TLS for usb connections. + LOG(FATAL) << "Not supported yet."; + return false; + } + private: void StartMonitor() { // This is a bit of a mess. diff --git a/adb/fdevent/fdevent_test.h b/adb/fdevent/fdevent_test.h index 2139d0f66..ecda4da97 100644 --- a/adb/fdevent/fdevent_test.h +++ b/adb/fdevent/fdevent_test.h @@ -48,6 +48,12 @@ class FdeventTest : public ::testing::Test { protected: unique_fd dummy; + ~FdeventTest() { + if (thread_.joinable()) { + TerminateThread(); + } + } + static void SetUpTestCase() { #if !defined(_WIN32) ASSERT_NE(SIG_ERR, signal(SIGPIPE, SIG_IGN)); diff --git a/adb/pairing_connection/Android.bp b/adb/pairing_connection/Android.bp new file mode 100644 index 000000000..c05385475 --- /dev/null +++ b/adb/pairing_connection/Android.bp @@ -0,0 +1,185 @@ +// Copyright (C) 2020 The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +cc_defaults { + name: "libadb_pairing_connection_defaults", + cflags: [ + "-Wall", + "-Wextra", + "-Wthread-safety", + "-Werror", + ], + + compile_multilib: "both", + + srcs: [ + "pairing_connection.cpp", + ], + target: { + android: { + version_script: "libadb_pairing_connection.map.txt", + }, + windows: { + compile_multilib: "first", + enabled: true, + }, + }, + export_include_dirs: ["include"], + + visibility: [ + "//art:__subpackages__", + "//system/core/adb:__subpackages__", + "//frameworks/base/services:__subpackages__", + ], + apex_available: [ + "com.android.adbd", + ], + + // libadb_pairing_connection doesn't need an embedded build number. + use_version_lib: false, + + stl: "libc++_static", + + host_supported: true, + recovery_available: true, + + static_libs: [ + "libbase", + "libssl", + ], + shared_libs: [ + "libcrypto", + "liblog", + "libadb_pairing_auth", + ], +} + +cc_library { + name: "libadb_pairing_connection", + defaults: ["libadb_pairing_connection_defaults"], + + apex_available: [ + "com.android.adbd", + ], + + stubs: { + symbol_file: "libadb_pairing_connection.map.txt", + versions: ["30"], + }, + + static_libs: [ + "libadb_protos", + // Statically link libadb_tls_connection because it is not + // ABI-stable. + "libadb_tls_connection", + "libprotobuf-cpp-lite", + ], +} + +// For running atest (b/147158681) +cc_library_static { + name: "libadb_pairing_connection_static", + defaults: ["libadb_pairing_connection_defaults"], + + apex_available: [ + "//apex_available:platform", + ], + + static_libs: [ + "libadb_protos_static", + "libprotobuf-cpp-lite", + "libadb_tls_connection_static", + ], +} + +cc_defaults { + name: "libadb_pairing_server_defaults", + cflags: [ + "-Wall", + "-Wextra", + "-Wthread-safety", + "-Werror", + ], + + compile_multilib: "both", + + srcs: [ + "pairing_server.cpp", + ], + target: { + android: { + version_script: "libadb_pairing_server.map.txt", + }, + }, + export_include_dirs: ["include"], + + visibility: [ + "//art:__subpackages__", + "//system/core/adb:__subpackages__", + "//frameworks/base/services:__subpackages__", + ], + + host_supported: true, + recovery_available: true, + + stl: "libc++_static", + + static_libs: [ + "libbase", + ], + shared_libs: [ + "libcrypto", + "libcrypto_utils", + "libcutils", + "liblog", + "libadb_pairing_auth", + "libadb_pairing_connection", + ], +} + +cc_library { + name: "libadb_pairing_server", + defaults: ["libadb_pairing_server_defaults"], + + apex_available: [ + "com.android.adbd", + ], + + stubs: { + symbol_file: "libadb_pairing_server.map.txt", + versions: ["30"], + }, + + static_libs: [ + // Statically link libadb_crypto because it is not + // ABI-stable. + "libadb_crypto", + "libadb_protos", + ], +} + +// For running atest (b/147158681) +cc_library_static { + name: "libadb_pairing_server_static", + defaults: ["libadb_pairing_server_defaults"], + + apex_available: [ + "//apex_available:platform", + ], + + static_libs: [ + "libadb_crypto_static", + "libadb_protos_static", + ], +} diff --git a/adb/pairing_connection/include/adb/pairing/pairing_connection.h b/adb/pairing_connection/include/adb/pairing/pairing_connection.h new file mode 100644 index 000000000..3543b8738 --- /dev/null +++ b/adb/pairing_connection/include/adb/pairing/pairing_connection.h @@ -0,0 +1,130 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#if !defined(__INTRODUCED_IN) +#define __INTRODUCED_IN(__api_level) /* nothing */ +#endif + +// These APIs are for the Adb pairing protocol. This protocol requires both +// sides to possess a shared secret to authenticate each other. The connection +// is over TLS, and requires that both the client and server have a valid +// certificate. +// +// This protocol is one-to-one, i.e., one PairingConnectionCtx server instance +// interacts with only one PairingConnectionCtx client instance. In other words, +// every new client instance must be bound to a new server instance. +// +// If both sides have authenticated, they will exchange their peer information +// (see #PeerInfo). +__BEGIN_DECLS +#if !defined(__ANDROID__) || __ANDROID_API__ >= 30 + +const uint32_t kMaxPeerInfoSize = 8192; +struct PeerInfo { + uint8_t type; + uint8_t data[kMaxPeerInfoSize - 1]; +} __attribute__((packed)); +typedef struct PeerInfo PeerInfo; +static_assert(sizeof(PeerInfo) == kMaxPeerInfoSize, "PeerInfo has weird size"); + +enum PeerInfoType : uint8_t { + ADB_RSA_PUB_KEY = 0, + ADB_DEVICE_GUID = 1, +}; + +struct PairingConnectionCtx; +typedef struct PairingConnectionCtx PairingConnectionCtx; +typedef void (*pairing_result_cb)(const PeerInfo*, int, void*); + +// Starts the pairing connection on a separate thread. +// +// Upon completion, if the pairing was successful, +// |cb| will be called with the peer information and certificate. +// Otherwise, |cb| will be called with empty data. |fd| should already +// be opened. PairingConnectionCtx will take ownership of the |fd|. +// +// Pairing is successful if both server/client uses the same non-empty +// |pswd|, and they are able to exchange the information. |pswd| and +// |certificate| must be non-empty. start() can only be called once in the +// lifetime of this object. +// +// @param ctx the PairingConnectionCtx instance. Will abort if null. +// @param fd the fd connecting the peers. This will take ownership of fd. +// @param cb the user-provided callback that is called with the result of the +// pairing. The callback will be called on a different thread from the +// caller. +// @param opaque opaque userdata. +// @return true if the thread was successfully started, false otherwise. To stop +// the connection process, destroy the instance (see +// #pairing_connection_destroy). If false is returned, cb will not be +// invoked. Otherwise, cb is guaranteed to be invoked, even if you +// destroy the ctx while in the pairing process. +bool pairing_connection_start(PairingConnectionCtx* ctx, int fd, pairing_result_cb cb, void* opaque) + __INTRODUCED_IN(30); + +// Creates a new PairingConnectionCtx instance as the client. +// +// @param pswd the password to authenticate both peers. Will abort if null. +// @param pswd_len the length of pswd. Will abort if 0. +// @param peer_info the PeerInfo struct that is exchanged between peers if the +// pairing was successful. Will abort if null. +// @param x509_cert_pem the X.509 certificate in PEM format. Will abort if null. +// @param x509_size the size of x509_cert_pem. Will abort if 0. +// @param priv_key_pem the private key corresponding to the given X.509 +// certificate, in PEM format. Will abort if null. +// @param priv_size the size of priv_key_pem. Will abort if 0. +// @return a new PairingConnectionCtx client instance. The caller is responsible +// for destroying the context via #pairing_connection_destroy. +PairingConnectionCtx* pairing_connection_client_new(const uint8_t* pswd, size_t pswd_len, + const PeerInfo* peer_info, + const uint8_t* x509_cert_pem, size_t x509_size, + const uint8_t* priv_key_pem, size_t priv_size) + __INTRODUCED_IN(30); + +// Creates a new PairingConnectionCtx instance as the server. +// +// @param pswd the password to authenticate both peers. Will abort if null. +// @param pswd_len the length of pswd. Will abort if 0. +// @param peer_info the PeerInfo struct that is exchanged between peers if the +// pairing was successful. Will abort if null. +// @param x509_cert_pem the X.509 certificate in PEM format. Will abort if null. +// @param x509_size the size of x509_cert_pem. Will abort if 0. +// @param priv_key_pem the private key corresponding to the given X.509 +// certificate, in PEM format. Will abort if null. +// @param priv_size the size of priv_key_pem. Will abort if 0. +// @return a new PairingConnectionCtx server instance. The caller is responsible +// for destroying the context via #pairing_connection_destroy. +PairingConnectionCtx* pairing_connection_server_new(const uint8_t* pswd, size_t pswd_len, + const PeerInfo* peer_info, + const uint8_t* x509_cert_pem, size_t x509_size, + const uint8_t* priv_key_pem, size_t priv_size) + __INTRODUCED_IN(30); + +// Destroys the PairingConnectionCtx instance. +// +// It is safe to destroy the instance at any point in the pairing process. +// +// @param ctx the PairingConnectionCtx instance to destroy. Will abort if null. +void pairing_connection_destroy(PairingConnectionCtx* ctx) __INTRODUCED_IN(30); + +#endif //!__ANDROID__ || __ANDROID_API__ >= 30 +__END_DECLS diff --git a/adb/pairing_connection/include/adb/pairing/pairing_server.h b/adb/pairing_connection/include/adb/pairing/pairing_server.h new file mode 100644 index 000000000..178a174bd --- /dev/null +++ b/adb/pairing_connection/include/adb/pairing/pairing_server.h @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +#include "adb/pairing/pairing_connection.h" + +#if !defined(__INTRODUCED_IN) +#define __INTRODUCED_IN(__api_level) /* nothing */ +#endif + +__BEGIN_DECLS +#if !defined(__ANDROID__) || __ANDROID_API__ >= 30 + +// PairingServerCtx is a wrapper around the #PairingConnectionCtx APIs, +// which handles multiple client connections. +// +// See pairing_connection_test.cpp for example usage. +// +struct PairingServerCtx; +typedef struct PairingServerCtx PairingServerCtx; + +// Callback containing the result of the pairing. If #PeerInfo is null, +// then the pairing failed. Otherwise, pairing succeeded and #PeerInfo +// contains information about the peer. +typedef void (*pairing_server_result_cb)(const PeerInfo*, void*) __INTRODUCED_IN(30); + +// Starts the pairing server. +// +// This call is non-blocking. Upon completion, if the pairing was successful, +// then |cb| will be called with the PeerInfo +// containing the info of the trusted peer. Otherwise, |cb| will be +// called with an empty value. Start can only be called once in the lifetime +// of this object. +// +// @param ctx the PairingServerCtx instance. +// @param cb the user-provided callback to notify the result of the pairing. See +// #pairing_server_result_cb. +// @param opaque the opaque userdata. +// @return the port number the server is listening on. Returns 0 on failure. +uint16_t pairing_server_start(PairingServerCtx* ctx, pairing_server_result_cb cb, void* opaque) + __INTRODUCED_IN(30); + +// Creates a new PairingServerCtx instance. +// +// @param pswd the password used to authenticate the client and server. +// @param pswd_len the length of pswd. +// @param peer_info the #PeerInfo struct passed to the client on successful +// pairing. +// @param x509_cert_pem the X.509 certificate in PEM format. Cannot be empty. +// @param x509_size the size of x509_cert_pem. +// @param priv_key_pem the private key corresponding to the given X.509 +// certificate, in PEM format. Cannot be empty. +// @param priv_size the size of priv_key_pem. +// @param port the port number the server should listen on. Must be within the +// valid port range [0, 65535]. If port is 0, then the server will +// find an open port to listen on. See #pairing_server_start to +// obtain the port used. +// @return a new PairingServerCtx instance The caller is responsible +// for destroying the context via #pairing_server_destroy. +PairingServerCtx* pairing_server_new(const uint8_t* pswd, size_t pswd_len, + const PeerInfo* peer_info, const uint8_t* x509_cert_pem, + size_t x509_size, const uint8_t* priv_key_pem, + size_t priv_size, uint16_t port) __INTRODUCED_IN(30); + +// Same as #pairing_server_new, except that the x509 certificate and private key +// is generated internally. +// +// @param pswd the password used to authenticate the client and server. +// @param pswd_len the length of pswd. +// @param peer_info the #PeerInfo struct passed to the client on successful +// pairing. +// @param port the port number the server should listen on. Must be within the +// valid port range [0, 65535]. If port is 0, then the server will +// find an open port to listen on. See #pairing_server_start to +// obtain the port used. +// @return a new PairingServerCtx instance The caller is responsible +// for destroying the context via #pairing_server_destroy. +PairingServerCtx* pairing_server_new_no_cert(const uint8_t* pswd, size_t pswd_len, + const PeerInfo* peer_info, uint16_t port) + __INTRODUCED_IN(30); + +// Destroys the PairingServerCtx instance. +// +// @param ctx the PairingServerCtx instance to destroy. +void pairing_server_destroy(PairingServerCtx* ctx) __INTRODUCED_IN(30); + +#endif //!__ANDROID__ || __ANDROID_API__ >= 30 +__END_DECLS diff --git a/adb/pairing_connection/internal/constants.h b/adb/pairing_connection/internal/constants.h new file mode 100644 index 000000000..9a04f174e --- /dev/null +++ b/adb/pairing_connection/internal/constants.h @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// This file contains constants that can be used both in the pairing_connection +// code and tested in the pairing_connection_test code. +namespace adb { +namespace pairing { +namespace internal { + +// The maximum number of connections the PairingServer can handle at once. +constexpr int kMaxConnections = 10; +// The maximum number of attempts the PairingServer will take before quitting. +// This is to prevent someone malicious from quickly brute-forcing every +// combination. +constexpr int kMaxPairingAttempts = 20; + +} // namespace internal +} // namespace pairing +} // namespace adb diff --git a/adb/pairing_connection/libadb_pairing_connection.map.txt b/adb/pairing_connection/libadb_pairing_connection.map.txt new file mode 100644 index 000000000..abd5f16d4 --- /dev/null +++ b/adb/pairing_connection/libadb_pairing_connection.map.txt @@ -0,0 +1,10 @@ +LIBADB_PAIRING_CONNECTION { + global: + pairing_connection_client_new; # apex introduced=30 + pairing_connection_server_new; # apex introduced=30 + pairing_connection_start; # apex introduced=30 + pairing_connection_destroy; # apex introduced=30 + + local: + *; +}; diff --git a/adb/pairing_connection/libadb_pairing_server.map.txt b/adb/pairing_connection/libadb_pairing_server.map.txt new file mode 100644 index 000000000..dc0dc89a3 --- /dev/null +++ b/adb/pairing_connection/libadb_pairing_server.map.txt @@ -0,0 +1,10 @@ +LIBADB_PAIRING_SERVER { + global: + pairing_server_start; # apex introduced=30 + pairing_server_new; # apex introduced=30 + pairing_server_new_no_cert; # apex introduced=30 + pairing_server_destroy; # apex introduced=30 + + local: + *; +}; diff --git a/adb/pairing_connection/pairing_connection.cpp b/adb/pairing_connection/pairing_connection.cpp new file mode 100644 index 000000000..a26a6b4d2 --- /dev/null +++ b/adb/pairing_connection/pairing_connection.cpp @@ -0,0 +1,491 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "adb/pairing/pairing_connection.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "pairing.pb.h" + +using namespace adb; +using android::base::unique_fd; +using TlsError = tls::TlsConnection::TlsError; + +const uint8_t kCurrentKeyHeaderVersion = 1; +const uint8_t kMinSupportedKeyHeaderVersion = 1; +const uint8_t kMaxSupportedKeyHeaderVersion = 1; +const uint32_t kMaxPayloadSize = kMaxPeerInfoSize * 2; + +struct PairingPacketHeader { + uint8_t version; // PairingPacket version + uint8_t type; // the type of packet (PairingPacket.Type) + uint32_t payload; // Size of the payload in bytes +} __attribute__((packed)); + +struct PairingAuthDeleter { + void operator()(PairingAuthCtx* p) { pairing_auth_destroy(p); } +}; // PairingAuthDeleter +using PairingAuthPtr = std::unique_ptr; + +// PairingConnectionCtx encapsulates the protocol to authenticate two peers with +// each other. This class will open the tcp sockets and handle the pairing +// process. On completion, both sides will have each other's public key +// (certificate) if successful, otherwise, the pairing failed. The tcp port +// number is hardcoded (see pairing_connection.cpp). +// +// Each PairingConnectionCtx instance represents a different device trying to +// pair. So for the device, we can have multiple PairingConnectionCtxs while the +// host may have only one (unless host has a PairingServer). +// +// See pairing_connection_test.cpp for example usage. +// +struct PairingConnectionCtx { + public: + using Data = std::vector; + using ResultCallback = pairing_result_cb; + enum class Role { + Client, + Server, + }; + + explicit PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info, + const Data& certificate, const Data& priv_key); + virtual ~PairingConnectionCtx(); + + // Starts the pairing connection on a separate thread. + // Upon completion, if the pairing was successful, + // |cb| will be called with the peer information and certificate. + // Otherwise, |cb| will be called with empty data. |fd| should already + // be opened. PairingConnectionCtx will take ownership of the |fd|. + // + // Pairing is successful if both server/client uses the same non-empty + // |pswd|, and they are able to exchange the information. |pswd| and + // |certificate| must be non-empty. Start() can only be called once in the + // lifetime of this object. + // + // Returns true if the thread was successfully started, false otherwise. + bool Start(int fd, ResultCallback cb, void* opaque); + + private: + // Setup the tls connection. + bool SetupTlsConnection(); + + /************ PairingPacketHeader methods ****************/ + // Tries to write out the header and payload. + bool WriteHeader(const PairingPacketHeader* header, std::string_view payload); + // Tries to parse incoming data into the |header|. Returns true if header + // is valid and header version is supported. |header| is filled on success. + // |header| may contain garbage if unsuccessful. + bool ReadHeader(PairingPacketHeader* header); + // Creates a PairingPacketHeader. + void CreateHeader(PairingPacketHeader* header, adb::proto::PairingPacket::Type type, + uint32_t payload_size); + // Checks if actual matches expected. + bool CheckHeaderType(adb::proto::PairingPacket::Type expected, uint8_t actual); + + /*********** State related methods **************/ + // Handles the State::ExchangingMsgs state. + bool DoExchangeMsgs(); + // Handles the State::ExchangingPeerInfo state. + bool DoExchangePeerInfo(); + + // The background task to do the pairing. + void StartWorker(); + + // Calls |cb_| and sets the state to Stopped. + void NotifyResult(const PeerInfo* p); + + static PairingAuthPtr CreatePairingAuthPtr(Role role, const Data& pswd); + + enum class State { + Ready, + ExchangingMsgs, + ExchangingPeerInfo, + Stopped, + }; + + std::atomic state_{State::Ready}; + Role role_; + Data pswd_; + PeerInfo peer_info_; + Data cert_; + Data priv_key_; + + // Peer's info + PeerInfo their_info_; + + ResultCallback cb_; + void* opaque_ = nullptr; + std::unique_ptr tls_; + PairingAuthPtr auth_; + unique_fd fd_; + std::thread thread_; + static constexpr size_t kExportedKeySize = 64; +}; // PairingConnectionCtx + +PairingConnectionCtx::PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info, + const Data& cert, const Data& priv_key) + : role_(role), pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key) { + CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty()); +} + +PairingConnectionCtx::~PairingConnectionCtx() { + // Force close the fd and wait for the worker thread to finish. + fd_.reset(); + if (thread_.joinable()) { + thread_.join(); + } +} + +bool PairingConnectionCtx::SetupTlsConnection() { + tls_ = tls::TlsConnection::Create( + role_ == Role::Server ? tls::TlsConnection::Role::Server + : tls::TlsConnection::Role::Client, + std::string_view(reinterpret_cast(cert_.data()), cert_.size()), + std::string_view(reinterpret_cast(priv_key_.data()), priv_key_.size()), + fd_); + + if (tls_ == nullptr) { + LOG(ERROR) << "Unable to start TlsConnection. Unable to pair fd=" << fd_.get(); + return false; + } + + // Allow any peer certificate + tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; }); + + // SSL doesn't seem to behave correctly with fdevents so just do a blocking + // read for the pairing data. + if (tls_->DoHandshake() != TlsError::Success) { + LOG(ERROR) << "Failed to handshake with the peer fd=" << fd_.get(); + return false; + } + + // To ensure the connection is not stolen while we do the PAKE, append the + // exported key material from the tls connection to the password. + std::vector exportedKeyMaterial = tls_->ExportKeyingMaterial(kExportedKeySize); + if (exportedKeyMaterial.empty()) { + LOG(ERROR) << "Failed to export key material"; + return false; + } + pswd_.insert(pswd_.end(), std::make_move_iterator(exportedKeyMaterial.begin()), + std::make_move_iterator(exportedKeyMaterial.end())); + auth_ = CreatePairingAuthPtr(role_, pswd_); + + return true; +} + +bool PairingConnectionCtx::WriteHeader(const PairingPacketHeader* header, + std::string_view payload) { + PairingPacketHeader network_header = *header; + network_header.payload = htonl(network_header.payload); + if (!tls_->WriteFully(std::string_view(reinterpret_cast(&network_header), + sizeof(PairingPacketHeader))) || + !tls_->WriteFully(payload)) { + LOG(ERROR) << "Failed to write out PairingPacketHeader"; + state_ = State::Stopped; + return false; + } + return true; +} + +bool PairingConnectionCtx::ReadHeader(PairingPacketHeader* header) { + auto data = tls_->ReadFully(sizeof(PairingPacketHeader)); + if (data.empty()) { + return false; + } + + uint8_t* p = data.data(); + // First byte is always PairingPacketHeader version + header->version = *p; + ++p; + if (header->version < kMinSupportedKeyHeaderVersion || + header->version > kMaxSupportedKeyHeaderVersion) { + LOG(ERROR) << "PairingPacketHeader version mismatch (us=" << kCurrentKeyHeaderVersion + << " them=" << header->version << ")"; + return false; + } + // Next byte is the PairingPacket::Type + if (!adb::proto::PairingPacket::Type_IsValid(*p)) { + LOG(ERROR) << "Unknown PairingPacket type=" << static_cast(*p); + return false; + } + header->type = *p; + ++p; + // Last, the payload size + header->payload = ntohl(*(reinterpret_cast(p))); + if (header->payload == 0 || header->payload > kMaxPayloadSize) { + LOG(ERROR) << "header payload not within a safe payload size (size=" << header->payload + << ")"; + return false; + } + + return true; +} + +void PairingConnectionCtx::CreateHeader(PairingPacketHeader* header, + adb::proto::PairingPacket::Type type, + uint32_t payload_size) { + header->version = kCurrentKeyHeaderVersion; + uint8_t type8 = static_cast(static_cast(type)); + header->type = type8; + header->payload = payload_size; +} + +bool PairingConnectionCtx::CheckHeaderType(adb::proto::PairingPacket::Type expected_type, + uint8_t actual) { + uint8_t expected = *reinterpret_cast(&expected_type); + if (actual != expected) { + LOG(ERROR) << "Unexpected header type (expected=" << static_cast(expected) + << " actual=" << static_cast(actual) << ")"; + return false; + } + return true; +} + +void PairingConnectionCtx::NotifyResult(const PeerInfo* p) { + cb_(p, fd_.get(), opaque_); + state_ = State::Stopped; +} + +bool PairingConnectionCtx::Start(int fd, ResultCallback cb, void* opaque) { + if (fd < 0) { + return false; + } + + State expected = State::Ready; + if (!state_.compare_exchange_strong(expected, State::ExchangingMsgs)) { + return false; + } + + fd_.reset(fd); + cb_ = cb; + opaque_ = opaque; + + thread_ = std::thread([this] { StartWorker(); }); + return true; +} + +bool PairingConnectionCtx::DoExchangeMsgs() { + uint32_t payload = pairing_auth_msg_size(auth_.get()); + std::vector msg(payload); + pairing_auth_get_spake2_msg(auth_.get(), msg.data()); + + PairingPacketHeader header; + CreateHeader(&header, adb::proto::PairingPacket::SPAKE2_MSG, payload); + + // Write our SPAKE2 msg + if (!WriteHeader(&header, + std::string_view(reinterpret_cast(msg.data()), msg.size()))) { + LOG(ERROR) << "Failed to write SPAKE2 msg."; + return false; + } + + // Read the peer's SPAKE2 msg header + if (!ReadHeader(&header)) { + LOG(ERROR) << "Invalid PairingPacketHeader."; + return false; + } + if (!CheckHeaderType(adb::proto::PairingPacket::SPAKE2_MSG, header.type)) { + return false; + } + + // Read the SPAKE2 msg payload and initialize the cipher for + // encrypting the PeerInfo and certificate. + auto their_msg = tls_->ReadFully(header.payload); + if (their_msg.empty() || + !pairing_auth_init_cipher(auth_.get(), their_msg.data(), their_msg.size())) { + LOG(ERROR) << "Unable to initialize pairing cipher [their_msg.size=" << their_msg.size() + << "]"; + return false; + } + + return true; +} + +bool PairingConnectionCtx::DoExchangePeerInfo() { + // Encrypt PeerInfo + std::vector buf; + uint8_t* p = reinterpret_cast(&peer_info_); + buf.assign(p, p + sizeof(peer_info_)); + std::vector outbuf(pairing_auth_safe_encrypted_size(auth_.get(), buf.size())); + CHECK(!outbuf.empty()); + size_t outsize; + if (!pairing_auth_encrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) { + LOG(ERROR) << "Failed to encrypt peer info"; + return false; + } + outbuf.resize(outsize); + + // Write out the packet header + PairingPacketHeader out_header; + out_header.version = kCurrentKeyHeaderVersion; + out_header.type = static_cast(static_cast(adb::proto::PairingPacket::PEER_INFO)); + out_header.payload = htonl(outbuf.size()); + if (!tls_->WriteFully( + std::string_view(reinterpret_cast(&out_header), sizeof(out_header)))) { + LOG(ERROR) << "Unable to write PairingPacketHeader"; + return false; + } + + // Write out the encrypted payload + if (!tls_->WriteFully( + std::string_view(reinterpret_cast(outbuf.data()), outbuf.size()))) { + LOG(ERROR) << "Unable to write encrypted peer info"; + return false; + } + + // Read in the peer's packet header + PairingPacketHeader header; + if (!ReadHeader(&header)) { + LOG(ERROR) << "Invalid PairingPacketHeader."; + return false; + } + + if (!CheckHeaderType(adb::proto::PairingPacket::PEER_INFO, header.type)) { + return false; + } + + // Read in the encrypted peer certificate + buf = tls_->ReadFully(header.payload); + if (buf.empty()) { + return false; + } + + // Try to decrypt the certificate + outbuf.resize(pairing_auth_safe_decrypted_size(auth_.get(), buf.data(), buf.size())); + if (outbuf.empty()) { + LOG(ERROR) << "Unsupported payload while decrypting peer info."; + return false; + } + + if (!pairing_auth_decrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) { + LOG(ERROR) << "Failed to decrypt"; + return false; + } + outbuf.resize(outsize); + + // The decrypted message should contain the PeerInfo. + if (outbuf.size() != sizeof(PeerInfo)) { + LOG(ERROR) << "Got size=" << outbuf.size() << "PeerInfo.size=" << sizeof(PeerInfo); + return false; + } + + p = outbuf.data(); + ::memcpy(&their_info_, p, sizeof(PeerInfo)); + p += sizeof(PeerInfo); + + return true; +} + +void PairingConnectionCtx::StartWorker() { + // Setup the secure transport + if (!SetupTlsConnection()) { + NotifyResult(nullptr); + return; + } + + for (;;) { + switch (state_) { + case State::ExchangingMsgs: + if (!DoExchangeMsgs()) { + NotifyResult(nullptr); + return; + } + state_ = State::ExchangingPeerInfo; + break; + case State::ExchangingPeerInfo: + if (!DoExchangePeerInfo()) { + NotifyResult(nullptr); + return; + } + NotifyResult(&their_info_); + return; + case State::Ready: + case State::Stopped: + LOG(FATAL) << __func__ << ": Got invalid state"; + return; + } + } +} + +// static +PairingAuthPtr PairingConnectionCtx::CreatePairingAuthPtr(Role role, const Data& pswd) { + switch (role) { + case Role::Client: + return PairingAuthPtr(pairing_auth_client_new(pswd.data(), pswd.size())); + break; + case Role::Server: + return PairingAuthPtr(pairing_auth_server_new(pswd.data(), pswd.size())); + break; + } +} + +static PairingConnectionCtx* CreateConnection(PairingConnectionCtx::Role role, const uint8_t* pswd, + size_t pswd_len, const PeerInfo* peer_info, + const uint8_t* x509_cert_pem, size_t x509_size, + const uint8_t* priv_key_pem, size_t priv_size) { + CHECK(pswd); + CHECK_GT(pswd_len, 0U); + CHECK(x509_cert_pem); + CHECK_GT(x509_size, 0U); + CHECK(priv_key_pem); + CHECK_GT(priv_size, 0U); + CHECK(peer_info); + std::vector vec_pswd(pswd, pswd + pswd_len); + std::vector vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size); + std::vector vec_priv_key(priv_key_pem, priv_key_pem + priv_size); + return new PairingConnectionCtx(role, vec_pswd, *peer_info, vec_x509_cert, vec_priv_key); +} + +PairingConnectionCtx* pairing_connection_client_new(const uint8_t* pswd, size_t pswd_len, + const PeerInfo* peer_info, + const uint8_t* x509_cert_pem, size_t x509_size, + const uint8_t* priv_key_pem, size_t priv_size) { + return CreateConnection(PairingConnectionCtx::Role::Client, pswd, pswd_len, peer_info, + x509_cert_pem, x509_size, priv_key_pem, priv_size); +} + +PairingConnectionCtx* pairing_connection_server_new(const uint8_t* pswd, size_t pswd_len, + const PeerInfo* peer_info, + const uint8_t* x509_cert_pem, size_t x509_size, + const uint8_t* priv_key_pem, size_t priv_size) { + return CreateConnection(PairingConnectionCtx::Role::Server, pswd, pswd_len, peer_info, + x509_cert_pem, x509_size, priv_key_pem, priv_size); +} + +void pairing_connection_destroy(PairingConnectionCtx* ctx) { + CHECK(ctx); + delete ctx; +} + +bool pairing_connection_start(PairingConnectionCtx* ctx, int fd, pairing_result_cb cb, + void* opaque) { + return ctx->Start(fd, cb, opaque); +} diff --git a/adb/pairing_connection/pairing_server.cpp b/adb/pairing_connection/pairing_server.cpp new file mode 100644 index 000000000..7218eacf2 --- /dev/null +++ b/adb/pairing_connection/pairing_server.cpp @@ -0,0 +1,466 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "adb/pairing/pairing_server.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "internal/constants.h" + +using android::base::ScopedLockAssertion; +using android::base::unique_fd; +using namespace adb::crypto; +using namespace adb::pairing; + +// The implementation has two background threads running: one to handle and +// accept any new pairing connection requests (socket accept), and the other to +// handle connection events (connection started, connection finished). +struct PairingServerCtx { + public: + using Data = std::vector; + + virtual ~PairingServerCtx(); + + // All parameters must be non-empty. + explicit PairingServerCtx(const Data& pswd, const PeerInfo& peer_info, const Data& cert, + const Data& priv_key, uint16_t port); + + // Starts the pairing server. This call is non-blocking. Upon completion, + // if the pairing was successful, then |cb| will be called with the PublicKeyHeader + // containing the info of the trusted peer. Otherwise, |cb| will be + // called with an empty value. Start can only be called once in the lifetime + // of this object. + // + // Returns the port number if PairingServerCtx was successfully started. Otherwise, + // returns 0. + uint16_t Start(pairing_server_result_cb cb, void* opaque); + + private: + // Setup the server socket to accept incoming connections. Returns the + // server port number (> 0 on success). + uint16_t SetupServer(); + // Force stop the server thread. + void StopServer(); + + // handles a new pairing client connection + bool HandleNewClientConnection(int fd) EXCLUDES(conn_mutex_); + + // ======== connection events thread ============= + std::mutex conn_mutex_; + std::condition_variable conn_cv_; + + using FdVal = int; + struct ConnectionDeleter { + void operator()(PairingConnectionCtx* p) { pairing_connection_destroy(p); } + }; + using ConnectionPtr = std::unique_ptr; + static ConnectionPtr CreatePairingConnection(const Data& pswd, const PeerInfo& peer_info, + const Data& cert, const Data& priv_key); + using NewConnectionEvent = std::tuple; + // + using ConnectionFinishedEvent = std::tuple>; + using ConnectionEvent = std::variant; + // Queue for connections to write into. We have a separate queue to read + // from, in order to minimize the time the server thread is blocked. + std::deque conn_write_queue_ GUARDED_BY(conn_mutex_); + std::deque conn_read_queue_; + // Map of fds to their PairingConnections currently running. + std::unordered_map connections_; + + // Two threads launched when starting the pairing server: + // 1) A server thread that waits for incoming client connections, and + // 2) A connection events thread that synchonizes events from all of the + // clients, since each PairingConnection is running in it's own thread. + void StartConnectionEventsThread(); + void StartServerThread(); + + static void PairingConnectionCallback(const PeerInfo* peer_info, int fd, void* opaque); + + std::thread conn_events_thread_; + void ConnectionEventsWorker(); + std::thread server_thread_; + void ServerWorker(); + bool is_terminate_ GUARDED_BY(conn_mutex_) = false; + + enum class State { + Ready, + Running, + Stopped, + }; + State state_ = State::Ready; + Data pswd_; + PeerInfo peer_info_; + Data cert_; + Data priv_key_; + uint16_t port_; + + pairing_server_result_cb cb_; + void* opaque_ = nullptr; + bool got_valid_pairing_ = false; + + static const int kEpollConstSocket = 0; + // Used to break the server thread from epoll_wait + static const int kEpollConstEventFd = 1; + unique_fd epoll_fd_; + unique_fd server_fd_; + unique_fd event_fd_; +}; // PairingServerCtx + +// static +PairingServerCtx::ConnectionPtr PairingServerCtx::CreatePairingConnection(const Data& pswd, + const PeerInfo& peer_info, + const Data& cert, + const Data& priv_key) { + return ConnectionPtr(pairing_connection_server_new(pswd.data(), pswd.size(), &peer_info, + cert.data(), cert.size(), priv_key.data(), + priv_key.size())); +} + +PairingServerCtx::PairingServerCtx(const Data& pswd, const PeerInfo& peer_info, const Data& cert, + const Data& priv_key, uint16_t port) + : pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key), port_(port) { + CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty()); +} + +PairingServerCtx::~PairingServerCtx() { + // Since these connections have references to us, let's make sure they + // destruct before us. + if (server_thread_.joinable()) { + StopServer(); + server_thread_.join(); + } + + { + std::lock_guard lock(conn_mutex_); + is_terminate_ = true; + } + conn_cv_.notify_one(); + if (conn_events_thread_.joinable()) { + conn_events_thread_.join(); + } + + // Notify the cb_ if it hasn't already. + if (!got_valid_pairing_ && cb_ != nullptr) { + cb_(nullptr, opaque_); + } +} + +uint16_t PairingServerCtx::Start(pairing_server_result_cb cb, void* opaque) { + cb_ = cb; + opaque_ = opaque; + + if (state_ != State::Ready) { + LOG(ERROR) << "PairingServerCtx already running or stopped"; + return 0; + } + + port_ = SetupServer(); + if (port_ == 0) { + LOG(ERROR) << "Unable to start PairingServer"; + state_ = State::Stopped; + return 0; + } + LOG(INFO) << "Pairing server started on port " << port_; + + state_ = State::Running; + return port_; +} + +void PairingServerCtx::StopServer() { + if (event_fd_.get() == -1) { + return; + } + uint64_t value = 1; + ssize_t rc = write(event_fd_.get(), &value, sizeof(value)); + if (rc == -1) { + // This can happen if the server didn't start. + PLOG(ERROR) << "write to eventfd failed"; + } else if (rc != sizeof(value)) { + LOG(FATAL) << "write to event returned short (" << rc << ")"; + } +} + +uint16_t PairingServerCtx::SetupServer() { + epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC)); + if (epoll_fd_ == -1) { + PLOG(ERROR) << "failed to create epoll fd"; + return 0; + } + + event_fd_.reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)); + if (event_fd_ == -1) { + PLOG(ERROR) << "failed to create eventfd"; + return 0; + } + + server_fd_.reset(socket_inaddr_any_server(port_, SOCK_STREAM)); + if (server_fd_.get() == -1) { + PLOG(ERROR) << "Failed to start pairing connection server"; + return 0; + } else if (fcntl(server_fd_.get(), F_SETFD, FD_CLOEXEC) != 0) { + PLOG(ERROR) << "Failed to make server socket cloexec"; + return 0; + } else if (fcntl(server_fd_.get(), F_SETFD, O_NONBLOCK) != 0) { + PLOG(ERROR) << "Failed to make server socket nonblocking"; + return 0; + } + + StartConnectionEventsThread(); + StartServerThread(); + int port = socket_get_local_port(server_fd_.get()); + return (port <= 0 ? 0 : port); +} + +void PairingServerCtx::StartServerThread() { + server_thread_ = std::thread([this]() { ServerWorker(); }); +} + +void PairingServerCtx::StartConnectionEventsThread() { + conn_events_thread_ = std::thread([this]() { ConnectionEventsWorker(); }); +} + +void PairingServerCtx::ServerWorker() { + { + struct epoll_event event; + event.events = EPOLLIN; + event.data.u64 = kEpollConstSocket; + CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, server_fd_.get(), &event)); + } + + { + struct epoll_event event; + event.events = EPOLLIN; + event.data.u64 = kEpollConstEventFd; + CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, event_fd_.get(), &event)); + } + + while (true) { + struct epoll_event events[2]; + int rc = TEMP_FAILURE_RETRY(epoll_wait(epoll_fd_.get(), events, 2, -1)); + if (rc == -1) { + PLOG(ERROR) << "epoll_wait failed"; + return; + } else if (rc == 0) { + LOG(ERROR) << "epoll_wait returned 0"; + return; + } + + for (int i = 0; i < rc; ++i) { + struct epoll_event& event = events[i]; + switch (event.data.u64) { + case kEpollConstSocket: + HandleNewClientConnection(server_fd_.get()); + break; + case kEpollConstEventFd: + uint64_t dummy; + int rc = TEMP_FAILURE_RETRY(read(event_fd_.get(), &dummy, sizeof(dummy))); + if (rc != sizeof(dummy)) { + PLOG(FATAL) << "failed to read from eventfd (rc=" << rc << ")"; + } + return; + } + } + } +} + +// static +void PairingServerCtx::PairingConnectionCallback(const PeerInfo* peer_info, int fd, void* opaque) { + auto* p = reinterpret_cast(opaque); + + ConnectionFinishedEvent event; + if (peer_info != nullptr) { + if (peer_info->type == ADB_RSA_PUB_KEY) { + event = std::make_tuple(fd, peer_info->type, + std::string(reinterpret_cast(peer_info->data))); + } else { + LOG(WARNING) << "Ignoring successful pairing because of unknown " + << "PeerInfo type=" << peer_info->type; + } + } else { + event = std::make_tuple(fd, 0, std::nullopt); + } + { + std::lock_guard lock(p->conn_mutex_); + p->conn_write_queue_.push_back(std::move(event)); + } + p->conn_cv_.notify_one(); +} + +void PairingServerCtx::ConnectionEventsWorker() { + uint8_t num_tries = 0; + for (;;) { + // Transfer the write queue to the read queue. + { + std::unique_lock lock(conn_mutex_); + ScopedLockAssertion assume_locked(conn_mutex_); + + if (is_terminate_) { + // We check |is_terminate_| twice because condition_variable's + // notify() only wakes up a thread if it is in the wait state + // prior to notify(). Furthermore, we aren't holding the mutex + // when processing the events in |conn_read_queue_|. + return; + } + if (conn_write_queue_.empty()) { + // We need to wait for new events, or the termination signal. + conn_cv_.wait(lock, [this]() REQUIRES(conn_mutex_) { + return (is_terminate_ || !conn_write_queue_.empty()); + }); + } + if (is_terminate_) { + // We're done. + return; + } + // Move all events into the read queue. + conn_read_queue_ = std::move(conn_write_queue_); + conn_write_queue_.clear(); + } + + // Process all events in the read queue. + while (conn_read_queue_.size() > 0) { + auto& event = conn_read_queue_.front(); + if (auto* p = std::get_if(&event)) { + // Ignore if we are already at the max number of connections + if (connections_.size() >= internal::kMaxConnections) { + conn_read_queue_.pop_front(); + continue; + } + auto [ufd, connection] = std::move(*p); + int fd = ufd.release(); + bool started = pairing_connection_start(connection.get(), fd, + PairingConnectionCallback, this); + if (!started) { + LOG(ERROR) << "PairingServer unable to start a PairingConnection fd=" << fd; + ufd.reset(fd); + } else { + connections_[fd] = std::move(connection); + } + } else if (auto* p = std::get_if(&event)) { + auto [fd, info_type, public_key] = std::move(*p); + if (public_key.has_value() && !public_key->empty()) { + // Valid pairing. Let's shutdown the server and close any + // pairing connections in progress. + StopServer(); + connections_.clear(); + + PeerInfo info = {}; + info.type = info_type; + strncpy(reinterpret_cast(info.data), public_key->data(), + public_key->size()); + + cb_(&info, opaque_); + + got_valid_pairing_ = true; + return; + } + // Invalid pairing. Close the invalid connection. + if (connections_.find(fd) != connections_.end()) { + connections_.erase(fd); + } + + if (++num_tries >= internal::kMaxPairingAttempts) { + cb_(nullptr, opaque_); + // To prevent the destructor from calling it again. + cb_ = nullptr; + return; + } + } + conn_read_queue_.pop_front(); + } + } +} + +bool PairingServerCtx::HandleNewClientConnection(int fd) { + unique_fd ufd(TEMP_FAILURE_RETRY(accept4(fd, nullptr, nullptr, SOCK_CLOEXEC))); + if (ufd == -1) { + PLOG(WARNING) << "adb_socket_accept failed fd=" << fd; + return false; + } + auto connection = CreatePairingConnection(pswd_, peer_info_, cert_, priv_key_); + if (connection == nullptr) { + LOG(ERROR) << "PairingServer unable to create a PairingConnection fd=" << fd; + return false; + } + // send the new connection to the connection thread for further processing + NewConnectionEvent event = std::make_tuple(std::move(ufd), std::move(connection)); + { + std::lock_guard lock(conn_mutex_); + conn_write_queue_.push_back(std::move(event)); + } + conn_cv_.notify_one(); + + return true; +} + +uint16_t pairing_server_start(PairingServerCtx* ctx, pairing_server_result_cb cb, void* opaque) { + return ctx->Start(cb, opaque); +} + +PairingServerCtx* pairing_server_new(const uint8_t* pswd, size_t pswd_len, + const PeerInfo* peer_info, const uint8_t* x509_cert_pem, + size_t x509_size, const uint8_t* priv_key_pem, + size_t priv_size, uint16_t port) { + CHECK(pswd); + CHECK_GT(pswd_len, 0U); + CHECK(x509_cert_pem); + CHECK_GT(x509_size, 0U); + CHECK(priv_key_pem); + CHECK_GT(priv_size, 0U); + CHECK(peer_info); + std::vector vec_pswd(pswd, pswd + pswd_len); + std::vector vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size); + std::vector vec_priv_key(priv_key_pem, priv_key_pem + priv_size); + return new PairingServerCtx(vec_pswd, *peer_info, vec_x509_cert, vec_priv_key, port); +} + +PairingServerCtx* pairing_server_new_no_cert(const uint8_t* pswd, size_t pswd_len, + const PeerInfo* peer_info, uint16_t port) { + auto rsa_2048 = CreateRSA2048Key(); + auto x509_cert = GenerateX509Certificate(rsa_2048->GetEvpPkey()); + std::string pkey_pem = Key::ToPEMString(rsa_2048->GetEvpPkey()); + std::string cert_pem = X509ToPEMString(x509_cert.get()); + + return pairing_server_new(pswd, pswd_len, peer_info, + reinterpret_cast(cert_pem.data()), cert_pem.size(), + reinterpret_cast(pkey_pem.data()), pkey_pem.size(), + port); +} + +void pairing_server_destroy(PairingServerCtx* ctx) { + CHECK(ctx); + delete ctx; +} diff --git a/adb/pairing_connection/tests/Android.bp b/adb/pairing_connection/tests/Android.bp new file mode 100644 index 000000000..bf075bcb9 --- /dev/null +++ b/adb/pairing_connection/tests/Android.bp @@ -0,0 +1,47 @@ +// +// Copyright (C) 2020 The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +cc_test { + name: "adb_pairing_connection_test", + srcs: [ + "pairing_client.cpp", + "pairing_connection_test.cpp", + ], + + compile_multilib: "first", + + shared_libs: [ + "libbase", + "libcutils", + "libcrypto", + "libcrypto_utils", + "libprotobuf-cpp-lite", + "libssl", + ], + + // Let's statically link them so we don't have to install it onto the + // system image for testing. + static_libs: [ + "libadb_pairing_auth_static", + "libadb_pairing_connection_static", + "libadb_pairing_server_static", + "libadb_crypto_static", + "libadb_protos_static", + "libadb_tls_connection_static", + ], + + test_suites: ["device-tests"], +} diff --git a/adb/pairing_connection/tests/pairing_client.cpp b/adb/pairing_connection/tests/pairing_client.cpp new file mode 100644 index 000000000..1f3ef5a34 --- /dev/null +++ b/adb/pairing_connection/tests/pairing_client.cpp @@ -0,0 +1,201 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pairing_client.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace adb { +namespace pairing { + +using android::base::unique_fd; + +static void ConnectionDeleter(PairingConnectionCtx* p) { + pairing_connection_destroy(p); +} +using ConnectionPtr = std::unique_ptr; + +namespace { + +class PairingClientImpl : public PairingClient { + public: + explicit PairingClientImpl(const Data& pswd, const PeerInfo& peer_info, const Data& cert, + const Data& priv_key); + + // Starts the pairing client. This call is non-blocking. Upon pairing + // completion, |cb| will be called with the PeerInfo on success, + // or an empty value on failure. + // + // Returns true if PairingClient was successfully started. Otherwise, + // return false. + virtual bool Start(std::string_view ip_addr, pairing_client_result_cb cb, + void* opaque) override; + + private: + static ConnectionPtr CreatePairingConnection(const Data& pswd, const PeerInfo& peer_info, + const Data& cert, const Data& priv_key); + + static void PairingResultCallback(const PeerInfo* peer_info, int fd, void* opaque); + // Setup and start the PairingConnection + bool StartConnection(); + + enum class State { + Ready, + Running, + Stopped, + }; + + State state_ = State::Ready; + Data pswd_; + PeerInfo peer_info_; + Data cert_; + Data priv_key_; + std::string host_; + int port_; + + ConnectionPtr connection_; + pairing_client_result_cb cb_; + void* opaque_ = nullptr; +}; // PairingClientImpl + +// static +ConnectionPtr PairingClientImpl::CreatePairingConnection(const Data& pswd, + const PeerInfo& peer_info, + const Data& cert, const Data& priv_key) { + return ConnectionPtr( + pairing_connection_client_new(pswd.data(), pswd.size(), &peer_info, cert.data(), + cert.size(), priv_key.data(), priv_key.size()), + ConnectionDeleter); +} + +PairingClientImpl::PairingClientImpl(const Data& pswd, const PeerInfo& peer_info, const Data& cert, + const Data& priv_key) + : pswd_(pswd), + peer_info_(peer_info), + cert_(cert), + priv_key_(priv_key), + connection_(nullptr, ConnectionDeleter) { + CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty()); + + state_ = State::Ready; +} + +bool PairingClientImpl::Start(std::string_view ip_addr, pairing_client_result_cb cb, void* opaque) { + CHECK(!ip_addr.empty()); + cb_ = cb; + opaque_ = opaque; + + if (state_ != State::Ready) { + LOG(ERROR) << "PairingClient already running or finished"; + return false; + } + + // Try to parse the host address + std::string err; + CHECK(android::base::ParseNetAddress(std::string(ip_addr), &host_, &port_, nullptr, &err)); + CHECK(port_ > 0 && port_ <= 65535); + + if (!StartConnection()) { + LOG(ERROR) << "Unable to start PairingClient connection"; + state_ = State::Stopped; + return false; + } + + state_ = State::Running; + return true; +} + +static int network_connect(const std::string& host, int port, int type, int timeout, + std::string* error) { + int getaddrinfo_error = 0; + int fd = socket_network_client_timeout(host.c_str(), port, type, timeout, &getaddrinfo_error); + if (fd != -1) { + return fd; + } + if (getaddrinfo_error != 0) { + *error = android::base::StringPrintf("failed to resolve host: '%s': %s", host.c_str(), + gai_strerror(getaddrinfo_error)); + LOG(WARNING) << *error; + } else { + *error = android::base::StringPrintf("failed to connect to '%s:%d': %s", host.c_str(), port, + strerror(errno)); + LOG(WARNING) << *error; + } + return -1; +} + +// static +void PairingClientImpl::PairingResultCallback(const PeerInfo* peer_info, int /* fd */, + void* opaque) { + auto* p = reinterpret_cast(opaque); + p->cb_(peer_info, p->opaque_); +} + +bool PairingClientImpl::StartConnection() { + std::string err; + const int timeout = 10; // seconds + unique_fd fd(network_connect(host_, port_, SOCK_STREAM, timeout, &err)); + if (fd.get() == -1) { + LOG(ERROR) << "Failed to start pairing connection client [" << err << "]"; + return false; + } + int off = 1; + setsockopt(fd.get(), IPPROTO_TCP, TCP_NODELAY, &off, sizeof(off)); + + connection_ = CreatePairingConnection(pswd_, peer_info_, cert_, priv_key_); + if (connection_ == nullptr) { + LOG(ERROR) << "PairingClient unable to create a PairingConnection"; + return false; + } + + if (!pairing_connection_start(connection_.get(), fd.release(), PairingResultCallback, this)) { + LOG(ERROR) << "PairingClient failed to start the PairingConnection"; + state_ = State::Stopped; + return false; + } + + return true; +} + +} // namespace + +// static +std::unique_ptr PairingClient::Create(const Data& pswd, const PeerInfo& peer_info, + const Data& cert, const Data& priv_key) { + CHECK(!pswd.empty()); + CHECK(!cert.empty()); + CHECK(!priv_key.empty()); + + return std::unique_ptr(new PairingClientImpl(pswd, peer_info, cert, priv_key)); +} + +} // namespace pairing +} // namespace adb diff --git a/adb/pairing_connection/tests/pairing_client.h b/adb/pairing_connection/tests/pairing_client.h new file mode 100644 index 000000000..be0db5ce4 --- /dev/null +++ b/adb/pairing_connection/tests/pairing_client.h @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "adb/pairing/pairing_connection.h" + +typedef void (*pairing_client_result_cb)(const PeerInfo*, void*); + +namespace adb { +namespace pairing { + +// PairingClient is the client side of the PairingConnection protocol. It will +// attempt to connect to a PairingServer specified at |host| and |port|, and +// allocate a new PairingConnection for processing. +// +// See pairing_connection_test.cpp for example usage. +// +class PairingClient { + public: + using Data = std::vector; + + virtual ~PairingClient() = default; + + // Starts the pairing client. This call is non-blocking. Upon completion, + // if the pairing was successful, then |cb| will be called with the PeerInfo + // containing the info of the trusted peer. Otherwise, |cb| will be + // called with an empty value. Start can only be called once in the lifetime + // of this object. |ip_addr| requires a port to be specified. + // + // Returns true if PairingClient was successfully started. Otherwise, + // returns false. + virtual bool Start(std::string_view ip_addr, pairing_client_result_cb cb, void* opaque) = 0; + + // Creates a new PairingClient instance. May return null if unable + // to create an instance. |pswd|, |certificate|, |priv_key| and + // |ip_addr| cannot be empty. |peer_info| must contain non-empty strings for + // the guid and name fields. + static std::unique_ptr Create(const Data& pswd, const PeerInfo& peer_info, + const Data& certificate, const Data& priv_key); + + protected: + PairingClient() = default; +}; // class PairingClient + +} // namespace pairing +} // namespace adb diff --git a/adb/pairing_connection/tests/pairing_connection_test.cpp b/adb/pairing_connection/tests/pairing_connection_test.cpp new file mode 100644 index 000000000..b6e09f190 --- /dev/null +++ b/adb/pairing_connection/tests/pairing_connection_test.cpp @@ -0,0 +1,500 @@ +/* + * Copyright 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define LOG_TAG "AdbPairingConnectionTest" + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../internal/constants.h" +#include "pairing_client.h" + +using namespace std::chrono_literals; + +namespace adb { +namespace pairing { + +// Test X.509 certificates (RSA 2048) +static const std::string kTestRsa2048ServerCert = + "-----BEGIN CERTIFICATE-----\n" + "MIIDFzCCAf+gAwIBAgIBATANBgkqhkiG9w0BAQsFADAtMQswCQYDVQQGEwJVUzEQ\n" + "MA4GA1UECgwHQW5kcm9pZDEMMAoGA1UEAwwDQWRiMB4XDTIwMDEyMTIyMjU1NVoX\n" + "DTMwMDExODIyMjU1NVowLTELMAkGA1UEBhMCVVMxEDAOBgNVBAoMB0FuZHJvaWQx\n" + "DDAKBgNVBAMMA0FkYjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAK8E\n" + "2Ck9TfuKlz7wqWdMfknjZ1luFDp2IHxAUZzh/F6jeI2dOFGAjpeloSnGOE86FIaT\n" + "d1EvpyTh7nBwbrLZAA6XFZTo7Bl6BdNOQdqb2d2+cLEN0inFxqUIycevRtohUE1Y\n" + "FHM9fg442X1jOTWXjDZWeiqFWo95paAPhzm6pWqfJK1+YKfT1LsWZpYqJGGQE5pi\n" + "C3qOBYYgFpoXMxTYJNoZo3uOYEdM6upc8/vh15nMgIxX/ymJxEY5BHPpZPPWjXLg\n" + "BfzVaV9fUfv0JT4HQ4t2WvxC3cD/UsjWp2a6p454uUp2ENrANa+jRdRJepepg9D2\n" + "DKsx9L8zjc5Obqexrt0CAwEAAaNCMEAwDwYDVR0TAQH/BAUwAwEB/zAOBgNVHQ8B\n" + "Af8EBAMCAYYwHQYDVR0OBBYEFDFW+8GTErwoZN5Uu9KyY4QdGYKpMA0GCSqGSIb3\n" + "DQEBCwUAA4IBAQBCDEn6SHXGlq5TU7J8cg1kRPd9bsJW+0hDuKSq0REXDkl0PcBf\n" + "fy282Agg9enKPPKmnpeQjM1dmnxdM8tT8LIUbMl779i3fn6v9HJVB+yG4gmRFThW\n" + "c+AGlBnrIT820cX/gU3h3R3FTahfsq+1rrSJkEgHyuC0HYeRyveSckBdaEOLvx0S\n" + "toun+32JJl5hWydpUUZhE9Mbb3KHBRM2YYZZU9JeJ08Apjl+3lRUeMAUwI5fkAAu\n" + "z/1SqnuGL96bd8P5ixdkA1+rF8FPhodGcq9mQOuUGP9g5HOXjaNoJYvwVRUdLeGh\n" + "cP/ReOTwQIzM1K5a83p8cX8AGGYmM7dQp7ec\n" + "-----END CERTIFICATE-----\n"; + +static const std::string kTestRsa2048ServerPrivKey = + "-----BEGIN PRIVATE KEY-----\n" + "MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCvBNgpPU37ipc+\n" + "8KlnTH5J42dZbhQ6diB8QFGc4fxeo3iNnThRgI6XpaEpxjhPOhSGk3dRL6ck4e5w\n" + "cG6y2QAOlxWU6OwZegXTTkHam9ndvnCxDdIpxcalCMnHr0baIVBNWBRzPX4OONl9\n" + "Yzk1l4w2VnoqhVqPeaWgD4c5uqVqnyStfmCn09S7FmaWKiRhkBOaYgt6jgWGIBaa\n" + "FzMU2CTaGaN7jmBHTOrqXPP74deZzICMV/8picRGOQRz6WTz1o1y4AX81WlfX1H7\n" + "9CU+B0OLdlr8Qt3A/1LI1qdmuqeOeLlKdhDawDWvo0XUSXqXqYPQ9gyrMfS/M43O\n" + "Tm6nsa7dAgMBAAECggEAFCS2bPdUKIgjbzLgtHW+hT+J2hD20rcHdyAp+dNH/2vI\n" + "yLfDJHJA4chGMRondKA704oDw2bSJxxlG9t83326lB35yxPhye7cM8fqgWrK8PVl\n" + "tU22FhO1ZgeJvb9OeXWNxKZyDW9oOOJ8eazNXVMuEo+dFj7B6l3MXQyHJPL2mJDm\n" + "u9ofFLdypX+gJncVO0oW0FNJnEUn2MMwHDNlo7gc4WdQuidPkuZItKRGcB8TTGF3\n" + "Ka1/2taYdTQ4Aq//Z84LlFvE0zD3T4c8LwYYzOzD4gGGTXvft7vSHzIun1S8YLRS\n" + "dEKXdVjtaFhgH3uUe4j+1b/vMvSHeoGBNX/G88GD+wKBgQDWUYVlMVqc9HD2IeYi\n" + "EfBcNwAJFJkh51yAl5QbUBgFYgFJVkkS/EDxEGFPvEmI3/pAeQFHFY13BI466EPs\n" + "o8Z8UUwWDp+Z1MFHHKQKnFakbsZbZlbqjJ9VJsqpezbpWhMHTOmcG0dmE7rf0lyM\n" + "eQv9slBB8qp2NEUs5Of7f2C2bwKBgQDRDq4nUuMQF1hbjM05tGKSIwkobmGsLspv\n" + "TMhkM7fq4RpbFHmbNgsFqMhcqYZ8gY6/scv5KCuAZ4yHUkbqwf5h+QCwrJ4uJeUJ\n" + "ZgJfHus2mmcNSo8FwSkNoojIQtzcbJav7bs2K9VTuertk/i7IJLApU4FOZZ5pghN\n" + "EXu0CZF1cwKBgDWFGhjRIF29tU/h20R60llU6s9Zs3wB+NmsALJpZ/ZAKS4VPB5f\n" + "nCAXBRYSYRKrTCU5kpYbzb4BBzuysPOxWmnFK4j+keCqfrGxd02nCQP7HdHJVr8v\n" + "6sIq88UrHeVcNxBFprjzHvtgxfQK5k22FMZ/9wbhAKyQFQ5HA5+MiaxFAoGAIcZZ\n" + "ZIkDninnYIMS9OursShv5lRO+15j3i9tgKLKZ+wOMgDQ1L6acUOfezj4PU1BHr8+\n" + "0PYocQpJreMhCfRlgLaV4fVBaPs+UZJld7CrF5tCYudUy/01ALrtlk0XGZWBktK5\n" + "mDrksC4tQkzRtonAq9cJD9cJ9IVaefkFH0UcdvkCgYBpZj50VLeGhnHHBnkJRlV1\n" + "fV+/P6PAq6RtqjA6O9Qdaoj5V3w2d63aQcQXQLJjH2BBmtCIy47r04rFvZpbCxP7\n" + "NH/OnK9NHpk2ucRTe8TAnVbvF/TZzPJoIxAO/D3OWaW6df4R8en8u6GYzWFglAyT\n" + "sydGT8yfWD1FYUWgfrVRbg==\n" + "-----END PRIVATE KEY-----\n"; + +static const std::string kTestRsa2048ClientCert = + "-----BEGIN CERTIFICATE-----\n" + "MIIDFzCCAf+gAwIBAgIBATANBgkqhkiG9w0BAQsFADAtMQswCQYDVQQGEwJVUzEQ\n" + "MA4GA1UECgwHQW5kcm9pZDEMMAoGA1UEAwwDQWRiMB4XDTIwMDEyMTIyMjU1NloX\n" + "DTMwMDExODIyMjU1NlowLTELMAkGA1UEBhMCVVMxEDAOBgNVBAoMB0FuZHJvaWQx\n" + "DDAKBgNVBAMMA0FkYjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAI3a\n" + "EXh1S5FTbet7JVONswffRPaekdIK53cb8SnAbSO9X5OLA4zGwdkrBvDTsd96SKrp\n" + "JxmoNOE1DhbZh05KPlWAPkGKacjGWaz+S7biDOL0I6aaLbTlU/il1Ub9olPSBVUx\n" + "0nhdtEFgIOzddnP6/1KmyIIeRxS5lTKeg4avqUkZNXkz/wL1dHBFL7FNFf0SCcbo\n" + "tsub/deFbjZ27LTDN+SIBgFttTNqC5NTvoBAoMdyCOAgNYwaHO+fKiK3edfJieaw\n" + "7HD8qqmQxcpCtRlA8CUPj7GfR+WHiCJmlevhnkFXCo56R1BS0F4wuD4KPdSWt8gc\n" + "27ejH/9/z2cKo/6SLJMCAwEAAaNCMEAwDwYDVR0TAQH/BAUwAwEB/zAOBgNVHQ8B\n" + "Af8EBAMCAYYwHQYDVR0OBBYEFO/Mr5ygqqpyU/EHM9v7RDvcqaOkMA0GCSqGSIb3\n" + "DQEBCwUAA4IBAQAH33KMouzF2DYbjg90KDrDQr4rq3WfNb6P743knxdUFuvb+40U\n" + "QjC2OJZHkSexH7wfG/y6ic7vfCfF4clNs3QvU1lEjOZC57St8Fk7mdNdsWLwxEMD\n" + "uePFz0dvclSxNUHyCVMqNxddzQYzxiDWQRmXWrUBliMduQqEQelcxW2yDtg8bj+s\n" + "aMpR1ra9scaD4jzIZIIxLoOS9zBMuNRbgP217sZrniyGMhzoI1pZ/izN4oXpyH7O\n" + "THuaCzzRT3ph2f8EgmHSodz3ttgSf2DHzi/Ez1xUkk7NOlgNtmsxEdrM47+cC5ae\n" + "fIf2V+1o1JW8J7D11RmRbNPh3vfisueB4f88\n" + "-----END CERTIFICATE-----\n"; + +static const std::string kTestRsa2048ClientPrivKey = + "-----BEGIN PRIVATE KEY-----\n" + "MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCN2hF4dUuRU23r\n" + "eyVTjbMH30T2npHSCud3G/EpwG0jvV+TiwOMxsHZKwbw07Hfekiq6ScZqDThNQ4W\n" + "2YdOSj5VgD5BimnIxlms/ku24gzi9COmmi205VP4pdVG/aJT0gVVMdJ4XbRBYCDs\n" + "3XZz+v9SpsiCHkcUuZUynoOGr6lJGTV5M/8C9XRwRS+xTRX9EgnG6LbLm/3XhW42\n" + "duy0wzfkiAYBbbUzaguTU76AQKDHcgjgIDWMGhzvnyoit3nXyYnmsOxw/KqpkMXK\n" + "QrUZQPAlD4+xn0flh4giZpXr4Z5BVwqOekdQUtBeMLg+Cj3UlrfIHNu3ox//f89n\n" + "CqP+kiyTAgMBAAECggEAAa64eP6ggCob1P3c73oayYPIbvRqiQdAFOrr7Vwu7zbr\n" + "z0rde+n6RU0mrpc+4NuzyPMtrOGQiatLbidJB5Cx3z8U00ovqbCl7PtcgorOhFKe\n" + "VEzihebCcYyQqbWQcKtpDMhOgBxRwFoXieJb6VGXfa96FAZalCWvXgOrTl7/BF2X\n" + "qMqIm9nJi+yS5tIO8VdOsOmrMWRH/b/ENUcef4WpLoxTXr0EEgyKWraeZ/hhXo1e\n" + "z29dZKqdr9wMsq11NPsRddwS94jnDkXTo+EQyWVTfB7gb6yyp07s8jysaDb21tVv\n" + "UXB9MRhDV1mOv0ncXfXZ4/+4A2UahmZaLDAVLaat4QKBgQDAVRredhGRGl2Nkic3\n" + "KvZCAfyxug788CgasBdEiouz19iCCwcgMIDwnq0s3/WM7h/laCamT2x38riYDnpq\n" + "rkYMfuVtU9CjEL9pTrdfwbIRhTwYNqADaPz2mXwQUhRXutE5TIdgxxC/a+ZTh0qN\n" + "S+vhTj/4hf0IZhMh5Nqj7IPExQKBgQC8zxEzhmSGjys0GuE6Wl6Doo2TpiR6vwvi\n" + "xPLU9lmIz5eca/Rd/eERioFQqeoIWDLzx52DXuz6rUoQhbJWz9hP3yqCwXD+pbNP\n" + "oDJqDDbCC4IMYEb0IK/PEPH+gIpnTjoFcW+ecKDFG7W5Lt05J8WsJsfOaJvMrOU+\n" + "dLXq3IgxdwKBgQC5RAFq0v6e8G+3hFaEHL0z3igkpt3zJf7rnj37hx2FMmDa+3Z0\n" + "umQp5B9af61PgL12xLmeMBmC/Wp1BlVDV/Yf6Uhk5Hyv5t0KuomHEtTNbbLyfAPs\n" + "5P/vJu/L5NS1oT4S3LX3MineyjgGs+bLbpub3z1dzutrYLADUSiPCK/xJQKBgBQt\n" + "nQ0Ao+Wtj1R2OvPdjJRM3wyUiPmFSWPm4HzaBx+T8AQLlYYmB9O0FbXlMtnJc0iS\n" + "YMcVcgYoVu4FG9YjSF7g3s4yljzgwJUV7c1fmMqMKE3iTDLy+1cJ3JLycdgwiArk\n" + "4KTyLHxkRbuQwpvFIF8RlfD9RQlOwQE3v+llwDhpAoGBAL6XG6Rp6mBoD2Ds5c9R\n" + "943yYgSUes3ji1SI9zFqeJtj8Ml/enuK1xu+8E/BxB0//+vgZsH6i3i8GFwygKey\n" + "CGJF8CbiHc3EJc3NQIIRXcni/CGacf0HwC6m+PGFDBIpA4H2iDpVvCSofxttQiq0\n" + "/Z7HXmXUvZHVyYi/QzX2Gahj\n" + "-----END PRIVATE KEY-----\n"; + +struct ServerDeleter { + void operator()(PairingServerCtx* p) { pairing_server_destroy(p); } +}; +using ServerPtr = std::unique_ptr; + +struct ResultWaiter { + std::mutex mutex_; + std::condition_variable cv_; + std::optional is_valid_; + PeerInfo peer_info_; + + static void ResultCallback(const PeerInfo* peer_info, void* opaque) { + auto* p = reinterpret_cast(opaque); + { + std::unique_lock lock(p->mutex_); + if (peer_info) { + memcpy(&(p->peer_info_), peer_info, sizeof(PeerInfo)); + } + p->is_valid_ = (peer_info != nullptr); + } + p->cv_.notify_one(); + } +}; + +class AdbPairingConnectionTest : public testing::Test { + protected: + virtual void SetUp() override {} + + virtual void TearDown() override {} + + void InitPairing(const std::vector& server_pswd, + const std::vector& client_pswd) { + server_ = CreateServer(server_pswd); + client_ = CreateClient(client_pswd); + } + + ServerPtr CreateServer(const std::vector& pswd) { + return CreateServer(pswd, &server_info_, kTestRsa2048ServerCert, kTestRsa2048ServerPrivKey, + 0); + } + + std::unique_ptr CreateClient(const std::vector pswd) { + std::vector cert; + std::vector key; + // Include the null-byte as well. + cert.assign(reinterpret_cast(kTestRsa2048ClientCert.data()), + reinterpret_cast(kTestRsa2048ClientCert.data()) + + kTestRsa2048ClientCert.size() + 1); + key.assign(reinterpret_cast(kTestRsa2048ClientPrivKey.data()), + reinterpret_cast(kTestRsa2048ClientPrivKey.data()) + + kTestRsa2048ClientPrivKey.size() + 1); + return PairingClient::Create(pswd, client_info_, cert, key); + } + + static ServerPtr CreateServer(const std::vector& pswd, const PeerInfo* peer_info, + const std::string_view cert, const std::string_view priv_key, + int port) { + return ServerPtr(pairing_server_new( + pswd.data(), pswd.size(), peer_info, reinterpret_cast(cert.data()), + cert.size(), reinterpret_cast(priv_key.data()), priv_key.size(), + port)); + } + + ServerPtr server_; + const PeerInfo server_info_ = { + .type = ADB_DEVICE_GUID, + .data = "my_server_info", + }; + std::unique_ptr client_; + const PeerInfo client_info_ = { + .type = ADB_RSA_PUB_KEY, + .data = "my_client_info", + }; + std::string ip_addr_ = "127.0.0.1:"; +}; + +TEST_F(AdbPairingConnectionTest, ServerCreation) { + // All parameters bad + ASSERT_DEATH({ auto server = CreateServer({}, nullptr, "", "", 0); }, ""); + // Bad password + ASSERT_DEATH( + { + auto server = CreateServer({}, &server_info_, kTestRsa2048ServerCert, + kTestRsa2048ServerPrivKey, 0); + }, + ""); + // Bad peer_info + ASSERT_DEATH( + { + auto server = CreateServer({0x01}, nullptr, kTestRsa2048ServerCert, + kTestRsa2048ServerPrivKey, 0); + }, + ""); + // Bad certificate + ASSERT_DEATH( + { + auto server = CreateServer({0x01}, &server_info_, "", kTestRsa2048ServerPrivKey, 0); + }, + ""); + // Bad private key + ASSERT_DEATH( + { auto server = CreateServer({0x01}, &server_info_, kTestRsa2048ServerCert, "", 0); }, + ""); + // Valid params + auto server = CreateServer({0x01}, &server_info_, kTestRsa2048ServerCert, + kTestRsa2048ServerPrivKey, 0); + EXPECT_NE(nullptr, server); +} + +TEST_F(AdbPairingConnectionTest, ClientCreation) { + std::vector pswd{0x01, 0x03, 0x05, 0x07}; + // Bad password + ASSERT_DEATH( + { + pairing_connection_client_new( + nullptr, pswd.size(), &client_info_, + reinterpret_cast(kTestRsa2048ClientCert.data()), + kTestRsa2048ClientCert.size(), + reinterpret_cast(kTestRsa2048ClientPrivKey.data()), + kTestRsa2048ClientPrivKey.size()); + }, + ""); + ASSERT_DEATH( + { + pairing_connection_client_new( + pswd.data(), 0, &client_info_, + reinterpret_cast(kTestRsa2048ClientCert.data()), + kTestRsa2048ClientCert.size(), + reinterpret_cast(kTestRsa2048ClientPrivKey.data()), + kTestRsa2048ClientPrivKey.size()); + }, + ""); + + // Bad peer_info + ASSERT_DEATH( + { + pairing_connection_client_new( + pswd.data(), pswd.size(), nullptr, + reinterpret_cast(kTestRsa2048ClientCert.data()), + kTestRsa2048ClientCert.size(), + reinterpret_cast(kTestRsa2048ClientPrivKey.data()), + kTestRsa2048ClientPrivKey.size()); + }, + ""); + + // Bad certificate + ASSERT_DEATH( + { + pairing_connection_client_new( + pswd.data(), pswd.size(), &client_info_, nullptr, + kTestRsa2048ClientCert.size(), + reinterpret_cast(kTestRsa2048ClientPrivKey.data()), + kTestRsa2048ClientPrivKey.size()); + }, + ""); + ASSERT_DEATH( + { + pairing_connection_client_new( + pswd.data(), pswd.size(), &client_info_, + reinterpret_cast(kTestRsa2048ClientCert.data()), 0, + reinterpret_cast(kTestRsa2048ClientPrivKey.data()), + kTestRsa2048ClientPrivKey.size()); + }, + ""); + + // Bad private key + ASSERT_DEATH( + { + pairing_connection_client_new( + pswd.data(), pswd.size(), &client_info_, + reinterpret_cast(kTestRsa2048ClientCert.data()), + kTestRsa2048ClientCert.size(), nullptr, kTestRsa2048ClientPrivKey.size()); + }, + ""); + ASSERT_DEATH( + { + pairing_connection_client_new( + pswd.data(), pswd.size(), &client_info_, + reinterpret_cast(kTestRsa2048ClientCert.data()), + kTestRsa2048ClientCert.size(), + reinterpret_cast(kTestRsa2048ClientPrivKey.data()), 0); + }, + ""); + + // Valid params + auto client = pairing_connection_client_new( + pswd.data(), pswd.size(), &client_info_, + reinterpret_cast(kTestRsa2048ClientCert.data()), + kTestRsa2048ClientCert.size(), + reinterpret_cast(kTestRsa2048ClientPrivKey.data()), + kTestRsa2048ClientPrivKey.size()); + EXPECT_NE(nullptr, client); +} + +TEST_F(AdbPairingConnectionTest, SmokeValidPairing) { + std::vector pswd{0x01, 0x03, 0x05, 0x07}; + InitPairing(pswd, pswd); + + // Start the server + ResultWaiter server_waiter; + std::unique_lock server_lock(server_waiter.mutex_); + auto port = pairing_server_start(server_.get(), server_waiter.ResultCallback, &server_waiter); + ASSERT_GT(port, 0); + ip_addr_ += std::to_string(port); + + // Start the client + ResultWaiter client_waiter; + std::unique_lock client_lock(client_waiter.mutex_); + ASSERT_TRUE(client_->Start(ip_addr_, client_waiter.ResultCallback, &client_waiter)); + client_waiter.cv_.wait(client_lock, [&]() { return client_waiter.is_valid_.has_value(); }); + ASSERT_TRUE(*(client_waiter.is_valid_)); + ASSERT_EQ(strlen(reinterpret_cast(client_waiter.peer_info_.data)), + strlen(reinterpret_cast(server_info_.data))); + EXPECT_EQ(memcmp(client_waiter.peer_info_.data, server_info_.data, sizeof(server_info_.data)), + 0); + + // Kill server if the pairing failed, since server only shuts down when + // it gets a valid pairing. + if (!client_waiter.is_valid_) { + server_lock.unlock(); + server_.reset(); + } else { + server_waiter.cv_.wait(server_lock, [&]() { return server_waiter.is_valid_.has_value(); }); + ASSERT_TRUE(*(server_waiter.is_valid_)); + ASSERT_EQ(strlen(reinterpret_cast(server_waiter.peer_info_.data)), + strlen(reinterpret_cast(client_info_.data))); + EXPECT_EQ( + memcmp(server_waiter.peer_info_.data, client_info_.data, sizeof(client_info_.data)), + 0); + } +} + +TEST_F(AdbPairingConnectionTest, CancelPairing) { + std::vector pswd{0x01, 0x03, 0x05, 0x07}; + std::vector pswd2{0x01, 0x03, 0x05, 0x06}; + InitPairing(pswd, pswd2); + + // Start the server + ResultWaiter server_waiter; + std::unique_lock server_lock(server_waiter.mutex_); + auto port = pairing_server_start(server_.get(), server_waiter.ResultCallback, &server_waiter); + ASSERT_GT(port, 0); + ip_addr_ += std::to_string(port); + + // Start the client. Client should fail to pair + ResultWaiter client_waiter; + std::unique_lock client_lock(client_waiter.mutex_); + ASSERT_TRUE(client_->Start(ip_addr_, client_waiter.ResultCallback, &client_waiter)); + client_waiter.cv_.wait(client_lock, [&]() { return client_waiter.is_valid_.has_value(); }); + ASSERT_FALSE(*(client_waiter.is_valid_)); + + // Kill the server. We should still receive the callback with no valid + // pairing. + server_lock.unlock(); + server_.reset(); + server_lock.lock(); + ASSERT_TRUE(server_waiter.is_valid_.has_value()); + EXPECT_FALSE(*(server_waiter.is_valid_)); +} + +TEST_F(AdbPairingConnectionTest, MultipleClientsAllFail) { + std::vector pswd{0x01, 0x03, 0x05, 0x07}; + std::vector pswd2{0x01, 0x03, 0x05, 0x06}; + + // Start the server + auto server = CreateServer(pswd); + ResultWaiter server_waiter; + std::unique_lock server_lock(server_waiter.mutex_); + auto port = pairing_server_start(server.get(), server_waiter.ResultCallback, &server_waiter); + ASSERT_GT(port, 0); + ip_addr_ += std::to_string(port); + + // Start multiple clients, all with bad passwords + int test_num_clients = 5; + int num_clients_done = 0; + std::mutex global_clients_mutex; + std::unique_lock global_clients_lock(global_clients_mutex); + std::condition_variable global_cv_; + for (int i = 0; i < test_num_clients; ++i) { + std::thread([&]() { + auto client = CreateClient(pswd2); + ResultWaiter client_waiter; + std::unique_lock client_lock(client_waiter.mutex_); + ASSERT_TRUE(client->Start(ip_addr_, client_waiter.ResultCallback, &client_waiter)); + client_waiter.cv_.wait(client_lock, + [&]() { return client_waiter.is_valid_.has_value(); }); + ASSERT_FALSE(*(client_waiter.is_valid_)); + { + std::lock_guard global_lock(global_clients_mutex); + ++num_clients_done; + } + global_cv_.notify_one(); + }).detach(); + } + + global_cv_.wait(global_clients_lock, [&]() { return num_clients_done == test_num_clients; }); + server_lock.unlock(); + server.reset(); + server_lock.lock(); + ASSERT_TRUE(server_waiter.is_valid_.has_value()); + EXPECT_FALSE(*(server_waiter.is_valid_)); +} + +TEST_F(AdbPairingConnectionTest, MultipleClientsOnePass) { + // Send multiple clients with bad passwords, but send the last one with the + // correct password. + std::vector pswd{0x01, 0x03, 0x05, 0x07}; + std::vector pswd2{0x01, 0x03, 0x05, 0x06}; + + // Start the server + auto server = CreateServer(pswd); + ResultWaiter server_waiter; + std::unique_lock server_lock(server_waiter.mutex_); + auto port = pairing_server_start(server.get(), server_waiter.ResultCallback, &server_waiter); + ASSERT_GT(port, 0); + ip_addr_ += std::to_string(port); + + // Start multiple clients, all with bad passwords + int test_num_clients = 5; + int num_clients_done = 0; + std::mutex global_clients_mutex; + std::unique_lock global_clients_lock(global_clients_mutex); + std::condition_variable global_cv_; + for (int i = 0; i < test_num_clients; ++i) { + std::thread([&, i]() { + bool good_client = (i == (test_num_clients - 1)); + auto client = CreateClient((good_client ? pswd : pswd2)); + ResultWaiter client_waiter; + std::unique_lock client_lock(client_waiter.mutex_); + ASSERT_TRUE(client->Start(ip_addr_, client_waiter.ResultCallback, &client_waiter)); + client_waiter.cv_.wait(client_lock, + [&]() { return client_waiter.is_valid_.has_value(); }); + if (good_client) { + ASSERT_TRUE(*(client_waiter.is_valid_)); + ASSERT_EQ(strlen(reinterpret_cast(client_waiter.peer_info_.data)), + strlen(reinterpret_cast(server_info_.data))); + EXPECT_EQ(memcmp(client_waiter.peer_info_.data, server_info_.data, + sizeof(server_info_.data)), + 0); + } else { + ASSERT_FALSE(*(client_waiter.is_valid_)); + } + { + std::lock_guard global_lock(global_clients_mutex); + ++num_clients_done; + } + global_cv_.notify_one(); + }).detach(); + } + + global_cv_.wait(global_clients_lock, [&]() { return num_clients_done == test_num_clients; }); + server_waiter.cv_.wait(server_lock, [&]() { return server_waiter.is_valid_.has_value(); }); + ASSERT_TRUE(*(server_waiter.is_valid_)); + ASSERT_EQ(strlen(reinterpret_cast(server_waiter.peer_info_.data)), + strlen(reinterpret_cast(client_info_.data))); + EXPECT_EQ(memcmp(server_waiter.peer_info_.data, client_info_.data, sizeof(client_info_.data)), + 0); +} + +} // namespace pairing +} // namespace adb diff --git a/adb/protocol.txt b/adb/protocol.txt index f4523c4be..75700a4d5 100644 --- a/adb/protocol.txt +++ b/adb/protocol.txt @@ -79,6 +79,14 @@ where systemtype is "bootloader", "device", or "host", serialno is some kind of unique ID (or empty), and banner is a human-readable version or identifier string. The banner is used to transmit useful properties. +--- STLS(type, version, "") -------------------------------------------- + +Command constant: A_STLS + +The TLS message informs the recipient that the connection will be encrypted +and will need to perform a TLS handshake. version is the current version of +the protocol. + --- AUTH(type, 0, "data") ---------------------------------------------- @@ -207,6 +215,7 @@ to send across the wire. #define A_OKAY 0x59414b4f #define A_CLSE 0x45534c43 #define A_WRTE 0x45545257 +#define A_STLS 0x534C5453 diff --git a/adb/services.cpp b/adb/services.cpp index 6185aa68a..853d65897 100644 --- a/adb/services.cpp +++ b/adb/services.cpp @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -34,6 +35,7 @@ #include "adb_io.h" #include "adb_unique_fd.h" #include "adb_utils.h" +#include "adb_wifi.h" #include "services.h" #include "socket_spec.h" #include "sysdeps.h" @@ -193,6 +195,12 @@ static void connect_service(unique_fd fd, std::string host) { // Send response for emulator and device SendProtocolString(fd.get(), response); } + +static void pair_service(unique_fd fd, std::string host, std::string password) { + std::string response; + adb_wifi_pair_device(host, password, response); + SendProtocolString(fd.get(), response); +} #endif #if ADB_HOST @@ -248,6 +256,16 @@ asocket* host_service_to_socket(std::string_view name, std::string_view serial, unique_fd fd = create_service_thread( "connect", std::bind(connect_service, std::placeholders::_1, host)); return create_local_socket(std::move(fd)); + } else if (android::base::ConsumePrefix(&name, "pair:")) { + const char* divider = strchr(name.data(), ':'); + if (!divider) { + return nullptr; + } + std::string password(name.data(), divider); + std::string host(divider + 1); + unique_fd fd = create_service_thread( + "pair", std::bind(pair_service, std::placeholders::_1, host, password)); + return create_local_socket(std::move(fd)); } return nullptr; } diff --git a/adb/sysdeps.h b/adb/sysdeps.h index 231839505..4efbc02c3 100644 --- a/adb/sysdeps.h +++ b/adb/sysdeps.h @@ -88,6 +88,8 @@ extern int adb_mkdir(const std::string& path, int mode); #undef mkdir #define mkdir ___xxx_mkdir +extern int adb_rename(const char* oldpath, const char* newpath); + // See the comments for the !defined(_WIN32) versions of adb_*(). extern int adb_open(const char* path, int options); extern int adb_creat(const char* path, int mode); @@ -101,6 +103,9 @@ extern int adb_close(int fd); extern int adb_register_socket(SOCKET s); extern HANDLE adb_get_os_handle(borrowed_fd fd); +extern int adb_gethostname(char* name, size_t len); +extern int adb_getlogin_r(char* buf, size_t bufsize); + // See the comments for the !defined(_WIN32) version of unix_close(). static __inline__ int unix_close(int fd) { return close(fd); @@ -461,6 +466,14 @@ __inline__ int adb_register_socket(int s) { return s; } +static __inline__ int adb_gethostname(char* name, size_t len) { + return gethostname(name, len); +} + +static __inline__ int adb_getlogin_r(char* buf, size_t bufsize) { + return getlogin_r(buf, bufsize); +} + static __inline__ int adb_read(borrowed_fd fd, void* buf, size_t len) { return TEMP_FAILURE_RETRY(read(fd.get(), buf, len)); } @@ -637,6 +650,10 @@ static __inline__ int adb_mkdir(const std::string& path, int mode) { #undef mkdir #define mkdir ___xxx_mkdir +static __inline__ int adb_rename(const char* oldpath, const char* newpath) { + return rename(oldpath, newpath); +} + static __inline__ int adb_is_absolute_host_path(const char* path) { return path[0] == '/'; } diff --git a/adb/sysdeps_win32.cpp b/adb/sysdeps_win32.cpp index e33d51cc3..be82bc0d1 100644 --- a/adb/sysdeps_win32.cpp +++ b/adb/sysdeps_win32.cpp @@ -18,8 +18,9 @@ #include "sysdeps.h" -#include /* winsock.h *must* be included before windows.h. */ +#include #include +#include /* winsock.h *must* be included before windows.h. */ #include #include @@ -1009,6 +1010,55 @@ int adb_register_socket(SOCKET s) { return _fh_to_int(f); } +static bool isBlankStr(const char* str) { + for (; *str != '\0'; ++str) { + if (!isblank(*str)) { + return false; + } + } + return true; +} + +int adb_gethostname(char* name, size_t len) { + const char* computerName = adb_getenv("COMPUTERNAME"); + if (computerName && !isBlankStr(computerName)) { + strncpy(name, computerName, len); + name[len - 1] = '\0'; + return 0; + } + + wchar_t buffer[MAX_COMPUTERNAME_LENGTH + 1]; + DWORD size = sizeof(buffer); + if (!GetComputerNameW(buffer, &size)) { + return -1; + } + std::string name_utf8; + if (!android::base::WideToUTF8(buffer, &name_utf8)) { + return -1; + } + + strncpy(name, name_utf8.c_str(), len); + name[len - 1] = '\0'; + return 0; +} + +int adb_getlogin_r(char* buf, size_t bufsize) { + wchar_t buffer[UNLEN + 1]; + DWORD len = sizeof(buffer); + if (!GetUserNameW(buffer, &len)) { + return -1; + } + + std::string login; + if (!android::base::WideToUTF8(buffer, &login)) { + return -1; + } + + strncpy(buf, login.c_str(), bufsize); + buf[bufsize - 1] = '\0'; + return 0; +} + #undef accept int adb_socket_accept(borrowed_fd serverfd, struct sockaddr* addr, socklen_t* addrlen) { FH serverfh = _fh_from_int(serverfd, __func__); @@ -2342,6 +2392,20 @@ int adb_mkdir(const std::string& path, int mode) { return _wmkdir(path_wide.c_str()); } +int adb_rename(const char* oldpath, const char* newpath) { + std::wstring oldpath_wide, newpath_wide; + if (!android::base::UTF8ToWide(oldpath, &oldpath_wide)) { + return -1; + } + if (!android::base::UTF8ToWide(newpath, &newpath_wide)) { + return -1; + } + + // MSDN just says the return value is non-zero on failure, make sure it + // returns -1 on failure so that it behaves the same as other systems. + return _wrename(oldpath_wide.c_str(), newpath_wide.c_str()) ? -1 : 0; +} + // Version of utime() that takes a UTF-8 path. int adb_utime(const char* path, struct utimbuf* u) { std::wstring path_wide; diff --git a/adb/transport.cpp b/adb/transport.cpp index 9dd6ec642..8b3461a83 100644 --- a/adb/transport.cpp +++ b/adb/transport.cpp @@ -36,6 +36,9 @@ #include #include +#include +#include +#include #include #include #include @@ -52,7 +55,10 @@ #include "fdevent/fdevent.h" #include "sysdeps/chrono.h" +using namespace adb::crypto; +using namespace adb::tls; using android::base::ScopedLockAssertion; +using TlsError = TlsConnection::TlsError; static void remove_transport(atransport* transport); static void transport_destroy(atransport* transport); @@ -279,18 +285,7 @@ void BlockingConnectionAdapter::Start() { << "): started multiple times"; } - read_thread_ = std::thread([this]() { - LOG(INFO) << this->transport_name_ << ": read thread spawning"; - while (true) { - auto packet = std::make_unique(); - if (!underlying_->Read(packet.get())) { - PLOG(INFO) << this->transport_name_ << ": read failed"; - break; - } - read_callback_(this, std::move(packet)); - } - std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "read failed"); }); - }); + StartReadThread(); write_thread_ = std::thread([this]() { LOG(INFO) << this->transport_name_ << ": write thread spawning"; @@ -319,6 +314,46 @@ void BlockingConnectionAdapter::Start() { started_ = true; } +void BlockingConnectionAdapter::StartReadThread() { + read_thread_ = std::thread([this]() { + LOG(INFO) << this->transport_name_ << ": read thread spawning"; + while (true) { + auto packet = std::make_unique(); + if (!underlying_->Read(packet.get())) { + PLOG(INFO) << this->transport_name_ << ": read failed"; + break; + } + + bool got_stls_cmd = false; + if (packet->msg.command == A_STLS) { + got_stls_cmd = true; + } + + read_callback_(this, std::move(packet)); + + // If we received the STLS packet, we are about to perform the TLS + // handshake. So this read thread must stop and resume after the + // handshake completes otherwise this will interfere in the process. + if (got_stls_cmd) { + LOG(INFO) << this->transport_name_ + << ": Received STLS packet. Stopping read thread."; + return; + } + } + std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "read failed"); }); + }); +} + +bool BlockingConnectionAdapter::DoTlsHandshake(RSA* key, std::string* auth_key) { + std::lock_guard lock(mutex_); + if (read_thread_.joinable()) { + read_thread_.join(); + } + bool success = this->underlying_->DoTlsHandshake(key, auth_key); + StartReadThread(); + return success; +} + void BlockingConnectionAdapter::Reset() { { std::lock_guard lock(mutex_); @@ -388,8 +423,36 @@ bool BlockingConnectionAdapter::Write(std::unique_ptr packet) { return true; } +FdConnection::FdConnection(unique_fd fd) : fd_(std::move(fd)) {} + +FdConnection::~FdConnection() {} + +bool FdConnection::DispatchRead(void* buf, size_t len) { + if (tls_ != nullptr) { + // The TlsConnection doesn't allow 0 byte reads + if (len == 0) { + return true; + } + return tls_->ReadFully(buf, len); + } + + return ReadFdExactly(fd_.get(), buf, len); +} + +bool FdConnection::DispatchWrite(void* buf, size_t len) { + if (tls_ != nullptr) { + // The TlsConnection doesn't allow 0 byte writes + if (len == 0) { + return true; + } + return tls_->WriteFully(std::string_view(reinterpret_cast(buf), len)); + } + + return WriteFdExactly(fd_.get(), buf, len); +} + bool FdConnection::Read(apacket* packet) { - if (!ReadFdExactly(fd_.get(), &packet->msg, sizeof(amessage))) { + if (!DispatchRead(&packet->msg, sizeof(amessage))) { D("remote local: read terminated (message)"); return false; } @@ -401,7 +464,7 @@ bool FdConnection::Read(apacket* packet) { packet->payload.resize(packet->msg.data_length); - if (!ReadFdExactly(fd_.get(), &packet->payload[0], packet->payload.size())) { + if (!DispatchRead(&packet->payload[0], packet->payload.size())) { D("remote local: terminated (data)"); return false; } @@ -410,13 +473,13 @@ bool FdConnection::Read(apacket* packet) { } bool FdConnection::Write(apacket* packet) { - if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(packet->msg))) { + if (!DispatchWrite(&packet->msg, sizeof(packet->msg))) { D("remote local: write terminated"); return false; } if (packet->msg.data_length) { - if (!WriteFdExactly(fd_.get(), &packet->payload[0], packet->msg.data_length)) { + if (!DispatchWrite(&packet->payload[0], packet->msg.data_length)) { D("remote local: write terminated"); return false; } @@ -425,6 +488,51 @@ bool FdConnection::Write(apacket* packet) { return true; } +bool FdConnection::DoTlsHandshake(RSA* key, std::string* auth_key) { + bssl::UniquePtr evp_pkey(EVP_PKEY_new()); + if (!EVP_PKEY_set1_RSA(evp_pkey.get(), key)) { + LOG(ERROR) << "EVP_PKEY_set1_RSA failed"; + return false; + } + auto x509 = GenerateX509Certificate(evp_pkey.get()); + auto x509_str = X509ToPEMString(x509.get()); + auto evp_str = Key::ToPEMString(evp_pkey.get()); +#if ADB_HOST + tls_ = TlsConnection::Create(TlsConnection::Role::Client, +#else + tls_ = TlsConnection::Create(TlsConnection::Role::Server, +#endif + x509_str, evp_str, fd_); + CHECK(tls_); +#if ADB_HOST + // TLS 1.3 gives the client no message if the server rejected the + // certificate. This will enable a check in the tls connection to check + // whether the client certificate got rejected. Note that this assumes + // that, on handshake success, the server speaks first. + tls_->EnableClientPostHandshakeCheck(true); + // Add callback to set the certificate when server issues the + // CertificateRequest. + tls_->SetCertificateCallback(adb_tls_set_certificate); + // Allow any server certificate + tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; }); +#else + // Add callback to check certificate against a list of known public keys + tls_->SetCertVerifyCallback( + [auth_key](X509_STORE_CTX* ctx) { return adbd_tls_verify_cert(ctx, auth_key); }); + // Add the list of allowed client CA issuers + auto ca_list = adbd_tls_client_ca_list(); + tls_->SetClientCAList(ca_list.get()); +#endif + + auto err = tls_->DoHandshake(); + if (err == TlsError::Success) { + return true; + } + + tls_.reset(); + return false; +} + void FdConnection::Close() { adb_shutdown(fd_.get()); fd_.reset(); @@ -750,6 +858,26 @@ void kick_all_transports() { } } +void kick_all_tcp_tls_transports() { + std::lock_guard lock(transport_lock); + for (auto t : transport_list) { + if (t->IsTcpDevice() && t->use_tls) { + t->Kick(); + } + } +} + +#if !ADB_HOST +void kick_all_transports_by_auth_key(std::string_view auth_key) { + std::lock_guard lock(transport_lock); + for (auto t : transport_list) { + if (auth_key == t->auth_key) { + t->Kick(); + } + } +} +#endif + /* the fdevent select pump is single threaded */ void register_transport(atransport* transport) { tmsg m; @@ -1026,6 +1154,10 @@ int atransport::get_protocol_version() const { return protocol_version; } +int atransport::get_tls_version() const { + return tls_version; +} + size_t atransport::get_max_payload() const { return max_payload; } @@ -1221,8 +1353,9 @@ void close_usb_devices(bool reset) { #endif // ADB_HOST bool register_socket_transport(unique_fd s, std::string serial, int port, int local, - atransport::ReconnectCallback reconnect, int* error) { + atransport::ReconnectCallback reconnect, bool use_tls, int* error) { atransport* t = new atransport(std::move(reconnect), kCsOffline); + t->use_tls = use_tls; D("transport: %s init'ing for socket %d, on port %d", serial.c_str(), s.get(), port); if (init_socket_transport(t, std::move(s), port, local) < 0) { @@ -1360,6 +1493,15 @@ bool check_header(apacket* p, atransport* t) { } #if ADB_HOST +std::shared_ptr atransport::Key() { + if (keys_.empty()) { + return nullptr; + } + + std::shared_ptr result = keys_[0]; + return result; +} + std::shared_ptr atransport::NextKey() { if (keys_.empty()) { LOG(INFO) << "fetching keys for transport " << this->serial_name(); @@ -1367,10 +1509,11 @@ std::shared_ptr atransport::NextKey() { // We should have gotten at least one key: the one that's automatically generated. CHECK(!keys_.empty()); + } else { + keys_.pop_front(); } std::shared_ptr result = keys_[0]; - keys_.pop_front(); return result; } diff --git a/adb/transport.h b/adb/transport.h index 5a750eea1..8a0f62ab5 100644 --- a/adb/transport.h +++ b/adb/transport.h @@ -43,6 +43,14 @@ typedef std::unordered_set FeatureSet; +namespace adb { +namespace tls { + +class TlsConnection; + +} // namespace tls +} // namespace adb + const FeatureSet& supported_features(); // Encodes and decodes FeatureSet objects into human-readable strings. @@ -104,6 +112,8 @@ struct Connection { virtual void Start() = 0; virtual void Stop() = 0; + virtual bool DoTlsHandshake(RSA* key, std::string* auth_key = nullptr) = 0; + // Stop, and reset the device if it's a USB connection. virtual void Reset(); @@ -128,6 +138,8 @@ struct BlockingConnection { virtual bool Read(apacket* packet) = 0; virtual bool Write(apacket* packet) = 0; + virtual bool DoTlsHandshake(RSA* key, std::string* auth_key = nullptr) = 0; + // Terminate a connection. // This method must be thread-safe, and must cause concurrent Reads/Writes to terminate. // Formerly known as 'Kick' in atransport. @@ -146,9 +158,12 @@ struct BlockingConnectionAdapter : public Connection { virtual void Start() override final; virtual void Stop() override final; + virtual bool DoTlsHandshake(RSA* key, std::string* auth_key) override final; virtual void Reset() override final; + private: + void StartReadThread() REQUIRES(mutex_); bool started_ GUARDED_BY(mutex_) = false; bool stopped_ GUARDED_BY(mutex_) = false; @@ -164,16 +179,22 @@ struct BlockingConnectionAdapter : public Connection { }; struct FdConnection : public BlockingConnection { - explicit FdConnection(unique_fd fd) : fd_(std::move(fd)) {} + explicit FdConnection(unique_fd fd); + ~FdConnection(); bool Read(apacket* packet) override final; bool Write(apacket* packet) override final; + bool DoTlsHandshake(RSA* key, std::string* auth_key) override final; void Close() override; virtual void Reset() override final { Close(); } private: + bool DispatchRead(void* buf, size_t len); + bool DispatchWrite(void* buf, size_t len); + unique_fd fd_; + std::unique_ptr tls_; }; struct UsbConnection : public BlockingConnection { @@ -182,6 +203,7 @@ struct UsbConnection : public BlockingConnection { bool Read(apacket* packet) override final; bool Write(apacket* packet) override final; + bool DoTlsHandshake(RSA* key, std::string* auth_key) override final; void Close() override final; virtual void Reset() override final; @@ -279,6 +301,12 @@ class atransport : public enable_weak_from_this { std::string device; std::string devpath; + // If this is set, the transport will initiate the connection with a + // START_TLS command, instead of AUTH. + bool use_tls = false; + int tls_version = A_STLS_VERSION; + int get_tls_version() const; + #if !ADB_HOST // Used to provide the key to the framework. std::string auth_key; @@ -288,6 +316,8 @@ class atransport : public enable_weak_from_this { bool IsTcpDevice() const { return type == kTransportLocal; } #if ADB_HOST + // The current key being authorized. + std::shared_ptr Key(); std::shared_ptr NextKey(); void ResetKeys(); #endif @@ -400,6 +430,10 @@ std::string list_transports(bool long_listing); atransport* find_transport(const char* serial); void kick_all_tcp_devices(); void kick_all_transports(); +void kick_all_tcp_tls_transports(); +#if !ADB_HOST +void kick_all_transports_by_auth_key(std::string_view auth_key); +#endif void register_transport(atransport* transport); void register_usb_transport(usb_handle* h, const char* serial, @@ -410,7 +444,8 @@ void connect_device(const std::string& address, std::string* response); /* cause new transports to be init'd and added to the list */ bool register_socket_transport(unique_fd s, std::string serial, int port, int local, - atransport::ReconnectCallback reconnect, int* error = nullptr); + atransport::ReconnectCallback reconnect, bool use_tls, + int* error = nullptr); // This should only be used for transports with connection_state == kCsNoPerm. void unregister_usb_transport(usb_handle* usb); diff --git a/adb/transport_fd.cpp b/adb/transport_fd.cpp index 8d2ad6616..b9b4f42b9 100644 --- a/adb/transport_fd.cpp +++ b/adb/transport_fd.cpp @@ -155,6 +155,11 @@ struct NonblockingFdConnection : public Connection { thread_.join(); } + bool DoTlsHandshake(RSA* key, std::string* auth_key) override final { + LOG(FATAL) << "Not supported yet"; + return false; + } + void WakeThread() { uint64_t buf = 0; if (TEMP_FAILURE_RETRY(adb_write(wake_fd_write_.get(), &buf, sizeof(buf))) != sizeof(buf)) { diff --git a/adb/transport_local.cpp b/adb/transport_local.cpp index c72618603..5ec8e1665 100644 --- a/adb/transport_local.cpp +++ b/adb/transport_local.cpp @@ -126,7 +126,8 @@ void connect_device(const std::string& address, std::string* response) { }; int error; - if (!register_socket_transport(std::move(fd), serial, port, 0, std::move(reconnect), &error)) { + if (!register_socket_transport(std::move(fd), serial, port, 0, std::move(reconnect), false, + &error)) { if (error == EALREADY) { *response = android::base::StringPrintf("already connected to %s", serial.c_str()); } else if (error == EPERM) { @@ -163,8 +164,9 @@ int local_connect_arbitrary_ports(int console_port, int adb_port, std::string* e close_on_exec(fd.get()); disable_tcp_nagle(fd.get()); std::string serial = getEmulatorSerialString(console_port); - if (register_socket_transport(std::move(fd), std::move(serial), adb_port, 1, - [](atransport*) { return ReconnectResult::Abort; })) { + if (register_socket_transport( + std::move(fd), std::move(serial), adb_port, 1, + [](atransport*) { return ReconnectResult::Abort; }, false)) { return 0; } } @@ -271,8 +273,9 @@ void server_socket_thread(std::function(std::move(fd), adb_port); t->SetConnection( - std::make_unique(std::move(emulator_connection))); + std::make_unique(std::move(emulator_connection))); std::lock_guard lock(local_transports_lock); atransport* existing_transport = find_emulator_transport_by_adb_port_locked(adb_port); if (existing_transport != nullptr) { diff --git a/adb/transport_usb.cpp b/adb/transport_usb.cpp index 3e87522e5..fb81b37e0 100644 --- a/adb/transport_usb.cpp +++ b/adb/transport_usb.cpp @@ -171,6 +171,12 @@ bool UsbConnection::Write(apacket* packet) { return true; } +bool UsbConnection::DoTlsHandshake(RSA* key, std::string* auth_key) { + // TODO: support TLS for usb connections + LOG(FATAL) << "Not supported yet."; + return false; +} + void UsbConnection::Reset() { usb_reset(handle_); usb_kick(handle_);