Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d22de9f1 authored by Chiachang Wang's avatar Chiachang Wang Committed by Gerrit Code Review
Browse files

Merge "Clean up for async dns query in NetworkStack"

parents fc129991 ddb7da6c
Loading
Loading
Loading
Loading
+77 −22
Original line number Diff line number Diff line
@@ -21,21 +21,25 @@ import static android.net.DnsResolver.TYPE_A;
import static android.net.DnsResolver.TYPE_AAAA;

import android.annotation.NonNull;
import android.annotation.Nullable;
import android.net.DnsResolver;
import android.net.Network;
import android.net.TrafficStats;
import android.net.util.Stopwatch;
import android.util.Log;

import com.android.internal.util.TrafficStatsConstants;
import com.android.server.connectivity.NetworkMonitor.DnsLogFunc;

import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.TimeoutException;

/**
 * Collection of utilities for dns query.
@@ -44,6 +48,7 @@ public class DnsUtils {
    // Decide what queries to make depending on what IP addresses are on the system.
    public static final int TYPE_ADDRCONFIG = -1;
    private static final String TAG = DnsUtils.class.getSimpleName();
    private static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);

    /**
     * Return both A and AAAA query results regardless the ip address type of the giving network.
@@ -51,27 +56,33 @@ public class DnsUtils {
     */
    @NonNull
    public static InetAddress[] getAllByName(@NonNull final DnsResolver dnsResolver,
            @NonNull final Network network, @NonNull String host, int timeout)
            throws UnknownHostException {
            @NonNull final Network network, @NonNull String host, int timeout,
            @NonNull final DnsLogFunc logger) throws UnknownHostException {
        final List<InetAddress> result = new ArrayList<InetAddress>();
        final StringBuilder errorMsg = new StringBuilder(host);

        try {
            result.addAll(Arrays.asList(
                    getAllByName(dnsResolver, network, host, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP,
                    timeout)));
                    timeout, logger)));
        } catch (UnknownHostException e) {
            // Might happen if the host is v4-only, still need to query TYPE_A
            errorMsg.append(String.format(" (%s)%s", dnsTypeToStr(TYPE_AAAA), e.getMessage()));
        }
        try {
            result.addAll(Arrays.asList(
                    getAllByName(dnsResolver, network, host, TYPE_A, FLAG_NO_CACHE_LOOKUP,
                    timeout)));
                    timeout, logger)));
        } catch (UnknownHostException e) {
            // Might happen if the host is v6-only, still need to return AAAA answers
            errorMsg.append(String.format(" (%s)%s", dnsTypeToStr(TYPE_A), e.getMessage()));
        }

        if (result.size() == 0) {
            throw new UnknownHostException(host);
            logger.log("FAIL: " + errorMsg.toString());
            throw new UnknownHostException(errorMsg.toString());
        }
        logger.log("OK: " + host + " " + result.toString());
        return result.toArray(new InetAddress[0]);
    }

@@ -82,26 +93,34 @@ public class DnsUtils {
    @NonNull
    public static InetAddress[] getAllByName(@NonNull final DnsResolver dnsResolver,
            @NonNull final Network network, @NonNull final String host, int type, int flag,
            int timeoutMs) throws UnknownHostException {
        final CountDownLatch latch = new CountDownLatch(1);
        final AtomicReference<List<InetAddress>> resultRef = new AtomicReference<>();
            int timeoutMs, @Nullable final DnsLogFunc logger) throws UnknownHostException {
        final CompletableFuture<List<InetAddress>> resultRef = new CompletableFuture<>();
        final Stopwatch watch = new Stopwatch().start();


        final DnsResolver.Callback<List<InetAddress>> callback =
                new DnsResolver.Callback<List<InetAddress>>()  {
            @Override
            public void onAnswer(List<InetAddress> answer, int rcode) {
                if (rcode == 0) {
                    resultRef.set(answer);
                if (rcode == 0 && answer != null && answer.size() != 0) {
                    resultRef.complete(answer);
                } else {
                    resultRef.completeExceptionally(new UnknownHostException());
                }
                latch.countDown();
            }

            @Override
            public void onError(@NonNull DnsResolver.DnsException e) {
                Log.d(TAG, "DNS error resolving " + host + ": " + e.getMessage());
                latch.countDown();
                if (DBG) {
                    Log.d(TAG, "DNS error resolving " + host, e);
                }
                resultRef.completeExceptionally(e);
            }
        };
        // TODO: Investigate whether this is still useful.
        // The packets that actually do the DNS queries are sent by netd, but netd doesn't
        // look at the tag at all. Given that this is a library, the tag should be passed in by the
        // caller.
        final int oldTag = TrafficStats.getAndSetThreadStatsTag(
                TrafficStatsConstants.TAG_SYSTEM_PROBE);

@@ -115,16 +134,52 @@ public class DnsUtils {

        TrafficStats.setThreadStatsTag(oldTag);

        List<InetAddress> result = null;
        Exception exception = null;
        try {
            latch.await(timeoutMs, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            result = resultRef.get(timeoutMs, TimeUnit.MILLISECONDS);
        } catch (ExecutionException e) {
            exception = e;
        } catch (TimeoutException | InterruptedException e) {
            exception = new UnknownHostException("Timeout");
        } finally {
            logDnsResult(result, watch.stop() /* latency */, logger, type,
                    exception != null ? exception.getMessage() : "" /* errorMsg */);
        }

        final List<InetAddress> result = resultRef.get();
        if (result == null || result.size() == 0) {
            throw new UnknownHostException(host);
        }
        if (null != exception) throw (UnknownHostException) exception;

        return result.toArray(new InetAddress[0]);
    }

    private static void logDnsResult(@Nullable final List<InetAddress> results, final long latency,
            @Nullable final DnsLogFunc logger, int type, @NonNull final String errorMsg) {
        if (logger == null) {
            return;
        }

        if (results != null && results.size() != 0) {
            final StringBuilder builder = new StringBuilder();
            for (InetAddress address : results) {
                builder.append(',').append(address.getHostAddress());
            }
            logger.log(String.format("%dms OK %s", latency, builder.substring(1)));
        } else {
            logger.log(String.format("%dms FAIL in type %s %s", latency, dnsTypeToStr(type),
                    errorMsg));
        }
    }

    private static String dnsTypeToStr(int type) {
        switch (type) {
            case TYPE_A:
                return "A";
            case TYPE_AAAA:
                return "AAAA";
            case TYPE_ADDRCONFIG:
                return "ADDRCONFIG";
            default:
        }
        return "UNDEFINED";
    }
}
+14 −13
Original line number Diff line number Diff line
@@ -1044,12 +1044,11 @@ public class NetworkMonitor extends StateMachine {
            try {
                // Do a blocking DNS resolution using the network-assigned nameservers.
                final InetAddress[] ips = DnsUtils.getAllByName(mDependencies.getDnsResolver(),
                        mCleartextDnsNetwork, mPrivateDnsProviderHostname, getDnsProbeTimeout());
                        mCleartextDnsNetwork, mPrivateDnsProviderHostname, getDnsProbeTimeout(),
                        str -> validationLog("Strict mode hostname resolution " + str));
                mPrivateDnsConfig = new PrivateDnsConfig(mPrivateDnsProviderHostname, ips);
                validationLog("Strict mode hostname resolved: " + mPrivateDnsConfig);
            } catch (UnknownHostException uhe) {
                mPrivateDnsConfig = null;
                validationLog("Strict mode hostname resolution failed: " + uhe.getMessage());
            }
            mEvaluationState.noteProbeResult(NETWORK_VALIDATION_PROBE_PRIVDNS,
                    (mPrivateDnsConfig != null) /* succeeded */);
@@ -1153,7 +1152,6 @@ public class NetworkMonitor extends StateMachine {
                    } else if (probeResult.isPartialConnectivity()) {
                        mEvaluationState.reportEvaluationResult(NETWORK_VALIDATION_RESULT_PARTIAL,
                                null /* redirectUrl */);
                        // Check if disable https probing needed.
                        maybeDisableHttpsProbing(mAcceptPartialConnectivity);
                        if (mAcceptPartialConnectivity) {
                            transitionTo(mEvaluatingPrivateDnsState);
@@ -1557,7 +1555,8 @@ public class NetworkMonitor extends StateMachine {
    protected InetAddress[] sendDnsProbeWithTimeout(String host, int timeoutMs)
                throws UnknownHostException {
        return DnsUtils.getAllByName(mDependencies.getDnsResolver(), mCleartextDnsNetwork, host,
                TYPE_ADDRCONFIG, FLAG_EMPTY, timeoutMs);
                TYPE_ADDRCONFIG, FLAG_EMPTY, timeoutMs,
                str -> validationLog(ValidationProbeEvent.PROBE_DNS, host, str));
    }

    /** Do a DNS resolution of the given server. */
@@ -1572,19 +1571,11 @@ public class NetworkMonitor extends StateMachine {
        String connectInfo;
        try {
            InetAddress[] addresses = sendDnsProbeWithTimeout(host, getDnsProbeTimeout());
            StringBuffer buffer = new StringBuffer();
            for (InetAddress address : addresses) {
                buffer.append(',').append(address.getHostAddress());
            }
            result = ValidationProbeEvent.DNS_SUCCESS;
            connectInfo = "OK " + buffer.substring(1);
        } catch (UnknownHostException e) {
            result = ValidationProbeEvent.DNS_FAILURE;
            connectInfo = "FAIL";
        }
        final long latency = watch.stop();
        validationLog(ValidationProbeEvent.PROBE_DNS, host,
                String.format("%dms %s", latency, connectInfo));
        logValidationProbe(latency, ValidationProbeEvent.PROBE_DNS, result);
    }

@@ -2175,4 +2166,14 @@ public class NetworkMonitor extends StateMachine {
        }
        mEvaluationState.noteProbeResult(probeResult, succeeded);
    }

    /**
     * Interface for logging dns results.
     */
    public interface DnsLogFunc {
        /**
         * Log function.
         */
        void log(String s);
    }
}
+24 −26
Original line number Diff line number Diff line
@@ -103,7 +103,8 @@ import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.Spy;
import org.mockito.verification.VerificationWithTimeout;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import java.io.IOException;
import java.net.HttpURLConnection;
@@ -252,25 +253,23 @@ public class NetworkMonitorTest {

            // Queries on mCleartextDnsNetwork using DnsResolver#query.
            doAnswer(invocation -> {
                String hostname = (String) invocation.getArgument(1);
                Executor executor = (Executor) invocation.getArgument(3);
                DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(5);

                List<InetAddress> answer = getAnswer(invocation.getMock(), hostname);
                if (answer != null && answer.size() > 0) {
                    new Handler(Looper.getMainLooper()).post(() -> {
                        executor.execute(() -> callback.onAnswer(answer, 0));
                    });
                }
                // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
                return null;
                return mockQuery(invocation, 1 /* posHostname */, 3 /* posExecutor */,
                        5 /* posCallback */);
            }).when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any());

            // Queries on mCleartextDnsNetwork using using DnsResolver#query with QueryType.
            // Queries on mCleartextDnsNetwork using DnsResolver#query with QueryType.
            doAnswer(invocation -> {
                String hostname = (String) invocation.getArgument(1);
                Executor executor = (Executor) invocation.getArgument(4);
                DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(6);
                return mockQuery(invocation, 1 /* posHostname */, 4 /* posExecutor */,
                        6 /* posCallback */);
            }).when(mDnsResolver).query(any(), any(), anyInt(), anyInt(), any(), any(), any());
        }

        // Mocking queries on DnsResolver#query.
        private Answer mockQuery(InvocationOnMock invocation, int posHostname, int posExecutor,
                int posCallback) {
            String hostname = (String) invocation.getArgument(posHostname);
            Executor executor = (Executor) invocation.getArgument(posExecutor);
            DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(posCallback);

            List<InetAddress> answer = getAnswer(invocation.getMock(), hostname);
            if (answer != null && answer.size() > 0) {
@@ -280,7 +279,6 @@ public class NetworkMonitorTest {
            }
            // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
            return null;
            }).when(mDnsResolver).query(any(), any(), anyInt(), anyInt(), any(), any(), any());
        }
    }