Loading include/input/MotionPredictor.h +32 −0 Original line number Diff line number Diff line Loading @@ -16,6 +16,7 @@ #pragma once #include <array> #include <cstdint> #include <memory> #include <mutex> Loading @@ -28,6 +29,7 @@ #include <android/sysprop/InputProperties.sysprop.h> #include <input/Input.h> #include <input/MotionPredictorMetricsManager.h> #include <input/RingBuffer.h> #include <input/TfLiteMotionPredictor.h> #include <utils/Timers.h> // for nsecs_t Loading @@ -37,6 +39,31 @@ static inline bool isMotionPredictionEnabled() { return sysprop::InputProperties::enable_motion_prediction().value_or(true); } // Tracker to calculate jerk from motion position samples. class JerkTracker { public: // Initialize the tracker. If normalizedDt is true, assume that each sample pushed has dt=1. JerkTracker(bool normalizedDt); // Add a position to the tracker and update derivative estimates. void pushSample(int64_t timestamp, float xPos, float yPos); // Reset JerkTracker for a new motion input. void reset(); // Return last jerk calculation, if enough samples have been collected. // Jerk is defined as the 3rd derivative of position (change in // acceleration) and has the units of d^3p/dt^3. std::optional<float> jerkMagnitude() const; private: const bool mNormalizedDt; RingBuffer<int64_t> mTimestamps{4}; std::array<float, 4> mXDerivatives{}; // [x, x', x'', x'''] std::array<float, 4> mYDerivatives{}; // [y, y', y'', y'''] }; /** * Given a set of MotionEvents for the current gesture, predict the motion. The returned MotionEvent * contains a set of samples in the future. Loading Loading @@ -97,6 +124,11 @@ private: std::unique_ptr<TfLiteMotionPredictorBuffers> mBuffers; std::optional<MotionEvent> mLastEvent; // mJerkTracker assumes normalized dt = 1 between recorded samples because // the underlying mModel input also assumes fixed-interval samples. // Normalized dt as 1 is also used to correspond with the similar Jank // implementation from the JetPack MotionPredictor implementation. JerkTracker mJerkTracker{true}; std::optional<MotionPredictorMetricsManager> mMetricsManager; Loading libs/input/MotionPredictor.cpp +67 −0 Original line number Diff line number Diff line Loading @@ -18,12 +18,15 @@ #include <input/MotionPredictor.h> #include <array> #include <cinttypes> #include <cmath> #include <cstddef> #include <cstdint> #include <limits> #include <optional> #include <string> #include <utility> #include <vector> #include <android-base/logging.h> Loading Loading @@ -61,6 +64,66 @@ TfLiteMotionPredictorSample::Point convertPrediction( } // namespace // --- JerkTracker --- JerkTracker::JerkTracker(bool normalizedDt) : mNormalizedDt(normalizedDt) {} void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) { mTimestamps.pushBack(timestamp); const int numSamples = mTimestamps.size(); std::array<float, 4> newXDerivatives; std::array<float, 4> newYDerivatives; /** * Diagram showing the calculation of higher order derivatives of sample x3 * collected at time=t3. * Terms in parentheses are not stored (and not needed for calculations) * t0 ----- t1 ----- t2 ----- t3 * (x0)-----(x1) ----- x2 ----- x3 * (x'0) --- x'1 --- x'2 * x''0 - x''1 * x'''0 * * In this example: * x'2 = (x3 - x2) / (t3 - t2) * x''1 = (x'2 - x'1) / (t2 - t1) * x'''0 = (x''1 - x''0) / (t1 - t0) * Therefore, timestamp history is needed to calculate higher order derivatives, * compared to just the last calculated derivative sample. * * If mNormalizedDt = true, then dt = 1 and the division is moot. */ for (int i = 0; i < numSamples; ++i) { if (i == 0) { newXDerivatives[i] = xPos; newYDerivatives[i] = yPos; } else { newXDerivatives[i] = newXDerivatives[i - 1] - mXDerivatives[i - 1]; newYDerivatives[i] = newYDerivatives[i - 1] - mYDerivatives[i - 1]; if (!mNormalizedDt) { const float dt = mTimestamps[numSamples - i] - mTimestamps[numSamples - i - 1]; newXDerivatives[i] = newXDerivatives[i] / dt; newYDerivatives[i] = newYDerivatives[i] / dt; } } } std::swap(newXDerivatives, mXDerivatives); std::swap(newYDerivatives, mYDerivatives); } void JerkTracker::reset() { mTimestamps.clear(); } std::optional<float> JerkTracker::jerkMagnitude() const { if (mTimestamps.size() == mTimestamps.capacity()) { return std::hypot(mXDerivatives[3], mYDerivatives[3]); } return std::nullopt; } // --- MotionPredictor --- MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, Loading Loading @@ -107,6 +170,7 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) { if (action == AMOTION_EVENT_ACTION_UP || action == AMOTION_EVENT_ACTION_CANCEL) { ALOGD_IF(isDebug(), "End of event stream"); mBuffers->reset(); mJerkTracker.reset(); mLastEvent.reset(); return {}; } else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) { Loading Loading @@ -141,6 +205,9 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) { 0, i), .orientation = event.getHistoricalOrientation(0, i), }); mJerkTracker.pushSample(event.getHistoricalEventTime(i), coords->getAxisValue(AMOTION_EVENT_AXIS_X), coords->getAxisValue(AMOTION_EVENT_AXIS_Y)); } if (!mLastEvent) { Loading libs/input/tests/MotionPredictor_test.cpp +103 −0 Original line number Diff line number Diff line Loading @@ -15,6 +15,7 @@ */ #include <chrono> #include <cmath> #include <gmock/gmock.h> #include <gtest/gtest.h> Loading Loading @@ -65,6 +66,108 @@ static MotionEvent getMotionEvent(int32_t action, float x, float y, return event; } TEST(JerkTrackerTest, JerkReadiness) { JerkTracker jerkTracker(true); EXPECT_FALSE(jerkTracker.jerkMagnitude()); jerkTracker.pushSample(/*timestamp=*/0, 20, 50); EXPECT_FALSE(jerkTracker.jerkMagnitude()); jerkTracker.pushSample(/*timestamp=*/1, 25, 53); EXPECT_FALSE(jerkTracker.jerkMagnitude()); jerkTracker.pushSample(/*timestamp=*/2, 30, 60); EXPECT_FALSE(jerkTracker.jerkMagnitude()); jerkTracker.pushSample(/*timestamp=*/3, 35, 70); EXPECT_TRUE(jerkTracker.jerkMagnitude()); jerkTracker.reset(); EXPECT_FALSE(jerkTracker.jerkMagnitude()); jerkTracker.pushSample(/*timestamp=*/4, 30, 60); EXPECT_FALSE(jerkTracker.jerkMagnitude()); } TEST(JerkTrackerTest, JerkCalculationNormalizedDtTrue) { JerkTracker jerkTracker(true); jerkTracker.pushSample(/*timestamp=*/0, 20, 50); jerkTracker.pushSample(/*timestamp=*/1, 25, 53); jerkTracker.pushSample(/*timestamp=*/2, 30, 60); jerkTracker.pushSample(/*timestamp=*/3, 45, 70); /** * Jerk derivative table * x: 20 25 30 45 * x': 5 5 15 * x'': 0 10 * x''': 10 * * y: 50 53 60 70 * y': 3 7 10 * y'': 4 3 * y''': -1 */ EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(10, -1)); jerkTracker.pushSample(/*timestamp=*/4, 20, 65); /** * (continuing from above table) * x: 45 -> 20 * x': 15 -> -25 * x'': 10 -> -40 * x''': -50 * * y: 70 -> 65 * y': 10 -> -5 * y'': 3 -> -15 * y''': -18 */ EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(-50, -18)); } TEST(JerkTrackerTest, JerkCalculationNormalizedDtFalse) { JerkTracker jerkTracker(false); jerkTracker.pushSample(/*timestamp=*/0, 20, 50); jerkTracker.pushSample(/*timestamp=*/10, 25, 53); jerkTracker.pushSample(/*timestamp=*/20, 30, 60); jerkTracker.pushSample(/*timestamp=*/30, 45, 70); /** * Jerk derivative table * x: 20 25 30 45 * x': .5 .5 1.5 * x'': 0 .1 * x''': .01 * * y: 50 53 60 70 * y': .3 .7 1 * y'': .04 .03 * y''': -.001 */ EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(.01, -.001)); jerkTracker.pushSample(/*timestamp=*/50, 20, 65); /** * (continuing from above table) * x: 45 -> 20 * x': 1.5 -> -1.25 (delta above, divide by 20) * x'': .1 -> -.275 (delta above, divide by 10) * x''': -.0375 (delta above, divide by 10) * * y: 70 -> 65 * y': 1 -> -.25 (delta above, divide by 20) * y'': .03 -> -.125 (delta above, divide by 10) * y''': -.0155 (delta above, divide by 10) */ EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(-.0375, -.0155)); } TEST(JerkTrackerTest, JerkCalculationAfterReset) { JerkTracker jerkTracker(true); jerkTracker.pushSample(/*timestamp=*/0, 20, 50); jerkTracker.pushSample(/*timestamp=*/1, 25, 53); jerkTracker.pushSample(/*timestamp=*/2, 30, 60); jerkTracker.pushSample(/*timestamp=*/3, 45, 70); jerkTracker.pushSample(/*timestamp=*/4, 20, 65); jerkTracker.reset(); jerkTracker.pushSample(/*timestamp=*/5, 20, 50); jerkTracker.pushSample(/*timestamp=*/6, 25, 53); jerkTracker.pushSample(/*timestamp=*/7, 30, 60); jerkTracker.pushSample(/*timestamp=*/8, 45, 70); EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(10, -1)); } TEST(MotionPredictorTest, IsPredictionAvailable) { MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, []() { return true /*enable prediction*/; }); Loading Loading
include/input/MotionPredictor.h +32 −0 Original line number Diff line number Diff line Loading @@ -16,6 +16,7 @@ #pragma once #include <array> #include <cstdint> #include <memory> #include <mutex> Loading @@ -28,6 +29,7 @@ #include <android/sysprop/InputProperties.sysprop.h> #include <input/Input.h> #include <input/MotionPredictorMetricsManager.h> #include <input/RingBuffer.h> #include <input/TfLiteMotionPredictor.h> #include <utils/Timers.h> // for nsecs_t Loading @@ -37,6 +39,31 @@ static inline bool isMotionPredictionEnabled() { return sysprop::InputProperties::enable_motion_prediction().value_or(true); } // Tracker to calculate jerk from motion position samples. class JerkTracker { public: // Initialize the tracker. If normalizedDt is true, assume that each sample pushed has dt=1. JerkTracker(bool normalizedDt); // Add a position to the tracker and update derivative estimates. void pushSample(int64_t timestamp, float xPos, float yPos); // Reset JerkTracker for a new motion input. void reset(); // Return last jerk calculation, if enough samples have been collected. // Jerk is defined as the 3rd derivative of position (change in // acceleration) and has the units of d^3p/dt^3. std::optional<float> jerkMagnitude() const; private: const bool mNormalizedDt; RingBuffer<int64_t> mTimestamps{4}; std::array<float, 4> mXDerivatives{}; // [x, x', x'', x'''] std::array<float, 4> mYDerivatives{}; // [y, y', y'', y'''] }; /** * Given a set of MotionEvents for the current gesture, predict the motion. The returned MotionEvent * contains a set of samples in the future. Loading Loading @@ -97,6 +124,11 @@ private: std::unique_ptr<TfLiteMotionPredictorBuffers> mBuffers; std::optional<MotionEvent> mLastEvent; // mJerkTracker assumes normalized dt = 1 between recorded samples because // the underlying mModel input also assumes fixed-interval samples. // Normalized dt as 1 is also used to correspond with the similar Jank // implementation from the JetPack MotionPredictor implementation. JerkTracker mJerkTracker{true}; std::optional<MotionPredictorMetricsManager> mMetricsManager; Loading
libs/input/MotionPredictor.cpp +67 −0 Original line number Diff line number Diff line Loading @@ -18,12 +18,15 @@ #include <input/MotionPredictor.h> #include <array> #include <cinttypes> #include <cmath> #include <cstddef> #include <cstdint> #include <limits> #include <optional> #include <string> #include <utility> #include <vector> #include <android-base/logging.h> Loading Loading @@ -61,6 +64,66 @@ TfLiteMotionPredictorSample::Point convertPrediction( } // namespace // --- JerkTracker --- JerkTracker::JerkTracker(bool normalizedDt) : mNormalizedDt(normalizedDt) {} void JerkTracker::pushSample(int64_t timestamp, float xPos, float yPos) { mTimestamps.pushBack(timestamp); const int numSamples = mTimestamps.size(); std::array<float, 4> newXDerivatives; std::array<float, 4> newYDerivatives; /** * Diagram showing the calculation of higher order derivatives of sample x3 * collected at time=t3. * Terms in parentheses are not stored (and not needed for calculations) * t0 ----- t1 ----- t2 ----- t3 * (x0)-----(x1) ----- x2 ----- x3 * (x'0) --- x'1 --- x'2 * x''0 - x''1 * x'''0 * * In this example: * x'2 = (x3 - x2) / (t3 - t2) * x''1 = (x'2 - x'1) / (t2 - t1) * x'''0 = (x''1 - x''0) / (t1 - t0) * Therefore, timestamp history is needed to calculate higher order derivatives, * compared to just the last calculated derivative sample. * * If mNormalizedDt = true, then dt = 1 and the division is moot. */ for (int i = 0; i < numSamples; ++i) { if (i == 0) { newXDerivatives[i] = xPos; newYDerivatives[i] = yPos; } else { newXDerivatives[i] = newXDerivatives[i - 1] - mXDerivatives[i - 1]; newYDerivatives[i] = newYDerivatives[i - 1] - mYDerivatives[i - 1]; if (!mNormalizedDt) { const float dt = mTimestamps[numSamples - i] - mTimestamps[numSamples - i - 1]; newXDerivatives[i] = newXDerivatives[i] / dt; newYDerivatives[i] = newYDerivatives[i] / dt; } } } std::swap(newXDerivatives, mXDerivatives); std::swap(newYDerivatives, mYDerivatives); } void JerkTracker::reset() { mTimestamps.clear(); } std::optional<float> JerkTracker::jerkMagnitude() const { if (mTimestamps.size() == mTimestamps.capacity()) { return std::hypot(mXDerivatives[3], mYDerivatives[3]); } return std::nullopt; } // --- MotionPredictor --- MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, Loading Loading @@ -107,6 +170,7 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) { if (action == AMOTION_EVENT_ACTION_UP || action == AMOTION_EVENT_ACTION_CANCEL) { ALOGD_IF(isDebug(), "End of event stream"); mBuffers->reset(); mJerkTracker.reset(); mLastEvent.reset(); return {}; } else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) { Loading Loading @@ -141,6 +205,9 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) { 0, i), .orientation = event.getHistoricalOrientation(0, i), }); mJerkTracker.pushSample(event.getHistoricalEventTime(i), coords->getAxisValue(AMOTION_EVENT_AXIS_X), coords->getAxisValue(AMOTION_EVENT_AXIS_Y)); } if (!mLastEvent) { Loading
libs/input/tests/MotionPredictor_test.cpp +103 −0 Original line number Diff line number Diff line Loading @@ -15,6 +15,7 @@ */ #include <chrono> #include <cmath> #include <gmock/gmock.h> #include <gtest/gtest.h> Loading Loading @@ -65,6 +66,108 @@ static MotionEvent getMotionEvent(int32_t action, float x, float y, return event; } TEST(JerkTrackerTest, JerkReadiness) { JerkTracker jerkTracker(true); EXPECT_FALSE(jerkTracker.jerkMagnitude()); jerkTracker.pushSample(/*timestamp=*/0, 20, 50); EXPECT_FALSE(jerkTracker.jerkMagnitude()); jerkTracker.pushSample(/*timestamp=*/1, 25, 53); EXPECT_FALSE(jerkTracker.jerkMagnitude()); jerkTracker.pushSample(/*timestamp=*/2, 30, 60); EXPECT_FALSE(jerkTracker.jerkMagnitude()); jerkTracker.pushSample(/*timestamp=*/3, 35, 70); EXPECT_TRUE(jerkTracker.jerkMagnitude()); jerkTracker.reset(); EXPECT_FALSE(jerkTracker.jerkMagnitude()); jerkTracker.pushSample(/*timestamp=*/4, 30, 60); EXPECT_FALSE(jerkTracker.jerkMagnitude()); } TEST(JerkTrackerTest, JerkCalculationNormalizedDtTrue) { JerkTracker jerkTracker(true); jerkTracker.pushSample(/*timestamp=*/0, 20, 50); jerkTracker.pushSample(/*timestamp=*/1, 25, 53); jerkTracker.pushSample(/*timestamp=*/2, 30, 60); jerkTracker.pushSample(/*timestamp=*/3, 45, 70); /** * Jerk derivative table * x: 20 25 30 45 * x': 5 5 15 * x'': 0 10 * x''': 10 * * y: 50 53 60 70 * y': 3 7 10 * y'': 4 3 * y''': -1 */ EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(10, -1)); jerkTracker.pushSample(/*timestamp=*/4, 20, 65); /** * (continuing from above table) * x: 45 -> 20 * x': 15 -> -25 * x'': 10 -> -40 * x''': -50 * * y: 70 -> 65 * y': 10 -> -5 * y'': 3 -> -15 * y''': -18 */ EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(-50, -18)); } TEST(JerkTrackerTest, JerkCalculationNormalizedDtFalse) { JerkTracker jerkTracker(false); jerkTracker.pushSample(/*timestamp=*/0, 20, 50); jerkTracker.pushSample(/*timestamp=*/10, 25, 53); jerkTracker.pushSample(/*timestamp=*/20, 30, 60); jerkTracker.pushSample(/*timestamp=*/30, 45, 70); /** * Jerk derivative table * x: 20 25 30 45 * x': .5 .5 1.5 * x'': 0 .1 * x''': .01 * * y: 50 53 60 70 * y': .3 .7 1 * y'': .04 .03 * y''': -.001 */ EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(.01, -.001)); jerkTracker.pushSample(/*timestamp=*/50, 20, 65); /** * (continuing from above table) * x: 45 -> 20 * x': 1.5 -> -1.25 (delta above, divide by 20) * x'': .1 -> -.275 (delta above, divide by 10) * x''': -.0375 (delta above, divide by 10) * * y: 70 -> 65 * y': 1 -> -.25 (delta above, divide by 20) * y'': .03 -> -.125 (delta above, divide by 10) * y''': -.0155 (delta above, divide by 10) */ EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(-.0375, -.0155)); } TEST(JerkTrackerTest, JerkCalculationAfterReset) { JerkTracker jerkTracker(true); jerkTracker.pushSample(/*timestamp=*/0, 20, 50); jerkTracker.pushSample(/*timestamp=*/1, 25, 53); jerkTracker.pushSample(/*timestamp=*/2, 30, 60); jerkTracker.pushSample(/*timestamp=*/3, 45, 70); jerkTracker.pushSample(/*timestamp=*/4, 20, 65); jerkTracker.reset(); jerkTracker.pushSample(/*timestamp=*/5, 20, 50); jerkTracker.pushSample(/*timestamp=*/6, 25, 53); jerkTracker.pushSample(/*timestamp=*/7, 30, 60); jerkTracker.pushSample(/*timestamp=*/8, 45, 70); EXPECT_FLOAT_EQ(jerkTracker.jerkMagnitude().value(), std::hypot(10, -1)); } TEST(MotionPredictorTest, IsPredictionAvailable) { MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, []() { return true /*enable prediction*/; }); Loading