Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ab431b16 authored by Nicolò Mazzucato's avatar Nicolò Mazzucato Committed by Android (Google) Code Review
Browse files

Merge "tracinglib: trace sections for coroutine scopes" into main

parents 246c179e f1369524
Loading
Loading
Loading
Loading
+8 −8
Original line number Diff line number Diff line
@@ -23,8 +23,8 @@ import android.platform.test.flag.junit.SetFlagsRule
import android.platform.test.rule.EnsureDeviceSettingsRule
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.filters.SmallTest
import com.android.app.tracing.coroutines.TraceContextElement
import com.android.app.tracing.coroutines.launch
import com.android.app.tracing.coroutines.createCoroutineTracingContext
import com.android.app.tracing.coroutines.nameCoroutine
import com.android.app.tracing.coroutines.traceCoroutine
import com.android.systemui.Flags
import kotlinx.coroutines.delay
@@ -74,7 +74,7 @@ class TraceContextMicroBenchmark {
    @Test
    fun testSingleTraceSection() {
        val state = perfStatusReporter.benchmarkState
        runBlocking(TraceContextElement()) {
        runBlocking(createCoroutineTracingContext("root")) {
            while (state.keepRunning()) {
                traceCoroutine("hello-world") { ensureSuspend(state) }
            }
@@ -86,8 +86,8 @@ class TraceContextMicroBenchmark {
    fun testNestedContext() {
        val state = perfStatusReporter.benchmarkState

        val context1 = TraceContextElement()
        val context2 = TraceContextElement()
        val context1 = createCoroutineTracingContext("scope1")
        val context2 = nameCoroutine("scope2")
        runBlocking {
            while (state.keepRunning()) {
                withContext(context1) {
@@ -113,9 +113,9 @@ class TraceContextMicroBenchmark {
    fun testInterleavedLaunch() {
        val state = perfStatusReporter.benchmarkState

        runBlocking(TraceContextElement()) {
        runBlocking(createCoroutineTracingContext("root")) {
            val job1 =
                launch(TraceContextElement()) {
                launch(nameCoroutine("scope1")) {
                    while (true) {
                        traceCoroutine("hello") {
                            traceCoroutine("world") { yield() }
@@ -124,7 +124,7 @@ class TraceContextMicroBenchmark {
                    }
                }
            val job2 =
                launch(TraceContextElement()) {
                launch(nameCoroutine("scope2")) {
                    while (true) {
                        traceCoroutine("hallo") {
                            traceCoroutine("welt") { yield() }
+0 −1
Original line number Diff line number Diff line
@@ -21,7 +21,6 @@ java_library {
    static_libs: [
        "kotlinx_coroutines_android",
        "com_android_systemui_flags_lib",
        "//frameworks/libs/systemui:compilelib",
    ],
    libs: [
        "androidx.annotation_annotation",
+31 −45
Original line number Diff line number Diff line
@@ -16,15 +16,14 @@

package com.android.app.tracing.coroutines

import android.os.Trace
import com.android.systemui.util.Compile
import java.util.concurrent.ThreadLocalRandom
import com.android.systemui.Flags
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.Job
import kotlinx.coroutines.async
@@ -33,8 +32,6 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext

const val TAG = "CoroutineTracing"

const val DEFAULT_TRACK_NAME = "Coroutines"

@OptIn(ExperimentalContracts::class)
@@ -58,21 +55,21 @@ suspend inline fun <R> coroutineScope(
inline fun CoroutineScope.launch(
    crossinline spanName: () -> String,
    context: CoroutineContext = EmptyCoroutineContext,
    // TODO(b/306457056): DO NOT pass CoroutineStart; doing so will regress .odex size
    crossinline block: suspend CoroutineScope.() -> Unit,
): Job = launch(context) { traceCoroutine(spanName) { block() } }
    start: CoroutineStart = CoroutineStart.DEFAULT,
    noinline block: suspend CoroutineScope.() -> Unit,
): Job = launch(nameCoroutine(spanName) + context, start, block)

/**
 * Convenience function for calling [CoroutineScope.launch] with [traceCoroutine] to enable tracing.
 *
 * @see traceCoroutine
 */
inline fun CoroutineScope.launch(
fun CoroutineScope.launch(
    spanName: String,
    context: CoroutineContext = EmptyCoroutineContext,
    // TODO(b/306457056): DO NOT pass CoroutineStart; doing so will regress .odex size
    crossinline block: suspend CoroutineScope.() -> Unit,
): Job = launch(context) { traceCoroutine(spanName) { block() } }
    start: CoroutineStart = CoroutineStart.DEFAULT,
    block: suspend CoroutineScope.() -> Unit,
): Job = launch(nameCoroutine(spanName) + context, start, block)

/**
 * Convenience function for calling [CoroutineScope.async] with [traceCoroutine] enable tracing
@@ -80,23 +77,23 @@ inline fun CoroutineScope.launch(
 * @see traceCoroutine
 */
inline fun <T> CoroutineScope.async(
    crossinline spanName: () -> String,
    spanName: () -> String,
    context: CoroutineContext = EmptyCoroutineContext,
    // TODO(b/306457056): DO NOT pass CoroutineStart; doing so will regress .odex size
    crossinline block: suspend CoroutineScope.() -> T,
): Deferred<T> = async(context) { traceCoroutine(spanName) { block() } }
    start: CoroutineStart = CoroutineStart.DEFAULT,
    noinline block: suspend CoroutineScope.() -> T,
): Deferred<T> = async(nameCoroutine(spanName) + context, start, block)

/**
 * Convenience function for calling [CoroutineScope.async] with [traceCoroutine] enable tracing.
 *
 * @see traceCoroutine
 */
inline fun <T> CoroutineScope.async(
fun <T> CoroutineScope.async(
    spanName: String,
    context: CoroutineContext = EmptyCoroutineContext,
    // TODO(b/306457056): DO NOT pass CoroutineStart; doing so will regress .odex size
    crossinline block: suspend CoroutineScope.() -> T,
): Deferred<T> = async(context) { traceCoroutine(spanName) { block() } }
    start: CoroutineStart = CoroutineStart.DEFAULT,
    block: suspend CoroutineScope.() -> T,
): Deferred<T> = async(nameCoroutine(spanName) + context, start, block)

/**
 * Convenience function for calling [runBlocking] with [traceCoroutine] to enable tracing.
@@ -104,32 +101,32 @@ inline fun <T> CoroutineScope.async(
 * @see traceCoroutine
 */
inline fun <T> runBlocking(
    crossinline spanName: () -> String,
    spanName: () -> String,
    context: CoroutineContext,
    crossinline block: suspend () -> T,
): T = runBlocking(context) { traceCoroutine(spanName) { block() } }
    noinline block: suspend CoroutineScope.() -> T,
): T = runBlocking(nameCoroutine(spanName) + context, block)

/**
 * Convenience function for calling [runBlocking] with [traceCoroutine] to enable tracing.
 *
 * @see traceCoroutine
 */
inline fun <T> runBlocking(
fun <T> runBlocking(
    spanName: String,
    context: CoroutineContext,
    crossinline block: suspend CoroutineScope.() -> T,
): T = runBlocking(context) { traceCoroutine(spanName) { block() } }
    block: suspend CoroutineScope.() -> T,
): T = runBlocking(nameCoroutine(spanName) + context, block)

/**
 * Convenience function for calling [withContext] with [traceCoroutine] to enable tracing.
 *
 * @see traceCoroutine
 */
suspend inline fun <T> withContext(
suspend fun <T> withContext(
    spanName: String,
    context: CoroutineContext,
    crossinline block: suspend CoroutineScope.() -> T,
): T = withContext(context) { traceCoroutine(spanName) { block() } }
    block: suspend CoroutineScope.() -> T,
): T = withContext(nameCoroutine(spanName) + context, block)

/**
 * Convenience function for calling [withContext] with [traceCoroutine] to enable tracing.
@@ -137,10 +134,10 @@ suspend inline fun <T> withContext(
 * @see traceCoroutine
 */
suspend inline fun <T> withContext(
    crossinline spanName: () -> String,
    spanName: () -> String,
    context: CoroutineContext,
    crossinline block: suspend CoroutineScope.() -> T,
): T = withContext(context) { traceCoroutine(spanName) { block() } }
    noinline block: suspend CoroutineScope.() -> T,
): T = withContext(nameCoroutine(spanName) + context, block)

/**
 * Traces a section of work of a `suspend` [block]. The trace sections will appear on the thread
@@ -189,22 +186,11 @@ inline fun <T> traceCoroutine(spanName: () -> String, block: () -> T): T {
    // For coroutine tracing to work, trace spans must be added and removed even when
    // tracing is not active (i.e. when TRACE_TAG_APP is disabled). Otherwise, when the
    // coroutine resumes when tracing is active, we won't know its name.
    val traceData = if (Compile.IS_DEBUG) traceThreadLocal.get() else null
    val asyncTracingEnabled = Trace.isEnabled()
    val spanString = if (traceData != null || asyncTracingEnabled) spanName() else "<none>"
    traceData?.beginSpan(spanString)

    // Also trace to the "Coroutines" async track. This makes it easy to see the duration of
    // coroutine spans. When the coroutine_tracing flag is enabled, those same names will
    // appear in small slices on each thread as the coroutines are suspended and resumed.
    val cookie = if (asyncTracingEnabled) ThreadLocalRandom.current().nextInt() else 0
    if (asyncTracingEnabled)
        Trace.asyncTraceForTrackBegin(Trace.TRACE_TAG_APP, DEFAULT_TRACK_NAME, spanString, cookie)
    val traceData = if (Flags.coroutineTracing()) traceThreadLocal.get() else null
    traceData?.beginSpan(spanName())
    try {
        return block()
    } finally {
        if (asyncTracingEnabled)
            Trace.asyncTraceForTrackEnd(Trace.TRACE_TAG_APP, DEFAULT_TRACK_NAME, cookie)
        traceData?.endSpan()
    }
}
+121 −40
Original line number Diff line number Diff line
@@ -16,22 +16,21 @@

package com.android.app.tracing.coroutines

import android.annotation.SuppressLint
import android.os.Trace
import android.util.Log
import androidx.annotation.VisibleForTesting
import com.android.systemui.Flags
import com.android.systemui.util.Compile
import java.util.concurrent.atomic.AtomicInteger
import kotlin.coroutines.AbstractCoroutineContextKey
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
import kotlin.coroutines.getPolymorphicElement
import kotlin.coroutines.minusPolymorphicKey
import kotlinx.coroutines.CopyableThreadContextElement
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.ExperimentalCoroutinesApi

private const val DEBUG = false

/** Log a message with a tag indicating the current thread ID */
private inline fun debug(message: () -> String) {
    if (DEBUG) println("Thread #${Thread.currentThread().id}: ${message()}")
}

/** Use a final subclass to avoid virtual calls (b/316642146). */
class TraceDataThreadLocal : ThreadLocal<TraceData?>()

@@ -51,13 +50,54 @@ val traceThreadLocal = TraceDataThreadLocal()
/**
 * Returns a new [CoroutineContext] used for tracing. Used to hide internal implementation details.
 */
fun createCoroutineTracingContext(): CoroutineContext =
    if (Compile.IS_DEBUG && Flags.coroutineTracing()) {
        TraceContextElement()
    } else {
        EmptyCoroutineContext
fun createCoroutineTracingContext(name: String = "UnnamedScope"): CoroutineContext =
    if (Flags.coroutineTracing()) TraceContextElement(name) else EmptyCoroutineContext

fun nameCoroutine(name: String): CoroutineContext =
    if (Flags.coroutineTracing()) CoroutineTraceName(name) else EmptyCoroutineContext

inline fun nameCoroutine(name: () -> String): CoroutineContext =
    if (Flags.coroutineTracing()) CoroutineTraceName(name()) else EmptyCoroutineContext

open class BaseTraceElement : CoroutineContext.Element {
    companion object Key : CoroutineContext.Key<BaseTraceElement>

    override val key: CoroutineContext.Key<*>
        get() = Key

    // It is important to use getPolymorphicKey and minusPolymorphicKey
    @OptIn(ExperimentalStdlibApi::class)
    override fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? =
        getPolymorphicElement(key)

    @OptIn(ExperimentalStdlibApi::class)
    override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext = minusPolymorphicKey(key)

    @Suppress("DeprecatedCallableAddReplaceWith")
    @Deprecated(
        message =
            "Operator `+` on two BaseTraceElement objects is meaningless. " +
                "If used, the context element to the right of `+` would simply replace the " +
                "element to the left. To properly use `BaseTraceElement`, `CoroutineTraceName` " +
                "should be used when creating a top-level `CoroutineScope`, " +
                "and `TraceContextElement` should be passed to the child context " +
                "that is under construction.",
        level = DeprecationLevel.ERROR,
    )
    operator fun plus(other: BaseTraceElement): BaseTraceElement = other
}

class CoroutineTraceName(val name: String) : BaseTraceElement() {
    @OptIn(ExperimentalStdlibApi::class)
    companion object Key :
        AbstractCoroutineContextKey<BaseTraceElement, CoroutineTraceName>(
            BaseTraceElement,
            { it as? CoroutineTraceName },
        )
}

const val ROOT_SCOPE = 0

/**
 * Used for safely persisting [TraceData] state when coroutines are suspended and resumed.
 *
@@ -68,18 +108,38 @@ fun createCoroutineTracingContext(): CoroutineContext =
 */
@OptIn(DelicateCoroutinesApi::class, ExperimentalCoroutinesApi::class)
@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
class TraceContextElement private constructor(val traceData: TraceData?) :
    CopyableThreadContextElement<TraceData?> {
class TraceContextElement
private constructor(
    coroutineTraceName: String,
    inheritedTracePrefix: String,
    @get:VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
    val contextTraceData: TraceData?,
    private val coroutineDepth: Int, // depth relative to first TraceContextElement
    parentId: Int,
) : CopyableThreadContextElement<TraceData?>, BaseTraceElement() {

    companion object Key : CoroutineContext.Key<TraceContextElement>
    @OptIn(ExperimentalStdlibApi::class)
    companion object Key :
        AbstractCoroutineContextKey<BaseTraceElement, TraceContextElement>(
            BaseTraceElement,
            { it as? TraceContextElement },
        )

    constructor() : this(if (Compile.IS_DEBUG) TraceData() else null)
    /**
     * Minor perf optimization: no need to create TraceData() for root scopes since all launches
     * require creation of child via [copyForChild] or [mergeForChild].
     */
    constructor(scopeName: String) : this(scopeName, "", null, 0, ROOT_SCOPE)

    override val key: CoroutineContext.Key<*>
        get() = Key
    private var childCoroutineCount = AtomicInteger(0)
    private val currentId = hashCode()

    private val fullCoroutineTraceName = "$inheritedTracePrefix$coroutineTraceName"
    private val continuationTraceMessage =
        "$fullCoroutineTraceName;$coroutineTraceName;d=$coroutineDepth;c=$currentId;p=$parentId"

    init {
        debug { "$this #init" }
        debug { "#init" }
    }

    /**
@@ -97,17 +157,18 @@ class TraceContextElement private constructor(val traceData: TraceData?) :
     * (`...` indicate coroutine body is running; whitespace indicates the thread is not scheduled;
     * `^` is a suspension point)
     */
    @SuppressLint("UnclosedTrace")
    override fun updateThreadContext(context: CoroutineContext): TraceData? {
        if (!Compile.IS_DEBUG) return null
        val oldState = traceThreadLocal.get()
        debug { "$this #updateThreadContext oldState=$oldState" }
        if (oldState !== traceData) {
            traceThreadLocal.set(traceData)
        debug { "#updateThreadContext oldState=$oldState" }
        if (oldState !== contextTraceData) {
            Trace.traceBegin(Trace.TRACE_TAG_APP, continuationTraceMessage)
            traceThreadLocal.set(contextTraceData)
            // Calls to `updateThreadContext` will not happen in parallel on the same context, and
            // they cannot happen before the prior suspension point. Additionally,
            // `restoreThreadContext` does not modify `traceData`, so it is safe to iterate over the
            // collection here:
            traceData?.beginAllOnThread()
            contextTraceData?.beginAllOnThread()
        }
        return oldState
    }
@@ -140,34 +201,54 @@ class TraceContextElement private constructor(val traceData: TraceData?) :
     * ```
     */
    override fun restoreThreadContext(context: CoroutineContext, oldState: TraceData?) {
        if (!Compile.IS_DEBUG) return
        debug { "$this#restoreThreadContext restoring=$oldState" }
        debug { "#restoreThreadContext restoring=$oldState" }
        // We not use the `TraceData` object here because it may have been modified on another
        // thread after the last suspension point. This is why we use a [TraceStateHolder]:
        // so we can end the correct number of trace sections, restoring the thread to its state
        // prior to the last call to [updateThreadContext].
        if (oldState !== traceThreadLocal.get()) {
            traceData?.endAllOnThread()
            contextTraceData?.endAllOnThread()
            traceThreadLocal.set(oldState)
            Trace.traceEnd(Trace.TRACE_TAG_APP) // end: currentScopeTraceMessage
        }
    }

    override fun copyForChild(): CopyableThreadContextElement<TraceData?> {
        if (!Compile.IS_DEBUG) return TraceContextElement(null)
        debug { "$this #copyForChild" }
        return TraceContextElement(traceData?.clone())
        debug { "#copyForChild" }
        return createChildContext()
    }

    override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
        if (!Compile.IS_DEBUG) return EmptyCoroutineContext
        debug { "$this #mergeForChild" }
        // For our use-case, we always give precedence to the parent trace context, and the
        // child context (overwritingElement) is ignored
        return TraceContextElement(traceData?.clone())
        debug { "#mergeForChild" }
        val otherTraceContext = overwritingElement[TraceContextElement]
        if (DEBUG && otherTraceContext != null) {
            Log.e(
                TAG,
                UNEXPECTED_TRACE_DATA_ERROR_MESSAGE +
                    "Current CoroutineContext.Element=$fullCoroutineTraceName, other CoroutineContext.Element=${otherTraceContext.fullCoroutineTraceName}",
            )
        }
        return createChildContext(overwritingElement[CoroutineTraceName]?.name ?: "")
    }

    @OptIn(ExperimentalStdlibApi::class)
    override fun toString(): String {
        return "TraceContextElement@${hashCode().toHexString()}[$traceData]"
    private fun createChildContext(coroutineTraceName: String = ""): TraceContextElement {
        val childCount = childCoroutineCount.incrementAndGet()
        return TraceContextElement(
            coroutineTraceName,
            "$fullCoroutineTraceName:$childCount^",
            TraceData(),
            coroutineDepth + 1,
            currentId,
        )
    }

    private inline fun debug(message: () -> String) {
        if (DEBUG) Log.d(TAG, "@$currentId ${message()} $contextTraceData")
    }
}

private const val UNEXPECTED_TRACE_DATA_ERROR_MESSAGE =
    "Overwriting context element with non-empty trace data. There should only be one " +
        "TraceContextElement per coroutine, and it should be installed in the root scope. "
private const val TAG = "TraceContextElement"
internal const val DEBUG = false
+28 −32
Original line number Diff line number Diff line
@@ -42,7 +42,9 @@ class TraceCountThreadLocal : ThreadLocal<Int>() {
 * @see traceCoroutine
 */
@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
class TraceData(val slices: ArrayDeque<TraceSection> = ArrayDeque()) : Cloneable {
class TraceData {

    var slices: ArrayDeque<TraceSection>? = null

    /**
     * ThreadLocal counter for how many open trace sections there are. This is needed because it is
@@ -57,8 +59,8 @@ class TraceData(val slices: ArrayDeque<TraceSection> = ArrayDeque()) : Cloneable
    /** Adds current trace slices back to the current thread. Called when coroutine is resumed. */
    fun beginAllOnThread() {
        strictModeCheck()
        slices.descendingIterator().forEach { beginSlice(it) }
        openSliceCount.set(slices.size)
        slices?.descendingIterator()?.forEach { beginSlice(it) }
        openSliceCount.set(slices?.size ?: 0)
    }

    /**
@@ -78,8 +80,11 @@ class TraceData(val slices: ArrayDeque<TraceSection> = ArrayDeque()) : Cloneable
     */
    fun beginSpan(name: String) {
        strictModeCheck()
        slices.push(name)
        openSliceCount.set(slices.size)
        if (slices == null) {
            slices = ArrayDeque()
        }
        slices!!.push(name)
        openSliceCount.set(slices!!.size)
        beginSlice(name)
    }

@@ -91,48 +96,39 @@ class TraceData(val slices: ArrayDeque<TraceSection> = ArrayDeque()) : Cloneable
    fun endSpan() {
        strictModeCheck()
        // Should never happen, but we should be defensive rather than crash the whole application
        if (slices.size > 0) {
            slices.pop()
            openSliceCount.set(slices.size)
        if (slices != null && slices!!.size > 0) {
            slices!!.pop()
            openSliceCount.set(slices!!.size)
            endSlice()
        } else if (strictModeForTesting) {
        } else if (STRICT_MODE_FOR_TESTING) {
            throw IllegalStateException(INVALID_SPAN_END_CALL_ERROR_MESSAGE)
        }
    }

    /**
     * Used by [TraceContextElement] when launching a child coroutine so that the child coroutine's
     * state is isolated from the parent.
     */
    public override fun clone(): TraceData {
        return TraceData(slices.clone())
    }

    @OptIn(ExperimentalStdlibApi::class)
    override fun toString(): String {
        return "TraceData@${hashCode().toHexString()}-size=${slices.size}"
    }
    override fun toString(): String =
        if (DEBUG) "{${slices?.joinToString(separator = "\", \"", prefix = "\"", postfix = "\"")}}"
        else super.toString()

    private fun strictModeCheck() {
        if (strictModeForTesting && traceThreadLocal.get() !== this) {
        if (STRICT_MODE_FOR_TESTING && traceThreadLocal.get() !== this) {
            throw ConcurrentModificationException(STRICT_MODE_ERROR_MESSAGE)
        }
    }
}

    companion object {
/**
 * Whether to add additional checks to the coroutine machinery, throwing a
         * `ConcurrentModificationException` if TraceData is modified from the wrong thread. This
         * should only be set for testing.
 * `ConcurrentModificationException` if TraceData is modified from the wrong thread. This should
 * only be set for testing.
 */
        var strictModeForTesting: Boolean = false
    }
}
var STRICT_MODE_FOR_TESTING: Boolean = false

private const val INVALID_SPAN_END_CALL_ERROR_MESSAGE =
    "TraceData#endSpan called when there were no active trace sections."
    "TraceData#endSpan called when there were no active trace sections in its scope."

private const val STRICT_MODE_ERROR_MESSAGE =
    "TraceData should only be accessed using " +
        "the ThreadLocal: CURRENT_TRACE.get(). Accessing TraceData by other means, such as " +
        "through the TraceContextElement's property may lead to concurrent modification."

@OptIn(ExperimentalStdlibApi::class) val hexFormatForId = HexFormat { number.prefix = "0x" }
Loading