Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c7f61812 authored by Derek Wu's avatar Derek Wu Committed by Android (Google) Code Review
Browse files

Merge "Add jerk thresholded pruning." into main

parents cf0a70f6 aaa47312
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -31,5 +31,13 @@
       the UX issue mentioned above.
  -->
  <distance-noise-floor>0.2</distance-noise-floor>
  <!-- The low and high jerk thresholds for prediction pruning.

    The jerk thresholds are based on normalized dt = 1 calculations, and
    are taken from Jetpacks MotionEventPredictor's KalmanPredictor
    implementation (using its ACCURATE_LOW_JANK and ACCURATE_HIGH_JANK).
  -->
  <low-jerk>0.1</low-jerk>
  <high-jerk>0.2</high-jerk>
</motion-predictor>
+5 −0
Original line number Diff line number Diff line
@@ -105,6 +105,11 @@ public:
        // The noise floor for predictions.
        // Distances (r) less than this should be discarded as noise.
        float distanceNoiseFloor = 0;

        // Low and high jerk thresholds (with normalized dt = 1) for predictions.
        // High jerk means more predictions will be pruned, vice versa for low.
        float lowJerk = 0;
        float highJerk = 0;
    };

    // Creates a model from an encoded Flatbuffer model.
+20 −4
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@

#include <input/MotionPredictor.h>

#include <algorithm>
#include <array>
#include <cinttypes>
#include <cmath>
@@ -62,6 +63,11 @@ TfLiteMotionPredictorSample::Point convertPrediction(
    return {.x = axisTo.x + x_delta, .y = axisTo.y + y_delta};
}

float normalizeRange(float x, float min, float max) {
    const float normalized = (x - min) / (max - min);
    return std::min(1.0f, std::max(0.0f, normalized));
}

} // namespace

// --- JerkTracker ---
@@ -255,6 +261,17 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) {
    int64_t predictionTime = mBuffers->lastTimestamp();
    const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;

    const float jerkMagnitude = mJerkTracker.jerkMagnitude().value_or(0);
    const float fractionKept =
            1 - normalizeRange(jerkMagnitude, mModel->config().lowJerk, mModel->config().highJerk);
    // float to ensure proper division below.
    const float predictionTimeWindow = futureTime - predictionTime;
    const int maxNumPredictions = static_cast<int>(
            std::ceil(predictionTimeWindow / mModel->config().predictionInterval * fractionKept));
    ALOGD_IF(isDebug(),
             "jerk (d^3p/normalizedDt^3): %f, fraction of prediction window pruned: %f, max number "
             "of predictions: %d",
             jerkMagnitude, 1 - fractionKept, maxNumPredictions);
    for (size_t i = 0; i < static_cast<size_t>(predictedR.size()) && predictionTime <= futureTime;
         ++i) {
        if (predictedR[i] < mModel->config().distanceNoiseFloor) {
@@ -269,13 +286,12 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) {
            break;
        }
        if (input_flags::enable_prediction_pruning_via_jerk_thresholding()) {
            // TODO(b/266747654): Stop predictions if confidence is < some threshold
            // Arbitrarily high pruning index, will correct once jerk thresholding is implemented.
            const size_t upperBoundPredictionIndex = std::numeric_limits<size_t>::max();
            if (i > upperBoundPredictionIndex) {
            if (i >= static_cast<size_t>(maxNumPredictions)) {
                break;
            }
        }
        // TODO(b/266747654): Stop predictions if confidence is < some
        // threshold. Currently predictions are pruned via jerk thresholding.

        const TfLiteMotionPredictorSample::Point predictedPoint =
                convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]);
+2 −0
Original line number Diff line number Diff line
@@ -281,6 +281,8 @@ std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create()
    Config config{
            .predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"),
            .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"),
            .lowJerk = parseXMLFloat(*configRoot, "low-jerk"),
            .highJerk = parseXMLFloat(*configRoot, "high-jerk"),
    };

    return std::unique_ptr<TfLiteMotionPredictorModel>(
+1 −0
Original line number Diff line number Diff line
@@ -36,6 +36,7 @@ cc_test {
        "tensorflow_headers",
    ],
    static_libs: [
        "libflagtest",
        "libgmock",
        "libgui_window_info_static",
        "libinput",
Loading