Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit cc1fb3d1 authored by Lev Proleev's avatar Lev Proleev Committed by Android (Google) Code Review
Browse files

Merge "Fix VTS tests."

parents d9c767a4 923b8c58
Loading
Loading
Loading
Loading
+32 −2
Original line number Diff line number Diff line
@@ -157,6 +157,7 @@ static uint32_t getInvalidRank(OperandType type) {
        case OperandType::UINT32:
        case OperandType::BOOL:
            return 1;
        case OperandType::TENSOR_BOOL8:
        case OperandType::TENSOR_FLOAT16:
        case OperandType::TENSOR_FLOAT32:
        case OperandType::TENSOR_INT32:
@@ -194,6 +195,7 @@ static float getInvalidScale(OperandType type) {
        case OperandType::INT32:
        case OperandType::UINT32:
        case OperandType::BOOL:
        case OperandType::TENSOR_BOOL8:
        case OperandType::TENSOR_FLOAT16:
        case OperandType::TENSOR_FLOAT32:
        case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
@@ -230,6 +232,7 @@ static std::vector<int32_t> getInvalidZeroPoints(OperandType type) {
        case OperandType::INT32:
        case OperandType::UINT32:
        case OperandType::BOOL:
        case OperandType::TENSOR_BOOL8:
        case OperandType::TENSOR_FLOAT16:
        case OperandType::TENSOR_FLOAT32:
        case OperandType::TENSOR_INT32:
@@ -283,6 +286,7 @@ static void mutateOperand(Operand* operand, OperandType type) {
            newOperand.scale = 0.0f;
            newOperand.zeroPoint = 0;
            break;
        case OperandType::TENSOR_BOOL8:
        case OperandType::TENSOR_FLOAT16:
        case OperandType::TENSOR_FLOAT32:
            newOperand.dimensions =
@@ -339,6 +343,10 @@ static bool mutateOperationOperandTypeSkip(size_t operand, OperandType type, con
        // TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
        // - CAST's argument can be any of TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
        // - RANDOM_MULTINOMIAL's argument can be either TENSOR_FLOAT16 or TENSOR_FLOAT32.
        // - DEQUANTIZE input can be any of
        // TENSOR_(QUANT8_ASYMM|QUANT8_SYMM|QUANT8_SYMM_PER_CHANNEL), output can
        // be of either TENSOR_FLOAT16 or TENSOR_FLOAT32.
        // - QUANTIZE input can be either TENSOR_FLOAT16 or TENSOR_FLOAT32
        // - CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
        // - DEPTHWISE_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
        // - GROUPED_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
@@ -357,8 +365,22 @@ static bool mutateOperationOperandTypeSkip(size_t operand, OperandType type, con
                    return true;
                }
            } break;
            case OperationType::QUANTIZE:
            case OperationType::RANDOM_MULTINOMIAL: {
                if (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32) {
                if (operand == operation.inputs[0] &&
                    (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
                    return true;
                }
            } break;
            case OperationType::DEQUANTIZE: {
                if (operand == operation.inputs[0] &&
                    (type == OperandType::TENSOR_QUANT8_ASYMM ||
                     type == OperandType::TENSOR_QUANT8_SYMM ||
                     type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)) {
                    return true;
                }
                if (operand == operation.outputs[0] &&
                    (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
                    return true;
                }
            } break;
@@ -397,7 +419,6 @@ static void mutateOperationOperandTypeTest(const sp<IDevice>& device, const Mode
///////////////////////// VALIDATE MODEL OPERATION TYPE /////////////////////////

static const uint32_t invalidOperationTypes[] = {
        static_cast<uint32_t>(OperationTypeRange::FUNDAMENTAL_MIN) - 1,
        static_cast<uint32_t>(OperationTypeRange::FUNDAMENTAL_MAX) + 1,
        static_cast<uint32_t>(OperationTypeRange::OEM_MIN) - 1,
        static_cast<uint32_t>(OperationTypeRange::OEM_MAX) + 1,
@@ -484,6 +505,15 @@ static bool removeOperandSkip(size_t operand, const Model& model) {
                }
            }
        }
        // BIDIRECTIONAL_SEQUENCE_RNN can have either on or two outputs
        // depending on a mergeOutputs parameter
        if (operation.type == OperationType::BIDIRECTIONAL_SEQUENCE_RNN) {
            for (const size_t outOprand : operation.outputs) {
                if (operand == outOprand) {
                    return true;
                }
            }
        }
    }
    return false;
}