Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 6964e11c authored by Philip Quinn's avatar Philip Quinn Committed by Android (Google) Code Review
Browse files

Merge "Postpone loading the TFLite model until a supported event is recorded."

parents 30bf0cde bd66e620
Loading
Loading
Loading
Loading
+2 −0
Original line number Original line Diff line number Diff line
@@ -19,6 +19,7 @@
#include <cstdint>
#include <cstdint>
#include <memory>
#include <memory>
#include <mutex>
#include <mutex>
#include <string>
#include <unordered_map>
#include <unordered_map>


#include <android-base/thread_annotations.h>
#include <android-base/thread_annotations.h>
@@ -73,6 +74,7 @@ public:


private:
private:
    const nsecs_t mPredictionTimestampOffsetNanos;
    const nsecs_t mPredictionTimestampOffsetNanos;
    const std::string mModelPath;
    const std::function<bool()> mCheckMotionPredictionEnabled;
    const std::function<bool()> mCheckMotionPredictionEnabled;


    std::unique_ptr<TfLiteMotionPredictorModel> mModel;
    std::unique_ptr<TfLiteMotionPredictorModel> mModel;
+8 −3
Original line number Original line Diff line number Diff line
@@ -65,9 +65,8 @@ TfLiteMotionPredictorSample::Point convertPrediction(
MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, const char* modelPath,
MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, const char* modelPath,
                                 std::function<bool()> checkMotionPredictionEnabled)
                                 std::function<bool()> checkMotionPredictionEnabled)
      : mPredictionTimestampOffsetNanos(predictionTimestampOffsetNanos),
      : mPredictionTimestampOffsetNanos(predictionTimestampOffsetNanos),
        mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)),
        mModelPath(modelPath == nullptr ? DEFAULT_MODEL_PATH : modelPath),
        mModel(TfLiteMotionPredictorModel::create(modelPath == nullptr ? DEFAULT_MODEL_PATH
        mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)) {}
                                                                       : modelPath)) {}


void MotionPredictor::record(const MotionEvent& event) {
void MotionPredictor::record(const MotionEvent& event) {
    if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) {
    if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) {
@@ -76,6 +75,11 @@ void MotionPredictor::record(const MotionEvent& event) {
        return;
        return;
    }
    }


    // Initialise the model now that it's likely to be used.
    if (!mModel) {
        mModel = TfLiteMotionPredictorModel::create(mModelPath.c_str());
    }

    TfLiteMotionPredictorBuffers& buffers =
    TfLiteMotionPredictorBuffers& buffers =
            mDeviceBuffers.try_emplace(event.getDeviceId(), mModel->inputLength()).first->second;
            mDeviceBuffers.try_emplace(event.getDeviceId(), mModel->inputLength()).first->second;


@@ -130,6 +134,7 @@ std::vector<std::unique_ptr<MotionEvent>> MotionPredictor::predict(nsecs_t times
            continue;
            continue;
        }
        }


        LOG_ALWAYS_FATAL_IF(!mModel);
        buffer.copyTo(*mModel);
        buffer.copyTo(*mModel);
        LOG_ALWAYS_FATAL_IF(!mModel->invoke());
        LOG_ALWAYS_FATAL_IF(!mModel->invoke());