Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit b1b39afe authored by Michael Butler's avatar Michael Butler Committed by Android (Google) Code Review
Browse files

Merge "Update NN VTS callback objects"

parents c2689156 23d0e562
Loading
Loading
Loading
Loading
+44 −96
Original line number Diff line number Diff line
@@ -14,130 +14,78 @@
 * limitations under the License.
 */

#define LOG_TAG "Callbacks"

#include "1.0/Callbacks.h"
#include <android-base/logging.h>

namespace android {
namespace hardware {
namespace neuralnetworks {
namespace V1_0 {
namespace implementation {

CallbackBase::CallbackBase() : mNotified(false) {}

CallbackBase::~CallbackBase() {
    // Note that we cannot call CallbackBase::join_thread from here:
    // CallbackBase is intended to be reference counted, and it is possible that
    // the reference count drops to zero in the bound thread, causing the
    // bound thread to call this destructor. If a thread tries to join
    // itself, it throws an exception, producing a message like the
    // following:
    //
    //     terminating with uncaught exception of type std::__1::system_error:
    //     thread::join failed: Resource deadlock would occur
}
#include <android-base/logging.h>

void CallbackBase::wait() {
    std::unique_lock<std::mutex> lock(mMutex);
    mCondition.wait(lock, [this]{return mNotified;});
    join_thread_locked();
}
namespace android::hardware::neuralnetworks::V1_0::implementation {

bool CallbackBase::on_finish(std::function<bool(void)> post_work) {
    std::lock_guard<std::mutex> lock(mMutex);
    if (mPostWork != nullptr) {
        LOG(ERROR) << "CallbackBase::on_finish -- a post-work function has already been bound to "
                   "this callback object";
        return false;
    }
    if (post_work == nullptr) {
        LOG(ERROR) << "CallbackBase::on_finish -- the new post-work function is invalid";
        return false;
    }
    mPostWork = std::move(post_work);
    return true;
}
// PreparedModelCallback methods begin here

bool CallbackBase::bind_thread(std::thread&& asyncThread) {
    std::lock_guard<std::mutex> lock(mMutex);
    if (mThread.joinable()) {
        LOG(ERROR) << "CallbackBase::bind_thread -- a thread has already been bound to this "
                   "callback object";
        return false;
    }
    if (!asyncThread.joinable()) {
        LOG(ERROR) << "CallbackBase::bind_thread -- the new thread is not joinable";
        return false;
    }
    mThread = std::move(asyncThread);
    return true;
}
Return<void> PreparedModelCallback::notify(ErrorStatus errorStatus,
                                           const sp<IPreparedModel>& preparedModel) {
    {
        std::lock_guard<std::mutex> hold(mMutex);

void CallbackBase::join_thread() {
    std::lock_guard<std::mutex> lock(mMutex);
    join_thread_locked();
        // quick-return if object has already been notified
        if (mNotified) {
            return Void();
        }

void CallbackBase::notify() {
    {
        std::lock_guard<std::mutex> lock(mMutex);
        // store results and mark as notified
        mErrorStatus = errorStatus;
        mPreparedModel = preparedModel;
        mNotified = true;
        if (mPostWork != nullptr) {
            bool success = mPostWork();
            if (!success) {
                LOG(ERROR) << "CallbackBase::notify -- post work failed";
            }
        }
    }
    mCondition.notify_all();
    }

void CallbackBase::join_thread_locked() {
    if (mThread.joinable()) {
        mThread.join();
    }
    mCondition.notify_all();
    return Void();
}

PreparedModelCallback::PreparedModelCallback() :
        mErrorStatus(ErrorStatus::GENERAL_FAILURE), mPreparedModel(nullptr) {}

PreparedModelCallback::~PreparedModelCallback() {}

Return<void> PreparedModelCallback::notify(ErrorStatus errorStatus,
                                           const sp<V1_0::IPreparedModel>& preparedModel) {
    mErrorStatus = errorStatus;
    mPreparedModel = preparedModel;
    CallbackBase::notify();
    return Void();
void PreparedModelCallback::wait() const {
    std::unique_lock<std::mutex> lock(mMutex);
    mCondition.wait(lock, [this] { return mNotified; });
}

ErrorStatus PreparedModelCallback::getStatus() {
ErrorStatus PreparedModelCallback::getStatus() const {
    wait();
    return mErrorStatus;
}

sp<V1_0::IPreparedModel> PreparedModelCallback::getPreparedModel() {
sp<IPreparedModel> PreparedModelCallback::getPreparedModel() const {
    wait();
    return mPreparedModel;
}

ExecutionCallback::ExecutionCallback() : mErrorStatus(ErrorStatus::GENERAL_FAILURE) {}

ExecutionCallback::~ExecutionCallback() {}
// ExecutionCallback methods begin here

Return<void> ExecutionCallback::notify(ErrorStatus errorStatus) {
    {
        std::lock_guard<std::mutex> hold(mMutex);

        // quick-return if object has already been notified
        if (mNotified) {
            return Void();
        }

        mErrorStatus = errorStatus;
    CallbackBase::notify();
        mNotified = true;
    }
    mCondition.notify_all();

    return Void();
}

ErrorStatus ExecutionCallback::getStatus() {
void ExecutionCallback::wait() const {
    std::unique_lock<std::mutex> lock(mMutex);
    mCondition.wait(lock, [this] { return mNotified; });
}

ErrorStatus ExecutionCallback::getStatus() const {
    wait();
    return mErrorStatus;
}

}  // namespace implementation
}  // namespace V1_0
}  // namespace neuralnetworks
}  // namespace hardware
}  // namespace android
}  // namespace android::hardware::neuralnetworks::V1_0::implementation
+107 −230
Original line number Diff line number Diff line
@@ -17,184 +17,62 @@
#ifndef ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
#define ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H

#include <android-base/thread_annotations.h>
#include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
#include <android/hardware/neuralnetworks/1.0/IPreparedModelCallback.h>
#include <hidl/Status.h>
#include <chrono>
#include <condition_variable>
#include <functional>
#include <mutex>
#include <thread>

namespace android {
namespace hardware {
namespace neuralnetworks {
namespace V1_0 {
namespace implementation {

/**
 * The CallbackBase class is used internally by the NeuralNetworks runtime to
/*
 * The Callback classes are used internally by the NeuralNetworks runtime to
 * synchronize between different threads. An asynchronous task is launched
 * paired with a callback object. When a client thread requires the output being
 * generated by the asynchronous task, the client thread can wait for the result
 * and be blocked until it has completed or a timeout condition has been
 * reached. Any wait* may safely be called concurrently, even on the same
 * callback object. When the asynchronous task has finished its workload, it
 * must immediately call "notify". If the asynchronous task has failed to launch,
 * the function that tried to launch the asynchronous task must immediately call
 * "notify". This "notify" call awakens any client threads waiting on the
 * callback object.
 *
 * The CallbackBase class implements some of the base synchronization common to
 * both PrepareModelCallback and ExecutionCallback. For consistency, any HIDL
 * callback class must inherit from CallbackBase as well as the HIDL callback
 * interface it implements.
 *
 * This class exists to enable synchronization across HIDL. When synchronization
 * is only required in the same process, consider using std::future, std::mutex,
 * std::condition_variable, or std::experimental::latch instead.
 */
class CallbackBase {
  public:
    CallbackBase();
    ~CallbackBase();

    /**
     * CallbackBase::wait blocks until notify has been called on the callback
     * object.
     */
    void wait();

    /**
     * CallbackBase::wait_for blocks until notify has been called on the
     * callback object or the time duration from the time the wait_for function
     * was called has expired, whichever comes first.
     *
     * @return Status std::cv_status::no_timeout if the callback was notified
     *                before the time duration expired, std::cv_status::timeout
     *                otherwise.
     */
    template <class Rep, class Period>
    std::cv_status wait_for(const std::chrono::duration<Rep, Period>& timeout_duration);

    /**
     * CallbackBase::on_finish binds a function to the callback object. This
     * bound function will be executed when CallbackBase::notify is called,
     * before any calls to wait* return. (Note that CallbackBase::wait_for can
     * return std::cv_status::timeout before CallbackBase::notify is called for
     * the first time, and hence before the bound function is executed.)
     *
     * The bound function must not synchronize with or otherwise access the
     * callback object it is bound to, as this could cause a deadlock.
     *
     * CallbackBase::on_finish can be called at most once on a given callback
     * object, and the call to CallbackBase::on_finish must finish before
     * CallbackBase::notify is called.
     *
     * @param post_work Function to be invoked the first time
     *                  CallbackBase::notify is called. Must have a target --
     *                  i.e., must not compare equal to nullptr. post_work
     *                  returns true if it successfully completes, false if it
     *                  fails.
     * @return bool True if the function was successfully bound, false if
     *              unsuccessful.
     *
     * TODO: Why does the return value of the callback matter?
     */
    bool on_finish(std::function<bool(void)> post_work);

    /**
     * CallbackBase::bind_thread binds a thread to the event for later use by
     * CallbackBase::join_thread.
     *
     * The thread must be passed using std::move.
     *
     * Once a thread is bound with CallbackBase::bind_thread, the client code
     * should ensure that one of the following occurs before the event is
     * destroyed:
     * - CallbackBase::join_thread has been called.
     * - CallbackBase::wait has been called.
     * - CallbackBase::wait_for has been called and returned other than
     *   std::cv_status::no_timeout.
     *
     * The bound thread shall not call any CallbackBase method with the
     * exception of CallbackBase::notify, which it must call when the thread has
     * finished its computation.
     *
     * CallbackBase::bind_thread can be called at most once on a given callback
     * object.
     *
     * @param asyncThread Thread to be bound to the callback object. The thread
     *                    object must represent a thread of execution -- i.e.,
     *                    asyncThread.joinable() must be true.
     * @return bool True if successful, false if thread was not properly bound.
 * and be blocked until it has completed. Any wait may safely be called
 * concurrently, even on the same callback object. When the asynchronous task
 * has finished its workload, it must immediately call "notify". If the
 * asynchronous task has failed to launch, the function that tried to launch the
 * asynchronous task must immediately call "notify". This "notify" call
 * awakens any client threads waiting on the callback object.
 *
 * These classes exist to enable synchronization across HIDL. When
 * synchronization is only required in the same process, consider using
 * std::future, std::mutex, std::condition_variable, or std::experimental::latch
 * instead.
 */
    bool bind_thread(std::thread&& asyncThread);

    /**
     * CallbackBase::join_thread ensures that the thread (if any) bound to this
     * event with CallbackBase::bind_thread has fully finished and cleaned its
     * resources. It is legal to call this function multiple times, concurrently
     * or sequentially.
     */
    void join_thread();

  protected:
    /**
     * CallbackBase::notify enables all prior and future wait* calls on the
     * callback object to proceed. The call to CallbackBase::notify happens
     * before any wait* calls on this callback object return (except in the case
     * of wait_for timing out). The asynchronous call the callback object is
     * paired with must ensure that any update to state that should be visible
     * to the caller of wait* happens before the call to CallbackBase::notify.
     *
     * CallbackBase::notify must be called exactly once on a given callback
     * object.
     */
    void notify();

  private:
    // Same as CallbackBase::join_thread but assumes we already hold a lock on
    // mMutex.
    void join_thread_locked();

    bool mNotified;
    std::mutex mMutex;
    std::condition_variable mCondition;
    std::function<bool(void)> mPostWork;
    std::thread mThread;
};
namespace android::hardware::neuralnetworks::V1_0::implementation {

/**
 * The PreparedModelCallback class is used to receive the error status of
 * preparing a model as well as the prepared model from a task executing
 * asynchronously with respect to the runtime. If a calling thread calls wait*
 * asynchronously with respect to the runtime. If a calling thread calls wait
 * or get* on a PreparedModelCallback object and the corresponding asynchronous
 * task has not finished preparing the model, the calling thread will block
 * until the asynchronous task has called notify. For more information on the
 * synchronization behavior, refer to the CallbackBase class.
 * until the asynchronous task has called notify.
 *
 * This class inherits the basic blocking and signaling calls from
 * CallbackBase, and implements the HIDL notify call from
 * IPreparedModelCallback. This callback object is passed as an argument to
 * IDevice::prepareModel.
 * If the callback object is notified more than once, only the results of the
 * first call to notify are used, and the results from subsequent calls are
 * discarded.
 *
 * This callback object is passed as an argument to IDevice::prepareModel*.
 */
class PreparedModelCallback : public CallbackBase, public IPreparedModelCallback {
class PreparedModelCallback : public IPreparedModelCallback {
  public:
    PreparedModelCallback();
    ~PreparedModelCallback() override;

    /**
     * IPreparedModelCallback::notify marks the callback object with the return
     * status of the asynchronous model preparation along with the prepared
     * model and calls CallbackBase::notify, enabling all prior and future
     * wait* calls on the PreparedModelCallback object to proceed.
     * For more information on the synchronization behavior, refer to the
     * CallbackBase class.
     * model, and allows all prior and future wait calls on the
     * PreparedModelCallback object to proceed.
     *
     * IPreparedModelCallback::notify must be called exactly once on a given
     * IPreparedModelCallback::notify must be called on a given
     * PreparedModelCallback object.
     *
     * If the callback object is notified more than once, only the results of
     * the first call to notify are used, and the results from subsequent calls
     * are discarded.
     *
     * @param status Error status returned from asynchronously preparing the
     *     model; will be:
     *     - NONE if the asynchronous preparation was successful
@@ -204,11 +82,17 @@ class PreparedModelCallback : public CallbackBase, public IPreparedModelCallback
     * @param preparedModel Returned model that has been prepared for execution,
     *     nullptr if the model was unable to be prepared.
     */
    Return<void> notify(ErrorStatus status, const sp<V1_0::IPreparedModel>& preparedModel) override;
    Return<void> notify(ErrorStatus status, const sp<IPreparedModel>& preparedModel) override;

    /**
     * PreparedModelCallback::wait blocks until notify has been called on the
     * callback object.
     */
    void wait() const;

    /**
     * Retrieves the error status returned from the asynchronous task launched
     * by IDevice::prepareModel. If IDevice::prepareModel has not finished
     * by IDevice::prepareModel*. If IDevice::prepareModel* has not finished
     * asynchronously preparing the model, this call will block until the
     * asynchronous task notifies the object.
     *
@@ -219,66 +103,74 @@ class PreparedModelCallback : public CallbackBase, public IPreparedModelCallback
     *     - GENERAL_FAILURE if there is an unspecified error
     *     - INVALID_ARGUMENT if the input model is invalid
     */
    ErrorStatus getStatus();
    ErrorStatus getStatus() const;

    /**
     * Retrieves the model that has been prepared for execution from the
     * asynchronous task launched by IDevice::prepareModel. If
     * IDevice::prepareModel has not finished asynchronously preparing the
     * asynchronous task launched by IDevice::prepareModel*. If
     * IDevice::prepareModel* has not finished asynchronously preparing the
     * model, this call will block until the asynchronous task notifies the
     * object.
     *
     * @return preparedModel Returned model that has been prepared for
     *                       execution, nullptr if the model was unable to be
     *                       prepared.
     *     execution, nullptr if the model was unable to be prepared.
     */
    sp<V1_0::IPreparedModel> getPreparedModel();
    sp<IPreparedModel> getPreparedModel() const;

  private:
    ErrorStatus mErrorStatus;
    sp<V1_0::IPreparedModel> mPreparedModel;
    mutable std::mutex mMutex;
    mutable std::condition_variable mCondition;
    bool mNotified GUARDED_BY(mMutex) = false;
    ErrorStatus mErrorStatus = ErrorStatus::GENERAL_FAILURE;
    sp<IPreparedModel> mPreparedModel;
};

/**
 * The ExecutionCallback class is used to receive the error status of the
 * execution from a task executing asynchronously with respect to the runtime.
 * If a calling thread calls wait* or get* on a PreparedModelCallback object and
 * the corresponding asynchronous task has not finished the execution, the
 * calling thread will block until the asynchronous task has called notify.
 * For more information on the synchronization behavior, refer to the
 * CallbackBase class.
 * The ExecutionCallback class is used to receive the results of the execution
 * from a task executing asynchronously with respect to the runtime. If a
 * calling thread calls wait or get* on a ExecutionCallback object and the
 * corresponding asynchronous task has not finished the execution, the calling
 * thread will block until the asynchronous task has called notify.
 *
 * If the callback object is notified more than once, only the results of the
 * first call to notify are used, and the results from subsequent calls are
 * discarded.
 *
 * This class inherits the basic blocking and signaling calls from
 * CallbackBase, and implements the HIDL notify call from IExecutionCallback.
 * This callback object is passed as an argument to IPreparedModel::execute.
 * This callback object is passed as an argument to IPreparedModel::execute*.
 */
class ExecutionCallback : public CallbackBase, public IExecutionCallback {
class ExecutionCallback : public IExecutionCallback {
  public:
    ExecutionCallback();
    ~ExecutionCallback() override;

    /**
     * IExecutionCallback::notify marks the callback object with the return
     * status of the asynchronous execution that held this callback and enable
     * all prior and future wait* calls on the ExecutionCallback object to
     * proceed. For more information on the synchronization behavior, refer to
     * the CallbackBase class.
     * status of the asynchronous execution that held this callback and enables
     * all prior and future wait calls on the ExecutionCallback object to
     * proceed.
     *
     * IExecutionCallback::notify must be called exactly once on a given
     * ExecutionCallback object.
     * IExecutionCallback::notify must be called on a given ExecutionCallback
     * object.
     *
     * If the callback object is notified more than once, only the results of
     * the first call to notify are used, and the results from subsequent calls
     * are discarded.
     *
     * @param status Error status returned from launching the asynchronous task
     *               (if the launch fails) or from the asynchronous task itself
     *               (if the launch succeeds). Must be:
     *     (if the launch fails) or from the asynchronous task itself (if the
     *     launch succeeds). Must be:
     *     - NONE if the asynchronous execution was successful
     *     - DEVICE_UNAVAILABLE if driver is offline or busy
     *     - GENERAL_FAILURE if there is an unspecified error
     *               - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
     *                 not large enough to store the resultant values
     *     - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is not large
     *         enough to store the resultant values
     *     - INVALID_ARGUMENT if the input request is invalid
     */
    Return<void> notify(ErrorStatus status) override;

    /**
     * ExecutionCallback::wait blocks until notify has been called on the
     * callback object.
     */
    void wait() const;

    /**
     * Retrieves the error status returned from the asynchronous task launched
     * by IPreparedModel::execute. If IPreparedModel::execute has not finished
@@ -286,41 +178,26 @@ class ExecutionCallback : public CallbackBase, public IExecutionCallback {
     * task notifies the object.
     *
     * @return status Error status returned from launching the asynchronous task
     *                (if the launch fails) or from the asynchronous task itself
     *                (if the launch succeeds). Must be:
     *     (if the launch fails) or from the asynchronous task itself (if the
     *     launch succeeds). Must be:
     *     - NONE if the asynchronous execution was successful
     *     - DEVICE_UNAVAILABLE if driver is offline or busy
     *                - GENERAL_FAILURE if the asynchronous task resulted in an
     *                  unspecified error
     *                - OUTPUT_INSUFFICIENT_SIZE if at least one output
     *                  operand buffer is not large enough to store the
     *                  corresponding output
     *                - INVALID_ARGUMENT if one of the input arguments to
     *                  prepareModel is invalid
     *     - GENERAL_FAILURE if the asynchronous task resulted in an unspecified
     *         error
     *     - OUTPUT_INSUFFICIENT_SIZE if at least one output operand buffer is
     *         not large enough to store the corresponding output
     *     - INVALID_ARGUMENT if one of the input arguments to prepareModel is
     *         invalid
     */
    ErrorStatus getStatus();
    ErrorStatus getStatus() const;

  private:
    mutable std::mutex mMutex;
    mutable std::condition_variable mCondition;
    bool mNotified GUARDED_BY(mMutex) = false;
    ErrorStatus mErrorStatus = ErrorStatus::GENERAL_FAILURE;
};

// template function implementation(s) below this point

template <class Rep, class Period>
std::cv_status CallbackBase::wait_for(const std::chrono::duration<Rep, Period>& timeout_duration) {
    std::unique_lock<std::mutex> lock(mMutex);
    std::cv_status status =
            mCondition.wait_for(lock, timeout_duration, [this] { return mNotified; });
    if (status != std::cv_status::timeout) {
        join_thread_locked();
    }
    return status;
}

}  // namespace implementation
}  // namespace V1_0
}  // namespace neuralnetworks
}  // namespace hardware
}  // namespace android
}  // namespace android::hardware::neuralnetworks::V1_0::implementation

#endif  // ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
+75 −107

File changed.

Preview size limit exceeded, changes collapsed.

+1 −0
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@
#include <cstdio>
#include <cstdlib>
#include <random>
#include <thread>

#include "1.2/Callbacks.h"
#include "GeneratedTestHarness.h"
+202 −277

File changed.

Preview size limit exceeded, changes collapsed.