diff --git a/libstats/push_compat/StatsEventCompat.cpp b/libstats/push_compat/StatsEventCompat.cpp index b065af21c..e1a86ae1a 100644 --- a/libstats/push_compat/StatsEventCompat.cpp +++ b/libstats/push_compat/StatsEventCompat.cpp @@ -40,10 +40,10 @@ const bool StatsEventCompat::mPlatformAtLeastR = GetProperty("ro.build.version.codename", "") == "R" || android_get_device_api_level() > __ANDROID_API_Q__; -// definitions of static class variables +// initializations of static class variables bool StatsEventCompat::mAttemptedLoad = false; -void* StatsEventCompat::mStatsEventApi = nullptr; std::mutex StatsEventCompat::mLoadLock; +AStatsEventApi StatsEventCompat::mAStatsEventApi; static int64_t elapsedRealtimeNano() { return std::chrono::time_point_cast(boot_clock::now()) @@ -56,11 +56,10 @@ StatsEventCompat::StatsEventCompat() : mEventQ(kStatsEventTag) { // environment { std::lock_guard lg(mLoadLock); - if (!mAttemptedLoad) { + if (!mAttemptedLoad && mPlatformAtLeastR) { void* handle = dlopen("libstatssocket.so", RTLD_NOW); if (handle) { - // mStatsEventApi = (struct AStatsEvent_apiTable*)dlsym(handle, - // "table"); + initializeApiTableLocked(handle); } else { ALOGE("dlopen failed: %s\n", dlerror()); } @@ -68,61 +67,93 @@ StatsEventCompat::StatsEventCompat() : mEventQ(kStatsEventTag) { mAttemptedLoad = true; } - if (mStatsEventApi) { - // mEventR = mStatsEventApi->obtain(); - } else if (!mPlatformAtLeastR) { + if (useRSchema()) { + mEventR = mAStatsEventApi.obtain(); + } else if (useQSchema()) { mEventQ << elapsedRealtimeNano(); } } StatsEventCompat::~StatsEventCompat() { - // if (mStatsEventApi) mStatsEventApi->release(mEventR); + if (useRSchema()) mAStatsEventApi.release(mEventR); +} + +// Populates the AStatsEventApi struct by calling dlsym to find the address of +// each API function. +void StatsEventCompat::initializeApiTableLocked(void* handle) { + mAStatsEventApi.obtain = (AStatsEvent* (*)())dlsym(handle, "AStatsEvent_obtain"); + mAStatsEventApi.build = (void (*)(AStatsEvent*))dlsym(handle, "AStatsEvent_build"); + mAStatsEventApi.write = (int (*)(AStatsEvent*))dlsym(handle, "AStatsEvent_write"); + mAStatsEventApi.release = (void (*)(AStatsEvent*))dlsym(handle, "AStatsEvent_release"); + mAStatsEventApi.setAtomId = + (void (*)(AStatsEvent*, uint32_t))dlsym(handle, "AStatsEvent_setAtomId"); + mAStatsEventApi.writeInt32 = + (void (*)(AStatsEvent*, int32_t))dlsym(handle, "AStatsEvent_writeInt32"); + mAStatsEventApi.writeInt64 = + (void (*)(AStatsEvent*, int64_t))dlsym(handle, "AStatsEvent_writeInt64"); + mAStatsEventApi.writeFloat = + (void (*)(AStatsEvent*, float))dlsym(handle, "AStatsEvent_writeFloat"); + mAStatsEventApi.writeBool = + (void (*)(AStatsEvent*, bool))dlsym(handle, "AStatsEvent_writeBool"); + mAStatsEventApi.writeByteArray = (void (*)(AStatsEvent*, const uint8_t*, size_t))dlsym( + handle, "AStatsEvent_writeByteArray"); + mAStatsEventApi.writeString = + (void (*)(AStatsEvent*, const char*))dlsym(handle, "AStatsEvent_writeString"); + mAStatsEventApi.writeAttributionChain = + (void (*)(AStatsEvent*, const uint32_t*, const char* const*, uint8_t))dlsym( + handle, "AStatsEvent_writeAttributionChain"); + mAStatsEventApi.addBoolAnnotation = + (void (*)(AStatsEvent*, uint8_t, bool))dlsym(handle, "AStatsEvent_addBoolAnnotation"); + mAStatsEventApi.addInt32Annotation = (void (*)(AStatsEvent*, uint8_t, int32_t))dlsym( + handle, "AStatsEvent_addInt32Annotation"); + + mAStatsEventApi.initialized = true; } void StatsEventCompat::setAtomId(int32_t atomId) { - if (mStatsEventApi) { - // mStatsEventApi->setAtomId(mEventR, (uint32_t)atomId); - } else if (!mPlatformAtLeastR) { + if (useRSchema()) { + mAStatsEventApi.setAtomId(mEventR, (uint32_t)atomId); + } else if (useQSchema()) { mEventQ << atomId; } } void StatsEventCompat::writeInt32(int32_t value) { - if (mStatsEventApi) { - // mStatsEventApi->writeInt32(mEventR, value); - } else if (!mPlatformAtLeastR) { + if (useRSchema()) { + mAStatsEventApi.writeInt32(mEventR, value); + } else if (useQSchema()) { mEventQ << value; } } void StatsEventCompat::writeInt64(int64_t value) { - if (mStatsEventApi) { - // mStatsEventApi->writeInt64(mEventR, value); - } else if (!mPlatformAtLeastR) { + if (useRSchema()) { + mAStatsEventApi.writeInt64(mEventR, value); + } else if (useQSchema()) { mEventQ << value; } } void StatsEventCompat::writeFloat(float value) { - if (mStatsEventApi) { - // mStatsEventApi->writeFloat(mEventR, value); - } else if (!mPlatformAtLeastR) { + if (useRSchema()) { + mAStatsEventApi.writeFloat(mEventR, value); + } else if (useQSchema()) { mEventQ << value; } } void StatsEventCompat::writeBool(bool value) { - if (mStatsEventApi) { - // mStatsEventApi->writeBool(mEventR, value); - } else if (!mPlatformAtLeastR) { + if (useRSchema()) { + mAStatsEventApi.writeBool(mEventR, value); + } else if (useQSchema()) { mEventQ << value; } } void StatsEventCompat::writeByteArray(const char* buffer, size_t length) { - if (mStatsEventApi) { - // mStatsEventApi->writeByteArray(mEventR, (const uint8_t*)buffer, length); - } else if (!mPlatformAtLeastR) { + if (useRSchema()) { + mAStatsEventApi.writeByteArray(mEventR, reinterpret_cast(buffer), length); + } else if (useQSchema()) { mEventQ.AppendCharArray(buffer, length); } } @@ -130,19 +161,19 @@ void StatsEventCompat::writeByteArray(const char* buffer, size_t length) { void StatsEventCompat::writeString(const char* value) { if (value == nullptr) value = ""; - if (mStatsEventApi) { - // mStatsEventApi->writeString(mEventR, value); - } else if (!mPlatformAtLeastR) { + if (useRSchema()) { + mAStatsEventApi.writeString(mEventR, value); + } else if (useQSchema()) { mEventQ << value; } } void StatsEventCompat::writeAttributionChain(const int32_t* uids, size_t numUids, const vector& tags) { - if (mStatsEventApi) { - // mStatsEventApi->writeAttributionChain(mEventR, (const uint32_t*)uids, tags.data(), - // (uint8_t)numUids); - } else if (!mPlatformAtLeastR) { + if (useRSchema()) { + mAStatsEventApi.writeAttributionChain(mEventR, (const uint32_t*)uids, tags.data(), + (uint8_t)numUids); + } else if (useQSchema()) { mEventQ.begin(); for (size_t i = 0; i < numUids; i++) { mEventQ.begin(); @@ -159,8 +190,8 @@ void StatsEventCompat::writeKeyValuePairs(const map& int32Map, const map& int64Map, const map& stringMap, const map& floatMap) { - // Key value pairs are not supported with AStatsEvent. - if (!mPlatformAtLeastR) { + // AStatsEvent does not support key value pairs. + if (useQSchema()) { mEventQ.begin(); writeKeyValuePairMap(int32Map); writeKeyValuePairMap(int64Map); @@ -187,34 +218,36 @@ template void StatsEventCompat::writeKeyValuePairMap(const map(const map&); void StatsEventCompat::addBoolAnnotation(uint8_t annotationId, bool value) { - // Workaround for unused params. - (void)annotationId; - (void)value; - // if (mStatsEventApi) mStatsEventApi->addBoolAnnotation(mEventR, annotationId, value); + if (useRSchema()) { + mAStatsEventApi.addBoolAnnotation(mEventR, annotationId, value); + } // Don't do anything if on Q. } void StatsEventCompat::addInt32Annotation(uint8_t annotationId, int32_t value) { - // Workaround for unused params. - (void)annotationId; - (void)value; - // if (mStatsEventApi) mStatsEventApi->addInt32Annotation(mEventR, annotationId, value); + if (useRSchema()) { + mAStatsEventApi.addInt32Annotation(mEventR, annotationId, value); + } // Don't do anything if on Q. } int StatsEventCompat::writeToSocket() { - if (mStatsEventApi) { - // mStatsEventApi->build(mEventR); - // return mStatsEventApi->write(mEventR); + if (useRSchema()) { + mAStatsEventApi.build(mEventR); + return mAStatsEventApi.write(mEventR); } - if (!mPlatformAtLeastR) return mEventQ.write(LOG_ID_STATS); + if (useQSchema()) return mEventQ.write(LOG_ID_STATS); - // We reach here only if we're on R, but libstatspush_compat was unable to + // We reach here only if we're on R, but libstatssocket was unable to // be loaded using dlopen. return -ENOLINK; } -bool StatsEventCompat::usesNewSchema() { - return mStatsEventApi != nullptr; +bool StatsEventCompat::useRSchema() { + return mPlatformAtLeastR && mAStatsEventApi.initialized; +} + +bool StatsEventCompat::useQSchema() { + return !mPlatformAtLeastR; } diff --git a/libstats/push_compat/include/StatsEventCompat.h b/libstats/push_compat/include/StatsEventCompat.h index ad423a1c9..00bf48bf5 100644 --- a/libstats/push_compat/include/StatsEventCompat.h +++ b/libstats/push_compat/include/StatsEventCompat.h @@ -26,6 +26,26 @@ using std::map; using std::vector; +struct AStatsEventApi { + // Indicates whether the below function pointers have been set using dlsym. + bool initialized = false; + + AStatsEvent* (*obtain)(void); + void (*build)(AStatsEvent*); + int (*write)(AStatsEvent*); + void (*release)(AStatsEvent*); + void (*setAtomId)(AStatsEvent*, uint32_t); + void (*writeInt32)(AStatsEvent*, int32_t); + void (*writeInt64)(AStatsEvent*, int64_t); + void (*writeFloat)(AStatsEvent*, float); + void (*writeBool)(AStatsEvent*, bool); + void (*writeByteArray)(AStatsEvent*, const uint8_t*, size_t); + void (*writeString)(AStatsEvent*, const char*); + void (*writeAttributionChain)(AStatsEvent*, const uint32_t*, const char* const*, uint8_t); + void (*addBoolAnnotation)(AStatsEvent*, uint8_t, bool); + void (*addInt32Annotation)(AStatsEvent*, uint8_t, int32_t); +}; + class StatsEventCompat { public: StatsEventCompat(); @@ -57,8 +77,7 @@ class StatsEventCompat { const static bool mPlatformAtLeastR; static bool mAttemptedLoad; static std::mutex mLoadLock; - // static struct AStatsEvent_apiTable* mStatsEventApi; - static void* mStatsEventApi; + static AStatsEventApi mAStatsEventApi; // non-static member variables AStatsEvent* mEventR = nullptr; @@ -67,6 +86,9 @@ class StatsEventCompat { template void writeKeyValuePairMap(const map& keyValuePairMap); - bool usesNewSchema(); + void initializeApiTableLocked(void* handle); + bool useRSchema(); + bool useQSchema(); + FRIEND_TEST(StatsEventCompatTest, TestDynamicLoading); }; diff --git a/libstats/push_compat/tests/StatsEventCompat_test.cpp b/libstats/push_compat/tests/StatsEventCompat_test.cpp index 2be24ec10..dcb37973e 100644 --- a/libstats/push_compat/tests/StatsEventCompat_test.cpp +++ b/libstats/push_compat/tests/StatsEventCompat_test.cpp @@ -29,10 +29,10 @@ using android::base::GetProperty; * * TODO(b/146019024): migrate to android_get_device_api_level() */ -const static bool mPlatformAtLeastR = GetProperty("ro.build.version.release", "") == "R" || +const static bool mPlatformAtLeastR = GetProperty("ro.build.version.codename", "") == "R" || android_get_device_api_level() > __ANDROID_API_Q__; TEST(StatsEventCompatTest, TestDynamicLoading) { StatsEventCompat event; - EXPECT_EQ(mPlatformAtLeastR, event.usesNewSchema()); + EXPECT_EQ(mPlatformAtLeastR, event.useRSchema()); }