Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ddb7da6c authored by Chiachang Wang's avatar Chiachang Wang
Browse files

Clean up for async dns query in NetworkStack

This is a follow-up commit for aosp/944772 to
1. Update error message for UNE
2. Refactor identical code
3. Update comments

Bug: 123435238
Test: atest com.android.server.connectivity.NetworkMonitorTest
Change-Id: I11d5013497352cd32ff43fbdd88e39d12835277c
parent 6a842931
Loading
Loading
Loading
Loading
+77 −22
Original line number Diff line number Diff line
@@ -21,21 +21,25 @@ import static android.net.DnsResolver.TYPE_A;
import static android.net.DnsResolver.TYPE_AAAA;

import android.annotation.NonNull;
import android.annotation.Nullable;
import android.net.DnsResolver;
import android.net.Network;
import android.net.TrafficStats;
import android.net.util.Stopwatch;
import android.util.Log;

import com.android.internal.util.TrafficStatsConstants;
import com.android.server.connectivity.NetworkMonitor.DnsLogFunc;

import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.TimeoutException;

/**
 * Collection of utilities for dns query.
@@ -44,6 +48,7 @@ public class DnsUtils {
    // Decide what queries to make depending on what IP addresses are on the system.
    public static final int TYPE_ADDRCONFIG = -1;
    private static final String TAG = DnsUtils.class.getSimpleName();
    private static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);

    /**
     * Return both A and AAAA query results regardless the ip address type of the giving network.
@@ -51,27 +56,33 @@ public class DnsUtils {
     */
    @NonNull
    public static InetAddress[] getAllByName(@NonNull final DnsResolver dnsResolver,
            @NonNull final Network network, @NonNull String host, int timeout)
            throws UnknownHostException {
            @NonNull final Network network, @NonNull String host, int timeout,
            @NonNull final DnsLogFunc logger) throws UnknownHostException {
        final List<InetAddress> result = new ArrayList<InetAddress>();
        final StringBuilder errorMsg = new StringBuilder(host);

        try {
            result.addAll(Arrays.asList(
                    getAllByName(dnsResolver, network, host, TYPE_AAAA, FLAG_NO_CACHE_LOOKUP,
                    timeout)));
                    timeout, logger)));
        } catch (UnknownHostException e) {
            // Might happen if the host is v4-only, still need to query TYPE_A
            errorMsg.append(String.format(" (%s)%s", dnsTypeToStr(TYPE_AAAA), e.getMessage()));
        }
        try {
            result.addAll(Arrays.asList(
                    getAllByName(dnsResolver, network, host, TYPE_A, FLAG_NO_CACHE_LOOKUP,
                    timeout)));
                    timeout, logger)));
        } catch (UnknownHostException e) {
            // Might happen if the host is v6-only, still need to return AAAA answers
            errorMsg.append(String.format(" (%s)%s", dnsTypeToStr(TYPE_A), e.getMessage()));
        }

        if (result.size() == 0) {
            throw new UnknownHostException(host);
            logger.log("FAIL: " + errorMsg.toString());
            throw new UnknownHostException(errorMsg.toString());
        }
        logger.log("OK: " + host + " " + result.toString());
        return result.toArray(new InetAddress[0]);
    }

@@ -82,26 +93,34 @@ public class DnsUtils {
    @NonNull
    public static InetAddress[] getAllByName(@NonNull final DnsResolver dnsResolver,
            @NonNull final Network network, @NonNull final String host, int type, int flag,
            int timeoutMs) throws UnknownHostException {
        final CountDownLatch latch = new CountDownLatch(1);
        final AtomicReference<List<InetAddress>> resultRef = new AtomicReference<>();
            int timeoutMs, @Nullable final DnsLogFunc logger) throws UnknownHostException {
        final CompletableFuture<List<InetAddress>> resultRef = new CompletableFuture<>();
        final Stopwatch watch = new Stopwatch().start();


        final DnsResolver.Callback<List<InetAddress>> callback =
                new DnsResolver.Callback<List<InetAddress>>()  {
            @Override
            public void onAnswer(List<InetAddress> answer, int rcode) {
                if (rcode == 0) {
                    resultRef.set(answer);
                if (rcode == 0 && answer != null && answer.size() != 0) {
                    resultRef.complete(answer);
                } else {
                    resultRef.completeExceptionally(new UnknownHostException());
                }
                latch.countDown();
            }

            @Override
            public void onError(@NonNull DnsResolver.DnsException e) {
                Log.d(TAG, "DNS error resolving " + host + ": " + e.getMessage());
                latch.countDown();
                if (DBG) {
                    Log.d(TAG, "DNS error resolving " + host, e);
                }
                resultRef.completeExceptionally(e);
            }
        };
        // TODO: Investigate whether this is still useful.
        // The packets that actually do the DNS queries are sent by netd, but netd doesn't
        // look at the tag at all. Given that this is a library, the tag should be passed in by the
        // caller.
        final int oldTag = TrafficStats.getAndSetThreadStatsTag(
                TrafficStatsConstants.TAG_SYSTEM_PROBE);

@@ -115,16 +134,52 @@ public class DnsUtils {

        TrafficStats.setThreadStatsTag(oldTag);

        List<InetAddress> result = null;
        Exception exception = null;
        try {
            latch.await(timeoutMs, TimeUnit.MILLISECONDS);
        } catch (InterruptedException e) {
            result = resultRef.get(timeoutMs, TimeUnit.MILLISECONDS);
        } catch (ExecutionException e) {
            exception = e;
        } catch (TimeoutException | InterruptedException e) {
            exception = new UnknownHostException("Timeout");
        } finally {
            logDnsResult(result, watch.stop() /* latency */, logger, type,
                    exception != null ? exception.getMessage() : "" /* errorMsg */);
        }

        final List<InetAddress> result = resultRef.get();
        if (result == null || result.size() == 0) {
            throw new UnknownHostException(host);
        }
        if (null != exception) throw (UnknownHostException) exception;

        return result.toArray(new InetAddress[0]);
    }

    private static void logDnsResult(@Nullable final List<InetAddress> results, final long latency,
            @Nullable final DnsLogFunc logger, int type, @NonNull final String errorMsg) {
        if (logger == null) {
            return;
        }

        if (results != null && results.size() != 0) {
            final StringBuilder builder = new StringBuilder();
            for (InetAddress address : results) {
                builder.append(',').append(address.getHostAddress());
            }
            logger.log(String.format("%dms OK %s", latency, builder.substring(1)));
        } else {
            logger.log(String.format("%dms FAIL in type %s %s", latency, dnsTypeToStr(type),
                    errorMsg));
        }
    }

    private static String dnsTypeToStr(int type) {
        switch (type) {
            case TYPE_A:
                return "A";
            case TYPE_AAAA:
                return "AAAA";
            case TYPE_ADDRCONFIG:
                return "ADDRCONFIG";
            default:
        }
        return "UNDEFINED";
    }
}
+14 −13
Original line number Diff line number Diff line
@@ -1044,12 +1044,11 @@ public class NetworkMonitor extends StateMachine {
            try {
                // Do a blocking DNS resolution using the network-assigned nameservers.
                final InetAddress[] ips = DnsUtils.getAllByName(mDependencies.getDnsResolver(),
                        mCleartextDnsNetwork, mPrivateDnsProviderHostname, getDnsProbeTimeout());
                        mCleartextDnsNetwork, mPrivateDnsProviderHostname, getDnsProbeTimeout(),
                        str -> validationLog("Strict mode hostname resolution " + str));
                mPrivateDnsConfig = new PrivateDnsConfig(mPrivateDnsProviderHostname, ips);
                validationLog("Strict mode hostname resolved: " + mPrivateDnsConfig);
            } catch (UnknownHostException uhe) {
                mPrivateDnsConfig = null;
                validationLog("Strict mode hostname resolution failed: " + uhe.getMessage());
            }
            mEvaluationState.noteProbeResult(NETWORK_VALIDATION_PROBE_PRIVDNS,
                    (mPrivateDnsConfig != null) /* succeeded */);
@@ -1153,7 +1152,6 @@ public class NetworkMonitor extends StateMachine {
                    } else if (probeResult.isPartialConnectivity()) {
                        mEvaluationState.reportEvaluationResult(NETWORK_VALIDATION_RESULT_PARTIAL,
                                null /* redirectUrl */);
                        // Check if disable https probing needed.
                        maybeDisableHttpsProbing(mAcceptPartialConnectivity);
                        if (mAcceptPartialConnectivity) {
                            transitionTo(mEvaluatingPrivateDnsState);
@@ -1557,7 +1555,8 @@ public class NetworkMonitor extends StateMachine {
    protected InetAddress[] sendDnsProbeWithTimeout(String host, int timeoutMs)
                throws UnknownHostException {
        return DnsUtils.getAllByName(mDependencies.getDnsResolver(), mCleartextDnsNetwork, host,
                TYPE_ADDRCONFIG, FLAG_EMPTY, timeoutMs);
                TYPE_ADDRCONFIG, FLAG_EMPTY, timeoutMs,
                str -> validationLog(ValidationProbeEvent.PROBE_DNS, host, str));
    }

    /** Do a DNS resolution of the given server. */
@@ -1572,19 +1571,11 @@ public class NetworkMonitor extends StateMachine {
        String connectInfo;
        try {
            InetAddress[] addresses = sendDnsProbeWithTimeout(host, getDnsProbeTimeout());
            StringBuffer buffer = new StringBuffer();
            for (InetAddress address : addresses) {
                buffer.append(',').append(address.getHostAddress());
            }
            result = ValidationProbeEvent.DNS_SUCCESS;
            connectInfo = "OK " + buffer.substring(1);
        } catch (UnknownHostException e) {
            result = ValidationProbeEvent.DNS_FAILURE;
            connectInfo = "FAIL";
        }
        final long latency = watch.stop();
        validationLog(ValidationProbeEvent.PROBE_DNS, host,
                String.format("%dms %s", latency, connectInfo));
        logValidationProbe(latency, ValidationProbeEvent.PROBE_DNS, result);
    }

@@ -2175,4 +2166,14 @@ public class NetworkMonitor extends StateMachine {
        }
        mEvaluationState.noteProbeResult(probeResult, succeeded);
    }

    /**
     * Interface for logging dns results.
     */
    public interface DnsLogFunc {
        /**
         * Log function.
         */
        void log(String s);
    }
}
+24 −26
Original line number Diff line number Diff line
@@ -103,7 +103,8 @@ import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.Spy;
import org.mockito.verification.VerificationWithTimeout;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import java.io.IOException;
import java.net.HttpURLConnection;
@@ -252,25 +253,23 @@ public class NetworkMonitorTest {

            // Queries on mCleartextDnsNetwork using DnsResolver#query.
            doAnswer(invocation -> {
                String hostname = (String) invocation.getArgument(1);
                Executor executor = (Executor) invocation.getArgument(3);
                DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(5);

                List<InetAddress> answer = getAnswer(invocation.getMock(), hostname);
                if (answer != null && answer.size() > 0) {
                    new Handler(Looper.getMainLooper()).post(() -> {
                        executor.execute(() -> callback.onAnswer(answer, 0));
                    });
                }
                // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
                return null;
                return mockQuery(invocation, 1 /* posHostname */, 3 /* posExecutor */,
                        5 /* posCallback */);
            }).when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any());

            // Queries on mCleartextDnsNetwork using using DnsResolver#query with QueryType.
            // Queries on mCleartextDnsNetwork using DnsResolver#query with QueryType.
            doAnswer(invocation -> {
                String hostname = (String) invocation.getArgument(1);
                Executor executor = (Executor) invocation.getArgument(4);
                DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(6);
                return mockQuery(invocation, 1 /* posHostname */, 4 /* posExecutor */,
                        6 /* posCallback */);
            }).when(mDnsResolver).query(any(), any(), anyInt(), anyInt(), any(), any(), any());
        }

        // Mocking queries on DnsResolver#query.
        private Answer mockQuery(InvocationOnMock invocation, int posHostname, int posExecutor,
                int posCallback) {
            String hostname = (String) invocation.getArgument(posHostname);
            Executor executor = (Executor) invocation.getArgument(posExecutor);
            DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(posCallback);

            List<InetAddress> answer = getAnswer(invocation.getMock(), hostname);
            if (answer != null && answer.size() > 0) {
@@ -280,7 +279,6 @@ public class NetworkMonitorTest {
            }
            // If no answers, do nothing. sendDnsProbeWithTimeout will time out and throw UHE.
            return null;
            }).when(mDnsResolver).query(any(), any(), anyInt(), anyInt(), any(), any(), any());
        }
    }