Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit cc6aec59 authored by Derek Wu's avatar Derek Wu
Browse files

Refactor JerkTracker and MotionPredictor for better testing.

Changes include renaming forgetFactor to alpha.

Test: atest libinput_tests
Bug: 266747654
Bug: 353161308
Flag: com.android.input.flags.enable_prediction_pruning_via_jerk_thresholding
Change-Id: Icd056d36a3d7894c6c9b4b957233002ad961a9a1
parent 4b84e83d
Loading
Loading
Loading
Loading
+3 −2
Original line number Original line Diff line number Diff line
@@ -38,7 +38,8 @@
  <low-jerk>1.5</low-jerk>
  <low-jerk>1.5</low-jerk>
  <high-jerk>2.0</high-jerk>
  <high-jerk>2.0</high-jerk>


  <!-- The forget factor in the first-order IIR filter for jerk smoothing -->
  <!-- The alpha in the first-order IIR filter for jerk smoothing. An alpha
  <jerk-forget-factor>0.25</jerk-forget-factor>
       of 1 results in no smoothing.-->
  <jerk-alpha>0.25</jerk-alpha>
</motion-predictor>
</motion-predictor>
+12 −18
Original line number Original line Diff line number Diff line
@@ -43,7 +43,9 @@ static inline bool isMotionPredictionEnabled() {
class JerkTracker {
class JerkTracker {
public:
public:
    // Initialize the tracker. If normalizedDt is true, assume that each sample pushed has dt=1.
    // Initialize the tracker. If normalizedDt is true, assume that each sample pushed has dt=1.
    JerkTracker(bool normalizedDt);
    // alpha is the coefficient of the first-order IIR filter for jerk. A factor of 1 results
    // in no smoothing.
    JerkTracker(bool normalizedDt, float alpha);


    // Add a position to the tracker and update derivative estimates.
    // Add a position to the tracker and update derivative estimates.
    void pushSample(int64_t timestamp, float xPos, float yPos);
    void pushSample(int64_t timestamp, float xPos, float yPos);
@@ -56,15 +58,10 @@ public:
    // acceleration) and has the units of d^3p/dt^3.
    // acceleration) and has the units of d^3p/dt^3.
    std::optional<float> jerkMagnitude() const;
    std::optional<float> jerkMagnitude() const;


    // forgetFactor is the coefficient of the first-order IIR filter for jerk. A factor of 1 results
    // in no smoothing.
    void setForgetFactor(float forgetFactor);
    float getForgetFactor() const;

private:
private:
    const bool mNormalizedDt;
    const bool mNormalizedDt;
    // Coefficient of first-order IIR filter to smooth jerk calculation.
    // Coefficient of first-order IIR filter to smooth jerk calculation.
    float mForgetFactor = 1;
    const float mAlpha;


    RingBuffer<int64_t> mTimestamps{4};
    RingBuffer<int64_t> mTimestamps{4};
    std::array<float, 4> mXDerivatives{}; // [x, x', x'', x''']
    std::array<float, 4> mXDerivatives{}; // [x, x', x'', x''']
@@ -124,11 +121,6 @@ public:


    bool isPredictionAvailable(int32_t deviceId, int32_t source);
    bool isPredictionAvailable(int32_t deviceId, int32_t source);


    /**
     * Currently used to expose config constants in testing.
     */
    const TfLiteMotionPredictorModel::Config& getModelConfig();

private:
private:
    const nsecs_t mPredictionTimestampOffsetNanos;
    const nsecs_t mPredictionTimestampOffsetNanos;
    const std::function<bool()> mCheckMotionPredictionEnabled;
    const std::function<bool()> mCheckMotionPredictionEnabled;
@@ -137,15 +129,17 @@ private:


    std::unique_ptr<TfLiteMotionPredictorBuffers> mBuffers;
    std::unique_ptr<TfLiteMotionPredictorBuffers> mBuffers;
    std::optional<MotionEvent> mLastEvent;
    std::optional<MotionEvent> mLastEvent;
    // mJerkTracker assumes normalized dt = 1 between recorded samples because
    // the underlying mModel input also assumes fixed-interval samples.
    // Normalized dt as 1 is also used to correspond with the similar Jank
    // implementation from the JetPack MotionPredictor implementation.
    JerkTracker mJerkTracker{true};


    std::optional<MotionPredictorMetricsManager> mMetricsManager;
    std::unique_ptr<JerkTracker> mJerkTracker;

    std::unique_ptr<MotionPredictorMetricsManager> mMetricsManager;


    const ReportAtomFunction mReportAtomFunction;
    const ReportAtomFunction mReportAtomFunction;

    // Initialize prediction model and associated objects.
    // Called during lazy initialization.
    // TODO: b/210158587 Consider removing lazy initialization.
    void initializeObjects();
};
};


} // namespace android
} // namespace android
+1 −1
Original line number Original line Diff line number Diff line
@@ -112,7 +112,7 @@ public:
        float highJerk = 0;
        float highJerk = 0;


        // Coefficient for the first-order IIR filter for jerk calculation.
        // Coefficient for the first-order IIR filter for jerk calculation.
        float jerkForgetFactor = 1;
        float jerkAlpha = 1;
    };
    };


    // Creates a model from an encoded Flatbuffer model.
    // Creates a model from an encoded Flatbuffer model.
+27 −35
Original line number Original line Diff line number Diff line
@@ -72,7 +72,8 @@ float normalizeRange(float x, float min, float max) {


// --- JerkTracker ---
// --- JerkTracker ---


JerkTracker::JerkTracker(bool normalizedDt) : mNormalizedDt(normalizedDt) {}
JerkTracker::JerkTracker(bool normalizedDt, float alpha)
      : mNormalizedDt(normalizedDt), mAlpha(alpha) {}


void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) {
void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) {
    // If we previously had full samples, we have a previous jerk calculation
    // If we previously had full samples, we have a previous jerk calculation
@@ -122,7 +123,7 @@ void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) {
        float newJerkMagnitude = std::hypot(newXDerivatives[3], newYDerivatives[3]);
        float newJerkMagnitude = std::hypot(newXDerivatives[3], newYDerivatives[3]);
        ALOGD_IF(isDebug(), "raw jerk: %f", newJerkMagnitude);
        ALOGD_IF(isDebug(), "raw jerk: %f", newJerkMagnitude);
        if (applySmoothing) {
        if (applySmoothing) {
            mJerkMagnitude = mJerkMagnitude + (mForgetFactor * (newJerkMagnitude - mJerkMagnitude));
            mJerkMagnitude = mJerkMagnitude + (mAlpha * (newJerkMagnitude - mJerkMagnitude));
        } else {
        } else {
            mJerkMagnitude = newJerkMagnitude;
            mJerkMagnitude = newJerkMagnitude;
        }
        }
@@ -143,14 +144,6 @@ std::optional<float> JerkTracker::jerkMagnitude() const {
    return std::nullopt;
    return std::nullopt;
}
}


void JerkTracker::setForgetFactor(float forgetFactor) {
    mForgetFactor = forgetFactor;
}

float JerkTracker::getForgetFactor() const {
    return mForgetFactor;
}

// --- MotionPredictor ---
// --- MotionPredictor ---


MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
@@ -160,6 +153,24 @@ MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
        mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)),
        mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)),
        mReportAtomFunction(reportAtomFunction) {}
        mReportAtomFunction(reportAtomFunction) {}


void MotionPredictor::initializeObjects() {
    mModel = TfLiteMotionPredictorModel::create();
    LOG_ALWAYS_FATAL_IF(!mModel);

    // mJerkTracker assumes normalized dt = 1 between recorded samples because
    // the underlying mModel input also assumes fixed-interval samples.
    // Normalized dt as 1 is also used to correspond with the similar Jank
    // implementation from the JetPack MotionPredictor implementation.
    mJerkTracker = std::make_unique<JerkTracker>(/*normalizedDt=*/true, mModel->config().jerkAlpha);

    mBuffers = std::make_unique<TfLiteMotionPredictorBuffers>(mModel->inputLength());

    mMetricsManager =
            std::make_unique<MotionPredictorMetricsManager>(mModel->config().predictionInterval,
                                                            mModel->outputLength(),
                                                            mReportAtomFunction);
}

android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
    if (mLastEvent && mLastEvent->getDeviceId() != event.getDeviceId()) {
    if (mLastEvent && mLastEvent->getDeviceId() != event.getDeviceId()) {
        // We still have an active gesture for another device. The provided MotionEvent is not
        // We still have an active gesture for another device. The provided MotionEvent is not
@@ -176,29 +187,18 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
        return {};
        return {};
    }
    }


    // Initialise the model now that it's likely to be used.
    if (!mModel) {
    if (!mModel) {
        mModel = TfLiteMotionPredictorModel::create();
        initializeObjects();
        LOG_ALWAYS_FATAL_IF(!mModel);
        mJerkTracker.setForgetFactor(mModel->config().jerkForgetFactor);
    }

    if (!mBuffers) {
        mBuffers = std::make_unique<TfLiteMotionPredictorBuffers>(mModel->inputLength());
    }
    }


    // Pass input event to the MetricsManager.
    // Pass input event to the MetricsManager.
    if (!mMetricsManager) {
        mMetricsManager.emplace(mModel->config().predictionInterval, mModel->outputLength(),
                                mReportAtomFunction);
    }
    mMetricsManager->onRecord(event);
    mMetricsManager->onRecord(event);


    const int32_t action = event.getActionMasked();
    const int32_t action = event.getActionMasked();
    if (action == AMOTION_EVENT_ACTION_UP || action == AMOTION_EVENT_ACTION_CANCEL) {
    if (action == AMOTION_EVENT_ACTION_UP || action == AMOTION_EVENT_ACTION_CANCEL) {
        ALOGD_IF(isDebug(), "End of event stream");
        ALOGD_IF(isDebug(), "End of event stream");
        mBuffers->reset();
        mBuffers->reset();
        mJerkTracker.reset();
        mJerkTracker->reset();
        mLastEvent.reset();
        mLastEvent.reset();
        return {};
        return {};
    } else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) {
    } else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) {
@@ -233,7 +233,7 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
                                                                          0, i),
                                                                          0, i),
                                     .orientation = event.getHistoricalOrientation(0, i),
                                     .orientation = event.getHistoricalOrientation(0, i),
                             });
                             });
        mJerkTracker.pushSample(event.getHistoricalEventTime(i),
        mJerkTracker->pushSample(event.getHistoricalEventTime(i),
                                 coords->getAxisValue(AMOTION_EVENT_AXIS_X),
                                 coords->getAxisValue(AMOTION_EVENT_AXIS_X),
                                 coords->getAxisValue(AMOTION_EVENT_AXIS_Y));
                                 coords->getAxisValue(AMOTION_EVENT_AXIS_Y));
    }
    }
@@ -283,7 +283,7 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) {
    int64_t predictionTime = mBuffers->lastTimestamp();
    int64_t predictionTime = mBuffers->lastTimestamp();
    const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;
    const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;


    const float jerkMagnitude = mJerkTracker.jerkMagnitude().value_or(0);
    const float jerkMagnitude = mJerkTracker->jerkMagnitude().value_or(0);
    const float fractionKept =
    const float fractionKept =
            1 - normalizeRange(jerkMagnitude, mModel->config().lowJerk, mModel->config().highJerk);
            1 - normalizeRange(jerkMagnitude, mModel->config().lowJerk, mModel->config().highJerk);
    // float to ensure proper division below.
    // float to ensure proper division below.
@@ -379,12 +379,4 @@ bool MotionPredictor::isPredictionAvailable(int32_t /*deviceId*/, int32_t source
    return true;
    return true;
}
}


const TfLiteMotionPredictorModel::Config& MotionPredictor::getModelConfig() {
    if (!mModel) {
        mModel = TfLiteMotionPredictorModel::create();
        LOG_ALWAYS_FATAL_IF(!mModel);
    }
    return mModel->config();
}

} // namespace android
} // namespace android
+1 −1
Original line number Original line Diff line number Diff line
@@ -283,7 +283,7 @@ std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create()
            .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"),
            .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"),
            .lowJerk = parseXMLFloat(*configRoot, "low-jerk"),
            .lowJerk = parseXMLFloat(*configRoot, "low-jerk"),
            .highJerk = parseXMLFloat(*configRoot, "high-jerk"),
            .highJerk = parseXMLFloat(*configRoot, "high-jerk"),
            .jerkForgetFactor = parseXMLFloat(*configRoot, "jerk-forget-factor"),
            .jerkAlpha = parseXMLFloat(*configRoot, "jerk-alpha"),
    };
    };


    return std::unique_ptr<TfLiteMotionPredictorModel>(
    return std::unique_ptr<TfLiteMotionPredictorModel>(
Loading