Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 7007a86e authored by Jordan Demeulenaere's avatar Jordan Demeulenaere
Browse files

Prevent elements from jump-cutting after an interruption

This CL adds initial support for interruptions.

Before this CL, whenever a new transition was started, all elements
taking part in this transition (i.e. elements in the transition
fromScene or toScene) would instantly jump to their new state in this
transition.

This CL improves this by tracking the current transition used by an
element. If that transition is interrupted (i.e. the transition for that
element changed), we save the current state of the element. Then, the
next time we compute the state of the element, which is the first time
we compute the state of the element given the new transition, we also
compute and save the diff/delta between the state we saved earlier and
the new state. This delta is then animated to zero and added to the
state computed with the new transition. That way, we nicely animate to
the new state of the new transition while preventing jump cuts caused by
the interruption.

This CL adds support for elements offset, alpha and scale. The size will
be supported in a follow-up CL; it is quite harder to support given that
elements can be measured only once.

The unit test focus on offset only at the moment, because testing scale
and alpha in a unit test can't really be done without changing
production code. I *might* add some screenshot tests for this in another
CL.

See b/290930950#comment5 for details on how this works.

Bug: 290930950
Test: ElementTest
Test: Performed a lot of different transitions in a row in the STL demo
 manually.
Test: This change should not have any impact on usages outside of the
 STL demo given that interruptions were disabled everywhere in
 ag/26621600 and will be enabled later after thorough manual testing.

Flag: N/A

Change-Id: I5b2ee38d71b947cd3aaf587194281c75eca1a0ca
parent 606bcdf3
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -190,4 +190,4 @@ private class OneOffTransition(

// TODO(b/290184746): Compute a good default visibility threshold that depends on the layout size
// and screen density.
private const val ProgressVisibilityThreshold = 1e-3f
internal const val ProgressVisibilityThreshold = 1e-3f
+276 −33
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ import androidx.compose.runtime.snapshots.SnapshotStateMap
import androidx.compose.ui.ExperimentalComposeUiApi
import androidx.compose.ui.Modifier
import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.geometry.isSpecified
import androidx.compose.ui.geometry.isUnspecified
import androidx.compose.ui.geometry.lerp
import androidx.compose.ui.graphics.CompositingStrategy
@@ -55,9 +56,15 @@ import kotlinx.coroutines.launch
internal class Element(val key: ElementKey) {
    /** The mapping between a scene and the state this element has in that scene, if any. */
    // TODO(b/316901148): Make this a normal map instead once we can make sure that new transitions
    // are first seen by composition then layout/drawing code. See 316901148#comment2 for details.
    // are first seen by composition then layout/drawing code. See b/316901148#comment2 for details.
    val sceneStates = SnapshotStateMap<SceneKey, SceneState>()

    /**
     * The last transition that was used when computing the state (size, position and alpha) of this
     * element in any scene, or `null` if it was last laid out when idle.
     */
    var lastTransition: TransitionState.Transition? = null

    override fun toString(): String {
        return "Element(key=$key)"
    }
@@ -65,9 +72,33 @@ internal class Element(val key: ElementKey) {
    /** The last and target state of this element in a given scene. */
    @Stable
    class SceneState(val scene: SceneKey) {
        /**
         * The *target* state of this element in this scene, i.e. the state of this element when we
         * are idle on this scene.
         */
        var targetSize by mutableStateOf(SizeUnspecified)
        var targetOffset by mutableStateOf(Offset.Unspecified)

        /** The last state this element had in this scene. */
        var lastOffset = Offset.Unspecified
        var lastScale = Scale.Unspecified
        var lastAlpha = AlphaUnspecified

        /** The state of this element in this scene right before the last interruption (if any). */
        var offsetBeforeInterruption = Offset.Unspecified
        var scaleBeforeInterruption = Scale.Unspecified
        var alphaBeforeInterruption = AlphaUnspecified

        /**
         * The delta values to add to this element state to have smoother interruptions. These
         * should be multiplied by the
         * [current interruption progress][TransitionState.Transition.interruptionProgress] so that
         * they nicely animate from their values down to 0.
         */
        var offsetInterruptionDelta = Offset.Zero
        var scaleInterruptionDelta = Scale.Zero
        var alphaInterruptionDelta = 0f

        /**
         * The attached [ElementNode] a Modifier.element() for a given element and scene. During
         * composition, this set could have 0 to 2 elements. After composition and after all
@@ -78,12 +109,15 @@ internal class Element(val key: ElementKey) {

    companion object {
        val SizeUnspecified = IntSize(Int.MAX_VALUE, Int.MAX_VALUE)
        val AlphaUnspecified = Float.MAX_VALUE
    }
}

data class Scale(val scaleX: Float, val scaleY: Float, val pivot: Offset = Offset.Unspecified) {
    companion object {
        val Default = Scale(1f, 1f, Offset.Unspecified)
        val Zero = Scale(0f, 0f, Offset.Zero)
        val Unspecified = Scale(Float.MAX_VALUE, Float.MAX_VALUE, Offset.Unspecified)
    }
}

@@ -212,6 +246,10 @@ internal class ElementNode(
        val isOtherSceneOverscrolling = overscrollScene != null && overscrollScene != scene.key
        val isNotPartOfAnyOngoingTransitions = transitions.isNotEmpty() && transition == null
        if (isNotPartOfAnyOngoingTransitions || isOtherSceneOverscrolling) {
            sceneState.lastOffset = Offset.Unspecified
            sceneState.lastScale = Scale.Unspecified
            sceneState.lastAlpha = Element.AlphaUnspecified

            val placeable = measurable.measure(constraints)
            return layout(placeable.width, placeable.height) {}
        }
@@ -233,7 +271,7 @@ internal class ElementNode(

    override fun ContentDrawScope.draw() {
        val transition = elementTransition(element, layoutImpl.state.currentTransitions)
        val drawScale = getDrawScale(layoutImpl, scene, element, transition)
        val drawScale = getDrawScale(layoutImpl, scene, element, transition, sceneState)
        if (drawScale == Scale.Default) {
            drawContent()
        } else {
@@ -276,9 +314,117 @@ private fun elementTransition(
    element: Element,
    transitions: List<TransitionState.Transition>,
): TransitionState.Transition? {
    return transitions.fastLastOrNull { transition ->
    val transition =
        transitions.fastLastOrNull { transition ->
            transition.fromScene in element.sceneStates || transition.toScene in element.sceneStates
        }

    val previousTransition = element.lastTransition
    element.lastTransition = transition

    if (transition != previousTransition && transition != null && previousTransition != null) {
        // The previous transition was interrupted by another transition.
        prepareInterruption(element)
    }

    if (transition == null && previousTransition != null) {
        // The transition was just finished.
        element.sceneStates.values.forEach { sceneState ->
            sceneState.offsetInterruptionDelta = Offset.Zero
            sceneState.scaleInterruptionDelta = Scale.Zero
            sceneState.alphaInterruptionDelta = 0f
        }
    }

    return transition
}

private fun prepareInterruption(element: Element) {
    // We look for the last unique state of this element so that we animate the delta with its
    // future state.
    val sceneStates = element.sceneStates.values
    var lastUniqueState: Element.SceneState? = null
    for (sceneState in sceneStates) {
        val offset = sceneState.lastOffset

        // If the element was placed in this scene...
        if (offset != Offset.Unspecified) {
            // ... and it is the first (and potentially the only) scene where the element was
            // placed, save the state for later.
            if (lastUniqueState == null) {
                lastUniqueState = sceneState
            } else {
                // The element was placed in multiple scenes: we abort the interruption for this
                // element.
                // TODO(b/290930950): Better support cases where a shared element animation is
                // disabled and the same element is drawn/placed in multiple scenes at the same
                // time.
                lastUniqueState = null
                break
            }
        }
    }

    val lastOffset = lastUniqueState?.lastOffset ?: Offset.Unspecified
    val lastScale = lastUniqueState?.lastScale ?: Scale.Unspecified
    val lastAlpha = lastUniqueState?.lastAlpha ?: Element.AlphaUnspecified

    // Store the state of the element before the interruption and reset the deltas.
    sceneStates.forEach { sceneState ->
        sceneState.offsetBeforeInterruption = lastOffset
        sceneState.scaleBeforeInterruption = lastScale
        sceneState.alphaBeforeInterruption = lastAlpha

        sceneState.offsetInterruptionDelta = Offset.Zero
        sceneState.scaleInterruptionDelta = Scale.Zero
        sceneState.alphaInterruptionDelta = 0f
    }
}

/**
 * Compute what [value] should be if we take the
 * [interruption progress][TransitionState.Transition.interruptionProgress] of [transition] into
 * account.
 */
private inline fun <T> computeInterruptedValue(
    layoutImpl: SceneTransitionLayoutImpl,
    transition: TransitionState.Transition?,
    value: T,
    unspecifiedValue: T,
    zeroValue: T,
    getValueBeforeInterruption: () -> T,
    setValueBeforeInterruption: (T) -> Unit,
    getInterruptionDelta: () -> T,
    setInterruptionDelta: (T) -> Unit,
    diff: (a: T, b: T) -> T, // a - b
    add: (a: T, b: T, bProgress: Float) -> T, // a + (b * bProgress)
): T {
    val valueBeforeInterruption = getValueBeforeInterruption()

    // If the value before the interruption is specified, it means that this is the first time we
    // compute [value] right after an interruption.
    if (valueBeforeInterruption != unspecifiedValue) {
        // Compute and store the delta between the value before the interruption and the current
        // value.
        setInterruptionDelta(diff(valueBeforeInterruption, value))

        // Reset the value before interruption now that we processed it.
        setValueBeforeInterruption(unspecifiedValue)
    }

    val delta = getInterruptionDelta()
    return if (delta == zeroValue || transition == null) {
        // There was no interruption or there is no transition: just return the value.
        value
    } else {
        // Add `delta * interruptionProgress` to the value so that we animate to value.
        val interruptionProgress = transition.interruptionProgress(layoutImpl)
        if (interruptionProgress == 0f) {
            value
        } else {
            add(value, delta, interruptionProgress)
        }
    }
}

private fun shouldPlaceElement(
@@ -417,8 +563,10 @@ private fun elementAlpha(
    scene: Scene,
    element: Element,
    transition: TransitionState.Transition?,
    sceneState: Element.SceneState,
): Float {
    return computeValue(
    val alpha =
        computeValue(
                layoutImpl,
                scene,
                element,
@@ -431,6 +579,31 @@ private fun elementAlpha(
                ::lerp,
            )
            .fastCoerceIn(0f, 1f)

    val interruptedAlpha = interruptedAlpha(layoutImpl, transition, sceneState, alpha)
    sceneState.lastAlpha = interruptedAlpha
    return interruptedAlpha
}

private fun interruptedAlpha(
    layoutImpl: SceneTransitionLayoutImpl,
    transition: TransitionState.Transition?,
    sceneState: Element.SceneState,
    alpha: Float,
): Float {
    return computeInterruptedValue(
        layoutImpl,
        transition,
        value = alpha,
        unspecifiedValue = Element.AlphaUnspecified,
        zeroValue = 0f,
        getValueBeforeInterruption = { sceneState.alphaBeforeInterruption },
        setValueBeforeInterruption = { sceneState.alphaBeforeInterruption = it },
        getInterruptionDelta = { sceneState.alphaInterruptionDelta },
        setInterruptionDelta = { sceneState.alphaInterruptionDelta = it },
        diff = { a, b -> a - b },
        add = { a, b, bProgress -> a + b * bProgress },
    )
}

@OptIn(ExperimentalComposeUiApi::class)
@@ -480,13 +653,15 @@ private fun ApproachMeasureScope.measure(
        )
}

private fun getDrawScale(
private fun ContentDrawScope.getDrawScale(
    layoutImpl: SceneTransitionLayoutImpl,
    scene: Scene,
    element: Element,
    transition: TransitionState.Transition?,
    sceneState: Element.SceneState,
): Scale {
    return computeValue(
    val scale =
        computeValue(
            layoutImpl,
            scene,
            element,
@@ -498,6 +673,50 @@ private fun getDrawScale(
            isSpecified = { true },
            ::lerp,
        )

    fun Offset.specifiedOrCenter(): Offset {
        return this.takeIf { isSpecified } ?: center
    }

    val interruptedScale =
        computeInterruptedValue(
            layoutImpl,
            transition,
            value = scale,
            unspecifiedValue = Scale.Unspecified,
            zeroValue = Scale.Zero,
            getValueBeforeInterruption = { sceneState.scaleBeforeInterruption },
            setValueBeforeInterruption = { sceneState.scaleBeforeInterruption = it },
            getInterruptionDelta = { sceneState.scaleInterruptionDelta },
            setInterruptionDelta = { sceneState.scaleInterruptionDelta = it },
            diff = { a, b ->
                Scale(
                    scaleX = a.scaleX - b.scaleX,
                    scaleY = a.scaleY - b.scaleY,
                    pivot =
                        if (a.pivot.isUnspecified && b.pivot.isUnspecified) {
                            Offset.Unspecified
                        } else {
                            a.pivot.specifiedOrCenter() - b.pivot.specifiedOrCenter()
                        }
                )
            },
            add = { a, b, bProgress ->
                Scale(
                    scaleX = a.scaleX + b.scaleX * bProgress,
                    scaleY = a.scaleY + b.scaleY * bProgress,
                    pivot =
                        if (a.pivot.isUnspecified && b.pivot.isUnspecified) {
                            Offset.Unspecified
                        } else {
                            a.pivot.specifiedOrCenter() + b.pivot.specifiedOrCenter() * bProgress
                        }
                )
            }
        )

    sceneState.lastScale = interruptedScale
    return interruptedScale
}

@OptIn(ExperimentalComposeUiApi::class)
@@ -524,6 +743,8 @@ private fun ApproachMeasureScope.place(

        // No need to place the element in this scene if we don't want to draw it anyways.
        if (!shouldPlaceElement(layoutImpl, scene, element, transition)) {
            sceneState.lastOffset = Offset.Unspecified
            sceneState.offsetBeforeInterruption = Offset.Unspecified
            return
        }

@@ -542,15 +763,37 @@ private fun ApproachMeasureScope.place(
                ::lerp,
            )

        val offset = (targetOffset - currentOffset).round()
        if (isElementOpaque(scene, element, transition)) {
        val interruptedOffset =
            computeInterruptedValue(
                layoutImpl,
                transition,
                value = targetOffset,
                unspecifiedValue = Offset.Unspecified,
                zeroValue = Offset.Zero,
                getValueBeforeInterruption = { sceneState.offsetBeforeInterruption },
                setValueBeforeInterruption = { sceneState.offsetBeforeInterruption = it },
                getInterruptionDelta = { sceneState.offsetInterruptionDelta },
                setInterruptionDelta = { sceneState.offsetInterruptionDelta = it },
                diff = { a, b -> a - b },
                add = { a, b, bProgress -> a + b * bProgress },
            )

        sceneState.lastOffset = interruptedOffset

        val offset = (interruptedOffset - currentOffset).round()
        if (
            isElementOpaque(scene, element, transition) &&
                interruptedAlpha(layoutImpl, transition, sceneState, alpha = 1f) == 1f
        ) {
            sceneState.lastAlpha = 1f

            // TODO(b/291071158): Call placeWithLayer() if offset != IntOffset.Zero and size is not
            // animated once b/305195729 is fixed. Test that drawing is not invalidated in that
            // case.
            placeable.place(offset)
        } else {
            placeable.placeWithLayer(offset) {
                alpha = elementAlpha(layoutImpl, scene, element, transition)
                alpha = elementAlpha(layoutImpl, scene, element, transition, sceneState)
                compositingStrategy = CompositingStrategy.ModulateAlpha
            }
        }
+1 −1
Original line number Diff line number Diff line
@@ -49,7 +49,7 @@ internal class SceneTransitionLayoutImpl(
    internal var swipeSourceDetector: SwipeSourceDetector,
    internal var transitionInterceptionThreshold: Float,
    builder: SceneTransitionLayoutScope.() -> Unit,
    private val coroutineScope: CoroutineScope,
    internal val coroutineScope: CoroutineScope,
) {
    /**
     * The map of [Scene]s.
+37 −0
Original line number Diff line number Diff line
@@ -18,6 +18,9 @@ package com.android.compose.animation.scene

import android.util.Log
import androidx.annotation.VisibleForTesting
import androidx.compose.animation.core.Animatable
import androidx.compose.animation.core.AnimationVector1D
import androidx.compose.animation.core.spring
import androidx.compose.foundation.gestures.Orientation
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
@@ -34,6 +37,7 @@ import kotlin.math.absoluteValue
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.launch

/**
 * The state of a [SceneTransitionLayout].
@@ -253,6 +257,12 @@ sealed interface TransitionState {
                }
            }

        /**
         * An animatable that animates from 1f to 0f. This will be used to nicely animate the sudden
         * jump of values when this transitions interrupts another one.
         */
        private var interruptionDecay: Animatable<Float, AnimationVector1D>? = null

        init {
            check(fromScene != toScene)
        }
@@ -289,6 +299,33 @@ sealed interface TransitionState {
            fromOverscrollSpec = fromSpec
            toOverscrollSpec = toSpec
        }

        internal open fun interruptionProgress(
            layoutImpl: SceneTransitionLayoutImpl,
        ): Float {
            if (!layoutImpl.state.enableInterruptions) {
                return 0f
            }

            fun create(): Animatable<Float, AnimationVector1D> {
                val animatable = Animatable(1f, visibilityThreshold = ProgressVisibilityThreshold)
                layoutImpl.coroutineScope.launch {
                    val swipeSpec = layoutImpl.state.transitions.defaultSwipeSpec
                    val progressSpec =
                        spring(
                            stiffness = swipeSpec.stiffness,
                            dampingRatio = swipeSpec.dampingRatio,
                            visibilityThreshold = ProgressVisibilityThreshold,
                        )
                    animatable.animateTo(0f, progressSpec)
                }

                return animatable
            }

            val animatable = interruptionDecay ?: create().also { interruptionDecay = it }
            return animatable.value
        }
    }

    interface HasOverscrollProperties {
+122 −0
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@ import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.SideEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableFloatStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
@@ -55,7 +56,10 @@ import androidx.compose.ui.test.onNodeWithTag
import androidx.compose.ui.test.onRoot
import androidx.compose.ui.test.performTouchInput
import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.DpOffset
import androidx.compose.ui.unit.DpSize
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.lerp
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.android.compose.animation.scene.TestScenes.SceneA
import com.android.compose.animation.scene.TestScenes.SceneB
@@ -1019,4 +1023,122 @@ class ElementTest {
        rule.onNode(isElement(TestElements.Foo)).assertDoesNotExist()
        rule.onNode(isElement(TestElements.Bar)).assertPositionInRootIsEqualTo(100.dp, 100.dp)
    }

    @Test
    fun interruption() = runTest {
        // 4 frames of animation.
        val duration = 4 * 16

        val state =
            MutableSceneTransitionLayoutStateImpl(
                SceneA,
                transitions {
                    from(SceneA, to = SceneB) { spec = tween(duration, easing = LinearEasing) }
                    from(SceneB, to = SceneC) { spec = tween(duration, easing = LinearEasing) }
                },
                enableInterruptions = false,
            )

        val layoutSize = DpSize(200.dp, 100.dp)
        val fooSize = DpSize(20.dp, 10.dp)

        @Composable
        fun SceneScope.Foo(modifier: Modifier = Modifier) {
            Box(modifier.element(TestElements.Foo).size(fooSize))
        }

        rule.setContent {
            SceneTransitionLayout(state, Modifier.size(layoutSize)) {
                // In scene A, Foo is aligned at the TopStart.
                scene(SceneA) {
                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.TopStart)) }
                }

                // In scene B, Foo is aligned at the TopEnd, so it moves horizontally when coming
                // from A.
                scene(SceneB) {
                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.TopEnd)) }
                }

                // In scene C, Foo is aligned at the BottomEnd, so it moves vertically when coming
                // from B.
                scene(SceneC) {
                    Box(Modifier.fillMaxSize()) { Foo(Modifier.align(Alignment.BottomEnd)) }
                }
            }
        }

        // The offset of Foo when idle in A, B or C.
        val offsetInA = DpOffset.Zero
        val offsetInB = DpOffset(layoutSize.width - fooSize.width, 0.dp)
        val offsetInC =
            DpOffset(layoutSize.width - fooSize.width, layoutSize.height - fooSize.height)

        // Initial state (idle in A).
        rule
            .onNode(isElement(TestElements.Foo, SceneA))
            .assertPositionInRootIsEqualTo(offsetInA.x, offsetInA.y)

        // Current transition is A => B at 50%.
        val aToBProgress = 0.5f
        val aToB =
            transition(
                from = SceneA,
                to = SceneB,
                progress = { aToBProgress },
                onFinish = neverFinish(),
            )
        val offsetInAToB = lerp(offsetInA, offsetInB, aToBProgress)
        rule.runOnUiThread { state.startTransition(aToB, transitionKey = null) }
        rule
            .onNode(isElement(TestElements.Foo, SceneB))
            .assertPositionInRootIsEqualTo(offsetInAToB.x, offsetInAToB.y)

        // Start B => C at 0%.
        var bToCProgress by mutableFloatStateOf(0f)
        var interruptionProgress by mutableFloatStateOf(1f)
        val bToC =
            transition(
                from = SceneB,
                to = SceneC,
                progress = { bToCProgress },
                interruptionProgress = { interruptionProgress },
            )
        rule.runOnUiThread { state.startTransition(bToC, transitionKey = null) }

        // The offset interruption delta, which will be multiplied by the interruption progress then
        // added to the current transition offset.
        val interruptionDelta = offsetInAToB - offsetInB

        // Interruption progress is at 100% and bToC is at 0%, so Foo should be at the same offset
        // as right before the interruption.
        rule
            .onNode(isElement(TestElements.Foo, SceneC))
            .assertPositionInRootIsEqualTo(offsetInAToB.x, offsetInAToB.y)

        // Move the transition forward at 30% and set the interruption progress to 50%.
        bToCProgress = 0.3f
        interruptionProgress = 0.5f
        val offsetInBToC = lerp(offsetInB, offsetInC, bToCProgress)
        val offsetInBToCWithInterruption =
            offsetInBToC +
                DpOffset(
                    interruptionDelta.x * interruptionProgress,
                    interruptionDelta.y * interruptionProgress,
                )
        rule.waitForIdle()
        rule
            .onNode(isElement(TestElements.Foo, SceneC))
            .assertPositionInRootIsEqualTo(
                offsetInBToCWithInterruption.x,
                offsetInBToCWithInterruption.y,
            )

        // Finish the transition and interruption.
        bToCProgress = 1f
        interruptionProgress = 0f
        rule
            .onNode(isElement(TestElements.Foo, SceneC))
            .assertPositionInRootIsEqualTo(offsetInC.x, offsetInC.y)
    }
}
Loading