Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 090d3546 authored by Michael Butler's avatar Michael Butler Committed by Automerger Merge Worker
Browse files

NNAPI VTS: Add validation for Priority am: 68a6de7a

Change-Id: Ia68b2b30dc56a0783a20c07e2475c894b5fceb02
parents e904793d 68a6de7a
Loading
Loading
Loading
Loading
+9 −6
Original line number Diff line number Diff line
@@ -24,6 +24,8 @@ namespace android::hardware::neuralnetworks::V1_0::vts::functional {

using implementation::PreparedModelCallback;

using PrepareModelMutation = std::function<void(Model*)>;

///////////////////////// UTILITY FUNCTIONS /////////////////////////

static void validateGetSupportedOperations(const sp<IDevice>& device, const std::string& message,
@@ -54,12 +56,13 @@ static void validatePrepareModel(const sp<IDevice>& device, const std::string& m
}

// Primary validation function. This function will take a valid model, apply a
// mutation to it to invalidate the model, then pass it to interface calls that
// use the model. Note that the model here is passed by value, and any mutation
// to the model does not leave this function.
static void validate(const sp<IDevice>& device, const std::string& message, Model model,
                     const std::function<void(Model*)>& mutation) {
    mutation(&model);
// mutation to invalidate the model, then pass these to supportedOperations and
// prepareModel.
static void validate(const sp<IDevice>& device, const std::string& message,
                     const Model& originalModel, const PrepareModelMutation& mutate) {
    Model model = originalModel;
    mutate(&model);

    validateGetSupportedOperations(device, message, model);
    validatePrepareModel(device, message, model);
}
+6 −4
Original line number Diff line number Diff line
@@ -24,15 +24,17 @@ namespace android::hardware::neuralnetworks::V1_0::vts::functional {

using implementation::ExecutionCallback;

using ExecutionMutation = std::function<void(Request*)>;

///////////////////////// UTILITY FUNCTIONS /////////////////////////

// Primary validation function. This function will take a valid request, apply a
// mutation to it to invalidate the request, then pass it to interface calls
// that use the request. Note that the request here is passed by value, and any
// mutation to the request does not leave this function.
// that use the request.
static void validate(const sp<IPreparedModel>& preparedModel, const std::string& message,
                     Request request, const std::function<void(Request*)>& mutation) {
    mutation(&request);
                     const Request& originalRequest, const ExecutionMutation& mutate) {
    Request request = originalRequest;
    mutate(&request);
    SCOPED_TRACE(message + " [execute]");

    sp<ExecutionCallback> executionCallback = new ExecutionCallback();
+68 −50
Original line number Diff line number Diff line
@@ -30,6 +30,8 @@ using V1_0::OperandLifeTime;
using V1_0::OperandType;
using V1_0::implementation::PreparedModelCallback;

using PrepareModelMutation = std::function<void(Model*, ExecutionPreference*)>;

///////////////////////// UTILITY FUNCTIONS /////////////////////////

static void validateGetSupportedOperations(const sp<IDevice>& device, const std::string& message,
@@ -67,16 +69,19 @@ static bool validExecutionPreference(ExecutionPreference preference) {
}

// Primary validation function. This function will take a valid model, apply a
// mutation to it to invalidate the model, then pass it to interface calls that
// use the model. Note that the model here is passed by value, and any mutation
// to the model does not leave this function.
static void validate(const sp<IDevice>& device, const std::string& message, Model model,
                     const std::function<void(Model*)>& mutation,
                     ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER) {
    mutation(&model);
// mutation to invalidate either the model or the execution preference, then
// pass these to supportedOperations and/or prepareModel if that method is
// called with an invalid argument.
static void validate(const sp<IDevice>& device, const std::string& message,
                     const Model& originalModel, const PrepareModelMutation& mutate) {
    Model model = originalModel;
    ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER;
    mutate(&model, &preference);

    if (validExecutionPreference(preference)) {
        validateGetSupportedOperations(device, message, model);
    }

    validatePrepareModel(device, message, model, preference);
}

@@ -115,8 +120,10 @@ static void mutateOperandTypeTest(const sp<IDevice>& device, const Model& model)
            const std::string message = "mutateOperandTypeTest: operand " +
                                        std::to_string(operand) + " set to value " +
                                        std::to_string(invalidOperandType);
            validate(device, message, model, [operand, invalidOperandType](Model* model) {
                model->operands[operand].type = static_cast<OperandType>(invalidOperandType);
            validate(device, message, model,
                     [operand, invalidOperandType](Model* model, ExecutionPreference*) {
                         model->operands[operand].type =
                                 static_cast<OperandType>(invalidOperandType);
                     });
        }
    }
@@ -144,7 +151,8 @@ static void mutateOperandRankTest(const sp<IDevice>& device, const Model& model)
        const uint32_t invalidRank = getInvalidRank(model.operands[operand].type);
        const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) +
                                    " has rank of " + std::to_string(invalidRank);
        validate(device, message, model, [operand, invalidRank](Model* model) {
        validate(device, message, model,
                 [operand, invalidRank](Model* model, ExecutionPreference*) {
                     model->operands[operand].dimensions = std::vector<uint32_t>(invalidRank, 0);
                 });
    }
@@ -173,7 +181,8 @@ static void mutateOperandScaleTest(const sp<IDevice>& device, const Model& model
        const float invalidScale = getInvalidScale(model.operands[operand].type);
        const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) +
                                    " has scale of " + std::to_string(invalidScale);
        validate(device, message, model, [operand, invalidScale](Model* model) {
        validate(device, message, model,
                 [operand, invalidScale](Model* model, ExecutionPreference*) {
                     model->operands[operand].scale = invalidScale;
                 });
    }
@@ -204,7 +213,8 @@ static void mutateOperandZeroPointTest(const sp<IDevice>& device, const Model& m
            const std::string message = "mutateOperandZeroPointTest: operand " +
                                        std::to_string(operand) + " has zero point of " +
                                        std::to_string(invalidZeroPoint);
            validate(device, message, model, [operand, invalidZeroPoint](Model* model) {
            validate(device, message, model,
                     [operand, invalidZeroPoint](Model* model, ExecutionPreference*) {
                         model->operands[operand].zeroPoint = invalidZeroPoint;
                     });
        }
@@ -282,7 +292,8 @@ static void mutateOperationOperandTypeTest(const sp<IDevice>& device, const Mode
            const std::string message = "mutateOperationOperandTypeTest: operand " +
                                        std::to_string(operand) + " set to type " +
                                        toString(invalidOperandType);
            validate(device, message, model, [operand, invalidOperandType](Model* model) {
            validate(device, message, model,
                     [operand, invalidOperandType](Model* model, ExecutionPreference*) {
                         mutateOperand(&model->operands[operand], invalidOperandType);
                     });
        }
@@ -304,7 +315,8 @@ static void mutateOperationTypeTest(const sp<IDevice>& device, const Model& mode
            const std::string message = "mutateOperationTypeTest: operation " +
                                        std::to_string(operation) + " set to value " +
                                        std::to_string(invalidOperationType);
            validate(device, message, model, [operation, invalidOperationType](Model* model) {
            validate(device, message, model,
                     [operation, invalidOperationType](Model* model, ExecutionPreference*) {
                         model->operations[operation].type =
                                 static_cast<OperationType>(invalidOperationType);
                     });
@@ -321,7 +333,8 @@ static void mutateOperationInputOperandIndexTest(const sp<IDevice>& device, cons
            const std::string message = "mutateOperationInputOperandIndexTest: operation " +
                                        std::to_string(operation) + " input " +
                                        std::to_string(input);
            validate(device, message, model, [operation, input, invalidOperand](Model* model) {
            validate(device, message, model,
                     [operation, input, invalidOperand](Model* model, ExecutionPreference*) {
                         model->operations[operation].inputs[input] = invalidOperand;
                     });
        }
@@ -337,7 +350,8 @@ static void mutateOperationOutputOperandIndexTest(const sp<IDevice>& device, con
            const std::string message = "mutateOperationOutputOperandIndexTest: operation " +
                                        std::to_string(operation) + " output " +
                                        std::to_string(output);
            validate(device, message, model, [operation, output, invalidOperand](Model* model) {
            validate(device, message, model,
                     [operation, output, invalidOperand](Model* model, ExecutionPreference*) {
                         model->operations[operation].outputs[output] = invalidOperand;
                     });
        }
@@ -372,7 +386,7 @@ static void removeOperandTest(const sp<IDevice>& device, const Model& model) {
    for (size_t operand = 0; operand < model.operands.size(); ++operand) {
        const std::string message = "removeOperandTest: operand " + std::to_string(operand);
        validate(device, message, model,
                 [operand](Model* model) { removeOperand(model, operand); });
                 [operand](Model* model, ExecutionPreference*) { removeOperand(model, operand); });
    }
}

@@ -388,8 +402,9 @@ static void removeOperation(Model* model, uint32_t index) {
static void removeOperationTest(const sp<IDevice>& device, const Model& model) {
    for (size_t operation = 0; operation < model.operations.size(); ++operation) {
        const std::string message = "removeOperationTest: operation " + std::to_string(operation);
        validate(device, message, model,
                 [operation](Model* model) { removeOperation(model, operation); });
        validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
            removeOperation(model, operation);
        });
    }
}

@@ -409,7 +424,8 @@ static void removeOperationInputTest(const sp<IDevice>& device, const Model& mod
            const std::string message = "removeOperationInputTest: operation " +
                                        std::to_string(operation) + ", input " +
                                        std::to_string(input);
            validate(device, message, model, [operation, input](Model* model) {
            validate(device, message, model,
                     [operation, input](Model* model, ExecutionPreference*) {
                         uint32_t operand = model->operations[operation].inputs[input];
                         model->operands[operand].numberOfConsumers--;
                         hidl_vec_removeAt(&model->operations[operation].inputs, input);
@@ -426,7 +442,8 @@ static void removeOperationOutputTest(const sp<IDevice>& device, const Model& mo
            const std::string message = "removeOperationOutputTest: operation " +
                                        std::to_string(operation) + ", output " +
                                        std::to_string(output);
            validate(device, message, model, [operation, output](Model* model) {
            validate(device, message, model,
                     [operation, output](Model* model, ExecutionPreference*) {
                         hidl_vec_removeAt(&model->operations[operation].outputs, output);
                     });
        }
@@ -444,7 +461,7 @@ static void removeOperationOutputTest(const sp<IDevice>& device, const Model& mo
static void addOperationInputTest(const sp<IDevice>& device, const Model& model) {
    for (size_t operation = 0; operation < model.operations.size(); ++operation) {
        const std::string message = "addOperationInputTest: operation " + std::to_string(operation);
        validate(device, message, model, [operation](Model* model) {
        validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
            uint32_t index = addOperand(model, OperandLifeTime::MODEL_INPUT);
            hidl_vec_push_back(&model->operations[operation].inputs, index);
            hidl_vec_push_back(&model->inputIndexes, index);
@@ -458,7 +475,7 @@ static void addOperationOutputTest(const sp<IDevice>& device, const Model& model
    for (size_t operation = 0; operation < model.operations.size(); ++operation) {
        const std::string message =
                "addOperationOutputTest: operation " + std::to_string(operation);
        validate(device, message, model, [operation](Model* model) {
        validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
            uint32_t index = addOperand(model, OperandLifeTime::MODEL_OUTPUT);
            hidl_vec_push_back(&model->operations[operation].outputs, index);
            hidl_vec_push_back(&model->outputIndexes, index);
@@ -474,12 +491,13 @@ static const int32_t invalidExecutionPreferences[] = {
};

static void mutateExecutionPreferenceTest(const sp<IDevice>& device, const Model& model) {
    for (int32_t preference : invalidExecutionPreferences) {
    for (int32_t invalidPreference : invalidExecutionPreferences) {
        const std::string message =
                "mutateExecutionPreferenceTest: preference " + std::to_string(preference);
        validate(
                device, message, model, [](Model*) {},
                static_cast<ExecutionPreference>(preference));
                "mutateExecutionPreferenceTest: preference " + std::to_string(invalidPreference);
        validate(device, message, model,
                 [invalidPreference](Model*, ExecutionPreference* preference) {
                     *preference = static_cast<ExecutionPreference>(invalidPreference);
                 });
    }
}

+6 −4
Original line number Diff line number Diff line
@@ -28,15 +28,17 @@ using V1_0::IPreparedModel;
using V1_0::Request;
using V1_0::implementation::ExecutionCallback;

using ExecutionMutation = std::function<void(Request*)>;

///////////////////////// UTILITY FUNCTIONS /////////////////////////

// Primary validation function. This function will take a valid request, apply a
// mutation to it to invalidate the request, then pass it to interface calls
// that use the request. Note that the request here is passed by value, and any
// mutation to the request does not leave this function.
// that use the request.
static void validate(const sp<IPreparedModel>& preparedModel, const std::string& message,
                     Request request, const std::function<void(Request*)>& mutation) {
    mutation(&request);
                     const Request& originalRequest, const ExecutionMutation& mutate) {
    Request request = originalRequest;
    mutate(&request);
    SCOPED_TRACE(message + " [execute]");

    sp<ExecutionCallback> executionCallback = new ExecutionCallback();
+8 −6
Original line number Diff line number Diff line
@@ -37,6 +37,8 @@ using V1_0::ErrorStatus;
using V1_0::Request;
using ExecutionBurstCallback = ExecutionBurstController::ExecutionBurstCallback;

using BurstExecutionMutation = std::function<void(std::vector<FmqRequestDatum>*)>;

// This constant value represents the length of an FMQ that is large enough to
// return a result from a burst execution for all of the generated test cases.
constexpr size_t kExecutionBurstChannelLength = 1024;
@@ -115,13 +117,13 @@ static void createBurstWithResultChannelLength(

// Primary validation function. This function will take a valid serialized
// request, apply a mutation to it to invalidate the serialized request, then
// pass it to interface calls that use the serialized request. Note that the
// serialized request here is passed by value, and any mutation to the
// serialized request does not leave this function.
// pass it to interface calls that use the serialized request.
static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver,
                     const std::string& message, std::vector<FmqRequestDatum> serialized,
                     const std::function<void(std::vector<FmqRequestDatum>*)>& mutation) {
    mutation(&serialized);
                     const std::string& message,
                     const std::vector<FmqRequestDatum>& originalSerialized,
                     const BurstExecutionMutation& mutate) {
    std::vector<FmqRequestDatum> serialized = originalSerialized;
    mutate(&serialized);

    // skip if packet is too large to send
    if (serialized.size() > kExecutionBurstChannelLength) {
Loading