Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 082cefe3 authored by Michael Butler's avatar Michael Butler
Browse files

Add additional bounds checks to NNAPI FMQ deserialize utility functions

This CL adds the following additional bounds checks:
* Adds additional checks of the index of the std::vector before
  accessing the element at the index
* Changes the array index operator [] to the checked std::vector::at
  method

Bug: 256589724
Test: mma
Merged-In: I6bfb02a5cd76258284cc4d797a4508b21e672c4b
Change-Id: I6bfb02a5cd76258284cc4d797a4508b21e672c4b
parent 7a95f4b5
Loading
Loading
Loading
Loading
+33 −23
Original line number Original line Diff line number Diff line
@@ -192,12 +192,13 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
    size_t index = 0;
    size_t index = 0;


    // validate packet information
    // validate packet information
    if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
    if (index >= data.size() ||
        data.at(index).getDiscriminator() != discriminator::packetInformation) {
        return NN_ERROR() << "FMQ Request packet ill-formed";
        return NN_ERROR() << "FMQ Request packet ill-formed";
    }
    }


    // unpackage packet information
    // unpackage packet information
    const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation();
    const FmqRequestDatum::PacketInformation& packetInfo = data.at(index).packetInformation();
    index++;
    index++;
    const uint32_t packetSize = packetInfo.packetSize;
    const uint32_t packetSize = packetInfo.packetSize;
    const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
    const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
@@ -214,13 +215,14 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
    inputs.reserve(numberOfInputOperands);
    inputs.reserve(numberOfInputOperands);
    for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
    for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
        // validate input operand information
        // validate input operand information
        if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
        if (index >= data.size() ||
            data.at(index).getDiscriminator() != discriminator::inputOperandInformation) {
            return NN_ERROR() << "FMQ Request packet ill-formed";
            return NN_ERROR() << "FMQ Request packet ill-formed";
        }
        }


        // unpackage operand information
        // unpackage operand information
        const FmqRequestDatum::OperandInformation& operandInfo =
        const FmqRequestDatum::OperandInformation& operandInfo =
                data[index].inputOperandInformation();
                data.at(index).inputOperandInformation();
        index++;
        index++;
        const bool hasNoValue = operandInfo.hasNoValue;
        const bool hasNoValue = operandInfo.hasNoValue;
        const V1_0::DataLocation location = operandInfo.location;
        const V1_0::DataLocation location = operandInfo.location;
@@ -231,12 +233,13 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
        dimensions.reserve(numberOfDimensions);
        dimensions.reserve(numberOfDimensions);
        for (size_t i = 0; i < numberOfDimensions; ++i) {
        for (size_t i = 0; i < numberOfDimensions; ++i) {
            // validate dimension
            // validate dimension
            if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
            if (index >= data.size() ||
                data.at(index).getDiscriminator() != discriminator::inputOperandDimensionValue) {
                return NN_ERROR() << "FMQ Request packet ill-formed";
                return NN_ERROR() << "FMQ Request packet ill-formed";
            }
            }


            // unpackage dimension
            // unpackage dimension
            const uint32_t dimension = data[index].inputOperandDimensionValue();
            const uint32_t dimension = data.at(index).inputOperandDimensionValue();
            index++;
            index++;


            // store result
            // store result
@@ -253,13 +256,14 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
    outputs.reserve(numberOfOutputOperands);
    outputs.reserve(numberOfOutputOperands);
    for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
    for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
        // validate output operand information
        // validate output operand information
        if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
        if (index >= data.size() ||
            data.at(index).getDiscriminator() != discriminator::outputOperandInformation) {
            return NN_ERROR() << "FMQ Request packet ill-formed";
            return NN_ERROR() << "FMQ Request packet ill-formed";
        }
        }


        // unpackage operand information
        // unpackage operand information
        const FmqRequestDatum::OperandInformation& operandInfo =
        const FmqRequestDatum::OperandInformation& operandInfo =
                data[index].outputOperandInformation();
                data.at(index).outputOperandInformation();
        index++;
        index++;
        const bool hasNoValue = operandInfo.hasNoValue;
        const bool hasNoValue = operandInfo.hasNoValue;
        const V1_0::DataLocation location = operandInfo.location;
        const V1_0::DataLocation location = operandInfo.location;
@@ -270,12 +274,13 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
        dimensions.reserve(numberOfDimensions);
        dimensions.reserve(numberOfDimensions);
        for (size_t i = 0; i < numberOfDimensions; ++i) {
        for (size_t i = 0; i < numberOfDimensions; ++i) {
            // validate dimension
            // validate dimension
            if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
            if (index >= data.size() ||
                data.at(index).getDiscriminator() != discriminator::outputOperandDimensionValue) {
                return NN_ERROR() << "FMQ Request packet ill-formed";
                return NN_ERROR() << "FMQ Request packet ill-formed";
            }
            }


            // unpackage dimension
            // unpackage dimension
            const uint32_t dimension = data[index].outputOperandDimensionValue();
            const uint32_t dimension = data.at(index).outputOperandDimensionValue();
            index++;
            index++;


            // store result
            // store result
@@ -292,12 +297,13 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
    slots.reserve(numberOfPools);
    slots.reserve(numberOfPools);
    for (size_t pool = 0; pool < numberOfPools; ++pool) {
    for (size_t pool = 0; pool < numberOfPools; ++pool) {
        // validate input operand information
        // validate input operand information
        if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
        if (index >= data.size() ||
            data.at(index).getDiscriminator() != discriminator::poolIdentifier) {
            return NN_ERROR() << "FMQ Request packet ill-formed";
            return NN_ERROR() << "FMQ Request packet ill-formed";
        }
        }


        // unpackage operand information
        // unpackage operand information
        const int32_t poolId = data[index].poolIdentifier();
        const int32_t poolId = data.at(index).poolIdentifier();
        index++;
        index++;


        // store result
        // store result
@@ -305,17 +311,17 @@ nn::Result<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTiming>>
    }
    }


    // validate measureTiming
    // validate measureTiming
    if (data[index].getDiscriminator() != discriminator::measureTiming) {
    if (index >= data.size() || data.at(index).getDiscriminator() != discriminator::measureTiming) {
        return NN_ERROR() << "FMQ Request packet ill-formed";
        return NN_ERROR() << "FMQ Request packet ill-formed";
    }
    }


    // unpackage measureTiming
    // unpackage measureTiming
    const V1_2::MeasureTiming measure = data[index].measureTiming();
    const V1_2::MeasureTiming measure = data.at(index).measureTiming();
    index++;
    index++;


    // validate packet information
    // validate packet information
    if (index != packetSize) {
    if (index != packetSize) {
        return NN_ERROR() << "FMQ Result packet ill-formed";
        return NN_ERROR() << "FMQ Request packet ill-formed";
    }
    }


    // return request
    // return request
@@ -330,12 +336,13 @@ nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::T
    size_t index = 0;
    size_t index = 0;


    // validate packet information
    // validate packet information
    if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
    if (index >= data.size() ||
        data.at(index).getDiscriminator() != discriminator::packetInformation) {
        return NN_ERROR() << "FMQ Result packet ill-formed";
        return NN_ERROR() << "FMQ Result packet ill-formed";
    }
    }


    // unpackage packet information
    // unpackage packet information
    const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
    const FmqResultDatum::PacketInformation& packetInfo = data.at(index).packetInformation();
    index++;
    index++;
    const uint32_t packetSize = packetInfo.packetSize;
    const uint32_t packetSize = packetInfo.packetSize;
    const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus;
    const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus;
@@ -351,12 +358,13 @@ nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::T
    outputShapes.reserve(numberOfOperands);
    outputShapes.reserve(numberOfOperands);
    for (size_t operand = 0; operand < numberOfOperands; ++operand) {
    for (size_t operand = 0; operand < numberOfOperands; ++operand) {
        // validate operand information
        // validate operand information
        if (data[index].getDiscriminator() != discriminator::operandInformation) {
        if (index >= data.size() ||
            data.at(index).getDiscriminator() != discriminator::operandInformation) {
            return NN_ERROR() << "FMQ Result packet ill-formed";
            return NN_ERROR() << "FMQ Result packet ill-formed";
        }
        }


        // unpackage operand information
        // unpackage operand information
        const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
        const FmqResultDatum::OperandInformation& operandInfo = data.at(index).operandInformation();
        index++;
        index++;
        const bool isSufficient = operandInfo.isSufficient;
        const bool isSufficient = operandInfo.isSufficient;
        const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
        const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
@@ -366,12 +374,13 @@ nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::T
        dimensions.reserve(numberOfDimensions);
        dimensions.reserve(numberOfDimensions);
        for (size_t i = 0; i < numberOfDimensions; ++i) {
        for (size_t i = 0; i < numberOfDimensions; ++i) {
            // validate dimension
            // validate dimension
            if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
            if (index >= data.size() ||
                data.at(index).getDiscriminator() != discriminator::operandDimensionValue) {
                return NN_ERROR() << "FMQ Result packet ill-formed";
                return NN_ERROR() << "FMQ Result packet ill-formed";
            }
            }


            // unpackage dimension
            // unpackage dimension
            const uint32_t dimension = data[index].operandDimensionValue();
            const uint32_t dimension = data.at(index).operandDimensionValue();
            index++;
            index++;


            // store result
            // store result
@@ -383,12 +392,13 @@ nn::Result<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::T
    }
    }


    // validate execution timing
    // validate execution timing
    if (data[index].getDiscriminator() != discriminator::executionTiming) {
    if (index >= data.size() ||
        data.at(index).getDiscriminator() != discriminator::executionTiming) {
        return NN_ERROR() << "FMQ Result packet ill-formed";
        return NN_ERROR() << "FMQ Result packet ill-formed";
    }
    }


    // unpackage execution timing
    // unpackage execution timing
    const V1_2::Timing timing = data[index].executionTiming();
    const V1_2::Timing timing = data.at(index).executionTiming();
    index++;
    index++;


    // validate packet information
    // validate packet information