Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit feae1434 authored by Peter Kalauskas's avatar Peter Kalauskas Committed by Android (Google) Code Review
Browse files

Merge "Updates to coroutine tracing" into main

parents 3c94afed b6c0607e
Loading
Loading
Loading
Loading
+3 −4
Original line number Diff line number Diff line
@@ -29,7 +29,6 @@ import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.Job
import kotlinx.coroutines.async
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
@@ -168,7 +167,7 @@ suspend inline fun <T> withContext(
 * @see traceCoroutine
 */
@OptIn(ExperimentalContracts::class)
suspend inline fun <T> traceCoroutine(spanName: () -> String, block: () -> T): T {
inline fun <T> traceCoroutine(spanName: () -> String, block: () -> T): T {
    contract {
        callsInPlace(spanName, InvocationKind.AT_MOST_ONCE)
        callsInPlace(block, InvocationKind.EXACTLY_ONCE)
@@ -177,7 +176,7 @@ suspend inline fun <T> traceCoroutine(spanName: () -> String, block: () -> T): T
    // For coroutine tracing to work, trace spans must be added and removed even when
    // tracing is not active (i.e. when TRACE_TAG_APP is disabled). Otherwise, when the
    // coroutine resumes when tracing is active, we won't know its name.
    val tracer = currentCoroutineContext()[TraceContextElement]?.traceData
    val tracer = CURRENT_TRACE.get()

    val asyncTracingEnabled = isEnabled()
    val spanString = if (tracer != null || asyncTracingEnabled) spanName() else "<none>"
@@ -198,5 +197,5 @@ suspend inline fun <T> traceCoroutine(spanName: () -> String, block: () -> T): T
}

/** @see traceCoroutine */
suspend inline fun <T> traceCoroutine(spanName: String, block: () -> T): T =
inline fun <T> traceCoroutine(spanName: String, block: () -> T): T =
    traceCoroutine({ spanName }, block)
+58 −19
Original line number Diff line number Diff line
@@ -16,8 +16,6 @@

package com.android.app.tracing.coroutines

import com.android.app.tracing.beginSlice
import com.android.app.tracing.endSlice
import com.android.systemui.Flags.coroutineTracing
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
@@ -25,6 +23,9 @@ import kotlinx.coroutines.CopyableThreadContextElement
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineName

/** Use a final subclass to avoid virtual calls (b/316642146). */
@PublishedApi internal class ThreadStateLocal : ThreadLocal<TraceData?>()

/**
 * Thread-local storage for giving each thread a unique [TraceData]. It can only be used when paired
 * with a [TraceContextElement].
@@ -36,13 +37,7 @@ import kotlinx.coroutines.CoroutineName
 *
 * @see traceCoroutine
 */
internal val CURRENT_TRACE = ThreadLocal<TraceData?>()

/**
 * If `true`, the CoroutineDispatcher and CoroutineName will be included in the trace each time the
 * coroutine context changes. This makes the trace extremely noisy, so it is off by default.
 */
private const val DEBUG_COROUTINE_CONTEXT_UPDATES = false
@PublishedApi internal val CURRENT_TRACE = ThreadStateLocal()

/**
 * Returns a new [CoroutineContext] used for tracing. Used to hide internal implementation details.
@@ -65,7 +60,6 @@ private fun CoroutineContext.nameForTrace(): String {
 *
 * @see traceCoroutine
 */
@PublishedApi
internal class TraceContextElement(@PublishedApi internal val traceData: TraceData = TraceData()) :
    CopyableThreadContextElement<TraceData?> {

@@ -74,29 +68,74 @@ internal class TraceContextElement(@PublishedApi internal val traceData: TraceDa
    override val key: CoroutineContext.Key<*>
        get() = Key

    /**
     * This function is invoked before the coroutine is resumed on the current thread. When a
     * multi-threaded dispatcher is used, calls to `updateThreadContext` may happen in parallel to
     * the prior `restoreThreadContext` in the same context. However, calls to `updateThreadContext`
     * will not run in parallel on the same context.
     *
     * ```
     * Thread #1 | [updateThreadContext]....^              [restoreThreadContext]
     * --------------------------------------------------------------------------------------------
     * Thread #2 |                           [updateThreadContext]...........^[restoreThreadContext]
     * ```
     *
     * (`...` indicate coroutine body is running; whitespace indicates the thread is not scheduled;
     * `^` is a suspension point)
     */
    override fun updateThreadContext(context: CoroutineContext): TraceData? {
        val oldState = CURRENT_TRACE.get()
        // oldState should never be null because we always initialize the thread-local with a
        // non-null instance,
        oldState?.endAllOnThread()
        CURRENT_TRACE.set(traceData)
        if (DEBUG_COROUTINE_CONTEXT_UPDATES) beginSlice(context.nameForTrace())
        // Calls to `updateThreadContext` will not happen in parallel on the same context, and
        // they cannot happen before the prior suspension point. Additionally,
        // `restoreThreadContext` does not modify `traceData`, so it is safe to iterate over the
        // collection here:
        traceData.beginAllOnThread()
        return oldState
    }

    /**
     * This function is invoked after the coroutine has suspended on the current thread. When a
     * multi-threaded dispatcher is used, calls to `restoreThreadContext` may happen in parallel to
     * the subsequent `updateThreadContext` and `restoreThreadContext` operations. The coroutine
     * body itself will not run in parallel, but `TraceData` could be modified by a coroutine body
     * after the suspension point in parallel to `restoreThreadContext` associated with the
     * coroutine body _prior_ to the suspension point.
     *
     * ```
     * Thread #1 | [updateThreadContext].x..^              [restoreThreadContext]
     * --------------------------------------------------------------------------------------------
     * Thread #2 |                           [updateThreadContext]..x..x.....^[restoreThreadContext]
     * ```
     *
     * OR
     *
     * ```
     * Thread #1 |                                   [restoreThreadContext]
     * --------------------------------------------------------------------------------------------
     * Thread #2 |     [updateThreadContext]...x.......^[restoreThreadContext]
     * ```
     *
     * (`...` indicate coroutine body is running; whitespace indicates the thread is not scheduled;
     * `^` is a suspension point; `x` are calls to modify the thread-local trace data)
     *
     * ```
     */
    override fun restoreThreadContext(context: CoroutineContext, oldState: TraceData?) {
        if (DEBUG_COROUTINE_CONTEXT_UPDATES) endSlice()
        traceData.endAllOnThread()
        // We should normally should not use the `TraceData` object here because it may have been
        // modified on another thread after the last suspension point, but `endAllOnThread()` uses a
        // `ThreadLocal` internally and is thread-safe:
        CURRENT_TRACE.get()?.endAllOnThread()
        CURRENT_TRACE.set(oldState)
        oldState?.beginAllOnThread()
    }

    override fun copyForChild(): CopyableThreadContextElement<TraceData?> {
        return TraceContextElement(traceData.clone())
        return TraceContextElement(CURRENT_TRACE.get()?.clone() ?: TraceData())
    }

    override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
        return TraceContextElement(traceData.clone())
        // For our use-case, we always give precedence to the parent trace context, and the
        // child context is ignored
        return TraceContextElement(CURRENT_TRACE.get()?.clone() ?: TraceData())
    }
}
+63 −11
Original line number Diff line number Diff line
@@ -28,6 +28,13 @@ import java.util.ArrayDeque
 */
typealias TraceSection = String

@PublishedApi
internal class TraceCountThreadLocal : ThreadLocal<Int>() {
    override fun initialValue(): Int {
        return 0
    }
}

/**
 * Used for storing trace sections so that they can be added and removed from the currently running
 * thread when the coroutine is suspended and resumed.
@@ -35,17 +42,43 @@ typealias TraceSection = String
 * @see traceCoroutine
 */
@PublishedApi
internal class TraceData(private val slices: ArrayDeque<TraceSection> = ArrayDeque()) : Cloneable {
internal class TraceData(
    internal val slices: ArrayDeque<TraceSection> = ArrayDeque(),
) : Cloneable {

    companion object {
        /**
         * ThreadLocal counter for how many open trace sections there are. This is needed because it
         * is possible that on a multi-threaded dispatcher, one of the threads could be slow, and
         * `restoreThreadContext` might be invoked _after_ the coroutine has already resumed and
         * modified TraceData - either adding or removing trace sections and changing the count. If
         * we did not store this thread-locally, then we would incorrectly end too many or too few
         * trace sections.
         */
        @PublishedApi internal val openSliceCount = TraceCountThreadLocal()

        /**
         * Whether to add additional checks to the coroutine machinery, throwing a
         * `ConcurrentModificationException` if TraceData is modified from the wrong thread. This
         * should only be set for testing.
         */
        internal var strictModeForTesting: Boolean = false
    }

    /** Adds current trace slices back to the current thread. Called when coroutine is resumed. */
    internal fun beginAllOnThread() {
        strictModeCheck()
        slices.descendingIterator().forEach { beginSlice(it) }
        openSliceCount.set(slices.size)
    }

    /**
     * Removes all current trace slices from the current thread. Called when coroutine is suspended.
     */
    internal fun endAllOnThread() {
        repeat(slices.size) { endSlice() }
        strictModeCheck()
        repeat(openSliceCount.get()) { endSlice() }
        openSliceCount.set(0)
    }

    /**
@@ -56,18 +89,12 @@ internal class TraceData(private val slices: ArrayDeque<TraceSection> = ArrayDeq
     */
    @PublishedApi
    internal fun beginSpan(name: String) {
        strictModeCheck()
        slices.push(name)
        openSliceCount.set(slices.size)
        beginSlice(name)
    }

    /**
     * Used by [TraceContextElement] when launching a child coroutine so that the child coroutine's
     * state is isolated from the parent.
     */
    public override fun clone(): TraceData {
        return TraceData(slices.clone())
    }

    /**
     * Ends the trace section and validates it corresponds with an earlier call to [beginSpan]. The
     * trace slice will immediately be removed from the current thread. This information will not
@@ -75,7 +102,32 @@ internal class TraceData(private val slices: ArrayDeque<TraceSection> = ArrayDeq
     */
    @PublishedApi
    internal fun endSpan() {
        strictModeCheck()
        // Should never happen, but we should be defensive rather than crash the whole application
        if (slices.size > 0) {
            slices.pop()
            openSliceCount.set(slices.size)
            endSlice()
        }
    }

    /**
     * Used by [TraceContextElement] when launching a child coroutine so that the child coroutine's
     * state is isolated from the parent.
     */
    public override fun clone(): TraceData {
        strictModeCheck()
        return TraceData(slices.clone())
    }

    private fun strictModeCheck() {
        if (strictModeForTesting && CURRENT_TRACE.get() !== this) {
            throw ConcurrentModificationException(strictModeErrorMessage)
        }
    }
}

private const val strictModeErrorMessage =
    "TraceData should only be accessed using " +
        "the ThreadLocal: CURRENT_TRACE.get(). Accessing TraceData by other means, such as " +
        "through the TraceContextElement's property may lead to concurrent modification."
+33 −16
Original line number Diff line number Diff line
@@ -29,45 +29,62 @@ internal actual fun traceCounter(counterName: String, counterValue: Int) {
    traceCounters[counterName] = counterValue
}

object TraceState {
    private val traceSections = mutableMapOf<Long, MutableList<String>>()
private val allThreadStates = HashMap<Long, MutableList<String>>()

private class FakeThreadStateLocal : ThreadLocal<MutableList<String>>() {
    override fun initialValue(): MutableList<String> {
        val openTraceSections = mutableListOf<String>()
        val threadId = Thread.currentThread().id
        synchronized(allThreadStates) { allThreadStates.put(threadId, openTraceSections) }
        return openTraceSections
    }
}

private val threadLocalTraceState = FakeThreadStateLocal()

object FakeTraceState {

    fun begin(sectionName: String) {
        synchronized(this) {
            traceSections.getOrPut(Thread.currentThread().id) { mutableListOf() }.add(sectionName)
        }
        threadLocalTraceState.get().add(sectionName)
    }

    fun end() {
        synchronized(this) {
            val openSectionsOnThread = traceSections[Thread.currentThread().id]
        threadLocalTraceState.get().let {
            assertFalse(
                "Attempting to close trace section on thread=${Thread.currentThread().id}, " +
                    "but there are no open sections",
                openSectionsOnThread.isNullOrEmpty()
                it.isNullOrEmpty()
            )
            // TODO: Replace with .removeLast() once available
            openSectionsOnThread!!.removeAt(openSectionsOnThread!!.lastIndex)
            it.removeAt(it.lastIndex)
        }
    }

    fun openSectionsOnCurrentThread(): Array<String> {
        return synchronized(this) {
            traceSections.getOrPut(Thread.currentThread().id) { mutableListOf() }.toTypedArray()
        }
    fun getOpenTraceSectionsOnCurrentThread(): Array<String> {
        return threadLocalTraceState.get().toTypedArray()
    }

    /**
     * Helper function for debugging; use as follows:
     * ```
     * println(FakeThreadStateLocal)
     * ```
     */
    override fun toString(): String {
        return traceSections.toString()
        val sb = StringBuilder()
        synchronized(allThreadStates) {
            allThreadStates.entries.forEach { sb.appendLine("${it.key} -> ${it.value}") }
        }
        return sb.toString()
    }
}

internal actual fun traceBegin(methodName: String) {
    TraceState.begin(methodName)
    FakeTraceState.begin(methodName)
}

internal actual fun traceEnd() {
    TraceState.end()
    FakeTraceState.end()
}

internal actual fun asyncTraceBegin(methodName: String, cookie: Int) {}
+6 −1
Original line number Diff line number Diff line
@@ -16,6 +16,11 @@

package com.android.systemui

private var isCoroutineTracingFlagEnabledForTests = true

object Flags {
    fun coroutineTracing() = true
    fun coroutineTracing() = isCoroutineTracingFlagEnabledForTests
    fun disableCoroutineTracing() {
        isCoroutineTracingFlagEnabledForTests = false
    }
}
Loading