Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 7844d421 authored by Joe Antonetti's avatar Joe Antonetti
Browse files

Create ConnectedAssociationStore

Create a new class which tracks which associations are connected. This logic was previously in TaskBroadcaster, but it's now needed for other usages (such as in the RemoteTaskStore).

Flag: android.companion.enable_task_continuity
Test: Added unit tests
Bug: 400970610
parent 61173642
Loading
Loading
Loading
Loading
+22 −30
Original line number Diff line number Diff line
@@ -21,10 +21,10 @@ import static android.companion.CompanionDeviceManager.MESSAGE_TASK_CONTINUITY;
import android.app.ActivityManager;
import android.app.ActivityTaskManager;
import android.companion.CompanionDeviceManager;
import android.companion.AssociationInfo;
import android.content.Context;
import android.util.Slog;

import com.android.server.companion.datatransfer.continuity.connectivity.ConnectedAssociationStore;
import com.android.server.companion.datatransfer.continuity.messages.ContinuityDeviceConnected;
import com.android.server.companion.datatransfer.continuity.messages.RemoteTaskInfo;
import com.android.server.companion.datatransfer.continuity.messages.TaskContinuityMessage;
@@ -34,28 +34,28 @@ import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;

/**
 * Responsible for broadcasting recent tasks on the current device to the user's
 * other devices via {@link CompanionDeviceManager}.
 */
class TaskBroadcaster {
class TaskBroadcaster implements ConnectedAssociationStore.Observer {

    private static final String TAG = "TaskBroadcaster";

    private final Context mContext;
    private final ActivityTaskManager mActivityTaskManager;
    private final CompanionDeviceManager mCompanionDeviceManager;
    private final Set<Integer> mConnectedAssociationIds = new HashSet<>();

    private final Consumer<List<AssociationInfo>> mOnTransportsChangedListener =
        this::onTransportsChanged;
    private final ConnectedAssociationStore mConnectedAssociationStore;

    private boolean mIsBroadcasting = false;

    public TaskBroadcaster(Context context) {
    public TaskBroadcaster(
        Context context,
        ConnectedAssociationStore connectedAssociationStore) {

        mContext = context;
        mConnectedAssociationStore = connectedAssociationStore;

        mActivityTaskManager
            = context.getSystemService(ActivityTaskManager.class);
@@ -71,10 +71,7 @@ class TaskBroadcaster {
        }

        Slog.v(TAG, "Starting broadcasting");
        mCompanionDeviceManager.addOnTransportsChangedListener(
            mContext.getMainExecutor(),
            mOnTransportsChangedListener
        );
        mConnectedAssociationStore.addObserver(this);
        mIsBroadcasting = true;
    }

@@ -86,28 +83,23 @@ class TaskBroadcaster {

        Slog.v(TAG, "Stopping broadcasting");
        mIsBroadcasting = false;
        mCompanionDeviceManager.removeOnTransportsChangedListener(
            mOnTransportsChangedListener
        );
        mConnectedAssociationStore.removeObserver(this);
    }

    private void onTransportsChanged(List<AssociationInfo> associationInfos) {
        Set<Integer> removedAssociationIds
            = new HashSet<>(mConnectedAssociationIds);

        for (AssociationInfo associationInfo : associationInfos) {
            if (!mConnectedAssociationIds.contains(associationInfo.getId())) {
                sendDeviceConnectedMessage(associationInfo.getId());
            } else {
                removedAssociationIds.remove(associationInfo.getId());
            }
    @Override
    public void onTransportConnected(int associationId) {
        Slog.v(
            TAG,
            "Transport connected for association id: " + associationId);

            mConnectedAssociationIds.add(associationInfo.getId());
        sendDeviceConnectedMessage(associationId);
    }

        for (Integer removedAssociationId : removedAssociationIds) {
            mConnectedAssociationIds.remove(removedAssociationId);
        }
    @Override
    public void onTransportDisconnected(int associationId) {
        Slog.v(
            TAG,
            "Transport disconnected for association id: " + associationId);
    }

    private void sendDeviceConnectedMessage(int associationId) {
+9 −1
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@

package com.android.server.companion.datatransfer.continuity;

import android.companion.CompanionDeviceManager;
import android.companion.datatransfer.continuity.ITaskContinuityManager;
import android.content.Context;
import android.util.Slog;
@@ -25,6 +26,7 @@ import com.android.server.companion.datatransfer.continuity.messages.TaskContinu
import com.android.server.companion.datatransfer.continuity.tasks.RemoteTaskStore;

import com.android.server.SystemService;
import com.android.server.companion.datatransfer.continuity.connectivity.ConnectedAssociationStore;

/**
 * Service to handle task continuity features
@@ -38,12 +40,18 @@ public final class TaskContinuityManagerService extends SystemService {

    private TaskContinuityManagerServiceImpl mTaskContinuityManagerService;
    private TaskBroadcaster mTaskBroadcaster;
    private ConnectedAssociationStore mConnectedAssociationStore;
    private TaskContinuityMessageReceiver mTaskContinuityMessageReceiver;
    private RemoteTaskStore mRemoteTaskStore;

    public TaskContinuityManagerService(Context context) {
        super(context);
        mTaskBroadcaster = new TaskBroadcaster(context);
        mConnectedAssociationStore = new ConnectedAssociationStore(context);

        mTaskBroadcaster = new TaskBroadcaster(
            context,
            mConnectedAssociationStore);

        mTaskContinuityMessageReceiver = new TaskContinuityMessageReceiver(context);
        mRemoteTaskStore = new RemoteTaskStore();
    }
+103 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2025 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.server.companion.datatransfer.continuity.connectivity;

import android.annotation.NonNull;
import android.companion.AssociationInfo;
import android.companion.CompanionDeviceManager;
import android.content.Context;
import android.os.Handler;
import android.util.Log;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;

public class ConnectedAssociationStore {

    private static final String TAG = "ConnectedAssociationStore";

    private final CompanionDeviceManager mCompanionDeviceManager;
    private final Set<Integer> mConnectedAssociations = new HashSet<>();
    private final List<Observer> mObservers = new ArrayList<>();

    public interface Observer {
        void onTransportConnected(int associationId);
        void onTransportDisconnected(int associationId);
    }

    public ConnectedAssociationStore(
            @NonNull Context context) {
        mCompanionDeviceManager = context
            .getSystemService(CompanionDeviceManager.class);

        mCompanionDeviceManager.addOnTransportsChangedListener(
                context.getMainExecutor(),
                this::onTransportsChanged);
   }

    public void addObserver(@NonNull Observer observer) {
        mObservers.add(observer);
    }

    public void removeObserver(@NonNull Observer observer) {
        mObservers.remove(observer);
    }

    public Set<Integer> getConnectedAssociations() {
        return mConnectedAssociations;
    }

    private void onTransportsChanged(List<AssociationInfo> associationInfos) {
        Set<Integer> removedAssociations
            = new HashSet<>(mConnectedAssociations);

        Set<Integer> addedAssociations = new HashSet<>();
        for (AssociationInfo associationInfo : associationInfos) {
            if (!mConnectedAssociations.contains(associationInfo.getId())) {
                addedAssociations.add(associationInfo.getId());
            }

            if (removedAssociations.contains(associationInfo.getId())) {
                removedAssociations.remove(associationInfo.getId());
            }
        }

        mConnectedAssociations.addAll(addedAssociations);
        mConnectedAssociations.removeAll(removedAssociations);

        for (Integer associationId : removedAssociations) {
            Log.i(
                TAG,
                "Transport disconnected for association: " + associationId);

            for (Observer observer : mObservers) {
                observer.onTransportDisconnected(associationId);
            }
        }

        for (Integer associationId : addedAssociations) {
            Log.i(TAG, "Transport connected for association: " + associationId);
            for (Observer observer : mObservers) {
                observer.onTransportConnected(associationId);
            }
        }
    }
}
 No newline at end of file
+17 −26
Original line number Diff line number Diff line
@@ -38,6 +38,7 @@ import android.testing.TestableLooper;

import androidx.test.platform.app.InstrumentationRegistry;

import com.android.server.companion.datatransfer.continuity.connectivity.ConnectedAssociationStore;
import com.android.server.companion.datatransfer.continuity.messages.ContinuityDeviceConnected;
import com.android.server.companion.datatransfer.continuity.messages.TaskContinuityMessage;

@@ -65,6 +66,8 @@ public class TaskBroadcasterTest {

    private CompanionDeviceManager mCompanionDeviceManager;

    @Mock private ConnectedAssociationStore mMockConnectedAssociationStore;

    private TaskBroadcaster mTaskBroadcaster;

    @Before
@@ -88,7 +91,9 @@ public class TaskBroadcasterTest {
            .thenReturn(mCompanionDeviceManager);

        // Create TaskBroadcaster.
        mTaskBroadcaster = new TaskBroadcaster(mMockContext);
        mTaskBroadcaster = new TaskBroadcaster(
            mMockContext,
            mMockConnectedAssociationStore);
    }

    @Test
@@ -96,40 +101,31 @@ public class TaskBroadcasterTest {
        throws Exception {

        mTaskBroadcaster.stopBroadcasting();
        verify(mMockCompanionDeviceManagerService, never())
            .removeOnTransportsChangedListener(any());
        verify(mMockConnectedAssociationStore, never())
            .addObserver(mTaskBroadcaster);
    }

    @Test
    public void testStartAndStopBroadcasting_updatesTransportsListener()
        throws Exception {

        // Start broadcasting, verifying a transport listener is added.
        ArgumentCaptor<IOnTransportsChangedListener> listenerCaptor
            = ArgumentCaptor.forClass(IOnTransportsChangedListener.class);
        // Start broadcasting, verifying an association listener is added.
        mTaskBroadcaster.startBroadcasting();
        verify(mMockCompanionDeviceManagerService, times(1))
            .addOnTransportsChangedListener(
                listenerCaptor.capture());
        IOnTransportsChangedListener listener = listenerCaptor.getValue();
        assertThat(listener).isNotNull();
        verify(mMockConnectedAssociationStore, times(1))
            .addObserver(mTaskBroadcaster);

        // Stop broadcasting, verifying the transport listener is removed.
        // Stop broadcasting, verifying the association listener is removed.
        mTaskBroadcaster.stopBroadcasting();
        verify(mMockCompanionDeviceManagerService, times(1))
            .removeOnTransportsChangedListener(listener);
        verify(mMockConnectedAssociationStore, times(1))
            .removeObserver(mTaskBroadcaster);
    }

    @Test
    public void testStartBroadcasting_startsBroadcasting() throws Exception {
        // Start broadcasting, verifying a transport listener is added.
        ArgumentCaptor<IOnTransportsChangedListener> listenerCaptor
            = ArgumentCaptor.forClass(IOnTransportsChangedListener.class);
        mTaskBroadcaster.startBroadcasting();
        verify(mMockCompanionDeviceManagerService, times(1))
            .addOnTransportsChangedListener(
                listenerCaptor.capture());
        IOnTransportsChangedListener listener = listenerCaptor.getValue();
        verify(mMockConnectedAssociationStore, times(1))
            .addObserver(mTaskBroadcaster);

        // Setup a fake foreground task.
        String expectedLabel = "test";
@@ -143,12 +139,7 @@ public class TaskBroadcasterTest {
            .thenReturn(Arrays.asList(taskInfo));

        // Add a new transport
        AssociationInfo associationInfo = new AssociationInfo.Builder(1, 0, "")
            .setDisplayName("test")
            .build();

        listener.onTransportsChanged(Arrays.asList(associationInfo));
        TestableLooper.get(this).processAllMessages();
        mTaskBroadcaster.onTransportConnected(1);

        // Verify the message is sent.
        ArgumentCaptor<byte[]> messageCaptor
+204 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2025 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.server.companion.datatransfer.continuity.connectivity;

import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import android.companion.AssociationInfo;
import android.companion.CompanionDeviceManager;
import android.companion.ICompanionDeviceManager;
import android.companion.IOnTransportsChangedListener;
import android.content.Context;
import android.content.ContextWrapper;
import android.os.RemoteException;
import android.platform.test.annotations.Presubmit;
import android.testing.AndroidTestingRunner;
import android.testing.TestableLooper;

import androidx.test.platform.app.InstrumentationRegistry;

import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Executor;

@Presubmit
@RunWith(AndroidTestingRunner.class)
@TestableLooper.RunWithLooper(setAsMainLooper = true)
public class ConnectedAssociationStoreTest {

    private Context mMockContext;
    @Mock private ICompanionDeviceManager mMockCompanionDeviceManagerService;
    @Mock private Executor mMockExecutor;
    @Mock private ConnectedAssociationStore.Observer mMockObserver;

    @Captor
    private ArgumentCaptor<IOnTransportsChangedListener> mListenerCaptor;

    private ConnectedAssociationStore mConnectedAssociationStore;

    private CompanionDeviceManager mCompanionDeviceManager;

    @Before
    public void setUp() throws RemoteException {
        MockitoAnnotations.initMocks(this);
        mMockContext =  Mockito.spy(
            new ContextWrapper(
                InstrumentationRegistry
                    .getInstrumentation()
                    .getTargetContext()));

        mCompanionDeviceManager = new CompanionDeviceManager(
                mMockCompanionDeviceManagerService,
                mMockContext);

        when(mMockContext.getSystemService(Context.COMPANION_DEVICE_SERVICE))
            .thenReturn(mCompanionDeviceManager);

        mConnectedAssociationStore = new ConnectedAssociationStore(mMockContext);
        mConnectedAssociationStore.addObserver(mMockObserver);
        verify(mMockCompanionDeviceManagerService).addOnTransportsChangedListener(
                mListenerCaptor.capture());
    }

    @Test
    public void testOnTransportConnected_notifyObserver() throws RemoteException {
        // Simulate a new association connected.
        int associationId = 1;
        notifyTransportsChanged(
                Arrays.asList(createAssociationInfo(associationId)));

        // Verify the observer is notified.
        verify(mMockObserver).onTransportConnected(associationId);
        verify(mMockObserver, never()).onTransportDisconnected(associationId);
    }

    @Test
    public void testOnTransportDisconnected_notifyObserver() throws RemoteException {
        // Start with an association connected.
        int associationId = 1;
        notifyTransportsChanged(
                Arrays.asList(createAssociationInfo(associationId)));

        // Simulate the association being disconnected.
        notifyTransportsChanged(Collections.emptyList());

        // Verify the observer is notified of the disconnection.
        verify(mMockObserver).onTransportDisconnected(associationId);
    }

    @Test
    public void testOnTransportChanged_noChange_noNotification() throws RemoteException {
        // Start with an association connected.
        int associationId = 1;
        notifyTransportsChanged(
                Arrays.asList(createAssociationInfo(associationId)));

        // Simulate the same association still connected.
        notifyTransportsChanged(
                Arrays.asList(createAssociationInfo(associationId)));

        // Verify the observer is only notified once for the initial connection.
        verify(mMockObserver, times(1)).onTransportConnected(associationId);
        verify(mMockObserver, never()).onTransportDisconnected(associationId);
    }

    @Test
    public void testGetConnectedAssociations() throws RemoteException {
        // Connect two associations.
        int associationId1 = 1;
        int associationId2 = 2;
        notifyTransportsChanged(
                Arrays.asList(
                        createAssociationInfo(associationId1),
                        createAssociationInfo(associationId2)));

        // Verify that getConnectedAssociations returns the correct set.
        Set<Integer> connectedAssociations
            = mConnectedAssociationStore.getConnectedAssociations();
        assertThat(connectedAssociations)
            .containsExactly(associationId1, associationId2);

        // Disconnect one association.
        notifyTransportsChanged(
                Arrays.asList(createAssociationInfo(associationId1)));

        // Verify that getConnectedAssociations returns the updated set.
        connectedAssociations
            = mConnectedAssociationStore.getConnectedAssociations();

        assertThat(connectedAssociations).containsExactly(associationId1);
    }

    @Test
    public void testAddAndRemoveObserver() throws RemoteException {
        ConnectedAssociationStore.Observer newMockObserver = mock(
            ConnectedAssociationStore.Observer.class);

        // Add a new observer
        mConnectedAssociationStore.addObserver(newMockObserver);

        // Simulate a new association connected.
        int associationId = 1;
        notifyTransportsChanged(
            Arrays.asList(createAssociationInfo(associationId)));

        // Verify the new observer is notified.
        verify(newMockObserver).onTransportConnected(associationId);

        // Remove the new observer
        mConnectedAssociationStore.removeObserver(newMockObserver);

        // Simulate the association being disconnected.
        notifyTransportsChanged(Collections.emptyList());

        // Verify the removed observer is not notified.
        verify(newMockObserver, never()).onTransportDisconnected(associationId);
        // But the original observer is still notified
        verify(mMockObserver).onTransportDisconnected(associationId);
    }

    private void notifyTransportsChanged(
        List<AssociationInfo> associationInfos) throws RemoteException {

        mListenerCaptor.getValue().onTransportsChanged(associationInfos);
        TestableLooper.get(this).processAllMessages();
    }

    private AssociationInfo createAssociationInfo(int associationId) {
        return new AssociationInfo.Builder(associationId, 0, "test_device_mac_address")
                .setDisplayName("test_device_name")
                .build();
    }
}
 No newline at end of file