Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ab3611bc authored by Lorenzo Colitti's avatar Lorenzo Colitti
Browse files

Add tests for strict mode private DNS validation.

Test successful and failed validation, and updating the config.
In order to do this, add a FakeDns class so we can change DNS
responses dynamically while the test is running.

Also a couple of minor fixes:
1. Make sure the DNS timeout is set. Before this CL, it was
   always 0. Not sure why. It does seem to be set to the default
   value (12500) when actually running on device. We didn't
   catch this because the only tests that use the timeout set it
   explicitly.
2. Make runNetworkTest a bit more realistic: always send
   NetworkCapabilities *before* calling notifyNetworkConnected.
   This is what ConnectivityService does.

Bug: 122652057
Test: atest FrameworksNetTests NetworkStackTests
Test: atest --generate-new-metrics 50 NetworkStackTests:com.android.server.connectivity.NetworkMonitorTest
Change-Id: Ifd6694262501874f3261c864a049cb35c6afb9c8
Merged-In: Ifd6694262501874f3261c864a049cb35c6afb9c8
(cherry picked from commit 89909bef)
parent 64c39a18
Loading
Loading
Loading
Loading
+165 −36
Original line number Diff line number Diff line
@@ -42,10 +42,10 @@ import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@@ -66,6 +66,7 @@ import android.net.NetworkCapabilities;
import android.net.NetworkInfo;
import android.net.captiveportal.CaptivePortalProbeResult;
import android.net.metrics.IpConnectivityLog;
import android.net.shared.PrivateDnsConfig;
import android.net.util.SharedLog;
import android.net.wifi.WifiInfo;
import android.net.wifi.WifiManager;
@@ -73,6 +74,7 @@ import android.os.Bundle;
import android.os.ConditionVariable;
import android.os.Handler;
import android.os.Looper;
import android.os.Process;
import android.os.RemoteException;
import android.os.SystemClock;
import android.provider.Settings;
@@ -132,6 +134,7 @@ public class NetworkMonitorTest {
    private @Mock NetworkMonitor.Dependencies mDependencies;
    private @Mock INetworkMonitorCallbacks mCallbacks;
    private @Spy Network mNetwork = new Network(TEST_NETID);
    private @Mock Network mNonPrivateDnsBypassNetwork;
    private @Mock DataStallStatsUtils mDataStallStatsUtils;
    private @Mock WifiInfo mWifiInfo;
    private @Captor ArgumentCaptor<String> mNetworkTestedRedirectUrlCaptor;
@@ -166,31 +169,93 @@ public class NetworkMonitorTest {
    private static final NetworkCapabilities NO_INTERNET_CAPABILITIES = new NetworkCapabilities()
            .addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR);

    private void setDnsAnswers(String[] answers) throws UnknownHostException {
        if (answers == null) {
            doThrow(new UnknownHostException()).when(mNetwork).getAllByName(any());
            doNothing().when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any());
            return;
    /**
     * Fakes DNS responses.
     *
     * Allows test methods to configure the IP addresses that will be resolved by
     * Network#getAllByName and by DnsResolver#query.
     */
    class FakeDns {
        private final ArrayMap<String, List<InetAddress>> mAnswers = new ArrayMap<>();
        private boolean mNonBypassPrivateDnsWorking = true;

        /** Whether DNS queries on mNonBypassPrivateDnsWorking should succeed. */
        private void setNonBypassPrivateDnsWorking(boolean working) {
            mNonBypassPrivateDnsWorking = working;
        }

        /** Clears all DNS entries. */
        private synchronized void clearAll() {
            mAnswers.clear();
        }

        /** Returns the answer for a given name on the given mock network. */
        private synchronized List<InetAddress> getAnswer(Object mock, String hostname) {
            if (mock == mNonPrivateDnsBypassNetwork && !mNonBypassPrivateDnsWorking) {
                return null;
            }
            if (mAnswers.containsKey(hostname)) {
                return mAnswers.get(hostname);
            }
            return mAnswers.get("*");
        }

        /** Sets the answer for a given name. */
        private synchronized void setAnswer(String hostname, String[] answer)
                throws UnknownHostException {
            if (answer == null) {
                mAnswers.remove(hostname);
            } else {
                List<InetAddress> answerList = new ArrayList<>();
        for (String answer : answers) {
            answerList.add(InetAddresses.parseNumericAddress(answer));
                for (String addr : answer) {
                    answerList.add(InetAddresses.parseNumericAddress(addr));
                }
                mAnswers.put(hostname, answerList);
            }
        }

        /** Simulates a getAllByName call for the specified name on the specified mock network. */
        private InetAddress[] getAllByName(Object mock, String hostname)
                throws UnknownHostException {
            List<InetAddress> answer = getAnswer(mock, hostname);
            if (answer == null || answer.size() == 0) {
                throw new UnknownHostException(hostname);
            }
            return answer.toArray(new InetAddress[0]);
        }
        InetAddress[] answerArray = answerList.toArray(new InetAddress[0]);

        doReturn(answerArray).when(mNetwork).getAllByName(any());
        /** Starts mocking DNS queries. */
        private void startMocking() throws UnknownHostException {
            // Queries on mNetwork (i.e., bypassing private DNS) using getAllByName.
            doAnswer(invocation -> {
                return getAllByName(invocation.getMock(), invocation.getArgument(0));
            }).when(mNetwork).getAllByName(any());

        doAnswer((invocation) -> {
            // Queries on mNonBypassPrivateDnsNetwork using getAllByName.
            doAnswer(invocation -> {
                return getAllByName(invocation.getMock(), invocation.getArgument(0));
            }).when(mNonPrivateDnsBypassNetwork).getAllByName(any());

            // Queries on mNetwork (i.e., bypassing private DNS) using DnsResolver#query.
            doAnswer(invocation -> {
                String hostname = (String) invocation.getArgument(1);
                Executor executor = (Executor) invocation.getArgument(3);
                DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(5);

                List<InetAddress> answer = getAnswer(invocation.getMock(), hostname);
                if (answer != null && answer.size() > 0) {
                    new Handler(Looper.getMainLooper()).post(() -> {
                executor.execute(() -> callback.onAnswer(answerList, 0));
                        executor.execute(() -> callback.onAnswer(answer, 0));
                    });
                }
                // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
                return null;
        }).when(mDnsResolver).query(eq(mNetwork), any(), anyInt(), any(), any(), any());
            }).when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any());
        }
    }

    private FakeDns mFakeDns;

    @Before
    public void setUp() throws IOException {
        MockitoAnnotations.initMocks(this);
@@ -206,7 +271,7 @@ public class NetworkMonitorTest {
        when(mDependencies.getSetting(any(), eq(Settings.Global.CAPTIVE_PORTAL_HTTPS_URL), any()))
                .thenReturn(TEST_HTTPS_URL);

        doReturn(mNetwork).when(mNetwork).getPrivateDnsBypassingCopy();
        doReturn(mNetwork).when(mNonPrivateDnsBypassNetwork).getPrivateDnsBypassingCopy();

        when(mContext.getSystemService(Context.CONNECTIVITY_SERVICE)).thenReturn(mCm);
        when(mContext.getSystemService(Context.TELEPHONY_SERVICE)).thenReturn(mTelephony);
@@ -222,6 +287,9 @@ public class NetworkMonitorTest {
        setFallbackSpecs(null); // Test with no fallback spec by default
        when(mRandom.nextInt()).thenReturn(0);

        when(mResources.getInteger(eq(R.integer.config_captive_portal_dns_probe_timeout)))
                .thenReturn(500);

        doAnswer((invocation) -> {
            URL url = invocation.getArgument(0);
            switch(url.toString()) {
@@ -241,7 +309,9 @@ public class NetworkMonitorTest {
        when(mHttpConnection.getRequestProperties()).thenReturn(new ArrayMap<>());
        when(mHttpsConnection.getRequestProperties()).thenReturn(new ArrayMap<>());

        setDnsAnswers(new String[]{"2001:db8::1", "192.0.2.2"});
        mFakeDns = new FakeDns();
        mFakeDns.startMocking();
        mFakeDns.setAnswer("*", new String[]{"2001:db8::1", "192.0.2.2"});

        when(mContext.registerReceiver(any(BroadcastReceiver.class), any())).then((invocation) -> {
            mRegisteredReceivers.add(invocation.getArgument(0));
@@ -264,6 +334,7 @@ public class NetworkMonitorTest {

    @After
    public void tearDown() {
        mFakeDns.clearAll();
        assertTrue(mCreatedNetworkMonitors.size() > 0);
        // Make a local copy of mCreatedNetworkMonitors because during the iteration below,
        // WrappedNetworkMonitor#onQuitting will delete elements from it on the handler threads.
@@ -284,8 +355,8 @@ public class NetworkMonitorTest {
        private final ConditionVariable mQuitCv = new ConditionVariable(false);

        WrappedNetworkMonitor() {
                super(mContext, mCallbacks, mNetwork, mLogger, mValidationLogger, mDependencies,
                        mDataStallStatsUtils);
            super(mContext, mCallbacks, mNonPrivateDnsBypassNetwork, mLogger, mValidationLogger,
                    mDependencies, mDataStallStatsUtils);
        }

        @Override
@@ -314,23 +385,22 @@ public class NetworkMonitorTest {
        }
    }

    private WrappedNetworkMonitor makeMonitor() {
    private WrappedNetworkMonitor makeMonitor(NetworkCapabilities nc) {
        final WrappedNetworkMonitor nm = new WrappedNetworkMonitor();
        nm.start();
        setNetworkCapabilities(nm, nc);
        waitForIdle(nm.getHandler());
        mCreatedNetworkMonitors.add(nm);
        return nm;
    }

    private WrappedNetworkMonitor makeMeteredNetworkMonitor() {
        final WrappedNetworkMonitor nm = makeMonitor();
        setNetworkCapabilities(nm, METERED_CAPABILITIES);
        final WrappedNetworkMonitor nm = makeMonitor(METERED_CAPABILITIES);
        return nm;
    }

    private WrappedNetworkMonitor makeNotMeteredNetworkMonitor() {
        final WrappedNetworkMonitor nm = makeMonitor();
        setNetworkCapabilities(nm, NOT_METERED_CAPABILITIES);
        final WrappedNetworkMonitor nm = makeMonitor(NOT_METERED_CAPABILITIES);
        return nm;
    }

@@ -603,7 +673,7 @@ public class NetworkMonitorTest {
        setSslException(mHttpsConnection);
        setPortal302(mHttpConnection);

        final NetworkMonitor nm = makeMonitor();
        final NetworkMonitor nm = makeMonitor(METERED_CAPABILITIES);
        nm.notifyNetworkConnected(TEST_LINK_PROPERTIES, METERED_CAPABILITIES);

        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
@@ -637,6 +707,63 @@ public class NetworkMonitorTest {
        assertEquals(0, mRegisteredReceivers.size());
    }

    @Test
    public void testPrivateDnsSuccess() throws Exception {
        setStatus(mHttpsConnection, 204);
        setStatus(mHttpConnection, 204);
        mFakeDns.setAnswer("dns.google", new String[]{"2001:db8::53"});

        WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0]));
        wnm.notifyNetworkConnected(TEST_LINK_PROPERTIES, NOT_METERED_CAPABILITIES);
        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_VALID), eq(null));
    }

    @Test
    public void testPrivateDnsResolutionRetryUpdate() throws Exception {
        // Set a private DNS hostname that doesn't resolve and expect validation to fail.
        mFakeDns.setAnswer("dns.google", new String[0]);
        setStatus(mHttpsConnection, 204);
        setStatus(mHttpConnection, 204);

        WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0]));
        wnm.notifyNetworkConnected(TEST_LINK_PROPERTIES, NOT_METERED_CAPABILITIES);
        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_INVALID), eq(null));

        // Fix DNS and retry, expect validation to succeed.
        reset(mCallbacks);
        mFakeDns.setAnswer("dns.google", new String[]{"2001:db8::1"});

        wnm.forceReevaluation(Process.myUid());
        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_VALID), eq(null));

        // Change configuration to an invalid DNS name, expect validation to fail.
        reset(mCallbacks);
        mFakeDns.setAnswer("dns.bad", new String[0]);
        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.bad", new InetAddress[0]));
        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_INVALID), eq(null));

        // Change configuration back to working again, but make private DNS not work.
        // Expect validation to fail.
        reset(mCallbacks);
        mFakeDns.setNonBypassPrivateDnsWorking(false);
        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0]));
        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_INVALID), eq(null));

        // Make private DNS work again. Expect validation to succeed.
        reset(mCallbacks);
        mFakeDns.setNonBypassPrivateDnsWorking(true);
        wnm.forceReevaluation(Process.myUid());
        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
                .notifyNetworkTested(eq(NETWORK_TEST_RESULT_VALID), eq(null));
    }

    @Test
    public void testDataStall_StallSuspectedAndSendMetrics() throws IOException {
        WrappedNetworkMonitor wrappedMonitor = makeNotMeteredNetworkMonitor();
@@ -728,25 +855,27 @@ public class NetworkMonitorTest {
        WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
        final int shortTimeoutMs = 200;

        // Clear the wildcard DNS response created in setUp.
        mFakeDns.setAnswer("*", null);

        String[] expected = new String[]{"2001:db8::"};
        setDnsAnswers(expected);
        mFakeDns.setAnswer("www.google.com", expected);
        InetAddress[] actual = wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
        assertIpAddressArrayEquals(expected, actual);

        expected = new String[]{"2001:db8::", "192.0.2.1"};
        setDnsAnswers(expected);
        actual = wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
        mFakeDns.setAnswer("www.googleapis.com", expected);
        actual = wnm.sendDnsProbeWithTimeout("www.googleapis.com", shortTimeoutMs);
        assertIpAddressArrayEquals(expected, actual);

        expected = new String[0];
        setDnsAnswers(expected);
        mFakeDns.setAnswer("www.google.com", new String[0]);
        try {
            wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
            fail("No DNS results, expected UnknownHostException");
        } catch (UnknownHostException e) {
        }

        setDnsAnswers(null);
        mFakeDns.setAnswer("www.google.com", null);
        try {
            wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
            fail("DNS query timed out, expected UnknownHostException");
@@ -841,7 +970,7 @@ public class NetworkMonitorTest {
    }

    private NetworkMonitor runNetworkTest(NetworkCapabilities nc, int testResult) {
        final NetworkMonitor monitor = makeMonitor();
        final NetworkMonitor monitor = makeMonitor(nc);
        monitor.notifyNetworkConnected(TEST_LINK_PROPERTIES, nc);
        try {
            verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))