Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit f02692d8 authored by Michael Butler's avatar Michael Butler
Browse files

NNAPI: Add execution preference to prepareModel (HAL)

A model can be prepared in different ways to optimize for different
use-cases. This CL propagates the execution preference across the HAL so
that the NN service can better fit the users needs.

Bug: 77864669
Test: mma
Test: NeuralNetworksTest_static
Test: VtsHalNeuralnetworksV1_1TargetTest
Merged-In: Ib928d510d462f36b6a87d5e81010513db7829fa8
Change-Id: Ib928d510d462f36b6a87d5e81010513db7829fa8
(cherry picked from commit 2504c2fe)
parent 1e9f62d4
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -242,8 +242,8 @@ void Execute(const sp<V1_1::IDevice>& device, std::function<V1_1::Model(void)> c
    // launch prepare model
    sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
    ASSERT_NE(nullptr, preparedModelCallback.get());
    Return<ErrorStatus> prepareLaunchStatus =
        device->prepareModel_1_1(model, preparedModelCallback);
    Return<ErrorStatus> prepareLaunchStatus = device->prepareModel_1_1(
        model, ExecutionPreference::FAST_SINGLE_ANSWER, preparedModelCallback);
    ASSERT_TRUE(prepareLaunchStatus.isOk());
    ASSERT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(prepareLaunchStatus));

+4 −1
Original line number Diff line number Diff line
@@ -102,6 +102,8 @@ interface IDevice extends @1.0::IDevice {
     * Multiple threads can call prepareModel on the same model concurrently.
     *
     * @param model The model to be prepared for execution.
     * @param preference Indicates the intended execution behavior of a prepared
     *                   model.
     * @param callback A callback object used to return the error status of
     *                 preparing the model for execution and the prepared model
     *                 if successful, nullptr otherwise. The callback object's
@@ -115,6 +117,7 @@ interface IDevice extends @1.0::IDevice {
     *                - INVALID_ARGUMENT if one of the input arguments is
     *                  invalid
     */
    prepareModel_1_1(Model model, IPreparedModelCallback callback)
    prepareModel_1_1(Model model, ExecutionPreference preference,
                     IPreparedModelCallback callback)
          generates (ErrorStatus status);
};
+21 −0
Original line number Diff line number Diff line
@@ -382,3 +382,24 @@ struct Model {
     */
    bool relaxComputationFloat32toFloat16;
};

/**
 * Execution preferences.
 */
enum ExecutionPreference : int32_t {
    /**
     * Prefer executing in a way that minimizes battery drain.
     * This is desirable for compilations that will be executed often.
     */
    LOW_POWER = 0,
    /**
     * Prefer returning a single answer as fast as possible, even if this causes
     * more power consumption.
     */
    FAST_SINGLE_ANSWER = 1,
    /**
     * Prefer maximizing the throughput of successive frames, for example when
     * processing successive frames coming from the camera.
     */
    SUSTAINED_SPEED = 2,
};
+31 −5
Original line number Diff line number Diff line
@@ -50,13 +50,13 @@ static void validateGetSupportedOperations(const sp<IDevice>& device, const std:
}

static void validatePrepareModel(const sp<IDevice>& device, const std::string& message,
                                 const V1_1::Model& model) {
                                 const V1_1::Model& model, ExecutionPreference preference) {
    SCOPED_TRACE(message + " [prepareModel_1_1]");

    sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
    ASSERT_NE(nullptr, preparedModelCallback.get());
    Return<ErrorStatus> prepareLaunchStatus =
        device->prepareModel_1_1(model, preparedModelCallback);
        device->prepareModel_1_1(model, preference, preparedModelCallback);
    ASSERT_TRUE(prepareLaunchStatus.isOk());
    ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, static_cast<ErrorStatus>(prepareLaunchStatus));

@@ -67,15 +67,24 @@ static void validatePrepareModel(const sp<IDevice>& device, const std::string& m
    ASSERT_EQ(nullptr, preparedModel.get());
}

static bool validExecutionPreference(ExecutionPreference preference) {
    return preference == ExecutionPreference::LOW_POWER ||
           preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
           preference == ExecutionPreference::SUSTAINED_SPEED;
}

// Primary validation function. This function will take a valid model, apply a
// mutation to it to invalidate the model, then pass it to interface calls that
// use the model. Note that the model here is passed by value, and any mutation
// to the model does not leave this function.
static void validate(const sp<IDevice>& device, const std::string& message, V1_1::Model model,
                     const std::function<void(Model*)>& mutation) {
                     const std::function<void(Model*)>& mutation,
                     ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER) {
    mutation(&model);
    if (validExecutionPreference(preference)) {
        validateGetSupportedOperations(device, message, model);
    validatePrepareModel(device, message, model);
    }
    validatePrepareModel(device, message, model, preference);
}

// Delete element from hidl_vec. hidl_vec doesn't support a "remove" operation,
@@ -486,6 +495,22 @@ static void addOperationOutputTest(const sp<IDevice>& device, const V1_1::Model&
    }
}

///////////////////////// VALIDATE EXECUTION PREFERENCE /////////////////////////

static const int32_t invalidExecutionPreferences[] = {
    static_cast<int32_t>(ExecutionPreference::LOW_POWER) - 1,        // lower bound
    static_cast<int32_t>(ExecutionPreference::SUSTAINED_SPEED) + 1,  // upper bound
};

static void mutateExecutionPreferenceTest(const sp<IDevice>& device, const V1_1::Model& model) {
    for (int32_t preference : invalidExecutionPreferences) {
        const std::string message =
            "mutateExecutionPreferenceTest: preference " + std::to_string(preference);
        validate(device, message, model, [](Model*) {},
                 static_cast<ExecutionPreference>(preference));
    }
}

////////////////////////// ENTRY POINT //////////////////////////////

void ValidationTest::validateModel(const V1_1::Model& model) {
@@ -503,6 +528,7 @@ void ValidationTest::validateModel(const V1_1::Model& model) {
    removeOperationOutputTest(device, model);
    addOperationInputTest(device, model);
    addOperationOutputTest(device, model);
    mutateExecutionPreferenceTest(device, model);
}

}  // namespace functional
+2 −2
Original line number Diff line number Diff line
@@ -60,8 +60,8 @@ static void createPreparedModel(const sp<IDevice>& device, const V1_1::Model& mo
    // launch prepare model
    sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
    ASSERT_NE(nullptr, preparedModelCallback.get());
    Return<ErrorStatus> prepareLaunchStatus =
        device->prepareModel_1_1(model, preparedModelCallback);
    Return<ErrorStatus> prepareLaunchStatus = device->prepareModel_1_1(
        model, ExecutionPreference::FAST_SINGLE_ANSWER, preparedModelCallback);
    ASSERT_TRUE(prepareLaunchStatus.isOk());
    ASSERT_EQ(ErrorStatus::NONE, static_cast<ErrorStatus>(prepareLaunchStatus));