Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 609bcdfa authored by Treehugger Robot's avatar Treehugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Allow recording the spring state in tests" into main

parents d4ba09a9 76675d9a
Loading
Loading
Loading
Loading
+9 −22
Original line number Diff line number Diff line
@@ -24,7 +24,6 @@ import androidx.compose.runtime.setValue
import androidx.compose.runtime.snapshots.Snapshot
import com.android.mechanics.DistanceGestureContext
import com.android.mechanics.MotionValue
import com.android.mechanics.debug.FrameData
import com.android.mechanics.spec.InputDirection
import com.android.mechanics.spec.MotionSpec
import kotlinx.coroutines.Dispatchers
@@ -50,12 +49,12 @@ data object ComposeMotionValueToolkit : MotionValueToolkit<MotionValue, Distance
        motionTestRule: MotionTestRule<*>,
        spec: MotionSpec,
        createDerived: (underTest: MotionValue) -> List<MotionValue>,
        semantics: List<CapturedSemantics<*>>,
        initialValue: Float,
        initialDirection: InputDirection,
        directionChangeSlop: Float,
        stableThreshold: Float,
        verifyTimeSeries: TimeSeries.() -> VerifyTimeSeriesResult,
        capture: CaptureTimeSeriesFn,
        testInput: suspend InputScope<MotionValue, DistanceGestureContext>.() -> Unit,
    ) = runMonotonicClockTest {
        val frameEmitter = MutableStateFlow<Long>(0)
@@ -73,21 +72,20 @@ data object ComposeMotionValueToolkit : MotionValueToolkit<MotionValue, Distance
        val underTest = testHarness.underTest
        val derived = testHarness.derived

        val motionValues = derived + underTest
        val motionValueCaptures = buildList {
            add(MotionValueCapture(underTest.debugInspector()))
            derived.forEach { add(MotionValueCapture(it.debugInspector(), "${it.label}-")) }
        }

        val inspectors = motionValues.map { it to it.debugInspector() }.toMap()
        val keepRunningJobs = motionValues.map { launch { it.keepRunning() } }
        val keepRunningJobs = (derived + underTest).map { launch { it.keepRunning() } }

        val recordingJob = launch { testInput.invoke(testHarness) }

        val frameIds = mutableListOf<FrameId>()
        val frameData = mutableMapOf<MotionValue, MutableList<FrameData>>()

        fun recordFrame(frameId: TimestampFrameId) {
            frameIds.add(frameId)
            inspectors.forEach { (motionValue, inspector) ->
                frameData.computeIfAbsent(motionValue) { mutableListOf() }.add(inspector.frame)
            }
            motionValueCaptures.forEach { it.captureCurrentFrame(capture) }
        }
        runBlocking(Dispatchers.Main) {
            val startFrameTime = testScheduler.currentTime
@@ -110,19 +108,8 @@ data object ComposeMotionValueToolkit : MotionValueToolkit<MotionValue, Distance
            }
        }

        val timeSeries =
            createTimeSeries(
                frameIds,
                frameData.entries
                    .map { (motionValue, frameData) ->
                        val prefix = if (motionValue == underTest) "" else "${motionValue.label}-"
                        prefix to frameData
                    }
                    .sortedBy { it.first },
                semantics,
            )

        inspectors.values.forEach { it.dispose() }
        val timeSeries = createTimeSeries(frameIds, motionValueCaptures)
        motionValueCaptures.forEach { it.debugger.dispose() }
        keepRunningJobs.forEach { it.cancel() }
        verifyTimeSeries(motionTestRule, timeSeries, verifyTimeSeries)
    }
+23 −0
Original line number Diff line number Diff line
@@ -17,13 +17,17 @@
package com.android.mechanics.testing

import com.android.mechanics.spring.SpringParameters
import com.android.mechanics.spring.SpringState
import com.android.mechanics.testing.DataPointTypes.springParameters
import com.android.mechanics.testing.DataPointTypes.springState
import org.json.JSONObject
import platform.test.motion.golden.DataPointType
import platform.test.motion.golden.UnknownTypeException

fun SpringParameters.asDataPoint() = springParameters.makeDataPoint(this)

fun SpringState.asDataPoint() = springState.makeDataPoint(this)

object DataPointTypes {
    val springParameters: DataPointType<SpringParameters> =
        DataPointType(
@@ -43,4 +47,23 @@ object DataPointTypes {
                }
            },
        )

    val springState: DataPointType<SpringState> =
        DataPointType(
            "springState",
            jsonToValue = {
                with(it as? JSONObject ?: throw UnknownTypeException()) {
                    SpringState(
                        getDouble("displacement").toFloat(),
                        getDouble("velocity").toFloat(),
                    )
                }
            },
            valueToJson = {
                JSONObject().apply {
                    put("displacement", it.displacement)
                    put("velocity", it.velocity)
                }
            },
        )
}
+71 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2025 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.mechanics.testing

import com.android.mechanics.debug.DebugInspector
import com.android.mechanics.spec.SemanticKey
import com.android.mechanics.spring.SpringParameters
import com.android.mechanics.spring.SpringState
import platform.test.motion.golden.DataPointType
import platform.test.motion.golden.FeatureCapture
import platform.test.motion.golden.asDataPoint

/** Feature captures on MotionValue's [DebugInspector] */
object FeatureCaptures {
    /** Input value of the current frame. */
    val input = FeatureCapture<DebugInspector, Float>("input") { it.frame.input.asDataPoint() }

    /** Gesture direction of the current frame. */
    val gestureDirection =
        FeatureCapture<DebugInspector, String>("gestureDirection") {
            it.frame.gestureDirection.name.asDataPoint()
        }

    /** Animated output value of the current frame. */
    val output = FeatureCapture<DebugInspector, Float>("output") { it.frame.output.asDataPoint() }

    /** Output target value of the current frame. */
    val outputTarget =
        FeatureCapture<DebugInspector, Float>("outputTarget") {
            it.frame.outputTarget.asDataPoint()
        }

    /** Spring parameters currently in use. */
    val springParameters =
        FeatureCapture<DebugInspector, SpringParameters>("springParameters") {
            it.frame.springParameters.asDataPoint()
        }

    /** Spring state currently in use. */
    val springState =
        FeatureCapture<DebugInspector, SpringState>("springState") {
            it.frame.springState.asDataPoint()
        }

    /** Whether the spring is currently stable. */
    val isStable =
        FeatureCapture<DebugInspector, Boolean>("isStable") { it.frame.isStable.asDataPoint() }

    /** A semantic value to capture in the golden. */
    fun <T> semantics(
        key: SemanticKey<T>,
        dataPointType: DataPointType<T & Any>,
        name: String = key.debugLabel,
    ): FeatureCapture<DebugInspector, T & Any> {
        return FeatureCapture(name) { dataPointType.makeDataPoint(it.frame.semantic(key)) }
    }
}
+33 −46
Original line number Diff line number Diff line
@@ -16,10 +16,10 @@

package com.android.mechanics.testing

import com.android.mechanics.debug.FrameData
import com.android.mechanics.MotionValue
import com.android.mechanics.debug.DebugInspector
import com.android.mechanics.spec.InputDirection
import com.android.mechanics.spec.MotionSpec
import com.android.mechanics.spec.SemanticKey
import kotlin.math.abs
import kotlin.math.floor
import kotlin.math.sign
@@ -27,11 +27,10 @@ import kotlin.time.Duration.Companion.milliseconds
import platform.test.motion.MotionTestRule
import platform.test.motion.RecordedMotion.Companion.create
import platform.test.motion.golden.DataPoint
import platform.test.motion.golden.DataPointType
import platform.test.motion.golden.Feature
import platform.test.motion.golden.FrameId
import platform.test.motion.golden.TimeSeries
import platform.test.motion.golden.asDataPoint
import platform.test.motion.golden.TimeSeriesCaptureScope

/**
 * Records and verifies a timeseries of the [MotionValue]'s output.
@@ -40,7 +39,6 @@ import platform.test.motion.golden.asDataPoint
 * [MotionValue] input over time.
 *
 * @param spec The initial [MotionSpec]
 * @param semantics The list of semantic values to capture in the golden
 * @param initialValue The initial value of the [MotionValue]
 * @param initialDirection The initial [InputDirection] of the [MotionValue]
 * @param directionChangeSlop the minimum distance for the input to change in the opposite direction
@@ -51,6 +49,8 @@ import platform.test.motion.golden.asDataPoint
 *   series. If the function returns `SkipGoldenVerification`, the timeseries won`t be compared to a
 *   golden.
 * @param createDerived (experimental) Creates derived MotionValues
 * @param capture The features to capture on each motion value. See [defaultFeatureCaptures] for
 *   defaults.
 * @param testInput Controls the MotionValue during the test. The timeseries is being recorded until
 *   the function completes.
 * @see ComposeMotionValueToolkit
@@ -62,7 +62,6 @@ fun <
    GestureContextType,
> MotionTestRule<T>.goldenTest(
    spec: MotionSpec,
    semantics: List<CapturedSemantics<*>> = emptyList(),
    initialValue: Float = 0f,
    initialDirection: InputDirection = InputDirection.Max,
    directionChangeSlop: Float = 5f,
@@ -71,18 +70,19 @@ fun <
        VerifyTimeSeriesResult.AssertTimeSeriesMatchesGolden()
    },
    createDerived: (underTest: MotionValueType) -> List<MotionValueType> = { emptyList() },
    capture: CaptureTimeSeriesFn = defaultFeatureCaptures,
    testInput: suspend (InputScope<MotionValueType, GestureContextType>).() -> Unit,
) {
    toolkit.goldenTest(
        this,
        spec,
        createDerived,
        semantics,
        initialValue,
        initialDirection,
        directionChangeSlop,
        stableThreshold,
        verifyTimeSeries,
        capture,
        testInput,
    )
}
@@ -149,15 +149,16 @@ interface VerifyTimeSeriesResult {
        VerifyTimeSeriesResult
}

/** A semantic value to capture in the golden. */
class CapturedSemantics<T>(
    val key: SemanticKey<T>,
    val dataPointType: DataPointType<T & Any>,
    val name: String = key.debugLabel,
) {
    fun toDataPoint(frameData: FrameData): DataPoint<T> {
        return dataPointType.makeDataPoint(frameData.semantic(key))
    }
typealias CaptureTimeSeriesFn = TimeSeriesCaptureScope<DebugInspector>.() -> Unit

/** Default feature captures. */
val defaultFeatureCaptures: CaptureTimeSeriesFn = {
    feature(FeatureCaptures.input)
    feature(FeatureCaptures.gestureDirection)
    feature(FeatureCaptures.output)
    feature(FeatureCaptures.outputTarget)
    feature(FeatureCaptures.springParameters, name = "outputSpring")
    feature(FeatureCaptures.isStable)
}

sealed class MotionValueToolkit<MotionValueType, GestureContextType> {
@@ -165,53 +166,30 @@ sealed class MotionValueToolkit<MotionValueType, GestureContextType> {
        motionTestRule: MotionTestRule<*>,
        spec: MotionSpec,
        createDerived: (underTest: MotionValueType) -> List<MotionValueType>,
        semantics: List<CapturedSemantics<*>>,
        initialValue: Float,
        initialDirection: InputDirection,
        directionChangeSlop: Float,
        stableThreshold: Float,
        verifyTimeSeries: TimeSeries.() -> VerifyTimeSeriesResult,
        capture: CaptureTimeSeriesFn,
        testInput: suspend (InputScope<MotionValueType, GestureContextType>).() -> Unit,
    )

    protected fun createTimeSeries(
    internal fun createTimeSeries(
        frameIds: List<FrameId>,
        frameData: List<Pair<String, List<FrameData>>>,
        semantics: List<CapturedSemantics<*>>,
        motionValueCaptures: List<MotionValueCapture>,
    ): TimeSeries {
        return TimeSeries(
            frameIds.toList(),
            buildList {
                frameData.forEach { (prefix, frames) ->
                    add(Feature("${prefix}input", frames.map { it.input.asDataPoint() }))
                    add(
                        Feature(
                            "${prefix}gestureDirection",
                            frames.map { it.gestureDirection.name.asDataPoint() },
                        )
                    )
                    add(Feature("${prefix}output", frames.map { it.output.asDataPoint() }))
                    add(
                        Feature(
                            "${prefix}outputTarget",
                            frames.map { it.outputTarget.asDataPoint() },
                        )
                    )
                    add(
                        Feature(
                            "${prefix}outputSpring",
                            frames.map { it.springParameters.asDataPoint() },
                        )
                    )
                    add(Feature("${prefix}isStable", frames.map { it.isStable.asDataPoint() }))

                    semantics.forEach { add(Feature(it.name, frames.map(it::toDataPoint))) }
            motionValueCaptures.flatMap { motionValueCapture ->
                motionValueCapture.propertyCollector.entries.map { (name, dataPoints) ->
                    Feature("${motionValueCapture.prefix}$name", dataPoints)
                }
            },
        )
    }

    protected fun verifyTimeSeries(
    internal fun verifyTimeSeries(
        motionTestRule: MotionTestRule<*>,
        timeSeries: TimeSeries,
        verificationFn: TimeSeries.() -> VerifyTimeSeriesResult,
@@ -241,3 +219,12 @@ sealed class MotionValueToolkit<MotionValueType, GestureContextType> {
        val FrameDuration = 16.milliseconds
    }
}

internal class MotionValueCapture(val debugger: DebugInspector, val prefix: String = "") {
    val propertyCollector = mutableMapOf<String, MutableList<DataPoint<*>>>()
    val captureScope = TimeSeriesCaptureScope(debugger, propertyCollector)

    fun captureCurrentFrame(captureFn: CaptureTimeSeriesFn) {
        captureFn(captureScope)
    }
}
+5 −10
Original line number Diff line number Diff line
@@ -19,7 +19,6 @@
package com.android.mechanics.testing

import android.animation.AnimatorTestRule
import com.android.mechanics.debug.FrameData
import com.android.mechanics.spec.InputDirection
import com.android.mechanics.spec.MotionSpec
import com.android.mechanics.view.DistanceGestureContext
@@ -49,12 +48,12 @@ class ViewMotionValueToolkit(private val animatorTestRule: AnimatorTestRule) :
        motionTestRule: MotionTestRule<*>,
        spec: MotionSpec,
        createDerived: (underTest: ViewMotionValue) -> List<ViewMotionValue>,
        semantics: List<CapturedSemantics<*>>,
        initialValue: Float,
        initialDirection: InputDirection,
        directionChangeSlop: Float,
        stableThreshold: Float,
        verifyTimeSeries: TimeSeries.() -> VerifyTimeSeriesResult,
        capture: CaptureTimeSeriesFn,
        testInput: suspend InputScope<ViewMotionValue, DistanceGestureContext>.() -> Unit,
    ) = runTest {
        val frameEmitter = MutableStateFlow<Long>(0)
@@ -74,17 +73,14 @@ class ViewMotionValueToolkit(private val animatorTestRule: AnimatorTestRule) :
            }

        val underTest = testHarness.underTest
        val inspectors = buildMap { put(underTest, underTest.debugInspector()) }
        val motionValueCapture = MotionValueCapture(underTest.debugInspector())
        val recordingJob = launch { testInput.invoke(testHarness) }

        val frameIds = mutableListOf<FrameId>()
        val frameData = mutableMapOf<ViewMotionValue, MutableList<FrameData>>()

        fun recordFrame(frameId: TimestampFrameId) {
            frameIds.add(frameId)
            inspectors.forEach { (motionValue, inspector) ->
                frameData.computeIfAbsent(motionValue) { mutableListOf() }.add(inspector.frame)
            }
            motionValueCapture.captureCurrentFrame(capture)
        }

        runBlocking(Dispatchers.Main) {
@@ -99,10 +95,9 @@ class ViewMotionValueToolkit(private val animatorTestRule: AnimatorTestRule) :
                runCurrent()
            }

            val timeSeries =
                createTimeSeries(frameIds, listOf("" to frameData.values.first()), semantics)
            val timeSeries = createTimeSeries(frameIds, listOf(motionValueCapture))

            inspectors.values.forEach { it.dispose() }
            motionValueCapture.debugger.dispose()
            underTest.dispose()
            verifyTimeSeries(motionTestRule, timeSeries, verifyTimeSeries)
        }
Loading