Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 667dc2dc authored by Michael Butler's avatar Michael Butler
Browse files

Add recovery code to NN ResilientPreparedModel and *Buffer

Prior to this CL, ResilientPreparedModel and ResilientBuffer were
passthrough interfaces that just forwarded calls to the underlying
interface object. This CL implements the full recovery mechanism for
these two classes. However, because we do not want to enable this
functionality in the NN runtime yet, ResilientDevice hides the paths
that create ResilientPreparedModel and ResilientBuffer behind an #if
until we want to enable those paths.

Bug: N/A
Test: mma
Change-Id: Idfe8093c63c7ba2f16c995eec872d150696e7a08
parent 1fd4b9d0
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -42,7 +42,7 @@ class ResilientBuffer final : public nn::IBuffer {
                             nn::SharedBuffer buffer);

    nn::SharedBuffer getBuffer() const;
    nn::SharedBuffer recover(const nn::IBuffer* failingBuffer, bool blocking) const;
    nn::GeneralResult<nn::SharedBuffer> recover(const nn::IBuffer* failingBuffer) const;

    nn::Request::MemoryDomainToken getToken() const override;

+2 −2
Original line number Diff line number Diff line
@@ -43,8 +43,8 @@ class ResilientPreparedModel final : public nn::IPreparedModel {
                                    nn::SharedPreparedModel preparedModel);

    nn::SharedPreparedModel getPreparedModel() const;
    nn::SharedPreparedModel recover(const nn::IPreparedModel* failingPreparedModel,
                                    bool blocking) const;
    nn::GeneralResult<nn::SharedPreparedModel> recover(
            const nn::IPreparedModel* failingPreparedModel) const;

    nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
            const nn::Request& request, nn::MeasureTiming measure,
+44 −4
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@
#include <android-base/thread_annotations.h>
#include <nnapi/IBuffer.h>
#include <nnapi/Result.h>
#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>

#include <functional>
@@ -29,6 +30,34 @@
#include <vector>

namespace android::hardware::neuralnetworks::utils {
namespace {

template <typename FnType>
auto protect(const ResilientBuffer& resilientBuffer, const FnType& fn)
        -> decltype(fn(*resilientBuffer.getBuffer())) {
    auto buffer = resilientBuffer.getBuffer();
    auto result = fn(*buffer);

    // Immediately return if device is not dead.
    if (result.has_value() || result.error().code != nn::ErrorStatus::DEAD_OBJECT) {
        return result;
    }

    // Attempt recovery and return if it fails.
    auto maybeBuffer = resilientBuffer.recover(buffer.get());
    if (!maybeBuffer.has_value()) {
        const auto& [resultErrorMessage, resultErrorCode] = result.error();
        const auto& [recoveryErrorMessage, recoveryErrorCode] = maybeBuffer.error();
        return nn::error(resultErrorCode)
               << resultErrorMessage << ", and failed to recover dead buffer with error "
               << recoveryErrorCode << ": " << recoveryErrorMessage;
    }
    buffer = std::move(maybeBuffer).value();

    return fn(*buffer);
}

}  // namespace

nn::GeneralResult<std::shared_ptr<const ResilientBuffer>> ResilientBuffer::create(
        Factory makeBuffer) {
@@ -53,9 +82,16 @@ nn::SharedBuffer ResilientBuffer::getBuffer() const {
    std::lock_guard guard(mMutex);
    return mBuffer;
}
nn::SharedBuffer ResilientBuffer::recover(const nn::IBuffer* /*failingBuffer*/,
                                          bool /*blocking*/) const {
nn::GeneralResult<nn::SharedBuffer> ResilientBuffer::recover(
        const nn::IBuffer* failingBuffer) const {
    std::lock_guard guard(mMutex);

    // Another caller updated the failing prepared model.
    if (mBuffer.get() != failingBuffer) {
        return mBuffer;
    }

    mBuffer = NN_TRY(kMakeBuffer());
    return mBuffer;
}

@@ -64,12 +100,16 @@ nn::Request::MemoryDomainToken ResilientBuffer::getToken() const {
}

nn::GeneralResult<void> ResilientBuffer::copyTo(const nn::Memory& dst) const {
    return getBuffer()->copyTo(dst);
    const auto fn = [&dst](const nn::IBuffer& buffer) { return buffer.copyTo(dst); };
    return protect(*this, fn);
}

nn::GeneralResult<void> ResilientBuffer::copyFrom(const nn::Memory& src,
                                                  const nn::Dimensions& dimensions) const {
    return getBuffer()->copyFrom(src, dimensions);
    const auto fn = [&src, &dimensions](const nn::IBuffer& buffer) {
        return buffer.copyFrom(src, dimensions);
    };
    return protect(*this, fn);
}

}  // namespace android::hardware::neuralnetworks::utils
+16 −3
Original line number Diff line number Diff line
@@ -180,6 +180,7 @@ nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModel(
        const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
        nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
        const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
#if 0
    auto self = shared_from_this();
    ResilientPreparedModel::Factory makePreparedModel = [device = std::move(self), model,
                                                         preference, priority, deadline, modelCache,
@@ -188,29 +189,41 @@ nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModel(
                                            dataCache, token);
    };
    return ResilientPreparedModel::create(std::move(makePreparedModel));
#else
    return prepareModelInternal(model, preference, priority, deadline, modelCache, dataCache,
                                token);
#endif
}

nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModelFromCache(
        nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
        const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
#if 0
    auto self = shared_from_this();
    ResilientPreparedModel::Factory makePreparedModel = [device = std::move(self), deadline,
                                                         modelCache, dataCache, token] {
        return device->prepareModelFromCacheInternal(deadline, modelCache, dataCache, token);
    };
    return ResilientPreparedModel::create(std::move(makePreparedModel));
#else
    return prepareModelFromCacheInternal(deadline, modelCache, dataCache, token);
#endif
}

nn::GeneralResult<nn::SharedBuffer> ResilientDevice::allocate(
        const nn::BufferDesc& desc, const std::vector<nn::SharedPreparedModel>& preparedModels,
        const std::vector<nn::BufferRole>& inputRoles,
        const std::vector<nn::BufferRole>& outputRoles) const {
#if 0
    auto self = shared_from_this();
    ResilientBuffer::Factory makeBuffer = [device = std::move(self), desc, preparedModels,
                                           inputRoles, outputRoles] {
        return device->allocateInternal(desc, preparedModels, inputRoles, outputRoles);
    };
    return ResilientBuffer::create(std::move(makeBuffer));
#else
    return allocateInternal(desc, preparedModels, inputRoles, outputRoles);
#endif
}

bool ResilientDevice::isValidInternal() const {
@@ -225,8 +238,8 @@ nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModelInternal
    if (!isValidInternal()) {
        return std::make_shared<const InvalidPreparedModel>();
    }
    const auto fn = [&model, preference, priority, deadline, &modelCache, &dataCache,
                     token](const nn::IDevice& device) {
    const auto fn = [&model, preference, priority, &deadline, &modelCache, &dataCache,
                     &token](const nn::IDevice& device) {
        return device.prepareModel(model, preference, priority, deadline, modelCache, dataCache,
                                   token);
    };
@@ -239,7 +252,7 @@ nn::GeneralResult<nn::SharedPreparedModel> ResilientDevice::prepareModelFromCach
    if (!isValidInternal()) {
        return std::make_shared<const InvalidPreparedModel>();
    }
    const auto fn = [deadline, &modelCache, &dataCache, token](const nn::IDevice& device) {
    const auto fn = [&deadline, &modelCache, &dataCache, &token](const nn::IDevice& device) {
        return device.prepareModelFromCache(deadline, modelCache, dataCache, token);
    };
    return protect(*this, fn, /*blocking=*/false);
+50 −5
Original line number Diff line number Diff line
@@ -20,15 +20,45 @@
#include <android-base/thread_annotations.h>
#include <nnapi/IPreparedModel.h>
#include <nnapi/Result.h>
#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>

#include <functional>
#include <memory>
#include <mutex>
#include <sstream>
#include <utility>
#include <vector>

namespace android::hardware::neuralnetworks::utils {
namespace {

template <typename FnType>
auto protect(const ResilientPreparedModel& resilientPreparedModel, const FnType& fn)
        -> decltype(fn(*resilientPreparedModel.getPreparedModel())) {
    auto preparedModel = resilientPreparedModel.getPreparedModel();
    auto result = fn(*preparedModel);

    // Immediately return if prepared model is not dead.
    if (result.has_value() || result.error().code != nn::ErrorStatus::DEAD_OBJECT) {
        return result;
    }

    // Attempt recovery and return if it fails.
    auto maybePreparedModel = resilientPreparedModel.recover(preparedModel.get());
    if (!maybePreparedModel.has_value()) {
        const auto& [message, code] = maybePreparedModel.error();
        std::ostringstream oss;
        oss << ", and failed to recover dead prepared model with error " << code << ": " << message;
        result.error().message += oss.str();
        return result;
    }
    preparedModel = std::move(maybePreparedModel).value();

    return fn(*preparedModel);
}

}  // namespace

nn::GeneralResult<std::shared_ptr<const ResilientPreparedModel>> ResilientPreparedModel::create(
        Factory makePreparedModel) {
@@ -55,9 +85,16 @@ nn::SharedPreparedModel ResilientPreparedModel::getPreparedModel() const {
    return mPreparedModel;
}

nn::SharedPreparedModel ResilientPreparedModel::recover(
        const nn::IPreparedModel* /*failingPreparedModel*/, bool /*blocking*/) const {
nn::GeneralResult<nn::SharedPreparedModel> ResilientPreparedModel::recover(
        const nn::IPreparedModel* failingPreparedModel) const {
    std::lock_guard guard(mMutex);

    // Another caller updated the failing prepared model.
    if (mPreparedModel.get() != failingPreparedModel) {
        return mPreparedModel;
    }

    mPreparedModel = NN_TRY(kMakePreparedModel());
    return mPreparedModel;
}

@@ -65,7 +102,11 @@ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>
ResilientPreparedModel::execute(const nn::Request& request, nn::MeasureTiming measure,
                                const nn::OptionalTimePoint& deadline,
                                const nn::OptionalDuration& loopTimeoutDuration) const {
    return getPreparedModel()->execute(request, measure, deadline, loopTimeoutDuration);
    const auto fn = [&request, measure, &deadline,
                     &loopTimeoutDuration](const nn::IPreparedModel& preparedModel) {
        return preparedModel.execute(request, measure, deadline, loopTimeoutDuration);
    };
    return protect(*this, fn);
}

nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
@@ -75,8 +116,12 @@ ResilientPreparedModel::executeFenced(const nn::Request& request,
                                      const nn::OptionalTimePoint& deadline,
                                      const nn::OptionalDuration& loopTimeoutDuration,
                                      const nn::OptionalDuration& timeoutDurationAfterFence) const {
    return getPreparedModel()->executeFenced(request, waitFor, measure, deadline,
                                             loopTimeoutDuration, timeoutDurationAfterFence);
    const auto fn = [&request, &waitFor, measure, &deadline, &loopTimeoutDuration,
                     &timeoutDurationAfterFence](const nn::IPreparedModel& preparedModel) {
        return preparedModel.executeFenced(request, waitFor, measure, deadline, loopTimeoutDuration,
                                           timeoutDurationAfterFence);
    };
    return protect(*this, fn);
}

std::any ResilientPreparedModel::getUnderlyingResource() const {