Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit af809a4a authored by Lev Proleev's avatar Lev Proleev Committed by android-build-merger
Browse files

Fix VTS tests.

am: dce38f13

Change-Id: I374b7b665526dd7b2a02fbcb36e66907fb9529b4
parents 7231bbe7 dce38f13
Loading
Loading
Loading
Loading
+32 −2
Original line number Diff line number Diff line
@@ -157,6 +157,7 @@ static uint32_t getInvalidRank(OperandType type) {
        case OperandType::UINT32:
        case OperandType::BOOL:
            return 1;
        case OperandType::TENSOR_BOOL8:
        case OperandType::TENSOR_FLOAT16:
        case OperandType::TENSOR_FLOAT32:
        case OperandType::TENSOR_INT32:
@@ -194,6 +195,7 @@ static float getInvalidScale(OperandType type) {
        case OperandType::INT32:
        case OperandType::UINT32:
        case OperandType::BOOL:
        case OperandType::TENSOR_BOOL8:
        case OperandType::TENSOR_FLOAT16:
        case OperandType::TENSOR_FLOAT32:
        case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
@@ -230,6 +232,7 @@ static std::vector<int32_t> getInvalidZeroPoints(OperandType type) {
        case OperandType::INT32:
        case OperandType::UINT32:
        case OperandType::BOOL:
        case OperandType::TENSOR_BOOL8:
        case OperandType::TENSOR_FLOAT16:
        case OperandType::TENSOR_FLOAT32:
        case OperandType::TENSOR_INT32:
@@ -283,6 +286,7 @@ static void mutateOperand(Operand* operand, OperandType type) {
            newOperand.scale = 0.0f;
            newOperand.zeroPoint = 0;
            break;
        case OperandType::TENSOR_BOOL8:
        case OperandType::TENSOR_FLOAT16:
        case OperandType::TENSOR_FLOAT32:
            newOperand.dimensions =
@@ -339,6 +343,10 @@ static bool mutateOperationOperandTypeSkip(size_t operand, OperandType type, con
        // TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
        // - CAST's argument can be any of TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
        // - RANDOM_MULTINOMIAL's argument can be either TENSOR_FLOAT16 or TENSOR_FLOAT32.
        // - DEQUANTIZE input can be any of
        // TENSOR_(QUANT8_ASYMM|QUANT8_SYMM|QUANT8_SYMM_PER_CHANNEL), output can
        // be of either TENSOR_FLOAT16 or TENSOR_FLOAT32.
        // - QUANTIZE input can be either TENSOR_FLOAT16 or TENSOR_FLOAT32
        // - CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
        // - DEPTHWISE_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
        // - GROUPED_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
@@ -357,8 +365,22 @@ static bool mutateOperationOperandTypeSkip(size_t operand, OperandType type, con
                    return true;
                }
            } break;
            case OperationType::QUANTIZE:
            case OperationType::RANDOM_MULTINOMIAL: {
                if (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32) {
                if (operand == operation.inputs[0] &&
                    (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
                    return true;
                }
            } break;
            case OperationType::DEQUANTIZE: {
                if (operand == operation.inputs[0] &&
                    (type == OperandType::TENSOR_QUANT8_ASYMM ||
                     type == OperandType::TENSOR_QUANT8_SYMM ||
                     type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)) {
                    return true;
                }
                if (operand == operation.outputs[0] &&
                    (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
                    return true;
                }
            } break;
@@ -397,7 +419,6 @@ static void mutateOperationOperandTypeTest(const sp<IDevice>& device, const Mode
///////////////////////// VALIDATE MODEL OPERATION TYPE /////////////////////////

static const uint32_t invalidOperationTypes[] = {
        static_cast<uint32_t>(OperationTypeRange::FUNDAMENTAL_MIN) - 1,
        static_cast<uint32_t>(OperationTypeRange::FUNDAMENTAL_MAX) + 1,
        static_cast<uint32_t>(OperationTypeRange::OEM_MIN) - 1,
        static_cast<uint32_t>(OperationTypeRange::OEM_MAX) + 1,
@@ -484,6 +505,15 @@ static bool removeOperandSkip(size_t operand, const Model& model) {
                }
            }
        }
        // BIDIRECTIONAL_SEQUENCE_RNN can have either on or two outputs
        // depending on a mergeOutputs parameter
        if (operation.type == OperationType::BIDIRECTIONAL_SEQUENCE_RNN) {
            for (const size_t outOprand : operation.outputs) {
                if (operand == outOprand) {
                    return true;
                }
            }
        }
    }
    return false;
}