Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e5041f55 authored by Peter Kalauskas's avatar Peter Kalauskas
Browse files

Fix malformed coroutine traces

Flag: ACONFIG com.android.systemui.coroutine_tracing DEVELOPMENT
Bug: 289353932
Fixes: 339179378
Test: atest tracinglib-host-test tracinglib-robolectric-test
Change-Id: I9077107a2a45fb00aa94ab9429ad164c5b8d1e6b
parent b74f6c83
Loading
Loading
Loading
Loading
+4 −5
Original line number Diff line number Diff line
@@ -176,12 +176,11 @@ inline fun <T> traceCoroutine(spanName: () -> String, block: () -> T): T {
    // For coroutine tracing to work, trace spans must be added and removed even when
    // tracing is not active (i.e. when TRACE_TAG_APP is disabled). Otherwise, when the
    // coroutine resumes when tracing is active, we won't know its name.
    val tracer = CURRENT_TRACE.get()

    val traceData = traceThreadLocal.get()
    val asyncTracingEnabled = isEnabled()
    val spanString = if (tracer != null || asyncTracingEnabled) spanName() else "<none>"
    val spanString = if (traceData != null || asyncTracingEnabled) spanName() else "<none>"

    tracer?.beginSpan(spanString)
    traceData?.beginSpan(spanString)

    // Also trace to the "Coroutines" async track. This makes it easy to see the duration of
    // coroutine spans. When the coroutine_tracing flag is enabled, those same names will
@@ -192,7 +191,7 @@ inline fun <T> traceCoroutine(spanName: () -> String, block: () -> T): T {
        return block()
    } finally {
        if (asyncTracingEnabled) asyncTraceForTrackEnd(DEFAULT_TRACK_NAME, spanString, cookie)
        tracer?.endSpan()
        traceData?.endSpan()
    }
}

+50 −34
Original line number Diff line number Diff line
@@ -20,36 +20,35 @@ import com.android.systemui.Flags.coroutineTracing
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
import kotlinx.coroutines.CopyableThreadContextElement
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineName

private const val DEBUG = false

/** Log a message with a tag indicating the current thread ID */
private inline fun debug(message: () -> String) {
    if (DEBUG) println("Thread #${Thread.currentThread().id}: ${message()}")
}

/** Use a final subclass to avoid virtual calls (b/316642146). */
@PublishedApi internal class ThreadStateLocal : ThreadLocal<TraceData?>()
@PublishedApi internal class TraceDataThreadLocal : ThreadLocal<TraceData?>()

/**
 * Thread-local storage for giving each thread a unique [TraceData]. It can only be used when paired
 * with a [TraceContextElement].
 *
 * [CURRENT_TRACE] will be `null` if either 1) we aren't in a coroutine, or 2) the current coroutine
 * context does not have [TraceContextElement]. In both cases, writing to this thread-local would be
 * undefined behavior if it were not null, which is why we use null as the default value rather than
 * an empty TraceData.
 * [traceThreadLocal] will be `null` if either 1) we aren't in a coroutine, or 2) the current
 * coroutine context does not have [TraceContextElement]. In both cases, writing to this
 * thread-local would be undefined behavior if it were not null, which is why we use null as the
 * default value rather than an empty TraceData.
 *
 * @see traceCoroutine
 */
@PublishedApi internal val CURRENT_TRACE = ThreadStateLocal()
@PublishedApi internal val traceThreadLocal = TraceDataThreadLocal()

/**
 * Returns a new [CoroutineContext] used for tracing. Used to hide internal implementation details.
 */
fun createCoroutineTracingContext(): CoroutineContext {
    return if (coroutineTracing()) TraceContextElement() else EmptyCoroutineContext
}

private fun CoroutineContext.nameForTrace(): String {
    val dispatcherStr = "${this[CoroutineDispatcher]}"
    val nameStr = "${this[CoroutineName]?.name}"
    return "CoroutineDispatcher: $dispatcherStr; CoroutineName: $nameStr"
    return if (coroutineTracing()) TraceContextElement(TraceData()) else EmptyCoroutineContext
}

/**
@@ -60,14 +59,18 @@ private fun CoroutineContext.nameForTrace(): String {
 *
 * @see traceCoroutine
 */
internal class TraceContextElement(@PublishedApi internal val traceData: TraceData = TraceData()) :
internal class TraceContextElement(internal val traceData: TraceData? = TraceData()) :
    CopyableThreadContextElement<TraceData?> {

    @PublishedApi internal companion object Key : CoroutineContext.Key<TraceContextElement>
    internal companion object Key : CoroutineContext.Key<TraceContextElement>

    override val key: CoroutineContext.Key<*>
        get() = Key

    init {
        debug { "$this #init" }
    }

    /**
     * This function is invoked before the coroutine is resumed on the current thread. When a
     * multi-threaded dispatcher is used, calls to `updateThreadContext` may happen in parallel to
@@ -84,13 +87,16 @@ internal class TraceContextElement(@PublishedApi internal val traceData: TraceDa
     * `^` is a suspension point)
     */
    override fun updateThreadContext(context: CoroutineContext): TraceData? {
        val oldState = CURRENT_TRACE.get()
        CURRENT_TRACE.set(traceData)
        val oldState = traceThreadLocal.get()
        debug { "$this #updateThreadContext oldState=$oldState" }
        if (oldState !== traceData) {
            traceThreadLocal.set(traceData)
            // Calls to `updateThreadContext` will not happen in parallel on the same context, and
            // they cannot happen before the prior suspension point. Additionally,
            // `restoreThreadContext` does not modify `traceData`, so it is safe to iterate over the
            // collection here:
        traceData.beginAllOnThread()
            traceData?.beginAllOnThread()
        }
        return oldState
    }

@@ -113,7 +119,7 @@ internal class TraceContextElement(@PublishedApi internal val traceData: TraceDa
     * ```
     * Thread #1 |                                 [restoreThreadContext]
     * --------------------------------------------------------------------------------------------
     * Thread #2 |     [updateThreadContext]...x.......^[restoreThreadContext]
     * Thread #2 |     [updateThreadContext]...x....x..^[restoreThreadContext]
     * ```
     *
     * (`...` indicate coroutine body is running; whitespace indicates the thread is not scheduled;
@@ -122,20 +128,30 @@ internal class TraceContextElement(@PublishedApi internal val traceData: TraceDa
     * ```
     */
    override fun restoreThreadContext(context: CoroutineContext, oldState: TraceData?) {
        // We should normally should not use the `TraceData` object here because it may have been
        // modified on another thread after the last suspension point, but `endAllOnThread()` uses a
        // `ThreadLocal` internally and is thread-safe:
        CURRENT_TRACE.get()?.endAllOnThread()
        CURRENT_TRACE.set(oldState)
        debug { "$this#restoreThreadContext restoring=$oldState" }
        // We not use the `TraceData` object here because it may have been modified on another
        // thread after the last suspension point. This is why we use a [TraceStateHolder]:
        // so we can end the correct number of trace sections, restoring the thread to its state
        // prior to the last call to [updateThreadContext].
        if (oldState !== traceThreadLocal.get()) {
            traceData?.endAllOnThread()
            traceThreadLocal.set(oldState)
        }
    }

    override fun copyForChild(): CopyableThreadContextElement<TraceData?> {
        return TraceContextElement(CURRENT_TRACE.get()?.clone() ?: TraceData())
        debug { "$this #copyForChild" }
        return TraceContextElement(traceData?.clone())
    }

    override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
        debug { "$this #mergeForChild" }
        // For our use-case, we always give precedence to the parent trace context, and the
        // child context is ignored
        return TraceContextElement(CURRENT_TRACE.get()?.clone() ?: TraceData())
        // child context (overwritingElement) is ignored
        return TraceContextElement(traceData?.clone())
    }

    override fun toString(): String {
        return "TraceContextElement@${hashCode().toHexString()}[$traceData]"
    }
}
+30 −22
Original line number Diff line number Diff line
@@ -46,24 +46,15 @@ internal class TraceData(
    internal val slices: ArrayDeque<TraceSection> = ArrayDeque(),
) : Cloneable {

    companion object {
    /**
         * ThreadLocal counter for how many open trace sections there are. This is needed because it
         * is possible that on a multi-threaded dispatcher, one of the threads could be slow, and
     * ThreadLocal counter for how many open trace sections there are. This is needed because it is
     * possible that on a multi-threaded dispatcher, one of the threads could be slow, and
     * `restoreThreadContext` might be invoked _after_ the coroutine has already resumed and
         * modified TraceData - either adding or removing trace sections and changing the count. If
         * we did not store this thread-locally, then we would incorrectly end too many or too few
         * trace sections.
     * modified TraceData - either adding or removing trace sections and changing the count. If we
     * did not store this thread-locally, then we would incorrectly end too many or too few trace
     * sections.
     */
        @PublishedApi internal val openSliceCount = TraceCountThreadLocal()

        /**
         * Whether to add additional checks to the coroutine machinery, throwing a
         * `ConcurrentModificationException` if TraceData is modified from the wrong thread. This
         * should only be set for testing.
         */
        internal var strictModeForTesting: Boolean = false
    }
    private val openSliceCount = TraceCountThreadLocal()

    /** Adds current trace slices back to the current thread. Called when coroutine is resumed. */
    internal fun beginAllOnThread() {
@@ -108,6 +99,8 @@ internal class TraceData(
            slices.pop()
            openSliceCount.set(slices.size)
            endSlice()
        } else if (strictModeForTesting) {
            throw IllegalStateException(INVALID_SPAN_END_CALL_ERROR_MESSAGE)
        }
    }

@@ -116,18 +109,33 @@ internal class TraceData(
     * state is isolated from the parent.
     */
    public override fun clone(): TraceData {
        strictModeCheck()
        return TraceData(slices.clone())
    }

    override fun toString(): String {
        return "TraceData@${hashCode().toHexString()}-size=${slices.size}"
    }

    private fun strictModeCheck() {
        if (strictModeForTesting && CURRENT_TRACE.get() !== this) {
            throw ConcurrentModificationException(strictModeErrorMessage)
        if (strictModeForTesting && traceThreadLocal.get() !== this) {
            throw ConcurrentModificationException(STRICT_MODE_ERROR_MESSAGE)
        }
    }

    companion object {
        /**
         * Whether to add additional checks to the coroutine machinery, throwing a
         * `ConcurrentModificationException` if TraceData is modified from the wrong thread. This
         * should only be set for testing.
         */
        internal var strictModeForTesting: Boolean = false
    }
}

private const val INVALID_SPAN_END_CALL_ERROR_MESSAGE =
    "TraceData#endSpan called when there were no active trace sections."

private const val strictModeErrorMessage =
private const val STRICT_MODE_ERROR_MESSAGE =
    "TraceData should only be accessed using " +
        "the ThreadLocal: CURRENT_TRACE.get(). Accessing TraceData by other means, such as " +
        "through the TraceContextElement's property may lead to concurrent modification."
+47 −25
Original line number Diff line number Diff line
@@ -18,6 +18,13 @@ package com.android.app.tracing

import org.junit.Assert.assertFalse

const val DEBUG = false

/** Log a message with a tag indicating the current thread ID */
private fun debug(message: String) {
    if (DEBUG) println("Thread #${Thread.currentThread().id}: $message")
}

@PublishedApi
internal actual fun isEnabled(): Boolean {
    return true
@@ -29,39 +36,38 @@ internal actual fun traceCounter(counterName: String, counterValue: Int) {
    traceCounters[counterName] = counterValue
}

private val allThreadStates = HashMap<Long, MutableList<String>>()
object FakeTraceState {

private class FakeThreadStateLocal : ThreadLocal<MutableList<String>>() {
    override fun initialValue(): MutableList<String> {
        val openTraceSections = mutableListOf<String>()
    private val allThreadStates = hashMapOf<Long, MutableList<String>>()
    fun begin(sectionName: String) {
        val threadId = Thread.currentThread().id
        synchronized(allThreadStates) { allThreadStates.put(threadId, openTraceSections) }
        return openTraceSections
        synchronized(allThreadStates) {
            if (allThreadStates.containsKey(threadId)) {
                allThreadStates[threadId]!!.add(sectionName)
            } else {
                allThreadStates[threadId] = mutableListOf(sectionName)
            }
        }

private val threadLocalTraceState = FakeThreadStateLocal()

object FakeTraceState {

    fun begin(sectionName: String) {
        threadLocalTraceState.get().add(sectionName)
    }

    fun end() {
        threadLocalTraceState.get().let {
        val threadId = Thread.currentThread().id
        synchronized(allThreadStates) {
            assertFalse(
                "Attempting to close trace section on thread=${Thread.currentThread().id}, " +
                "Attempting to close trace section on thread=$threadId, " +
                    "but there are no open sections",
                it.isNullOrEmpty()
                allThreadStates[threadId].isNullOrEmpty()
            )
            // TODO: Replace with .removeLast() once available
            it.removeAt(it.lastIndex)
            allThreadStates[threadId]!!.removeAt(allThreadStates[threadId]!!.lastIndex)
        }
    }

    fun getOpenTraceSectionsOnCurrentThread(): Array<String> {
        return threadLocalTraceState.get().toTypedArray()
        val threadId = Thread.currentThread().id
        synchronized(allThreadStates) {
            return allThreadStates[threadId]?.toTypedArray() ?: emptyArray()
        }
    }

    /**
@@ -80,23 +86,39 @@ object FakeTraceState {
}

internal actual fun traceBegin(methodName: String) {
    debug("traceBegin: name=$methodName")
    FakeTraceState.begin(methodName)
}

internal actual fun traceEnd() {
    debug("traceEnd")
    FakeTraceState.end()
}

internal actual fun asyncTraceBegin(methodName: String, cookie: Int) {}
internal actual fun asyncTraceBegin(methodName: String, cookie: Int) {
    debug("asyncTraceBegin: name=$methodName cookie=${cookie.toHexString()}")
}

internal actual fun asyncTraceEnd(methodName: String, cookie: Int) {}
internal actual fun asyncTraceEnd(methodName: String, cookie: Int) {
    debug("asyncTraceEnd: name=$methodName cookie=${cookie.toHexString()}")
}

@PublishedApi
internal actual fun asyncTraceForTrackBegin(trackName: String, methodName: String, cookie: Int) {}
internal actual fun asyncTraceForTrackBegin(trackName: String, methodName: String, cookie: Int) {
    debug(
        "asyncTraceForTrackBegin: track=$trackName name=$methodName cookie=${cookie.toHexString()}"
    )
}

@PublishedApi
internal actual fun asyncTraceForTrackEnd(trackName: String, methodName: String, cookie: Int) {}
internal actual fun asyncTraceForTrackEnd(trackName: String, methodName: String, cookie: Int) {
    debug("asyncTraceForTrackEnd: track=$trackName name=$methodName cookie=${cookie.toHexString()}")
}

internal actual fun instant(eventName: String) {}
internal actual fun instant(eventName: String) {
    debug("instant: name=$eventName")
}

internal actual fun instantForTrack(trackName: String, eventName: String) {}
internal actual fun instantForTrack(trackName: String, eventName: String) {
    debug("instantForTrack: track=$trackName name=$eventName")
}
+309 −163

File changed.

Preview size limit exceeded, changes collapsed.