Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ecacd5e0 authored by Chalard Jean's avatar Chalard Jean
Browse files

Move VPN allowed UIDs into NetworkCapabilities.

Test: runtest frameworks-net
Test: also new specific tests for this new code
Test: also tested with VPN app
Test: also cts passing
Change-Id: If0311bae2bf99dedac959febadecf4f92f3064b8
parent ce1a9d8f
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -29,6 +29,7 @@ import com.android.internal.util.AsyncChannel;
import com.android.internal.util.Protocol;

import java.util.ArrayList;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

/**
@@ -224,6 +225,11 @@ public abstract class NetworkAgent extends Handler {
                Context.CONNECTIVITY_SERVICE);
        netId = cm.registerNetworkAgent(new Messenger(this), new NetworkInfo(ni),
                new LinkProperties(lp), new NetworkCapabilities(nc), score, misc);

        final Set<UidRange> uids = nc.getUids();
        if (null != uids) {
            addUidRanges(uids.toArray(new UidRange[uids.size()]));
        }
    }

    @Override
+179 −21
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@ import android.annotation.IntDef;
import android.net.ConnectivityManager.NetworkCallback;
import android.os.Parcel;
import android.os.Parcelable;
import android.util.ArraySet;
import android.util.proto.ProtoOutputStream;

import com.android.internal.annotations.VisibleForTesting;
@@ -29,6 +30,7 @@ import com.android.internal.util.Preconditions;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.Objects;
import java.util.Set;
import java.util.StringJoiner;

/**
@@ -64,6 +66,7 @@ public final class NetworkCapabilities implements Parcelable {
            mLinkDownBandwidthKbps = nc.mLinkDownBandwidthKbps;
            mNetworkSpecifier = nc.mNetworkSpecifier;
            mSignalStrength = nc.mSignalStrength;
            mUids = nc.mUids;
        }
    }

@@ -77,6 +80,7 @@ public final class NetworkCapabilities implements Parcelable {
        mLinkUpBandwidthKbps = mLinkDownBandwidthKbps = LINK_BANDWIDTH_UNSPECIFIED;
        mNetworkSpecifier = null;
        mSignalStrength = SIGNAL_STRENGTH_UNSPECIFIED;
        mUids = null;
    }

    /**
@@ -836,6 +840,150 @@ public final class NetworkCapabilities implements Parcelable {
        return this.mSignalStrength == nc.mSignalStrength;
    }

    /**
     * List of UIDs this network applies to. No restriction if null.
     * <p>
     * This is typically (and at this time, only) used by VPN. This network is only available to
     * the UIDs in this list, and it is their default network. Apps in this list that wish to
     * bypass the VPN can do so iff the VPN app allows them to or if they are privileged. If this
     * member is null, then the network is not restricted by app UID. If it's an empty list, then
     * it means nobody can use it.
     * <p>
     * Please note that in principle a single app can be associated with multiple UIDs because
     * each app will have a different UID when it's run as a different (macro-)user. A single
     * macro user can only have a single active VPN app at any given time however.
     * <p>
     * Also please be aware this class does not try to enforce any normalization on this. Callers
     * can only alter the UIDs by setting them wholesale : this class does not provide any utility
     * to add or remove individual UIDs or ranges. If callers have any normalization needs on
     * their own (like requiring sortedness or no overlap) they need to enforce it
     * themselves. Some of the internal methods also assume this is normalized as in no adjacent
     * or overlapping ranges are present.
     *
     * @hide
     */
    private Set<UidRange> mUids = null;

    /**
     * Set the list of UIDs this network applies to.
     * This makes a copy of the set so that callers can't modify it after the call.
     * @hide
     */
    public NetworkCapabilities setUids(Set<UidRange> uids) {
        if (null == uids) {
            mUids = null;
        } else {
            mUids = new ArraySet<>(uids);
        }
        return this;
    }

    /**
     * Get the list of UIDs this network applies to.
     * This returns a copy of the set so that callers can't modify the original object.
     * @hide
     */
    public Set<UidRange> getUids() {
        return null == mUids ? null : new ArraySet<>(mUids);
    }

    /**
     * Test whether this network applies to this UID.
     * @hide
     */
    public boolean appliesToUid(int uid) {
        if (null == mUids) return true;
        for (UidRange range : mUids) {
            if (range.contains(uid)) {
                return true;
            }
        }
        return false;
    }

    /**
     * Tests if the set of UIDs that this network applies to is the same of the passed set of UIDs.
     * <p>
     * This test only checks whether equal range objects are in both sets. It will
     * return false if the ranges are not exactly the same, even if the covered UIDs
     * are for an equivalent result.
     * <p>
     * Note that this method is not very optimized, which is fine as long as it's not used very
     * often.
     * <p>
     * nc is assumed nonnull.
     *
     * @hide
     */
    @VisibleForTesting
    public boolean equalsUids(NetworkCapabilities nc) {
        Set<UidRange> comparedUids = nc.mUids;
        if (null == comparedUids) return null == mUids;
        if (null == mUids) return false;
        // Make a copy so it can be mutated to check that all ranges in mUids
        // also are in uids.
        final Set<UidRange> uids = new ArraySet<>(mUids);
        for (UidRange range : comparedUids) {
            if (!uids.contains(range)) {
                return false;
            }
            uids.remove(range);
        }
        return uids.isEmpty();
    }

    /**
     * Test whether the passed NetworkCapabilities satisfies the UIDs this capabilities require.
     *
     * This is called on the NetworkCapabilities embedded in a request with the capabilities
     * of an available network.
     * nc is assumed nonnull.
     * @see #appliesToUid
     * @hide
     */
    public boolean satisfiedByUids(NetworkCapabilities nc) {
        if (null == nc.mUids) return true; // The network satisfies everything.
        if (null == mUids) return false; // Not everything allowed but requires everything
        for (UidRange requiredRange : mUids) {
            if (!nc.appliesToUidRange(requiredRange)) {
                return false;
            }
        }
        return true;
    }

    /**
     * Returns whether this network applies to the passed ranges.
     * This assumes that to apply, the passed range has to be entirely contained
     * within one of the ranges this network applies to. If the ranges are not normalized,
     * this method may return false even though all required UIDs are covered because no
     * single range contained them all.
     * @hide
     */
    @VisibleForTesting
    public boolean appliesToUidRange(UidRange requiredRange) {
        if (null == mUids) return true;
        for (UidRange uidRange : mUids) {
            if (uidRange.containsRange(requiredRange)) {
                return true;
            }
        }
        return false;
    }

    /**
     * Combine the UIDs this network currently applies to with the UIDs the passed
     * NetworkCapabilities apply to.
     * nc is assumed nonnull.
     */
    private void combineUids(NetworkCapabilities nc) {
        if (null == nc.mUids || null == mUids) {
            mUids = null;
            return;
        }
        mUids.addAll(nc.mUids);
    }

    /**
     * Combine a set of Capabilities to this one.  Useful for coming up with the complete set
     * @hide
@@ -846,6 +994,7 @@ public final class NetworkCapabilities implements Parcelable {
        combineLinkBandwidths(nc);
        combineSpecifiers(nc);
        combineSignalStrength(nc);
        combineUids(nc);
    }

    /**
@@ -858,12 +1007,13 @@ public final class NetworkCapabilities implements Parcelable {
     * @hide
     */
    private boolean satisfiedByNetworkCapabilities(NetworkCapabilities nc, boolean onlyImmutable) {
        return (nc != null &&
                satisfiedByNetCapabilities(nc, onlyImmutable) &&
                satisfiedByTransportTypes(nc) &&
                (onlyImmutable || satisfiedByLinkBandwidths(nc)) &&
                satisfiedBySpecifier(nc) &&
                (onlyImmutable || satisfiedBySignalStrength(nc)));
        return (nc != null
                && satisfiedByNetCapabilities(nc, onlyImmutable)
                && satisfiedByTransportTypes(nc)
                && (onlyImmutable || satisfiedByLinkBandwidths(nc))
                && satisfiedBySpecifier(nc)
                && (onlyImmutable || satisfiedBySignalStrength(nc))
                && (onlyImmutable || satisfiedByUids(nc)));
    }

    /**
@@ -947,23 +1097,25 @@ public final class NetworkCapabilities implements Parcelable {
    public boolean equals(Object obj) {
        if (obj == null || (obj instanceof NetworkCapabilities == false)) return false;
        NetworkCapabilities that = (NetworkCapabilities) obj;
        return (equalsNetCapabilities(that) &&
                equalsTransportTypes(that) &&
                equalsLinkBandwidths(that) &&
                equalsSignalStrength(that) &&
                equalsSpecifier(that));
        return (equalsNetCapabilities(that)
                && equalsTransportTypes(that)
                && equalsLinkBandwidths(that)
                && equalsSignalStrength(that)
                && equalsSpecifier(that)
                && equalsUids(that));
    }

    @Override
    public int hashCode() {
        return ((int)(mNetworkCapabilities & 0xFFFFFFFF) +
                ((int)(mNetworkCapabilities >> 32) * 3) +
                ((int)(mTransportTypes & 0xFFFFFFFF) * 5) +
                ((int)(mTransportTypes >> 32) * 7) +
                (mLinkUpBandwidthKbps * 11) +
                (mLinkDownBandwidthKbps * 13) +
                Objects.hashCode(mNetworkSpecifier) * 17 +
                (mSignalStrength * 19));
        return ((int) (mNetworkCapabilities & 0xFFFFFFFF)
                + ((int) (mNetworkCapabilities >> 32) * 3)
                + ((int) (mTransportTypes & 0xFFFFFFFF) * 5)
                + ((int) (mTransportTypes >> 32) * 7)
                + (mLinkUpBandwidthKbps * 11)
                + (mLinkDownBandwidthKbps * 13)
                + Objects.hashCode(mNetworkSpecifier) * 17
                + (mSignalStrength * 19)
                + Objects.hashCode(mUids) * 23);
    }

    @Override
@@ -978,6 +1130,7 @@ public final class NetworkCapabilities implements Parcelable {
        dest.writeInt(mLinkDownBandwidthKbps);
        dest.writeParcelable((Parcelable) mNetworkSpecifier, flags);
        dest.writeInt(mSignalStrength);
        dest.writeArraySet(new ArraySet<>(mUids));
    }

    public static final Creator<NetworkCapabilities> CREATOR =
@@ -992,6 +1145,8 @@ public final class NetworkCapabilities implements Parcelable {
                netCap.mLinkDownBandwidthKbps = in.readInt();
                netCap.mNetworkSpecifier = in.readParcelable(null);
                netCap.mSignalStrength = in.readInt();
                netCap.mUids = (ArraySet<UidRange>) in.readArraySet(
                        null /* ClassLoader, null for default */);
                return netCap;
            }
            @Override
@@ -1024,7 +1179,10 @@ public final class NetworkCapabilities implements Parcelable {

        String signalStrength = (hasSignalStrength() ? " SignalStrength: " + mSignalStrength : "");

        return "[" + transports + capabilities + upBand + dnBand + specifier + signalStrength + "]";
        String uids = (null != mUids ? " Uids: <" + mUids + ">" : "");

        return "[" + transports + capabilities + upBand + dnBand + specifier + signalStrength
            + uids + "]";
    }

    /**
+27 −55
Original line number Diff line number Diff line
@@ -163,19 +163,6 @@ public class Vpn {
     */
    private boolean mLockdown = false;

    /**
     * List of UIDs that are set to use this VPN by default. Normally, every UID in the user is
     * added to this set but that can be changed by adding allowed or disallowed applications. It
     * is non-null iff the VPN is connected.
     *
     * Unless the VPN has set allowBypass=true, these UIDs are forced into the VPN.
     *
     * @see VpnService.Builder#addAllowedApplication(String)
     * @see VpnService.Builder#addDisallowedApplication(String)
     */
    @GuardedBy("this")
    private Set<UidRange> mVpnUsers = null;

    /**
     * List of UIDs for which networking should be blocked until VPN is ready, during brief periods
     * when VPN is not running. For example, during system startup or after a crash.
@@ -688,7 +675,7 @@ public class Vpn {
                agentDisconnect();
                jniReset(mInterface);
                mInterface = null;
                mVpnUsers = null;
                mNetworkCapabilities.setUids(null);
            }

            // Revoke the connection or stop LegacyVpnRunner.
@@ -857,6 +844,8 @@ public class Vpn {
        NetworkMisc networkMisc = new NetworkMisc();
        networkMisc.allowBypass = mConfig.allowBypass && !mLockdown;

        mNetworkCapabilities.setUids(createUserAndRestrictedProfilesRanges(mUserHandle,
                mConfig.allowedApplications, mConfig.disallowedApplications));
        long token = Binder.clearCallingIdentity();
        try {
            mNetworkAgent = new NetworkAgent(mLooper, mContext, NETWORKTYPE /* logtag */,
@@ -869,11 +858,6 @@ public class Vpn {
        } finally {
            Binder.restoreCallingIdentity(token);
        }

        mVpnUsers = createUserAndRestrictedProfilesRanges(mUserHandle,
                mConfig.allowedApplications, mConfig.disallowedApplications);
        mNetworkAgent.addUidRanges(mVpnUsers.toArray(new UidRange[mVpnUsers.size()]));

        mNetworkInfo.setIsAvailable(true);
        updateState(DetailedState.CONNECTED, "agentConnect");
    }
@@ -953,7 +937,7 @@ public class Vpn {
        Connection oldConnection = mConnection;
        NetworkAgent oldNetworkAgent = mNetworkAgent;
        mNetworkAgent = null;
        Set<UidRange> oldUsers = mVpnUsers;
        Set<UidRange> oldUsers = mNetworkCapabilities.getUids();

        // Configure the interface. Abort if any of these steps fails.
        ParcelFileDescriptor tun = ParcelFileDescriptor.adoptFd(jniCreate(config.mtu));
@@ -1011,7 +995,7 @@ public class Vpn {
            // restore old state
            mConfig = oldConfig;
            mConnection = oldConnection;
            mVpnUsers = oldUsers;
            mNetworkCapabilities.setUids(oldUsers);
            mNetworkAgent = oldNetworkAgent;
            mInterface = oldInterface;
            throw e;
@@ -1131,10 +1115,12 @@ public class Vpn {

    // Returns the subset of the full list of active UID ranges the VPN applies to (mVpnUsers) that
    // apply to userHandle.
    private List<UidRange> uidRangesForUser(int userHandle) {
    static private List<UidRange> uidRangesForUser(int userHandle, Set<UidRange> existingRanges) {
        // UidRange#createForUser returns the entire range of UIDs available to a macro-user.
        // This is something like 0-99999 ; {@see UserHandle#PER_USER_RANGE}
        final UidRange userRange = UidRange.createForUser(userHandle);
        final List<UidRange> ranges = new ArrayList<UidRange>();
        for (UidRange range : mVpnUsers) {
        for (UidRange range : existingRanges) {
            if (userRange.containsRange(range)) {
                ranges.add(range);
            }
@@ -1142,28 +1128,20 @@ public class Vpn {
        return ranges;
    }

    private void removeVpnUserLocked(int userHandle) {
        if (mVpnUsers == null) {
            throw new IllegalStateException("VPN is not active");
        }
        final List<UidRange> ranges = uidRangesForUser(userHandle);
        if (mNetworkAgent != null) {
            mNetworkAgent.removeUidRanges(ranges.toArray(new UidRange[ranges.size()]));
        }
        mVpnUsers.removeAll(ranges);
    }

    public void onUserAdded(int userHandle) {
        // If the user is restricted tie them to the parent user's VPN
        UserInfo user = UserManager.get(mContext).getUserInfo(userHandle);
        if (user.isRestricted() && user.restrictedProfileParentId == mUserHandle) {
            synchronized(Vpn.this) {
                if (mVpnUsers != null) {
                final Set<UidRange> existingRanges = mNetworkCapabilities.getUids();
                if (existingRanges != null) {
                    try {
                        addUserToRanges(mVpnUsers, userHandle, mConfig.allowedApplications,
                        addUserToRanges(existingRanges, userHandle, mConfig.allowedApplications,
                                mConfig.disallowedApplications);
                        mNetworkCapabilities.setUids(existingRanges);
                        if (mNetworkAgent != null) {
                            final List<UidRange> ranges = uidRangesForUser(userHandle);
                            final List<UidRange> ranges =
                                uidRangesForUser(userHandle, mNetworkCapabilities.getUids());
                            mNetworkAgent.addUidRanges(ranges.toArray(new UidRange[ranges.size()]));
                        }
                    } catch (Exception e) {
@@ -1180,9 +1158,17 @@ public class Vpn {
        UserInfo user = UserManager.get(mContext).getUserInfo(userHandle);
        if (user.isRestricted() && user.restrictedProfileParentId == mUserHandle) {
            synchronized(Vpn.this) {
                if (mVpnUsers != null) {
                final Set<UidRange> existingRanges = mNetworkCapabilities.getUids();
                if (existingRanges != null) {
                    try {
                        removeVpnUserLocked(userHandle);
                        final List<UidRange> removedRanges =
                            uidRangesForUser(userHandle, existingRanges);
                        if (mNetworkAgent != null) {
                            mNetworkAgent.removeUidRanges(removedRanges.toArray(
                                new UidRange[removedRanges.size()]));
                        }
                        existingRanges.removeAll(removedRanges);
                        mNetworkCapabilities.setUids(existingRanges);
                    } catch (Exception e) {
                        Log.wtf(TAG, "Failed to remove restricted user to owner", e);
                    }
@@ -1226,15 +1212,6 @@ public class Vpn {
    private void setVpnForcedLocked(boolean enforce) {
        final List<String> exemptedPackages =
                isNullOrLegacyVpn(mPackage) ? null : Collections.singletonList(mPackage);
        setVpnForcedWithExemptionsLocked(enforce, exemptedPackages);
    }

    /**
     * @see #setVpnForcedLocked
     */
    @GuardedBy("this")
    private void setVpnForcedWithExemptionsLocked(boolean enforce,
            @Nullable List<String> exemptedPackages) {
        final Set<UidRange> removedRanges = new ArraySet<>(mBlockedUsers);

        Set<UidRange> addedRanges = Collections.emptySet();
@@ -1314,7 +1291,7 @@ public class Vpn {
            synchronized (Vpn.this) {
                if (interfaze.equals(mInterface) && jniCheck(interfaze) == 0) {
                    mStatusIntent = null;
                    mVpnUsers = null;
                    mNetworkCapabilities.setUids(null);
                    mConfig = null;
                    mInterface = null;
                    if (mConnection != null) {
@@ -1433,12 +1410,7 @@ public class Vpn {
        if (!isRunningLocked()) {
            return false;
        }
        for (UidRange uidRange : mVpnUsers) {
            if (uidRange.contains(uid)) {
                return true;
            }
        }
        return false;
        return mNetworkCapabilities.appliesToUid(uid);
    }

    /**
+83 −0
Original line number Diff line number Diff line
@@ -34,12 +34,15 @@ import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;

import android.os.Parcel;
import android.support.test.runner.AndroidJUnit4;
import android.test.suitebuilder.annotation.SmallTest;
import android.util.ArraySet;

import org.junit.Test;
import org.junit.runner.RunWith;

import java.util.Set;

@RunWith(AndroidJUnit4.class)
@SmallTest
@@ -189,4 +192,84 @@ public class NetworkCapabilitiesTest {
        assertEquals(20, NetworkCapabilities
                .maxBandwidth(10, 20));
    }

    @Test
    public void testSetUids() {
        final NetworkCapabilities netCap = new NetworkCapabilities();
        final Set<UidRange> uids = new ArraySet<>();
        uids.add(new UidRange(50, 100));
        uids.add(new UidRange(3000, 4000));
        netCap.setUids(uids);
        assertTrue(netCap.appliesToUid(50));
        assertTrue(netCap.appliesToUid(80));
        assertTrue(netCap.appliesToUid(100));
        assertTrue(netCap.appliesToUid(3000));
        assertTrue(netCap.appliesToUid(3001));
        assertFalse(netCap.appliesToUid(10));
        assertFalse(netCap.appliesToUid(25));
        assertFalse(netCap.appliesToUid(49));
        assertFalse(netCap.appliesToUid(101));
        assertFalse(netCap.appliesToUid(2000));
        assertFalse(netCap.appliesToUid(100000));

        assertTrue(netCap.appliesToUidRange(new UidRange(50, 100)));
        assertTrue(netCap.appliesToUidRange(new UidRange(70, 72)));
        assertTrue(netCap.appliesToUidRange(new UidRange(3500, 3912)));
        assertFalse(netCap.appliesToUidRange(new UidRange(1, 100)));
        assertFalse(netCap.appliesToUidRange(new UidRange(49, 100)));
        assertFalse(netCap.appliesToUidRange(new UidRange(1, 10)));
        assertFalse(netCap.appliesToUidRange(new UidRange(60, 101)));
        assertFalse(netCap.appliesToUidRange(new UidRange(60, 3400)));

        NetworkCapabilities netCap2 = new NetworkCapabilities();
        assertFalse(netCap2.satisfiedByUids(netCap));
        assertFalse(netCap2.equalsUids(netCap));
        netCap2.setUids(uids);
        assertTrue(netCap2.satisfiedByUids(netCap));
        assertTrue(netCap.equalsUids(netCap2));
        assertTrue(netCap2.equalsUids(netCap));

        uids.add(new UidRange(600, 700));
        netCap2.setUids(uids);
        assertFalse(netCap2.satisfiedByUids(netCap));
        assertFalse(netCap.appliesToUid(650));
        assertTrue(netCap2.appliesToUid(650));
        netCap.combineCapabilities(netCap2);
        assertTrue(netCap2.satisfiedByUids(netCap));
        assertTrue(netCap.appliesToUid(650));
        assertFalse(netCap.appliesToUid(500));

        assertFalse(new NetworkCapabilities().satisfiedByUids(netCap));
        netCap.combineCapabilities(new NetworkCapabilities());
        assertTrue(netCap.appliesToUid(500));
        assertTrue(netCap.appliesToUidRange(new UidRange(1, 100000)));
        assertFalse(netCap2.appliesToUid(500));
        assertFalse(netCap2.appliesToUidRange(new UidRange(1, 100000)));
        assertTrue(new NetworkCapabilities().satisfiedByUids(netCap));
    }

    @Test
    public void testParcelNetworkCapabilities() {
        final Set<UidRange> uids = new ArraySet<>();
        uids.add(new UidRange(50, 100));
        uids.add(new UidRange(3000, 4000));
        final NetworkCapabilities netCap = new NetworkCapabilities()
            .addCapability(NET_CAPABILITY_INTERNET)
            .setUids(uids)
            .addCapability(NET_CAPABILITY_EIMS)
            .addCapability(NET_CAPABILITY_NOT_METERED);
        assertEqualsThroughMarshalling(netCap);
    }

    private void assertEqualsThroughMarshalling(NetworkCapabilities netCap) {
        Parcel p = Parcel.obtain();
        netCap.writeToParcel(p, /* flags */ 0);
        p.setDataPosition(0);
        byte[] marshalledData = p.marshall();

        p = Parcel.obtain();
        p.unmarshall(marshalledData, 0, marshalledData.length);
        p.setDataPosition(0);
        assertEquals(NetworkCapabilities.CREATOR.createFromParcel(p), netCap);
    }
}