Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d1b1115b authored by Seth Moore's avatar Seth Moore Committed by Android (Google) Code Review
Browse files

Merge "Use IBinder to compare callbacks in remote provisioning service" into udc-dev

parents 2ce407b1 c55fb4a8
Loading
Loading
Loading
Loading
+31 −25
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@
package com.android.server.security.rkp;

import android.os.CancellationSignal;
import android.os.IBinder;
import android.os.OperationCanceledException;
import android.os.OutcomeReceiver;
import android.security.rkp.IGetKeyCallback;
@@ -39,23 +40,23 @@ import java.util.concurrent.Executor;
 */
final class RemoteProvisioningRegistration extends IRegistration.Stub {
    static final String TAG = RemoteProvisioningService.TAG;
    private final ConcurrentHashMap<IGetKeyCallback, CancellationSignal> mGetKeyOperations =
    private final ConcurrentHashMap<IBinder, CancellationSignal> mGetKeyOperations =
            new ConcurrentHashMap<>();
    private final Set<IStoreUpgradedKeyCallback> mStoreUpgradedKeyOperations =
            ConcurrentHashMap.newKeySet();
    private final Set<IBinder> mStoreUpgradedKeyOperations = ConcurrentHashMap.newKeySet();
    private final RegistrationProxy mRegistration;
    private final Executor mExecutor;

    private class GetKeyReceiver implements OutcomeReceiver<RemotelyProvisionedKey, Exception> {
        IGetKeyCallback mCallback;

        GetKeyReceiver(IGetKeyCallback callback) {
            mCallback = callback;
        }

        @Override
        public void onResult(RemotelyProvisionedKey result) {
            mGetKeyOperations.remove(mCallback);
            Log.i(TAG, "Successfully fetched key for client " + mCallback.hashCode());
            mGetKeyOperations.remove(mCallback.asBinder());
            Log.i(TAG, "Successfully fetched key for client " + mCallback.asBinder().hashCode());
            android.security.rkp.RemotelyProvisionedKey parcelable =
                    new android.security.rkp.RemotelyProvisionedKey();
            parcelable.keyBlob = result.getKeyBlob();
@@ -65,19 +66,21 @@ final class RemoteProvisioningRegistration extends IRegistration.Stub {

        @Override
        public void onError(Exception e) {
            mGetKeyOperations.remove(mCallback);
            mGetKeyOperations.remove(mCallback.asBinder());
            if (e instanceof OperationCanceledException) {
                Log.i(TAG, "Operation cancelled for client " + mCallback.hashCode());
                Log.i(TAG, "Operation cancelled for client " + mCallback.asBinder().hashCode());
                wrapCallback(mCallback::onCancel);
            } else if (e instanceof RkpProxyException) {
                Log.e(TAG, "RKP error fetching key for client " + mCallback.hashCode() + ": "
                Log.e(TAG, "RKP error fetching key for client " + mCallback.asBinder().hashCode()
                        + ": "
                        + e.getMessage());
                RkpProxyException rkpException = (RkpProxyException) e;
                wrapCallback(() -> mCallback.onError(toGetKeyError(rkpException),
                        e.getMessage()));
            } else {
                Log.e(TAG, "Unknown error fetching key for client " + mCallback.hashCode() + ": "
                        + e.getMessage());
                Log.e(TAG,
                        "Unknown error fetching key for client " + mCallback.asBinder().hashCode()
                                + ": " + e.getMessage());
                wrapCallback(() -> mCallback.onError(IGetKeyCallback.ErrorCode.ERROR_UNKNOWN,
                        e.getMessage()));
            }
@@ -108,20 +111,23 @@ final class RemoteProvisioningRegistration extends IRegistration.Stub {
    @Override
    public void getKey(int keyId, IGetKeyCallback callback) {
        CancellationSignal cancellationSignal = new CancellationSignal();
        if (mGetKeyOperations.putIfAbsent(callback, cancellationSignal) != null) {
            Log.e(TAG, "Client can only request one call at a time " + callback.hashCode());
        if (mGetKeyOperations.putIfAbsent(callback.asBinder(), cancellationSignal) != null) {
            Log.e(TAG,
                    "Client can only request one call at a time " + callback.asBinder().hashCode());
            throw new IllegalArgumentException(
                    "Callback is already associated with an existing operation: "
                            + callback.hashCode());
                            + callback.asBinder().hashCode());
        }

        try {
            Log.i(TAG, "Fetching key " + keyId + " for client " + callback.hashCode());
            Log.i(TAG, "Fetching key " + keyId + " for client " + callback.asBinder().hashCode());
            mRegistration.getKeyAsync(keyId, cancellationSignal, mExecutor,
                    new GetKeyReceiver(callback));
        } catch (Exception e) {
            Log.e(TAG, "getKeyAsync threw an exception for client " + callback.hashCode(), e);
            mGetKeyOperations.remove(callback);
            Log.e(TAG,
                    "getKeyAsync threw an exception for client " + callback.asBinder().hashCode(),
                    e);
            mGetKeyOperations.remove(callback.asBinder());
            wrapCallback(() -> callback.onError(IGetKeyCallback.ErrorCode.ERROR_UNKNOWN,
                    e.getMessage()));
        }
@@ -129,23 +135,23 @@ final class RemoteProvisioningRegistration extends IRegistration.Stub {

    @Override
    public void cancelGetKey(IGetKeyCallback callback) {
        CancellationSignal cancellationSignal = mGetKeyOperations.remove(callback);
        CancellationSignal cancellationSignal = mGetKeyOperations.remove(callback.asBinder());
        if (cancellationSignal == null) {
            throw new IllegalArgumentException(
                    "Invalid client in cancelGetKey: " + callback.hashCode());
                    "Invalid client in cancelGetKey: " + callback.asBinder().hashCode());
        }

        Log.i(TAG, "Requesting cancellation for client " + callback.hashCode());
        Log.i(TAG, "Requesting cancellation for client " + callback.asBinder().hashCode());
        cancellationSignal.cancel();
    }

    @Override
    public void storeUpgradedKeyAsync(byte[] oldKeyBlob, byte[] newKeyBlob,
            IStoreUpgradedKeyCallback callback) {
        if (!mStoreUpgradedKeyOperations.add(callback)) {
        if (!mStoreUpgradedKeyOperations.add(callback.asBinder())) {
            throw new IllegalArgumentException(
                    "Callback is already associated with an existing operation: "
                            + callback.hashCode());
                            + callback.asBinder().hashCode());
        }

        try {
@@ -153,20 +159,20 @@ final class RemoteProvisioningRegistration extends IRegistration.Stub {
                    new OutcomeReceiver<>() {
                        @Override
                        public void onResult(Void result) {
                            mStoreUpgradedKeyOperations.remove(callback);
                            mStoreUpgradedKeyOperations.remove(callback.asBinder());
                            wrapCallback(callback::onSuccess);
                        }

                        @Override
                        public void onError(Exception e) {
                            mStoreUpgradedKeyOperations.remove(callback);
                            mStoreUpgradedKeyOperations.remove(callback.asBinder());
                            wrapCallback(() -> callback.onError(e.getMessage()));
                        }
                    });
        } catch (Exception e) {
            Log.e(TAG, "storeUpgradedKeyAsync threw an exception for client "
                    + callback.hashCode(), e);
            mStoreUpgradedKeyOperations.remove(callback);
                    + callback.asBinder().hashCode(), e);
            mStoreUpgradedKeyOperations.remove(callback.asBinder());
            wrapCallback(() -> callback.onError(e.getMessage()));
        }
    }
+2 −2
Original line number Diff line number Diff line
@@ -61,7 +61,7 @@ public class RemoteProvisioningService extends SystemService {
            try {
                mCallback.onSuccess(new RemoteProvisioningRegistration(registration, mExecutor));
            } catch (RemoteException e) {
                Log.e(TAG, "Error calling success callback " + mCallback.hashCode(), e);
                Log.e(TAG, "Error calling success callback " + mCallback.asBinder().hashCode(), e);
            }
        }

@@ -70,7 +70,7 @@ public class RemoteProvisioningService extends SystemService {
            try {
                mCallback.onError(error.toString());
            } catch (RemoteException e) {
                Log.e(TAG, "Error calling error callback " + mCallback.hashCode(), e);
                Log.e(TAG, "Error calling error callback " + mCallback.asBinder().hashCode(), e);
            }
        }
    }
+34 −1
Original line number Diff line number Diff line
@@ -23,8 +23,10 @@ import static org.mockito.AdditionalAnswers.answerVoid;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.argThat;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.contains;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
@@ -32,6 +34,7 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;

import android.os.Binder;
import android.os.CancellationSignal;
import android.os.OperationCanceledException;
import android.os.OutcomeReceiver;
@@ -101,8 +104,10 @@ public class RemoteProvisioningRegistrationTest {
                .when(mRegistrationProxy).getKeyAsync(eq(42), any(), any(), any());

        IGetKeyCallback callback = mock(IGetKeyCallback.class);
        doReturn(new Binder()).when(callback).asBinder();
        mRegistration.getKey(42, callback);
        verify(callback).onSuccess(matches(expectedKey));
        verify(callback, atLeastOnce()).asBinder();
        verifyNoMoreInteractions(callback);
    }

@@ -114,8 +119,10 @@ public class RemoteProvisioningRegistrationTest {
                        executor.execute(() -> receiver.onError(expectedException))))
                .when(mRegistrationProxy).getKeyAsync(eq(0), any(), any(), any());
        IGetKeyCallback callback = mock(IGetKeyCallback.class);
        doReturn(new Binder()).when(callback).asBinder();
        mRegistration.getKey(0, callback);
        verify(callback).onError(eq(IGetKeyCallback.ErrorCode.ERROR_UNKNOWN), eq("oops!"));
        verify(callback, atLeastOnce()).asBinder();
        verifyNoMoreInteractions(callback);
    }

@@ -140,18 +147,28 @@ public class RemoteProvisioningRegistrationTest {
                            executor.execute(() -> receiver.onError(expectedException))))
                    .when(mRegistrationProxy).getKeyAsync(eq(0), any(), any(), any());
            IGetKeyCallback callback = mock(IGetKeyCallback.class);
            doReturn(new Binder()).when(callback).asBinder();
            mRegistration.getKey(0, callback);
            verify(callback).onError(eq(error), contains(errorField.getName()));
            verify(callback, atLeastOnce()).asBinder();
            verifyNoMoreInteractions(callback);
        }
    }

    @Test
    public void getKeyCancelDuringProxyOperation() throws Exception {
        final Binder theBinder = new Binder();
        IGetKeyCallback callback = mock(IGetKeyCallback.class);
        doReturn(theBinder).when(callback).asBinder();
        doAnswer(
                answerGetKeyAsync((keyId, cancelSignal, executor, receiver) -> {
                    mRegistration.cancelGetKey(callback);
                    // Use a different callback object to ensure that the callback equivalence
                    // relies on the actual IBinder object.
                    IGetKeyCallback differentCallback = mock(IGetKeyCallback.class);
                    doReturn(theBinder).when(differentCallback).asBinder();
                    mRegistration.cancelGetKey(differentCallback);
                    verify(differentCallback, atLeastOnce()).asBinder();
                    verifyNoMoreInteractions(differentCallback);
                    assertThat(cancelSignal.isCanceled()).isTrue();
                    executor.execute(() -> receiver.onError(new OperationCanceledException()));
                }))
@@ -159,18 +176,21 @@ public class RemoteProvisioningRegistrationTest {

        mRegistration.getKey(Integer.MAX_VALUE, callback);
        verify(callback).onCancel();
        verify(callback, atLeastOnce()).asBinder();
        verifyNoMoreInteractions(callback);
    }

    @Test
    public void cancelGetKeyWithInvalidCallback() throws Exception {
        IGetKeyCallback callback = mock(IGetKeyCallback.class);
        doReturn(new Binder()).when(callback).asBinder();
        assertThrows(IllegalArgumentException.class, () -> mRegistration.cancelGetKey(callback));
    }

    @Test
    public void getKeyRejectsDuplicateCallback() throws Exception {
        IGetKeyCallback callback = mock(IGetKeyCallback.class);
        doReturn(new Binder()).when(callback).asBinder();
        doAnswer(
                answerGetKeyAsync((keyId, cancelSignal, executor, receiver) -> {
                    assertThrows(IllegalArgumentException.class, () ->
@@ -181,12 +201,14 @@ public class RemoteProvisioningRegistrationTest {

        mRegistration.getKey(0, callback);
        verify(callback, times(1)).onSuccess(any());
        verify(callback, atLeastOnce()).asBinder();
        verifyNoMoreInteractions(callback);
    }

    @Test
    public void getKeyCancelAfterCompleteFails() throws Exception {
        IGetKeyCallback callback = mock(IGetKeyCallback.class);
        doReturn(new Binder()).when(callback).asBinder();
        doAnswer(
                answerGetKeyAsync((keyId, cancelSignal, executor, receiver) ->
                        executor.execute(() ->
@@ -197,6 +219,7 @@ public class RemoteProvisioningRegistrationTest {
        mRegistration.getKey(Integer.MIN_VALUE, callback);
        verify(callback).onSuccess(any());
        assertThrows(IllegalArgumentException.class, () -> mRegistration.cancelGetKey(callback));
        verify(callback, atLeastOnce()).asBinder();
        verifyNoMoreInteractions(callback);
    }

@@ -208,10 +231,12 @@ public class RemoteProvisioningRegistrationTest {
                .getKeyAsync(anyInt(), any(), any(), any());

        IGetKeyCallback callback = mock(IGetKeyCallback.class);
        doReturn(new Binder()).when(callback).asBinder();
        mRegistration.getKey(0, callback);
        verify(callback).onError(eq(IGetKeyCallback.ErrorCode.ERROR_UNKNOWN),
                eq(expectedException.getMessage()));
        assertThrows(IllegalArgumentException.class, () -> mRegistration.cancelGetKey(callback));
        verify(callback, atLeastOnce()).asBinder();
        verifyNoMoreInteractions(callback);
    }

@@ -224,8 +249,10 @@ public class RemoteProvisioningRegistrationTest {
                .storeUpgradedKeyAsync(any(), any(), any(), any());

        IStoreUpgradedKeyCallback callback = mock(IStoreUpgradedKeyCallback.class);
        doReturn(new Binder()).when(callback).asBinder();
        mRegistration.storeUpgradedKeyAsync(new byte[0], new byte[0], callback);
        verify(callback).onSuccess();
        verify(callback, atLeastOnce()).asBinder();
        verifyNoMoreInteractions(callback);
    }

@@ -239,8 +266,10 @@ public class RemoteProvisioningRegistrationTest {
                .storeUpgradedKeyAsync(any(), any(), any(), any());

        IStoreUpgradedKeyCallback callback = mock(IStoreUpgradedKeyCallback.class);
        doReturn(new Binder()).when(callback).asBinder();
        mRegistration.storeUpgradedKeyAsync(new byte[0], new byte[0], callback);
        verify(callback).onError(errorString);
        verify(callback, atLeastOnce()).asBinder();
        verifyNoMoreInteractions(callback);
    }

@@ -252,14 +281,17 @@ public class RemoteProvisioningRegistrationTest {
                .storeUpgradedKeyAsync(any(), any(), any(), any());

        IStoreUpgradedKeyCallback callback = mock(IStoreUpgradedKeyCallback.class);
        doReturn(new Binder()).when(callback).asBinder();
        mRegistration.storeUpgradedKeyAsync(new byte[0], new byte[0], callback);
        verify(callback).onError(errorString);
        verify(callback, atLeastOnce()).asBinder();
        verifyNoMoreInteractions(callback);
    }

    @Test
    public void storeUpgradedKeyDuplicateCallback() throws Exception {
        IStoreUpgradedKeyCallback callback = mock(IStoreUpgradedKeyCallback.class);
        doReturn(new Binder()).when(callback).asBinder();

        doAnswer(
                answerStoreUpgradedKeyAsync((oldBlob, newBlob, executor, receiver) -> {
@@ -273,6 +305,7 @@ public class RemoteProvisioningRegistrationTest {

        mRegistration.storeUpgradedKeyAsync(new byte[0], new byte[0], callback);
        verify(callback).onSuccess();
        verify(callback, atLeastOnce()).asBinder();
        verifyNoMoreInteractions(callback);
    }