Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 6c40f008 authored by Cody Heiner's avatar Cody Heiner Committed by Android (Google) Code Review
Browse files

Merge "Implement Stylus Prediction Metrics" into main

parents 651244f2 52db4741
Loading
Loading
Loading
Loading
+176 −6
Original line number Diff line number Diff line
/*
 * Copyright (C) 2023 The Android Open Source Project
 * Copyright 2023 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
@@ -14,23 +14,193 @@
 * limitations under the License.
 */

#include <utils/Timers.h>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <limits>
#include <optional>
#include <vector>

#include <input/Input.h> // for MotionEvent
#include <input/RingBuffer.h>
#include <utils/Timers.h> // for nsecs_t

#include "Eigen/Core"

namespace android {

/**
 * Class to handle computing and reporting metrics for MotionPredictor.
 *
 * Currently an empty implementation, containing only the API.
 * The public API provides two methods: `onRecord` and `onPredict`, which expect to receive the
 * MotionEvents from the corresponding methods in MotionPredictor.
 *
 * This class stores AggregatedStrokeMetrics, updating them as new MotionEvents are passed in. When
 * onRecord receives an UP or CANCEL event, this indicates the end of the stroke, and the final
 * AtomFields are computed and reported to the stats library.
 *
 * If mMockLoggedAtomFields is set, the batch of AtomFields that are reported to the stats library
 * for one stroke are also stored in mMockLoggedAtomFields at the time they're reported.
 */
class MotionPredictorMetricsManager {
public:
    // Note: the MetricsManager assumes that the input interval equals the prediction interval.
    MotionPredictorMetricsManager(nsecs_t /*predictionInterval*/, size_t /*maxNumPredictions*/) {}
    MotionPredictorMetricsManager(nsecs_t predictionInterval, size_t maxNumPredictions);

    // This method should be called once for each call to MotionPredictor::record, receiving the
    // forwarded MotionEvent argument.
    void onRecord(const MotionEvent& inputEvent);

    // This method should be called once for each call to MotionPredictor::predict, receiving the
    // MotionEvent that will be returned by MotionPredictor::predict.
    void onPredict(const MotionEvent& predictionEvent);

    // Simple structs to hold relevant touch input information. Public so they can be used in tests.

    struct TouchPoint {
        Eigen::Vector2f position; // (y, x) in pixels
        float pressure;
    };

    struct GroundTruthPoint : TouchPoint {
        nsecs_t timestamp;
    };

    struct PredictionPoint : TouchPoint {
        // The timestamp of the last ground truth point when the prediction was made.
        nsecs_t originTimestamp;

        nsecs_t targetTimestamp;

        // Order by targetTimestamp when sorting.
        bool operator<(const PredictionPoint& other) const {
            return this->targetTimestamp < other.targetTimestamp;
        }
    };

    // Metrics aggregated so far for the current stroke. These are not the final fields to be
    // reported in the atom (see AtomFields below), but rather an intermediate representation of the
    // data that can be conveniently aggregated and from which the atom fields can be derived later.
    //
    // Displacement units are in pixels.
    //
    // "Along-trajectory error" is the dot product of the prediction error with the unit vector
    // pointing towards the ground truth point whose timestamp corresponds to the prediction
    // target timestamp, originating from the preceding ground truth point.
    //
    // "Off-trajectory error" is the component of the prediction error orthogonal to the
    // "along-trajectory" unit vector described above.
    //
    // "High-velocity" errors are errors that are only accumulated when the velocity between the
    // most recent two input events exceeds a certain threshold.
    //
    // "Scale-invariant errors" are the errors produced when the path length of the stroke is
    // scaled to 1. (In other words, the error distances are normalized by the path length.)
    struct AggregatedStrokeMetrics {
        // General errors
        float alongTrajectoryErrorSum = 0;
        float alongTrajectorySumSquaredErrors = 0;
        float offTrajectorySumSquaredErrors = 0;
        float pressureSumSquaredErrors = 0;
        size_t generalErrorsCount = 0;

        // High-velocity errors
        float highVelocityAlongTrajectorySse = 0;
        float highVelocityOffTrajectorySse = 0;
        size_t highVelocityErrorsCount = 0;

        // Scale-invariant errors
        float scaleInvariantAlongTrajectorySse = 0;
        float scaleInvariantOffTrajectorySse = 0;
        size_t scaleInvariantErrorsCount = 0;
    };

    // In order to explicitly indicate "no relevant data" for a metric, we report this
    // large-magnitude negative sentinel value. (Most metrics are non-negative, so this value is
    // completely unobtainable. For along-trajectory error mean, which can be negative, the
    // magnitude makes it unobtainable in practice.)
    static const int NO_DATA_SENTINEL = std::numeric_limits<int32_t>::min();

    // Final metrics reported in the atom.
    struct AtomFields {
        int deltaTimeBucketMilliseconds = 0;

        // General errors
        int alongTrajectoryErrorMeanMillipixels = NO_DATA_SENTINEL;
        int alongTrajectoryErrorStdMillipixels = NO_DATA_SENTINEL;
        int offTrajectoryRmseMillipixels = NO_DATA_SENTINEL;
        int pressureRmseMilliunits = NO_DATA_SENTINEL;

        // High-velocity errors
        int highVelocityAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
        int highVelocityOffTrajectoryRmse = NO_DATA_SENTINEL;   // millipixels

        // Scale-invariant errors
        int scaleInvariantAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
        int scaleInvariantOffTrajectoryRmse = NO_DATA_SENTINEL;   // millipixels
    };

    // Allow tests to pass in a mock AtomFields pointer.
    //
    // When metrics are reported to the stats library on stroke end, they will also be written to
    // mockLoggedAtomFields, overwriting existing data. The size of mockLoggedAtomFields will equal
    // the number of calls to stats_write for that stroke.
    void setMockLoggedAtomFields(std::vector<AtomFields>* mockLoggedAtomFields) {
        mMockLoggedAtomFields = mockLoggedAtomFields;
    }

private:
    // The interval between consecutive predictions' target timestamps. We assume that the input
    // interval also equals this value.
    const nsecs_t mPredictionInterval;

    // The maximum number of input frames into the future the model can predict.
    // Used to perform time-bucketing of metrics.
    const size_t mMaxNumPredictions;

    // History of mMaxNumPredictions + 1 ground truth points, used to compute scale-invariant
    // error. (Also, the last two points are used to compute the ground truth trajectory.)
    RingBuffer<GroundTruthPoint> mRecentGroundTruthPoints;

    // Predictions having a targetTimestamp after the most recent ground truth point's timestamp.
    // Invariant: sorted in ascending order of targetTimestamp.
    std::vector<PredictionPoint> mRecentPredictions;

    // Containers for the intermediate representation of stroke metrics and the final atom fields.
    // These are indexed by the number of input frames into the future being predicted minus one,
    // and always have size mMaxNumPredictions.
    std::vector<AggregatedStrokeMetrics> mAggregatedMetrics;
    std::vector<AtomFields> mAtomFields;

    // Non-owning pointer to the location of mock AtomFields. If present, will be filled with the
    // values reported to stats_write on each batch of reported metrics.
    //
    // This pointer must remain valid as long as the MotionPredictorMetricsManager exists.
    std::vector<AtomFields>* mMockLoggedAtomFields = nullptr;

    // Helper methods for the implementation of onRecord and onPredict.

    // Clears stored ground truth and prediction points, as well as all stored metrics for the
    // current stroke.
    void clearStrokeData();

    // Adds the new ground truth point to mRecentGroundTruths, removes outdated predictions from
    // mRecentPredictions, and updates the aggregated metrics to include the recent predictions that
    // fuzzily match with the new ground truth point.
    void incorporateNewGroundTruth(const GroundTruthPoint& groundTruthPoint);

    // Given a new prediction with targetTimestamp matching the latest ground truth point's
    // timestamp, computes the corresponding metrics and updates mAggregatedMetrics.
    void updateAggregatedMetrics(const PredictionPoint& predictionPoint);

    void onRecord(const MotionEvent& /*inputEvent*/) {}
    // Computes the atom fields to mAtomFields from the values in mAggregatedMetrics.
    void computeAtomFields();

    void onPredict(const MotionEvent& /*predictionEvent*/) {}
    // Reports the metrics given by the current data in mAtomFields:
    //  • If on an Android device, reports the metrics to stats_write.
    //  • If mMockLoggedAtomFields is present, it will be overwritten with logged metrics, with one
    //    AtomFields element per call to stats_write.
    void reportMetrics();
};

} // namespace android
+54 −1
Original line number Diff line number Diff line
@@ -139,6 +139,7 @@ cc_library {
        "KeyCharacterMap.cpp",
        "KeyLayoutMap.cpp",
        "MotionPredictor.cpp",
        "MotionPredictorMetricsManager.cpp",
        "PrintTools.cpp",
        "PropertyMap.cpp",
        "TfLiteMotionPredictor.cpp",
@@ -152,9 +153,13 @@ cc_library {
    header_libs: [
        "flatbuffer_headers",
        "jni_headers",
        "libeigen",
        "tensorflow_headers",
    ],
    export_header_lib_headers: ["jni_headers"],
    export_header_lib_headers: [
        "jni_headers",
        "libeigen",
    ],

    generated_headers: [
        "cxx-bridge-header",
@@ -206,6 +211,17 @@ cc_library {

    target: {
        android: {
            export_shared_lib_headers: ["libbinder"],

            shared_libs: [
                "libutils",
                "libbinder",
                // Stats logging library and its dependencies.
                "libstatslog_libinput",
                "libstatsbootstrap",
                "android.os.statsbootstrap_aidl-cpp",
            ],

            required: [
                "motion_predictor_model_prebuilt",
                "motion_predictor_model_config",
@@ -228,6 +244,43 @@ cc_library {
    },
}

// Use bootstrap version of stats logging library.
// libinput is a bootstrap process (starts early in the boot process), and thus can't use the normal
// `libstatslog` because that requires `libstatssocket`, which is only available later in the boot.
cc_library {
    name: "libstatslog_libinput",
    generated_sources: ["statslog_libinput.cpp"],
    generated_headers: ["statslog_libinput.h"],
    export_generated_headers: ["statslog_libinput.h"],
    shared_libs: [
        "libbinder",
        "libstatsbootstrap",
        "libutils",
        "android.os.statsbootstrap_aidl-cpp",
    ],
}

genrule {
    name: "statslog_libinput.h",
    tools: ["stats-log-api-gen"],
    cmd: "$(location stats-log-api-gen) --header $(genDir)/statslog_libinput.h --module libinput" +
        " --namespace android,stats,libinput --bootstrap",
    out: [
        "statslog_libinput.h",
    ],
}

genrule {
    name: "statslog_libinput.cpp",
    tools: ["stats-log-api-gen"],
    cmd: "$(location stats-log-api-gen) --cpp $(genDir)/statslog_libinput.cpp --module libinput" +
        " --namespace android,stats,libinput --importHeader statslog_libinput.h" +
        " --bootstrap",
    out: [
        "statslog_libinput.cpp",
    ],
}

cc_defaults {
    name: "libinput_fuzz_defaults",
    cpp_std: "c++20",
+1 −4
Original line number Diff line number Diff line
@@ -137,10 +137,7 @@ android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {

    // Pass input event to the MetricsManager.
    if (!mMetricsManager) {
        mMetricsManager =
                std::make_optional<MotionPredictorMetricsManager>(mModel->config()
                                                                          .predictionInterval,
                                                                  mModel->outputLength());
        mMetricsManager.emplace(mModel->config().predictionInterval, mModel->outputLength());
    }
    mMetricsManager->onRecord(event);

+373 −0

File added.

Preview size limit exceeded, changes collapsed.

+16 −7
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@ cc_test {
        "InputPublisherAndConsumer_test.cpp",
        "InputVerifier_test.cpp",
        "MotionPredictor_test.cpp",
        "MotionPredictorMetricsManager_test.cpp",
        "RingBuffer_test.cpp",
        "TfLiteMotionPredictor_test.cpp",
        "TouchResampling_test.cpp",
@@ -52,13 +53,6 @@ cc_test {
            undefined: true,
        },
    },
    target: {
        host: {
            sanitize: {
                address: true,
            },
        },
    },
    shared_libs: [
        "libbase",
        "libbinder",
@@ -77,6 +71,21 @@ cc_test {
        unit_test: true,
    },
    test_suites: ["device-tests"],
    target: {
        host: {
            sanitize: {
                address: true,
            },
        },
        android: {
            static_libs: [
                // Stats logging library and its dependencies.
                "libstatslog_libinput",
                "libstatsbootstrap",
                "android.os.statsbootstrap_aidl-cpp",
            ],
        },
    },
}

// NOTE: This is a compile time test, and does not need to be
Loading