Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 32acc061 authored by Michael Butler's avatar Michael Butler
Browse files

Validate during NN conversions by default -- hal

This change renames all `convert` functions to `unvalidatedConvert`.
This change also introduces new `convert` functions that act only on the
types that appear in the NN HIDL methods directly. These new `convert`
functions perform validation. Specifically, if either the source or
destination value is invalid, then the conversion fails.

Bug: 160667419
Test: mma
Test: NeuralNetworksTest_static
Change-Id: I492956ff60ad1466c67893993d28cdd6f3860708
parent 77cbc789
Loading
Loading
Loading
Loading
+38 −23
Original line number Diff line number Diff line
@@ -24,20 +24,28 @@

namespace android::nn {

GeneralResult<OperandType> convert(const hal::V1_0::OperandType& operandType);
GeneralResult<OperationType> convert(const hal::V1_0::OperationType& operationType);
GeneralResult<Operand::LifeTime> convert(const hal::V1_0::OperandLifeTime& lifetime);
GeneralResult<DeviceStatus> convert(const hal::V1_0::DeviceStatus& deviceStatus);
GeneralResult<Capabilities::PerformanceInfo> convert(
GeneralResult<OperandType> unvalidatedConvert(const hal::V1_0::OperandType& operandType);
GeneralResult<OperationType> unvalidatedConvert(const hal::V1_0::OperationType& operationType);
GeneralResult<Operand::LifeTime> unvalidatedConvert(const hal::V1_0::OperandLifeTime& lifetime);
GeneralResult<DeviceStatus> unvalidatedConvert(const hal::V1_0::DeviceStatus& deviceStatus);
GeneralResult<Capabilities::PerformanceInfo> unvalidatedConvert(
        const hal::V1_0::PerformanceInfo& performanceInfo);
GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_0::Capabilities& capabilities);
GeneralResult<DataLocation> unvalidatedConvert(const hal::V1_0::DataLocation& location);
GeneralResult<Operand> unvalidatedConvert(const hal::V1_0::Operand& operand);
GeneralResult<Operation> unvalidatedConvert(const hal::V1_0::Operation& operation);
GeneralResult<Model::OperandValues> unvalidatedConvert(
        const hardware::hidl_vec<uint8_t>& operandValues);
GeneralResult<Memory> unvalidatedConvert(const hardware::hidl_memory& memory);
GeneralResult<Model> unvalidatedConvert(const hal::V1_0::Model& model);
GeneralResult<Request::Argument> unvalidatedConvert(
        const hal::V1_0::RequestArgument& requestArgument);
GeneralResult<Request> unvalidatedConvert(const hal::V1_0::Request& request);
GeneralResult<ErrorStatus> unvalidatedConvert(const hal::V1_0::ErrorStatus& status);

GeneralResult<DeviceStatus> convert(const hal::V1_0::DeviceStatus& deviceStatus);
GeneralResult<Capabilities> convert(const hal::V1_0::Capabilities& capabilities);
GeneralResult<DataLocation> convert(const hal::V1_0::DataLocation& location);
GeneralResult<Operand> convert(const hal::V1_0::Operand& operand);
GeneralResult<Operation> convert(const hal::V1_0::Operation& operation);
GeneralResult<Model::OperandValues> convert(const hardware::hidl_vec<uint8_t>& operandValues);
GeneralResult<Memory> convert(const hardware::hidl_memory& memory);
GeneralResult<Model> convert(const hal::V1_0::Model& model);
GeneralResult<Request::Argument> convert(const hal::V1_0::RequestArgument& requestArgument);
GeneralResult<Request> convert(const hal::V1_0::Request& request);
GeneralResult<ErrorStatus> convert(const hal::V1_0::ErrorStatus& status);

@@ -45,21 +53,28 @@ GeneralResult<ErrorStatus> convert(const hal::V1_0::ErrorStatus& status);

namespace android::hardware::neuralnetworks::V1_0::utils {

nn::GeneralResult<OperandType> convert(const nn::OperandType& operandType);
nn::GeneralResult<OperationType> convert(const nn::OperationType& operationType);
nn::GeneralResult<OperandLifeTime> convert(const nn::Operand::LifeTime& lifetime);
nn::GeneralResult<DeviceStatus> convert(const nn::DeviceStatus& deviceStatus);
nn::GeneralResult<PerformanceInfo> convert(
nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType);
nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType);
nn::GeneralResult<OperandLifeTime> unvalidatedConvert(const nn::Operand::LifeTime& lifetime);
nn::GeneralResult<DeviceStatus> unvalidatedConvert(const nn::DeviceStatus& deviceStatus);
nn::GeneralResult<PerformanceInfo> unvalidatedConvert(
        const nn::Capabilities::PerformanceInfo& performanceInfo);
nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities);
nn::GeneralResult<DataLocation> unvalidatedConvert(const nn::DataLocation& location);
nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand);
nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation);
nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
        const nn::Model::OperandValues& operandValues);
nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Memory& memory);
nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model);
nn::GeneralResult<RequestArgument> unvalidatedConvert(const nn::Request::Argument& requestArgument);
nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Request::MemoryPool& memoryPool);
nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request);
nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& status);

nn::GeneralResult<DeviceStatus> convert(const nn::DeviceStatus& deviceStatus);
nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities);
nn::GeneralResult<DataLocation> convert(const nn::DataLocation& location);
nn::GeneralResult<Operand> convert(const nn::Operand& operand);
nn::GeneralResult<Operation> convert(const nn::Operation& operation);
nn::GeneralResult<hidl_vec<uint8_t>> convert(const nn::Model::OperandValues& operandValues);
nn::GeneralResult<hidl_memory> convert(const nn::Memory& memory);
nn::GeneralResult<Model> convert(const nn::Model& model);
nn::GeneralResult<RequestArgument> convert(const nn::Request::Argument& requestArgument);
nn::GeneralResult<hidl_memory> convert(const nn::Request::MemoryPool& memoryPool);
nn::GeneralResult<Request> convert(const nn::Request& request);
nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& status);

+0 −24
Original line number Diff line number Diff line
@@ -22,25 +22,16 @@
#include <android-base/logging.h>
#include <android/hardware/neuralnetworks/1.0/types.h>
#include <nnapi/Result.h>
#include <nnapi/TypeUtils.h>
#include <nnapi/Types.h>
#include <nnapi/Validation.h>

namespace android::hardware::neuralnetworks::V1_0::utils {

constexpr auto kVersion = nn::Version::ANDROID_OC_MR1;

template <typename Type>
nn::Result<void> validate(const Type& halObject) {
    const auto maybeCanonical = nn::convert(halObject);
    if (!maybeCanonical.has_value()) {
        return nn::error() << maybeCanonical.error().message;
    }
    const auto version = NN_TRY(nn::validate(maybeCanonical.value()));
    if (version > utils::kVersion) {
        return NN_ERROR() << "Insufficient version: " << version << " vs required "
                          << utils::kVersion;
    }
    return {};
}

@@ -53,21 +44,6 @@ bool valid(const Type& halObject) {
    return result.has_value();
}

template <typename Type>
decltype(nn::convert(std::declval<Type>())) validatedConvertToCanonical(const Type& halObject) {
    auto canonical = NN_TRY(nn::convert(halObject));
    const auto maybeVersion = nn::validate(canonical);
    if (!maybeVersion.has_value()) {
        return nn::error() << maybeVersion.error();
    }
    const auto version = maybeVersion.value();
    if (version > utils::kVersion) {
        return NN_ERROR() << "Insufficient version: " << version << " vs required "
                          << utils::kVersion;
    }
    return canonical;
}

}  // namespace android::hardware::neuralnetworks::V1_0::utils

#endif  // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_0_UTILS_H
+2 −4
Original line number Diff line number Diff line
@@ -45,8 +45,7 @@ nn::GeneralResult<nn::SharedPreparedModel> convertPreparedModel(
Return<void> PreparedModelCallback::notify(ErrorStatus status,
                                           const sp<IPreparedModel>& preparedModel) {
    if (status != ErrorStatus::NONE) {
        const auto canonical =
                validatedConvertToCanonical(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
        const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
        notifyInternal(NN_ERROR(canonical) << "preparedModel failed with " << toString(status));
    } else if (preparedModel == nullptr) {
        notifyInternal(NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
@@ -73,8 +72,7 @@ void PreparedModelCallback::notifyInternal(PreparedModelCallback::Data result) {

Return<void> ExecutionCallback::notify(ErrorStatus status) {
    if (status != ErrorStatus::NONE) {
        const auto canonical =
                validatedConvertToCanonical(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
        const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
        notifyInternal(NN_ERROR(canonical) << "execute failed with " << toString(status));
    } else {
        notifyInternal({});
+150 −72

File changed.

Preview size limit exceeded, changes collapsed.

+4 −7
Original line number Diff line number Diff line
@@ -48,11 +48,10 @@ nn::GeneralResult<nn::Capabilities> initCapabilities(V1_0::IDevice* device) {
                                                 << "uninitialized";
    const auto cb = [&result](ErrorStatus status, const Capabilities& capabilities) {
        if (status != ErrorStatus::NONE) {
            const auto canonical =
                    validatedConvertToCanonical(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
            const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
            result = NN_ERROR(canonical) << "getCapabilities failed with " << toString(status);
        } else {
            result = validatedConvertToCanonical(capabilities);
            result = nn::convert(capabilities);
        }
    };

@@ -135,8 +134,7 @@ nn::GeneralResult<std::vector<bool>> Device::getSupportedOperations(const nn::Mo
                                                  << "uninitialized";
    auto cb = [&result, &model](ErrorStatus status, const hidl_vec<bool>& supportedOperations) {
        if (status != ErrorStatus::NONE) {
            const auto canonical =
                    validatedConvertToCanonical(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
            const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
            result = NN_ERROR(canonical)
                     << "getSupportedOperations failed with " << toString(status);
        } else if (supportedOperations.size() != model.main.operations.size()) {
@@ -172,8 +170,7 @@ nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
    const auto ret = kDevice->prepareModel(hidlModel, cb);
    const auto status = NN_TRY(hal::utils::handleTransportError(ret));
    if (status != ErrorStatus::NONE) {
        const auto canonical =
                validatedConvertToCanonical(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
        const auto canonical = nn::convert(status).value_or(nn::ErrorStatus::GENERAL_FAILURE);
        return NN_ERROR(canonical) << "prepareModel failed with " << toString(status);
    }

Loading