Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit cb3229aa authored by Philip Quinn's avatar Philip Quinn
Browse files

Use mmap to read TFLite model.

The buffers in the model file are used directly by TFLite, and so a
small memory saving can be achieved by backing those memory pages with
the file itself.

Bug: 267050081
Test: atest libinput_tests
Change-Id: I743a3c94477d4bb778b6e0c4b4890a44f4e19aa4
parent da6a448e
Loading
Loading
Loading
Loading
+3 −3
Original line number Original line Diff line number Diff line
@@ -22,8 +22,8 @@
#include <memory>
#include <memory>
#include <optional>
#include <optional>
#include <span>
#include <span>
#include <string>


#include <android-base/mapped_file.h>
#include <input/RingBuffer.h>
#include <input/RingBuffer.h>


#include <tensorflow/lite/core/api/error_reporter.h>
#include <tensorflow/lite/core/api/error_reporter.h>
@@ -124,7 +124,7 @@ public:
    std::span<const float> outputPressure() const;
    std::span<const float> outputPressure() const;


private:
private:
    explicit TfLiteMotionPredictorModel(std::string model);
    explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model);


    void allocateTensors();
    void allocateTensors();
    void attachInputTensors();
    void attachInputTensors();
@@ -140,7 +140,7 @@ private:
    const TfLiteTensor* mOutputPhi = nullptr;
    const TfLiteTensor* mOutputPhi = nullptr;
    const TfLiteTensor* mOutputPressure = nullptr;
    const TfLiteTensor* mOutputPressure = nullptr;


    std::string mFlatBuffer;
    std::unique_ptr<android::base::MappedFile> mFlatBuffer;
    std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
    std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
    std::unique_ptr<tflite::FlatBufferModel> mModel;
    std::unique_ptr<tflite::FlatBufferModel> mModel;
    std::unique_ptr<tflite::Interpreter> mInterpreter;
    std::unique_ptr<tflite::Interpreter> mInterpreter;
+29 −12
Original line number Original line Diff line number Diff line
@@ -17,19 +17,21 @@
#define LOG_TAG "TfLiteMotionPredictor"
#define LOG_TAG "TfLiteMotionPredictor"
#include <input/TfLiteMotionPredictor.h>
#include <input/TfLiteMotionPredictor.h>


#include <fcntl.h>
#include <sys/mman.h>
#include <unistd.h>

#include <algorithm>
#include <algorithm>
#include <cmath>
#include <cmath>
#include <cstddef>
#include <cstddef>
#include <cstdint>
#include <cstdint>
#include <fstream>
#include <ios>
#include <iterator>
#include <memory>
#include <memory>
#include <span>
#include <span>
#include <string>
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>


#include <android-base/logging.h>
#include <android-base/mapped_file.h>
#define ATRACE_TAG ATRACE_TAG_INPUT
#define ATRACE_TAG ATRACE_TAG_INPUT
#include <cutils/trace.h>
#include <cutils/trace.h>
#include <log/log.h>
#include <log/log.h>
@@ -206,21 +208,36 @@ void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,


std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create(
std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create(
        const char* modelPath) {
        const char* modelPath) {
    std::ifstream f(modelPath, std::ios::binary);
    const int fd = open(modelPath, O_RDONLY);
    LOG_ALWAYS_FATAL_IF(!f, "Could not read model from %s", modelPath);
    if (fd == -1) {
        PLOG(FATAL) << "Could not read model from " << modelPath;
    }

    const off_t fdSize = lseek(fd, 0, SEEK_END);
    if (fdSize == -1) {
        PLOG(FATAL) << "Failed to determine file size";
    }


    std::string data;
    std::unique_ptr<android::base::MappedFile> modelBuffer =
    data.assign(std::istreambuf_iterator<char>(f), std::istreambuf_iterator<char>());
            android::base::MappedFile::FromFd(fd, /*offset=*/0, fdSize, PROT_READ);
    if (!modelBuffer) {
        PLOG(FATAL) << "Failed to mmap model";
    }
    if (close(fd) == -1) {
        PLOG(FATAL) << "Failed to close model fd";
    }


    return std::unique_ptr<TfLiteMotionPredictorModel>(
    return std::unique_ptr<TfLiteMotionPredictorModel>(
            new TfLiteMotionPredictorModel(std::move(data)));
            new TfLiteMotionPredictorModel(std::move(modelBuffer)));
}
}


TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(std::string model)
TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
        std::unique_ptr<android::base::MappedFile> model)
      : mFlatBuffer(std::move(model)) {
      : mFlatBuffer(std::move(model)) {
    CHECK(mFlatBuffer);
    mErrorReporter = std::make_unique<LoggingErrorReporter>();
    mErrorReporter = std::make_unique<LoggingErrorReporter>();
    mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer.data(),
    mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),
                                                               mFlatBuffer.length(),
                                                               mFlatBuffer->size(),
                                                               /*extra_verifier=*/nullptr,
                                                               /*extra_verifier=*/nullptr,
                                                               mErrorReporter.get());
                                                               mErrorReporter.get());
    LOG_ALWAYS_FATAL_IF(!mModel);
    LOG_ALWAYS_FATAL_IF(!mModel);