Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 484074ae authored by Kyunglyul Hyun's avatar Kyunglyul Hyun
Browse files

MediaRouter: make setControlCategories synchronous

MR2.getRoutes() right after MR2.setControlCategories() will return
routes filtered with new control categories.

In order to call backs properly, mHandledControlCategories is added,
which is changed only in handler.

getRoutes() is also changed such that routes are evaluated lazily
to avoid unnecessary evaluation.

CallbackRecords are implemented in a way that doesn't acquire explicit
lock.

Callbacks are modified so that newly register callbacks no longer
get notified existing routes.

Also, fixed test according to changes of MR2 and MRM.

Bug: 145488462
Test: atest mediaroutertest (5 times)
Change-Id: I9ba3d74bcf423d801249420d947f41eccb37d67a
parent a15562ee
Loading
Loading
Loading
Loading
+102 −108
Original line number Diff line number Diff line
@@ -23,6 +23,7 @@ import static java.lang.annotation.RetentionPolicy.SOURCE;
import android.annotation.CallbackExecutor;
import android.annotation.IntDef;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.content.Context;
import android.content.Intent;
import android.os.Bundle;
@@ -51,7 +52,6 @@ import java.util.concurrent.Executor;
 * @hide
 */
public class MediaRouter2 {

    /** @hide */
    @Retention(SOURCE)
    @IntDef(value = {
@@ -102,13 +102,11 @@ public class MediaRouter2 {
            new CopyOnWriteArrayList<>();

    private final String mPackageName;
    @GuardedBy("sLock")
    private final Map<String, MediaRoute2Info> mRoutes = new HashMap<>();

    //TODO: Use a lock for this to cover the below use case
    // mRouter.setControlCategories(...);
    // routes = mRouter.getRoutes();
    // The current implementation returns empty list
    private volatile List<String> mControlCategories = Collections.emptyList();
    @GuardedBy("sLock")
    private List<String> mControlCategories = Collections.emptyList();

    private MediaRoute2Info mSelectedRoute;
    @GuardedBy("sLock")
@@ -117,7 +115,9 @@ public class MediaRouter2 {
    private Client2 mClient;

    final Handler mHandler;
    volatile List<MediaRoute2Info> mFilteredRoutes = Collections.emptyList();
    @GuardedBy("sLock")
    private boolean mShouldUpdateRoutes;
    private volatile List<MediaRoute2Info> mFilteredRoutes = Collections.emptyList();

    /**
     * Gets an instance of the media router associated with the context.
@@ -171,8 +171,7 @@ public class MediaRouter2 {
    /**
     * Registers a callback to discover routes and to receive events when they change.
     * <p>
     * If you register the same callback twice or more, the previous arguments will be overwritten
     * with the new arguments.
     * If you register the same callback twice or more, it will be ignored.
     * </p>
     */
    public void registerCallback(@NonNull @CallbackExecutor Executor executor,
@@ -180,18 +179,10 @@ public class MediaRouter2 {
        Objects.requireNonNull(executor, "executor must not be null");
        Objects.requireNonNull(callback, "callback must not be null");

        CallbackRecord record;
        // This is required to prevent adding the same callback twice.
        synchronized (mCallbackRecords) {
            final int index = findCallbackRecordIndexLocked(callback);
            if (index < 0) {
                record = new CallbackRecord(callback);
                mCallbackRecords.add(record);
            } else {
                record = mCallbackRecords.get(index);
            }
            record.mExecutor = executor;
            record.mFlags = flags;
        CallbackRecord record = new CallbackRecord(callback, executor, flags);
        if (!mCallbackRecords.addIfAbsent(record)) {
            Log.w(TAG, "Ignoring the same callback");
            return;
        }

        synchronized (sLock) {
@@ -206,8 +197,6 @@ public class MediaRouter2 {
                }
            }
        }
        //TODO: Is it thread-safe?
        record.notifyRoutes();

        //TODO: Update discovery request here.
    }
@@ -222,13 +211,11 @@ public class MediaRouter2 {
    public void unregisterCallback(@NonNull Callback callback) {
        Objects.requireNonNull(callback, "callback must not be null");

        synchronized (mCallbackRecords) {
            final int index = findCallbackRecordIndexLocked(callback);
            if (index < 0) {
                Log.w(TAG, "Ignoring to remove unknown callback. " + callback);
        if (!mCallbackRecords.remove(new CallbackRecord(callback, null, 0))) {
            Log.w(TAG, "Ignoring unknown callback");
            return;
        }
            mCallbackRecords.remove(index);

        synchronized (sLock) {
            if (mCallbackRecords.size() == 0 && mClient != null) {
                try {
@@ -241,31 +228,56 @@ public class MediaRouter2 {
            }
        }
    }
    }

    //TODO(b/139033746): Rename "Control Category" when it's finalized.
    /**
     * Sets the control categories of the application.
     * Routes that support at least one of the given control categories only exists and are handled
     * Routes that support at least one of the given control categories are handled
     * by the media router.
     */
    public void setControlCategories(@NonNull Collection<String> controlCategories) {
        Objects.requireNonNull(controlCategories, "control categories must not be null");

        // To ensure invoking callbacks correctly according to control categories
        mHandler.sendMessage(obtainMessage(MediaRouter2::setControlCategoriesOnHandler,
                MediaRouter2.this, new ArrayList<>(controlCategories)));
        List<String> newControlCategories = new ArrayList<>(controlCategories);

        synchronized (sLock) {
            mShouldUpdateRoutes = true;

            // invoke callbacks due to control categories change
            handleControlCategoriesChangedLocked(newControlCategories);
            if (mClient != null) {
                try {
                    mMediaRouterService.setControlCategories(mClient, mControlCategories);
                } catch (RemoteException ex) {
                    Log.e(TAG, "Unable to set control categories.", ex);
                }
            }
        }
    }

    /**
     * Gets the unmodifiable list of {@link MediaRoute2Info routes} currently
     * known to the media router.
     * Please note that the list can be changed before callbacks are invoked.
     *
     * @return the list of routes that support at least one of the control categories set by
     * the application
     */
    @NonNull
    public List<MediaRoute2Info> getRoutes() {
        synchronized (sLock) {
            if (mShouldUpdateRoutes) {
                mShouldUpdateRoutes = false;

                List<MediaRoute2Info> filteredRoutes = new ArrayList<>();
                for (MediaRoute2Info route : mRoutes.values()) {
                    if (route.supportsControlCategory(mControlCategories)) {
                        filteredRoutes.add(route);
                    }
                }
                mFilteredRoutes = Collections.unmodifiableList(filteredRoutes);
            }
        }
        return mFilteredRoutes;
    }

@@ -379,43 +391,16 @@ public class MediaRouter2 {
        }
    }

    @GuardedBy("mCallbackRecords")
    private int findCallbackRecordIndexLocked(Callback callback) {
        final int count = mCallbackRecords.size();
        for (int i = 0; i < count; i++) {
            CallbackRecord callbackRecord = mCallbackRecords.get(i);
            if (callbackRecord.mCallback == callback) {
                return i;
            }
        }
        return -1;
    }

    private void setControlCategoriesOnHandler(List<String> newControlCategories) {
        List<String> prevControlCategories = mControlCategories;
    private void handleControlCategoriesChangedLocked(List<String> newControlCategories) {
        List<MediaRoute2Info> addedRoutes = new ArrayList<>();
        List<MediaRoute2Info> removedRoutes = new ArrayList<>();
        List<MediaRoute2Info> filteredRoutes = new ArrayList<>();

        List<String> prevControlCategories = mControlCategories;
        mControlCategories = newControlCategories;
        Client2 client;
        synchronized (sLock) {
            client = mClient;
        }
        if (client != null) {
            try {
                mMediaRouterService.setControlCategories(client, mControlCategories);
            } catch (RemoteException ex) {
                Log.e(TAG, "Unable to set control categories.", ex);
            }
        }

        for (MediaRoute2Info route : mRoutes.values()) {
            boolean preSupported = route.supportsControlCategory(prevControlCategories);
            boolean postSupported = route.supportsControlCategory(newControlCategories);
            if (postSupported) {
                filteredRoutes.add(route);
            }
            if (preSupported == postSupported) {
                continue;
            }
@@ -425,13 +410,14 @@ public class MediaRouter2 {
                addedRoutes.add(route);
            }
        }
        mFilteredRoutes = Collections.unmodifiableList(filteredRoutes);

        if (removedRoutes.size() > 0) {
            notifyRoutesRemoved(removedRoutes);
            mHandler.sendMessage(obtainMessage(MediaRouter2::notifyRoutesRemoved,
                    MediaRouter2.this, removedRoutes));
        }
        if (addedRoutes.size() > 0) {
            notifyRoutesAdded(addedRoutes);
            mHandler.sendMessage(obtainMessage(MediaRouter2::notifyRoutesAdded,
                    MediaRouter2.this, addedRoutes));
        }
    }

@@ -441,42 +427,47 @@ public class MediaRouter2 {
        //  2) Call onRouteSelected(system_route, reason_fallback) if previously selected route
        //     does not exist anymore. => We may need 'boolean MediaRoute2Info#isSystemRoute()'.
        List<MediaRoute2Info> addedRoutes = new ArrayList<>();
        synchronized (sLock) {
            for (MediaRoute2Info route : routes) {
                mRoutes.put(route.getUniqueId(), route);
                if (route.supportsControlCategory(mControlCategories)) {
                    addedRoutes.add(route);
                }
            }
            mShouldUpdateRoutes = true;
        }
        if (addedRoutes.size() > 0) {
            refreshFilteredRoutes();
            notifyRoutesAdded(addedRoutes);
        }
    }

    void removeRoutesOnHandler(List<MediaRoute2Info> routes) {
        List<MediaRoute2Info> removedRoutes = new ArrayList<>();
        synchronized (sLock) {
            for (MediaRoute2Info route : routes) {
                mRoutes.remove(route.getUniqueId());
                if (route.supportsControlCategory(mControlCategories)) {
                    removedRoutes.add(route);
                }
            }
            mShouldUpdateRoutes = true;
        }
        if (removedRoutes.size() > 0) {
            refreshFilteredRoutes();
            notifyRoutesRemoved(removedRoutes);
        }
    }

    void changeRoutesOnHandler(List<MediaRoute2Info> routes) {
        List<MediaRoute2Info> changedRoutes = new ArrayList<>();
        synchronized (sLock) {
            for (MediaRoute2Info route : routes) {
                mRoutes.put(route.getUniqueId(), route);
                if (route.supportsControlCategory(mControlCategories)) {
                    changedRoutes.add(route);
                }
            }
        }
        if (changedRoutes.size() > 0) {
            refreshFilteredRoutes();
            notifyRoutesChanged(changedRoutes);
        }
    }
@@ -500,17 +491,6 @@ public class MediaRouter2 {
        notifyRouteSelected(route, reason, controlHints);
    }

    private void refreshFilteredRoutes() {
        List<MediaRoute2Info> filteredRoutes = new ArrayList<>();

        for (MediaRoute2Info route : mRoutes.values()) {
            if (route.supportsControlCategory(mControlCategories)) {
                filteredRoutes.add(route);
            }
        }
        mFilteredRoutes = Collections.unmodifiableList(filteredRoutes);
    }

    private void notifyRoutesAdded(List<MediaRoute2Info> routes) {
        for (CallbackRecord record: mCallbackRecords) {
            record.mExecutor.execute(
@@ -544,13 +524,16 @@ public class MediaRouter2 {
     */
    public static class Callback {
        /**
         * Called when routes are added.
         * Called when routes are added. Whenever you registers a callback, this will
         * be invoked with known routes.
         *
         * @param routes the list of routes that have been added. It's never empty.
         */
        public void onRoutesAdded(@NonNull List<MediaRoute2Info> routes) {}

        /**
         * Called when routes are removed.
         *
         * @param routes the list of routes that have been removed. It's never empty.
         */
        public void onRoutesRemoved(@NonNull List<MediaRoute2Info> routes) {}
@@ -569,6 +552,7 @@ public class MediaRouter2 {

        /**
         * Called when a route is selected. Exactly one route can be selected at a time.
         *
         * @param route the selected route.
         * @param reason the reason why the route is selected.
         * @param controlHints An optional bundle of provider-specific arguments which may be
@@ -587,16 +571,26 @@ public class MediaRouter2 {
        public Executor mExecutor;
        public int mFlags;

        CallbackRecord(@NonNull Callback callback) {
        CallbackRecord(@NonNull Callback callback, @Nullable Executor executor, int flags) {
            mCallback = callback;
            mExecutor = executor;
            mFlags = flags;
        }

        void notifyRoutes() {
            final List<MediaRoute2Info> routes = mFilteredRoutes;
            // notify only when bound to media router service.
            if (routes.size() > 0) {
                mExecutor.execute(() -> mCallback.onRoutesAdded(routes));
        @Override
        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof CallbackRecord)) {
                return false;
            }
            return mCallback == ((CallbackRecord) obj).mCallback;
        }

        @Override
        public int hashCode() {
            return mCallback.hashCode();
        }
    }

+39 −47
Original line number Diff line number Diff line
@@ -57,7 +57,7 @@ public class MediaRouter2Manager {
    private Client mClient;
    private final IMediaRouterService mMediaRouterService;
    final Handler mHandler;
    final List<CallbackRecord> mCallbackRecords = new CopyOnWriteArrayList<>();
    final CopyOnWriteArrayList<CallbackRecord> mCallbackRecords = new CopyOnWriteArrayList<>();

    private final Object mRoutesLock = new Object();
    @GuardedBy("mRoutesLock")
@@ -99,15 +99,11 @@ public class MediaRouter2Manager {
        Objects.requireNonNull(executor, "executor must not be null");
        Objects.requireNonNull(callback, "callback must not be null");

        CallbackRecord callbackRecord;
        synchronized (mCallbackRecords) {
            if (findCallbackRecordIndexLocked(callback) >= 0) {
        CallbackRecord callbackRecord = new CallbackRecord(executor, callback);
        if (!mCallbackRecords.addIfAbsent(callbackRecord)) {
            Log.w(TAG, "Ignoring to add the same callback twice.");
            return;
        }
            callbackRecord = new CallbackRecord(executor, callback);
            mCallbackRecords.add(callbackRecord);
        }

        synchronized (sLock) {
            if (mClient == null) {
@@ -118,8 +114,6 @@ public class MediaRouter2Manager {
                } catch (RemoteException ex) {
                    Log.e(TAG, "Unable to register media router manager.", ex);
                }
            } else {
                callbackRecord.notifyRoutes();
            }
        }
    }
@@ -132,13 +126,11 @@ public class MediaRouter2Manager {
    public void unregisterCallback(@NonNull Callback callback) {
        Objects.requireNonNull(callback, "callback must not be null");

        synchronized (mCallbackRecords) {
            final int index = findCallbackRecordIndexLocked(callback);
            if (index < 0) {
        if (!mCallbackRecords.remove(new CallbackRecord(null, callback))) {
            Log.w(TAG, "Ignore removing unknown callback. " + callback);
            return;
        }
            mCallbackRecords.remove(index);

        synchronized (sLock) {
            if (mCallbackRecords.size() == 0 && mClient != null) {
                try {
@@ -148,21 +140,10 @@ public class MediaRouter2Manager {
                }
                //TODO: clear mRoutes?
                mClient = null;
                mControlCategoryMap.clear();
            }
        }
    }
    }

    @GuardedBy("mCallbackRecords")
    private int findCallbackRecordIndexLocked(Callback callback) {
        final int count = mCallbackRecords.size();
        for (int i = 0; i < count; i++) {
            if (mCallbackRecords.get(i).mCallback == callback) {
                return i;
            }
        }
        return -1;
    }

    //TODO: Use cache not to create array. For now, it's unclear when to purge the cache.
    //Do this when we finalize how to set control categories.
@@ -187,7 +168,6 @@ public class MediaRouter2Manager {
                }
            }
        }
        //TODO: Should we cache this?
        return routes;
    }

@@ -342,10 +322,14 @@ public class MediaRouter2Manager {
    }

    void updateControlCategories(String packageName, List<String> categories) {
        mControlCategoryMap.put(packageName, categories);
        List<String> prevCategories = mControlCategoryMap.put(packageName, categories);
        if ((prevCategories == null && categories.size() == 0)
                || Objects.equals(categories, prevCategories)) {
            return;
        }
        for (CallbackRecord record : mCallbackRecords) {
            record.mExecutor.execute(
                    () -> record.mCallback.onControlCategoriesChanged(packageName));
                    () -> record.mCallback.onControlCategoriesChanged(packageName, categories));
        }
    }

@@ -386,8 +370,10 @@ public class MediaRouter2Manager {
         * Called when the control categories of an app is changed.
         *
         * @param packageName the package name of the application
         * @param controlCategories the list of control categories set by an application.
         */
        public void onControlCategoriesChanged(@NonNull String packageName) {}
        public void onControlCategoriesChanged(@NonNull String packageName,
                @NonNull List<String> controlCategories) {}
    }

    final class CallbackRecord {
@@ -399,14 +385,20 @@ public class MediaRouter2Manager {
            mCallback = callback;
        }

        void notifyRoutes() {
            List<MediaRoute2Info> routes;
            synchronized (mRoutesLock) {
                routes = new ArrayList<>(mRoutes.values());
        @Override
        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (routes.size() > 0) {
                mExecutor.execute(() -> mCallback.onRoutesAdded(routes));
            if (!(obj instanceof CallbackRecord)) {
                return false;
            }
            return mCallback ==  ((CallbackRecord) obj).mCallback;
        }

        @Override
        public int hashCode() {
            return mCallback.hashCode();
        }
    }

+116 −0
Original line number Diff line number Diff line
@@ -16,7 +16,15 @@

package com.android.mediaroutertest;

import static com.android.mediaroutertest.MediaRouterManagerTest.CATEGORIES_ALL;
import static com.android.mediaroutertest.MediaRouterManagerTest.CATEGORIES_SPECIAL;
import static com.android.mediaroutertest.MediaRouterManagerTest.ROUTE_ID_SPECIAL_CATEGORY;
import static com.android.mediaroutertest.MediaRouterManagerTest.ROUTE_ID_VARIABLE_VOLUME;
import static com.android.mediaroutertest.MediaRouterManagerTest.SYSTEM_PROVIDER_ID;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

import android.content.Context;
import android.media.MediaRoute2Info;
@@ -24,20 +32,37 @@ import android.media.MediaRouter2;
import android.support.test.InstrumentationRegistry;
import android.support.test.filters.SmallTest;
import android.support.test.runner.AndroidJUnit4;
import android.text.TextUtils;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;

@RunWith(AndroidJUnit4.class)
@SmallTest
public class MediaRouter2Test {
    private static final String TAG = "MediaRouter2Test";
    Context mContext;
    private MediaRouter2 mRouter2;
    private Executor mExecutor;

    private static final int TIMEOUT_MS = 5000;

    @Before
    public void setUp() throws Exception {
        mContext = InstrumentationRegistry.getTargetContext();
        mRouter2 = MediaRouter2.getInstance(mContext);
        mExecutor = Executors.newSingleThreadExecutor();
    }

    @After
@@ -50,4 +75,95 @@ public class MediaRouter2Test {
        MediaRoute2Info initiallySelectedRoute = router.getSelectedRoute();
        assertNotNull(initiallySelectedRoute);
    }

    /**
     * Tests if we get proper routes for application that has special control category.
     */
    @Test
    public void testGetRoutes() throws Exception {
        Map<String, MediaRoute2Info> routes = waitAndGetRoutes(CATEGORIES_SPECIAL);

        assertEquals(1, routes.size());
        assertNotNull(routes.get(ROUTE_ID_SPECIAL_CATEGORY));
    }

    @Test
    public void testControlVolumeWithRouter() throws Exception {
        Map<String, MediaRoute2Info> routes = waitAndGetRoutes(CATEGORIES_ALL);

        MediaRoute2Info volRoute = routes.get(ROUTE_ID_VARIABLE_VOLUME);
        assertNotNull(volRoute);

        int originalVolume = volRoute.getVolume();
        int deltaVolume = (originalVolume == volRoute.getVolumeMax() ? -1 : 1);

        awaitOnRouteChanged(
                () -> mRouter2.requestUpdateVolume(volRoute, deltaVolume),
                ROUTE_ID_VARIABLE_VOLUME,
                (route -> route.getVolume() == originalVolume + deltaVolume));

        awaitOnRouteChanged(
                () -> mRouter2.requestSetVolume(volRoute, originalVolume),
                ROUTE_ID_VARIABLE_VOLUME,
                (route -> route.getVolume() == originalVolume));
    }


    // Helper for getting routes easily
    static Map<String, MediaRoute2Info> createRouteMap(List<MediaRoute2Info> routes) {
        Map<String, MediaRoute2Info> routeMap = new HashMap<>();
        for (MediaRoute2Info route : routes) {
            // intentionally not using route.getUniqueId() for convenience.
            routeMap.put(route.getId(), route);
        }
        return routeMap;
    }

    Map<String, MediaRoute2Info> waitAndGetRoutes(List<String> controlCategories)
            throws Exception {
        CountDownLatch latch = new CountDownLatch(1);

        // A dummy callback is required to send control category info.
        MediaRouter2.Callback routerCallback = new MediaRouter2.Callback() {
            @Override
            public void onRoutesAdded(List<MediaRoute2Info> routes) {
                for (int i = 0; i < routes.size(); i++) {
                    //TODO: use isSystem() or similar method when it's ready
                    if (!TextUtils.equals(routes.get(i).getProviderId(), SYSTEM_PROVIDER_ID)) {
                        latch.countDown();
                    }
                }
            }
        };

        mRouter2.setControlCategories(controlCategories);
        mRouter2.registerCallback(mExecutor, routerCallback);
        try {
            latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS);
            return createRouteMap(mRouter2.getRoutes());
        } finally {
            mRouter2.unregisterCallback(routerCallback);
        }
    }

    void awaitOnRouteChanged(Runnable task, String routeId,
            Predicate<MediaRoute2Info> predicate) throws Exception {
        CountDownLatch latch = new CountDownLatch(1);
        MediaRouter2.Callback callback = new MediaRouter2.Callback() {
            @Override
            public void onRoutesChanged(List<MediaRoute2Info> changed) {
                MediaRoute2Info route = createRouteMap(changed).get(routeId);
                if (route != null && predicate.test(route)) {
                    latch.countDown();
                }
            }
        };
        mRouter2.registerCallback(mExecutor, callback);
        try {
            task.run();
            assertTrue(latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
        } finally {
            mRouter2.unregisterCallback(callback);
        }
    }
}
+47 −103

File changed.

Preview size limit exceeded, changes collapsed.

+6 −2
Original line number Diff line number Diff line
@@ -183,8 +183,9 @@ class MediaRouter2ServiceImpl {
    }

    public void setControlCategories(@NonNull IMediaRouter2Client client,
            @Nullable List<String> categories) {
            @NonNull List<String> categories) {
        Objects.requireNonNull(client, "client must not be null");
        Objects.requireNonNull(categories, "categories must not be null");

        final long token = Binder.clearCallingIdentity();
        try {
@@ -390,8 +391,11 @@ class MediaRouter2ServiceImpl {

    private void setControlCategoriesLocked(Client2Record clientRecord, List<String> categories) {
        if (clientRecord != null) {
            clientRecord.mControlCategories = categories;
            if (clientRecord.mControlCategories.equals(categories)) {
                return;
            }

            clientRecord.mControlCategories = categories;
            clientRecord.mUserRecord.mHandler.sendMessage(
                    obtainMessage(UserHandler::updateClientUsage,
                            clientRecord.mUserRecord.mHandler, clientRecord));