Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 81c78d6f authored by Peter Kalauskas's avatar Peter Kalauskas
Browse files

tracinglib: fix thread-local slice counter

 - Replace ThreadLocal<Int> with ThreadLocal<MutableInt> for better
   performance.

 - Make the ThreadLocal<MutableInt> a top-level val since since only one
   instance is needed per-thread (this will reduce memory overhead).

 - Fix issue where traceCoroutine() would write to the wrong
   thread-local after a suspension point.

 - Add benchmark tests for thread-local usage so we can decide on which
   style is most performant for our tracing use-case.

 - Replace nullable slices type with lateinit.

 - idle() shadow Looper in Robolectric tests so that main dispatcher
   can run.

Bug: 351054475
Test: atest CoroutineTracingPerfTests
Flag: com.android.systemui.coroutine_tracing
Change-Id: I3619fb2f75eef845374ec76b157fe3a7f8cb5e35
parent b9dc4ca3
Loading
Loading
Loading
Loading
+102 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2025 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.android.app.tracing.benchmark

import android.os.Trace
import android.perftests.utils.PerfStatusReporter
import android.platform.test.annotations.EnableFlags
import android.platform.test.flag.junit.SetFlagsRule
import android.platform.test.rule.EnsureDeviceSettingsRule
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.filters.SmallTest
import com.android.systemui.Flags
import java.util.concurrent.atomic.AtomicInteger
import org.junit.After
import org.junit.Assert
import org.junit.Before
import org.junit.ClassRule
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith

@RunWith(AndroidJUnit4::class)
@EnableFlags(Flags.FLAG_COROUTINE_TRACING)
class ThreadLocalMicroBenchmark {

    @get:Rule val perfStatusReporter = PerfStatusReporter()

    @get:Rule val setFlagsRule = SetFlagsRule()

    companion object {
        @JvmField @ClassRule(order = 1) var ensureDeviceSettingsRule = EnsureDeviceSettingsRule()
    }

    @Before
    fun before() {
        Assert.assertTrue(Trace.isEnabled())
    }

    @After
    fun after() {
        Assert.assertTrue(Trace.isEnabled())
    }

    @SmallTest
    @Test
    fun testIntegerIncrement() {
        val state = perfStatusReporter.benchmarkState
        val count: ThreadLocal<Int> = ThreadLocal()
        count.set(0)
        while (state.keepRunning()) {
            count.set(count.get()!! + 1)
        }
    }

    @SmallTest
    @Test
    fun testAtomicIntegerIncrement() {
        val state = perfStatusReporter.benchmarkState
        val count: ThreadLocal<AtomicInteger> = ThreadLocal()
        count.set(AtomicInteger(0))
        while (state.keepRunning()) {
            count.get()!!.getAndIncrement()
        }
    }

    @SmallTest
    @Test
    fun testIntArrayIncrement() {
        val state = perfStatusReporter.benchmarkState
        val count: ThreadLocal<Array<Int>> = ThreadLocal()
        count.set(arrayOf(0))
        while (state.keepRunning()) {
            val arr = count.get()!!
            arr[0]++
        }
    }

    @SmallTest
    @Test
    fun testMutableIntIncrement() {
        val state = perfStatusReporter.benchmarkState
        class MutableInt(var value: Int)
        val count: ThreadLocal<MutableInt> = ThreadLocal()
        count.set(MutableInt(0))
        while (state.keepRunning()) {
            count.get()!!.value++
        }
    }
}
+4 −6
Original line number Diff line number Diff line
@@ -222,12 +222,11 @@ public inline fun <T, R> R.traceCoroutine(crossinline spanName: () -> String, bl
    // For coroutine tracing to work, trace spans must be added and removed even when
    // tracing is not active (i.e. when TRACE_TAG_APP is disabled). Otherwise, when the
    // coroutine resumes when tracing is active, we won't know its name.
    val traceData = if (Flags.coroutineTracing()) traceThreadLocal.get() else null
    traceData?.beginSpan(spanName())
    try {
        if (Flags.coroutineTracing()) traceThreadLocal.get()?.beginSpan(spanName())
        return block()
    } finally {
        traceData?.endSpan()
        if (Flags.coroutineTracing()) traceThreadLocal.get()?.endSpan()
    }
}

@@ -239,12 +238,11 @@ public inline fun <T> traceCoroutine(crossinline spanName: () -> String, block:
    // For coroutine tracing to work, trace spans must be added and removed even when
    // tracing is not active (i.e. when TRACE_TAG_APP is disabled). Otherwise, when the
    // coroutine resumes when tracing is active, we won't know its name.
    val traceData = if (Flags.coroutineTracing()) traceThreadLocal.get() else null
    traceData?.beginSpan(spanName())
    try {
        if (Flags.coroutineTracing()) traceThreadLocal.get()?.beginSpan(spanName())
        return block()
    } finally {
        traceData?.endSpan()
        if (Flags.coroutineTracing()) traceThreadLocal.get()?.endSpan()
    }
}

+8 −4
Original line number Diff line number Diff line
@@ -338,9 +338,11 @@ internal class TraceContextElement(
     * OR
     *
     * ```
     * Thread #1 |                                 [restoreThreadContext]
     * Thread #1 |  [update].x..^  [   ...    restore    ...   ]               [update].x..^[restore]
     * --------------------------------------------------------------------------------------------
     * Thread #2 |     [updateThreadContext]...x....x..^[restoreThreadContext]
     * Thread #2 |                 [update]...x....x..^[restore]
     * --------------------------------------------------------------------------------------------
     * Thread #3 |                                     [ ... update ... ] ....^  [restore]
     * ```
     *
     * (`...` indicate coroutine body is running; whitespace indicates the thread is not scheduled;
@@ -355,10 +357,12 @@ internal class TraceContextElement(
        // so we can end the correct number of trace sections, restoring the thread to its state
        // prior to the last call to [updateThreadContext].
        if (oldState !== traceThreadLocal.get()) {
            if (Trace.isTagEnabled(Trace.TRACE_TAG_APP)) {
                contextTraceData?.endAllOnThread()
            traceThreadLocal.set(oldState)
                Trace.traceEnd(Trace.TRACE_TAG_APP) // end: coroutineTraceName
            }
            traceThreadLocal.set(oldState)
        }
    }

    override fun copyForChild(): CopyableThreadContextElement<TraceData?> {
+37 −27
Original line number Diff line number Diff line
@@ -29,12 +29,24 @@ import java.util.ArrayDeque
 */
private typealias TraceSection = String

private class TraceCountThreadLocal : ThreadLocal<Int>() {
    override fun initialValue(): Int {
        return 0
private class MutableInt(var value: Int)

private class ThreadLocalInt : ThreadLocal<MutableInt>() {
    override fun initialValue(): MutableInt {
        return MutableInt(0)
    }
}

/**
 * ThreadLocal counter for how many open trace sections there are. This is needed because it is
 * possible that on a multi-threaded dispatcher, one of the threads could be slow, and
 * [TraceContextElement.restoreThreadContext] might be invoked _after_ the coroutine has already
 * resumed and modified [TraceData] - either adding or removing trace sections and changing the
 * count. If we did not store this thread-locally, then we would incorrectly end too many or too few
 * trace sections.
 */
private val openSliceCount = ThreadLocalInt()

/**
 * Used for storing trace sections so that they can be added and removed from the currently running
 * thread when the coroutine is suspended and resumed.
@@ -48,24 +60,16 @@ private class TraceCountThreadLocal : ThreadLocal<Int>() {
@PublishedApi
internal class TraceData(internal val currentId: Int, private val strictMode: Boolean) {

    internal var slices: ArrayDeque<TraceSection>? = null

    /**
     * ThreadLocal counter for how many open trace sections there are. This is needed because it is
     * possible that on a multi-threaded dispatcher, one of the threads could be slow, and
     * `restoreThreadContext` might be invoked _after_ the coroutine has already resumed and
     * modified TraceData - either adding or removing trace sections and changing the count. If we
     * did not store this thread-locally, then we would incorrectly end too many or too few trace
     * sections.
     */
    private val openSliceCount = TraceCountThreadLocal()
    internal lateinit var slices: ArrayDeque<TraceSection>

    /** Adds current trace slices back to the current thread. Called when coroutine is resumed. */
    internal fun beginAllOnThread() {
        if (Trace.isTagEnabled(Trace.TRACE_TAG_APP)) {
            strictModeCheck()
            slices?.descendingIterator()?.forEach { beginSlice(it) }
            openSliceCount.set(slices?.size ?: 0)
            if (::slices.isInitialized) {
                slices.descendingIterator().forEach { sectionName -> beginSlice(sectionName) }
                openSliceCount.get()!!.value = slices.size
            }
        }
    }

@@ -75,8 +79,9 @@ internal class TraceData(internal val currentId: Int, private val strictMode: Bo
    internal fun endAllOnThread() {
        if (Trace.isTagEnabled(Trace.TRACE_TAG_APP)) {
            strictModeCheck()
            repeat(openSliceCount.get() ?: 0) { endSlice() }
            openSliceCount.set(0)
            val sliceCount = openSliceCount.get()!!
            repeat(sliceCount.value) { endSlice() }
            sliceCount.value = 0
        }
    }

@@ -89,11 +94,11 @@ internal class TraceData(internal val currentId: Int, private val strictMode: Bo
    @PublishedApi
    internal fun beginSpan(name: String) {
        strictModeCheck()
        if (slices == null) {
            slices = ArrayDeque()
        if (!::slices.isInitialized) {
            slices = ArrayDeque<TraceSection>(4)
        }
        slices!!.push(name)
        openSliceCount.set(slices!!.size)
        slices.push(name)
        openSliceCount.get()!!.value = slices.size
        beginSlice(name)
    }

@@ -106,9 +111,9 @@ internal class TraceData(internal val currentId: Int, private val strictMode: Bo
    internal fun endSpan() {
        strictModeCheck()
        // Should never happen, but we should be defensive rather than crash the whole application
        if (slices != null && slices!!.size > 0) {
            slices!!.pop()
            openSliceCount.set(slices!!.size)
        if (::slices.isInitialized && slices.size > 0) {
            slices.pop()
            openSliceCount.get()!!.value = slices.size
            endSlice()
        } else if (strictMode) {
            throw IllegalStateException(INVALID_SPAN_END_CALL_ERROR_MESSAGE)
@@ -116,8 +121,13 @@ internal class TraceData(internal val currentId: Int, private val strictMode: Bo
    }

    public override fun toString(): String =
        if (DEBUG) "{${slices?.joinToString(separator = "\", \"", prefix = "\"", postfix = "\"")}}"
        else super.toString()
        if (DEBUG) {
            if (::slices.isInitialized) {
                "{${slices.joinToString(separator = "\", \"", prefix = "\"", postfix = "\"")}}"
            } else {
                "{<uninitialized>}"
            }
        } else super.toString()

    private fun strictModeCheck() {
        if (strictMode && traceThreadLocal.get() !== this) {
+72 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2025 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

@file:OptIn(DelicateCoroutinesApi::class, ExperimentalCoroutinesApi::class)

package com.android.test.tracing.coroutines

import android.platform.test.annotations.EnableFlags
import com.android.app.tracing.coroutines.createCoroutineTracingContext
import com.android.app.tracing.coroutines.launchTraced
import com.android.app.tracing.coroutines.withContextTraced
import com.android.systemui.Flags.FLAG_COROUTINE_TRACING
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.asExecutor
import kotlinx.coroutines.delay
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.plus
import org.junit.Test

@EnableFlags(FLAG_COROUTINE_TRACING)
class BackgroundThreadTracingTest : TestBase() {

    override val scope = CoroutineScope(createCoroutineTracingContext("main", testMode = true))

    @Test
    fun withContext_reentrant() =
        runTest(totalEvents = 11) {
            expect("1^main")
            val thread1 = newSingleThreadContext("thread-#1").asExecutor().asCoroutineDispatcher()
            val bgScope = scope.plus(thread1)
            val otherJob =
                bgScope.launchTraced("AAA") {
                    expect("2^AAA")
                    delay(1)
                    expect("2^AAA")
                    withContextTraced("BBB", Dispatchers.Main.immediate) {
                        expect("2^AAA", "BBB")
                        delay(1)
                        expect("2^AAA", "BBB")
                        withContextTraced("CCC", thread1) {
                            expect("2^AAA", "BBB", "CCC")
                            delay(1)
                            expect("2^AAA", "BBB", "CCC")
                        }
                        expect("2^AAA", "BBB")
                    }
                    expect("2^AAA")
                    delay(1)
                    expect("2^AAA")
                }
            delay(195)
            otherJob.cancel()
            expect("1^main")
        }
}
Loading