Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 93f2b5ee authored by Derek Wu's avatar Derek Wu Committed by Android (Google) Code Review
Browse files

Merge "Refactor JerkTracker and MotionPredictor for better testing." into main

parents 2526c331 cc6aec59
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -38,7 +38,8 @@
  <low-jerk>1.5</low-jerk>
  <high-jerk>2.0</high-jerk>

  <!-- The forget factor in the first-order IIR filter for jerk smoothing -->
  <jerk-forget-factor>0.25</jerk-forget-factor>
  <!-- The alpha in the first-order IIR filter for jerk smoothing. An alpha
       of 1 results in no smoothing.-->
  <jerk-alpha>0.25</jerk-alpha>
</motion-predictor>
+12 −18
Original line number Diff line number Diff line
@@ -43,7 +43,9 @@ static inline bool isMotionPredictionEnabled() {
class JerkTracker {
public:
    // Initialize the tracker. If normalizedDt is true, assume that each sample pushed has dt=1.
    JerkTracker(bool normalizedDt);
    // alpha is the coefficient of the first-order IIR filter for jerk. A factor of 1 results
    // in no smoothing.
    JerkTracker(bool normalizedDt, float alpha);

    // Add a position to the tracker and update derivative estimates.
    void pushSample(int64_t timestamp, float xPos, float yPos);
@@ -56,15 +58,10 @@ public:
    // acceleration) and has the units of d^3p/dt^3.
    std::optional<float> jerkMagnitude() const;

    // forgetFactor is the coefficient of the first-order IIR filter for jerk. A factor of 1 results
    // in no smoothing.
    void setForgetFactor(float forgetFactor);
    float getForgetFactor() const;

private:
    const bool mNormalizedDt;
    // Coefficient of first-order IIR filter to smooth jerk calculation.
    float mForgetFactor = 1;
    const float mAlpha;

    RingBuffer<int64_t> mTimestamps{4};
    std::array<float, 4> mXDerivatives{}; // [x, x', x'', x''']
@@ -124,11 +121,6 @@ public:

    bool isPredictionAvailable(int32_t deviceId, int32_t source);

    /**
     * Currently used to expose config constants in testing.
     */
    const TfLiteMotionPredictorModel::Config& getModelConfig();

private:
    const nsecs_t mPredictionTimestampOffsetNanos;
    const std::function<bool()> mCheckMotionPredictionEnabled;
@@ -137,15 +129,17 @@ private:

    std::unique_ptr<TfLiteMotionPredictorBuffers> mBuffers;
    std::optional<MotionEvent> mLastEvent;
    // mJerkTracker assumes normalized dt = 1 between recorded samples because
    // the underlying mModel input also assumes fixed-interval samples.
    // Normalized dt as 1 is also used to correspond with the similar Jank
    // implementation from the JetPack MotionPredictor implementation.
    JerkTracker mJerkTracker{true};

    std::optional<MotionPredictorMetricsManager> mMetricsManager;
    std::unique_ptr<JerkTracker> mJerkTracker;

    std::unique_ptr<MotionPredictorMetricsManager> mMetricsManager;

    const ReportAtomFunction mReportAtomFunction;

    // Initialize prediction model and associated objects.
    // Called during lazy initialization.
    // TODO: b/210158587 Consider removing lazy initialization.
    void initializeObjects();
};

} // namespace android
+1 −1
Original line number Diff line number Diff line
@@ -112,7 +112,7 @@ public:
        float highJerk = 0;

        // Coefficient for the first-order IIR filter for jerk calculation.
        float jerkForgetFactor = 1;
        float jerkAlpha = 1;
    };

    // Creates a model from an encoded Flatbuffer model.
+27 −35
Original line number Diff line number Diff line
@@ -72,7 +72,8 @@ float normalizeRange(float x, float min, float max) {

// --- JerkTracker ---

JerkTracker::JerkTracker(bool normalizedDt) : mNormalizedDt(normalizedDt) {}
JerkTracker::JerkTracker(bool normalizedDt, float alpha)
      : mNormalizedDt(normalizedDt), mAlpha(alpha) {}

void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) {
    // If we previously had full samples, we have a previous jerk calculation
@@ -122,7 +123,7 @@ void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) {
        float newJerkMagnitude = std::hypot(newXDerivatives[3], newYDerivatives[3]);
        ALOGD_IF(isDebug(), "raw jerk: %f", newJerkMagnitude);
        if (applySmoothing) {
            mJerkMagnitude = mJerkMagnitude + (mForgetFactor * (newJerkMagnitude - mJerkMagnitude));
            mJerkMagnitude = mJerkMagnitude + (mAlpha * (newJerkMagnitude - mJerkMagnitude));
        } else {
            mJerkMagnitude = newJerkMagnitude;
        }
@@ -143,14 +144,6 @@ std::optional<float> JerkTracker::jerkMagnitude() const {
    return std::nullopt;
}

void JerkTracker::setForgetFactor(float forgetFactor) {
    mForgetFactor = forgetFactor;
}

float JerkTracker::getForgetFactor() const {
    return mForgetFactor;
}

// --- MotionPredictor ---

MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
@@ -160,6 +153,24 @@ MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
        mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)),
        mReportAtomFunction(reportAtomFunction) {}

void MotionPredictor::initializeObjects() {
    mModel = TfLiteMotionPredictorModel::create();
    LOG_ALWAYS_FATAL_IF(!mModel);

    // mJerkTracker assumes normalized dt = 1 between recorded samples because
    // the underlying mModel input also assumes fixed-interval samples.
    // Normalized dt as 1 is also used to correspond with the similar Jank
    // implementation from the JetPack MotionPredictor implementation.
    mJerkTracker = std::make_unique<JerkTracker>(/*normalizedDt=*/true, mModel->config().jerkAlpha);

    mBuffers = std::make_unique<TfLiteMotionPredictorBuffers>(mModel->inputLength());

    mMetricsManager =
            std::make_unique<MotionPredictorMetricsManager>(mModel->config().predictionInterval,
                                                            mModel->outputLength(),
                                                            mReportAtomFunction);
}

android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
    if (mLastEvent && mLastEvent->getDeviceId() != event.getDeviceId()) {
        // We still have an active gesture for another device. The provided MotionEvent is not
@@ -176,29 +187,18 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
        return {};
    }

    // Initialise the model now that it's likely to be used.
    if (!mModel) {
        mModel = TfLiteMotionPredictorModel::create();
        LOG_ALWAYS_FATAL_IF(!mModel);
        mJerkTracker.setForgetFactor(mModel->config().jerkForgetFactor);
    }

    if (!mBuffers) {
        mBuffers = std::make_unique<TfLiteMotionPredictorBuffers>(mModel->inputLength());
        initializeObjects();
    }

    // Pass input event to the MetricsManager.
    if (!mMetricsManager) {
        mMetricsManager.emplace(mModel->config().predictionInterval, mModel->outputLength(),
                                mReportAtomFunction);
    }
    mMetricsManager->onRecord(event);

    const int32_t action = event.getActionMasked();
    if (action == AMOTION_EVENT_ACTION_UP || action == AMOTION_EVENT_ACTION_CANCEL) {
        ALOGD_IF(isDebug(), "End of event stream");
        mBuffers->reset();
        mJerkTracker.reset();
        mJerkTracker->reset();
        mLastEvent.reset();
        return {};
    } else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) {
@@ -233,7 +233,7 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
                                                                          0, i),
                                     .orientation = event.getHistoricalOrientation(0, i),
                             });
        mJerkTracker.pushSample(event.getHistoricalEventTime(i),
        mJerkTracker->pushSample(event.getHistoricalEventTime(i),
                                 coords->getAxisValue(AMOTION_EVENT_AXIS_X),
                                 coords->getAxisValue(AMOTION_EVENT_AXIS_Y));
    }
@@ -283,7 +283,7 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) {
    int64_t predictionTime = mBuffers->lastTimestamp();
    const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;

    const float jerkMagnitude = mJerkTracker.jerkMagnitude().value_or(0);
    const float jerkMagnitude = mJerkTracker->jerkMagnitude().value_or(0);
    const float fractionKept =
            1 - normalizeRange(jerkMagnitude, mModel->config().lowJerk, mModel->config().highJerk);
    // float to ensure proper division below.
@@ -379,12 +379,4 @@ bool MotionPredictor::isPredictionAvailable(int32_t /*deviceId*/, int32_t source
    return true;
}

const TfLiteMotionPredictorModel::Config& MotionPredictor::getModelConfig() {
    if (!mModel) {
        mModel = TfLiteMotionPredictorModel::create();
        LOG_ALWAYS_FATAL_IF(!mModel);
    }
    return mModel->config();
}

} // namespace android
+1 −1
Original line number Diff line number Diff line
@@ -283,7 +283,7 @@ std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create()
            .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"),
            .lowJerk = parseXMLFloat(*configRoot, "low-jerk"),
            .highJerk = parseXMLFloat(*configRoot, "high-jerk"),
            .jerkForgetFactor = parseXMLFloat(*configRoot, "jerk-forget-factor"),
            .jerkAlpha = parseXMLFloat(*configRoot, "jerk-alpha"),
    };

    return std::unique_ptr<TfLiteMotionPredictorModel>(
Loading