Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit b6c0607e authored by Peter Kalauskas's avatar Peter Kalauskas
Browse files

Updates to coroutine tracing

 - Small performance improvements

 - New tests

 - New comments

 - Fix issue with concurrent modification of TraceData that would
   cause malformed traces

Flag: ACONFIG com.android.systemui.coroutine_tracing DEVELOPMENT
Bug: 289353932
Test: atest tracinglib-host-test tracinglib-robolectric-test
Change-Id: I08e3b492db36f4a9c02af17587624b2740f83d30
parent dabea20f
Loading
Loading
Loading
Loading
+3 −4
Original line number Diff line number Diff line
@@ -29,7 +29,6 @@ import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.Job
import kotlinx.coroutines.async
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
@@ -168,7 +167,7 @@ suspend inline fun <T> withContext(
 * @see traceCoroutine
 */
@OptIn(ExperimentalContracts::class)
suspend inline fun <T> traceCoroutine(spanName: () -> String, block: () -> T): T {
inline fun <T> traceCoroutine(spanName: () -> String, block: () -> T): T {
    contract {
        callsInPlace(spanName, InvocationKind.AT_MOST_ONCE)
        callsInPlace(block, InvocationKind.EXACTLY_ONCE)
@@ -177,7 +176,7 @@ suspend inline fun <T> traceCoroutine(spanName: () -> String, block: () -> T): T
    // For coroutine tracing to work, trace spans must be added and removed even when
    // tracing is not active (i.e. when TRACE_TAG_APP is disabled). Otherwise, when the
    // coroutine resumes when tracing is active, we won't know its name.
    val tracer = currentCoroutineContext()[TraceContextElement]?.traceData
    val tracer = CURRENT_TRACE.get()

    val asyncTracingEnabled = isEnabled()
    val spanString = if (tracer != null || asyncTracingEnabled) spanName() else "<none>"
@@ -198,5 +197,5 @@ suspend inline fun <T> traceCoroutine(spanName: () -> String, block: () -> T): T
}

/** @see traceCoroutine */
suspend inline fun <T> traceCoroutine(spanName: String, block: () -> T): T =
inline fun <T> traceCoroutine(spanName: String, block: () -> T): T =
    traceCoroutine({ spanName }, block)
+58 −19
Original line number Diff line number Diff line
@@ -16,8 +16,6 @@

package com.android.app.tracing.coroutines

import com.android.app.tracing.beginSlice
import com.android.app.tracing.endSlice
import com.android.systemui.Flags.coroutineTracing
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
@@ -25,6 +23,9 @@ import kotlinx.coroutines.CopyableThreadContextElement
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineName

/** Use a final subclass to avoid virtual calls (b/316642146). */
@PublishedApi internal class ThreadStateLocal : ThreadLocal<TraceData?>()

/**
 * Thread-local storage for giving each thread a unique [TraceData]. It can only be used when paired
 * with a [TraceContextElement].
@@ -36,13 +37,7 @@ import kotlinx.coroutines.CoroutineName
 *
 * @see traceCoroutine
 */
internal val CURRENT_TRACE = ThreadLocal<TraceData?>()

/**
 * If `true`, the CoroutineDispatcher and CoroutineName will be included in the trace each time the
 * coroutine context changes. This makes the trace extremely noisy, so it is off by default.
 */
private const val DEBUG_COROUTINE_CONTEXT_UPDATES = false
@PublishedApi internal val CURRENT_TRACE = ThreadStateLocal()

/**
 * Returns a new [CoroutineContext] used for tracing. Used to hide internal implementation details.
@@ -65,7 +60,6 @@ private fun CoroutineContext.nameForTrace(): String {
 *
 * @see traceCoroutine
 */
@PublishedApi
internal class TraceContextElement(@PublishedApi internal val traceData: TraceData = TraceData()) :
    CopyableThreadContextElement<TraceData?> {

@@ -74,29 +68,74 @@ internal class TraceContextElement(@PublishedApi internal val traceData: TraceDa
    override val key: CoroutineContext.Key<*>
        get() = Key

    /**
     * This function is invoked before the coroutine is resumed on the current thread. When a
     * multi-threaded dispatcher is used, calls to `updateThreadContext` may happen in parallel to
     * the prior `restoreThreadContext` in the same context. However, calls to `updateThreadContext`
     * will not run in parallel on the same context.
     *
     * ```
     * Thread #1 | [updateThreadContext]....^              [restoreThreadContext]
     * --------------------------------------------------------------------------------------------
     * Thread #2 |                           [updateThreadContext]...........^[restoreThreadContext]
     * ```
     *
     * (`...` indicate coroutine body is running; whitespace indicates the thread is not scheduled;
     * `^` is a suspension point)
     */
    override fun updateThreadContext(context: CoroutineContext): TraceData? {
        val oldState = CURRENT_TRACE.get()
        // oldState should never be null because we always initialize the thread-local with a
        // non-null instance,
        oldState?.endAllOnThread()
        CURRENT_TRACE.set(traceData)
        if (DEBUG_COROUTINE_CONTEXT_UPDATES) beginSlice(context.nameForTrace())
        // Calls to `updateThreadContext` will not happen in parallel on the same context, and
        // they cannot happen before the prior suspension point. Additionally,
        // `restoreThreadContext` does not modify `traceData`, so it is safe to iterate over the
        // collection here:
        traceData.beginAllOnThread()
        return oldState
    }

    /**
     * This function is invoked after the coroutine has suspended on the current thread. When a
     * multi-threaded dispatcher is used, calls to `restoreThreadContext` may happen in parallel to
     * the subsequent `updateThreadContext` and `restoreThreadContext` operations. The coroutine
     * body itself will not run in parallel, but `TraceData` could be modified by a coroutine body
     * after the suspension point in parallel to `restoreThreadContext` associated with the
     * coroutine body _prior_ to the suspension point.
     *
     * ```
     * Thread #1 | [updateThreadContext].x..^              [restoreThreadContext]
     * --------------------------------------------------------------------------------------------
     * Thread #2 |                           [updateThreadContext]..x..x.....^[restoreThreadContext]
     * ```
     *
     * OR
     *
     * ```
     * Thread #1 |                                   [restoreThreadContext]
     * --------------------------------------------------------------------------------------------
     * Thread #2 |     [updateThreadContext]...x.......^[restoreThreadContext]
     * ```
     *
     * (`...` indicate coroutine body is running; whitespace indicates the thread is not scheduled;
     * `^` is a suspension point; `x` are calls to modify the thread-local trace data)
     *
     * ```
     */
    override fun restoreThreadContext(context: CoroutineContext, oldState: TraceData?) {
        if (DEBUG_COROUTINE_CONTEXT_UPDATES) endSlice()
        traceData.endAllOnThread()
        // We should normally should not use the `TraceData` object here because it may have been
        // modified on another thread after the last suspension point, but `endAllOnThread()` uses a
        // `ThreadLocal` internally and is thread-safe:
        CURRENT_TRACE.get()?.endAllOnThread()
        CURRENT_TRACE.set(oldState)
        oldState?.beginAllOnThread()
    }

    override fun copyForChild(): CopyableThreadContextElement<TraceData?> {
        return TraceContextElement(traceData.clone())
        return TraceContextElement(CURRENT_TRACE.get()?.clone() ?: TraceData())
    }

    override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
        return TraceContextElement(traceData.clone())
        // For our use-case, we always give precedence to the parent trace context, and the
        // child context is ignored
        return TraceContextElement(CURRENT_TRACE.get()?.clone() ?: TraceData())
    }
}
+63 −11
Original line number Diff line number Diff line
@@ -28,6 +28,13 @@ import java.util.ArrayDeque
 */
typealias TraceSection = String

@PublishedApi
internal class TraceCountThreadLocal : ThreadLocal<Int>() {
    override fun initialValue(): Int {
        return 0
    }
}

/**
 * Used for storing trace sections so that they can be added and removed from the currently running
 * thread when the coroutine is suspended and resumed.
@@ -35,17 +42,43 @@ typealias TraceSection = String
 * @see traceCoroutine
 */
@PublishedApi
internal class TraceData(private val slices: ArrayDeque<TraceSection> = ArrayDeque()) : Cloneable {
internal class TraceData(
    internal val slices: ArrayDeque<TraceSection> = ArrayDeque(),
) : Cloneable {

    companion object {
        /**
         * ThreadLocal counter for how many open trace sections there are. This is needed because it
         * is possible that on a multi-threaded dispatcher, one of the threads could be slow, and
         * `restoreThreadContext` might be invoked _after_ the coroutine has already resumed and
         * modified TraceData - either adding or removing trace sections and changing the count. If
         * we did not store this thread-locally, then we would incorrectly end too many or too few
         * trace sections.
         */
        @PublishedApi internal val openSliceCount = TraceCountThreadLocal()

        /**
         * Whether to add additional checks to the coroutine machinery, throwing a
         * `ConcurrentModificationException` if TraceData is modified from the wrong thread. This
         * should only be set for testing.
         */
        internal var strictModeForTesting: Boolean = false
    }

    /** Adds current trace slices back to the current thread. Called when coroutine is resumed. */
    internal fun beginAllOnThread() {
        strictModeCheck()
        slices.descendingIterator().forEach { beginSlice(it) }
        openSliceCount.set(slices.size)
    }

    /**
     * Removes all current trace slices from the current thread. Called when coroutine is suspended.
     */
    internal fun endAllOnThread() {
        repeat(slices.size) { endSlice() }
        strictModeCheck()
        repeat(openSliceCount.get()) { endSlice() }
        openSliceCount.set(0)
    }

    /**
@@ -56,18 +89,12 @@ internal class TraceData(private val slices: ArrayDeque<TraceSection> = ArrayDeq
     */
    @PublishedApi
    internal fun beginSpan(name: String) {
        strictModeCheck()
        slices.push(name)
        openSliceCount.set(slices.size)
        beginSlice(name)
    }

    /**
     * Used by [TraceContextElement] when launching a child coroutine so that the child coroutine's
     * state is isolated from the parent.
     */
    public override fun clone(): TraceData {
        return TraceData(slices.clone())
    }

    /**
     * Ends the trace section and validates it corresponds with an earlier call to [beginSpan]. The
     * trace slice will immediately be removed from the current thread. This information will not
@@ -75,7 +102,32 @@ internal class TraceData(private val slices: ArrayDeque<TraceSection> = ArrayDeq
     */
    @PublishedApi
    internal fun endSpan() {
        strictModeCheck()
        // Should never happen, but we should be defensive rather than crash the whole application
        if (slices.size > 0) {
            slices.pop()
            openSliceCount.set(slices.size)
            endSlice()
        }
    }

    /**
     * Used by [TraceContextElement] when launching a child coroutine so that the child coroutine's
     * state is isolated from the parent.
     */
    public override fun clone(): TraceData {
        strictModeCheck()
        return TraceData(slices.clone())
    }

    private fun strictModeCheck() {
        if (strictModeForTesting && CURRENT_TRACE.get() !== this) {
            throw ConcurrentModificationException(strictModeErrorMessage)
        }
    }
}

private const val strictModeErrorMessage =
    "TraceData should only be accessed using " +
        "the ThreadLocal: CURRENT_TRACE.get(). Accessing TraceData by other means, such as " +
        "through the TraceContextElement's property may lead to concurrent modification."
+33 −16
Original line number Diff line number Diff line
@@ -29,45 +29,62 @@ internal actual fun traceCounter(counterName: String, counterValue: Int) {
    traceCounters[counterName] = counterValue
}

object TraceState {
    private val traceSections = mutableMapOf<Long, MutableList<String>>()
private val allThreadStates = HashMap<Long, MutableList<String>>()

private class FakeThreadStateLocal : ThreadLocal<MutableList<String>>() {
    override fun initialValue(): MutableList<String> {
        val openTraceSections = mutableListOf<String>()
        val threadId = Thread.currentThread().id
        synchronized(allThreadStates) { allThreadStates.put(threadId, openTraceSections) }
        return openTraceSections
    }
}

private val threadLocalTraceState = FakeThreadStateLocal()

object FakeTraceState {

    fun begin(sectionName: String) {
        synchronized(this) {
            traceSections.getOrPut(Thread.currentThread().id) { mutableListOf() }.add(sectionName)
        }
        threadLocalTraceState.get().add(sectionName)
    }

    fun end() {
        synchronized(this) {
            val openSectionsOnThread = traceSections[Thread.currentThread().id]
        threadLocalTraceState.get().let {
            assertFalse(
                "Attempting to close trace section on thread=${Thread.currentThread().id}, " +
                    "but there are no open sections",
                openSectionsOnThread.isNullOrEmpty()
                it.isNullOrEmpty()
            )
            // TODO: Replace with .removeLast() once available
            openSectionsOnThread!!.removeAt(openSectionsOnThread!!.lastIndex)
            it.removeAt(it.lastIndex)
        }
    }

    fun openSectionsOnCurrentThread(): Array<String> {
        return synchronized(this) {
            traceSections.getOrPut(Thread.currentThread().id) { mutableListOf() }.toTypedArray()
        }
    fun getOpenTraceSectionsOnCurrentThread(): Array<String> {
        return threadLocalTraceState.get().toTypedArray()
    }

    /**
     * Helper function for debugging; use as follows:
     * ```
     * println(FakeThreadStateLocal)
     * ```
     */
    override fun toString(): String {
        return traceSections.toString()
        val sb = StringBuilder()
        synchronized(allThreadStates) {
            allThreadStates.entries.forEach { sb.appendLine("${it.key} -> ${it.value}") }
        }
        return sb.toString()
    }
}

internal actual fun traceBegin(methodName: String) {
    TraceState.begin(methodName)
    FakeTraceState.begin(methodName)
}

internal actual fun traceEnd() {
    TraceState.end()
    FakeTraceState.end()
}

internal actual fun asyncTraceBegin(methodName: String, cookie: Int) {}
+6 −1
Original line number Diff line number Diff line
@@ -16,6 +16,11 @@

package com.android.systemui

private var isCoroutineTracingFlagEnabledForTests = true

object Flags {
    fun coroutineTracing() = true
    fun coroutineTracing() = isCoroutineTracingFlagEnabledForTests
    fun disableCoroutineTracing() {
        isCoroutineTracingFlagEnabledForTests = false
    }
}
Loading