diff --git a/fs_mgr/libsnapshot/Android.bp b/fs_mgr/libsnapshot/Android.bp index 5e5f06dc3..95606d703 100644 --- a/fs_mgr/libsnapshot/Android.bp +++ b/fs_mgr/libsnapshot/Android.bp @@ -403,7 +403,7 @@ cc_defaults { ], srcs: [ "snapuserd_server.cpp", - "snapuserd.cpp", + "snapuserd.cpp", "snapuserd_daemon.cpp", ], @@ -558,7 +558,7 @@ cc_test { "libbrotli", "libgtest", "libsnapshot_cow", - "libsnapshot_snapuserd", + "libsnapshot_snapuserd", "libcutils_sockets", "libz", "libdm", diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h index ab2149e04..d6713b80c 100644 --- a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h +++ b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h @@ -26,6 +26,9 @@ namespace snapshot { static constexpr uint32_t PACKET_SIZE = 512; static constexpr uint32_t MAX_CONNECT_RETRY_COUNT = 10; +static constexpr char kSnapuserdSocketFirstStage[] = "snapuserd_first_stage"; +static constexpr char kSnapuserdSocket[] = "snapuserd"; + class SnapuserdClient { private: int sockfd_ = 0; diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_daemon.h b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_daemon.h index 94542d760..c6779b817 100644 --- a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_daemon.h +++ b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_daemon.h @@ -25,22 +25,21 @@ class Daemon { // The Daemon class is a singleton to avoid // instantiating more than once public: + Daemon() {} + static Daemon& Instance() { static Daemon instance; return instance; } - int StartServer(std::string socketname); - bool IsRunning(); + bool StartServer(const std::string& socketname); void Run(); + void Interrupt(); private: - bool is_running_; - std::unique_ptr poll_fd_; // Signal mask used with ppoll() sigset_t signal_mask_; - Daemon(); Daemon(Daemon const&) = delete; void operator=(Daemon const&) = delete; diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_server.h b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_server.h index a1ebd3af4..357acac9a 100644 --- a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_server.h +++ b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_server.h @@ -14,6 +14,8 @@ #pragma once +#include + #include #include #include @@ -34,12 +36,11 @@ static constexpr uint32_t MAX_PACKET_SIZE = 512; enum class DaemonOperations { START, QUERY, - TERMINATING, STOP, INVALID, }; -class Client { +class DmUserHandler { private: std::unique_ptr threadHandler_; @@ -77,7 +78,15 @@ class SnapuserdServer : public Stoppable { private: android::base::unique_fd sockfd_; bool terminating_; - std::vector> clients_vec_; + std::vector> dm_users_; + std::vector watched_fds_; + + void AddWatchedFd(android::base::borrowed_fd fd); + void AcceptClient(); + bool HandleClient(android::base::borrowed_fd fd, int revents); + bool Recv(android::base::borrowed_fd fd, std::string* data); + bool Sendmsg(android::base::borrowed_fd fd, const std::string& msg); + bool Receivemsg(android::base::borrowed_fd fd, const std::string& msg); void ThreadStart(std::string cow_device, std::string backing_device, std::string control_device) override; @@ -92,13 +101,11 @@ class SnapuserdServer : public Stoppable { public: SnapuserdServer() { terminating_ = false; } + ~SnapuserdServer(); - int Start(std::string socketname); - int AcceptClient(); - int Receivemsg(int fd); - int Sendmsg(int fd, char* msg, size_t len); - std::string Recvmsg(int fd, int* ret); - android::base::borrowed_fd GetSocketFd() { return sockfd_; } + bool Start(const std::string& socketname); + bool Run(); + void Interrupt(); }; } // namespace snapshot diff --git a/fs_mgr/libsnapshot/snapuserd.cpp b/fs_mgr/libsnapshot/snapuserd.cpp index 6a82a004e..6e772ad46 100644 --- a/fs_mgr/libsnapshot/snapuserd.cpp +++ b/fs_mgr/libsnapshot/snapuserd.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -482,13 +483,13 @@ int Snapuserd::WriteDmUserPayload(size_t size) { bool Snapuserd::Init() { backing_store_fd_.reset(open(backing_store_device_.c_str(), O_RDONLY)); if (backing_store_fd_ < 0) { - LOG(ERROR) << "Open Failed: " << backing_store_device_; + PLOG(ERROR) << "Open Failed: " << backing_store_device_; return false; } cow_fd_.reset(open(cow_device_.c_str(), O_RDWR)); if (cow_fd_ < 0) { - LOG(ERROR) << "Open Failed: " << cow_device_; + PLOG(ERROR) << "Open Failed: " << cow_device_; return false; } @@ -498,7 +499,7 @@ bool Snapuserd::Init() { ctrl_fd_.reset(open(control_path.c_str(), O_RDWR)); if (ctrl_fd_ < 0) { - LOG(ERROR) << "Unable to open " << control_path; + PLOG(ERROR) << "Unable to open " << control_path; return false; } @@ -629,7 +630,11 @@ int main([[maybe_unused]] int argc, char** argv) { android::snapshot::Daemon& daemon = android::snapshot::Daemon::Instance(); - daemon.StartServer(argv[1]); + std::string socket = android::snapshot::kSnapuserdSocket; + if (argc >= 2) { + socket = argv[1]; + } + daemon.StartServer(socket); daemon.Run(); return 0; diff --git a/fs_mgr/libsnapshot/snapuserd_daemon.cpp b/fs_mgr/libsnapshot/snapuserd_daemon.cpp index 8e7661896..4c8fa5768 100644 --- a/fs_mgr/libsnapshot/snapuserd_daemon.cpp +++ b/fs_mgr/libsnapshot/snapuserd_daemon.cpp @@ -20,16 +20,12 @@ namespace android { namespace snapshot { -int Daemon::StartServer(std::string socketname) { - int ret; - - ret = server_.Start(socketname); - if (ret < 0) { +bool Daemon::StartServer(const std::string& socketname) { + if (!server_.Start(socketname)) { LOG(ERROR) << "Snapuserd daemon failed to start..."; exit(EXIT_FAILURE); } - - return ret; + return true; } void Daemon::MaskAllSignalsExceptIntAndTerm() { @@ -51,51 +47,26 @@ void Daemon::MaskAllSignals() { } } -Daemon::Daemon() { - is_running_ = true; -} - -bool Daemon::IsRunning() { - return is_running_; -} - void Daemon::Run() { - poll_fd_ = std::make_unique(); - poll_fd_->fd = server_.GetSocketFd().get(); - poll_fd_->events = POLLIN; - sigfillset(&signal_mask_); sigdelset(&signal_mask_, SIGINT); sigdelset(&signal_mask_, SIGTERM); // Masking signals here ensure that after this point, we won't handle INT/TERM // until after we call into ppoll() - MaskAllSignals(); signal(SIGINT, Daemon::SignalHandler); signal(SIGTERM, Daemon::SignalHandler); signal(SIGPIPE, Daemon::SignalHandler); LOG(DEBUG) << "Snapuserd-server: ready to accept connections"; - while (IsRunning()) { - int ret = ppoll(poll_fd_.get(), 1, nullptr, &signal_mask_); - MaskAllSignalsExceptIntAndTerm(); + MaskAllSignalsExceptIntAndTerm(); - if (ret == -1) { - PLOG(ERROR) << "Snapuserd:ppoll error"; - break; - } + server_.Run(); +} - if (poll_fd_->revents == POLLIN) { - if (server_.AcceptClient() == static_cast(DaemonOperations::STOP)) { - Daemon::Instance().is_running_ = false; - } - } - - // Mask all signals to ensure that is_running_ can't become false between - // checking it in the while condition and calling into ppoll() - MaskAllSignals(); - } +void Daemon::Interrupt() { + server_.Interrupt(); } void Daemon::SignalHandler(int signal) { @@ -103,7 +74,7 @@ void Daemon::SignalHandler(int signal) { switch (signal) { case SIGINT: case SIGTERM: { - Daemon::Instance().is_running_ = false; + Daemon::Instance().Interrupt(); break; } case SIGPIPE: { diff --git a/fs_mgr/libsnapshot/snapuserd_server.cpp b/fs_mgr/libsnapshot/snapuserd_server.cpp index 53101aafa..48a3b2a79 100644 --- a/fs_mgr/libsnapshot/snapuserd_server.cpp +++ b/fs_mgr/libsnapshot/snapuserd_server.cpp @@ -35,12 +35,18 @@ namespace snapshot { DaemonOperations SnapuserdServer::Resolveop(std::string& input) { if (input == "start") return DaemonOperations::START; if (input == "stop") return DaemonOperations::STOP; - if (input == "terminate-request") return DaemonOperations::TERMINATING; if (input == "query") return DaemonOperations::QUERY; return DaemonOperations::INVALID; } +SnapuserdServer::~SnapuserdServer() { + // Close any client sockets that were added via AcceptClient(). + for (size_t i = 1; i < watched_fds_.size(); i++) { + close(watched_fds_[i].fd); + } +} + std::string SnapuserdServer::GetDaemonStatus() { std::string msg = ""; @@ -67,7 +73,7 @@ void SnapuserdServer::ThreadStart(std::string cow_device, std::string backing_de std::string control_device) { Snapuserd snapd(cow_device, backing_device, control_device); if (!snapd.Init()) { - PLOG(ERROR) << "Snapuserd: Init failed"; + LOG(ERROR) << "Snapuserd: Init failed"; return; } @@ -84,158 +90,174 @@ void SnapuserdServer::ThreadStart(std::string cow_device, std::string backing_de void SnapuserdServer::ShutdownThreads() { StopThreads(); - for (auto& client : clients_vec_) { + for (auto& client : dm_users_) { auto& th = client->GetThreadHandler(); if (th->joinable()) th->join(); } } -int SnapuserdServer::Sendmsg(int fd, char* msg, size_t size) { - int ret = TEMP_FAILURE_RETRY(send(fd, (char*)msg, size, 0)); +bool SnapuserdServer::Sendmsg(android::base::borrowed_fd fd, const std::string& msg) { + ssize_t ret = TEMP_FAILURE_RETRY(send(fd.get(), msg.data(), msg.size(), 0)); if (ret < 0) { PLOG(ERROR) << "Snapuserd:server: send() failed"; - return -1; + return false; } - if (ret < size) { - PLOG(ERROR) << "Partial data sent"; - return -1; + if (ret < msg.size()) { + LOG(ERROR) << "Partial send; expected " << msg.size() << " bytes, sent " << ret; + return false; } - - return 0; + return true; } -std::string SnapuserdServer::Recvmsg(int fd, int* ret) { - struct timeval tv; - fd_set set; +bool SnapuserdServer::Recv(android::base::borrowed_fd fd, std::string* data) { char msg[MAX_PACKET_SIZE]; + ssize_t rv = TEMP_FAILURE_RETRY(recv(fd.get(), msg, sizeof(msg), 0)); + if (rv < 0) { + PLOG(ERROR) << "recv failed"; + return false; + } + *data = std::string(msg, rv); + return true; +} - tv.tv_sec = 2; - tv.tv_usec = 0; - FD_ZERO(&set); - FD_SET(fd, &set); - *ret = select(fd + 1, &set, NULL, NULL, &tv); - if (*ret == -1) { // select failed - return {}; - } else if (*ret == 0) { // timeout - return {}; +bool SnapuserdServer::Receivemsg(android::base::borrowed_fd fd, const std::string& str) { + const char delim = ','; + + std::vector out; + Parsemsg(str, delim, out); + DaemonOperations op = Resolveop(out[0]); + + switch (op) { + case DaemonOperations::START: { + // Message format: + // start,,, + // + // Start the new thread which binds to dm-user misc device + auto handler = std::make_unique(); + handler->SetThreadHandler( + std::bind(&SnapuserdServer::ThreadStart, this, out[1], out[2], out[3])); + dm_users_.push_back(std::move(handler)); + return Sendmsg(fd, "success"); + } + case DaemonOperations::STOP: { + // Message format: stop + // + // Stop all the threads gracefully and then shutdown the + // main thread + SetTerminating(); + ShutdownThreads(); + return true; + } + case DaemonOperations::QUERY: { + // Message format: query + // + // As part of transition, Second stage daemon will be + // created before terminating the first stage daemon. Hence, + // for a brief period client may have to distiguish between + // first stage daemon and second stage daemon. + // + // Second stage daemon is marked as active and hence will + // be ready to receive control message. + return Sendmsg(fd, GetDaemonStatus()); + } + default: { + LOG(ERROR) << "Received unknown message type from client"; + Sendmsg(fd, "fail"); + return false; + } + } +} + +bool SnapuserdServer::Start(const std::string& socketname) { + sockfd_.reset(android_get_control_socket(socketname.c_str())); + if (sockfd_ >= 0) { + if (listen(sockfd_.get(), 4) < 0) { + PLOG(ERROR) << "listen socket failed: " << socketname; + return false; + } } else { - *ret = TEMP_FAILURE_RETRY(recv(fd, msg, MAX_PACKET_SIZE, 0)); - if (*ret < 0) { - PLOG(ERROR) << "Snapuserd:server: recv failed"; - return {}; - } else if (*ret == 0) { - LOG(DEBUG) << "Snapuserd client disconnected"; - return {}; - } else { - std::string str(msg); - return str; + sockfd_.reset(socket_local_server(socketname.c_str(), ANDROID_SOCKET_NAMESPACE_RESERVED, + SOCK_STREAM)); + if (sockfd_ < 0) { + PLOG(ERROR) << "Failed to create server socket " << socketname; + return false; } } -} -int SnapuserdServer::Receivemsg(int fd) { - char msg[MAX_PACKET_SIZE]; - std::unique_ptr newClient; - int ret = 0; - - while (1) { - memset(msg, '\0', MAX_PACKET_SIZE); - std::string str = Recvmsg(fd, &ret); - - if (ret <= 0) { - LOG(DEBUG) << "recv failed with ret: " << ret; - return 0; - } - - const char delim = ','; - - std::vector out; - Parsemsg(str, delim, out); - DaemonOperations op = Resolveop(out[0]); - memset(msg, '\0', MAX_PACKET_SIZE); - - switch (op) { - case DaemonOperations::START: { - // Message format: - // start,,, - // - // Start the new thread which binds to dm-user misc device - newClient = std::make_unique(); - newClient->SetThreadHandler( - std::bind(&SnapuserdServer::ThreadStart, this, out[1], out[2], out[3])); - clients_vec_.push_back(std::move(newClient)); - sprintf(msg, "success"); - Sendmsg(fd, msg, MAX_PACKET_SIZE); - return 0; - } - case DaemonOperations::STOP: { - // Message format: stop - // - // Stop all the threads gracefully and then shutdown the - // main thread - ShutdownThreads(); - return static_cast(DaemonOperations::STOP); - } - case DaemonOperations::TERMINATING: { - // Message format: terminate-request - // - // This is invoked during transition. First stage - // daemon will receive this request. First stage daemon - // will be considered as a passive daemon from hereon. - SetTerminating(); - sprintf(msg, "success"); - Sendmsg(fd, msg, MAX_PACKET_SIZE); - return 0; - } - case DaemonOperations::QUERY: { - // Message format: query - // - // As part of transition, Second stage daemon will be - // created before terminating the first stage daemon. Hence, - // for a brief period client may have to distiguish between - // first stage daemon and second stage daemon. - // - // Second stage daemon is marked as active and hence will - // be ready to receive control message. - std::string dstr = GetDaemonStatus(); - memcpy(msg, dstr.c_str(), dstr.size()); - Sendmsg(fd, msg, MAX_PACKET_SIZE); - if (dstr == "active") - break; - else - return 0; - } - default: { - sprintf(msg, "fail"); - Sendmsg(fd, msg, MAX_PACKET_SIZE); - return 0; - } - } - } -} - -int SnapuserdServer::Start(std::string socketname) { - sockfd_.reset(socket_local_server(socketname.c_str(), ANDROID_SOCKET_NAMESPACE_RESERVED, - SOCK_STREAM)); - if (sockfd_ < 0) { - PLOG(ERROR) << "Failed to create server socket " << socketname; - return -1; - } + AddWatchedFd(sockfd_); LOG(DEBUG) << "Snapuserd server successfully started with socket name " << socketname; - return 0; + return true; } -int SnapuserdServer::AcceptClient() { - int fd = accept(sockfd_.get(), NULL, NULL); +bool SnapuserdServer::Run() { + while (!IsTerminating()) { + int rv = TEMP_FAILURE_RETRY(poll(watched_fds_.data(), watched_fds_.size(), -1)); + if (rv < 0) { + PLOG(ERROR) << "poll failed"; + return false; + } + if (!rv) { + continue; + } + + if (watched_fds_[0].revents) { + AcceptClient(); + } + + auto iter = watched_fds_.begin() + 1; + while (iter != watched_fds_.end()) { + if (iter->revents && !HandleClient(iter->fd, iter->revents)) { + close(iter->fd); + iter = watched_fds_.erase(iter); + } else { + iter++; + } + } + } + return true; +} + +void SnapuserdServer::AddWatchedFd(android::base::borrowed_fd fd) { + struct pollfd p = {}; + p.fd = fd.get(); + p.events = POLLIN; + watched_fds_.emplace_back(std::move(p)); +} + +void SnapuserdServer::AcceptClient() { + int fd = TEMP_FAILURE_RETRY(accept4(sockfd_.get(), nullptr, nullptr, SOCK_CLOEXEC)); if (fd < 0) { - PLOG(ERROR) << "Socket accept failed: " << strerror(errno); - return -1; + PLOG(ERROR) << "accept4 failed"; + return; } - return Receivemsg(fd); + AddWatchedFd(fd); +} + +bool SnapuserdServer::HandleClient(android::base::borrowed_fd fd, int revents) { + if (revents & POLLHUP) { + LOG(DEBUG) << "Snapuserd client disconnected"; + return false; + } + + std::string str; + if (!Recv(fd, &str)) { + return false; + } + if (!Receivemsg(fd, str)) { + LOG(ERROR) << "Encountered error handling client message, revents: " << revents; + return false; + } + return true; +} + +void SnapuserdServer::Interrupt() { + // Force close the socket so poll() fails. + sockfd_ = {}; + SetTerminating(); } } // namespace snapshot