Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 66bc5288 authored by Benedict Wong's avatar Benedict Wong
Browse files

DO NOT MERGE: Add unit tests to ensure VPN meteredness

These new tests ensure that VPNs report the meteredness of their
underlying networks correctly. The added test verifies VPN meteredness
for cases of metered and unmetered WiFi and Cell

Bug: 78644887
Test: This; ran on walleye-eng
Change-Id: I28bdc71a336bfd97f7908455d4781d774df44b87
parent d08ab5a6
Loading
Loading
Loading
Loading
+37 −21
Original line number Diff line number Diff line
@@ -969,7 +969,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        if (!mLockdownEnabled) {
            int user = UserHandle.getUserId(uid);
            synchronized (mVpns) {
                Vpn vpn = mVpns.get(user);
                Vpn vpn = getVpn(user);
                if (vpn != null && vpn.appliesToUid(uid)) {
                    return vpn.getUnderlyingNetworks();
                }
@@ -1017,7 +1017,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
            return false;
        }
        synchronized (mVpns) {
            final Vpn vpn = mVpns.get(UserHandle.getUserId(uid));
            final Vpn vpn = getVpn(UserHandle.getUserId(uid));
            if (vpn != null && vpn.isBlockingUid(uid)) {
                return true;
            }
@@ -1094,7 +1094,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        final int user = UserHandle.getUserId(uid);
        int vpnNetId = NETID_UNSET;
        synchronized (mVpns) {
            final Vpn vpn = mVpns.get(user);
            final Vpn vpn = getVpn(user);
            if (vpn != null && vpn.appliesToUid(uid)) vpnNetId = vpn.getNetId();
        }
        NetworkAgentInfo nai;
@@ -1224,7 +1224,7 @@ public class ConnectivityService extends IConnectivityManager.Stub

        if (!mLockdownEnabled) {
            synchronized (mVpns) {
                Vpn vpn = mVpns.get(userId);
                Vpn vpn = getVpn(userId);
                if (vpn != null) {
                    Network[] networks = vpn.getUnderlyingNetworks();
                    if (networks != null) {
@@ -3424,7 +3424,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        throwIfLockdownEnabled();

        synchronized (mVpns) {
            Vpn vpn = mVpns.get(userId);
            Vpn vpn = getVpn(userId);
            if (vpn != null) {
                return vpn.prepare(oldPackage, newPackage);
            } else {
@@ -3451,7 +3451,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        enforceCrossUserPermission(userId);

        synchronized (mVpns) {
            Vpn vpn = mVpns.get(userId);
            Vpn vpn = getVpn(userId);
            if (vpn != null) {
                vpn.setPackageAuthorization(packageName, authorized);
            }
@@ -3470,7 +3470,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        throwIfLockdownEnabled();
        int user = UserHandle.getUserId(Binder.getCallingUid());
        synchronized (mVpns) {
            return mVpns.get(user).establish(config);
            return getVpn(user).establish(config);
        }
    }

@@ -3487,7 +3487,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        }
        int user = UserHandle.getUserId(Binder.getCallingUid());
        synchronized (mVpns) {
            mVpns.get(user).startLegacyVpn(profile, mKeyStore, egress);
            getVpn(user).startLegacyVpn(profile, mKeyStore, egress);
        }
    }

@@ -3501,7 +3501,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        enforceCrossUserPermission(userId);

        synchronized (mVpns) {
            return mVpns.get(userId).getLegacyVpnInfo();
            return getVpn(userId).getLegacyVpnInfo();
        }
    }

@@ -3565,7 +3565,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
    public VpnConfig getVpnConfig(int userId) {
        enforceCrossUserPermission(userId);
        synchronized (mVpns) {
            Vpn vpn = mVpns.get(userId);
            Vpn vpn = getVpn(userId);
            if (vpn != null) {
                return vpn.getVpnConfig();
            } else {
@@ -3599,7 +3599,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
            }
            int user = UserHandle.getUserId(Binder.getCallingUid());
            synchronized (mVpns) {
                Vpn vpn = mVpns.get(user);
                Vpn vpn = getVpn(user);
                if (vpn == null) {
                    Slog.w(TAG, "VPN for user " + user + " not ready yet. Skipping lockdown");
                    return false;
@@ -3646,7 +3646,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
     */
    private boolean startAlwaysOnVpn(int userId) {
        synchronized (mVpns) {
            Vpn vpn = mVpns.get(userId);
            Vpn vpn = getVpn(userId);
            if (vpn == null) {
                // Shouldn't happen as all codepaths that point here should have checked the Vpn
                // exists already.
@@ -3664,7 +3664,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        enforceCrossUserPermission(userId);

        synchronized (mVpns) {
            Vpn vpn = mVpns.get(userId);
            Vpn vpn = getVpn(userId);
            if (vpn == null) {
                Slog.w(TAG, "User " + userId + " has no Vpn configuration");
                return false;
@@ -3684,7 +3684,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        }

        synchronized (mVpns) {
            Vpn vpn = mVpns.get(userId);
            Vpn vpn = getVpn(userId);
            if (vpn == null) {
                Slog.w(TAG, "User " + userId + " has no Vpn configuration");
                return false;
@@ -3706,7 +3706,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        enforceCrossUserPermission(userId);

        synchronized (mVpns) {
            Vpn vpn = mVpns.get(userId);
            Vpn vpn = getVpn(userId);
            if (vpn == null) {
                Slog.w(TAG, "User " + userId + " has no Vpn configuration");
                return null;
@@ -3852,22 +3852,38 @@ public class ConnectivityService extends IConnectivityManager.Stub

    private void onUserStart(int userId) {
        synchronized (mVpns) {
            Vpn userVpn = mVpns.get(userId);
            Vpn userVpn = getVpn(userId);
            if (userVpn != null) {
                loge("Starting user already has a VPN");
                return;
            }
            userVpn = new Vpn(mHandler.getLooper(), mContext, mNetd, userId);
            mVpns.put(userId, userVpn);
            setVpn(userId, userVpn);
        }
        if (mUserManager.getUserInfo(userId).isPrimary() && LockdownVpnTracker.isEnabled()) {
            updateLockdownVpn();
        }
    }

    /** @hide */
    @VisibleForTesting
    Vpn getVpn(int userId) {
        synchronized (mVpns) {
            return mVpns.get(userId);
        }
    }

    /** @hide */
    @VisibleForTesting
    void setVpn(int userId, Vpn userVpn) {
        synchronized (mVpns) {
            mVpns.put(userId, userVpn);
        }
    }

    private void onUserStop(int userId) {
        synchronized (mVpns) {
            Vpn userVpn = mVpns.get(userId);
            Vpn userVpn = getVpn(userId);
            if (userVpn == null) {
                loge("Stopped user has no VPN");
                return;
@@ -5439,7 +5455,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        throwIfLockdownEnabled();
        int user = UserHandle.getUserId(Binder.getCallingUid());
        synchronized (mVpns) {
            return mVpns.get(user).addAddress(address, prefixLength);
            return getVpn(user).addAddress(address, prefixLength);
        }
    }

@@ -5448,7 +5464,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        throwIfLockdownEnabled();
        int user = UserHandle.getUserId(Binder.getCallingUid());
        synchronized (mVpns) {
            return mVpns.get(user).removeAddress(address, prefixLength);
            return getVpn(user).removeAddress(address, prefixLength);
        }
    }

@@ -5458,7 +5474,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        int user = UserHandle.getUserId(Binder.getCallingUid());
        boolean success;
        synchronized (mVpns) {
            success = mVpns.get(user).setUnderlyingNetworks(networks);
            success = getVpn(user).setUnderlyingNetworks(networks);
        }
        if (success) {
            notifyIfacesChangedForNetworkStats();
+87 −0
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ import static android.net.ConnectivityManager.TYPE_MOBILE;
import static android.net.ConnectivityManager.TYPE_MOBILE_FOTA;
import static android.net.ConnectivityManager.TYPE_MOBILE_MMS;
import static android.net.ConnectivityManager.TYPE_NONE;
import static android.net.ConnectivityManager.TYPE_VPN;
import static android.net.ConnectivityManager.TYPE_WIFI;
import static android.net.ConnectivityManager.getNetworkTypeName;
import static android.net.NetworkCapabilities.*;
@@ -102,6 +103,7 @@ import com.android.server.connectivity.MockableSystemProperties;
import com.android.server.connectivity.NetworkAgentInfo;
import com.android.server.connectivity.NetworkMonitor;
import com.android.server.connectivity.NetworkMonitor.CaptivePortalProbeResult;
import com.android.server.connectivity.Vpn;
import com.android.server.net.NetworkPinner;
import com.android.server.net.NetworkPolicyManagerInternal;

@@ -333,6 +335,9 @@ public class ConnectivityServiceTest extends AndroidTestCase {
                case TRANSPORT_WIFI_AWARE:
                    mScore = 20;
                    break;
                case TRANSPORT_VPN:
                    mScore = 0;
                    break;
                default:
                    throw new UnsupportedOperationException("unimplemented network type");
            }
@@ -868,6 +873,8 @@ public class ConnectivityServiceTest extends AndroidTestCase {
                return TYPE_WIFI;
            case TRANSPORT_CELLULAR:
                return TYPE_MOBILE;
            case TRANSPORT_VPN:
                return TYPE_VPN;
            default:
                return TYPE_NONE;
        }
@@ -3447,4 +3454,84 @@ public class ConnectivityServiceTest extends AndroidTestCase {
            return;
        }
    }

    @SmallTest
    public void testVpnNetworkMetered() {
        final TestNetworkCallback callback = new TestNetworkCallback();
        mCm.registerDefaultNetworkCallback(callback);

        final NetworkRequest cellRequest = new NetworkRequest.Builder()
                .addTransportType(TRANSPORT_CELLULAR).build();
        final TestNetworkCallback cellCallback = new TestNetworkCallback();
        mCm.registerNetworkCallback(cellRequest, cellCallback);

        // Setup cellular
        mCellNetworkAgent = new MockNetworkAgent(TRANSPORT_CELLULAR);
        mCellNetworkAgent.connect(true);
        callback.expectAvailableAndValidatedCallbacks(mCellNetworkAgent);
        cellCallback.expectAvailableAndValidatedCallbacks(mCellNetworkAgent);
        verifyActiveNetwork(TRANSPORT_CELLULAR);

        // Verify meteredness of cellular
        assertTrue(mCm.isActiveNetworkMetered());

        // Setup Wifi
        mWiFiNetworkAgent = new MockNetworkAgent(TRANSPORT_WIFI);
        mWiFiNetworkAgent.connect(true);
        callback.expectAvailableAndValidatedCallbacks(mWiFiNetworkAgent);
        cellCallback.expectCallback(CallbackState.LOSING, mCellNetworkAgent);
        verifyActiveNetwork(TRANSPORT_WIFI);

        // Verify meteredness of WiFi
        assertTrue(mCm.isActiveNetworkMetered());

        // Verify that setting unmetered on Wifi changes ActiveNetworkMetered
        mWiFiNetworkAgent.addCapability(NET_CAPABILITY_NOT_METERED);
        callback.expectCapabilitiesWith(NET_CAPABILITY_NOT_METERED, mWiFiNetworkAgent);
        assertFalse(mCm.isActiveNetworkMetered());

        // Setup VPN
        final MockNetworkAgent vpnNetworkAgent = new MockNetworkAgent(TRANSPORT_VPN);
        vpnNetworkAgent.connect(true);

        Vpn mockVpn = mock(Vpn.class);
        when(mockVpn.appliesToUid(anyInt())).thenReturn(true);
        when(mockVpn.getNetId()).thenReturn(vpnNetworkAgent.getNetwork().netId);

        Vpn oldVpn = mService.getVpn(UserHandle.myUserId());
        mService.setVpn(UserHandle.myUserId(), mockVpn);
        assertEquals(vpnNetworkAgent.getNetwork(), mCm.getActiveNetwork());

        // Verify meteredness of VPN on default network
        when(mockVpn.getUnderlyingNetworks()).thenReturn(null);
        assertFalse(mCm.isActiveNetworkMetered());
        assertFalse(mCm.isActiveNetworkMeteredForUid(Process.myUid()));

        // Verify meteredness of VPN on unmetered wifi
        when(mockVpn.getUnderlyingNetworks())
                .thenReturn(new Network[] {mWiFiNetworkAgent.getNetwork()});
        assertFalse(mCm.isActiveNetworkMetered());
        assertFalse(mCm.isActiveNetworkMeteredForUid(Process.myUid()));

        // Set WiFi as metered, then check to see that it has been updated on the VPN
        mWiFiNetworkAgent.removeCapability(NET_CAPABILITY_NOT_METERED);
        callback.expectCapabilitiesWithout(NET_CAPABILITY_NOT_METERED, mWiFiNetworkAgent);
        assertTrue(mCm.isActiveNetworkMetered());
        assertTrue(mCm.isActiveNetworkMeteredForUid(Process.myUid()));

        // Switch to cellular
        when(mockVpn.getUnderlyingNetworks())
                .thenReturn(new Network[] {mCellNetworkAgent.getNetwork()});
        assertTrue(mCm.isActiveNetworkMetered());
        assertTrue(mCm.isActiveNetworkMeteredForUid(Process.myUid()));

        // Test unmetered cellular
        mCellNetworkAgent.addCapability(NET_CAPABILITY_NOT_METERED);
        cellCallback.expectCapabilitiesWith(NET_CAPABILITY_NOT_METERED, mCellNetworkAgent);
        assertFalse(mCm.isActiveNetworkMetered());
        assertFalse(mCm.isActiveNetworkMeteredForUid(Process.myUid()));

        mService.setVpn(UserHandle.myUserId(), oldVpn);
        mCm.unregisterNetworkCallback(callback);
    }
}