Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 5ccd9a01 authored by Nicolo' Mazzucato's avatar Nicolo' Mazzucato
Browse files

Fix race condition in screenshot code with multiple displays

DisplayRepository.displays flow was initialized as an empty set at first when converted to a StateFlow.
Now it is a SharedFlow without any initial values: consumers of the flow will have to wait to get the first value.
As the replay number is equal to 1 and onStart emits an event, everyone collecting the flow will have something immediately and not starve.

Fixes: 306680094
Test: TakeScreenshotExecutorTest, DisplayRepositoryTest
Change-Id: I6ac7f3d6aa31f7f3fff4a5691a51bf8db33021df
parent 57cb73f2
Loading
Loading
Loading
Loading
+12 −17
Original line number Diff line number Diff line
@@ -47,6 +47,8 @@ import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.onStart
import kotlinx.coroutines.flow.shareIn
import kotlinx.coroutines.flow.stateIn

/** Provides a [Flow] of [Display] as returned by [DisplayManager]. */
@@ -54,10 +56,7 @@ interface DisplayRepository {
    /** Display change event indicating a change to the given displayId has occurred. */
    val displayChangeEvent: Flow<Int>

    /**
     * Provides a nullable set of displays. Updates when new displays have been added or removed but
     * not when a display's info has changed.
     */
    /** Provides the current set of displays. */
    val displays: Flow<Set<Display>>

    /**
@@ -112,10 +111,6 @@ constructor(
                            trySend(DisplayEvent.Changed(displayId))
                        }
                    }
                // Triggers an initial event when subscribed. This is needed to avoid getDisplays to
                // be called when this class is constructed, but only when someone subscribes to
                // this flow.
                trySend(DisplayEvent.Changed(Display.DEFAULT_DISPLAY))
                displayManager.registerDisplayListener(
                    callback,
                    backgroundHandler,
@@ -125,6 +120,7 @@ constructor(
                )
                awaitClose { displayManager.unregisterDisplayListener(callback) }
            }
            .onStart { emit(DisplayEvent.Changed(Display.DEFAULT_DISPLAY)) }
            .flowOn(backgroundCoroutineDispatcher)

    override val displayChangeEvent: Flow<Int> =
@@ -134,13 +130,9 @@ constructor(
        allDisplayEvents
            .map { getDisplays() }
            .flowOn(backgroundCoroutineDispatcher)
            .stateIn(
                applicationScope,
                started = SharingStarted.WhileSubscribed(),
                // To avoid getting displays on this object construction, they are get after the
                // first event. allDisplayEvents emits a changed event when we subscribe to it.
                initialValue = emptySet()
            )
            .shareIn(applicationScope, started = SharingStarted.WhileSubscribed(), replay = 1)

    override val displays: Flow<Set<Display>> = enabledDisplays

    private fun getDisplays(): Set<Display> =
        traceSection("DisplayRepository#getDisplays()") {
@@ -148,8 +140,6 @@ constructor(
        }

    /** Propagate to the listeners only enabled displays */
    override val displays: Flow<Set<Display>> = enabledDisplays

    private val enabledDisplayIds: Flow<Set<Int>> =
        enabledDisplays
            .map { enabledDisplaysSet -> enabledDisplaysSet.map { it.displayId }.toSet() }
@@ -251,6 +241,7 @@ constructor(
                val id = pendingDisplayIds.maxOrNull() ?: return@map null
                object : DisplayRepository.PendingDisplay {
                    override val id = id

                    override suspend fun enable() {
                        traceSection("DisplayRepository#enable($id)") {
                            if (DEBUG) {
@@ -303,8 +294,12 @@ constructor(
private interface DisplayConnectionListener : DisplayListener {

    override fun onDisplayConnected(id: Int) {}

    override fun onDisplayDisconnected(id: Int) {}

    override fun onDisplayAdded(id: Int) {}

    override fun onDisplayRemoved(id: Int) {}

    override fun onDisplayChanged(id: Int) {}
}
+5 −13
Original line number Diff line number Diff line
@@ -16,10 +16,7 @@ import com.android.systemui.screenshot.TakeScreenshotService.RequestCallback
import java.util.function.Consumer
import javax.inject.Inject
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch

/**
@@ -40,12 +37,7 @@ constructor(
    private val screenshotNotificationControllerFactory: ScreenshotNotificationsController.Factory,
) {

    private lateinit var displays: StateFlow<Set<Display>>
    private val displaysCollectionJob: Job =
        mainScope.launch {
            displays = displayRepository.displays.stateIn(this, SharingStarted.Eagerly, emptySet())
        }

    private val displays = displayRepository.displays
    private val screenshotControllers = mutableMapOf<Int, ScreenshotController>()
    private val notificationControllers = mutableMapOf<Int, ScreenshotNotificationsController>()

@@ -63,6 +55,7 @@ constructor(
        val displayIds = getDisplaysToScreenshot(screenshotRequest.type)
        val resultCallbackWrapper = MultiResultCallbackWrapper(requestCallback)
        displayIds.forEach { displayId: Int ->
            Log.d(TAG, "Executing screenshot for display $displayId")
            dispatchToController(
                rawScreenshotData = ScreenshotData.fromRequest(screenshotRequest, displayId),
                onSaved =
@@ -126,12 +119,12 @@ constructor(
        callback.reportError()
    }

    private fun getDisplaysToScreenshot(requestType: Int): List<Int> {
    private suspend fun getDisplaysToScreenshot(requestType: Int): List<Int> {
        return if (requestType == TAKE_SCREENSHOT_PROVIDED_IMAGE) {
            // If this is a provided image, let's show the UI on the default display only.
            listOf(Display.DEFAULT_DISPLAY)
        } else {
            displays.value.filter { it.type in ALLOWED_DISPLAY_TYPES }.map { it.displayId }
            displays.first().filter { it.type in ALLOWED_DISPLAY_TYPES }.map { it.displayId }
        }
    }

@@ -163,7 +156,6 @@ constructor(
            screenshotController.onDestroy()
        }
        screenshotControllers.clear()
        displaysCollectionJob.cancel()
    }

    private fun getScreenshotController(id: Int): ScreenshotController {
+5 −4
Original line number Diff line number Diff line
@@ -43,9 +43,10 @@ fun createPendingDisplay(id: Int = 0): DisplayRepository.PendingDisplay =
    mock<DisplayRepository.PendingDisplay> { whenever(this.id).thenReturn(id) }

/** Fake [DisplayRepository] implementation for testing. */
class FakeDisplayRepository() : DisplayRepository {
    private val flow = MutableSharedFlow<Set<Display>>()
    private val pendingDisplayFlow = MutableSharedFlow<DisplayRepository.PendingDisplay?>()
class FakeDisplayRepository : DisplayRepository {
    private val flow = MutableSharedFlow<Set<Display>>(replay = 1)
    private val pendingDisplayFlow =
        MutableSharedFlow<DisplayRepository.PendingDisplay?>(replay = 1)

    /** Emits [value] as [displays] flow value. */
    suspend fun emit(value: Set<Display>) = flow.emit(value)
@@ -59,7 +60,7 @@ class FakeDisplayRepository() : DisplayRepository {
    override val pendingDisplay: Flow<DisplayRepository.PendingDisplay?>
        get() = pendingDisplayFlow

    private val _displayChangeEvent = MutableSharedFlow<Int>()
    private val _displayChangeEvent = MutableSharedFlow<Int>(replay = 1)
    override val displayChangeEvent: Flow<Int> = _displayChangeEvent
    suspend fun emitDisplayChangeEvent(displayId: Int) = _displayChangeEvent.emit(displayId)
}