Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 7dc791d1 authored by Michael Butler's avatar Michael Butler Committed by android-build-merger
Browse files

Merge "NNAPI: validate that FmqResultDatum padding is 0 -- VTS" into qt-dev am: 2cd02a41

am: 557b9a95

Change-Id: Ice778d3f6d5b31bc2d3e7f570b04196973fd2d63
parents e14328ca 557b9a95
Loading
Loading
Loading
Loading
+80 −0
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@
#include "Utils.h"

#include <android-base/logging.h>
#include <cstring>

namespace android {
namespace hardware {
@@ -317,12 +318,91 @@ static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel,
    }
}

static bool isSanitized(const FmqResultDatum& datum) {
    using Discriminator = FmqResultDatum::hidl_discriminator;

    // check to ensure the padding values in the returned
    // FmqResultDatum::OperandInformation are initialized to 0
    if (datum.getDiscriminator() == Discriminator::operandInformation) {
        static_assert(
                offsetof(FmqResultDatum::OperandInformation, isSufficient) == 0,
                "unexpected value for offset of FmqResultDatum::OperandInformation::isSufficient");
        static_assert(
                sizeof(FmqResultDatum::OperandInformation::isSufficient) == 1,
                "unexpected value for size of FmqResultDatum::OperandInformation::isSufficient");
        static_assert(offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) == 4,
                      "unexpected value for offset of "
                      "FmqResultDatum::OperandInformation::numberOfDimensions");
        static_assert(sizeof(FmqResultDatum::OperandInformation::numberOfDimensions) == 4,
                      "unexpected value for size of "
                      "FmqResultDatum::OperandInformation::numberOfDimensions");
        static_assert(sizeof(FmqResultDatum::OperandInformation) == 8,
                      "unexpected value for size of "
                      "FmqResultDatum::OperandInformation");

        constexpr size_t paddingOffset =
                offsetof(FmqResultDatum::OperandInformation, isSufficient) +
                sizeof(FmqResultDatum::OperandInformation::isSufficient);
        constexpr size_t paddingSize =
                offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) - paddingOffset;

        FmqResultDatum::OperandInformation initialized{};
        std::memset(&initialized, 0, sizeof(initialized));

        const char* initializedPaddingStart =
                reinterpret_cast<const char*>(&initialized) + paddingOffset;
        const char* datumPaddingStart =
                reinterpret_cast<const char*>(&datum.operandInformation()) + paddingOffset;

        return std::memcmp(datumPaddingStart, initializedPaddingStart, paddingSize) == 0;
    }

    // there are no other padding initialization checks required, so return true
    // for any sum-type that isn't FmqResultDatum::OperandInformation
    return true;
}

static void validateBurstSanitized(const sp<IPreparedModel>& preparedModel,
                                   const std::vector<Request>& requests) {
    // create burst
    std::unique_ptr<RequestChannelSender> sender;
    std::unique_ptr<ResultChannelReceiver> receiver;
    sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
    sp<IBurstContext> context;
    ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
    ASSERT_NE(nullptr, sender.get());
    ASSERT_NE(nullptr, receiver.get());
    ASSERT_NE(nullptr, context.get());

    // validate each request
    for (const Request& request : requests) {
        // load memory into callback slots
        std::vector<intptr_t> keys;
        keys.reserve(request.pools.size());
        std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
                       [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
        const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);

        // send valid request
        ASSERT_TRUE(sender->send(request, MeasureTiming::YES, slots));

        // receive valid result
        auto serialized = receiver->getPacketBlocking();
        ASSERT_TRUE(serialized.has_value());

        // sanitize result
        ASSERT_TRUE(std::all_of(serialized->begin(), serialized->end(), isSanitized))
                << "The result serialized data is not properly sanitized";
    }
}

///////////////////////////// ENTRY POINT //////////////////////////////////

void ValidationTest::validateBurst(const sp<IPreparedModel>& preparedModel,
                                   const std::vector<Request>& requests) {
    ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, requests));
    ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, requests));
    ASSERT_NO_FATAL_FAILURE(validateBurstSanitized(preparedModel, requests));
}

}  // namespace functional