Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit def7a3cf authored by Devin Moore's avatar Devin Moore
Browse files

Keep track of DeathMonitor cookies

This change keeps track of the objects that the cookies points to so the
serviceDied callback knows when it can use the cookie.

Test: atest neuralnetworks_utils_hal_aidl_test
Tets: atest NeuralNetworksTest_static
Bug: 319210610

Change-Id: I418cbc6baa19aa702d9fd2e7d8096fe1a02b7794
parent 1e3f8851
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -56,6 +56,8 @@ class IProtectedCallback {
// Thread safe class
class DeathMonitor final {
  public:
    explicit DeathMonitor(uintptr_t cookieKey) : kCookieKey(cookieKey) {}

    static void serviceDied(void* cookie);
    void serviceDied();
    // Precondition: `killable` must be non-null.
@@ -63,9 +65,18 @@ class DeathMonitor final {
    // Precondition: `killable` must be non-null.
    void remove(IProtectedCallback* killable) const;

    uintptr_t getCookieKey() const { return kCookieKey; }

    ~DeathMonitor();
    DeathMonitor(const DeathMonitor&) = delete;
    DeathMonitor(DeathMonitor&&) noexcept = delete;
    DeathMonitor& operator=(const DeathMonitor&) = delete;
    DeathMonitor& operator=(DeathMonitor&&) noexcept = delete;

  private:
    mutable std::mutex mMutex;
    mutable std::vector<IProtectedCallback*> mObjects GUARDED_BY(mMutex);
    const uintptr_t kCookieKey;
};

class DeathHandler final {
+49 −7
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@

#include <algorithm>
#include <functional>
#include <map>
#include <memory>
#include <mutex>
#include <vector>
@@ -33,6 +34,16 @@

namespace aidl::android::hardware::neuralnetworks::utils {

namespace {

// Only dereference the cookie if it's valid (if it's in this set)
// Only used with ndk
std::mutex sCookiesMutex;
uintptr_t sCookieKeyCounter GUARDED_BY(sCookiesMutex) = 0;
std::map<uintptr_t, std::weak_ptr<DeathMonitor>> sCookies GUARDED_BY(sCookiesMutex);

}  // namespace

void DeathMonitor::serviceDied() {
    std::lock_guard guard(mMutex);
    std::for_each(mObjects.begin(), mObjects.end(),
@@ -40,8 +51,24 @@ void DeathMonitor::serviceDied() {
}

void DeathMonitor::serviceDied(void* cookie) {
    auto deathMonitor = static_cast<DeathMonitor*>(cookie);
    deathMonitor->serviceDied();
    std::shared_ptr<DeathMonitor> monitor;
    {
        std::lock_guard<std::mutex> guard(sCookiesMutex);
        if (auto it = sCookies.find(reinterpret_cast<uintptr_t>(cookie)); it != sCookies.end()) {
            monitor = it->second.lock();
            sCookies.erase(it);
        } else {
            LOG(INFO)
                    << "Service died, but cookie is no longer valid so there is nothing to notify.";
            return;
        }
    }
    if (monitor) {
        LOG(INFO) << "Notifying DeathMonitor from serviceDied.";
        monitor->serviceDied();
    } else {
        LOG(INFO) << "Tried to notify DeathMonitor from serviceDied but could not promote.";
    }
}

void DeathMonitor::add(IProtectedCallback* killable) const {
@@ -57,12 +84,25 @@ void DeathMonitor::remove(IProtectedCallback* killable) const {
    mObjects.erase(removedIter);
}

DeathMonitor::~DeathMonitor() {
    // lock must be taken so object is not used in OnBinderDied"
    std::lock_guard<std::mutex> guard(sCookiesMutex);
    sCookies.erase(kCookieKey);
}

nn::GeneralResult<DeathHandler> DeathHandler::create(std::shared_ptr<ndk::ICInterface> object) {
    if (object == nullptr) {
        return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
               << "utils::DeathHandler::create must have non-null object";
    }
    auto deathMonitor = std::make_shared<DeathMonitor>();

    std::shared_ptr<DeathMonitor> deathMonitor;
    {
        std::lock_guard<std::mutex> guard(sCookiesMutex);
        deathMonitor = std::make_shared<DeathMonitor>(sCookieKeyCounter++);
        sCookies[deathMonitor->getCookieKey()] = deathMonitor;
    }

    auto deathRecipient = ndk::ScopedAIBinder_DeathRecipient(
            AIBinder_DeathRecipient_new(DeathMonitor::serviceDied));

@@ -70,8 +110,9 @@ nn::GeneralResult<DeathHandler> DeathHandler::create(std::shared_ptr<ndk::ICInte
    // STATUS_INVALID_OPERATION. We ignore this case because we only use local binders in tests
    // where this is not an error.
    if (object->isRemote()) {
        const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_linkToDeath(
                object->asBinder().get(), deathRecipient.get(), deathMonitor.get()));
        const auto ret = ndk::ScopedAStatus::fromStatus(
                AIBinder_linkToDeath(object->asBinder().get(), deathRecipient.get(),
                                     reinterpret_cast<void*>(deathMonitor->getCookieKey())));
        HANDLE_ASTATUS(ret) << "AIBinder_linkToDeath failed";
    }

@@ -91,8 +132,9 @@ DeathHandler::DeathHandler(std::shared_ptr<ndk::ICInterface> object,

DeathHandler::~DeathHandler() {
    if (kObject != nullptr && kDeathRecipient.get() != nullptr && kDeathMonitor != nullptr) {
        const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_unlinkToDeath(
                kObject->asBinder().get(), kDeathRecipient.get(), kDeathMonitor.get()));
        const auto ret = ndk::ScopedAStatus::fromStatus(
                AIBinder_unlinkToDeath(kObject->asBinder().get(), kDeathRecipient.get(),
                                       reinterpret_cast<void*>(kDeathMonitor->getCookieKey())));
        const auto maybeSuccess = handleTransportError(ret);
        if (!maybeSuccess.ok()) {
            LOG(ERROR) << maybeSuccess.error().message;
+6 −3
Original line number Diff line number Diff line
@@ -697,7 +697,8 @@ TEST_P(DeviceTest, prepareModelAsyncCrash) {
    const auto mockDevice = createMockDevice();
    const auto device = Device::create(kName, mockDevice, kVersion).value();
    const auto ret = [&device]() {
        DeathMonitor::serviceDied(device->getDeathMonitor());
        DeathMonitor::serviceDied(
                reinterpret_cast<void*>(device->getDeathMonitor()->getCookieKey()));
        return ndk::ScopedAStatus::ok();
    };
    EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
@@ -846,7 +847,8 @@ TEST_P(DeviceTest, prepareModelWithConfigAsyncCrash) {
    const auto mockDevice = createMockDevice();
    const auto device = Device::create(kName, mockDevice, kVersion).value();
    const auto ret = [&device]() {
        DeathMonitor::serviceDied(device->getDeathMonitor());
        DeathMonitor::serviceDied(
                reinterpret_cast<void*>(device->getDeathMonitor()->getCookieKey()));
        return ndk::ScopedAStatus::ok();
    };
    EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _))
@@ -970,7 +972,8 @@ TEST_P(DeviceTest, prepareModelFromCacheAsyncCrash) {
    const auto mockDevice = createMockDevice();
    const auto device = Device::create(kName, mockDevice, kVersion).value();
    const auto ret = [&device]() {
        DeathMonitor::serviceDied(device->getDeathMonitor());
        DeathMonitor::serviceDied(
                reinterpret_cast<void*>(device->getDeathMonitor()->getCookieKey()));
        return ndk::ScopedAStatus::ok();
    };
    EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))