Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 7b77945a authored by Siarhei Vishniakou's avatar Siarhei Vishniakou Committed by Automerger Merge Worker
Browse files

Merge "Single-device prediction only" into udc-dev am: 020705b7

parents 1d33c61c 020705b7
Loading
Loading
Loading
Loading
+13 −5
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@
#include <string>
#include <unordered_map>

#include <android-base/result.h>
#include <android-base/thread_annotations.h>
#include <android/sysprop/InputProperties.sysprop.h>
#include <input/Input.h>
@@ -68,8 +69,15 @@ public:
     */
    MotionPredictor(nsecs_t predictionTimestampOffsetNanos, const char* modelPath = nullptr,
                    std::function<bool()> checkEnableMotionPrediction = isMotionPredictionEnabled);
    void record(const MotionEvent& event);
    std::vector<std::unique_ptr<MotionEvent>> predict(nsecs_t timestamp);
    /**
     * Record the actual motion received by the view. This event will be used for calculating the
     * predictions.
     *
     * @return empty result if the event was processed correctly, error if the event is not
     * consistent with the previously recorded events.
     */
    android::base::Result<void> record(const MotionEvent& event);
    std::unique_ptr<MotionEvent> predict(nsecs_t timestamp);
    bool isPredictionAvailable(int32_t deviceId, int32_t source);

private:
@@ -78,9 +86,9 @@ private:
    const std::function<bool()> mCheckMotionPredictionEnabled;

    std::unique_ptr<TfLiteMotionPredictorModel> mModel;
    // Buffers/events for each device seen by record().
    std::unordered_map</*deviceId*/ int32_t, TfLiteMotionPredictorBuffers> mDeviceBuffers;
    std::unordered_map</*deviceId*/ int32_t, MotionEvent> mLastEvents;

    std::unique_ptr<TfLiteMotionPredictorBuffers> mBuffers;
    std::optional<MotionEvent> mLastEvent;
};

} // namespace android
+104 −95
Original line number Diff line number Diff line
@@ -68,11 +68,20 @@ MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos, const c
        mModelPath(modelPath == nullptr ? DEFAULT_MODEL_PATH : modelPath),
        mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)) {}

void MotionPredictor::record(const MotionEvent& event) {
android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
    if (mLastEvent && mLastEvent->getDeviceId() != event.getDeviceId()) {
        // We still have an active gesture for another device. The provided MotionEvent is not
        // consistent the previous gesture.
        LOG(ERROR) << "Inconsistent event stream: last event is " << *mLastEvent << ", but "
                   << __func__ << " is called with " << event;
        return android::base::Error()
                << "Inconsistent event stream: still have an active gesture from device "
                << mLastEvent->getDeviceId() << ", but received " << event;
    }
    if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) {
        ALOGE("Prediction not supported for device %d's %s source", event.getDeviceId(),
              inputEventSourceToString(event.getSource()).c_str());
        return;
        return {};
    }

    // Initialise the model now that it's likely to be used.
@@ -80,30 +89,32 @@ void MotionPredictor::record(const MotionEvent& event) {
        mModel = TfLiteMotionPredictorModel::create(mModelPath.c_str());
    }

    TfLiteMotionPredictorBuffers& buffers =
            mDeviceBuffers.try_emplace(event.getDeviceId(), mModel->inputLength()).first->second;
    if (mBuffers == nullptr) {
        mBuffers = std::make_unique<TfLiteMotionPredictorBuffers>(mModel->inputLength());
    }

    const int32_t action = event.getActionMasked();
    if (action == AMOTION_EVENT_ACTION_UP) {
    if (action == AMOTION_EVENT_ACTION_UP || action == AMOTION_EVENT_ACTION_CANCEL) {
        ALOGD_IF(isDebug(), "End of event stream");
        buffers.reset();
        return;
        mBuffers->reset();
        mLastEvent.reset();
        return {};
    } else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) {
        ALOGD_IF(isDebug(), "Skipping unsupported %s action",
                 MotionEvent::actionToString(action).c_str());
        return;
        return {};
    }

    if (event.getPointerCount() != 1) {
        ALOGD_IF(isDebug(), "Prediction not supported for multiple pointers");
        return;
        return {};
    }

    const int32_t toolType = event.getPointerProperties(0)->toolType;
    if (toolType != AMOTION_EVENT_TOOL_TYPE_STYLUS) {
        ALOGD_IF(isDebug(), "Prediction not supported for non-stylus tool: %s",
                 motionToolTypeToString(toolType));
        return;
        return {};
    }

    for (size_t i = 0; i <= event.getHistorySize(); ++i) {
@@ -111,31 +122,31 @@ void MotionPredictor::record(const MotionEvent& event) {
            continue;
        }
        const PointerCoords* coords = event.getHistoricalRawPointerCoords(0, i);
        buffers.pushSample(event.getHistoricalEventTime(i),
        mBuffers->pushSample(event.getHistoricalEventTime(i),
                             {
                                     .position.x = coords->getAxisValue(AMOTION_EVENT_AXIS_X),
                                     .position.y = coords->getAxisValue(AMOTION_EVENT_AXIS_Y),
                                     .pressure = event.getHistoricalPressure(0, i),
                                   .tilt = event.getHistoricalAxisValue(AMOTION_EVENT_AXIS_TILT, 0,
                                                                        i),
                                     .tilt = event.getHistoricalAxisValue(AMOTION_EVENT_AXIS_TILT,
                                                                          0, i),
                                     .orientation = event.getHistoricalOrientation(0, i),
                             });
    }

    mLastEvents.try_emplace(event.getDeviceId())
            .first->second.copyFrom(&event, /*keepHistory=*/false);
    if (!mLastEvent) {
        mLastEvent = MotionEvent();
    }
    mLastEvent->copyFrom(&event, /*keepHistory=*/false);
    return {};
}

std::vector<std::unique_ptr<MotionEvent>> MotionPredictor::predict(nsecs_t timestamp) {
    std::vector<std::unique_ptr<MotionEvent>> predictions;

    for (const auto& [deviceId, buffer] : mDeviceBuffers) {
        if (!buffer.isReady()) {
            continue;
std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) {
    if (mBuffers == nullptr || !mBuffers->isReady()) {
        return nullptr;
    }

    LOG_ALWAYS_FATAL_IF(!mModel);
        buffer.copyTo(*mModel);
    mBuffers->copyTo(*mModel);
    LOG_ALWAYS_FATAL_IF(!mModel->invoke());

    // Read out the predictions.
@@ -143,11 +154,10 @@ std::vector<std::unique_ptr<MotionEvent>> MotionPredictor::predict(nsecs_t times
    const std::span<const float> predictedPhi = mModel->outputPhi();
    const std::span<const float> predictedPressure = mModel->outputPressure();

        TfLiteMotionPredictorSample::Point axisFrom = buffer.axisFrom().position;
        TfLiteMotionPredictorSample::Point axisTo = buffer.axisTo().position;
    TfLiteMotionPredictorSample::Point axisFrom = mBuffers->axisFrom().position;
    TfLiteMotionPredictorSample::Point axisTo = mBuffers->axisTo().position;

    if (isDebug()) {
            ALOGD("deviceId: %d", deviceId);
        ALOGD("axisFrom: %f, %f", axisFrom.x, axisFrom.y);
        ALOGD("axisTo: %f, %f", axisTo.x, axisTo.y);
        ALOGD("mInputR: %s", base::Join(mModel->inputR(), ", ").c_str());
@@ -160,10 +170,11 @@ std::vector<std::unique_ptr<MotionEvent>> MotionPredictor::predict(nsecs_t times
        ALOGD("predictedPressure: %s", base::Join(predictedPressure, ", ").c_str());
    }

        const MotionEvent& event = mLastEvents[deviceId];
    LOG_ALWAYS_FATAL_IF(!mLastEvent);
    const MotionEvent& event = *mLastEvent;
    bool hasPredictions = false;
    std::unique_ptr<MotionEvent> prediction = std::make_unique<MotionEvent>();
        int64_t predictionTime = buffer.lastTimestamp();
    int64_t predictionTime = mBuffers->lastTimestamp();
    const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;

    for (int i = 0; i < predictedR.size() && predictionTime <= futureTime; ++i) {
@@ -183,15 +194,14 @@ std::vector<std::unique_ptr<MotionEvent>> MotionPredictor::predict(nsecs_t times
        if (i == 0) {
            hasPredictions = true;
            prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(),
                                       event.getDisplayId(), INVALID_HMAC,
                                       AMOTION_EVENT_ACTION_MOVE, event.getActionButton(),
                                       event.getFlags(), event.getEdgeFlags(), event.getMetaState(),
                                       event.getButtonState(), event.getClassification(),
                                       event.getTransform(), event.getXPrecision(),
                                       event.getYPrecision(), event.getRawXCursorPosition(),
                                       event.getRawYCursorPosition(), event.getRawTransform(),
                                       event.getDownTime(), predictionTime, event.getPointerCount(),
                                       event.getPointerProperties(), &coords);
                                   event.getDisplayId(), INVALID_HMAC, AMOTION_EVENT_ACTION_MOVE,
                                   event.getActionButton(), event.getFlags(), event.getEdgeFlags(),
                                   event.getMetaState(), event.getButtonState(),
                                   event.getClassification(), event.getTransform(),
                                   event.getXPrecision(), event.getYPrecision(),
                                   event.getRawXCursorPosition(), event.getRawYCursorPosition(),
                                   event.getRawTransform(), event.getDownTime(), predictionTime,
                                   event.getPointerCount(), event.getPointerProperties(), &coords);
        } else {
            prediction->addSample(predictionTime, &coords);
        }
@@ -200,11 +210,10 @@ std::vector<std::unique_ptr<MotionEvent>> MotionPredictor::predict(nsecs_t times
        axisTo = point;
    }
    // TODO(b/266747511): Interpolate to futureTime?
        if (hasPredictions) {
            predictions.push_back(std::move(prediction));
        }
    if (!hasPredictions) {
        return nullptr;
    }
    return predictions;
    return prediction;
}

bool MotionPredictor::isPredictionAvailable(int32_t /*deviceId*/, int32_t source) {
+29 −38
Original line number Diff line number Diff line
@@ -84,9 +84,9 @@ TEST(MotionPredictorTest, Offset) {
                              []() { return true /*enable prediction*/; });
    predictor.record(getMotionEvent(DOWN, 0, 1, 30ms));
    predictor.record(getMotionEvent(MOVE, 0, 2, 35ms));
    std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40 * NSEC_PER_MSEC);
    ASSERT_EQ(1u, predicted.size());
    ASSERT_GE(predicted[0]->getEventTime(), 41);
    std::unique_ptr<MotionEvent> predicted = predictor.predict(40 * NSEC_PER_MSEC);
    ASSERT_NE(nullptr, predicted);
    ASSERT_GE(predicted->getEventTime(), 41);
}

TEST(MotionPredictorTest, FollowsGesture) {
@@ -95,52 +95,43 @@ TEST(MotionPredictorTest, FollowsGesture) {

    // MOVE without a DOWN is ignored.
    predictor.record(getMotionEvent(MOVE, 1, 3, 10ms));
    EXPECT_THAT(predictor.predict(20 * NSEC_PER_MSEC), IsEmpty());
    EXPECT_EQ(nullptr, predictor.predict(20 * NSEC_PER_MSEC));

    predictor.record(getMotionEvent(DOWN, 2, 5, 20ms));
    predictor.record(getMotionEvent(MOVE, 2, 7, 30ms));
    predictor.record(getMotionEvent(MOVE, 3, 9, 40ms));
    EXPECT_THAT(predictor.predict(50 * NSEC_PER_MSEC), SizeIs(1));
    EXPECT_NE(nullptr, predictor.predict(50 * NSEC_PER_MSEC));

    predictor.record(getMotionEvent(UP, 4, 11, 50ms));
    EXPECT_THAT(predictor.predict(20 * NSEC_PER_MSEC), IsEmpty());
    EXPECT_EQ(nullptr, predictor.predict(20 * NSEC_PER_MSEC));
}

TEST(MotionPredictorTest, MultipleDevicesTracked) {
TEST(MotionPredictorTest, MultipleDevicesNotSupported) {
    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, MODEL_PATH,
                              []() { return true /*enable prediction*/; });

    predictor.record(getMotionEvent(DOWN, 1, 3, 0ms, /*deviceId=*/0));
    predictor.record(getMotionEvent(MOVE, 1, 3, 10ms, /*deviceId=*/0));
    predictor.record(getMotionEvent(MOVE, 2, 5, 20ms, /*deviceId=*/0));
    predictor.record(getMotionEvent(MOVE, 3, 7, 30ms, /*deviceId=*/0));
    ASSERT_TRUE(predictor.record(getMotionEvent(DOWN, 1, 3, 0ms, /*deviceId=*/0)).ok());
    ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 1, 3, 10ms, /*deviceId=*/0)).ok());
    ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 2, 5, 20ms, /*deviceId=*/0)).ok());
    ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 3, 7, 30ms, /*deviceId=*/0)).ok());

    predictor.record(getMotionEvent(DOWN, 100, 300, 0ms, /*deviceId=*/1));
    predictor.record(getMotionEvent(MOVE, 100, 300, 10ms, /*deviceId=*/1));
    predictor.record(getMotionEvent(MOVE, 200, 500, 20ms, /*deviceId=*/1));
    predictor.record(getMotionEvent(MOVE, 300, 700, 30ms, /*deviceId=*/1));

    {
        std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40 * NSEC_PER_MSEC);
        ASSERT_EQ(2u, predicted.size());

        // Order of the returned vector is not guaranteed.
        std::vector<int32_t> seenDeviceIds;
        for (const auto& prediction : predicted) {
            seenDeviceIds.push_back(prediction->getDeviceId());
        }
        EXPECT_THAT(seenDeviceIds, UnorderedElementsAre(0, 1));
    ASSERT_FALSE(predictor.record(getMotionEvent(DOWN, 100, 300, 40ms, /*deviceId=*/1)).ok());
    ASSERT_FALSE(predictor.record(getMotionEvent(MOVE, 100, 300, 50ms, /*deviceId=*/1)).ok());
}

    // End the gesture for device 0.
    predictor.record(getMotionEvent(UP, 4, 9, 40ms, /*deviceId=*/0));
    predictor.record(getMotionEvent(MOVE, 400, 900, 40ms, /*deviceId=*/1));
TEST(MotionPredictorTest, IndividualGesturesFromDifferentDevicesAreSupported) {
    MotionPredictor predictor(/*predictionTimestampOffsetNanos=*/0, MODEL_PATH,
                              []() { return true /*enable prediction*/; });

    {
        std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40 * NSEC_PER_MSEC);
        ASSERT_EQ(1u, predicted.size());
        ASSERT_EQ(predicted[0]->getDeviceId(), 1);
    }
    ASSERT_TRUE(predictor.record(getMotionEvent(DOWN, 1, 3, 0ms, /*deviceId=*/0)).ok());
    ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 1, 3, 10ms, /*deviceId=*/0)).ok());
    ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 2, 5, 20ms, /*deviceId=*/0)).ok());
    ASSERT_TRUE(predictor.record(getMotionEvent(UP, 2, 5, 30ms, /*deviceId=*/0)).ok());

    // Now, send a gesture from a different device. Since we have no active gesture, the new gesture
    // should be processed correctly.
    ASSERT_TRUE(predictor.record(getMotionEvent(DOWN, 100, 300, 40ms, /*deviceId=*/1)).ok());
    ASSERT_TRUE(predictor.record(getMotionEvent(MOVE, 100, 300, 50ms, /*deviceId=*/1)).ok());
}

TEST(MotionPredictorTest, FlagDisablesPrediction) {
@@ -148,8 +139,8 @@ TEST(MotionPredictorTest, FlagDisablesPrediction) {
                              []() { return false /*disable prediction*/; });
    predictor.record(getMotionEvent(DOWN, 0, 1, 30ms));
    predictor.record(getMotionEvent(MOVE, 0, 1, 35ms));
    std::vector<std::unique_ptr<MotionEvent>> predicted = predictor.predict(40 * NSEC_PER_MSEC);
    ASSERT_EQ(0u, predicted.size());
    std::unique_ptr<MotionEvent> predicted = predictor.predict(40 * NSEC_PER_MSEC);
    ASSERT_EQ(nullptr, predicted);
    ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_STYLUS));
    ASSERT_FALSE(predictor.isPredictionAvailable(/*deviceId=*/1, AINPUT_SOURCE_TOUCHSCREEN));
}