Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 2a0cf0d7 authored by Peter Kalauskas's avatar Peter Kalauskas
Browse files

Add snapshot state to perf tests

Test: atest StructuredConcurrencyPerfTests --test-filter '.*benchmark_stateObservers_shallow'
Bug: 404377320
Flag: EXEMPT test
Change-Id: I0b8fc4acbce563a66d8683adf9f97541e49ca469
parent 73e783d4
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -41,6 +41,7 @@ java_defaults {
        // binary, which will be run on device to create metrics
        "androidx.benchmark_benchmark-macro-junit4",
        "androidx.benchmark_benchmark-traceprocessor",
        "androidx.compose.runtime_runtime",
    ],
    test_suites: ["performance-tests"],
    platform_apis: true,
+86 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2025 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.android.app.concurrent.benchmark

import com.android.app.concurrent.benchmark.base.BaseCoroutineBenchmark
import com.android.app.concurrent.benchmark.base.BaseCoroutineBenchmark.Companion.ExecutorThreadScopeBuilder
import com.android.app.concurrent.benchmark.base.BaseExecutorBenchmark
import com.android.app.concurrent.benchmark.base.BaseExecutorBenchmark.Companion.ExecutorThreadBuilder
import com.android.app.concurrent.benchmark.base.StateCollectBenchmark
import com.android.app.concurrent.benchmark.base.times
import com.android.app.concurrent.benchmark.builder.BenchmarkWithStateProvider
import com.android.app.concurrent.benchmark.builder.SnapshotStateCoroutineBuilder
import com.android.app.concurrent.benchmark.builder.SnapshotStateExecutorBuilder
import com.android.app.concurrent.benchmark.builder.StateBuilder
import com.android.app.concurrent.benchmark.util.ThreadFactory
import java.util.concurrent.Executor
import kotlinx.coroutines.CoroutineScope
import org.junit.FixMethodOrder
import org.junit.runner.RunWith
import org.junit.runners.MethodSorters
import org.junit.runners.Parameterized
import org.junit.runners.Parameterized.Parameters

@RunWith(Parameterized::class)
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
class SnapshotStateCollectExecutorBenchmark(
    threadParam: ThreadFactory<Any, Executor>,
    override val producerCount: Int,
    override val consumerCount: Int,
) : BaseSnapshotStateExecutorBenchmark(threadParam), StateCollectBenchmark {

    companion object {
        @Parameters(name = "{0},{1},{2}")
        @JvmStatic
        fun getDispatchers() =
            listOf(ExecutorThreadBuilder) *
                StateCollectBenchmark.PRODUCER_LIST *
                StateCollectBenchmark.CONSUMER_LIST
    }
}

@RunWith(Parameterized::class)
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
class SnapshotStateCollectCoroutineBenchmark(
    threadParam: ThreadFactory<Any, CoroutineScope>,
    override val producerCount: Int,
    override val consumerCount: Int,
) : BaseSnapshotStateCoroutineBenchmark(threadParam), StateCollectBenchmark {

    companion object {
        @Parameters(name = "{0},{1},{2}")
        @JvmStatic
        fun getDispatchers() =
            listOf(ExecutorThreadScopeBuilder) *
                StateCollectBenchmark.PRODUCER_LIST *
                StateCollectBenchmark.CONSUMER_LIST
    }
}

abstract class BaseSnapshotStateExecutorBenchmark(threadParam: ThreadFactory<Any, Executor>) :
    BaseExecutorBenchmark(threadParam), BenchmarkWithStateProvider {

    override fun <T> getStateBuilder(): StateBuilder<*, *, T> =
        SnapshotStateExecutorBuilder(executor)
}

abstract class BaseSnapshotStateCoroutineBenchmark(
    threadParam: ThreadFactory<Any, CoroutineScope>
) : BaseCoroutineBenchmark(threadParam), BenchmarkWithStateProvider {

    override fun <T> getStateBuilder(): StateBuilder<*, *, T> =
        SnapshotStateCoroutineBuilder(bgScope)
}
+9 −9
Original line number Diff line number Diff line
@@ -33,25 +33,25 @@ interface StateCollectBenchmark : BenchmarkWithStateProvider {
    val consumerCount: Int

    companion object {
        val PRODUCER_LIST = listOf(1, 2, 5, 10, 25)
        val CONSUMER_LIST = listOf(1, 10, 25, 50, 100)
        val PRODUCER_LIST = listOf(10, 25, 50)
        val CONSUMER_LIST = listOf(10, 25, 50)
    }

    @Test
    fun benchmark_stateListeners() {
    fun benchmark_stateObservers_shallow() {
        class Benchmark<M : R, R>(stateBuilder: StateBuilder<M, R, Int>) :
            StateBenchmarkTask<M, R, Int>(stateBuilder) {
            var receivedVal = Array(producerCount) { IntArray(consumerCount) }
            var receivedVal = Array(consumerCount) { IntArray(producerCount) }
            var producers = List(producerCount) { stateBuilder.createMutableState(0) }

            override fun ConcurrentBenchmarkBuilder.build() {
                if (consumerCount != 0) {
                    beforeFirstIteration(producerCount * consumerCount) { barrier ->
                        repeat(consumerCount) { consumerIndex ->
                            stateBuilder.readScope {
                                producers.forEachIndexed { producerIndex, state ->
                                repeat(consumerCount) { consumerIndex ->
                                    state.observe { newValue ->
                                        receivedVal[producerIndex][consumerIndex] = newValue
                                        receivedVal[consumerIndex][producerIndex] = newValue
                                        barrier.countDown()
                                    }
                                }
@@ -102,7 +102,7 @@ interface ChainedStateCollectBenchmark : BenchmarkWithStateProvider {
    }

    @Test
    fun benchmark() {
    fun benchmark_stateObservers_chained() {
        class Benchmark<M : R, R>(stateBuilder: StateBuilder<M, R, Int>) :
            StateBenchmarkTask<M, R, Int>(stateBuilder) {
            var receivedVal = 0
+166 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2025 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.android.app.concurrent.benchmark.builder

import androidx.compose.runtime.MutableState
import androidx.compose.runtime.State
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.snapshots.Snapshot
import androidx.compose.runtime.snapshots.SnapshotStateObserver
import java.io.Closeable
import java.util.concurrent.Executor
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.launch

// NOTE: The Snapshot APIs used here would typically not be called directly. This benchmark is for
// stress testing snapshot updates and observations. It's not meant to portray a realistic scenario.

class SnapshotStateExecutorBuilder<T>(val executor: Executor) : SnapshotStateBuilder<T>() {
    override fun startObservation(block: () -> Unit): Closeable {
        return SnapshotStateExecutorObserver(executor, block).start()
    }
}

class SnapshotStateCoroutineBuilder<T>(val scope: CoroutineScope) : SnapshotStateBuilder<T>() {
    override fun startObservation(block: () -> Unit): Closeable {
        return SnapshotStateCoroutineObserver(scope) { with(stateReader) { block() } }.start()
    }
}

abstract class SnapshotStateBuilder<T>() : StateBuilder<MutableState<T>, State<T>, T> {

    val openResources = mutableListOf<Closeable>()

    private val stateWriter =
        object : StateBuilder.WriteContext<MutableState<T>, T> {
            override fun MutableState<T>.update(newValue: T) {
                value = newValue
            }
        }

    override fun createMutableState(initialValue: T): MutableState<T> {
        return mutableStateOf(initialValue)
    }

    override fun writeScope(block: StateBuilder.WriteContext<MutableState<T>, T>.() -> Unit) {
        Snapshot.withMutableSnapshot { with(stateWriter) { block() } }
    }

    val stateReader =
        object : StateBuilder.ReadContext<State<T>, T> {
            override fun State<T>.observe(callback: (T) -> Unit) {
                callback(value)
            }
        }

    abstract fun startObservation(block: () -> Unit): Closeable

    @OptIn(ExperimentalStdlibApi::class)
    override fun readScope(block: StateBuilder.ReadContext<State<T>, T>.() -> Unit) {
        synchronized(openResources) {
            openResources += startObservation { with(stateReader) { block() } }
        }
    }

    override fun State<T>.mapState(transform: (T) -> T): State<T> {
        TODO("Not yet implemented")
    }

    override fun combineState(a: State<T>, b: State<T>, transform: (T, T) -> T): State<T> {
        TODO("Not yet implemented")
    }

    override fun combineState(
        a: State<T>,
        b: State<T>,
        c: State<T>,
        transform: (T, T, T) -> T,
    ): State<T> {
        TODO("Not yet implemented")
    }

    override fun dispose() {
        synchronized(openResources) {
            openResources.forEach { it.close() }
            openResources.clear()
        }
    }
}

private class SnapshotStateExecutorObserver(val executor: Executor, private val block: () -> Unit) {
    private val observer =
        SnapshotStateObserver(onChangedExecutor = { callback -> executor.execute(callback) })

    private val onValueChanged = { _: Unit -> observeBlock() }

    private fun observeBlock() {
        observer.observeReads(
            // Scope would only need to be used if we wanted to pass different data to
            // onValueChangedInBlock
            scope = Unit,
            onValueChangedForScope = onValueChanged,
            block = block,
        )
    }

    fun start(): Closeable {
        executor.execute {
            observer.start()
            observeBlock()
        }
        return Closeable { observer.stop() }
    }
}

private class SnapshotStateCoroutineObserver(
    val scope: CoroutineScope,
    private val block: () -> Unit,
) {
    private val changeCallbacks = Channel<() -> Unit>(Channel.UNLIMITED)

    private val observer = SnapshotStateObserver { callback -> changeCallbacks.trySend(callback) }

    private val onValueChanged = { _: Unit -> observeBlock() }

    private fun observeBlock() {
        observer.observeReads(
            // Scope would only need to be used if we wanted to pass different data to
            // onValueChangedInBlock
            scope = Unit,
            onValueChangedForScope = onValueChanged,
            block = block,
        )
    }

    fun start(): Closeable {
        val job =
            scope.launch {
                observer.start()
                try {
                    observeBlock()

                    // Process changes until cancelled:
                    for (callback in changeCallbacks) {
                        callback()
                    }
                } finally {
                    observer.stop()
                }
            }
        return Closeable { job.cancel() }
    }
}
+7 −0
Original line number Diff line number Diff line
@@ -37,16 +37,23 @@ interface StateBuilder<M : R, R, T> {
    fun combineState(a: R, b: R, transform: (T, T) -> T): R

    fun combineState(a: R, b: R, c: R, transform: (T, T, T) -> T): R

    fun dispose() {}
}

abstract class StateBenchmarkTask<M : R, R, T>(val stateBuilder: StateBuilder<M, R, T>) {
    abstract fun ConcurrentBenchmarkBuilder.build()

    fun dispose() {
        stateBuilder.dispose()
    }
}

fun <M : R, R, T> ConcurrentBenchmarkRule.runBenchmark(benchmark: StateBenchmarkTask<M, R, T>) {
    with(ConcurrentBenchmarkBuilder()) {
        with(benchmark) { build() }
        measure()
        benchmark.dispose()
    }
}