Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 0c171dff authored by Mike Yu's avatar Mike Yu Committed by Automerger Merge Worker
Browse files

Implement DNS probe in DoT validation am: 59cd583e

Original change: https://googleplex-android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/14997552

Change-Id: I17b74357ef92e7d1def416ced775bcde60771983
parents 70a21cad 59cd583e
Loading
Loading
Loading
Loading
+189 −43
Original line number Diff line number Diff line
@@ -18,14 +18,24 @@

#include "DnsTlsTransport.h"

#include <span>

#include <android-base/format.h>
#include <android-base/logging.h>
#include <android-base/result.h>
#include <android-base/stringprintf.h>
#include <arpa/inet.h>
#include <arpa/nameser.h>
#include <netdutils/Stopwatch.h>
#include <netdutils/ThreadUtil.h>
#include <private/android_filesystem_config.h>  // AID_DNS
#include <sys/poll.h>

#include "DnsTlsSocketFactory.h"
#include "Experiments.h"
#include "IDnsTlsSocketFactory.h"
#include "resolv_private.h"
#include "util.h"

using android::base::StringPrintf;
using android::netdutils::setThreadName;
@@ -33,6 +43,113 @@ using android::netdutils::setThreadName;
namespace android {
namespace net {

namespace {

// Make a DNS query for the hostname "<random>-dnsotls-ds.metric.gstatic.com".
std::vector<uint8_t> makeDnsQuery() {
    static const char kDnsSafeChars[] =
            "abcdefhijklmnopqrstuvwxyz"
            "ABCDEFHIJKLMNOPQRSTUVWXYZ"
            "0123456789";
    const auto c = [](uint8_t rnd) -> uint8_t {
        return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
    };
    uint8_t rnd[8];
    arc4random_buf(rnd, std::size(rnd));

    return std::vector<uint8_t>{
            rnd[6], rnd[7],  // [0-1]   query ID
            1,      0,       // [2-3]   flags; query[2] = 1 for recursion desired (RD).
            0,      1,       // [4-5]   QDCOUNT (number of queries)
            0,      0,       // [6-7]   ANCOUNT (number of answers)
            0,      0,       // [8-9]   NSCOUNT (number of name server records)
            0,      0,       // [10-11] ARCOUNT (number of additional records)
            17,     c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), '-', 'd', 'n',
            's',    'o',       't',       'l',       's',       '-',       'd',       's', 6,   'm',
            'e',    't',       'r',       'i',       'c',       7,         'g',       's', 't', 'a',
            't',    'i',       'c',       3,         'c',       'o',       'm',
            0,                  // null terminator of FQDN (root TLD)
            0,      ns_t_aaaa,  // QTYPE
            0,      ns_c_in     // QCLASS
    };
}

base::Result<void> checkDnsResponse(const std::span<const uint8_t> answer) {
    if (answer.size() < NS_HFIXEDSZ) {
        return Errorf("short response: {}", answer.size());
    }

    const int qdcount = (answer[4] << 8) | answer[5];
    if (qdcount != 1) {
        return Errorf("reply query count != 1: {}", qdcount);
    }

    const int ancount = (answer[6] << 8) | answer[7];
    LOG(DEBUG) << "answer count: " << ancount;

    // TODO: Further validate the response contents (check for valid AAAA record, ...).
    // Note that currently, integration tests rely on this function accepting a
    // response with zero records.

    return {};
}

// Sends |query| to the given server, and returns the DNS response.
base::Result<void> sendUdpQuery(netdutils::IPAddress ip, uint32_t mark,
                                std::span<const uint8_t> query) {
    const sockaddr_storage ss = netdutils::IPSockAddr(ip, 53);
    const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss);
    const int nsaplen = sockaddrSize(nsap);
    const int sockType = SOCK_DGRAM | SOCK_NONBLOCK | SOCK_CLOEXEC;
    android::base::unique_fd fd{socket(nsap->sa_family, sockType, 0)};
    if (fd < 0) {
        return ErrnoErrorf("socket failed");
    }

    resolv_tag_socket(fd.get(), AID_DNS, NET_CONTEXT_INVALID_PID);
    if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mark, sizeof(mark)) < 0) {
        return ErrnoErrorf("setsockopt failed");
    }

    if (connect(fd.get(), nsap, (socklen_t)nsaplen) < 0) {
        return ErrnoErrorf("connect failed");
    }

    if (send(fd, query.data(), query.size(), 0) != query.size()) {
        return ErrnoErrorf("send failed");
    }

    const int timeoutMs = 3000;
    while (true) {
        pollfd fds = {.fd = fd, .events = POLLIN};

        const int n = TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs));
        if (n == 0) {
            return Errorf("poll timed out");
        }
        if (n < 0) {
            return ErrnoErrorf("poll failed");
        }
        if (fds.revents & (POLLIN | POLLERR)) {
            std::vector<uint8_t> buf(MAXPACKET);
            const int resplen = recv(fd, buf.data(), buf.size(), 0);

            if (resplen < 0) {
                return ErrnoErrorf("recvfrom failed");
            }

            buf.resize(resplen);
            if (auto result = checkDnsResponse(buf); !result.ok()) {
                return Errorf("checkDnsResponse failed: {}", result.error().message());
            }

            return {};
        }
    }
}

}  // namespace

std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
    std::lock_guard guard(mLock);

@@ -160,65 +277,94 @@ DnsTlsTransport::~DnsTlsTransport() {
// That may require moving it to DnsTlsDispatcher.
bool DnsTlsTransport::validate(const DnsTlsServer& server, uint32_t mark) {
    LOG(DEBUG) << "Beginning validation with mark " << std::hex << mark;
    // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
    // order to prove that it is actually a working DNS over TLS server.
    static const char kDnsSafeChars[] =
            "abcdefhijklmnopqrstuvwxyz"
            "ABCDEFHIJKLMNOPQRSTUVWXYZ"
            "0123456789";
    const auto c = [](uint8_t rnd) -> uint8_t {
        return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
    };
    uint8_t rnd[8];
    arc4random_buf(rnd, std::size(rnd));
    // We could try to use res_mkquery() here, but it's basically the same.
    uint8_t query[] = {
        rnd[6], rnd[7],  // [0-1]   query ID
        1, 0,  // [2-3]   flags; query[2] = 1 for recursion desired (RD).
        0, 1,  // [4-5]   QDCOUNT (number of queries)
        0, 0,  // [6-7]   ANCOUNT (number of answers)
        0, 0,  // [8-9]   NSCOUNT (number of name server records)
        0, 0,  // [10-11] ARCOUNT (number of additional records)
        17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
            '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
        6, 'm', 'e', 't', 'r', 'i', 'c',
        7, 'g', 's', 't', 'a', 't', 'i', 'c',
        3, 'c', 'o', 'm',
        0,  // null terminator of FQDN (root TLD)
        0, ns_t_aaaa,  // QTYPE
        0, ns_c_in     // QCLASS
    };
    const int qlen = std::size(query);

    int replylen = 0;
    const std::vector<uint8_t> query = makeDnsQuery();
    DnsTlsSocketFactory factory;
    DnsTlsTransport transport(server, mark, &factory);
    auto r = transport.query(netdutils::Slice(query, qlen)).get();

    // Send the initial query to warm up the connection.
    auto r = transport.query(netdutils::makeSlice(query)).get();
    if (r.code != Response::success) {
        LOG(WARNING) << "query failed";
        return false;
    }

    const std::vector<uint8_t>& recvbuf = r.response;
    if (recvbuf.size() < NS_HFIXEDSZ) {
        LOG(WARNING) << "short response: " << replylen;
    if (auto result = checkDnsResponse(recvbuf); !result.ok()) {
        LOG(WARNING) << "checkDnsResponse failed: " << result.error().message();
        return false;
    }

    const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
    if (qdcount != 1) {
        LOG(WARNING) << "reply query count != 1: " << qdcount;
        return false;
    // If this validation is not for opportunistic mode, or the flags are not properly set,
    // the validation is done. If not, the validation will compare DoT probe latency and
    // UDP probe latency, and it will pass if:
    //   dot_probe_latency < latencyFactor * udp_probe_latency * latencyOffsetMs
    //
    // For instance, with latencyFactor = 3 and latencyOffsetMs = 10, if UDP probe latency is 5 ms,
    // DoT probe latency must less than 25 ms.
    int latencyFactor = Experiments::getInstance()->getFlag("dot_validation_latency_factor", -1);
    int latencyOffsetMs =
            Experiments::getInstance()->getFlag("dot_validation_latency_offset_ms", -1);
    const bool shouldCompareUdpLatency =
            server.name.empty() &&
            (latencyFactor >= 0 && latencyOffsetMs >= 0 && latencyFactor + latencyOffsetMs != 0);
    if (!shouldCompareUdpLatency) {
        return true;
    }

    const int ancount = (recvbuf[6] << 8) | recvbuf[7];
    LOG(DEBUG) << "answer count: " << ancount;
    LOG(INFO) << fmt::format("Use flags: latencyFactor={}, latencyOffsetMs={}", latencyFactor,
                             latencyOffsetMs);

    // TODO: Further validate the response contents (check for valid AAAA record, ...).
    // Note that currently, integration tests rely on this function accepting a
    // response with zero records.
    int64_t udpProbeTimeUs = 0;
    bool udpProbeGotAnswer = false;
    std::thread udpProbeThread([&] {
        // Can issue another probe if the first one fails or is lost.
        for (int i = 1; i < 3; i++) {
            netdutils::Stopwatch stopwatch;
            auto result = sendUdpQuery(server.addr().ip(), mark, query);
            udpProbeTimeUs = stopwatch.timeTakenUs();
            udpProbeGotAnswer = result.ok();
            LOG(INFO) << fmt::format("UDP probe for {} {}, took {:.3f}ms", server.toIpString(),
                                     (udpProbeGotAnswer ? "succeeded" : "failed"),
                                     udpProbeTimeUs / 1000.0);

    return true;
            if (udpProbeGotAnswer) {
                break;
            }
            LOG(WARNING) << "sendUdpQuery attempt " << i << " failed: " << result.error().message();
        }
    });

    int64_t dotProbeTimeUs = 0;
    bool dotProbeGotAnswer = false;
    std::thread dotProbeThread([&] {
        netdutils::Stopwatch stopwatch;
        auto r = transport.query(netdutils::makeSlice(query)).get();
        dotProbeTimeUs = stopwatch.timeTakenUs();

        if (r.code != Response::success) {
            LOG(WARNING) << "query failed";
        } else {
            if (auto result = checkDnsResponse(recvbuf); !result.ok()) {
                LOG(WARNING) << "checkDnsResponse failed: " << result.error().message();
            } else {
                dotProbeGotAnswer = true;
            }
        }

        LOG(INFO) << fmt::format("DoT probe for {} {}, took {:.3f}ms", server.toIpString(),
                                 (dotProbeGotAnswer ? "succeeded" : "failed"),
                                 dotProbeTimeUs / 1000.0);
    });

    // TODO: If DoT probe thread finishes before UDP probe thread and dotProbeGotAnswer is false,
    // actively cancel UDP probe thread.
    dotProbeThread.join();
    udpProbeThread.join();

    if (!dotProbeGotAnswer) return false;
    if (!udpProbeGotAnswer) return true;
    return dotProbeTimeUs < (latencyFactor * udpProbeTimeUs + latencyOffsetMs * 1000);
}

}  // end of namespace net
+11 −3
Original line number Diff line number Diff line
@@ -49,10 +49,18 @@ class Experiments {
    // TODO: Migrate other experiment flags to here.
    // (retry_count, retransmission_time_interval)
    static constexpr const char* const kExperimentFlagKeyList[] = {
            "keep_listening_udp",   "parallel_lookup_release",    "parallel_lookup_sleep_time",
            "sort_nameservers",     "dot_async_handshake",        "dot_connect_timeout_ms",
            "dot_maxtries",         "dot_revalidation_threshold", "dot_xport_unusable_threshold",
            "keep_listening_udp",
            "parallel_lookup_release",
            "parallel_lookup_sleep_time",
            "sort_nameservers",
            "dot_async_handshake",
            "dot_connect_timeout_ms",
            "dot_maxtries",
            "dot_revalidation_threshold",
            "dot_xport_unusable_threshold",
            "dot_query_timeout_ms",
            "dot_validation_latency_factor",
            "dot_validation_latency_offset_ms",
    };
    // This value is used in updateInternal as the default value if any flags can't be found.
    static constexpr int kFlagIntDefault = INT_MIN;
+4 −0
Original line number Diff line number Diff line
@@ -35,6 +35,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test {
        ASSERT_TRUE(tls1.startServer());
        ASSERT_TRUE(tls2.startServer());
        ASSERT_TRUE(backend.startServer());
        ASSERT_TRUE(backend1ForUdpProbe.startServer());
        ASSERT_TRUE(backend2ForUdpProbe.startServer());
    }

    void SetUp() {
@@ -132,6 +134,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test {
    inline static test::DnsTlsFrontend tls1{kServer1, "853", kBackend, "53"};
    inline static test::DnsTlsFrontend tls2{kServer2, "853", kBackend, "53"};
    inline static test::DNSResponder backend{kBackend, "53"};
    inline static test::DNSResponder backend1ForUdpProbe{kServer1, "53"};
    inline static test::DNSResponder backend2ForUdpProbe{kServer2, "53"};
};

TEST_F(PrivateDnsConfigurationTest, ValidationSuccess) {
+24 −0
Original line number Diff line number Diff line
@@ -32,6 +32,7 @@
#include <android-base/logging.h>
#include <netdutils/InternetAddresses.h>
#include <netdutils/SocketOption.h>
#include "dns_responder.h"
#include "dns_tls_certificate.h"

using android::netdutils::enableSockopt;
@@ -235,7 +236,9 @@ void DnsTlsFrontend::requestHandler() {
int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
    int queryCounts = 0;
    std::vector<uint8_t> reply;
    bool isDotProbe = false;
    pollfd fds = {.fd = clientFd, .events = POLLIN};
again:
    do {
        uint8_t queryHeader[2];
        if (SSL_read(ssl, &queryHeader, 2) != 2) {
@@ -258,6 +261,19 @@ int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
            LOG(INFO) << "Failed to send query";
            return queryCounts;
        }

        if (!isDotProbe) {
            DNSHeader dnsHdr;
            dnsHdr.read((char*)query, (char*)query + qlen);
            for (const auto& question : dnsHdr.questions) {
                if (question.qname.name.find("dnsotls-ds.metric.gstatic.com") !=
                    std::string::npos) {
                    isDotProbe = true;
                    break;
                }
            }
        }

        const int max_size = 4096;
        uint8_t recv_buffer[max_size];
        int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0);
@@ -288,6 +304,14 @@ int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
        LOG(WARNING) << "Failed to write response body";
    }

    // Poll again because the same DoT probe might be sent again.
    if (isDotProbe && queryCounts == 1) {
        int n = poll(&fds, 1, 50);
        if (n > 0 && fds.revents & POLLIN) {
            goto again;
        }
    }

    LOG(DEBUG) << __func__ << " return: " << queryCounts;
    return queryCounts;
}
+96 −3
Original line number Diff line number Diff line
@@ -88,7 +88,10 @@ const std::string kDotRevalidationThresholdFlag(
const std::string kDotXportUnusableThresholdFlag(
        "persist.device_config.netd_native.dot_xport_unusable_threshold");
const std::string kDotQueryTimeoutMsFlag("persist.device_config.netd_native.dot_query_timeout_ms");

const std::string kDotValidationLatencyFactorFlag(
        "persist.device_config.netd_native.dot_validation_latency_factor");
const std::string kDotValidationLatencyOffsetMsFlag(
        "persist.device_config.netd_native.dot_validation_latency_offset_ms");
// Semi-public Bionic hook used by the NDK (frameworks/base/native/android/net.c)
// Tested here for convenience.
extern "C" int android_getaddrinfofornet(const char* hostname, const char* servname,
@@ -4697,7 +4700,6 @@ TEST_F(ResolverTest, TlsServerRevalidation) {
        if (config.dnsMode == "STRICT") parcel.tlsName = kDefaultPrivateDnsHostName;
        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
        EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
        EXPECT_TRUE(tls.waitForQueries(1));
        tls.clearQueries();
        dns.clearQueries();

@@ -4739,9 +4741,22 @@ TEST_F(ResolverTest, TlsServerRevalidation) {

        // Step 5 and 6.
        int expectedDotQueries = queries;
        int extraDnsProbe = 0;
        if (config.expectRevalidationHappen) {
            EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
            expectedDotQueries++;

            // This test is sensitive to the number of queries sent in DoT validation.
            const std::string latencyFactor =
                    android::base::GetProperty(kDotValidationLatencyFactorFlag, "-1");
            const std::string latencyOffsetMs =
                    android::base::GetProperty(kDotValidationLatencyOffsetMsFlag, "-1");
            const bool dotValidationExtraProbes =
                    (latencyFactor != "-1" && latencyOffsetMs != "-1");
            if (dotValidationExtraProbes) {
                expectedDotQueries++;
                extraDnsProbe = 1;
            }
        }

        // Step 7 and 8.
@@ -4750,7 +4765,8 @@ TEST_F(ResolverTest, TlsServerRevalidation) {
        expectedDotQueries++;

        const int expectedDo53Queries =
                expectedDotQueries + (config.dnsMode == "OPPORTUNISTIC" ? queries : 0);
                expectedDotQueries +
                (config.dnsMode == "OPPORTUNISTIC" ? (queries + extraDnsProbe) : 0);

        if (config.expectDotUnusable) {
            // A DoT server can be deemed as unusable only in opportunistic mode. When it happens,
@@ -4762,6 +4778,83 @@ TEST_F(ResolverTest, TlsServerRevalidation) {
    }
}

// Verifies that private DNS validation fails if DoT server is much slower than cleartext server.
TEST_F(ResolverTest, TlsServerValidation_UdpProbe) {
    constexpr char backend_addr[] = "127.0.0.3";
    test::DNSResponder backend(backend_addr);
    backend.setResponseDelayMs(200);
    ASSERT_TRUE(backend.startServer());

    static const struct TestConfig {
        int latencyFactor;
        int latencyOffsetMs;
        bool udpProbeLost;
        size_t expectedUdpProbes;
        bool expectedValidationPass;
    } testConfigs[] = {
            // clang-format off
            {-1, -1,  false, 0, true},
            {0,  0,   false, 0, true},
            {1,  10,  false, 1, false},
            {1,  10,  true,  2, false},
            {5,  300, false, 1, true},
            {5,  300, true,  2, true},
            // clang-format on
    };

    for (const auto& config : testConfigs) {
        SCOPED_TRACE(fmt::format("testConfig: [{}, {}, {}]", config.latencyFactor,
                                 config.latencyOffsetMs, config.udpProbeLost));

        const std::string addr = getUniqueIPv4Address();
        test::DNSResponder dns(addr, "53", static_cast<ns_rcode>(-1));
        test::DnsTlsFrontend tls(addr, "853", backend_addr, "53");
        dns.setResponseDelayMs(10);
        ASSERT_TRUE(dns.startServer());
        ASSERT_TRUE(tls.startServer());

        ScopedSystemProperties sp1(kDotValidationLatencyFactorFlag,
                                   std::to_string(config.latencyFactor));
        ScopedSystemProperties sp2(kDotValidationLatencyOffsetMsFlag,
                                   std::to_string(config.latencyOffsetMs));
        resetNetwork();

        std::unique_ptr<std::thread> thread;
        if (config.udpProbeLost) {
            thread.reset(new std::thread([&dns]() {
                // Simulate that the first UDP probe is lost and the second UDP probe succeeds.
                dns.setResponseProbability(0.0);
                std::this_thread::sleep_for(std::chrono::seconds(2));
                dns.setResponseProbability(1.0);
            }));
        }

        // Set up opportunistic mode, and wait for the validation complete.
        auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
        parcel.servers = {addr};
        parcel.tlsServers = {addr};
        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));

        // The timeout of WaitForPrivateDnsValidation is 5 seconds which is still enough for
        // the testcase of UDP probe lost because the retry of UDP probe happens after 3 seconds.
        EXPECT_TRUE(
                WaitForPrivateDnsValidation(tls.listen_address(), config.expectedValidationPass));
        EXPECT_EQ(dns.queries().size(), config.expectedUdpProbes);
        dns.clearQueries();

        // Test that Private DNS validation always pass in strict mode.
        parcel.tlsName = kDefaultPrivateDnsHostName;
        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
        EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
        EXPECT_EQ(dns.queries().size(), 0U);

        if (thread) {
            thread->join();
            thread.reset();
        }
    }
}

TEST_F(ResolverTest, FlushNetworkCache) {
    SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
    test::DNSResponder dns;