Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 59cd583e authored by Mike Yu's avatar Mike Yu
Browse files

Implement DNS probe in DoT validation

The purpose of this change is to avoid using DoT servers if they
are much slower than DNS servers. The mechanics is flag-guarded,
and it performs in DoT validation.

The mechanics works as follows:
1. Make use of the original DoT query to establish the connection.
2. Use the same DNS packet to issue a DoT query in parallel with a
   UDP query to the same DoT server.
3. If UDP query failed or is lost, issue another one.
4. See the latencies of both queries. Decide if DoT validation
   can pass.

DoT validation passes if dot_latency is less than (a * udp_latency + b),
where a and b are configurable by dot_validation_latency_factor and
dot_validation_latency_offset_ms respectively.

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1733919

Bug: 188153519
Test: run resolv_integration_test twice
Test: run atest with all the flags off/on
	dot_validation_latency_factor: -1 / 3
	dot_validation_latency_offset_ms: -1 / 100
        sort_nameservers: 0 / 1
        dot_xport_unusable_threshold: -1 / 20
        dot_query_timeout_ms: -1 / 10000
        keep_listening_udp: 0 / 1
        parallel_lookup_sleep_time: 2 / 2
        dot_revalidation_threshold: -1 / 10
        dot_async_handshake: 0 / 1
        dot_maxtries: 3 / 1
        dot_connect_timeout_ms: 127000 / 10000
        parallel_lookup_release: UNSET / UNSET

Change-Id: I8507c409b0cb6e48655d54611256917392db69ac
Merged-In: I8507c409b0cb6e48655d54611256917392db69ac
parent b46f8fae
Loading
Loading
Loading
Loading
+189 −43
Original line number Diff line number Diff line
@@ -18,14 +18,24 @@

#include "DnsTlsTransport.h"

#include <span>

#include <android-base/format.h>
#include <android-base/logging.h>
#include <android-base/result.h>
#include <android-base/stringprintf.h>
#include <arpa/inet.h>
#include <arpa/nameser.h>
#include <netdutils/Stopwatch.h>
#include <netdutils/ThreadUtil.h>
#include <private/android_filesystem_config.h>  // AID_DNS
#include <sys/poll.h>

#include "DnsTlsSocketFactory.h"
#include "Experiments.h"
#include "IDnsTlsSocketFactory.h"
#include "resolv_private.h"
#include "util.h"

using android::base::StringPrintf;
using android::netdutils::setThreadName;
@@ -33,6 +43,113 @@ using android::netdutils::setThreadName;
namespace android {
namespace net {

namespace {

// Make a DNS query for the hostname "<random>-dnsotls-ds.metric.gstatic.com".
std::vector<uint8_t> makeDnsQuery() {
    static const char kDnsSafeChars[] =
            "abcdefhijklmnopqrstuvwxyz"
            "ABCDEFHIJKLMNOPQRSTUVWXYZ"
            "0123456789";
    const auto c = [](uint8_t rnd) -> uint8_t {
        return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
    };
    uint8_t rnd[8];
    arc4random_buf(rnd, std::size(rnd));

    return std::vector<uint8_t>{
            rnd[6], rnd[7],  // [0-1]   query ID
            1,      0,       // [2-3]   flags; query[2] = 1 for recursion desired (RD).
            0,      1,       // [4-5]   QDCOUNT (number of queries)
            0,      0,       // [6-7]   ANCOUNT (number of answers)
            0,      0,       // [8-9]   NSCOUNT (number of name server records)
            0,      0,       // [10-11] ARCOUNT (number of additional records)
            17,     c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), '-', 'd', 'n',
            's',    'o',       't',       'l',       's',       '-',       'd',       's', 6,   'm',
            'e',    't',       'r',       'i',       'c',       7,         'g',       's', 't', 'a',
            't',    'i',       'c',       3,         'c',       'o',       'm',
            0,                  // null terminator of FQDN (root TLD)
            0,      ns_t_aaaa,  // QTYPE
            0,      ns_c_in     // QCLASS
    };
}

base::Result<void> checkDnsResponse(const std::span<const uint8_t> answer) {
    if (answer.size() < NS_HFIXEDSZ) {
        return Errorf("short response: {}", answer.size());
    }

    const int qdcount = (answer[4] << 8) | answer[5];
    if (qdcount != 1) {
        return Errorf("reply query count != 1: {}", qdcount);
    }

    const int ancount = (answer[6] << 8) | answer[7];
    LOG(DEBUG) << "answer count: " << ancount;

    // TODO: Further validate the response contents (check for valid AAAA record, ...).
    // Note that currently, integration tests rely on this function accepting a
    // response with zero records.

    return {};
}

// Sends |query| to the given server, and returns the DNS response.
base::Result<void> sendUdpQuery(netdutils::IPAddress ip, uint32_t mark,
                                std::span<const uint8_t> query) {
    const sockaddr_storage ss = netdutils::IPSockAddr(ip, 53);
    const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss);
    const int nsaplen = sockaddrSize(nsap);
    const int sockType = SOCK_DGRAM | SOCK_NONBLOCK | SOCK_CLOEXEC;
    android::base::unique_fd fd{socket(nsap->sa_family, sockType, 0)};
    if (fd < 0) {
        return ErrnoErrorf("socket failed");
    }

    resolv_tag_socket(fd.get(), AID_DNS, NET_CONTEXT_INVALID_PID);
    if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mark, sizeof(mark)) < 0) {
        return ErrnoErrorf("setsockopt failed");
    }

    if (connect(fd.get(), nsap, (socklen_t)nsaplen) < 0) {
        return ErrnoErrorf("connect failed");
    }

    if (send(fd, query.data(), query.size(), 0) != query.size()) {
        return ErrnoErrorf("send failed");
    }

    const int timeoutMs = 3000;
    while (true) {
        pollfd fds = {.fd = fd, .events = POLLIN};

        const int n = TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs));
        if (n == 0) {
            return Errorf("poll timed out");
        }
        if (n < 0) {
            return ErrnoErrorf("poll failed");
        }
        if (fds.revents & (POLLIN | POLLERR)) {
            std::vector<uint8_t> buf(MAXPACKET);
            const int resplen = recv(fd, buf.data(), buf.size(), 0);

            if (resplen < 0) {
                return ErrnoErrorf("recvfrom failed");
            }

            buf.resize(resplen);
            if (auto result = checkDnsResponse(buf); !result.ok()) {
                return Errorf("checkDnsResponse failed: {}", result.error().message());
            }

            return {};
        }
    }
}

}  // namespace

std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
    std::lock_guard guard(mLock);

@@ -160,65 +277,94 @@ DnsTlsTransport::~DnsTlsTransport() {
// That may require moving it to DnsTlsDispatcher.
bool DnsTlsTransport::validate(const DnsTlsServer& server, uint32_t mark) {
    LOG(DEBUG) << "Beginning validation with mark " << std::hex << mark;
    // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
    // order to prove that it is actually a working DNS over TLS server.
    static const char kDnsSafeChars[] =
            "abcdefhijklmnopqrstuvwxyz"
            "ABCDEFHIJKLMNOPQRSTUVWXYZ"
            "0123456789";
    const auto c = [](uint8_t rnd) -> uint8_t {
        return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
    };
    uint8_t rnd[8];
    arc4random_buf(rnd, std::size(rnd));
    // We could try to use res_mkquery() here, but it's basically the same.
    uint8_t query[] = {
        rnd[6], rnd[7],  // [0-1]   query ID
        1, 0,  // [2-3]   flags; query[2] = 1 for recursion desired (RD).
        0, 1,  // [4-5]   QDCOUNT (number of queries)
        0, 0,  // [6-7]   ANCOUNT (number of answers)
        0, 0,  // [8-9]   NSCOUNT (number of name server records)
        0, 0,  // [10-11] ARCOUNT (number of additional records)
        17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
            '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
        6, 'm', 'e', 't', 'r', 'i', 'c',
        7, 'g', 's', 't', 'a', 't', 'i', 'c',
        3, 'c', 'o', 'm',
        0,  // null terminator of FQDN (root TLD)
        0, ns_t_aaaa,  // QTYPE
        0, ns_c_in     // QCLASS
    };
    const int qlen = std::size(query);

    int replylen = 0;
    const std::vector<uint8_t> query = makeDnsQuery();
    DnsTlsSocketFactory factory;
    DnsTlsTransport transport(server, mark, &factory);
    auto r = transport.query(netdutils::Slice(query, qlen)).get();

    // Send the initial query to warm up the connection.
    auto r = transport.query(netdutils::makeSlice(query)).get();
    if (r.code != Response::success) {
        LOG(WARNING) << "query failed";
        return false;
    }

    const std::vector<uint8_t>& recvbuf = r.response;
    if (recvbuf.size() < NS_HFIXEDSZ) {
        LOG(WARNING) << "short response: " << replylen;
    if (auto result = checkDnsResponse(recvbuf); !result.ok()) {
        LOG(WARNING) << "checkDnsResponse failed: " << result.error().message();
        return false;
    }

    const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
    if (qdcount != 1) {
        LOG(WARNING) << "reply query count != 1: " << qdcount;
        return false;
    // If this validation is not for opportunistic mode, or the flags are not properly set,
    // the validation is done. If not, the validation will compare DoT probe latency and
    // UDP probe latency, and it will pass if:
    //   dot_probe_latency < latencyFactor * udp_probe_latency * latencyOffsetMs
    //
    // For instance, with latencyFactor = 3 and latencyOffsetMs = 10, if UDP probe latency is 5 ms,
    // DoT probe latency must less than 25 ms.
    int latencyFactor = Experiments::getInstance()->getFlag("dot_validation_latency_factor", -1);
    int latencyOffsetMs =
            Experiments::getInstance()->getFlag("dot_validation_latency_offset_ms", -1);
    const bool shouldCompareUdpLatency =
            server.name.empty() &&
            (latencyFactor >= 0 && latencyOffsetMs >= 0 && latencyFactor + latencyOffsetMs != 0);
    if (!shouldCompareUdpLatency) {
        return true;
    }

    const int ancount = (recvbuf[6] << 8) | recvbuf[7];
    LOG(DEBUG) << "answer count: " << ancount;
    LOG(INFO) << fmt::format("Use flags: latencyFactor={}, latencyOffsetMs={}", latencyFactor,
                             latencyOffsetMs);

    // TODO: Further validate the response contents (check for valid AAAA record, ...).
    // Note that currently, integration tests rely on this function accepting a
    // response with zero records.
    int64_t udpProbeTimeUs = 0;
    bool udpProbeGotAnswer = false;
    std::thread udpProbeThread([&] {
        // Can issue another probe if the first one fails or is lost.
        for (int i = 1; i < 3; i++) {
            netdutils::Stopwatch stopwatch;
            auto result = sendUdpQuery(server.addr().ip(), mark, query);
            udpProbeTimeUs = stopwatch.timeTakenUs();
            udpProbeGotAnswer = result.ok();
            LOG(INFO) << fmt::format("UDP probe for {} {}, took {:.3f}ms", server.toIpString(),
                                     (udpProbeGotAnswer ? "succeeded" : "failed"),
                                     udpProbeTimeUs / 1000.0);

    return true;
            if (udpProbeGotAnswer) {
                break;
            }
            LOG(WARNING) << "sendUdpQuery attempt " << i << " failed: " << result.error().message();
        }
    });

    int64_t dotProbeTimeUs = 0;
    bool dotProbeGotAnswer = false;
    std::thread dotProbeThread([&] {
        netdutils::Stopwatch stopwatch;
        auto r = transport.query(netdutils::makeSlice(query)).get();
        dotProbeTimeUs = stopwatch.timeTakenUs();

        if (r.code != Response::success) {
            LOG(WARNING) << "query failed";
        } else {
            if (auto result = checkDnsResponse(recvbuf); !result.ok()) {
                LOG(WARNING) << "checkDnsResponse failed: " << result.error().message();
            } else {
                dotProbeGotAnswer = true;
            }
        }

        LOG(INFO) << fmt::format("DoT probe for {} {}, took {:.3f}ms", server.toIpString(),
                                 (dotProbeGotAnswer ? "succeeded" : "failed"),
                                 dotProbeTimeUs / 1000.0);
    });

    // TODO: If DoT probe thread finishes before UDP probe thread and dotProbeGotAnswer is false,
    // actively cancel UDP probe thread.
    dotProbeThread.join();
    udpProbeThread.join();

    if (!dotProbeGotAnswer) return false;
    if (!udpProbeGotAnswer) return true;
    return dotProbeTimeUs < (latencyFactor * udpProbeTimeUs + latencyOffsetMs * 1000);
}

}  // end of namespace net
+11 −3
Original line number Diff line number Diff line
@@ -49,10 +49,18 @@ class Experiments {
    // TODO: Migrate other experiment flags to here.
    // (retry_count, retransmission_time_interval)
    static constexpr const char* const kExperimentFlagKeyList[] = {
            "keep_listening_udp",   "parallel_lookup_release",    "parallel_lookup_sleep_time",
            "sort_nameservers",     "dot_async_handshake",        "dot_connect_timeout_ms",
            "dot_maxtries",         "dot_revalidation_threshold", "dot_xport_unusable_threshold",
            "keep_listening_udp",
            "parallel_lookup_release",
            "parallel_lookup_sleep_time",
            "sort_nameservers",
            "dot_async_handshake",
            "dot_connect_timeout_ms",
            "dot_maxtries",
            "dot_revalidation_threshold",
            "dot_xport_unusable_threshold",
            "dot_query_timeout_ms",
            "dot_validation_latency_factor",
            "dot_validation_latency_offset_ms",
    };
    // This value is used in updateInternal as the default value if any flags can't be found.
    static constexpr int kFlagIntDefault = INT_MIN;
+4 −0
Original line number Diff line number Diff line
@@ -35,6 +35,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test {
        ASSERT_TRUE(tls1.startServer());
        ASSERT_TRUE(tls2.startServer());
        ASSERT_TRUE(backend.startServer());
        ASSERT_TRUE(backend1ForUdpProbe.startServer());
        ASSERT_TRUE(backend2ForUdpProbe.startServer());
    }

    void SetUp() {
@@ -132,6 +134,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test {
    inline static test::DnsTlsFrontend tls1{kServer1, "853", kBackend, "53"};
    inline static test::DnsTlsFrontend tls2{kServer2, "853", kBackend, "53"};
    inline static test::DNSResponder backend{kBackend, "53"};
    inline static test::DNSResponder backend1ForUdpProbe{kServer1, "53"};
    inline static test::DNSResponder backend2ForUdpProbe{kServer2, "53"};
};

TEST_F(PrivateDnsConfigurationTest, ValidationSuccess) {
+24 −0
Original line number Diff line number Diff line
@@ -32,6 +32,7 @@
#include <android-base/logging.h>
#include <netdutils/InternetAddresses.h>
#include <netdutils/SocketOption.h>
#include "dns_responder.h"
#include "dns_tls_certificate.h"

using android::netdutils::enableSockopt;
@@ -235,7 +236,9 @@ void DnsTlsFrontend::requestHandler() {
int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
    int queryCounts = 0;
    std::vector<uint8_t> reply;
    bool isDotProbe = false;
    pollfd fds = {.fd = clientFd, .events = POLLIN};
again:
    do {
        uint8_t queryHeader[2];
        if (SSL_read(ssl, &queryHeader, 2) != 2) {
@@ -258,6 +261,19 @@ int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
            LOG(INFO) << "Failed to send query";
            return queryCounts;
        }

        if (!isDotProbe) {
            DNSHeader dnsHdr;
            dnsHdr.read((char*)query, (char*)query + qlen);
            for (const auto& question : dnsHdr.questions) {
                if (question.qname.name.find("dnsotls-ds.metric.gstatic.com") !=
                    std::string::npos) {
                    isDotProbe = true;
                    break;
                }
            }
        }

        const int max_size = 4096;
        uint8_t recv_buffer[max_size];
        int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0);
@@ -288,6 +304,14 @@ int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
        LOG(WARNING) << "Failed to write response body";
    }

    // Poll again because the same DoT probe might be sent again.
    if (isDotProbe && queryCounts == 1) {
        int n = poll(&fds, 1, 50);
        if (n > 0 && fds.revents & POLLIN) {
            goto again;
        }
    }

    LOG(DEBUG) << __func__ << " return: " << queryCounts;
    return queryCounts;
}
+96 −3
Original line number Diff line number Diff line
@@ -88,7 +88,10 @@ const std::string kDotRevalidationThresholdFlag(
const std::string kDotXportUnusableThresholdFlag(
        "persist.device_config.netd_native.dot_xport_unusable_threshold");
const std::string kDotQueryTimeoutMsFlag("persist.device_config.netd_native.dot_query_timeout_ms");

const std::string kDotValidationLatencyFactorFlag(
        "persist.device_config.netd_native.dot_validation_latency_factor");
const std::string kDotValidationLatencyOffsetMsFlag(
        "persist.device_config.netd_native.dot_validation_latency_offset_ms");
// Semi-public Bionic hook used by the NDK (frameworks/base/native/android/net.c)
// Tested here for convenience.
extern "C" int android_getaddrinfofornet(const char* hostname, const char* servname,
@@ -4697,7 +4700,6 @@ TEST_F(ResolverTest, TlsServerRevalidation) {
        if (config.dnsMode == "STRICT") parcel.tlsName = kDefaultPrivateDnsHostName;
        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
        EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
        EXPECT_TRUE(tls.waitForQueries(1));
        tls.clearQueries();
        dns.clearQueries();

@@ -4739,9 +4741,22 @@ TEST_F(ResolverTest, TlsServerRevalidation) {

        // Step 5 and 6.
        int expectedDotQueries = queries;
        int extraDnsProbe = 0;
        if (config.expectRevalidationHappen) {
            EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
            expectedDotQueries++;

            // This test is sensitive to the number of queries sent in DoT validation.
            const std::string latencyFactor =
                    android::base::GetProperty(kDotValidationLatencyFactorFlag, "-1");
            const std::string latencyOffsetMs =
                    android::base::GetProperty(kDotValidationLatencyOffsetMsFlag, "-1");
            const bool dotValidationExtraProbes =
                    (latencyFactor != "-1" && latencyOffsetMs != "-1");
            if (dotValidationExtraProbes) {
                expectedDotQueries++;
                extraDnsProbe = 1;
            }
        }

        // Step 7 and 8.
@@ -4750,7 +4765,8 @@ TEST_F(ResolverTest, TlsServerRevalidation) {
        expectedDotQueries++;

        const int expectedDo53Queries =
                expectedDotQueries + (config.dnsMode == "OPPORTUNISTIC" ? queries : 0);
                expectedDotQueries +
                (config.dnsMode == "OPPORTUNISTIC" ? (queries + extraDnsProbe) : 0);

        if (config.expectDotUnusable) {
            // A DoT server can be deemed as unusable only in opportunistic mode. When it happens,
@@ -4762,6 +4778,83 @@ TEST_F(ResolverTest, TlsServerRevalidation) {
    }
}

// Verifies that private DNS validation fails if DoT server is much slower than cleartext server.
TEST_F(ResolverTest, TlsServerValidation_UdpProbe) {
    constexpr char backend_addr[] = "127.0.0.3";
    test::DNSResponder backend(backend_addr);
    backend.setResponseDelayMs(200);
    ASSERT_TRUE(backend.startServer());

    static const struct TestConfig {
        int latencyFactor;
        int latencyOffsetMs;
        bool udpProbeLost;
        size_t expectedUdpProbes;
        bool expectedValidationPass;
    } testConfigs[] = {
            // clang-format off
            {-1, -1,  false, 0, true},
            {0,  0,   false, 0, true},
            {1,  10,  false, 1, false},
            {1,  10,  true,  2, false},
            {5,  300, false, 1, true},
            {5,  300, true,  2, true},
            // clang-format on
    };

    for (const auto& config : testConfigs) {
        SCOPED_TRACE(fmt::format("testConfig: [{}, {}, {}]", config.latencyFactor,
                                 config.latencyOffsetMs, config.udpProbeLost));

        const std::string addr = getUniqueIPv4Address();
        test::DNSResponder dns(addr, "53", static_cast<ns_rcode>(-1));
        test::DnsTlsFrontend tls(addr, "853", backend_addr, "53");
        dns.setResponseDelayMs(10);
        ASSERT_TRUE(dns.startServer());
        ASSERT_TRUE(tls.startServer());

        ScopedSystemProperties sp1(kDotValidationLatencyFactorFlag,
                                   std::to_string(config.latencyFactor));
        ScopedSystemProperties sp2(kDotValidationLatencyOffsetMsFlag,
                                   std::to_string(config.latencyOffsetMs));
        resetNetwork();

        std::unique_ptr<std::thread> thread;
        if (config.udpProbeLost) {
            thread.reset(new std::thread([&dns]() {
                // Simulate that the first UDP probe is lost and the second UDP probe succeeds.
                dns.setResponseProbability(0.0);
                std::this_thread::sleep_for(std::chrono::seconds(2));
                dns.setResponseProbability(1.0);
            }));
        }

        // Set up opportunistic mode, and wait for the validation complete.
        auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
        parcel.servers = {addr};
        parcel.tlsServers = {addr};
        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));

        // The timeout of WaitForPrivateDnsValidation is 5 seconds which is still enough for
        // the testcase of UDP probe lost because the retry of UDP probe happens after 3 seconds.
        EXPECT_TRUE(
                WaitForPrivateDnsValidation(tls.listen_address(), config.expectedValidationPass));
        EXPECT_EQ(dns.queries().size(), config.expectedUdpProbes);
        dns.clearQueries();

        // Test that Private DNS validation always pass in strict mode.
        parcel.tlsName = kDefaultPrivateDnsHostName;
        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
        EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
        EXPECT_EQ(dns.queries().size(), 0U);

        if (thread) {
            thread->join();
            thread.reset();
        }
    }
}

TEST_F(ResolverTest, FlushNetworkCache) {
    SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
    test::DNSResponder dns;