Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 75b3339e authored by Peter Kalauskas's avatar Peter Kalauskas
Browse files

tracinglib: improve inline usage

 - Mark spanName() lambdas as crossinline to prevent non-local returns

 - Add contracts to tracing functions

 - Add test for traceCoroutine with receiver

 - Replace runBlockingTraced with simpler implementation that
   does not use coroutine tracing

 - Fix default naming for lambdas passed to coroutine builders

 - Add mapLatestTraced

Test: Build and install CoroutineTracingDemoApp
Flag: com.android.systemui.coroutine_tracing
Bug: 383660219
Bug: 334171711
Bug: 381583986
Bug: 383660219
Bug: 350657545
Change-Id: I2aed20d45d80e2d4ce736a8c30db5772c3af7049
parent 66c8d887
Loading
Loading
Loading
Loading
+108 −60
Original line number Diff line number Diff line
@@ -14,8 +14,11 @@
 * limitations under the License.
 */

@file:OptIn(ExperimentalContracts::class, ExperimentalContracts::class)

package com.android.app.tracing.coroutines

import com.android.app.tracing.traceSection
import com.android.systemui.Flags
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
@@ -34,17 +37,25 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext

@OptIn(ExperimentalContracts::class)
/** @see kotlinx.coroutines.coroutineScope */
public suspend inline fun <R> coroutineScopeTraced(
    traceName: String,
    crossinline spanName: () -> String,
    crossinline block: suspend CoroutineScope.() -> R,
): R {
    contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
    return coroutineScope {
        traceCoroutine(traceName) {
            return@coroutineScope block()
    contract {
        callsInPlace(spanName, InvocationKind.AT_MOST_ONCE)
        callsInPlace(block, InvocationKind.EXACTLY_ONCE)
    }
    return coroutineScope { traceCoroutine(spanName) { block() } }
}

/** @see kotlinx.coroutines.coroutineScope */
public suspend inline fun <R> coroutineScopeTraced(
    traceName: String,
    crossinline block: suspend CoroutineScope.() -> R,
): R {
    contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
    return coroutineScopeTraced({ traceName }, block)
}

/**
@@ -58,7 +69,8 @@ public inline fun CoroutineScope.launchTraced(
    start: CoroutineStart = CoroutineStart.DEFAULT,
    noinline block: suspend CoroutineScope.() -> Unit,
): Job {
    return launch(nameCoroutine(spanName) + context, start, block)
    contract { callsInPlace(spanName, InvocationKind.AT_MOST_ONCE) }
    return launch(addName(spanName, context), start, block)
}

/**
@@ -71,11 +83,23 @@ public fun CoroutineScope.launchTraced(
    context: CoroutineContext = EmptyCoroutineContext,
    start: CoroutineStart = CoroutineStart.DEFAULT,
    block: suspend CoroutineScope.() -> Unit,
): Job = launchTraced({ spanName ?: block::class.simpleName ?: "launch" }, context, start, block)
): Job {
    return launchTraced({ spanName ?: block.traceName }, context, start, block)
}

/** @see kotlinx.coroutines.flow.launchIn */
public fun <T> Flow<T>.launchInTraced(name: String, scope: CoroutineScope): Job =
    scope.launchTraced(name) { collect() }
public inline fun <T> Flow<T>.launchInTraced(
    crossinline spanName: () -> String,
    scope: CoroutineScope,
): Job {
    contract { callsInPlace(spanName, InvocationKind.AT_MOST_ONCE) }
    return scope.launchTraced(spanName) { collect() }
}

/** @see kotlinx.coroutines.flow.launchIn */
public fun <T> Flow<T>.launchInTraced(spanName: String, scope: CoroutineScope): Job {
    return scope.launchTraced({ spanName }) { collect() }
}

/**
 * Convenience function for calling [CoroutineScope.async] with [traceCoroutine] enable tracing
@@ -83,11 +107,14 @@ public fun <T> Flow<T>.launchInTraced(name: String, scope: CoroutineScope): Job
 * @see traceCoroutine
 */
public inline fun <T> CoroutineScope.asyncTraced(
    spanName: () -> String,
    crossinline spanName: () -> String,
    context: CoroutineContext = EmptyCoroutineContext,
    start: CoroutineStart = CoroutineStart.DEFAULT,
    noinline block: suspend CoroutineScope.() -> T,
): Deferred<T> = async(nameCoroutine(spanName) + context, start, block)
): Deferred<T> {
    contract { callsInPlace(spanName, InvocationKind.AT_MOST_ONCE) }
    return async(addName(spanName, context), start, block)
}

/**
 * Convenience function for calling [CoroutineScope.async] with [traceCoroutine] enable tracing.
@@ -99,57 +126,62 @@ public fun <T> CoroutineScope.asyncTraced(
    context: CoroutineContext = EmptyCoroutineContext,
    start: CoroutineStart = CoroutineStart.DEFAULT,
    block: suspend CoroutineScope.() -> T,
): Deferred<T> =
    asyncTraced({ spanName ?: block::class.simpleName ?: "async" }, context, start, block)
): Deferred<T> {
    return asyncTraced({ spanName ?: block.traceName }, context, start, block)
}

/**
 * Convenience function for calling [runBlocking] with [traceCoroutine] to enable tracing.
 * Convenience function for calling [withContext] with [traceCoroutine] to enable tracing.
 *
 * @see traceCoroutine
 */
public inline fun <T> runBlockingTraced(
    spanName: () -> String,
public suspend inline fun <T> withContextTraced(
    crossinline spanName: () -> String,
    context: CoroutineContext,
    noinline block: suspend CoroutineScope.() -> T,
): T = runBlocking(nameCoroutine(spanName) + context, block)

/**
 * Convenience function for calling [runBlocking] with [traceCoroutine] to enable tracing.
 *
 * @see traceCoroutine
 */
public fun <T> runBlockingTraced(
    spanName: String? = null,
    context: CoroutineContext,
    block: suspend CoroutineScope.() -> T,
): T = runBlockingTraced({ spanName ?: block::class.simpleName ?: "runBlocking" }, context, block)

/**
 * Convenience function for calling [withContext] with [traceCoroutine] to enable tracing.
 *
 * @see traceCoroutine
 */
public suspend fun <T> withContextTraced(
    spanName: String? = null,
    context: CoroutineContext = EmptyCoroutineContext,
    block: suspend CoroutineScope.() -> T,
): T = withContextTraced({ spanName ?: block::class.simpleName ?: "withContext" }, context, block)
): T {
    contract {
        callsInPlace(spanName, InvocationKind.AT_MOST_ONCE)
        callsInPlace(block, InvocationKind.EXACTLY_ONCE)
    }
    return traceCoroutine(spanName) { withContext(context, block) }
}

/**
 * Convenience function for calling [withContext] with [traceCoroutine] to enable tracing.
 *
 * @see traceCoroutine
 */
@OptIn(ExperimentalContracts::class)
public suspend inline fun <T> withContextTraced(
    spanName: () -> String,
    context: CoroutineContext = EmptyCoroutineContext,
    spanName: String? = null,
    context: CoroutineContext,
    noinline block: suspend CoroutineScope.() -> T,
): T {
    contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
    traceCoroutine(spanName) {
        return@withContextTraced withContext(context, block)
    return withContextTraced({ spanName ?: block.traceName }, context, block)
}

/** @see kotlinx.coroutines.runBlocking */
public inline fun <T> runBlockingTraced(
    crossinline spanName: () -> String,
    context: CoroutineContext,
    noinline block: suspend CoroutineScope.() -> T,
): T {
    contract {
        callsInPlace(spanName, InvocationKind.AT_MOST_ONCE)
        callsInPlace(block, InvocationKind.EXACTLY_ONCE)
    }
    return traceSection(spanName) { runBlocking(context, block) }
}

/** @see kotlinx.coroutines.runBlocking */
public fun <T> runBlockingTraced(
    spanName: String?,
    context: CoroutineContext,
    block: suspend CoroutineScope.() -> T,
): T {
    contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
    return runBlockingTraced({ spanName ?: block.traceName }, context, block)
}

/**
@@ -182,30 +214,24 @@ public suspend inline fun <T> withContextTraced(
 * @param spanName The name of the code section to appear in the trace
 * @see traceCoroutine
 */
@OptIn(ExperimentalContracts::class)
public inline fun <T> traceCoroutine(spanName: () -> String, block: () -> T): T {
public inline fun <T, R> R.traceCoroutine(crossinline spanName: () -> String, block: R.() -> T): T {
    contract {
        callsInPlace(spanName, InvocationKind.AT_MOST_ONCE)
        callsInPlace(block, InvocationKind.EXACTLY_ONCE)
    }

    // For coroutine tracing to work, trace spans must be added and removed even when
    // tracing is not active (i.e. when TRACE_TAG_APP is disabled). Otherwise, when the
    // coroutine resumes when tracing is active, we won't know its name.
    if (Flags.coroutineTracing()) {
        traceThreadLocal.get()?.beginSpan(spanName())
    }
    val traceData = if (Flags.coroutineTracing()) traceThreadLocal.get() else null
    traceData?.beginSpan(spanName())
    try {
        return block()
    } finally {
        if (Flags.coroutineTracing()) {
            traceThreadLocal.get()?.endSpan()
        }
        traceData?.endSpan()
    }
}

@OptIn(ExperimentalContracts::class)
public inline fun <T, R> R.traceCoroutine(spanName: () -> String, block: R.() -> T): T {
public inline fun <T> traceCoroutine(crossinline spanName: () -> String, block: () -> T): T {
    contract {
        callsInPlace(spanName, InvocationKind.AT_MOST_ONCE)
        callsInPlace(block, InvocationKind.EXACTLY_ONCE)
@@ -223,9 +249,31 @@ public inline fun <T, R> R.traceCoroutine(spanName: () -> String, block: R.() ->
}

/** @see traceCoroutine */
public inline fun <T, R> R.traceCoroutine(spanName: String, block: R.() -> T): T =
    traceCoroutine({ spanName }, block)
public inline fun <T, R> R.traceCoroutine(spanName: String, block: R.() -> T): T {
    contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
    return traceCoroutine({ spanName }, block)
}

/** @see traceCoroutine */
public inline fun <T> traceCoroutine(spanName: String, block: () -> T): T =
    traceCoroutine({ spanName }, block)
public inline fun <T> traceCoroutine(spanName: String, block: () -> T): T {
    contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
    return traceCoroutine({ spanName }, block)
}

/**
 * Returns the passed context if [Flags.coroutineTracing] is false. Otherwise, returns a new context
 * by adding [CoroutineTraceName] to the given context. The [CoroutineTraceName] in the passed
 * context will take precedence over the new [CoroutineTraceName].
 */
@PublishedApi
internal inline fun addName(
    crossinline spanName: () -> String,
    context: CoroutineContext,
): CoroutineContext {
    contract { callsInPlace(spanName, InvocationKind.AT_MOST_ONCE) }
    return if (Flags.coroutineTracing()) CoroutineTraceName(spanName()) + context else context
}

@PublishedApi
internal inline val <reified T : Any> T.traceName: String
    inline get() = this::class.java.name.substringAfterLast(".")
+4 −7
Original line number Diff line number Diff line
@@ -300,9 +300,8 @@ internal class TraceContextElement(
     * `^` is a suspension point)
     */
    @SuppressLint("UnclosedTrace")
    public override fun updateThreadContext(context: CoroutineContext): TraceData? {
    override fun updateThreadContext(context: CoroutineContext): TraceData? {
        val oldState = traceThreadLocal.get()
        //        val coroutineName = context[CoroutineTraceName]?.name ?: ""
        debug { "TCE#update;$nameWithId oldState=${oldState?.currentId}" }
        if (oldState !== contextTraceData) {
            traceThreadLocal.set(contextTraceData)
@@ -349,7 +348,7 @@ internal class TraceContextElement(
     *
     * ```
     */
    public override fun restoreThreadContext(context: CoroutineContext, oldState: TraceData?) {
    override fun restoreThreadContext(context: CoroutineContext, oldState: TraceData?) {
        debug { "TCE#restore;$nameWithId restoring=${oldState?.currentId}" }
        // We not use the `TraceData` object here because it may have been modified on another
        // thread after the last suspension point. This is why we use a [TraceStateHolder]:
@@ -362,7 +361,7 @@ internal class TraceContextElement(
        }
    }

    public override fun copyForChild(): CopyableThreadContextElement<TraceData?> {
    override fun copyForChild(): CopyableThreadContextElement<TraceData?> {
        debug { copyForChildTraceMessage }
        try {
            Trace.traceBegin(Trace.TRACE_TAG_APP, copyForChildTraceMessage)
@@ -374,9 +373,7 @@ internal class TraceContextElement(
        }
    }

    public override fun mergeForChild(
        overwritingElement: CoroutineContext.Element
    ): CoroutineContext {
    override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
        debug { mergeForChildTraceMessage }
        try {
            Trace.traceBegin(Trace.TRACE_TAG_APP, mergeForChildTraceMessage)
+38 −22
Original line number Diff line number Diff line
@@ -14,13 +14,17 @@
 * limitations under the License.
 */

@file:OptIn(ExperimentalTypeInference::class)

package com.android.app.tracing.coroutines.flow

import com.android.app.tracing.coroutines.CoroutineTraceName
import com.android.app.tracing.coroutines.traceCoroutine
import com.android.app.tracing.coroutines.traceName
import com.android.systemui.Flags
import kotlin.experimental.ExperimentalTypeInference
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.FlowCollector
import kotlinx.coroutines.flow.SharedFlow
@@ -32,6 +36,7 @@ import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.flow as safeFlow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.mapLatest
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.shareIn
import kotlinx.coroutines.flow.stateIn
@@ -77,7 +82,7 @@ internal inline fun <T, R> Flow<T>.unsafeTransform(
 */
public fun <T> Flow<T>.flowName(name: String): Flow<T> {
    return if (Flags.coroutineTracing()) {
        unsafeFlow(name) { collect { traceCoroutine("emit") { emit(it) } } }
        unsafeTransform(name) { traceCoroutine("emit") { emit(it) } }
    } else {
        this
    }
@@ -87,9 +92,7 @@ public fun <T> Flow<T>.onEachTraced(name: String, action: suspend (T) -> Unit):
    return if (Flags.coroutineTracing()) {
        unsafeTransform(name) { value ->
            traceCoroutine("onEach:action") { action(value) }
            traceCoroutine("onEach:emit") {
                return@unsafeTransform emit(value)
            }
            traceCoroutine("onEach:emit") { emit(value) }
        }
    } else {
        onEach(action)
@@ -120,6 +123,7 @@ public suspend fun <T> Flow<T>.collectTraced(name: String, collector: FlowCollec
    }
}

/** @see kotlinx.coroutines.flow.collect */
public suspend fun <T> Flow<T>.collectTraced(name: String) {
    if (Flags.coroutineTracing()) {
        flowName(name).collect()
@@ -131,15 +135,37 @@ public suspend fun <T> Flow<T>.collectTraced(name: String) {
/** @see kotlinx.coroutines.flow.collect */
public suspend fun <T> Flow<T>.collectTraced(collector: FlowCollector<T>) {
    if (Flags.coroutineTracing()) {
        collectTraced(
            name = collector::class.java.name.substringAfterLast("."),
            collector = collector,
        )
        collectTraced(name = collector.traceName, collector = collector)
    } else {
        collect(collector)
    }
}

@ExperimentalCoroutinesApi
public fun <T, R> Flow<T>.mapLatestTraced(
    name: String,
    @BuilderInference transform: suspend (value: T) -> R,
): Flow<R> {
    return if (Flags.coroutineTracing()) {
        val collectName = "mapLatest:$name"
        val actionName = "$collectName:transform"
        traceCoroutine(collectName) { mapLatest { traceCoroutine(actionName) { transform(it) } } }
    } else {
        mapLatest(transform)
    }
}

@ExperimentalCoroutinesApi
public fun <T, R> Flow<T>.mapLatestTraced(
    @BuilderInference transform: suspend (value: T) -> R
): Flow<R> {
    return if (Flags.coroutineTracing()) {
        mapLatestTraced(transform.traceName, transform)
    } else {
        mapLatestTraced(transform)
    }
}

/** @see kotlinx.coroutines.flow.collectLatest */
internal suspend fun <T> Flow<T>.collectLatestTraced(
    name: String,
@@ -159,7 +185,7 @@ internal suspend fun <T> Flow<T>.collectLatestTraced(
/** @see kotlinx.coroutines.flow.collectLatest */
public suspend fun <T> Flow<T>.collectLatestTraced(action: suspend (value: T) -> Unit) {
    if (Flags.coroutineTracing()) {
        collectLatestTraced(action::class.java.name.substringAfterLast("."), action)
        collectLatestTraced(action.traceName, action)
    } else {
        collectLatest(action)
    }
@@ -173,13 +199,7 @@ public inline fun <T, R> Flow<T>.transformTraced(
): Flow<R> {
    return if (Flags.coroutineTracing()) {
        // Safe flow must be used because collector is exposed to the caller
        safeFlow {
            collect { value ->
                traceCoroutine("$name:transform") {
                    return@collect transform(value)
                }
            }
        }
        safeFlow { collect { value -> traceCoroutine("$name:transform") { transform(value) } } }
    } else {
        transform(transform)
    }
@@ -193,9 +213,7 @@ public inline fun <T> Flow<T>.filterTraced(
    return if (Flags.coroutineTracing()) {
        unsafeTransform(name) { value ->
            if (traceCoroutine("filter:predicate") { predicate(value) }) {
                traceCoroutine("filter:emit") {
                    return@unsafeTransform emit(value)
                }
                traceCoroutine("filter:emit") { emit(value) }
            }
        }
    } else {
@@ -211,9 +229,7 @@ public inline fun <T, R> Flow<T>.mapTraced(
    return if (Flags.coroutineTracing()) {
        unsafeTransform(name) { value ->
            val transformedValue = traceCoroutine("map:transform") { transform(value) }
            traceCoroutine("map:emit") {
                return@unsafeTransform emit(transformedValue)
            }
            traceCoroutine("map:emit") { emit(transformedValue) }
        }
    } else {
        map(transform)
+2 −1
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ import com.android.app.tracing.coroutines.nameCoroutine
import com.android.app.tracing.coroutines.traceCoroutine
import com.android.app.tracing.coroutines.withContextTraced
import com.android.systemui.Flags.FLAG_COROUTINE_TRACING
import kotlin.coroutines.EmptyCoroutineContext
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
@@ -151,7 +152,7 @@ class CoroutineTracingTest : TestBase() {
    fun withContext_correctUsage() =
        runTest(finalEvent = 4) {
            expect(1, "1^main")
            withContextTraced("inside-withContext") {
            withContextTraced("inside-withContext", EmptyCoroutineContext) {
                assertTrue(coroutineContext[CoroutineTraceName] is TraceContextElement)
                expect(2, "1^main", "inside-withContext")
                delay(1)
+28 −1
Original line number Diff line number Diff line
@@ -14,9 +14,12 @@
 * limitations under the License.
 */

@file:OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)

package com.android.test.tracing.coroutines

import android.platform.test.annotations.EnableFlags
import com.android.app.tracing.coroutines.asyncTraced
import com.android.app.tracing.coroutines.createCoroutineTracingContext
import com.android.app.tracing.coroutines.flow.collectLatestTraced
import com.android.app.tracing.coroutines.flow.collectTraced
@@ -24,7 +27,9 @@ import com.android.app.tracing.coroutines.flow.filterTraced
import com.android.app.tracing.coroutines.flow.mapTraced
import com.android.app.tracing.coroutines.flow.transformTraced
import com.android.app.tracing.coroutines.launchTraced
import com.android.app.tracing.coroutines.withContextTraced
import com.android.systemui.Flags.FLAG_COROUTINE_TRACING
import kotlin.coroutines.EmptyCoroutineContext
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.ExperimentalCoroutinesApi
@@ -362,7 +367,6 @@ class DefaultNamingTest : TestBase() {
            expect(8, "1^main")
        }

    @OptIn(ExperimentalCoroutinesApi::class, DelicateCoroutinesApi::class)
    @Test
    fun collectTraced12_badTransform() =
        runTest(
@@ -385,6 +389,29 @@ class DefaultNamingTest : TestBase() {
                    .collectTraced("COLLECT") {}
            },
        )

    @Test
    fun coroutineBuilder_defaultNames() {
        val localFun: suspend CoroutineScope.() -> Unit = {
            expectAny(
                arrayOf("1^main:4^DefaultNamingTest\$coroutineBuilder_defaultNames\$localFun$1"),
                arrayOf("1^main", "DefaultNamingTest\$coroutineBuilder_defaultNames\$localFun$1"),
                arrayOf("1^main:2^DefaultNamingTest\$coroutineBuilder_defaultNames\$localFun$1"),
            )
        }
        runTest(totalEvents = 6) {
            launchTraced { expect("1^main:1^DefaultNamingTest\$coroutineBuilder_defaultNames$1$1") }
                .join()
            launchTraced(block = localFun).join()
            asyncTraced { expect("1^main:3^DefaultNamingTest\$coroutineBuilder_defaultNames$1$2") }
                .await()
            asyncTraced(block = localFun).await()
            withContextTraced(context = EmptyCoroutineContext) {
                expect("1^main", "DefaultNamingTest\$coroutineBuilder_defaultNames$1$3")
            }
            withContextTraced(context = EmptyCoroutineContext, block = localFun)
        }
    }
}

fun topLevelFun(value: Int) {
Loading