Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit fc26c8a9 authored by Chalard Jean's avatar Chalard Jean
Browse files

Add a provider to VPN

Test: FrameworksNetTests NetworkStackTests
Change-Id: I982543cdee358bb62d3b56a7fd9d71dc18908b65
parent 871782de
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -5140,7 +5140,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        }
    }

    private void onUserStart(int userId) {
    private void onUserStarted(int userId) {
        synchronized (mVpns) {
            Vpn userVpn = mVpns.get(userId);
            if (userVpn != null) {
@@ -5155,7 +5155,7 @@ public class ConnectivityService extends IConnectivityManager.Stub
        }
    }

    private void onUserStop(int userId) {
    private void onUserStopped(int userId) {
        synchronized (mVpns) {
            Vpn userVpn = mVpns.get(userId);
            if (userVpn == null) {
@@ -5272,9 +5272,9 @@ public class ConnectivityService extends IConnectivityManager.Stub
            if (userId == UserHandle.USER_NULL) return;

            if (Intent.ACTION_USER_STARTED.equals(action)) {
                onUserStart(userId);
                onUserStarted(userId);
            } else if (Intent.ACTION_USER_STOPPED.equals(action)) {
                onUserStop(userId);
                onUserStopped(userId);
            } else if (Intent.ACTION_USER_ADDED.equals(action)) {
                onUserAdded(userId);
            } else if (Intent.ACTION_USER_REMOVED.equals(action)) {
+10 −3
Original line number Diff line number Diff line
@@ -151,7 +151,7 @@ import java.util.concurrent.atomic.AtomicInteger;
public class Vpn {
    private static final String NETWORKTYPE = "VPN";
    private static final String TAG = "Vpn";
    private static final String VPN_AGENT_NAME = "VpnNetworkAgent";
    private static final String VPN_PROVIDER_NAME_BASE = "VpnNetworkProvider:";
    private static final boolean LOGD = true;

    // Length of time (in milliseconds) that an app hosting an always-on VPN is placed on
@@ -195,6 +195,7 @@ public class Vpn {
    private final INetworkManagementService mNetd;
    @VisibleForTesting
    protected VpnConfig mConfig;
    private final NetworkProvider mNetworkProvider;
    @VisibleForTesting
    protected NetworkAgent mNetworkAgent;
    private final Looper mLooper;
@@ -397,6 +398,10 @@ public class Vpn {
            Log.wtf(TAG, "Problem registering observer", e);
        }

        mNetworkProvider = new NetworkProvider(context, looper, VPN_PROVIDER_NAME_BASE + mUserId);
        // This constructor is called in onUserStart and registers the provider. The provider
        // will be unregistered in onUserStop.
        mConnectivityManager.registerNetworkProvider(mNetworkProvider);
        mLegacyState = LegacyVpnInfo.STATE_DISCONNECTED;
        mNetworkInfo = new NetworkInfo(ConnectivityManager.TYPE_VPN, 0 /* subtype */, NETWORKTYPE,
                "" /* subtypeName */);
@@ -1277,8 +1282,7 @@ public class Vpn {

        mNetworkAgent = new NetworkAgent(mContext, mLooper, NETWORKTYPE /* logtag */,
                mNetworkCapabilities, lp,
                ConnectivityConstants.VPN_DEFAULT_SCORE, networkAgentConfig,
                new NetworkProvider(mContext, mLooper, VPN_AGENT_NAME)) {
                ConnectivityConstants.VPN_DEFAULT_SCORE, networkAgentConfig, mNetworkProvider) {
            @Override
            public void unwanted() {
                // We are user controlled, not driven by NetworkRequest.
@@ -1639,6 +1643,9 @@ public class Vpn {

        // Quit any active connections
        agentDisconnect();

        // The provider has been registered in the constructor, which is called in onUserStart.
        mConnectivityManager.unregisterNetworkProvider(mNetworkProvider);
    }

    /**
+10 −2
Original line number Diff line number Diff line
@@ -41,6 +41,7 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
@@ -86,10 +87,10 @@ import android.os.Build.VERSION_CODES;
import android.os.Bundle;
import android.os.ConditionVariable;
import android.os.INetworkManagementService;
import android.os.Looper;
import android.os.Process;
import android.os.UserHandle;
import android.os.UserManager;
import android.os.test.TestLooper;
import android.provider.Settings;
import android.security.Credentials;
import android.security.KeyStore;
@@ -224,6 +225,8 @@ public class VpnTest {
                .thenReturn(mNotificationManager);
        when(mContext.getSystemService(eq(Context.CONNECTIVITY_SERVICE)))
                .thenReturn(mConnectivityManager);
        when(mContext.getSystemServiceName(eq(ConnectivityManager.class)))
                .thenReturn(Context.CONNECTIVITY_SERVICE);
        when(mContext.getSystemService(eq(Context.IPSEC_SERVICE))).thenReturn(mIpSecManager);
        when(mContext.getString(R.string.config_customVpnAlwaysOnDisconnectedDialogComponent))
                .thenReturn(Resources.getSystem().getString(
@@ -1286,8 +1289,13 @@ public class VpnTest {
        doReturn(UserHandle.of(userId)).when(asUserContext).getUser();
        when(mContext.createContextAsUser(eq(UserHandle.of(userId)), anyInt()))
                .thenReturn(asUserContext);
        return new Vpn(Looper.myLooper(), mContext, new TestDeps(), mNetService,
        final TestLooper testLooper = new TestLooper();
        final Vpn vpn = new Vpn(testLooper.getLooper(), mContext, new TestDeps(), mNetService,
                userId, mKeyStore, mSystemServices, mIkev2SessionCreator);
        verify(mConnectivityManager, times(1)).registerNetworkProvider(argThat(
                provider -> provider.getName().contains("VpnNetworkProvider")
        ));
        return vpn;
    }

    private static void assertBlocked(Vpn vpn, int... uids) {