Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 4ca19e83 authored by Chad Brubaker's avatar Chad Brubaker
Browse files

Add per user VPN support

VPNs are now per user instead of global. A VPN set by user A routes only
user A's traffic and no other user can access it.

Change-Id: Ia66463637b6bd088b05768076a1db897fe95c46c
parent 12324b46
Loading
Loading
Loading
Loading
+7 −9
Original line number Diff line number Diff line
@@ -36,6 +36,7 @@ import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.Socket;
import java.util.ArrayList;
import java.util.List;

/**
 * VpnService is a base class for applications to extend and build their
@@ -253,8 +254,8 @@ public class VpnService extends Service {
    public class Builder {

        private final VpnConfig mConfig = new VpnConfig();
        private final StringBuilder mAddresses = new StringBuilder();
        private final StringBuilder mRoutes = new StringBuilder();
        private final List<LinkAddress> mAddresses = new ArrayList<LinkAddress>();
        private final List<RouteInfo> mRoutes = new ArrayList<RouteInfo>();

        public Builder() {
            mConfig.user = VpnService.this.getClass().getName();
@@ -328,9 +329,7 @@ public class VpnService extends Service {
            if (address.isAnyLocalAddress()) {
                throw new IllegalArgumentException("Bad address");
            }

            mAddresses.append(' ')
                    .append(address.getHostAddress()).append('/').append(prefixLength);
            mAddresses.add(new LinkAddress(address, prefixLength));
            return this;
        }

@@ -364,8 +363,7 @@ public class VpnService extends Service {
                    }
                }
            }

            mRoutes.append(' ').append(address.getHostAddress()).append('/').append(prefixLength);
            mRoutes.add(new RouteInfo(new LinkAddress(address, prefixLength), null));
            return this;
        }

@@ -466,8 +464,8 @@ public class VpnService extends Service {
         * @see VpnService
         */
        public ParcelFileDescriptor establish() {
            mConfig.addresses = mAddresses.toString();
            mConfig.routes = mRoutes.toString();
            mConfig.addresses = mAddresses;
            mConfig.routes = mRoutes;

            try {
                return getService().establishVpn(mConfig);
+38 −6
Original line number Diff line number Diff line
@@ -21,10 +21,14 @@ import android.content.Context;
import android.content.Intent;
import android.os.Parcel;
import android.os.Parcelable;
import android.net.RouteInfo;
import android.net.LinkAddress;

import com.android.internal.util.Preconditions;

import java.net.InetAddress;
import java.util.List;
import java.util.ArrayList;

/**
 * A simple container used to carry information in VpnBuilder, VpnDialogs,
@@ -61,14 +65,42 @@ public class VpnConfig implements Parcelable {
    public String interfaze;
    public String session;
    public int mtu = -1;
    public String addresses;
    public String routes;
    public List<LinkAddress> addresses = new ArrayList<LinkAddress>();
    public List<RouteInfo> routes = new ArrayList<RouteInfo>();
    public List<String> dnsServers;
    public List<String> searchDomains;
    public PendingIntent configureIntent;
    public long startTime = -1;
    public boolean legacy;

    public void addLegacyRoutes(String routesStr) {
        if (routesStr.trim().equals("")) {
            return;
        }
        String[] routes = routesStr.trim().split(" ");
        for (String route : routes) {
            //each route is ip/prefix
            String[] split = route.split("/");
            RouteInfo info = new RouteInfo(new LinkAddress
                    (InetAddress.parseNumericAddress(split[0]), Integer.parseInt(split[1])), null);
            this.routes.add(info);
        }
    }

    public void addLegacyAddresses(String addressesStr) {
        if (addressesStr.trim().equals("")) {
            return;
        }
        String[] addresses = addressesStr.trim().split(" ");
        for (String address : addresses) {
            //each address is ip/prefix
            String[] split = address.split("/");
            LinkAddress addr = new LinkAddress(InetAddress.parseNumericAddress(split[0]),
                    Integer.parseInt(split[1]));
            this.addresses.add(addr);
        }
    }

    @Override
    public int describeContents() {
        return 0;
@@ -80,8 +112,8 @@ public class VpnConfig implements Parcelable {
        out.writeString(interfaze);
        out.writeString(session);
        out.writeInt(mtu);
        out.writeString(addresses);
        out.writeString(routes);
        out.writeTypedList(addresses);
        out.writeTypedList(routes);
        out.writeStringList(dnsServers);
        out.writeStringList(searchDomains);
        out.writeParcelable(configureIntent, flags);
@@ -98,8 +130,8 @@ public class VpnConfig implements Parcelable {
            config.interfaze = in.readString();
            config.session = in.readString();
            config.mtu = in.readInt();
            config.addresses = in.readString();
            config.routes = in.readString();
            in.readTypedList(config.addresses, LinkAddress.CREATOR);
            in.readTypedList(config.routes, RouteInfo.CREATOR);
            config.dnsServers = in.createStringArrayList();
            config.searchDomains = in.createStringArrayList();
            config.configureIntent = in.readParcelable(null);
+144 −14
Original line number Diff line number Diff line
@@ -97,6 +97,7 @@ import android.telephony.TelephonyManager;
import android.text.TextUtils;
import android.util.Slog;
import android.util.SparseIntArray;
import android.util.SparseArray;

import com.android.internal.R;
import com.android.internal.net.LegacyVpnInfo;
@@ -116,6 +117,8 @@ import com.android.server.net.LockdownVpnTracker;
import com.google.android.collect.Lists;
import com.google.android.collect.Sets;

import com.android.internal.annotations.GuardedBy;

import dalvik.system.DexClassLoader;

import java.io.FileDescriptor;
@@ -171,7 +174,8 @@ public class ConnectivityService extends IConnectivityManager.Stub {

    private KeyStore mKeyStore;

    private Vpn mVpn;
    @GuardedBy("mVpns")
    private final SparseArray<Vpn> mVpns = new SparseArray<Vpn>();
    private VpnCallback mVpnCallback = new VpnCallback();

    private boolean mLockdownEnabled;
@@ -583,10 +587,13 @@ public class ConnectivityService extends IConnectivityManager.Stub {
                                  mTethering.getTetherableWifiRegexs().length != 0 ||
                                  mTethering.getTetherableBluetoothRegexs().length != 0) &&
                                 mTethering.getUpstreamIfaceTypes().length != 0);
        //set up the listener for user state for creating user VPNs

        mVpn = new Vpn(mContext, mVpnCallback, mNetd, this);
        mVpn.startMonitoring(mContext, mTrackerHandler);

        IntentFilter intentFilter = new IntentFilter();
        intentFilter.addAction(Intent.ACTION_USER_STARTING);
        intentFilter.addAction(Intent.ACTION_USER_STOPPING);
        mContext.registerReceiverAsUser(
                mUserIntentReceiver, UserHandle.ALL, intentFilter, null, null);
        mClat = new Nat464Xlat(mContext, mNetd, this, mTrackerHandler);

        try {
@@ -2313,7 +2320,11 @@ public class ConnectivityService extends IConnectivityManager.Stub {
                            // Tell VPN the interface is down. It is a temporary
                            // but effective fix to make VPN aware of the change.
                            if ((resetMask & NetworkUtils.RESET_IPV4_ADDRESSES) != 0) {
                                mVpn.interfaceStatusChanged(iface, false);
                                synchronized(mVpns) {
                                    for (int i = 0; i < mVpns.size(); i++) {
                                        mVpns.valueAt(i).interfaceStatusChanged(iface, false);
                                    }
                                }
                            }
                        }
                        if (resetDns) {
@@ -2570,7 +2581,6 @@ public class ConnectivityService extends IConnectivityManager.Stub {

        try {
            mNetd.setDnsServersForInterface(iface, NetworkUtils.makeStrings(dnses), domains);
            mNetd.setDefaultInterfaceForDns(iface);
            for (InetAddress dns : dnses) {
                ++last;
                String key = "net.dns" + last;
@@ -3305,8 +3315,12 @@ public class ConnectivityService extends IConnectivityManager.Stub {
        throwIfLockdownEnabled();
        try {
            int type = mActiveDefaultNetwork;
            int user = UserHandle.getUserId(Binder.getCallingUid());
            if (ConnectivityManager.isNetworkTypeValid(type) && mNetTrackers[type] != null) {
                mVpn.protect(socket, mNetTrackers[type].getLinkProperties().getInterfaceName());
                synchronized(mVpns) {
                    mVpns.get(user).protect(socket,
                            mNetTrackers[type].getLinkProperties().getInterfaceName());
                }
                return true;
            }
        } catch (Exception e) {
@@ -3330,7 +3344,10 @@ public class ConnectivityService extends IConnectivityManager.Stub {
    @Override
    public boolean prepareVpn(String oldPackage, String newPackage) {
        throwIfLockdownEnabled();
        return mVpn.prepare(oldPackage, newPackage);
        int user = UserHandle.getUserId(Binder.getCallingUid());
        synchronized(mVpns) {
            return mVpns.get(user).prepare(oldPackage, newPackage);
        }
    }

    /**
@@ -3343,7 +3360,10 @@ public class ConnectivityService extends IConnectivityManager.Stub {
    @Override
    public ParcelFileDescriptor establishVpn(VpnConfig config) {
        throwIfLockdownEnabled();
        return mVpn.establish(config);
        int user = UserHandle.getUserId(Binder.getCallingUid());
        synchronized(mVpns) {
            return mVpns.get(user).establish(config);
        }
    }

    /**
@@ -3357,7 +3377,10 @@ public class ConnectivityService extends IConnectivityManager.Stub {
        if (egress == null) {
            throw new IllegalStateException("Missing active network connection");
        }
        mVpn.startLegacyVpn(profile, mKeyStore, egress);
        int user = UserHandle.getUserId(Binder.getCallingUid());
        synchronized(mVpns) {
            mVpns.get(user).startLegacyVpn(profile, mKeyStore, egress);
        }
    }

    /**
@@ -3369,7 +3392,10 @@ public class ConnectivityService extends IConnectivityManager.Stub {
    @Override
    public LegacyVpnInfo getLegacyVpnInfo() {
        throwIfLockdownEnabled();
        return mVpn.getLegacyVpnInfo();
        int user = UserHandle.getUserId(Binder.getCallingUid());
        synchronized(mVpns) {
            return mVpns.get(user).getLegacyVpnInfo();
        }
    }

    /**
@@ -3390,7 +3416,7 @@ public class ConnectivityService extends IConnectivityManager.Stub {
            mHandler.obtainMessage(EVENT_VPN_STATE_CHANGED, info).sendToTarget();
        }

        public void override(List<String> dnsServers, List<String> searchDomains) {
        public void override(String iface, List<String> dnsServers, List<String> searchDomains) {
            if (dnsServers == null) {
                restore();
                return;
@@ -3422,7 +3448,7 @@ public class ConnectivityService extends IConnectivityManager.Stub {

            // Apply DNS changes.
            synchronized (mDnsLock) {
                updateDnsLocked("VPN", "VPN", addresses, domains);
                updateDnsLocked("VPN", iface, addresses, domains);
                mDnsOverridden = true;
            }

@@ -3451,6 +3477,67 @@ public class ConnectivityService extends IConnectivityManager.Stub {
                }
            }
        }

        public void protect(ParcelFileDescriptor socket) {
            try {
                final int mark = mNetd.getMarkForProtect();
                NetworkUtils.markSocket(socket.getFd(), mark);
            } catch (RemoteException e) {
            }
        }

        public void setRoutes(String interfaze, List<RouteInfo> routes) {
            for (RouteInfo route : routes) {
                try {
                    mNetd.setMarkedForwardingRoute(interfaze, route);
                } catch (RemoteException e) {
                }
            }
        }

        public void setMarkedForwarding(String interfaze) {
            try {
                mNetd.setMarkedForwarding(interfaze);
            } catch (RemoteException e) {
            }
        }

        public void clearMarkedForwarding(String interfaze) {
            try {
                mNetd.clearMarkedForwarding(interfaze);
            } catch (RemoteException e) {
            }
        }

        public void addUserForwarding(String interfaze, int uid) {
            int uidStart = uid * UserHandle.PER_USER_RANGE;
            int uidEnd = uidStart + UserHandle.PER_USER_RANGE - 1;
            addUidForwarding(interfaze, uidStart, uidEnd);
        }

        public void clearUserForwarding(String interfaze, int uid) {
            int uidStart = uid * UserHandle.PER_USER_RANGE;
            int uidEnd = uidStart + UserHandle.PER_USER_RANGE - 1;
            clearUidForwarding(interfaze, uidStart, uidEnd);
        }

        public void addUidForwarding(String interfaze, int uidStart, int uidEnd) {
            try {
                mNetd.setUidRangeRoute(interfaze,uidStart, uidEnd);
                mNetd.setDnsInterfaceForUidRange(interfaze, uidStart, uidEnd);
            } catch (RemoteException e) {
            }

        }

        public void clearUidForwarding(String interfaze, int uidStart, int uidEnd) {
            try {
                mNetd.clearUidRangeRoute(interfaze, uidStart, uidEnd);
                mNetd.clearDnsInterfaceForUidRange(uidStart, uidEnd);
            } catch (RemoteException e) {
            }

        }
    }

    @Override
@@ -3471,7 +3558,11 @@ public class ConnectivityService extends IConnectivityManager.Stub {
            final String profileName = new String(mKeyStore.get(Credentials.LOCKDOWN_VPN));
            final VpnProfile profile = VpnProfile.decode(
                    profileName, mKeyStore.get(Credentials.VPN + profileName));
            setLockdownTracker(new LockdownVpnTracker(mContext, mNetd, this, mVpn, profile));
            int user = UserHandle.getUserId(Binder.getCallingUid());
            synchronized(mVpns) {
                setLockdownTracker(new LockdownVpnTracker(mContext, mNetd, this, mVpns.get(user),
                            profile));
            }
        } else {
            setLockdownTracker(null);
        }
@@ -4002,4 +4093,43 @@ public class ConnectivityService extends IConnectivityManager.Stub {

        return url;
    }

    private void onUserStart(int userId) {
        synchronized(mVpns) {
            Vpn userVpn = mVpns.get(userId);
            if (userVpn != null) {
                loge("Starting user already has a VPN");
                return;
            }
            userVpn = new Vpn(mContext, mVpnCallback, mNetd, this, userId);
            mVpns.put(userId, userVpn);
            userVpn.startMonitoring(mContext, mTrackerHandler);
        }
    }

    private void onUserStop(int userId) {
        synchronized(mVpns) {
            Vpn userVpn = mVpns.get(userId);
            if (userVpn == null) {
                loge("Stopping user has no VPN");
                return;
            }
            mVpns.delete(userId);
        }
    }

    private BroadcastReceiver mUserIntentReceiver = new BroadcastReceiver() {
        @Override
        public void onReceive(Context context, Intent intent) {
            final String action = intent.getAction();
            final int userId = intent.getIntExtra(Intent.EXTRA_USER_HANDLE, UserHandle.USER_NULL);
            if (userId == UserHandle.USER_NULL) return;

            if (Intent.ACTION_USER_STARTING.equals(action)) {
                onUserStart(userId);
            } else if (Intent.ACTION_USER_STOPPING.equals(action)) {
                onUserStop(userId);
            }
        }
    };
}
+110 −41

File changed.

Preview size limit exceeded, changes collapsed.

+15 −7
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ import android.content.Context;
import android.content.Intent;
import android.content.IntentFilter;
import android.net.LinkProperties;
import android.net.LinkAddress;
import android.net.NetworkInfo;
import android.net.NetworkInfo.DetailedState;
import android.net.NetworkInfo.State;
@@ -44,6 +45,8 @@ import com.android.server.ConnectivityService;
import com.android.server.EventLogTags;
import com.android.server.connectivity.Vpn;

import java.util.List;

/**
 * State tracker for lockdown mode. Watches for normal {@link NetworkInfo} to be
 * connected and kicks off VPN connection, managing any required {@code netd}
@@ -73,7 +76,7 @@ public class LockdownVpnTracker {

    private String mAcceptedEgressIface;
    private String mAcceptedIface;
    private String mAcceptedSourceAddr;
    private List<LinkAddress> mAcceptedSourceAddr;

    private int mErrorCount;

@@ -162,14 +165,15 @@ public class LockdownVpnTracker {

        } else if (vpnInfo.isConnected() && vpnConfig != null) {
            final String iface = vpnConfig.interfaze;
            final String sourceAddr = vpnConfig.addresses;
            final List<LinkAddress> sourceAddrs = vpnConfig.addresses;

            if (TextUtils.equals(iface, mAcceptedIface)
                    && TextUtils.equals(sourceAddr, mAcceptedSourceAddr)) {
                  && sourceAddrs.equals(mAcceptedSourceAddr)) {
                return;
            }

            Slog.d(TAG, "VPN connected using iface=" + iface + ", sourceAddr=" + sourceAddr);
            Slog.d(TAG, "VPN connected using iface=" + iface +
                    ", sourceAddr=" + sourceAddrs.toString());
            EventLogTags.writeLockdownVpnConnected(egressType);
            showNotification(R.string.vpn_lockdown_connected, R.drawable.vpn_connected);

@@ -177,11 +181,13 @@ public class LockdownVpnTracker {
                clearSourceRulesLocked();

                mNetService.setFirewallInterfaceRule(iface, true);
                mNetService.setFirewallEgressSourceRule(sourceAddr, true);
                for (LinkAddress addr : sourceAddrs) {
                    mNetService.setFirewallEgressSourceRule(addr.toString(), true);
                }

                mErrorCount = 0;
                mAcceptedIface = iface;
                mAcceptedSourceAddr = sourceAddr;
                mAcceptedSourceAddr = sourceAddrs;
            } catch (RemoteException e) {
                throw new RuntimeException("Problem setting firewall rules", e);
            }
@@ -263,7 +269,9 @@ public class LockdownVpnTracker {
                mAcceptedIface = null;
            }
            if (mAcceptedSourceAddr != null) {
                mNetService.setFirewallEgressSourceRule(mAcceptedSourceAddr, false);
                for (LinkAddress addr : mAcceptedSourceAddr) {
                    mNetService.setFirewallEgressSourceRule(addr.toString(), false);
                }
                mAcceptedSourceAddr = null;
            }
        } catch (RemoteException e) {