Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit aacde147 authored by Fabian Kozynski's avatar Fabian Kozynski
Browse files

Added Lock for mCallbacks and test for Concurrency

mCallbacks in CastControllerImpl is now
thread-locked by a guard of the same name.

The test implemented in CastControllerImplTest fails
in previous builds (bug 79419738) and passes now.
In order to properly test,
CastControllerImpl::fireOnCastDevicesChanged is
now VisibleForTesting.

Change-Id: I4160938b2da1749a4370d902e314deaf445cda1a
Fixes: 79419738
Test: CastControllerImplTest::testConcurrencyOnMCallback
parent 8a8f138d
Loading
Loading
Loading
Loading
+30 −10
Original line number Diff line number Diff line
@@ -30,6 +30,7 @@ import android.util.ArrayMap;
import android.util.ArraySet;
import android.util.Log;

import com.android.internal.annotations.GuardedBy;
import com.android.systemui.R;

import java.io.FileDescriptor;
@@ -41,12 +42,16 @@ import java.util.UUID;

import static android.media.MediaRouter.ROUTE_TYPE_REMOTE_DISPLAY;

import androidx.annotation.VisibleForTesting;


/** Platform implementation of the cast controller. **/
public class CastControllerImpl implements CastController {
    private static final String TAG = "CastController";
    private static final boolean DEBUG = Log.isLoggable(TAG, Log.DEBUG);

    private final Context mContext;
    @GuardedBy("mCallbacks")
    private final ArrayList<Callback> mCallbacks = new ArrayList<Callback>();
    private final MediaRouter mMediaRouter;
    private final ArrayMap<String, RouteInfo> mRoutes = new ArrayMap<>();
@@ -72,7 +77,7 @@ public class CastControllerImpl implements CastController {
        pw.println("CastController state:");
        pw.print("  mDiscovering="); pw.println(mDiscovering);
        pw.print("  mCallbackRegistered="); pw.println(mCallbackRegistered);
        pw.print("  mCallbacks.size="); pw.println(mCallbacks.size());
        pw.print("  mCallbacks.size="); synchronized (mCallbacks) {pw.println(mCallbacks.size());}
        pw.print("  mRoutes.size="); pw.println(mRoutes.size());
        for (int i = 0; i < mRoutes.size(); i++) {
            final RouteInfo route = mRoutes.valueAt(i);
@@ -83,7 +88,9 @@ public class CastControllerImpl implements CastController {

    @Override
    public void addCallback(Callback callback) {
        synchronized (mCallbacks) {
            mCallbacks.add(callback);
        }
        fireOnCastDevicesChanged(callback);
        synchronized (mDiscoveringLock) {
            handleDiscoveryChangeLocked();
@@ -92,7 +99,9 @@ public class CastControllerImpl implements CastController {

    @Override
    public void removeCallback(Callback callback) {
        synchronized (mCallbacks) {
            mCallbacks.remove(callback);
        }
        synchronized (mDiscoveringLock) {
            handleDiscoveryChangeLocked();
        }
@@ -117,12 +126,18 @@ public class CastControllerImpl implements CastController {
            mMediaRouter.addCallback(ROUTE_TYPE_REMOTE_DISPLAY, mMediaCallback,
                    MediaRouter.CALLBACK_FLAG_REQUEST_DISCOVERY);
            mCallbackRegistered = true;
        } else if (mCallbacks.size() != 0) {
        } else {
            boolean hasCallbacks = false;
            synchronized (mCallbacks) {
                hasCallbacks = mCallbacks.isEmpty();
            }
            if (!hasCallbacks) {
                mMediaRouter.addCallback(ROUTE_TYPE_REMOTE_DISPLAY, mMediaCallback,
                        MediaRouter.CALLBACK_FLAG_PASSIVE_DISCOVERY);
                mCallbackRegistered = true;
            }
        }
    }

    @Override
    public void setCurrentUserId(int currentUserId) {
@@ -248,11 +263,16 @@ public class CastControllerImpl implements CastController {
        }
    }

    private void fireOnCastDevicesChanged() {
    @VisibleForTesting
    void fireOnCastDevicesChanged() {
        synchronized (mCallbacks) {
            for (Callback callback : mCallbacks) {
                fireOnCastDevicesChanged(callback);
            }

        }
    }


    private void fireOnCastDevicesChanged(Callback callback) {
        callback.onCastDevicesChanged();
+54 −5
Original line number Diff line number Diff line
@@ -6,8 +6,8 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.junit.Assert.fail;

import android.content.Context;
import android.media.MediaRouter;
import android.media.projection.MediaProjectionInfo;
import android.media.projection.MediaProjectionManager;
@@ -24,6 +24,11 @@ import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.junit.Test;

import java.util.ConcurrentModificationException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

@SmallTest
@RunWith(AndroidTestingRunner.class)
@TestableLooper.RunWithLooper
@@ -74,4 +79,48 @@ public class CastControllerImplTest extends SysuiTestCase {
        mController.removeCallback(mockCallback);
        verify(mockCallback, never()).onCastDevicesChanged();
    }

    @Test
    public void testAddCallbackRemoveCallback_concurrently() throws InterruptedException {
        int callbackCount = 20;
        int numThreads = 2 * callbackCount;
        CountDownLatch startThreadsLatch = new CountDownLatch(1);
        CountDownLatch threadsDone = new CountDownLatch(numThreads);
        Callback[] callbackList = new Callback[callbackCount];
        mController.setDiscovering(true);
        AtomicBoolean error = new AtomicBoolean(false);
        for (int cbIndex = 0; cbIndex < callbackCount; cbIndex++) {
            callbackList[cbIndex] = mock(Callback.class);
        }
        for (int i = 0; i < numThreads; i++) {
            final Callback mCallback = callbackList[i / 2];
            final boolean shouldAdd = (i % 2 == 0);
            new Thread() {
                public void run() {
                    try {
                        startThreadsLatch.await(10, TimeUnit.SECONDS);
                    } catch (InterruptedException e) {
                        throw new RuntimeException(e);
                    }
                    try {
                        if (shouldAdd) {
                            mController.addCallback(mCallback);
                        } else {
                            mController.removeCallback(mCallback);
                        }
                        mController.fireOnCastDevicesChanged();
                    } catch (ConcurrentModificationException exc) {
                        error.compareAndSet(false, true);
                    } finally {
                        threadsDone.countDown();
                    }
                }
            }.start();
        }
        startThreadsLatch.countDown();
        threadsDone.await(10, TimeUnit.SECONDS);
        if (error.get()) {
            fail("Concurrent modification exception");
        }
    }
}