diff --git a/init/Android.bp b/init/Android.bp index 2dd968369..58f4a9eff 100644 --- a/init/Android.bp +++ b/init/Android.bp @@ -217,6 +217,7 @@ cc_library_static { "selinux_policy_version", ], srcs: init_common_sources + init_device_sources, + export_include_dirs: ["."], generated_sources: [ "apex-info-list", ], @@ -246,6 +247,10 @@ cc_library_static { ], }, }, + visibility: [ + "//system/apex/apexd", + "//frameworks/native/cmds/installd", + ], } phony { diff --git a/init/init.cpp b/init/init.cpp index 5f516b704..91922143b 100644 --- a/init/init.cpp +++ b/init/init.cpp @@ -79,6 +79,7 @@ #include "selabel.h" #include "selinux.h" #include "service.h" +#include "service_list.h" #include "service_parser.h" #include "sigchld_handler.h" #include "snapuserd_transition.h" @@ -443,11 +444,32 @@ static Result DoControlRestart(Service* service) { return {}; } +int StopServicesFromApex(const std::string& apex_name) { + auto services = ServiceList::GetInstance().FindServicesByApexName(apex_name); + if (services.empty()) { + LOG(INFO) << "No service found for APEX: " << apex_name; + return 0; + } + std::set service_names; + for (const auto& service : services) { + service_names.emplace(service->name()); + } + constexpr std::chrono::milliseconds kServiceStopTimeout = 10s; + int still_running = StopServicesAndLogViolations(service_names, kServiceStopTimeout, + true /*SIGTERM*/); + // Send SIGKILL to ones that didn't terminate cleanly. + if (still_running > 0) { + still_running = StopServicesAndLogViolations(service_names, 0ms, false /*SIGKILL*/); + } + return still_running; +} + static Result DoUnloadApex(const std::string& apex_name) { - std::string prop_name = "init.apex." + apex_name; + if (StopServicesFromApex(apex_name) > 0) { + return Error() << "Unable to stop all service from " << apex_name; + } // TODO(b/232114573) remove services and actions read from the apex - // TODO(b/232799709) kill services from the apex - SetProperty(prop_name, "unloaded"); + SetProperty("init.apex." + apex_name, "unloaded"); return {}; } @@ -471,14 +493,12 @@ static Result UpdateApexLinkerConfig(const std::string& apex_name) { } static Result DoLoadApex(const std::string& apex_name) { - std::string prop_name = "init.apex." + apex_name; // TODO(b/232799709) read .rc files from the apex - if (auto result = UpdateApexLinkerConfig(apex_name); !result.ok()) { return result.error(); } - SetProperty(prop_name, "loaded"); + SetProperty("init.apex." + apex_name, "loaded"); return {}; } diff --git a/init/init.h b/init/init.h index 522053549..dd44e95b6 100644 --- a/init/init.h +++ b/init/init.h @@ -46,5 +46,7 @@ void DebugRebootLogging(); int SecondStageMain(int argc, char** argv); +int StopServicesFromApex(const std::string& apex_name); + } // namespace init } // namespace android diff --git a/init/init_test.cpp b/init/init_test.cpp index 5651a835d..e7218e8fe 100644 --- a/init/init_test.cpp +++ b/init/init_test.cpp @@ -15,11 +15,14 @@ */ #include +#include +#include #include #include #include #include +#include #include "action.h" #include "action_manager.h" @@ -27,6 +30,7 @@ #include "builtin_arguments.h" #include "builtins.h" #include "import_parser.h" +#include "init.h" #include "keyword_map.h" #include "parser.h" #include "service.h" @@ -37,6 +41,7 @@ using android::base::GetIntProperty; using android::base::GetProperty; using android::base::SetProperty; +using android::base::StringReplace; using android::base::WaitForProperty; using namespace std::literals; @@ -188,6 +193,186 @@ service A something EXPECT_TRUE(service->is_override()); } +static std::string GetSecurityContext() { + char* ctx; + if (getcon(&ctx) == -1) { + ADD_FAILURE() << "Failed to call getcon : " << strerror(errno); + } + std::string result = std::string(ctx); + freecon(ctx); + return result; +} + +void TestStartApexServices(const std::vector& service_names, + const std::string& apex_name) { + for (auto const& svc : service_names) { + auto service = ServiceList::GetInstance().FindService(svc); + ASSERT_NE(nullptr, service); + ASSERT_RESULT_OK(service->Start()); + ASSERT_TRUE(service->IsRunning()); + LOG(INFO) << "Service " << svc << " is running"; + if (!apex_name.empty()) { + service->set_filename("/apex/" + apex_name + "/init_test.rc"); + } else { + service->set_filename(""); + } + } + if (!apex_name.empty()) { + auto apex_services = ServiceList::GetInstance().FindServicesByApexName(apex_name); + EXPECT_EQ(service_names.size(), apex_services.size()); + } +} + +void TestStopApexServices(const std::vector& service_names, bool expect_to_run) { + for (auto const& svc : service_names) { + auto service = ServiceList::GetInstance().FindService(svc); + ASSERT_NE(nullptr, service); + EXPECT_EQ(expect_to_run, service->IsRunning()); + } + ServiceList::GetInstance().RemoveServiceIf([&](const std::unique_ptr& s) -> bool { + if (std::find(service_names.begin(), service_names.end(), s->name()) + != service_names.end()) { + return true; + } + return false; + }); +} + +void InitApexService(const std::string_view& init_template) { + std::string init_script = StringReplace(init_template, "$selabel", + GetSecurityContext(), true); + + ActionManager action_manager; + TestInitText(init_script, BuiltinFunctionMap(), {}, &action_manager, + &ServiceList::GetInstance()); +} + +void TestApexServicesInit(const std::vector& apex_services, + const std::vector& other_apex_services, + const std::vector non_apex_services) { + auto num_svc = apex_services.size() + other_apex_services.size() + non_apex_services.size(); + ASSERT_EQ(static_cast(num_svc), std::distance(ServiceList::GetInstance().begin(), + ServiceList::GetInstance().end())); + + TestStartApexServices(apex_services, "com.android.apex.test_service"); + TestStartApexServices(other_apex_services, "com.android.other_apex.test_service"); + TestStartApexServices(non_apex_services, /*apex_anme=*/ ""); + + StopServicesFromApex("com.android.apex.test_service"); + TestStopApexServices(apex_services, /*expect_to_run=*/ false); + TestStopApexServices(other_apex_services, /*expect_to_run=*/ true); + TestStopApexServices(non_apex_services, /*expect_to_run=*/ true); + + ASSERT_EQ(0, std::distance(ServiceList::GetInstance().begin(), + ServiceList::GetInstance().end())); +} + +TEST(init, StopServiceByApexName) { + std::string_view script_template = R"init( +service apex_test_service /system/bin/yes + user shell + group shell + seclabel $selabel +)init"; + InitApexService(script_template); + TestApexServicesInit({"apex_test_service"}, {}, {}); +} + +TEST(init, StopMultipleServicesByApexName) { + std::string_view script_template = R"init( +service apex_test_service_multiple_a /system/bin/yes + user shell + group shell + seclabel $selabel +service apex_test_service_multiple_b /system/bin/id + user shell + group shell + seclabel $selabel +)init"; + InitApexService(script_template); + TestApexServicesInit({"apex_test_service_multiple_a", + "apex_test_service_multiple_b"}, {}, {}); +} + +TEST(init, StopServicesFromMultipleApexes) { + std::string_view apex_script_template = R"init( +service apex_test_service_multi_apex_a /system/bin/yes + user shell + group shell + seclabel $selabel +service apex_test_service_multi_apex_b /system/bin/id + user shell + group shell + seclabel $selabel +)init"; + InitApexService(apex_script_template); + + std::string_view other_apex_script_template = R"init( +service apex_test_service_multi_apex_c /system/bin/yes + user shell + group shell + seclabel $selabel +)init"; + InitApexService(other_apex_script_template); + + TestApexServicesInit({"apex_test_service_multi_apex_a", + "apex_test_service_multi_apex_b"}, {"apex_test_service_multi_apex_c"}, {}); +} + +TEST(init, StopServicesFromApexAndNonApex) { + std::string_view apex_script_template = R"init( +service apex_test_service_apex_a /system/bin/yes + user shell + group shell + seclabel $selabel +service apex_test_service_apex_b /system/bin/id + user shell + group shell + seclabel $selabel +)init"; + InitApexService(apex_script_template); + + std::string_view non_apex_script_template = R"init( +service apex_test_service_non_apex /system/bin/yes + user shell + group shell + seclabel $selabel +)init"; + InitApexService(non_apex_script_template); + + TestApexServicesInit({"apex_test_service_apex_a", + "apex_test_service_apex_b"}, {}, {"apex_test_service_non_apex"}); +} + +TEST(init, StopServicesFromApexMixed) { + std::string_view script_template = R"init( +service apex_test_service_mixed_a /system/bin/yes + user shell + group shell + seclabel $selabel +)init"; + InitApexService(script_template); + + std::string_view other_apex_script_template = R"init( +service apex_test_service_mixed_b /system/bin/yes + user shell + group shell + seclabel $selabel +)init"; + InitApexService(other_apex_script_template); + + std::string_view non_apex_script_template = R"init( +service apex_test_service_mixed_c /system/bin/yes + user shell + group shell + seclabel $selabel +)init"; + InitApexService(non_apex_script_template); + + TestApexServicesInit({"apex_test_service_mixed_a"}, + {"apex_test_service_mixed_b"}, {"apex_test_service_mixed_c"}); +} + TEST(init, EventTriggerOrderMultipleFiles) { // 6 total files, which should have their triggers executed in the following order: // 1: start - original script parsed diff --git a/init/service.h b/init/service.h index f7f32d92d..7af361584 100644 --- a/init/service.h +++ b/init/service.h @@ -142,6 +142,8 @@ class Service { } } Subcontext* subcontext() const { return subcontext_; } + const std::string& filename() const { return filename_; } + void set_filename(const std::string& name) { filename_ = name; } private: void NotifyStateChange(const std::string& new_state) const; diff --git a/init/service_list.h b/init/service_list.h index 555da258a..33aaa5f8b 100644 --- a/init/service_list.h +++ b/init/service_list.h @@ -16,10 +16,14 @@ #pragma once +#include #include #include +#include + #include "service.h" +#include "util.h" namespace android { namespace init { @@ -52,6 +56,17 @@ class ServiceList { return nullptr; } + std::vector FindServicesByApexName(const std::string& apex_name) const { + CHECK(!apex_name.empty()) << "APEX name cannot be empty"; + std::vector matches; + for (const auto& svc : services_) { + if (GetApexNameFromFileName(svc->filename()) == apex_name) { + matches.emplace_back(svc.get()); + } + } + return matches; + } + Service* FindInterface(const std::string& interface_name) { for (const auto& svc : services_) { if (svc->interfaces().count(interface_name) > 0) {