Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit bebdd723 authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Added Lock for mCallbacks and test for Concurrency"

parents 03154666 aacde147
Loading
Loading
Loading
Loading
+30 −10
Original line number Diff line number Diff line
@@ -30,6 +30,7 @@ import android.util.ArrayMap;
import android.util.ArraySet;
import android.util.Log;

import com.android.internal.annotations.GuardedBy;
import com.android.systemui.R;

import java.io.FileDescriptor;
@@ -41,12 +42,16 @@ import java.util.UUID;

import static android.media.MediaRouter.ROUTE_TYPE_REMOTE_DISPLAY;

import androidx.annotation.VisibleForTesting;


/** Platform implementation of the cast controller. **/
public class CastControllerImpl implements CastController {
    private static final String TAG = "CastController";
    private static final boolean DEBUG = Log.isLoggable(TAG, Log.DEBUG);

    private final Context mContext;
    @GuardedBy("mCallbacks")
    private final ArrayList<Callback> mCallbacks = new ArrayList<Callback>();
    private final MediaRouter mMediaRouter;
    private final ArrayMap<String, RouteInfo> mRoutes = new ArrayMap<>();
@@ -72,7 +77,7 @@ public class CastControllerImpl implements CastController {
        pw.println("CastController state:");
        pw.print("  mDiscovering="); pw.println(mDiscovering);
        pw.print("  mCallbackRegistered="); pw.println(mCallbackRegistered);
        pw.print("  mCallbacks.size="); pw.println(mCallbacks.size());
        pw.print("  mCallbacks.size="); synchronized (mCallbacks) {pw.println(mCallbacks.size());}
        pw.print("  mRoutes.size="); pw.println(mRoutes.size());
        for (int i = 0; i < mRoutes.size(); i++) {
            final RouteInfo route = mRoutes.valueAt(i);
@@ -83,7 +88,9 @@ public class CastControllerImpl implements CastController {

    @Override
    public void addCallback(Callback callback) {
        synchronized (mCallbacks) {
            mCallbacks.add(callback);
        }
        fireOnCastDevicesChanged(callback);
        synchronized (mDiscoveringLock) {
            handleDiscoveryChangeLocked();
@@ -92,7 +99,9 @@ public class CastControllerImpl implements CastController {

    @Override
    public void removeCallback(Callback callback) {
        synchronized (mCallbacks) {
            mCallbacks.remove(callback);
        }
        synchronized (mDiscoveringLock) {
            handleDiscoveryChangeLocked();
        }
@@ -117,12 +126,18 @@ public class CastControllerImpl implements CastController {
            mMediaRouter.addCallback(ROUTE_TYPE_REMOTE_DISPLAY, mMediaCallback,
                    MediaRouter.CALLBACK_FLAG_REQUEST_DISCOVERY);
            mCallbackRegistered = true;
        } else if (mCallbacks.size() != 0) {
        } else {
            boolean hasCallbacks = false;
            synchronized (mCallbacks) {
                hasCallbacks = mCallbacks.isEmpty();
            }
            if (!hasCallbacks) {
                mMediaRouter.addCallback(ROUTE_TYPE_REMOTE_DISPLAY, mMediaCallback,
                        MediaRouter.CALLBACK_FLAG_PASSIVE_DISCOVERY);
                mCallbackRegistered = true;
            }
        }
    }

    @Override
    public void setCurrentUserId(int currentUserId) {
@@ -248,11 +263,16 @@ public class CastControllerImpl implements CastController {
        }
    }

    private void fireOnCastDevicesChanged() {
    @VisibleForTesting
    void fireOnCastDevicesChanged() {
        synchronized (mCallbacks) {
            for (Callback callback : mCallbacks) {
                fireOnCastDevicesChanged(callback);
            }

        }
    }


    private void fireOnCastDevicesChanged(Callback callback) {
        callback.onCastDevicesChanged();
+54 −5
Original line number Diff line number Diff line
@@ -6,8 +6,8 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.junit.Assert.fail;

import android.content.Context;
import android.media.MediaRouter;
import android.media.projection.MediaProjectionInfo;
import android.media.projection.MediaProjectionManager;
@@ -24,6 +24,11 @@ import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.junit.Test;

import java.util.ConcurrentModificationException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

@SmallTest
@RunWith(AndroidTestingRunner.class)
@TestableLooper.RunWithLooper
@@ -74,4 +79,48 @@ public class CastControllerImplTest extends SysuiTestCase {
        mController.removeCallback(mockCallback);
        verify(mockCallback, never()).onCastDevicesChanged();
    }

    @Test
    public void testAddCallbackRemoveCallback_concurrently() throws InterruptedException {
        int callbackCount = 20;
        int numThreads = 2 * callbackCount;
        CountDownLatch startThreadsLatch = new CountDownLatch(1);
        CountDownLatch threadsDone = new CountDownLatch(numThreads);
        Callback[] callbackList = new Callback[callbackCount];
        mController.setDiscovering(true);
        AtomicBoolean error = new AtomicBoolean(false);
        for (int cbIndex = 0; cbIndex < callbackCount; cbIndex++) {
            callbackList[cbIndex] = mock(Callback.class);
        }
        for (int i = 0; i < numThreads; i++) {
            final Callback mCallback = callbackList[i / 2];
            final boolean shouldAdd = (i % 2 == 0);
            new Thread() {
                public void run() {
                    try {
                        startThreadsLatch.await(10, TimeUnit.SECONDS);
                    } catch (InterruptedException e) {
                        throw new RuntimeException(e);
                    }
                    try {
                        if (shouldAdd) {
                            mController.addCallback(mCallback);
                        } else {
                            mController.removeCallback(mCallback);
                        }
                        mController.fireOnCastDevicesChanged();
                    } catch (ConcurrentModificationException exc) {
                        error.compareAndSet(false, true);
                    } finally {
                        threadsDone.countDown();
                    }
                }
            }.start();
        }
        startThreadsLatch.countDown();
        threadsDone.await(10, TimeUnit.SECONDS);
        if (error.get()) {
            fail("Concurrent modification exception");
        }
    }
}