Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 5e0e7cf9 authored by Derek Wu's avatar Derek Wu
Browse files

Add smoothing to jerk calculations and updated jerk thresholds.

Test: atest libinput_tests
Test: atest CtsInputTestCases
Test: atest MotionPredictorBenchmark MotionPredictorTest
Test: Using stylus in a drawing app and seeing the jerk logs.
Bug: 266747654
Bug: 353161308
Flag: com.android.input.flags.enable_prediction_pruning_via_jerk_thresholding
Change-Id: I3d6c47d94d66e5ff2b33474acbca72daca051242
parent cda4744d
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -35,7 +35,10 @@

    The jerk thresholds are based on normalized dt = 1 calculations.
  -->
  <low-jerk>1.0</low-jerk>
  <high-jerk>1.1</high-jerk>
  <low-jerk>1.5</low-jerk>
  <high-jerk>2.0</high-jerk>

  <!-- The forget factor in the first-order IIR filter for jerk smoothing -->
  <jerk-forget-factor>0.25</jerk-forget-factor>
</motion-predictor>
+13 −0
Original line number Diff line number Diff line
@@ -56,12 +56,20 @@ public:
    // acceleration) and has the units of d^3p/dt^3.
    std::optional<float> jerkMagnitude() const;

    // forgetFactor is the coefficient of the first-order IIR filter for jerk. A factor of 1 results
    // in no smoothing.
    void setForgetFactor(float forgetFactor);
    float getForgetFactor() const;

private:
    const bool mNormalizedDt;
    // Coefficient of first-order IIR filter to smooth jerk calculation.
    float mForgetFactor = 1;

    RingBuffer<int64_t> mTimestamps{4};
    std::array<float, 4> mXDerivatives{}; // [x, x', x'', x''']
    std::array<float, 4> mYDerivatives{}; // [y, y', y'', y''']
    float mJerkMagnitude;
};

/**
@@ -116,6 +124,11 @@ public:

    bool isPredictionAvailable(int32_t deviceId, int32_t source);

    /**
     * Currently used to expose config constants in testing.
     */
    const TfLiteMotionPredictorModel::Config& getModelConfig();

private:
    const nsecs_t mPredictionTimestampOffsetNanos;
    const std::function<bool()> mCheckMotionPredictionEnabled;
+3 −0
Original line number Diff line number Diff line
@@ -110,6 +110,9 @@ public:
        // High jerk means more predictions will be pruned, vice versa for low.
        float lowJerk = 0;
        float highJerk = 0;

        // Coefficient for the first-order IIR filter for jerk calculation.
        float jerkForgetFactor = 1;
    };

    // Creates a model from an encoded Flatbuffer model.
+31 −1
Original line number Diff line number Diff line
@@ -75,6 +75,9 @@ float normalizeRange(float x, float min, float max) {
JerkTracker::JerkTracker(bool normalizedDt) : mNormalizedDt(normalizedDt) {}

void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) {
    // If we previously had full samples, we have a previous jerk calculation
    // to do weighted smoothing.
    const bool applySmoothing = mTimestamps.size() == mTimestamps.capacity();
    mTimestamps.pushBack(timestamp);
    const int numSamples = mTimestamps.size();

@@ -115,6 +118,16 @@ void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) {
        }
    }

    if (numSamples == static_cast<int>(mTimestamps.capacity())) {
        float newJerkMagnitude = std::hypot(newXDerivatives[3], newYDerivatives[3]);
        ALOGD_IF(isDebug(), "raw jerk: %f", newJerkMagnitude);
        if (applySmoothing) {
            mJerkMagnitude = mJerkMagnitude + (mForgetFactor * (newJerkMagnitude - mJerkMagnitude));
        } else {
            mJerkMagnitude = newJerkMagnitude;
        }
    }

    std::swap(newXDerivatives, mXDerivatives);
    std::swap(newYDerivatives, mYDerivatives);
}
@@ -125,11 +138,19 @@ void JerkTracker::reset() {

std::optional<float> JerkTracker::jerkMagnitude() const {
    if (mTimestamps.size() == mTimestamps.capacity()) {
        return std::hypot(mXDerivatives[3], mYDerivatives[3]);
        return mJerkMagnitude;
    }
    return std::nullopt;
}

void JerkTracker::setForgetFactor(float forgetFactor) {
    mForgetFactor = forgetFactor;
}

float JerkTracker::getForgetFactor() const {
    return mForgetFactor;
}

// --- MotionPredictor ---

MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
@@ -159,6 +180,7 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
    if (!mModel) {
        mModel = TfLiteMotionPredictorModel::create();
        LOG_ALWAYS_FATAL_IF(!mModel);
        mJerkTracker.setForgetFactor(mModel->config().jerkForgetFactor);
    }

    if (!mBuffers) {
@@ -357,4 +379,12 @@ bool MotionPredictor::isPredictionAvailable(int32_t /*deviceId*/, int32_t source
    return true;
}

const TfLiteMotionPredictorModel::Config& MotionPredictor::getModelConfig() {
    if (!mModel) {
        mModel = TfLiteMotionPredictorModel::create();
        LOG_ALWAYS_FATAL_IF(!mModel);
    }
    return mModel->config();
}

} // namespace android
+1 −0
Original line number Diff line number Diff line
@@ -283,6 +283,7 @@ std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create()
            .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"),
            .lowJerk = parseXMLFloat(*configRoot, "low-jerk"),
            .highJerk = parseXMLFloat(*configRoot, "high-jerk"),
            .jerkForgetFactor = parseXMLFloat(*configRoot, "jerk-forget-factor"),
    };

    return std::unique_ptr<TfLiteMotionPredictorModel>(
Loading