Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit bd66e620 authored by Philip Quinn's avatar Philip Quinn
Browse files

Postpone loading the TFLite model until a supported event is recorded.

Bug: 267050081
Test: atest libinput_tests
Change-Id: I09666da123a58786e8a6d47d4c29a475e92f2bbf
parent cb3229aa
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@
#include <cstdint>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>

#include <android-base/thread_annotations.h>
@@ -73,6 +74,7 @@ public:

private:
    const nsecs_t mPredictionTimestampOffsetNanos;
    const std::string mModelPath;
    const std::function<bool()> mCheckMotionPredictionEnabled;

    std::unique_ptr<TfLiteMotionPredictorModel> mModel;
+8 −3
Original line number Diff line number Diff line
@@ -65,9 +65,8 @@ TfLiteMotionPredictorSample::Point convertPrediction(
MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, const char* modelPath,
                                 std::function<bool()> checkMotionPredictionEnabled)
      : mPredictionTimestampOffsetNanos(predictionTimestampOffsetNanos),
        mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)),
        mModel(TfLiteMotionPredictorModel::create(modelPath == nullptr ? DEFAULT_MODEL_PATH
                                                                       : modelPath)) {}
        mModelPath(modelPath == nullptr ? DEFAULT_MODEL_PATH : modelPath),
        mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)) {}

void MotionPredictor::record(const MotionEvent& event) {
    if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) {
@@ -76,6 +75,11 @@ void MotionPredictor::record(const MotionEvent& event) {
        return;
    }

    // Initialise the model now that it's likely to be used.
    if (!mModel) {
        mModel = TfLiteMotionPredictorModel::create(mModelPath.c_str());
    }

    TfLiteMotionPredictorBuffers& buffers =
            mDeviceBuffers.try_emplace(event.getDeviceId(), mModel->inputLength()).first->second;

@@ -130,6 +134,7 @@ std::vector<std::unique_ptr<MotionEvent>> MotionPredictor::predict(nsecs_t times
            continue;
        }

        LOG_ALWAYS_FATAL_IF(!mModel);
        buffer.copyTo(*mModel);
        LOG_ALWAYS_FATAL_IF(!mModel->invoke());