Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e5499d31 authored by Przemyslaw Szczepaniak's avatar Przemyslaw Szczepaniak Committed by Android (Google) Code Review
Browse files

Merge "Add TENSOR_QUANT8_SYMM_PER_CHANNEL to operand types"

parents 8032fd68 faa59b8a
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ hidl_interface {
        "android.hardware.neuralnetworks@1.0",
        "android.hardware.neuralnetworks@1.1",
        "android.hidl.base@1.0",
        "android.hidl.safe_union@1.0",
    ],
    types: [
        "Model",
+67 −1
Original line number Diff line number Diff line
@@ -22,6 +22,8 @@ import @1.0::OperandType;
import @1.0::PerformanceInfo;
import @1.1::OperationType;

import android.hidl.safe_union@1.0::Monostate;

enum OperandType : @1.0::OperandType {
    /**
     * An 8 bit boolean scalar value.
@@ -51,6 +53,29 @@ enum OperandType : @1.0::OperandType {
    TENSOR_BOOL8 = 9,
    /** An IEEE 754 16 bit floating point scalar value. */
    FLOAT16 = 10,
    /**
     * A tensor of 8 bit signed integers that represent real numbers.
     *
     * This tensor is associated with additional fields that are
     * used to convert the 8 bit signed integer to the real value and vice versa.
     * These fields are:
     * - channelDim: a 32 bit unsigned integer indicating channel dimension.
     * - scales: an array of positive 32 bit floating point values.
     * The size of the scales array must be equal to dimensions[channelDim].
     * These fields are located inside Operand's extraParams union, inside the
     * SymmPerChannelQuantParams struct.
     *
     * An Operand of this type must use 'channelQuant' field of its extraParams
     * union.
     *
     * The channel dimension of this tensor must not be unknown (dimensions[channelDim] != 0).
     *
     * The formula for real values:
     * realValue[..., C, ...] =
     *     integerValue[..., C, ...] * scales[C]
     * where C is an index in the Channel dimension.
     */
    TENSOR_QUANT8_SYMM_PER_CHANNEL = 11,
    /* ADDING A NEW FUNDAMENTAL TYPE REQUIRES UPDATING THE VALUE OF
     * OperandTypeRange::OPERAND_FUNDAMENTAL_MAX.
     */
@@ -64,7 +89,7 @@ enum OperandType : @1.0::OperandType {
 */
enum OperandTypeRange : uint32_t {
    OPERAND_FUNDAMENTAL_MIN = 0,
    OPERAND_FUNDAMENTAL_MAX = 10,
    OPERAND_FUNDAMENTAL_MAX = 11,
    OPERAND_OEM_MIN     = 10000,
    OPERAND_OEM_MAX     = 10001,
};
@@ -175,6 +200,25 @@ struct Operation {
    vec<uint32_t> outputs;
};

/**
 * Parameters for TENSOR_QUANT8_SYMM_PER_CHANNEL operand.
 */
struct SymmPerChannelQuantParams {
    /** Array of scaling values for each channel. Each value must be greater than zero. */
    vec<float> scales;
    /** Index of the channel dimension */
    uint32_t channelDim;
};

// TODO(slavash): Operand Extension support
// /**
//  * Parameters for an unknown (as of 1.2) operand extension. This is
//  * a vendor-specific extension or a platform extension (backport of
//  * functionality from newer NNAPI interface).
//  */
// struct OperandParamsUnknown {
// };

/**
 * Describes one operand of the model's graph.
 */
@@ -268,6 +312,28 @@ struct Operand {
     * - location.length is set.
     */
    DataLocation location;

    /**
     * Union of extra parameters, used by some types of Operands that need additional
     * information for the complete definition of an Operand.
     */
    safe_union ExtraParams {
       /**
        * Placeholder for operand with no extra parameters.
        */
       Monostate none;

       /**
        * Used with TENSOR_QUANT8_SYMM_PER_CHANNEL operand type.
        */
       SymmPerChannelQuantParams channelQuant;

       // TODO(slavash): Operand Extension support
       // /**
       //  * Used with Extension operand type.
       //  */
       // OperandParamsUnknown unknown;
    } extraParams;
};

/**
+18 −0
Original line number Diff line number Diff line
@@ -163,6 +163,7 @@ static uint32_t getInvalidRank(OperandType type) {
        case OperandType::TENSOR_INT32:
        case OperandType::TENSOR_QUANT8_ASYMM:
        case OperandType::TENSOR_QUANT16_SYMM:
        case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
            return 0;
        default:
            return 0;
@@ -191,6 +192,7 @@ static float getInvalidScale(OperandType type) {
        case OperandType::BOOL:
        case OperandType::TENSOR_FLOAT16:
        case OperandType::TENSOR_FLOAT32:
        case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
            return 1.0f;
        case OperandType::TENSOR_INT32:
            return -1.0f;
@@ -225,6 +227,7 @@ static std::vector<int32_t> getInvalidZeroPoints(OperandType type) {
        case OperandType::TENSOR_FLOAT16:
        case OperandType::TENSOR_FLOAT32:
        case OperandType::TENSOR_INT32:
        case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
            return {1};
        case OperandType::TENSOR_QUANT8_ASYMM:
            return {-1, 256};
@@ -288,6 +291,21 @@ static void mutateOperand(Operand* operand, OperandType type) {
                operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
            newOperand.scale = operand->scale != 0.0f ? operand->scale : 1.0f;
            break;
        case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
            newOperand.dimensions =
                operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
            newOperand.scale = 0.0f;
            newOperand.zeroPoint = 0;

            SymmPerChannelQuantParams channelQuant;
            channelQuant.channelDim = 0;
            channelQuant.scales = hidl_vec<float>(
                operand->dimensions.size() > 0 ? static_cast<size_t>(operand->dimensions[0]) : 0);
            for (size_t i = 0; i < channelQuant.scales.size(); ++i) {
                channelQuant.scales[i] = 1.0f;
            }
            newOperand.extraParams.channelQuant(std::move(channelQuant));
        } break;
        case OperandType::OEM:
        case OperandType::TENSOR_OEM_BYTE:
        default: