Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 2d5847b0 authored by Chiachang Wang's avatar Chiachang Wang Committed by Chalard Jean
Browse files

Use async dns query to resolve all addresses

Currently, it looks like private DNS server resolution uses
OneAddressPerFamilyNetwork and only returns one server address.
It should return all addresses. Use async dns api for this.

Bug: 123435238
Test: atest NetworkStacktests

Change-Id: I2b7e184d9b9800a83b55dceb73af69085668748c
(cherry picked from commit 40c5295c)
Merged-In: I2b7e184d9b9800a83b55dceb73af69085668748c
Merged-In: I9f50da3c8c2e3b12b29bc8844291e4bf1559cd1f
parent bdaa4de7
Loading
Loading
Loading
Loading
+119 −0
Original line number Original line Diff line number Diff line
/*
 * Copyright (C) 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.networkstack.util;

import static android.net.DnsResolver.FLAG_NO_CACHE_LOOKUP;
import static android.net.DnsResolver.TYPE_A;
import static android.net.DnsResolver.TYPE_AAAA;

import android.annotation.NonNull;
import android.net.DnsResolver;
import android.net.Network;
import android.net.TrafficStats;
import android.util.Log;

import com.android.internal.util.TrafficStatsConstants;

import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

/**
 * Collection of utilities for dns query.
 */
public class DnsUtils {
    // Decide what queries to make depending on what IP addresses are on the system.
    public static final int TYPE_ADDRCONFIG = -1;
    private static final String TAG = DnsUtils.class.getSimpleName();

    /**
     * Return both A and AAAA query results regardless the ip address type of the giving network.
     * Used for probing in NetworkMonitor.
     */
    @NonNull
    public static InetAddress[] getAllByName(@NonNull final DnsResolver dnsResolver,
            @NonNull final Network network, @NonNull String host, int timeout)
            throws UnknownHostException {
        final List<InetAddress> result = new ArrayList<InetAddress>();

        result.addAll(Arrays.asList(
                getAllByName(dnsResolver, network, host, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP,
                timeout)));
        result.addAll(Arrays.asList(
                getAllByName(dnsResolver, network, host, TYPE_A, FLAG_NO_CACHE_LOOKUP,
                timeout)));
        return result.toArray(new InetAddress[0]);
    }

    /**
     * Return dns query result based on the given QueryType(TYPE_A, TYPE_AAAA) or TYPE_ADDRCONFIG.
     * Used for probing in NetworkMonitor.
     */
    @NonNull
    public static InetAddress[] getAllByName(@NonNull final DnsResolver dnsResolver,
            @NonNull final Network network, @NonNull final String host, int type, int flag,
            int timeoutMs) throws UnknownHostException {
        final CountDownLatch latch = new CountDownLatch(1);
        final AtomicReference<List<InetAddress>> resultRef = new AtomicReference<>();

        final DnsResolver.Callback<List<InetAddress>> callback =
                new DnsResolver.Callback<List<InetAddress>>() {
            @Override
            public void onAnswer(List<InetAddress> answer, int rcode) {
                if (rcode == 0) {
                    resultRef.set(answer);
                }
                latch.countDown();
            }

            @Override
            public void onError(@NonNull DnsResolver.DnsException e) {
                Log.d(TAG, "DNS error resolving " + host + ": " + e.getMessage());
                latch.countDown();
            }
        };
        final int oldTag = TrafficStats.getAndSetThreadStatsTag(
                TrafficStatsConstants.TAG_SYSTEM_PROBE);

        if (type == TYPE_ADDRCONFIG) {
            dnsResolver.query(network, host, flag, r -> r.run(), null /* cancellationSignal */,
                    callback);
        } else {
            dnsResolver.query(network, host, type, flag, r -> r.run(),
                    null /* cancellationSignal */, callback);
        }

        TrafficStats.setThreadStatsTag(oldTag);

        try {
            latch.await(timeoutMs, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
        }

        final List<InetAddress> result = resultRef.get();
        if (result == null || result.size() == 0) {
            throw new UnknownHostException(host);
        }

        return result.toArray(new InetAddress[0]);
    }
}
+8 −36
Original line number Original line Diff line number Diff line
@@ -23,6 +23,7 @@ import static android.net.ConnectivityManager.EXTRA_CAPTIVE_PORTAL_PROBE_SPEC;
import static android.net.ConnectivityManager.EXTRA_CAPTIVE_PORTAL_URL;
import static android.net.ConnectivityManager.EXTRA_CAPTIVE_PORTAL_URL;
import static android.net.ConnectivityManager.TYPE_MOBILE;
import static android.net.ConnectivityManager.TYPE_MOBILE;
import static android.net.ConnectivityManager.TYPE_WIFI;
import static android.net.ConnectivityManager.TYPE_WIFI;
import static android.net.DnsResolver.FLAG_EMPTY;
import static android.net.INetworkMonitor.NETWORK_TEST_RESULT_INVALID;
import static android.net.INetworkMonitor.NETWORK_TEST_RESULT_INVALID;
import static android.net.INetworkMonitor.NETWORK_TEST_RESULT_PARTIAL_CONNECTIVITY;
import static android.net.INetworkMonitor.NETWORK_TEST_RESULT_PARTIAL_CONNECTIVITY;
import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED;
import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED;
@@ -56,6 +57,8 @@ import static android.net.util.NetworkStackUtils.CAPTIVE_PORTAL_USE_HTTPS;
import static android.net.util.NetworkStackUtils.NAMESPACE_CONNECTIVITY;
import static android.net.util.NetworkStackUtils.NAMESPACE_CONNECTIVITY;
import static android.net.util.NetworkStackUtils.isEmpty;
import static android.net.util.NetworkStackUtils.isEmpty;


import static com.android.networkstack.util.DnsUtils.TYPE_ADDRCONFIG;

import android.annotation.NonNull;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.Nullable;
import android.app.PendingIntent;
import android.app.PendingIntent;
@@ -113,6 +116,7 @@ import com.android.internal.util.TrafficStatsConstants;
import com.android.networkstack.R;
import com.android.networkstack.R;
import com.android.networkstack.metrics.DataStallDetectionStats;
import com.android.networkstack.metrics.DataStallDetectionStats;
import com.android.networkstack.metrics.DataStallStatsUtils;
import com.android.networkstack.metrics.DataStallStatsUtils;
import com.android.networkstack.util.DnsUtils;


import java.io.IOException;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.HttpURLConnection;
@@ -129,7 +133,6 @@ import java.util.Random;
import java.util.UUID;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Function;


/**
/**
@@ -994,8 +997,8 @@ public class NetworkMonitor extends StateMachine {
        private void resolveStrictModeHostname() {
        private void resolveStrictModeHostname() {
            try {
            try {
                // Do a blocking DNS resolution using the network-assigned nameservers.
                // Do a blocking DNS resolution using the network-assigned nameservers.
                final InetAddress[] ips = mCleartextDnsNetwork.getAllByName(
                final InetAddress[] ips = DnsUtils.getAllByName(mDependencies.getDnsResolver(),
                        mPrivateDnsProviderHostname);
                        mCleartextDnsNetwork, mPrivateDnsProviderHostname, getDnsProbeTimeout());
                mPrivateDnsConfig = new PrivateDnsConfig(mPrivateDnsProviderHostname, ips);
                mPrivateDnsConfig = new PrivateDnsConfig(mPrivateDnsProviderHostname, ips);
                validationLog("Strict mode hostname resolved: " + mPrivateDnsConfig);
                validationLog("Strict mode hostname resolved: " + mPrivateDnsConfig);
            } catch (UnknownHostException uhe) {
            } catch (UnknownHostException uhe) {
@@ -1489,39 +1492,8 @@ public class NetworkMonitor extends StateMachine {
    @VisibleForTesting
    @VisibleForTesting
    protected InetAddress[] sendDnsProbeWithTimeout(String host, int timeoutMs)
    protected InetAddress[] sendDnsProbeWithTimeout(String host, int timeoutMs)
                throws UnknownHostException {
                throws UnknownHostException {
        final CountDownLatch latch = new CountDownLatch(1);
        return DnsUtils.getAllByName(mDependencies.getDnsResolver(), mCleartextDnsNetwork, host,
        final AtomicReference<List<InetAddress>> resultRef = new AtomicReference<>();
                TYPE_ADDRCONFIG, FLAG_EMPTY, timeoutMs);
        final DnsResolver.Callback<List<InetAddress>> callback =
                    new DnsResolver.Callback<List<InetAddress>>() {
            public void onAnswer(List<InetAddress> answer, int rcode) {
                if (rcode == 0) {
                    resultRef.set(answer);
                }
                latch.countDown();
            }
            public void onError(@NonNull DnsResolver.DnsException e) {
                validationLog("DNS error resolving " + host + ": " + e.getMessage());
                latch.countDown();
            }
        };

        final int oldTag = TrafficStats.getAndSetThreadStatsTag(
                TrafficStatsConstants.TAG_SYSTEM_PROBE);
        mDependencies.getDnsResolver().query(mCleartextDnsNetwork, host, DnsResolver.FLAG_EMPTY,
                r -> r.run() /* executor */, null /* cancellationSignal */, callback);
        TrafficStats.setThreadStatsTag(oldTag);

        try {
            latch.await(timeoutMs, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
        }

        List<InetAddress> result = resultRef.get();
        if (result == null || result.size() == 0) {
            throw new UnknownHostException(host);
        }

        return result.toArray(new InetAddress[0]);
    }
    }


    /** Do a DNS resolution of the given server. */
    /** Do a DNS resolution of the given server. */
+16 −5
Original line number Original line Diff line number Diff line
@@ -226,11 +226,6 @@ public class NetworkMonitorTest {


        /** Starts mocking DNS queries. */
        /** Starts mocking DNS queries. */
        private void startMocking() throws UnknownHostException {
        private void startMocking() throws UnknownHostException {
            // Queries on mCleartextDnsNetwork using getAllByName.
            doAnswer(invocation -> {
                return getAllByName(invocation.getMock(), invocation.getArgument(0));
            }).when(mCleartextDnsNetwork).getAllByName(any());

            // Queries on mNetwork using getAllByName.
            // Queries on mNetwork using getAllByName.
            doAnswer(invocation -> {
            doAnswer(invocation -> {
                return getAllByName(invocation.getMock(), invocation.getArgument(0));
                return getAllByName(invocation.getMock(), invocation.getArgument(0));
@@ -251,6 +246,22 @@ public class NetworkMonitorTest {
                // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
                // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
                return null;
                return null;
            }).when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any());
            }).when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any());

            // Queries on mCleartextDnsNetwork using using DnsResolver#query with QueryType.
            doAnswer(invocation -> {
                String hostname = (String) invocation.getArgument(1);
                Executor executor = (Executor) invocation.getArgument(4);
                DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(6);

                List<InetAddress> answer = getAnswer(invocation.getMock(), hostname);
                if (answer != null && answer.size() > 0) {
                    new Handler(Looper.getMainLooper()).post(() -> {
                        executor.execute(() -> callback.onAnswer(answer, 0));
                    });
                }
                // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
                return null;
            }).when(mDnsResolver).query(any(), any(), anyInt(), anyInt(), any(), any(), any());
        }
        }
    }
    }