Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 7d629e95 authored by Mike Schneider's avatar Mike Schneider
Browse files

Protect MotionValue runtime from non-finite numbers

1) Ignore non-finite values produced by mappings
2) Require predefined mappings to be created with finite args only
3) Fix common source of `NaN`s from breakpoints at the same position

Bug: 394235639
Test: Unit tests
Flag: com.android.systemui.scene_container
Change-Id: I2d49a8617030fd3e331500c0106fedcb84f8fac5
parent bc403a68
Loading
Loading
Loading
Loading
+40 −9
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@

package com.android.mechanics

import android.util.Log
import androidx.compose.runtime.FloatState
import androidx.compose.runtime.derivedStateOf
import androidx.compose.runtime.getValue
@@ -28,6 +29,7 @@ import androidx.compose.runtime.snapshotFlow
import androidx.compose.runtime.withFrameNanos
import androidx.compose.ui.util.fastCoerceAtLeast
import androidx.compose.ui.util.fastCoerceIn
import androidx.compose.ui.util.fastIsFinite
import androidx.compose.ui.util.lerp
import androidx.compose.ui.util.packFloats
import androidx.compose.ui.util.unpackFloat1
@@ -670,10 +672,23 @@ class MotionValue(
            SegmentChangeType.Direction,
            SegmentChangeType.Spec -> {
                // Determine the delta in the output, as produced by the old and new mapping.
                val delta =
                    currentSegment.mapping.map(currentInput) - lastSegment.mapping.map(currentInput)
                val currentMapping = currentSegment.mapping.map(currentInput)
                val lastMapping = lastSegment.mapping.map(currentInput)
                val delta = currentMapping - lastMapping

                val deltaIsFinite = delta.fastIsFinite()
                if (!deltaIsFinite) {
                    Log.wtf(
                        TAG,
                        "Delta between mappings is undefined!\n" +
                            "  MotionValue: $label\n" +
                            "  input: $currentInput\n" +
                            "  lastMapping: $lastMapping (lastSegment: $lastSegment)\n" +
                            "  currentMapping: $currentMapping (currentSegment: $currentSegment)",
                    )
                }

                if (delta == 0f) {
                if (delta == 0f || !deltaIsFinite) {
                    // Nothing new to animate.
                    lastAnimation
                } else {
@@ -765,14 +780,28 @@ class MotionValue(
                            )
                        lastAnimationTime = nextBreakpointCrossTime

                        val beforeBreakpoint = mappings[segmentIndex].map(nextBreakpoint.position)
                        val afterBreakpoint =
                            mappings[segmentIndex + directionOffset].map(nextBreakpoint.position)
                        val mappingBefore = mappings[segmentIndex]
                        val beforeBreakpoint = mappingBefore.map(nextBreakpoint.position)
                        val mappingAfter = mappings[segmentIndex + directionOffset]
                        val afterBreakpoint = mappingAfter.map(nextBreakpoint.position)

                        val delta = afterBreakpoint - beforeBreakpoint
                        val deltaIsFinite = delta.fastIsFinite()
                        if (!deltaIsFinite) {
                            Log.wtf(
                                TAG,
                                "Delta between breakpoints is undefined!\n" +
                                    "  MotionValue: $label\n" +
                                    "  position: ${nextBreakpoint.position}\n" +
                                    "  before: $beforeBreakpoint (mapping: $mappingBefore)\n" +
                                    "  after: $afterBreakpoint (mapping: $mappingAfter)",
                            )
                        }

                        if (deltaIsFinite) {
                            springTarget += delta
                            springState = springState.addDisplacement(-delta)

                        }
                        segmentIndex += directionOffset
                        lastBreakpoint = nextBreakpoint
                        guaranteeState =
@@ -828,7 +857,9 @@ class MotionValue(
    }

    private val currentDirectMapped: Float
        get() = currentSegment.mapping.map(currentInput()) - currentAnimation.targetValue
        get() {
            return currentSegment.mapping.map(currentInput()) - currentAnimation.targetValue
        }

    private val currentAnimatedDelta: Float
        get() = currentAnimation.targetValue + currentSpringState.displacement
+19 −12
Original line number Diff line number Diff line
@@ -417,6 +417,11 @@ private class DirectionalMotionSpecBuilderImpl(override val defaultSpring: Sprin
            check(!sourceValue.isNaN())

            val sourcePosition = breakpoints.last().position
            val breakpointDistance = atPosition - sourcePosition
            val mapping =
                if (breakpointDistance == 0f) {
                    Mapping.Fixed(sourceValue)
                } else {

                    if (fractionalMapping.isNaN()) {
                        val delta = targetValue - sourceValue
@@ -427,8 +432,10 @@ private class DirectionalMotionSpecBuilderImpl(override val defaultSpring: Sprin
                    }

                    val offset = sourceValue - (sourcePosition * fractionalMapping)
                    Mapping.Linear(fractionalMapping, offset)
                }

            mappings.add(Mapping.Linear(fractionalMapping, offset))
            mappings.add(mapping)
            targetValue = Float.NaN
            sourceValue = Float.NaN
            fractionalMapping = Float.NaN
+20 −14
Original line number Diff line number Diff line
@@ -305,20 +305,26 @@ private class FluentSpecBuilder<R>(
            check(!sourceValue.isNaN())

            val sourcePosition = breakpoints.last().position

            val breakpointDistance = atPosition - sourcePosition
            val mapping =
                if (breakpointDistance == 0f) {
                    Mapping.Fixed(sourceValue)
                } else {
                    if (fractionalMapping.isNaN()) {
                        val delta = targetValue - sourceValue
                fractionalMapping = delta / (atPosition - sourcePosition)
                        fractionalMapping = delta / breakpointDistance
                    } else {
                val delta = (atPosition - sourcePosition) * fractionalMapping
                        val delta = breakpointDistance * fractionalMapping
                        targetValue = sourceValue + delta
                    }

                    val offset =
                        if (buildForward) sourceValue - (sourcePosition * fractionalMapping)
                        else targetValue - (atPosition * fractionalMapping)
                    Mapping.Linear(fractionalMapping, offset)
                }

            mappings.add(Mapping.Linear(fractionalMapping, offset))
            mappings.add(mapping)
            targetValue = Float.NaN
            sourceValue = Float.NaN
            fractionalMapping = Float.NaN
+15 −0
Original line number Diff line number Diff line
@@ -95,6 +95,10 @@ fun interface Mapping {

    /** `f(x) = value` */
    data class Fixed(val value: Float) : Mapping {
        init {
            require(value.isFinite())
        }

        override fun map(input: Float): Float {
            return value
        }
@@ -102,6 +106,11 @@ fun interface Mapping {

    /** `f(x) = factor*x + offset` */
    data class Linear(val factor: Float, val offset: Float = 0f) : Mapping {
        init {
            require(factor.isFinite())
            require(offset.isFinite())
        }

        override fun map(input: Float): Float {
            return input * factor + offset
        }
@@ -109,6 +118,12 @@ fun interface Mapping {

    data class Tanh(val scaling: Float, val tilt: Float, val offset: Float = 0f) : Mapping {

        init {
            require(scaling.isFinite())
            require(tilt.isFinite())
            require(offset.isFinite())
        }

        override fun map(input: Float): Float {
            return scaling * kotlin.math.tanh((input + offset) / (scaling * tilt))
        }
+98 −1
Original line number Diff line number Diff line
@@ -18,6 +18,8 @@

package com.android.mechanics

import android.util.Log
import android.util.Log.TerribleFailureHandler
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.mutableFloatStateOf
import androidx.compose.runtime.mutableStateOf
@@ -26,7 +28,6 @@ import androidx.compose.ui.test.ExperimentalTestApi
import androidx.compose.ui.test.TestMonotonicFrameClock
import androidx.compose.ui.test.junit4.createComposeRule
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.android.internal.R.id.primary
import com.android.mechanics.spec.BreakpointKey
import com.android.mechanics.spec.DirectionalMotionSpec
import com.android.mechanics.spec.Guarantee
@@ -42,6 +43,8 @@ import com.android.mechanics.testing.MotionValueToolkit.Companion.dataPoints
import com.android.mechanics.testing.MotionValueToolkit.Companion.input
import com.android.mechanics.testing.MotionValueToolkit.Companion.isStable
import com.android.mechanics.testing.MotionValueToolkit.Companion.output
import com.android.mechanics.testing.VerifyTimeSeriesResult.AssertTimeSeriesMatchesGolden
import com.android.mechanics.testing.VerifyTimeSeriesResult.SkipGoldenVerification
import com.android.mechanics.testing.goldenTest
import com.google.common.truth.Truth.assertThat
import com.google.common.truth.Truth.assertWithMessage
@@ -54,6 +57,7 @@ import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withContext
import org.junit.Rule
import org.junit.Test
import org.junit.rules.ExternalResource
import org.junit.runner.RunWith
import platform.test.motion.MotionTestRule
import platform.test.motion.testing.createGoldenPathManager
@@ -65,6 +69,7 @@ class MotionValueTest {

    @get:Rule(order = 0) val rule = createComposeRule()
    @get:Rule(order = 1) val motion = MotionTestRule(MotionValueToolkit(rule), goldenPathManager)
    @get:Rule(order = 2) val wtfLog = WtfLogRule()

    @Test
    fun emptySpec_outputMatchesInput_withoutAnimation() =
@@ -75,6 +80,8 @@ class MotionValueTest {
                assertThat(output).containsExactlyElementsIn(input).inOrder()
                // There must never be an ongoing animation.
                assertThat(isStable).doesNotContain(false)

                AssertTimeSeriesMatchesGolden
            },
        ) {
            animateValueTo(100f)
@@ -353,6 +360,8 @@ class MotionValueTest {
                    .inOrder()
                // and its never animated.
                assertThat(dataPoints<Float>("derived-isStable")).doesNotContain(false)

                AssertTimeSeriesMatchesGolden
            },
        ) {
            animateValueTo(1f, changePerFrame = 0.1f)
@@ -379,6 +388,75 @@ class MotionValueTest {
        }
    }

    @Test
    fun nonFiniteNumbers_producesNaN_recoversOnSubsequentFrames() {
        motion.goldenTest(
            spec = specBuilder(Mapping { if (it >= 1f) Float.NaN else 0f }).complete(),
            verifyTimeSeries = {
                assertThat(output.drop(1).take(5))
                    .containsExactlyElementsIn(listOf(0f, Float.NaN, Float.NaN, 0f, 0f))
                    .inOrder()
                SkipGoldenVerification
            },
        ) {
            animatedInputSequence(0f, 1f, 1f, 0f, 0f)
        }

        assertThat(wtfLog.loggedFailures).isEmpty()
    }

    @Test
    fun nonFiniteNumbers_segmentChange_skipsAnimation() {
        motion.goldenTest(
            spec = MotionSpec.Empty,
            verifyTimeSeries = {
                // The mappings produce a non-finite number during a segment change.
                // The animation thereof is skipped to avoid poisoning the state with non-finite
                // numbers
                assertThat(output.drop(1).take(5))
                    .containsExactlyElementsIn(listOf(0f, 1f, Float.NaN, 0f, 0f))
                    .inOrder()
                SkipGoldenVerification
            },
        ) {
            animatedInputSequence(0f, 1f)
            underTest.spec =
                specBuilder()
                    .toBreakpoint(0f)
                    .completeWith(Mapping { if (it >= 1f) Float.NaN else 0f })
            awaitFrames()

            animatedInputSequence(0f, 0f)
        }

        assertThat(wtfLog.loggedFailures).hasSize(1)
        assertThat(wtfLog.loggedFailures.first()).startsWith("Delta between mappings is undefined")
    }

    @Test
    fun nonFiniteNumbers_segmentTraverse_skipsAnimation() {
        motion.goldenTest(
            spec =
                specBuilder(Mapping.Zero)
                    .toBreakpoint(1f)
                    .completeWith(Mapping { if (it < 2f) Float.NaN else 2f }),
            verifyTimeSeries = {
                // The mappings produce a non-finite number during a breakpoint traversal.
                // The animation thereof is skipped to avoid poisoning the state with non-finite
                // numbers
                assertThat(output.drop(1).take(6))
                    .containsExactlyElementsIn(listOf(0f, 0f, Float.NaN, Float.NaN, 2f, 2f))
                    .inOrder()
                SkipGoldenVerification
            },
        ) {
            animatedInputSequence(0f, 0.5f, 1f, 1.5f, 2f, 3f)
        }
        assertThat(wtfLog.loggedFailures).hasSize(1)
        assertThat(wtfLog.loggedFailures.first())
            .startsWith("Delta between breakpoints is undefined")
    }

    @Test
    fun keepRunning_concurrentInvocationThrows() = runTestWithFrameClock { testScheduler, _ ->
        val underTest = MotionValue({ 1f }, FakeGestureContext, label = "Foo")
@@ -563,6 +641,25 @@ class MotionValueTest {
        }
    }

    class WtfLogRule : ExternalResource() {
        val loggedFailures = mutableListOf<String>()

        private lateinit var oldHandler: TerribleFailureHandler

        override fun before() {
            oldHandler =
                Log.setWtfHandler { tag, what, _ ->
                    if (tag == MotionValue.TAG) {
                        loggedFailures.add(checkNotNull(what.message))
                    }
                }
        }

        override fun after() {
            Log.setWtfHandler(oldHandler)
        }
    }

    companion object {
        val B1 = BreakpointKey("breakpoint1")
        val B2 = BreakpointKey("breakpoint2")
Loading