Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 3c60252e authored by Philip Quinn's avatar Philip Quinn Committed by Android (Google) Code Review
Browse files

Merge "Update motion prediction model." into udc-qpr-dev

parents 027baa7a 107ce707
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -16,5 +16,20 @@
<motion-predictor>
  <!-- The time interval (ns) between the model's predictions. -->
  <prediction-interval>4166666</prediction-interval>  <!-- 4.167 ms = ~240 Hz -->
  <!-- The noise floor (px) for predicted distances.

       As the model is trained stochastically, there is some expected minimum
       variability in its output. This can be a UX issue when the input device
       is moving slowly and the variability is large relative to the magnitude
       of the motion. In these cases, it is better to inhibit the prediction,
       rather than show noisy predictions (and there is little benefit to
       prediction anyway).

       The value for this parameter should at least be close to the maximum
       predicted distance when the input device is held stationary (i.e. the
       expected minimum variability), and perhaps a little larger to capture
       the UX issue mentioned above.
  -->
  <distance-noise-floor>0.2</distance-noise-floor>
</motion-predictor>
+11 −4
Original line number Diff line number Diff line
@@ -99,6 +99,14 @@ private:
// A TFLite model for generating motion predictions.
class TfLiteMotionPredictorModel {
public:
    struct Config {
        // The time between predictions.
        nsecs_t predictionInterval = 0;
        // The noise floor for predictions.
        // Distances (r) less than this should be discarded as noise.
        float distanceNoiseFloor = 0;
    };

    // Creates a model from an encoded Flatbuffer model.
    static std::unique_ptr<TfLiteMotionPredictorModel> create();

@@ -110,8 +118,7 @@ public:
    // Returns the length of the model's output buffers.
    size_t outputLength() const;

    // Returns the time interval between predictions.
    nsecs_t predictionInterval() const { return mPredictionInterval; }
    const Config& config() const { return mConfig; }

    // Executes the model.
    // Returns true if the model successfully executed and the output tensors can be read.
@@ -132,7 +139,7 @@ public:

private:
    explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model,
                                        nsecs_t predictionInterval);
                                        Config config);

    void allocateTensors();
    void attachInputTensors();
@@ -154,7 +161,7 @@ private:
    std::unique_ptr<tflite::Interpreter> mInterpreter;
    tflite::SignatureRunner* mRunner = nullptr;

    const nsecs_t mPredictionInterval = 0;
    const Config mConfig = {};
};

} // namespace android
+15 −4
Original line number Diff line number Diff line
@@ -138,7 +138,8 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
    // Pass input event to the MetricsManager.
    if (!mMetricsManager) {
        mMetricsManager =
                std::make_optional<MotionPredictorMetricsManager>(mModel->predictionInterval(),
                std::make_optional<MotionPredictorMetricsManager>(mModel->config()
                                                                          .predictionInterval,
                                                                  mModel->outputLength());
    }
    mMetricsManager->onRecord(event);
@@ -184,8 +185,18 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) {
    const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;

    for (int i = 0; i < predictedR.size() && predictionTime <= futureTime; ++i) {
        // TODO(b/266747654): Stop predictions if confidence and/or predicted pressure are below
        // some thresholds.
        if (predictedR[i] < mModel->config().distanceNoiseFloor) {
            // Stop predicting when the predicted output is below the model's noise floor.
            //
            // We assume that all subsequent predictions in the batch are unreliable because later
            // predictions are conditional on earlier predictions, and a state of noise is not a
            // good basis for prediction.
            //
            // The UX trade-off is that this potentially sacrifices some predictions when the input
            // device starts to speed up, but avoids producing noisy predictions as it slows down.
            break;
        }
        // TODO(b/266747654): Stop predictions if confidence is < some threshold.

        const TfLiteMotionPredictorSample::Point predictedPoint =
                convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]);
@@ -197,7 +208,7 @@ std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) {
        coords.setAxisValue(AMOTION_EVENT_AXIS_Y, predictedPoint.y);
        coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]);

        predictionTime += mModel->predictionInterval();
        predictionTime += mModel->config().predictionInterval;
        if (i == 0) {
            hasPredictions = true;
            prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(),
+23 −13
Original line number Diff line number Diff line
@@ -100,6 +100,16 @@ int64_t parseXMLInt64(const tinyxml2::XMLElement& configRoot, const char* elemen
    return value;
}

float parseXMLFloat(const tinyxml2::XMLElement& configRoot, const char* elementName) {
    const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName);
    LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName);

    float value = 0;
    LOG_ALWAYS_FATAL_IF(element->QueryFloatText(&value) != tinyxml2::XML_SUCCESS,
                        "Failed to parse %s: %s", elementName, element->GetText());
    return value;
}

// A TFLite ErrorReporter that logs to logcat.
class LoggingErrorReporter : public tflite::ErrorReporter {
public:
@@ -152,6 +162,7 @@ std::unique_ptr<tflite::OpResolver> createOpResolver() {
                         ::tflite::ops::builtin::Register_CONCATENATION());
    resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
                         ::tflite::ops::builtin::Register_FULLY_CONNECTED());
    resolver->AddBuiltin(::tflite::BuiltinOperator_GELU, ::tflite::ops::builtin::Register_GELU());
    return resolver;
}

@@ -208,13 +219,7 @@ void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
    float phi = 0;
    float orientation = 0;

    // Ignore the sample if there is no movement. These samples can occur when there's change to a
    // property other than the coordinates and pollute the input to the model.
    if (r == 0) {
        return;
    }

    if (!mAxisFrom) { // Second point.
    if (!mAxisFrom && r > 0) { // Second point.
        // We can only determine the distance from the first point, and not any
        // angle. However, if the second point forms an axis, the orientation can
        // be transformed relative to that axis.
@@ -235,8 +240,10 @@ void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
    }

    // Update the axis for the next point.
    if (r > 0) {
        mAxisFrom = mAxisTo;
        mAxisTo = sample;
    }

    // Push the current sample onto the end of the input buffers.
    mInputR.pushBack(r);
@@ -272,15 +279,18 @@ std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create()
    // Parse configuration file.
    const tinyxml2::XMLElement* configRoot = configDocument.FirstChildElement("motion-predictor");
    LOG_ALWAYS_FATAL_IF(!configRoot);
    const nsecs_t predictionInterval = parseXMLInt64(*configRoot, "prediction-interval");
    Config config{
            .predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"),
            .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"),
    };

    return std::unique_ptr<TfLiteMotionPredictorModel>(
            new TfLiteMotionPredictorModel(std::move(modelBuffer), predictionInterval));
            new TfLiteMotionPredictorModel(std::move(modelBuffer), std::move(config)));
}

TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
        std::unique_ptr<android::base::MappedFile> model, nsecs_t predictionInterval)
      : mFlatBuffer(std::move(model)), mPredictionInterval(predictionInterval) {
        std::unique_ptr<android::base::MappedFile> model, Config config)
      : mFlatBuffer(std::move(model)), mConfig(std::move(config)) {
    CHECK(mFlatBuffer);
    mErrorReporter = std::make_unique<LoggingErrorReporter>();
    mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),
Loading