Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ce46731d authored by Hyundo Moon's avatar Hyundo Moon Committed by Automerger Merge Worker
Browse files

Merge "Prevent abuse of MediaRoute2ProviderService#notifyRequestFailed()" into...

Merge "Prevent abuse of MediaRoute2ProviderService#notifyRequestFailed()" into rvc-dev am: f3ff3de3

Original change: https://googleplex-android-review.googlesource.com/c/platform/frameworks/base/+/11788495

Change-Id: Idd3857936b6d14ad8b12b3e37efa3611386c698b
parents d1d8bcab f3ff3de3
Loading
Loading
Loading
Loading
+58 −1
Original line number Diff line number Diff line
@@ -40,8 +40,10 @@ import com.android.internal.annotations.GuardedBy;

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Deque;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -132,15 +134,21 @@ public abstract class MediaRoute2ProviderService extends Service {
    @Retention(RetentionPolicy.SOURCE)
    public @interface Reason {}

    private static final int MAX_REQUEST_IDS_SIZE = 500;

    private final Handler mHandler;
    private final Object mSessionLock = new Object();
    private final Object mRequestIdsLock = new Object();
    private final AtomicBoolean mStatePublishScheduled = new AtomicBoolean(false);
    private MediaRoute2ProviderServiceStub mStub;
    private IMediaRoute2ProviderServiceCallback mRemoteCallback;
    private volatile MediaRoute2ProviderInfo mProviderInfo;

    @GuardedBy("mRequestIdsLock")
    private final Deque<Long> mRequestIds = new ArrayDeque<>(MAX_REQUEST_IDS_SIZE);

    @GuardedBy("mSessionLock")
    private ArrayMap<String, RoutingSessionInfo> mSessionInfo = new ArrayMap<>();
    private final ArrayMap<String, RoutingSessionInfo> mSessionInfo = new ArrayMap<>();

    public MediaRoute2ProviderService() {
        mHandler = new Handler(Looper.getMainLooper());
@@ -230,6 +238,11 @@ public abstract class MediaRoute2ProviderService extends Service {
            @NonNull RoutingSessionInfo sessionInfo) {
        Objects.requireNonNull(sessionInfo, "sessionInfo must not be null");

        if (requestId != REQUEST_ID_NONE && !removeRequestId(requestId)) {
            Log.w(TAG, "notifySessionCreated: The requestId doesn't exist. requestId=" + requestId);
            return;
        }

        String sessionId = sessionInfo.getId();
        synchronized (mSessionLock) {
            if (mSessionInfo.containsKey(sessionId)) {
@@ -322,6 +335,13 @@ public abstract class MediaRoute2ProviderService extends Service {
        if (mRemoteCallback == null) {
            return;
        }

        if (!removeRequestId(requestId)) {
            Log.w(TAG, "notifyRequestFailed: The requestId doesn't exist. requestId="
                    + requestId);
            return;
        }

        try {
            mRemoteCallback.notifyRequestFailed(requestId, reason);
        } catch (RemoteException ex) {
@@ -469,6 +489,36 @@ public abstract class MediaRoute2ProviderService extends Service {
        }
    }

    /**
     * Adds a requestId in the request ID list whose max size is {@link #MAX_REQUEST_IDS_SIZE}.
     * When the max size is reached, the first element is removed (FIFO).
     */
    private void addRequestId(long requestId) {
        synchronized (mRequestIdsLock) {
            if (mRequestIds.size() >= MAX_REQUEST_IDS_SIZE) {
                mRequestIds.removeFirst();
            }
            mRequestIds.addLast(requestId);
        }
    }

    /**
     * Removes the given {@code requestId} from received request ID list.
     * <p>
     * Returns whether the list contains the {@code requestId}. These are the cases when the list
     * doesn't contain the given {@code requestId}:
     * <ul>
     *     <li>This service has never received a request with the requestId. </li>
     *     <li>{@link #notifyRequestFailed} or {@link #notifySessionCreated} already has been called
     *         for the requestId. </li>
     * </ul>
     */
    private boolean removeRequestId(long requestId) {
        synchronized (mRequestIdsLock) {
            return mRequestIds.removeFirstOccurrence(requestId);
        }
    }

    final class MediaRoute2ProviderServiceStub extends IMediaRoute2ProviderService.Stub {
        MediaRoute2ProviderServiceStub() { }

@@ -529,6 +579,7 @@ public abstract class MediaRoute2ProviderService extends Service {
            if (!checkRouteIdIsValid(routeId, "setRouteVolume")) {
                return;
            }
            addRequestId(requestId);
            mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onSetRouteVolume,
                    MediaRoute2ProviderService.this, requestId, routeId, volume));
        }
@@ -542,6 +593,7 @@ public abstract class MediaRoute2ProviderService extends Service {
            if (!checkRouteIdIsValid(routeId, "requestCreateSession")) {
                return;
            }
            addRequestId(requestId);
            mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onCreateSession,
                    MediaRoute2ProviderService.this, requestId, packageName, routeId,
                    requestCreateSession));
@@ -556,6 +608,7 @@ public abstract class MediaRoute2ProviderService extends Service {
                    || !checkRouteIdIsValid(routeId, "selectRoute")) {
                return;
            }
            addRequestId(requestId);
            mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onSelectRoute,
                    MediaRoute2ProviderService.this, requestId, sessionId, routeId));
        }
@@ -569,6 +622,7 @@ public abstract class MediaRoute2ProviderService extends Service {
                    || !checkRouteIdIsValid(routeId, "deselectRoute")) {
                return;
            }
            addRequestId(requestId);
            mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onDeselectRoute,
                    MediaRoute2ProviderService.this, requestId, sessionId, routeId));
        }
@@ -582,6 +636,7 @@ public abstract class MediaRoute2ProviderService extends Service {
                    || !checkRouteIdIsValid(routeId, "transferToRoute")) {
                return;
            }
            addRequestId(requestId);
            mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onTransferToRoute,
                    MediaRoute2ProviderService.this, requestId, sessionId, routeId));
        }
@@ -594,6 +649,7 @@ public abstract class MediaRoute2ProviderService extends Service {
            if (!checkSessionIdIsValid(sessionId, "setSessionVolume")) {
                return;
            }
            addRequestId(requestId);
            mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onSetSessionVolume,
                    MediaRoute2ProviderService.this, requestId, sessionId, volume));
        }
@@ -606,6 +662,7 @@ public abstract class MediaRoute2ProviderService extends Service {
            if (!checkSessionIdIsValid(sessionId, "releaseSession")) {
                return;
            }
            addRequestId(requestId);
            mHandler.sendMessage(obtainMessage(MediaRoute2ProviderService::onReleaseSession,
                    MediaRoute2ProviderService.this, requestId, sessionId));
        }
+11 −1
Original line number Diff line number Diff line
@@ -593,11 +593,16 @@ public class MediaRouter2ManagerTest {

        final int failureReason = REASON_REJECTED;
        final CountDownLatch onRequestFailedLatch = new CountDownLatch(1);
        final CountDownLatch onRequestFailedSecondCallLatch = new CountDownLatch(1);
        addManagerCallback(new MediaRouter2Manager.Callback() {
            @Override
            public void onRequestFailed(int reason) {
                if (reason == failureReason) {
                    if (onRequestFailedLatch.getCount() > 0) {
                        onRequestFailedLatch.countDown();
                    } else {
                        onRequestFailedSecondCallLatch.countDown();
                    }
                }
            }
        });
@@ -609,6 +614,11 @@ public class MediaRouter2ManagerTest {
        final long validRequestId = requestIds.get(0);
        instance.notifyRequestFailed(validRequestId, failureReason);
        assertTrue(onRequestFailedLatch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));

        // Test calling notifyRequestFailed() multiple times with the same valid requestId.
        // onRequestFailed() shouldn't be called since the requestId has been already handled.
        instance.notifyRequestFailed(validRequestId, failureReason);
        assertFalse(onRequestFailedSecondCallLatch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
    }

    @Test