Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 07d6b5f1 authored by Lorenzo Colitti's avatar Lorenzo Colitti Committed by android-build-merger
Browse files

Merge "Add tests for strict mode private DNS validation." am: 6a4ae85e23

am: baee9317b5

Change-Id: I53ab6d9d30fa581347cf5aed4091d6a968946cd9
parents 56b770c3 7c656a5d
Loading
Loading
Loading
Loading
+165 −36
Original line number Diff line number Diff line
@@ -42,10 +42,10 @@ import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@@ -66,6 +66,7 @@ import android.net.NetworkCapabilities;
import android.net.NetworkInfo;
import android.net.captiveportal.CaptivePortalProbeResult;
import android.net.metrics.IpConnectivityLog;
import android.net.shared.PrivateDnsConfig;
import android.net.util.SharedLog;
import android.net.wifi.WifiInfo;
import android.net.wifi.WifiManager;
@@ -73,6 +74,7 @@ import android.os.Bundle;
import android.os.ConditionVariable;
import android.os.Handler;
import android.os.Looper;
import android.os.Process;
import android.os.RemoteException;
import android.os.SystemClock;
import android.provider.Settings;
@@ -132,6 +134,7 @@ public class NetworkMonitorTest {
    private @Mock NetworkMonitor.Dependencies mDependencies;
    private @Mock INetworkMonitorCallbacks mCallbacks;
    private @Spy Network mNetwork = new Network(TEST_NETID);
    private @Mock Network mNonPrivateDnsBypassNetwork;
    private @Mock DataStallStatsUtils mDataStallStatsUtils;
    private @Mock WifiInfo mWifiInfo;
    private @Captor ArgumentCaptor<String> mNetworkTestedRedirectUrlCaptor;
@@ -166,31 +169,93 @@ public class NetworkMonitorTest {
    private static final NetworkCapabilities NO_INTERNET_CAPABILITIES = new NetworkCapabilities()
            .addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR);

    private void setDnsAnswers(String[] answers) throws UnknownHostException {
        if (answers == null) {
            doThrow(new UnknownHostException()).when(mNetwork).getAllByName(any());
            doNothing().when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any());
            return;
    /**
     * Fakes DNS responses.
     *
     * Allows test methods to configure the IP addresses that will be resolved by
     * Network#getAllByName and by DnsResolver#query.
     */
    class FakeDns {
        private final ArrayMap<String, List<InetAddress>> mAnswers = new ArrayMap<>();
        private boolean mNonBypassPrivateDnsWorking = true;

        /** Whether DNS queries on mNonBypassPrivateDnsWorking should succeed. */
        private void setNonBypassPrivateDnsWorking(boolean working) {
            mNonBypassPrivateDnsWorking = working;
        }

        /** Clears all DNS entries. */
        private synchronized void clearAll() {
            mAnswers.clear();
        }

        /** Returns the answer for a given name on the given mock network. */
        private synchronized List<InetAddress> getAnswer(Object mock, String hostname) {
            if (mock == mNonPrivateDnsBypassNetwork && !mNonBypassPrivateDnsWorking) {
                return null;
            }
            if (mAnswers.containsKey(hostname)) {
                return mAnswers.get(hostname);
            }
            return mAnswers.get("*");
        }

        /** Sets the answer for a given name. */
        private synchronized void setAnswer(String hostname, String[] answer)
                throws UnknownHostException {
            if (answer == null) {
                mAnswers.remove(hostname);
            } else {
                List<InetAddress> answerList = new ArrayList<>();
        for (String answer : answers) {
            answerList.add(InetAddresses.parseNumericAddress(answer));
                for (String addr : answer) {
                    answerList.add(InetAddresses.parseNumericAddress(addr));
                }
                mAnswers.put(hostname, answerList);
            }
        }

        /** Simulates a getAllByName call for the specified name on the specified mock network. */
        private InetAddress[] getAllByName(Object mock, String hostname)
                throws UnknownHostException {
            List<InetAddress> answer = getAnswer(mock, hostname);
            if (answer == null || answer.size() == 0) {
                throw new UnknownHostException(hostname);
            }
            return answer.toArray(new InetAddress[0]);
        }
        InetAddress[] answerArray = answerList.toArray(new InetAddress[0]);

        doReturn(answerArray).when(mNetwork).getAllByName(any());
        /** Starts mocking DNS queries. */
        private void startMocking() throws UnknownHostException {
            // Queries on mNetwork (i.e., bypassing private DNS) using getAllByName.
            doAnswer(invocation -> {
                return getAllByName(invocation.getMock(), invocation.getArgument(0));
            }).when(mNetwork).getAllByName(any());

        doAnswer((invocation) -> {
            // Queries on mNonBypassPrivateDnsNetwork using getAllByName.
            doAnswer(invocation -> {
                return getAllByName(invocation.getMock(), invocation.getArgument(0));
            }).when(mNonPrivateDnsBypassNetwork).getAllByName(any());

            // Queries on mNetwork (i.e., bypassing private DNS) using DnsResolver#query.
            doAnswer(invocation -> {
                String hostname = (String) invocation.getArgument(1);
                Executor executor = (Executor) invocation.getArgument(3);
                DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(5);

                List<InetAddress> answer = getAnswer(invocation.getMock(), hostname);
                if (answer != null && answer.size() > 0) {
                    new Handler(Looper.getMainLooper()).post(() -> {
                executor.execute(() -> callback.onAnswer(answerList, 0));
                        executor.execute(() -> callback.onAnswer(answer, 0));
                    });
                }
                // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
                return null;
        }).when(mDnsResolver).query(eq(mNetwork), any(), anyInt(), any(), any(), any());
            }).when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any());
        }
    }

    private FakeDns mFakeDns;

    @Before
    public void setUp() throws IOException {
        MockitoAnnotations.initMocks(this);
@@ -206,7 +271,7 @@ public class NetworkMonitorTest {
        when(mDependencies.getSetting(any(), eq(Settings.Global.CAPTIVE_PORTAL_HTTPS_URL), any()))
                .thenReturn(TEST_HTTPS_URL);

        doReturn(mNetwork).when(mNetwork).getPrivateDnsBypassingCopy();
        doReturn(mNetwork).when(mNonPrivateDnsBypassNetwork).getPrivateDnsBypassingCopy();

        when(mContext.getSystemService(Context.CONNECTIVITY_SERVICE)).thenReturn(mCm);
        when(mContext.getSystemService(Context.TELEPHONY_SERVICE)).thenReturn(mTelephony);
@@ -222,6 +287,9 @@ public class NetworkMonitorTest {
        setFallbackSpecs(null); // Test with no fallback spec by default
        when(mRandom.nextInt()).thenReturn(0);

        when(mResources.getInteger(eq(R.integer.config_captive_portal_dns_probe_timeout)))
                .thenReturn(500);

        doAnswer((invocation) -> {
            URL url = invocation.getArgument(0);
            switch(url.toString()) {
@@ -241,7 +309,9 @@ public class NetworkMonitorTest {
        when(mHttpConnection.getRequestProperties()).thenReturn(new ArrayMap<>());
        when(mHttpsConnection.getRequestProperties()).thenReturn(new ArrayMap<>());

        setDnsAnswers(new String[]{"2001:db8::1", "192.0.2.2"});
        mFakeDns = new FakeDns();
        mFakeDns.startMocking();
        mFakeDns.setAnswer("*", new String[]{"2001:db8::1", "192.0.2.2"});

        when(mContext.registerReceiver(any(BroadcastReceiver.class), any())).then((invocation) -> {
            mRegisteredReceivers.add(invocation.getArgument(0));
@@ -264,6 +334,7 @@ public class NetworkMonitorTest {

    @After
    public void tearDown() {
        mFakeDns.clearAll();
        assertTrue(mCreatedNetworkMonitors.size() > 0);
        // Make a local copy of mCreatedNetworkMonitors because during the iteration below,
        // WrappedNetworkMonitor#onQuitting will delete elements from it on the handler threads.
@@ -284,8 +355,8 @@ public class NetworkMonitorTest {
        private final ConditionVariable mQuitCv = new ConditionVariable(false);

        WrappedNetworkMonitor() {
                super(mContext, mCallbacks, mNetwork, mLogger, mValidationLogger, mDependencies,
                        mDataStallStatsUtils);
            super(mContext, mCallbacks, mNonPrivateDnsBypassNetwork, mLogger, mValidationLogger,
                    mDependencies, mDataStallStatsUtils);
        }

        @Override
@@ -314,23 +385,22 @@ public class NetworkMonitorTest {
        }
    }

    private WrappedNetworkMonitor makeMonitor() {
    private WrappedNetworkMonitor makeMonitor(NetworkCapabilities nc) {
        final WrappedNetworkMonitor nm = new WrappedNetworkMonitor();
        nm.start();
        setNetworkCapabilities(nm, nc);
        waitForIdle(nm.getHandler());
        mCreatedNetworkMonitors.add(nm);
        return nm;
    }

    private WrappedNetworkMonitor makeMeteredNetworkMonitor() {
        final WrappedNetworkMonitor nm = makeMonitor();
        setNetworkCapabilities(nm, METERED_CAPABILITIES);
        final WrappedNetworkMonitor nm = makeMonitor(METERED_CAPABILITIES);
        return nm;
    }

    private WrappedNetworkMonitor makeNotMeteredNetworkMonitor() {
        final WrappedNetworkMonitor nm = makeMonitor();
        setNetworkCapabilities(nm, NOT_METERED_CAPABILITIES);
        final WrappedNetworkMonitor nm = makeMonitor(NOT_METERED_CAPABILITIES);
        return nm;
    }

@@ -603,7 +673,7 @@ public class NetworkMonitorTest {
        setSslException(mHttpsConnection);
        setPortal302(mHttpConnection);

        final NetworkMonitor nm = makeMonitor();
        final NetworkMonitor nm = makeMonitor(METERED_CAPABILITIES);
        nm.notifyNetworkConnected(TEST_LINK_PROPERTIES, METERED_CAPABILITIES);

        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
@@ -637,6 +707,63 @@ public class NetworkMonitorTest {
        assertEquals(0, mRegisteredReceivers.size());
    }

    @Test
    public void testPrivateDnsSuccess() throws Exception {
        setStatus(mHttpsConnection, 204);
        setStatus(mHttpConnection, 204);
        mFakeDns.setAnswer("dns.google", new String[]{"2001:db8::53"});

        WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0]));
        wnm.notifyNetworkConnected(TEST_LINK_PROPERTIES, NOT_METERED_CAPABILITIES);
        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_VALID), eq(null));
    }

    @Test
    public void testPrivateDnsResolutionRetryUpdate() throws Exception {
        // Set a private DNS hostname that doesn't resolve and expect validation to fail.
        mFakeDns.setAnswer("dns.google", new String[0]);
        setStatus(mHttpsConnection, 204);
        setStatus(mHttpConnection, 204);

        WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0]));
        wnm.notifyNetworkConnected(TEST_LINK_PROPERTIES, NOT_METERED_CAPABILITIES);
        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_INVALID), eq(null));

        // Fix DNS and retry, expect validation to succeed.
        reset(mCallbacks);
        mFakeDns.setAnswer("dns.google", new String[]{"2001:db8::1"});

        wnm.forceReevaluation(Process.myUid());
        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_VALID), eq(null));

        // Change configuration to an invalid DNS name, expect validation to fail.
        reset(mCallbacks);
        mFakeDns.setAnswer("dns.bad", new String[0]);
        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.bad", new InetAddress[0]));
        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_INVALID), eq(null));

        // Change configuration back to working again, but make private DNS not work.
        // Expect validation to fail.
        reset(mCallbacks);
        mFakeDns.setNonBypassPrivateDnsWorking(false);
        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0]));
        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_INVALID), eq(null));

        // Make private DNS work again. Expect validation to succeed.
        reset(mCallbacks);
        mFakeDns.setNonBypassPrivateDnsWorking(true);
        wnm.forceReevaluation(Process.myUid());
        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_VALID), eq(null));
    }

    @Test
    public void testDataStall_StallSuspectedAndSendMetrics() throws IOException {
        WrappedNetworkMonitor wrappedMonitor = makeNotMeteredNetworkMonitor();
@@ -728,25 +855,27 @@ public class NetworkMonitorTest {
        WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
        final int shortTimeoutMs = 200;

        // Clear the wildcard DNS response created in setUp.
        mFakeDns.setAnswer("*", null);

        String[] expected = new String[]{"2001:db8::"};
        setDnsAnswers(expected);
        mFakeDns.setAnswer("www.google.com", expected);
        InetAddress[] actual = wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
        assertIpAddressArrayEquals(expected, actual);

        expected = new String[]{"2001:db8::", "192.0.2.1"};
        setDnsAnswers(expected);
        actual = wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
        mFakeDns.setAnswer("www.googleapis.com", expected);
        actual = wnm.sendDnsProbeWithTimeout("www.googleapis.com", shortTimeoutMs);
        assertIpAddressArrayEquals(expected, actual);

        expected = new String[0];
        setDnsAnswers(expected);
        mFakeDns.setAnswer("www.google.com", new String[0]);
        try {
            wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
            fail("No DNS results, expected UnknownHostException");
        } catch (UnknownHostException e) {
        }

        setDnsAnswers(null);
        mFakeDns.setAnswer("www.google.com", null);
        try {
            wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
            fail("DNS query timed out, expected UnknownHostException");
@@ -841,7 +970,7 @@ public class NetworkMonitorTest {
    }

    private NetworkMonitor runNetworkTest(NetworkCapabilities nc, int testResult) {
        final NetworkMonitor monitor = makeMonitor();
        final NetworkMonitor monitor = makeMonitor(nc);
        monitor.notifyNetworkConnected(TEST_LINK_PROPERTIES, nc);
        try {
            verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))