Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 539dbe61 authored by Junyu Lai's avatar Junyu Lai Committed by Gerrit Code Review
Browse files

Merge "[VCN05] Pass request type when requesting network"

parents 3b26172c 11fb36ea
Loading
Loading
Loading
Loading
+10 −9
Original line number Diff line number Diff line
@@ -16,6 +16,9 @@
package android.net;

import static android.net.IpSecManager.INVALID_RESOURCE_ID;
import static android.net.NetworkRequest.Type.LISTEN;
import static android.net.NetworkRequest.Type.REQUEST;
import static android.net.NetworkRequest.Type.TRACK_DEFAULT;

import android.annotation.CallbackExecutor;
import android.annotation.IntDef;
@@ -3730,14 +3733,12 @@ public class ConnectivityManager {
    private static final HashMap<NetworkRequest, NetworkCallback> sCallbacks = new HashMap<>();
    private static CallbackHandler sCallbackHandler;

    private static final int LISTEN  = 1;
    private static final int REQUEST = 2;

    private NetworkRequest sendRequestForNetwork(NetworkCapabilities need, NetworkCallback callback,
            int timeoutMs, int action, int legacyType, CallbackHandler handler) {
            int timeoutMs, NetworkRequest.Type reqType, int legacyType, CallbackHandler handler) {
        printStackTrace();
        checkCallbackNotNull(callback);
        Preconditions.checkArgument(action == REQUEST || need != null, "null NetworkCapabilities");
        Preconditions.checkArgument(
                reqType == TRACK_DEFAULT || need != null, "null NetworkCapabilities");
        final NetworkRequest request;
        final String callingPackageName = mContext.getOpPackageName();
        try {
@@ -3750,13 +3751,13 @@ public class ConnectivityManager {
                }
                Messenger messenger = new Messenger(handler);
                Binder binder = new Binder();
                if (action == LISTEN) {
                if (reqType == LISTEN) {
                    request = mService.listenForNetwork(
                            need, messenger, binder, callingPackageName);
                } else {
                    request = mService.requestNetwork(
                            need, messenger, timeoutMs, binder, legacyType, callingPackageName,
                            getAttributionTag());
                            need, reqType.ordinal(), messenger, timeoutMs, binder, legacyType,
                            callingPackageName, getAttributionTag());
                }
                if (request != null) {
                    sCallbacks.put(request, callback);
@@ -4260,7 +4261,7 @@ public class ConnectivityManager {
        // request, i.e., the system default network.
        CallbackHandler cbHandler = new CallbackHandler(handler);
        sendRequestForNetwork(null /* NetworkCapabilities need */, networkCallback, 0,
                REQUEST, TYPE_NONE, cbHandler);
                TRACK_DEFAULT, TYPE_NONE, cbHandler);
    }

    /**
+1 −1
Original line number Diff line number Diff line
@@ -167,7 +167,7 @@ interface IConnectivityManager
            in NetworkCapabilities nc, int score, in NetworkAgentConfig config,
            in int factorySerialNumber);

    NetworkRequest requestNetwork(in NetworkCapabilities networkCapabilities,
    NetworkRequest requestNetwork(in NetworkCapabilities networkCapabilities, int reqType,
            in Messenger messenger, int timeoutSec, in IBinder binder, int legacy,
            String callingPackageName, String callingAttributionTag);

+29 −20
Original line number Diff line number Diff line
@@ -5642,31 +5642,40 @@ public class ConnectivityService extends IConnectivityManager.Stub

    @Override
    public NetworkRequest requestNetwork(NetworkCapabilities networkCapabilities,
            Messenger messenger, int timeoutMs, IBinder binder, int legacyType,
            @NonNull String callingPackageName, @Nullable String callingAttributionTag) {
            int reqTypeInt, Messenger messenger, int timeoutMs, IBinder binder,
            int legacyType, @NonNull String callingPackageName,
            @Nullable String callingAttributionTag) {
        if (legacyType != TYPE_NONE && !checkNetworkStackPermission()) {
            if (checkUnsupportedStartingFrom(Build.VERSION_CODES.M, callingPackageName)) {
                throw new SecurityException("Insufficient permissions to specify legacy type");
            }
        }
        final int callingUid = mDeps.getCallingUid();
        final NetworkRequest.Type type = (networkCapabilities == null)
                ? NetworkRequest.Type.TRACK_DEFAULT
                : NetworkRequest.Type.REQUEST;
        // If the requested networkCapabilities is null, take them instead from
        // the default network request. This allows callers to keep track of
        // the system default network.
        if (type == NetworkRequest.Type.TRACK_DEFAULT) {
        final NetworkRequest.Type reqType;
        try {
            reqType = NetworkRequest.Type.values()[reqTypeInt];
        } catch (ArrayIndexOutOfBoundsException e) {
            throw new IllegalArgumentException("Unsupported request type " + reqTypeInt);
        }
        switch (reqType) {
            case TRACK_DEFAULT:
                // If the request type is TRACK_DEFAULT, the passed {@code networkCapabilities}
                // is unused and will be replaced by the one from the default network request.
                // This allows callers to keep track of the system default network.
                networkCapabilities = createDefaultNetworkCapabilitiesForUid(callingUid);
                enforceAccessPermission();
        } else {
                break;
            case REQUEST:
                networkCapabilities = new NetworkCapabilities(networkCapabilities);
                enforceNetworkRequestPermissions(networkCapabilities, callingPackageName,
                        callingAttributionTag);
            // TODO: this is incorrect. We mark the request as metered or not depending on the state
            // of the app when the request is filed, but we never change the request if the app
            // changes network state. http://b/29964605
                // TODO: this is incorrect. We mark the request as metered or not depending on
                //  the state of the app when the request is filed, but we never change the
                //  request if the app changes network state. http://b/29964605
                enforceMeteredApnPolicy(networkCapabilities);
                break;
            default:
                throw new IllegalArgumentException("Unsupported request type " + reqType);
        }
        ensureRequestableCapabilities(networkCapabilities);
        ensureSufficientPermissionsForRequest(networkCapabilities,
@@ -5685,7 +5694,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        ensureValid(networkCapabilities);

        NetworkRequest networkRequest = new NetworkRequest(networkCapabilities, legacyType,
                nextNetworkRequestId(), type);
                nextNetworkRequestId(), reqType);
        NetworkRequestInfo nri = new NetworkRequestInfo(messenger, networkRequest, binder);
        if (DBG) log("requestNetwork for " + nri);

+42 −12
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@

package android.net;

import static android.net.ConnectivityManager.TYPE_NONE;
import static android.net.NetworkCapabilities.NET_CAPABILITY_CBS;
import static android.net.NetworkCapabilities.NET_CAPABILITY_DUN;
import static android.net.NetworkCapabilities.NET_CAPABILITY_FOTA;
@@ -31,16 +32,21 @@ import static android.net.NetworkCapabilities.TRANSPORT_BLUETOOTH;
import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
import static android.net.NetworkCapabilities.TRANSPORT_ETHERNET;
import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
import static android.net.NetworkRequest.Type.REQUEST;
import static android.net.NetworkRequest.Type.TRACK_DEFAULT;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@@ -49,9 +55,7 @@ import static org.mockito.Mockito.when;
import android.app.PendingIntent;
import android.content.Context;
import android.content.pm.ApplicationInfo;
import android.net.ConnectivityManager;
import android.net.ConnectivityManager.NetworkCallback;
import android.net.NetworkCapabilities;
import android.os.Build.VERSION_CODES;
import android.os.Bundle;
import android.os.Handler;
@@ -213,9 +217,8 @@ public class ConnectivityManagerTest {
        ArgumentCaptor<Messenger> captor = ArgumentCaptor.forClass(Messenger.class);

        // register callback
        when(mService.requestNetwork(
                any(), captor.capture(), anyInt(), any(), anyInt(), any(), nullable(String.class)))
                .thenReturn(request);
        when(mService.requestNetwork(any(), anyInt(), captor.capture(), anyInt(), any(), anyInt(),
                any(), nullable(String.class))).thenReturn(request);
        manager.requestNetwork(request, callback, handler);

        // callback triggers
@@ -242,9 +245,8 @@ public class ConnectivityManagerTest {
        ArgumentCaptor<Messenger> captor = ArgumentCaptor.forClass(Messenger.class);

        // register callback
        when(mService.requestNetwork(
                any(), captor.capture(), anyInt(), any(), anyInt(), any(), nullable(String.class)))
                .thenReturn(req1);
        when(mService.requestNetwork(any(), anyInt(), captor.capture(), anyInt(), any(), anyInt(),
                any(), nullable(String.class))).thenReturn(req1);
        manager.requestNetwork(req1, callback, handler);

        // callback triggers
@@ -261,9 +263,8 @@ public class ConnectivityManagerTest {
        verify(callback, timeout(100).times(0)).onLosing(any(), anyInt());

        // callback can be registered again
        when(mService.requestNetwork(
                any(), captor.capture(), anyInt(), any(), anyInt(), any(), nullable(String.class)))
                .thenReturn(req2);
        when(mService.requestNetwork(any(), anyInt(), captor.capture(), anyInt(), any(), anyInt(),
                any(), nullable(String.class))).thenReturn(req2);
        manager.requestNetwork(req2, callback, handler);

        // callback triggers
@@ -286,7 +287,7 @@ public class ConnectivityManagerTest {
        info.targetSdkVersion = VERSION_CODES.N_MR1 + 1;

        when(mCtx.getApplicationInfo()).thenReturn(info);
        when(mService.requestNetwork(any(), any(), anyInt(), any(), anyInt(), any(),
        when(mService.requestNetwork(any(), anyInt(), any(), anyInt(), any(), anyInt(), any(),
                nullable(String.class))).thenReturn(request);

        Handler handler = new Handler(Looper.getMainLooper());
@@ -340,6 +341,35 @@ public class ConnectivityManagerTest {
        }
    }

    @Test
    public void testRequestType() throws Exception {
        final String testPkgName = "MyPackage";
        final ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
        when(mCtx.getOpPackageName()).thenReturn(testPkgName);
        final NetworkRequest request = makeRequest(1);
        final NetworkCallback callback = new ConnectivityManager.NetworkCallback();

        manager.requestNetwork(request, callback);
        verify(mService).requestNetwork(eq(request.networkCapabilities),
                eq(REQUEST.ordinal()), any(), anyInt(), any(), eq(TYPE_NONE),
                eq(testPkgName), eq(null));
        reset(mService);

        // Verify that register network callback does not calls requestNetwork at all.
        manager.registerNetworkCallback(request, callback);
        verify(mService, never()).requestNetwork(any(), anyInt(), any(), anyInt(), any(),
                anyInt(), any(), any());
        verify(mService).listenForNetwork(eq(request.networkCapabilities), any(), any(),
                eq(testPkgName));
        reset(mService);

        manager.registerDefaultNetworkCallback(callback);
        verify(mService).requestNetwork(eq(null),
                eq(TRACK_DEFAULT.ordinal()), any(), anyInt(), any(), eq(TYPE_NONE),
                eq(testPkgName), eq(null));
        reset(mService);
    }

    static Message makeMessage(NetworkRequest req, int messageType) {
        Bundle bundle = new Bundle();
        bundle.putParcelable(NetworkRequest.class.getSimpleName(), req);
+2 −2
Original line number Diff line number Diff line
@@ -3360,8 +3360,8 @@ public class ConnectivityServiceTest {
            NetworkCapabilities networkCapabilities = new NetworkCapabilities();
            networkCapabilities.addTransportType(TRANSPORT_WIFI)
                    .setNetworkSpecifier(new MatchAllNetworkSpecifier());
            mService.requestNetwork(networkCapabilities, null, 0, null,
                    ConnectivityManager.TYPE_WIFI, mContext.getPackageName(),
            mService.requestNetwork(networkCapabilities, NetworkRequest.Type.REQUEST.ordinal(),
                    null, 0, null, ConnectivityManager.TYPE_WIFI, mContext.getPackageName(),
                    getAttributionTag());
        });