Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 0322e156 authored by Nicolo' Mazzucato's avatar Nicolo' Mazzucato Committed by Matt Casey
Browse files

Fix Race in ScreenshotController

TakeScreenshotExecutor was creating the ScreenshotController by passing the displayId. However, if the display was disconnected before ScreenshotController was fully initialized, it ended up crashing (exactly when creating the display context).

Now, ScreenshotController is not created unless we have a not-null display in TakescreenshotExecutor.

Test: TakeScreenshotExecutorTest
Flag: None
Fixes: 329147174
Merged-In: I0c7c9ec4e7d3dfc8896bd325bd250d372a6c0891
Change-Id: Ib5dbf463a89342f1f61e0d84f4031d7c72f2b87d
parent a922ac82
Loading
Loading
Loading
Loading
+22 −29
Original line number Diff line number Diff line
@@ -47,7 +47,6 @@ import android.content.res.Configuration;
import android.graphics.Bitmap;
import android.graphics.Insets;
import android.graphics.Rect;
import android.hardware.display.DisplayManager;
import android.net.Uri;
import android.os.Process;
import android.os.UserHandle;
@@ -209,8 +208,7 @@ public class ScreenshotController {
    @Nullable
    private final ScreenshotSoundController mScreenshotSoundController;
    private final PhoneWindow mWindow;
    private final DisplayManager mDisplayManager;
    private final int mDisplayId;
    private final Display mDisplay;
    private final ScrollCaptureExecutor mScrollCaptureExecutor;
    private final ScreenshotNotificationSmartActionsProvider
            mScreenshotNotificationSmartActionsProvider;
@@ -250,7 +248,6 @@ public class ScreenshotController {
    @AssistedInject
    ScreenshotController(
            Context context,
            DisplayManager displayManager,
            WindowManager windowManager,
            FeatureFlags flags,
            ScreenshotViewProxy.Factory viewProxyFactory,
@@ -272,12 +269,14 @@ public class ScreenshotController {
            AssistContentRequester assistContentRequester,
            MessageContainerController messageContainerController,
            Provider<ScreenshotSoundController> screenshotSoundController,
            @Assisted int displayId,
            @Assisted Display display,
            @Assisted boolean showUIOnExternalDisplay
    ) {
        mScreenshotSmartActions = screenshotSmartActions;
        mWindowManager = windowManager;
        mActionsProviderFactory = actionsProviderFactory;
        mNotificationsController = screenshotNotificationsControllerFactory.create(displayId);
        mNotificationsController = screenshotNotificationsControllerFactory.create(
                display.getDisplayId());
        mUiEventLogger = uiEventLogger;
        mImageExporter = imageExporter;
        mImageCapture = imageCapture;
@@ -290,12 +289,8 @@ public class ScreenshotController {

        mScreenshotHandler = timeoutHandler;
        mScreenshotHandler.setDefaultTimeoutMillis(SCREENSHOT_CORNER_DEFAULT_TIMEOUT_MILLIS);


        mDisplayId = displayId;
        mDisplayManager = displayManager;
        mWindowManager = windowManager;
        final Context displayContext = context.createDisplayContext(getDisplay());
        mDisplay = display;
        final Context displayContext = context.createDisplayContext(display);
        mContext = (WindowContext) displayContext.createWindowContext(TYPE_SCREENSHOT, null);
        mFlags = flags;
        mActionIntentExecutor = actionIntentExecutor;
@@ -303,7 +298,7 @@ public class ScreenshotController {
        mMessageContainerController = messageContainerController;
        mAssistContentRequester = assistContentRequester;

        mViewProxy = viewProxyFactory.getProxy(mContext, mDisplayId);
        mViewProxy = viewProxyFactory.getProxy(mContext, mDisplay.getDisplayId());

        mScreenshotHandler.setOnTimeoutRunnable(() -> {
            if (DEBUG_UI) {
@@ -329,7 +324,7 @@ public class ScreenshotController {
                });

        // Sound is only reproduced from the controller of the default display.
        if (displayId == Display.DEFAULT_DISPLAY) {
        if (display.getDisplayId() == Display.DEFAULT_DISPLAY) {
            mScreenshotSoundController = screenshotSoundController.get();
        } else {
            mScreenshotSoundController = null;
@@ -357,7 +352,7 @@ public class ScreenshotController {
        if (screenshot.getType() == WindowManager.TAKE_SCREENSHOT_FULLSCREEN) {
            Rect bounds = getFullScreenRect();
            screenshot.setBitmap(
                    mImageCapture.captureDisplay(mDisplayId, bounds));
                    mImageCapture.captureDisplay(mDisplay.getDisplayId(), bounds));
            screenshot.setScreenBounds(bounds);
        }

@@ -460,7 +455,7 @@ public class ScreenshotController {
    }

    private boolean shouldShowUi() {
        return mDisplayId == Display.DEFAULT_DISPLAY || mShowUIOnExternalDisplay;
        return mDisplay.getDisplayId() == Display.DEFAULT_DISPLAY || mShowUIOnExternalDisplay;
    }

    void prepareViewForNewScreenshot(@NonNull ScreenshotData screenshot, String oldPackageName) {
@@ -619,7 +614,7 @@ public class ScreenshotController {

    private void requestScrollCapture(UserHandle owner) {
        mScrollCaptureExecutor.requestScrollCapture(
                mDisplayId,
                mDisplay.getDisplayId(),
                mWindow.getDecorView().getWindowToken(),
                (response) -> {
                    mUiEventLogger.log(ScreenshotEvent.SCREENSHOT_LONG_SCREENSHOT_IMPRESSION,
@@ -642,7 +637,8 @@ public class ScreenshotController {
        }
        mUiEventLogger.log(ScreenshotEvent.SCREENSHOT_LONG_SCREENSHOT_REQUESTED, 0,
                response.getPackageName());
        Bitmap newScreenshot = mImageCapture.captureDisplay(mDisplayId, getFullScreenRect());
        Bitmap newScreenshot = mImageCapture.captureDisplay(mDisplay.getDisplayId(),
                getFullScreenRect());
        if (newScreenshot == null) {
            Log.e(TAG, "Failed to capture current screenshot for scroll transition!");
            return;
@@ -820,7 +816,8 @@ public class ScreenshotController {
    private void saveScreenshotInBackground(
            ScreenshotData screenshot, UUID requestId, Consumer<Uri> finisher) {
        ListenableFuture<ImageExporter.Result> future = mImageExporter.export(mBgExecutor,
                requestId, screenshot.getBitmap(), screenshot.getUserOrDefault(), mDisplayId);
                requestId, screenshot.getBitmap(), screenshot.getUserOrDefault(),
                mDisplay.getDisplayId());
        future.addListener(() -> {
            try {
                ImageExporter.Result result = future.get();
@@ -862,7 +859,7 @@ public class ScreenshotController {
        data.mActionsReadyListener = actionsReadyListener;
        data.mQuickShareActionsReadyListener = quickShareActionsReadyListener;
        data.owner = owner;
        data.displayId = mDisplayId;
        data.displayId = mDisplay.getDisplayId();

        if (mSaveInBgTask != null) {
            // just log success/failure for the pre-existing screenshot
@@ -987,13 +984,9 @@ public class ScreenshotController {
        }
    }

    private Display getDisplay() {
        return mDisplayManager.getDisplay(mDisplayId);
    }

    private Rect getFullScreenRect() {
        DisplayMetrics displayMetrics = new DisplayMetrics();
        getDisplay().getRealMetrics(displayMetrics);
        mDisplay.getRealMetrics(displayMetrics);
        return new Rect(0, 0, displayMetrics.widthPixels, displayMetrics.heightPixels);
    }

@@ -1029,10 +1022,10 @@ public class ScreenshotController {
        /**
         * Creates an instance of the controller for that specific displayId.
         *
         * @param displayId:               display to capture
         * @param showUIOnExternalDisplay: Whether the UI should be shown if this is an external
         * @param display                 Display to capture.
         * @param showUIOnExternalDisplay Whether the UI should be shown if this is an external
         *                                display.
         */
        ScreenshotController create(int displayId, boolean showUIOnExternalDisplay);
        ScreenshotController create(Display display, boolean showUIOnExternalDisplay);
    }
}
+13 −10
Original line number Diff line number Diff line
@@ -52,11 +52,13 @@ constructor(
        onSaved: (Uri?) -> Unit,
        requestCallback: RequestCallback
    ) {
        val displayIds = getDisplaysToScreenshot(screenshotRequest.type)
        val displays = getDisplaysToScreenshot(screenshotRequest.type)
        val resultCallbackWrapper = MultiResultCallbackWrapper(requestCallback)
        displayIds.forEach { displayId: Int ->
        displays.forEach { display: Display ->
            val displayId = display.displayId
            Log.d(TAG, "Executing screenshot for display $displayId")
            dispatchToController(
                display,
                rawScreenshotData = ScreenshotData.fromRequest(screenshotRequest, displayId),
                onSaved =
                    if (displayId == Display.DEFAULT_DISPLAY) {
@@ -69,6 +71,7 @@ constructor(

    /** All logging should be triggered only by this method. */
    private suspend fun dispatchToController(
        display: Display,
        rawScreenshotData: ScreenshotData,
        onSaved: (Uri?) -> Unit,
        callback: RequestCallback
@@ -88,8 +91,7 @@ constructor(
        logScreenshotRequested(screenshotData)
        Log.d(TAG, "Screenshot request: $screenshotData")
        try {
            getScreenshotController(screenshotData.displayId)
                .handleScreenshot(screenshotData, onSaved, callback)
            getScreenshotController(display).handleScreenshot(screenshotData, onSaved, callback)
        } catch (e: IllegalStateException) {
            Log.e(TAG, "Error while ScreenshotController was handling ScreenshotData!", e)
            onFailedScreenshotRequest(screenshotData, callback)
@@ -119,12 +121,13 @@ constructor(
        callback.reportError()
    }

    private suspend fun getDisplaysToScreenshot(requestType: Int): List<Int> {
    private suspend fun getDisplaysToScreenshot(requestType: Int): List<Display> {
        val allDisplays = displays.first()
        return if (requestType == TAKE_SCREENSHOT_PROVIDED_IMAGE) {
            // If this is a provided image, let's show the UI on the default display only.
            listOf(Display.DEFAULT_DISPLAY)
            allDisplays.filter { it.displayId == Display.DEFAULT_DISPLAY }
        } else {
            displays.first().filter { it.type in ALLOWED_DISPLAY_TYPES }.map { it.displayId }
            allDisplays.filter { it.type in ALLOWED_DISPLAY_TYPES }
        }
    }

@@ -158,9 +161,9 @@ constructor(
        screenshotControllers.clear()
    }

    private fun getScreenshotController(id: Int): ScreenshotController {
        return screenshotControllers.computeIfAbsent(id) {
            screenshotControllerFactory.create(id, /* showUIOnExternalDisplay= */ false)
    private fun getScreenshotController(display: Display): ScreenshotController {
        return screenshotControllers.computeIfAbsent(display.displayId) {
            screenshotControllerFactory.create(display, /* showUIOnExternalDisplay= */ false)
        }
    }

+5 −2
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@ import android.content.ComponentName;
import android.content.Context;
import android.content.Intent;
import android.content.IntentFilter;
import android.hardware.display.DisplayManager;
import android.net.Uri;
import android.os.Handler;
import android.os.IBinder;
@@ -116,7 +117,8 @@ public class TakeScreenshotService extends Service {
            UiEventLogger uiEventLogger,
            ScreenshotNotificationsController.Factory notificationsControllerFactory,
            Context context, @Background Executor bgExecutor, FeatureFlags featureFlags,
            RequestProcessor processor, Provider<TakeScreenshotExecutor> takeScreenshotExecutor) {
            RequestProcessor processor, Provider<TakeScreenshotExecutor> takeScreenshotExecutor,
            DisplayManager displayManager) {
        if (DEBUG_SERVICE) {
            Log.d(TAG, "new " + this);
        }
@@ -134,7 +136,8 @@ public class TakeScreenshotService extends Service {
            mScreenshot = null;
        } else {
            mScreenshot = screenshotControllerFactory.create(
                    Display.DEFAULT_DISPLAY, /* showUIOnExternalDisplay= */ false);
                    displayManager.getDisplay(
                            Display.DEFAULT_DISPLAY), /* showUIOnExternalDisplay= */ false);
        }
    }

+14 −8
Original line number Diff line number Diff line
@@ -69,8 +69,9 @@ class TakeScreenshotExecutorTest : SysuiTestCase() {

    @Before
    fun setUp() {
        whenever(controllerFactory.create(eq(0), any())).thenReturn(controller0)
        whenever(controllerFactory.create(eq(1), any())).thenReturn(controller1)
        whenever(controllerFactory.create(any(), any())).thenAnswer {
            if (it.getArgument<Display>(0).displayId == 0) controller0 else controller1
        }
        whenever(notificationControllerFactory.create(eq(0))).thenReturn(notificationsController0)
        whenever(notificationControllerFactory.create(eq(1))).thenReturn(notificationsController1)
    }
@@ -78,12 +79,14 @@ class TakeScreenshotExecutorTest : SysuiTestCase() {
    @Test
    fun executeScreenshots_severalDisplays_callsControllerForEachOne() =
        testScope.runTest {
            setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
            val internalDisplay = display(TYPE_INTERNAL, id = 0)
            val externalDisplay = display(TYPE_EXTERNAL, id = 1)
            setDisplays(internalDisplay, externalDisplay)
            val onSaved = { _: Uri? -> }
            screenshotExecutor.executeScreenshots(createScreenshotRequest(), onSaved, callback)

            verify(controllerFactory).create(eq(0), any())
            verify(controllerFactory).create(eq(1), any())
            verify(controllerFactory).create(eq(internalDisplay), any())
            verify(controllerFactory).create(eq(externalDisplay), any())

            val capturer = ArgumentCaptor<ScreenshotData>()

@@ -107,7 +110,9 @@ class TakeScreenshotExecutorTest : SysuiTestCase() {
    @Test
    fun executeScreenshots_providedImageType_callsOnlyDefaultDisplayController() =
        testScope.runTest {
            setDisplays(display(TYPE_INTERNAL, id = 0), display(TYPE_EXTERNAL, id = 1))
            val internalDisplay = display(TYPE_INTERNAL, id = 0)
            val externalDisplay = display(TYPE_EXTERNAL, id = 1)
            setDisplays(internalDisplay, externalDisplay)
            val onSaved = { _: Uri? -> }
            screenshotExecutor.executeScreenshots(
                createScreenshotRequest(TAKE_SCREENSHOT_PROVIDED_IMAGE),
@@ -115,8 +120,8 @@ class TakeScreenshotExecutorTest : SysuiTestCase() {
                callback
            )

            verify(controllerFactory).create(eq(0), any())
            verify(controllerFactory, never()).create(eq(1), any())
            verify(controllerFactory).create(eq(internalDisplay), any())
            verify(controllerFactory, never()).create(eq(externalDisplay), any())

            val capturer = ArgumentCaptor<ScreenshotData>()

@@ -473,6 +478,7 @@ class TakeScreenshotExecutorTest : SysuiTestCase() {
        var processed: ScreenshotData? = null
        var toReturn: ScreenshotData? = null
        var shouldThrowException = false

        override suspend fun process(screenshot: ScreenshotData): ScreenshotData {
            if (shouldThrowException) throw RequestProcessorException("")
            processed = screenshot
+5 −1
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@ import android.app.admin.DevicePolicyManager
import android.app.admin.DevicePolicyResources.Strings.SystemUi.SCREENSHOT_BLOCKED_BY_ADMIN
import android.app.admin.DevicePolicyResourcesManager
import android.content.ComponentName
import android.hardware.display.DisplayManager
import android.os.UserHandle
import android.os.UserManager
import android.testing.AndroidTestingRunner
@@ -38,6 +39,7 @@ import com.android.systemui.screenshot.TakeScreenshotService.RequestCallback
import com.android.systemui.util.mockito.any
import com.android.systemui.util.mockito.eq
import com.android.systemui.util.mockito.mock
import com.android.systemui.util.mockito.nullable
import com.android.systemui.util.mockito.whenever
import java.util.function.Consumer
import org.junit.Assert.assertEquals
@@ -68,6 +70,7 @@ class TakeScreenshotServiceTest : SysuiTestCase() {
    private val notificationsControllerFactory = mock<ScreenshotNotificationsController.Factory>()
    private val notificationsController = mock<ScreenshotNotificationsController>()
    private val callback = mock<RequestCallback>()
    private val displayManager = mock<DisplayManager>()

    private val eventLogger = UiEventLoggerFake()
    private val flags = FakeFeatureFlags()
@@ -87,7 +90,7 @@ class TakeScreenshotServiceTest : SysuiTestCase() {
            )
            .thenReturn(false)
        whenever(userManager.isUserUnlocked).thenReturn(true)
        whenever(controllerFactory.create(any(), any())).thenReturn(controller)
        whenever(controllerFactory.create(nullable<Display>(), any())).thenReturn(controller)
        whenever(notificationsControllerFactory.create(any())).thenReturn(notificationsController)

        // Stub request processor as a synchronous no-op for tests with the flag enabled
@@ -331,6 +334,7 @@ class TakeScreenshotServiceTest : SysuiTestCase() {
                flags,
                requestProcessor,
                { takeScreenshotExecutor },
                displayManager,
            )
        service.attach(
            mContext,
Loading