Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 950f20bf authored by Arthur Ishiguro's avatar Arthur Ishiguro
Browse files

Adds message duplication detection in ContextHubEndpointBroker

This CL adds logic to detect duplicate received messages and reject them if the message is in the history within a predefined received time window.

Bug: 395884574
Flag: android.chre.flags.offload_implementation
Test: atest FrameworksServicesTests_contexthub_presubmit

Change-Id: I63bffd946192eb845755b24f6cd682610e928205
parent 84711c83
Loading
Loading
Loading
Loading
+76 −25
Original line number Diff line number Diff line
@@ -16,6 +16,8 @@

package com.android.server.location.contexthub;

import static com.android.server.location.contexthub.ContextHubTransactionManager.RELIABLE_MESSAGE_DUPLICATE_DETECTION_TIMEOUT;

import android.annotation.NonNull;
import android.app.AppOpsManager;
import android.content.Context;
@@ -44,6 +46,9 @@ import com.android.internal.annotations.GuardedBy;

import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
@@ -119,6 +124,14 @@ public class ContextHubEndpointBroker extends IContextHubEndpoint.Stub
         */
        private final Set<Integer> mPendingSequenceNumbers = new HashSet<>();

        /**
         * Stores the history of received messages that are timestamped. We use a LinkedHashMap to
         * guarantee insertion ordering for easier manipulation of removing expired entries.
         *
         * <p>The key is the sequence number, and the value is the timestamp in milliseconds.
         */
        private final LinkedHashMap<Integer, Long> mRxMessageHistoryMap = new LinkedHashMap<>();

        Session(HubEndpointInfo remoteEndpointInfo, boolean remoteInitiated) {
            mRemoteEndpointInfo = remoteEndpointInfo;
            mRemoteInitiated = remoteInitiated;
@@ -157,6 +170,38 @@ public class ContextHubEndpointBroker extends IContextHubEndpoint.Stub
                consumer.accept(sequenceNumber);
            }
        }

        public boolean isInMessageHistory(HubMessage message) {
            // Clean up the history
            Iterator<Map.Entry<Integer, Long>> iterator =
                    mRxMessageHistoryMap.entrySet().iterator();
            long nowMillis = System.currentTimeMillis();
            while (iterator.hasNext()) {
                Map.Entry<Integer, Long> nextEntry = iterator.next();
                long expiryMillis = RELIABLE_MESSAGE_DUPLICATE_DETECTION_TIMEOUT.toMillis();
                if (nowMillis >= nextEntry.getValue() + expiryMillis) {
                    iterator.remove();
                }
                break;
            }

            return mRxMessageHistoryMap.containsKey(message.getMessageSequenceNumber());
        }

        public void addMessageToHistory(HubMessage message) {
            if (mRxMessageHistoryMap.containsKey(message.getMessageSequenceNumber())) {
                long value = mRxMessageHistoryMap.get(message.getMessageSequenceNumber());
                Log.w(
                        TAG,
                        "Message already exists in history (inserted @ "
                                + value
                                + " ms): "
                                + message);
                return;
            }
            mRxMessageHistoryMap.put(
                    message.getMessageSequenceNumber(), System.currentTimeMillis());
        }
    }

    /** A map between a session ID which maps to its current state. */
@@ -492,9 +537,9 @@ public class ContextHubEndpointBroker extends IContextHubEndpoint.Stub
    }

    /* package */ void onMessageReceived(int sessionId, HubMessage message) {
        byte code = onMessageReceivedInternal(sessionId, message);
        if (code != ErrorCode.OK && message.isResponseRequired()) {
            sendMessageDeliveryStatus(sessionId, message.getMessageSequenceNumber(), code);
        byte errorCode = onMessageReceivedInternal(sessionId, message);
        if (errorCode != ErrorCode.OK && message.isResponseRequired()) {
            sendMessageDeliveryStatus(sessionId, message.getMessageSequenceNumber(), errorCode);
        }
    }

@@ -567,7 +612,6 @@ public class ContextHubEndpointBroker extends IContextHubEndpoint.Stub
    }

    private byte onMessageReceivedInternal(int sessionId, HubMessage message) {
        HubEndpointInfo remote;
        synchronized (mOpenSessionLock) {
            if (!isSessionActive(sessionId)) {
                Log.e(
@@ -578,7 +622,10 @@ public class ContextHubEndpointBroker extends IContextHubEndpoint.Stub
                                + message);
                return ErrorCode.PERMANENT_ERROR;
            }
            remote = mSessionMap.get(sessionId).getRemoteEndpointInfo();
            HubEndpointInfo remote = mSessionMap.get(sessionId).getRemoteEndpointInfo();
            if (mSessionMap.get(sessionId).isInMessageHistory(message)) {
                Log.e(TAG, "Dropping duplicate message: " + message);
                return ErrorCode.TRANSIENT_ERROR;
            }

            try {
@@ -600,8 +647,12 @@ public class ContextHubEndpointBroker extends IContextHubEndpoint.Stub

            boolean success =
                    invokeCallback((consumer) -> consumer.onMessageReceived(sessionId, message));
            if (success) {
                mSessionMap.get(sessionId).addMessageToHistory(message);
            }
            return success ? ErrorCode.OK : ErrorCode.TRANSIENT_ERROR;
        }
    }

    /**
     * Calls the HAL closeEndpointSession API.
+50 −19
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@
package com.android.server.location.contexthub;

import static com.google.common.truth.Truth.assertThat;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.timeout;
@@ -42,12 +43,10 @@ import android.os.Binder;
import android.os.RemoteException;
import android.platform.test.annotations.Presubmit;
import android.util.Log;

import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.platform.app.InstrumentationRegistry;

import java.util.Collections;
import java.util.List;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@@ -57,6 +56,9 @@ import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;

import java.util.Collections;
import java.util.List;

@RunWith(AndroidJUnit4.class)
@Presubmit
public class ContextHubEndpointTest {
@@ -73,6 +75,12 @@ public class ContextHubEndpointTest {
    private static final String TARGET_ENDPOINT_NAME = "Example target endpoint";
    private static final int TARGET_ENDPOINT_ID = 1;

    private static final int SAMPLE_MESSAGE_TYPE = 1234;
    private static final HubMessage SAMPLE_MESSAGE =
            new HubMessage.Builder(SAMPLE_MESSAGE_TYPE, new byte[] {1, 2, 3, 4, 5})
                    .setResponseRequired(true)
                    .build();

    private ContextHubClientManager mClientManager;
    private ContextHubEndpointManager mEndpointManager;
    private HubInfoRegistry mHubInfoRegistry;
@@ -229,23 +237,34 @@ public class ContextHubEndpointTest {
        assertThat(mTransactionManager.numReliableMessageTransactionPending()).isEqualTo(0);
    }

    @Test
    public void testDuplicateMessageRejected() throws RemoteException {
        IContextHubEndpoint endpoint = registerExampleEndpoint();
        int sessionId = openTestSession(endpoint);

        mEndpointManager.onMessageReceived(sessionId, SAMPLE_MESSAGE);
        ArgumentCaptor<HubMessage> messageCaptor = ArgumentCaptor.forClass(HubMessage.class);
        verify(mMockCallback).onMessageReceived(eq(sessionId), messageCaptor.capture());
        assertThat(messageCaptor.getValue()).isEqualTo(SAMPLE_MESSAGE);

        // Send a duplicate message and confirm it can be rejected
        mEndpointManager.onMessageReceived(sessionId, SAMPLE_MESSAGE);
        ArgumentCaptor<MessageDeliveryStatus> statusCaptor =
                ArgumentCaptor.forClass(MessageDeliveryStatus.class);
        verify(mMockEndpointCommunications)
                .sendMessageDeliveryStatusToEndpoint(eq(sessionId), statusCaptor.capture());
        assertThat(statusCaptor.getValue().messageSequenceNumber)
                .isEqualTo(SAMPLE_MESSAGE.getMessageSequenceNumber());
        assertThat(statusCaptor.getValue().errorCode).isEqualTo(ErrorCode.TRANSIENT_ERROR);

        unregisterExampleEndpoint(endpoint);
    }

    /** A helper method to create a session and validates reliable message sending. */
    private void testMessageTransactionInternal(
            IContextHubEndpoint endpoint, boolean deliverMessageStatus) throws RemoteException {
        HubEndpointInfo targetInfo =
                new HubEndpointInfo(
                        TARGET_ENDPOINT_NAME,
                        TARGET_ENDPOINT_ID,
                        ENDPOINT_PACKAGE_NAME,
                        Collections.emptyList());
        int sessionId = endpoint.openSession(targetInfo, /* serviceDescriptor= */ null);
        mEndpointManager.onEndpointSessionOpenComplete(sessionId);
        int sessionId = openTestSession(endpoint);

        final int messageType = 1234;
        HubMessage message =
                new HubMessage.Builder(messageType, new byte[] {1, 2, 3, 4, 5})
                        .setResponseRequired(true)
                        .build();
        IContextHubTransactionCallback callback =
                new IContextHubTransactionCallback.Stub() {
                    @Override
@@ -258,13 +277,13 @@ public class ContextHubEndpointTest {
                        Log.i(TAG, "Received onTransactionComplete callback, result=" + result);
                    }
                };
        endpoint.sendMessage(sessionId, message, callback);
        endpoint.sendMessage(sessionId, SAMPLE_MESSAGE, callback);
        ArgumentCaptor<Message> messageCaptor = ArgumentCaptor.forClass(Message.class);
        verify(mMockEndpointCommunications, timeout(1000))
                .sendMessageToEndpoint(eq(sessionId), messageCaptor.capture());
        Message halMessage = messageCaptor.getValue();
        assertThat(halMessage.type).isEqualTo(message.getMessageType());
        assertThat(halMessage.content).isEqualTo(message.getMessageBody());
        assertThat(halMessage.type).isEqualTo(SAMPLE_MESSAGE.getMessageType());
        assertThat(halMessage.content).isEqualTo(SAMPLE_MESSAGE.getMessageBody());
        assertThat(mTransactionManager.numReliableMessageTransactionPending()).isEqualTo(1);

        if (deliverMessageStatus) {
@@ -308,4 +327,16 @@ public class ContextHubEndpointTest {
                .isEqualTo(expectedInfo.getIdentifier().getHub());
        assertThat(mEndpointManager.getNumRegisteredClients()).isEqualTo(0);
    }

    private int openTestSession(IContextHubEndpoint endpoint) throws RemoteException {
        HubEndpointInfo targetInfo =
                new HubEndpointInfo(
                        TARGET_ENDPOINT_NAME,
                        TARGET_ENDPOINT_ID,
                        ENDPOINT_PACKAGE_NAME,
                        Collections.emptyList());
        int sessionId = endpoint.openSession(targetInfo, /* serviceDescriptor= */ null);
        mEndpointManager.onEndpointSessionOpenComplete(sessionId);
        return sessionId;
    }
}