Loading DnsTlsTransport.cpp +189 −44 Original line number Diff line number Diff line Loading @@ -18,14 +18,24 @@ #include "DnsTlsTransport.h" #include <span> #include <android-base/format.h> #include <android-base/logging.h> #include <android-base/result.h> #include <android-base/stringprintf.h> #include <arpa/inet.h> #include <arpa/nameser.h> #include <netdutils/Stopwatch.h> #include <netdutils/ThreadUtil.h> #include <private/android_filesystem_config.h> // AID_DNS #include <sys/poll.h> #include "DnsTlsSocketFactory.h" #include "Experiments.h" #include "IDnsTlsSocketFactory.h" #include "resolv_private.h" #include "util.h" using android::base::StringPrintf; using android::netdutils::setThreadName; Loading @@ -33,6 +43,113 @@ using android::netdutils::setThreadName; namespace android { namespace net { namespace { // Make a DNS query for the hostname "<random>-dnsotls-ds.metric.gstatic.com". std::vector<uint8_t> makeDnsQuery() { static const char kDnsSafeChars[] = "abcdefhijklmnopqrstuvwxyz" "ABCDEFHIJKLMNOPQRSTUVWXYZ" "0123456789"; const auto c = [](uint8_t rnd) -> uint8_t { return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))]; }; uint8_t rnd[8]; arc4random_buf(rnd, std::size(rnd)); return std::vector<uint8_t>{ rnd[6], rnd[7], // [0-1] query ID 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD). 0, 1, // [4-5] QDCOUNT (number of queries) 0, 0, // [6-7] ANCOUNT (number of answers) 0, 0, // [8-9] NSCOUNT (number of name server records) 0, 0, // [10-11] ARCOUNT (number of additional records) 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's', 6, 'm', 'e', 't', 'r', 'i', 'c', 7, 'g', 's', 't', 'a', 't', 'i', 'c', 3, 'c', 'o', 'm', 0, // null terminator of FQDN (root TLD) 0, ns_t_aaaa, // QTYPE 0, ns_c_in // QCLASS }; } base::Result<void> checkDnsResponse(const std::span<const uint8_t> answer) { if (answer.size() < NS_HFIXEDSZ) { return Errorf("short response: {}", answer.size()); } const int qdcount = (answer[4] << 8) | answer[5]; if (qdcount != 1) { return Errorf("reply query count != 1: {}", qdcount); } const int ancount = (answer[6] << 8) | answer[7]; LOG(DEBUG) << "answer count: " << ancount; // TODO: Further validate the response contents (check for valid AAAA record, ...). // Note that currently, integration tests rely on this function accepting a // response with zero records. return {}; } // Sends |query| to the given server, and returns the DNS response. base::Result<void> sendUdpQuery(netdutils::IPAddress ip, uint32_t mark, std::span<const uint8_t> query) { const sockaddr_storage ss = netdutils::IPSockAddr(ip, 53); const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss); const int nsaplen = sockaddrSize(nsap); const int sockType = SOCK_DGRAM | SOCK_NONBLOCK | SOCK_CLOEXEC; android::base::unique_fd fd{socket(nsap->sa_family, sockType, 0)}; if (fd < 0) { return ErrnoErrorf("socket failed"); } resolv_tag_socket(fd.get(), AID_DNS, NET_CONTEXT_INVALID_PID); if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mark, sizeof(mark)) < 0) { return ErrnoErrorf("setsockopt failed"); } if (connect(fd.get(), nsap, (socklen_t)nsaplen) < 0) { return ErrnoErrorf("connect failed"); } if (send(fd, query.data(), query.size(), 0) != query.size()) { return ErrnoErrorf("send failed"); } const int timeoutMs = 3000; while (true) { pollfd fds = {.fd = fd, .events = POLLIN}; const int n = TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs)); if (n == 0) { return Errorf("poll timed out"); } if (n < 0) { return ErrnoErrorf("poll failed"); } if (fds.revents & (POLLIN | POLLERR)) { std::vector<uint8_t> buf(MAXPACKET); const int resplen = recv(fd, buf.data(), buf.size(), 0); if (resplen < 0) { return ErrnoErrorf("recvfrom failed"); } buf.resize(resplen); if (auto result = checkDnsResponse(buf); !result.ok()) { return Errorf("checkDnsResponse failed: {}", result.error().message()); } return {}; } } } } // namespace std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) { std::lock_guard guard(mLock); Loading Loading @@ -160,65 +277,93 @@ DnsTlsTransport::~DnsTlsTransport() { // That may require moving it to DnsTlsDispatcher. bool DnsTlsTransport::validate(const DnsTlsServer& server, uint32_t mark) { LOG(DEBUG) << "Beginning validation with mark " << std::hex << mark; // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in // order to prove that it is actually a working DNS over TLS server. static const char kDnsSafeChars[] = "abcdefhijklmnopqrstuvwxyz" "ABCDEFHIJKLMNOPQRSTUVWXYZ" "0123456789"; const auto c = [](uint8_t rnd) -> uint8_t { return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))]; }; uint8_t rnd[8]; arc4random_buf(rnd, std::size(rnd)); // We could try to use res_mkquery() here, but it's basically the same. uint8_t query[] = { rnd[6], rnd[7], // [0-1] query ID 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD). 0, 1, // [4-5] QDCOUNT (number of queries) 0, 0, // [6-7] ANCOUNT (number of answers) 0, 0, // [8-9] NSCOUNT (number of name server records) 0, 0, // [10-11] ARCOUNT (number of additional records) 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's', 6, 'm', 'e', 't', 'r', 'i', 'c', 7, 'g', 's', 't', 'a', 't', 'i', 'c', 3, 'c', 'o', 'm', 0, // null terminator of FQDN (root TLD) 0, ns_t_aaaa, // QTYPE 0, ns_c_in // QCLASS }; const int qlen = std::size(query); int replylen = 0; const std::vector<uint8_t> query = makeDnsQuery(); DnsTlsSocketFactory factory; DnsTlsTransport transport(server, mark, &factory); auto r = transport.query(netdutils::Slice(query, qlen)).get(); // Send the initial query to warm up the connection. auto r = transport.query(netdutils::makeSlice(query)).get(); if (r.code != Response::success) { LOG(WARNING) << "query failed"; return false; } const std::vector<uint8_t>& recvbuf = r.response; if (recvbuf.size() < NS_HFIXEDSZ) { LOG(WARNING) << "short response: " << replylen; if (auto result = checkDnsResponse(r.response); !result.ok()) { LOG(WARNING) << "checkDnsResponse failed: " << result.error().message(); return false; } const int qdcount = (recvbuf[4] << 8) | recvbuf[5]; if (qdcount != 1) { LOG(WARNING) << "reply query count != 1: " << qdcount; return false; // If this validation is not for opportunistic mode, or the flags are not properly set, // the validation is done. If not, the validation will compare DoT probe latency and // UDP probe latency, and it will pass if: // dot_probe_latency < latencyFactor * udp_probe_latency + latencyOffsetMs // // For instance, with latencyFactor = 3 and latencyOffsetMs = 10, if UDP probe latency is 5 ms, // DoT probe latency must less than 25 ms. int latencyFactor = Experiments::getInstance()->getFlag("dot_validation_latency_factor", -1); int latencyOffsetMs = Experiments::getInstance()->getFlag("dot_validation_latency_offset_ms", -1); const bool shouldCompareUdpLatency = server.name.empty() && (latencyFactor >= 0 && latencyOffsetMs >= 0 && latencyFactor + latencyOffsetMs != 0); if (!shouldCompareUdpLatency) { return true; } const int ancount = (recvbuf[6] << 8) | recvbuf[7]; LOG(DEBUG) << "answer count: " << ancount; LOG(INFO) << fmt::format("Use flags: latencyFactor={}, latencyOffsetMs={}", latencyFactor, latencyOffsetMs); // TODO: Further validate the response contents (check for valid AAAA record, ...). // Note that currently, integration tests rely on this function accepting a // response with zero records. int64_t udpProbeTimeUs = 0; bool udpProbeGotAnswer = false; std::thread udpProbeThread([&] { // Can issue another probe if the first one fails or is lost. for (int i = 1; i < 3; i++) { netdutils::Stopwatch stopwatch; auto result = sendUdpQuery(server.addr().ip(), mark, query); udpProbeTimeUs = stopwatch.timeTakenUs(); udpProbeGotAnswer = result.ok(); LOG(INFO) << fmt::format("UDP probe for {} {}, took {:.3f}ms", server.toIpString(), (udpProbeGotAnswer ? "succeeded" : "failed"), udpProbeTimeUs / 1000.0); return true; if (udpProbeGotAnswer) { break; } LOG(WARNING) << "sendUdpQuery attempt " << i << " failed: " << result.error().message(); } }); int64_t dotProbeTimeUs = 0; bool dotProbeGotAnswer = false; std::thread dotProbeThread([&] { netdutils::Stopwatch stopwatch; auto r = transport.query(netdutils::makeSlice(query)).get(); dotProbeTimeUs = stopwatch.timeTakenUs(); if (r.code != Response::success) { LOG(WARNING) << "query failed"; } else { if (auto result = checkDnsResponse(r.response); !result.ok()) { LOG(WARNING) << "checkDnsResponse failed: " << result.error().message(); } else { dotProbeGotAnswer = true; } } LOG(INFO) << fmt::format("DoT probe for {} {}, took {:.3f}ms", server.toIpString(), (dotProbeGotAnswer ? "succeeded" : "failed"), dotProbeTimeUs / 1000.0); }); // TODO: If DoT probe thread finishes before UDP probe thread and dotProbeGotAnswer is false, // actively cancel UDP probe thread. dotProbeThread.join(); udpProbeThread.join(); if (!dotProbeGotAnswer) return false; if (!udpProbeGotAnswer) return true; return dotProbeTimeUs < (latencyFactor * udpProbeTimeUs + latencyOffsetMs * 1000); } } // end of namespace net Loading Experiments.h +11 −3 Original line number Diff line number Diff line Loading @@ -49,10 +49,18 @@ class Experiments { // TODO: Migrate other experiment flags to here. // (retry_count, retransmission_time_interval) static constexpr const char* const kExperimentFlagKeyList[] = { "keep_listening_udp", "parallel_lookup_release", "parallel_lookup_sleep_time", "sort_nameservers", "dot_async_handshake", "dot_connect_timeout_ms", "dot_maxtries", "dot_revalidation_threshold", "dot_xport_unusable_threshold", "keep_listening_udp", "parallel_lookup_release", "parallel_lookup_sleep_time", "sort_nameservers", "dot_async_handshake", "dot_connect_timeout_ms", "dot_maxtries", "dot_revalidation_threshold", "dot_xport_unusable_threshold", "dot_query_timeout_ms", "dot_validation_latency_factor", "dot_validation_latency_offset_ms", }; // This value is used in updateInternal as the default value if any flags can't be found. static constexpr int kFlagIntDefault = INT_MIN; Loading PrivateDnsConfigurationTest.cpp +4 −0 Original line number Diff line number Diff line Loading @@ -35,6 +35,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test { ASSERT_TRUE(tls1.startServer()); ASSERT_TRUE(tls2.startServer()); ASSERT_TRUE(backend.startServer()); ASSERT_TRUE(backend1ForUdpProbe.startServer()); ASSERT_TRUE(backend2ForUdpProbe.startServer()); } void SetUp() { Loading Loading @@ -132,6 +134,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test { inline static test::DnsTlsFrontend tls1{kServer1, "853", kBackend, "53"}; inline static test::DnsTlsFrontend tls2{kServer2, "853", kBackend, "53"}; inline static test::DNSResponder backend{kBackend, "53"}; inline static test::DNSResponder backend1ForUdpProbe{kServer1, "53"}; inline static test::DNSResponder backend2ForUdpProbe{kServer2, "53"}; }; TEST_F(PrivateDnsConfigurationTest, ValidationSuccess) { Loading res_send.cpp +54 −42 Original line number Diff line number Diff line Loading @@ -146,11 +146,12 @@ using android::netdutils::Stopwatch; static int send_vc(res_state statp, res_params* params, const uint8_t* buf, int buflen, uint8_t* ans, int anssiz, int* terrno, size_t ns, time_t* at, int* rcode, int* delay); static int setupUdpSocket(ResState* statp, const sockaddr* sockap, size_t addrIndex, int* terrno); static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen, uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit, int* gotsomewhere, time_t* at, int* rcode, int* delay); static void dump_error(const char*, const struct sockaddr*, int); static void dump_error(const char*, const struct sockaddr*); static int sock_eq(struct sockaddr*, struct sockaddr*); static int connect_with_timeout(int sock, const struct sockaddr* nsap, socklen_t salen, Loading Loading @@ -726,14 +727,14 @@ same_ns: errno = 0; if (random_bind(statp->tcp_nssock, nsap->sa_family) < 0) { *terrno = errno; dump_error("bind/vc", nsap, nsaplen); dump_error("bind/vc", nsap); statp->closeSockets(); return (0); } if (connect_with_timeout(statp->tcp_nssock, nsap, (socklen_t)nsaplen, get_timeout(statp, params, ns)) < 0) { *terrno = errno; dump_error("connect/vc", nsap, nsaplen); dump_error("connect/vc", nsap); statp->closeSockets(); /* * The way connect_with_timeout() is implemented prevents us from reliably Loading Loading @@ -932,7 +933,7 @@ retry: static std::vector<pollfd> extractUdpFdset(res_state statp, const short events = POLLIN) { std::vector<pollfd> fdset(statp->nsaddrs.size()); for (size_t i = 0; i < statp->nsaddrs.size(); ++i) { fdset[i] = {.fd = statp->nssocks[i], .events = events}; fdset[i] = {.fd = statp->udpsocks[i], .events = events}; } return fdset; } Loading Loading @@ -969,10 +970,10 @@ static Result<std::vector<int>> udpRetryingPollWrapper(res_state statp, int ns, android::net::Experiments::getInstance()->getFlag("keep_listening_udp", 0); if (keepListeningUdp) return udpRetryingPoll(statp, finish); if (int n = retrying_poll(statp->nssocks[ns], POLLIN, finish); n <= 0) { if (int n = retrying_poll(statp->udpsocks[ns], POLLIN, finish); n <= 0) { return ErrnoError(); } return std::vector<int>{statp->nssocks[ns]}; return std::vector<int>{statp->udpsocks[ns]}; } bool ignoreInvalidAnswer(res_state statp, const sockaddr_storage& from, const uint8_t* buf, Loading @@ -997,66 +998,76 @@ bool ignoreInvalidAnswer(res_state statp, const sockaddr_storage& from, const ui return false; } static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen, uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit, int* gotsomewhere, time_t* at, int* rcode, int* delay) { // It should never happen, but just in case. if (*ns >= statp->nsaddrs.size()) { LOG(ERROR) << __func__ << ": Out-of-bound indexing: " << ns; *terrno = EINVAL; return -1; } // return 1 when setup udp socket success. // return 0 when timeout , bind error, network error(ex: Protocol not supported ...). // return -1 when create socket fail, set socket option fail. static int setupUdpSocket(ResState* statp, const sockaddr* sockap, size_t addrIndex, int* terrno) { statp->udpsocks[addrIndex].reset(socket(sockap->sa_family, SOCK_DGRAM | SOCK_CLOEXEC, 0)); *at = time(nullptr); *delay = 0; const sockaddr_storage ss = statp->nsaddrs[*ns]; const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss); const int nsaplen = sockaddrSize(nsap); if (statp->nssocks[*ns] == -1) { statp->nssocks[*ns].reset(socket(nsap->sa_family, SOCK_DGRAM | SOCK_CLOEXEC, 0)); if (statp->nssocks[*ns] < 0) { if (statp->udpsocks[addrIndex] < 0) { *terrno = errno; PLOG(DEBUG) << __func__ << ": socket(dg): "; PLOG(ERROR) << __func__ << ": socket: "; switch (errno) { case EPROTONOSUPPORT: case EPFNOSUPPORT: case EAFNOSUPPORT: return (0); return 0; default: return (-1); return -1; } } const uid_t uid = statp->enforce_dns_uid ? AID_DNS : statp->uid; resolv_tag_socket(statp->nssocks[*ns], uid, statp->pid); resolv_tag_socket(statp->udpsocks[addrIndex], uid, statp->pid); if (statp->_mark != MARK_UNSET) { if (setsockopt(statp->nssocks[*ns], SOL_SOCKET, SO_MARK, &(statp->_mark), if (setsockopt(statp->udpsocks[addrIndex], SOL_SOCKET, SO_MARK, &(statp->_mark), sizeof(statp->_mark)) < 0) { *terrno = errno; statp->closeSockets(); return -1; } } if (random_bind(statp->udpsocks[addrIndex], sockap->sa_family) < 0) { *terrno = errno; dump_error("bind", sockap); statp->closeSockets(); return 0; } return 1; } static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen, uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit, int* gotsomewhere, time_t* at, int* rcode, int* delay) { // It should never happen, but just in case. if (*ns >= statp->nsaddrs.size()) { LOG(ERROR) << __func__ << ": Out-of-bound indexing: " << ns; *terrno = EINVAL; return -1; } *at = time(nullptr); *delay = 0; const sockaddr_storage ss = statp->nsaddrs[*ns]; const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss); if (statp->udpsocks[*ns] == -1) { int result = setupUdpSocket(statp, nsap, *ns, terrno); if (result <= 0) return result; // Use a "connected" datagram socket to receive an ECONNREFUSED error // on the next socket operation when the server responds with an // ICMP port-unreachable error. This way we can detect the absence of // a nameserver without timing out. if (random_bind(statp->nssocks[*ns], nsap->sa_family) < 0) { *terrno = errno; dump_error("bind(dg)", nsap, nsaplen); statp->closeSockets(); return (0); } if (connect(statp->nssocks[*ns], nsap, (socklen_t)nsaplen) < 0) { if (connect(statp->udpsocks[*ns], nsap, sockaddrSize(nsap)) < 0) { *terrno = errno; dump_error("connect(dg)", nsap, nsaplen); dump_error("connect(dg)", nsap); statp->closeSockets(); return (0); } LOG(DEBUG) << __func__ << ": new DG socket"; } if (send(statp->nssocks[*ns], (const char*)buf, (size_t)buflen, 0) != buflen) { if (send(statp->udpsocks[*ns], (const char*)buf, (size_t)buflen, 0) != buflen) { *terrno = errno; PLOG(DEBUG) << __func__ << ": send: "; statp->closeSockets(); Loading Loading @@ -1150,7 +1161,7 @@ static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int } } static void dump_error(const char* str, const struct sockaddr* address, int alen) { static void dump_error(const char* str, const struct sockaddr* address) { char hbuf[NI_MAXHOST]; char sbuf[NI_MAXSERV]; constexpr int niflags = NI_NUMERICHOST | NI_NUMERICSERV; Loading @@ -1158,7 +1169,8 @@ static void dump_error(const char* str, const struct sockaddr* address, int alen if (!WOULD_LOG(DEBUG)) return; if (getnameinfo(address, (socklen_t)alen, hbuf, sizeof(hbuf), sbuf, sizeof(sbuf), niflags)) { if (getnameinfo(address, sockaddrSize(address), hbuf, sizeof(hbuf), sbuf, sizeof(sbuf), niflags)) { strncpy(hbuf, "?", sizeof(hbuf) - 1); hbuf[sizeof(hbuf) - 1] = '\0'; strncpy(sbuf, "?", sizeof(sbuf) - 1); Loading resolv_private.h +2 −2 Original line number Diff line number Diff line Loading @@ -119,7 +119,7 @@ struct ResState { tcp_nssock.reset(); _flags &= ~RES_F_VC; for (auto& sock : nssocks) { for (auto& sock : udpsocks) { sock.reset(); } } Loading @@ -132,7 +132,7 @@ struct ResState { pid_t pid; // pid of the app that sent the DNS lookup std::vector<std::string> search_domains{}; // domains to search std::vector<android::netdutils::IPSockAddr> nsaddrs; android::base::unique_fd nssocks[MAXNS]; // UDP sockets to nameservers android::base::unique_fd udpsocks[MAXNS]; // UDP sockets to nameservers and mdns responsder unsigned ndots : 4 = 1; // threshold for initial abs. query unsigned _mark; // If non-0 SET_MARK to _mark on all request sockets android::base::unique_fd tcp_nssock; // TCP socket (but why not one per nameserver?) Loading Loading
DnsTlsTransport.cpp +189 −44 Original line number Diff line number Diff line Loading @@ -18,14 +18,24 @@ #include "DnsTlsTransport.h" #include <span> #include <android-base/format.h> #include <android-base/logging.h> #include <android-base/result.h> #include <android-base/stringprintf.h> #include <arpa/inet.h> #include <arpa/nameser.h> #include <netdutils/Stopwatch.h> #include <netdutils/ThreadUtil.h> #include <private/android_filesystem_config.h> // AID_DNS #include <sys/poll.h> #include "DnsTlsSocketFactory.h" #include "Experiments.h" #include "IDnsTlsSocketFactory.h" #include "resolv_private.h" #include "util.h" using android::base::StringPrintf; using android::netdutils::setThreadName; Loading @@ -33,6 +43,113 @@ using android::netdutils::setThreadName; namespace android { namespace net { namespace { // Make a DNS query for the hostname "<random>-dnsotls-ds.metric.gstatic.com". std::vector<uint8_t> makeDnsQuery() { static const char kDnsSafeChars[] = "abcdefhijklmnopqrstuvwxyz" "ABCDEFHIJKLMNOPQRSTUVWXYZ" "0123456789"; const auto c = [](uint8_t rnd) -> uint8_t { return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))]; }; uint8_t rnd[8]; arc4random_buf(rnd, std::size(rnd)); return std::vector<uint8_t>{ rnd[6], rnd[7], // [0-1] query ID 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD). 0, 1, // [4-5] QDCOUNT (number of queries) 0, 0, // [6-7] ANCOUNT (number of answers) 0, 0, // [8-9] NSCOUNT (number of name server records) 0, 0, // [10-11] ARCOUNT (number of additional records) 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's', 6, 'm', 'e', 't', 'r', 'i', 'c', 7, 'g', 's', 't', 'a', 't', 'i', 'c', 3, 'c', 'o', 'm', 0, // null terminator of FQDN (root TLD) 0, ns_t_aaaa, // QTYPE 0, ns_c_in // QCLASS }; } base::Result<void> checkDnsResponse(const std::span<const uint8_t> answer) { if (answer.size() < NS_HFIXEDSZ) { return Errorf("short response: {}", answer.size()); } const int qdcount = (answer[4] << 8) | answer[5]; if (qdcount != 1) { return Errorf("reply query count != 1: {}", qdcount); } const int ancount = (answer[6] << 8) | answer[7]; LOG(DEBUG) << "answer count: " << ancount; // TODO: Further validate the response contents (check for valid AAAA record, ...). // Note that currently, integration tests rely on this function accepting a // response with zero records. return {}; } // Sends |query| to the given server, and returns the DNS response. base::Result<void> sendUdpQuery(netdutils::IPAddress ip, uint32_t mark, std::span<const uint8_t> query) { const sockaddr_storage ss = netdutils::IPSockAddr(ip, 53); const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss); const int nsaplen = sockaddrSize(nsap); const int sockType = SOCK_DGRAM | SOCK_NONBLOCK | SOCK_CLOEXEC; android::base::unique_fd fd{socket(nsap->sa_family, sockType, 0)}; if (fd < 0) { return ErrnoErrorf("socket failed"); } resolv_tag_socket(fd.get(), AID_DNS, NET_CONTEXT_INVALID_PID); if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mark, sizeof(mark)) < 0) { return ErrnoErrorf("setsockopt failed"); } if (connect(fd.get(), nsap, (socklen_t)nsaplen) < 0) { return ErrnoErrorf("connect failed"); } if (send(fd, query.data(), query.size(), 0) != query.size()) { return ErrnoErrorf("send failed"); } const int timeoutMs = 3000; while (true) { pollfd fds = {.fd = fd, .events = POLLIN}; const int n = TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs)); if (n == 0) { return Errorf("poll timed out"); } if (n < 0) { return ErrnoErrorf("poll failed"); } if (fds.revents & (POLLIN | POLLERR)) { std::vector<uint8_t> buf(MAXPACKET); const int resplen = recv(fd, buf.data(), buf.size(), 0); if (resplen < 0) { return ErrnoErrorf("recvfrom failed"); } buf.resize(resplen); if (auto result = checkDnsResponse(buf); !result.ok()) { return Errorf("checkDnsResponse failed: {}", result.error().message()); } return {}; } } } } // namespace std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) { std::lock_guard guard(mLock); Loading Loading @@ -160,65 +277,93 @@ DnsTlsTransport::~DnsTlsTransport() { // That may require moving it to DnsTlsDispatcher. bool DnsTlsTransport::validate(const DnsTlsServer& server, uint32_t mark) { LOG(DEBUG) << "Beginning validation with mark " << std::hex << mark; // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in // order to prove that it is actually a working DNS over TLS server. static const char kDnsSafeChars[] = "abcdefhijklmnopqrstuvwxyz" "ABCDEFHIJKLMNOPQRSTUVWXYZ" "0123456789"; const auto c = [](uint8_t rnd) -> uint8_t { return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))]; }; uint8_t rnd[8]; arc4random_buf(rnd, std::size(rnd)); // We could try to use res_mkquery() here, but it's basically the same. uint8_t query[] = { rnd[6], rnd[7], // [0-1] query ID 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD). 0, 1, // [4-5] QDCOUNT (number of queries) 0, 0, // [6-7] ANCOUNT (number of answers) 0, 0, // [8-9] NSCOUNT (number of name server records) 0, 0, // [10-11] ARCOUNT (number of additional records) 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's', 6, 'm', 'e', 't', 'r', 'i', 'c', 7, 'g', 's', 't', 'a', 't', 'i', 'c', 3, 'c', 'o', 'm', 0, // null terminator of FQDN (root TLD) 0, ns_t_aaaa, // QTYPE 0, ns_c_in // QCLASS }; const int qlen = std::size(query); int replylen = 0; const std::vector<uint8_t> query = makeDnsQuery(); DnsTlsSocketFactory factory; DnsTlsTransport transport(server, mark, &factory); auto r = transport.query(netdutils::Slice(query, qlen)).get(); // Send the initial query to warm up the connection. auto r = transport.query(netdutils::makeSlice(query)).get(); if (r.code != Response::success) { LOG(WARNING) << "query failed"; return false; } const std::vector<uint8_t>& recvbuf = r.response; if (recvbuf.size() < NS_HFIXEDSZ) { LOG(WARNING) << "short response: " << replylen; if (auto result = checkDnsResponse(r.response); !result.ok()) { LOG(WARNING) << "checkDnsResponse failed: " << result.error().message(); return false; } const int qdcount = (recvbuf[4] << 8) | recvbuf[5]; if (qdcount != 1) { LOG(WARNING) << "reply query count != 1: " << qdcount; return false; // If this validation is not for opportunistic mode, or the flags are not properly set, // the validation is done. If not, the validation will compare DoT probe latency and // UDP probe latency, and it will pass if: // dot_probe_latency < latencyFactor * udp_probe_latency + latencyOffsetMs // // For instance, with latencyFactor = 3 and latencyOffsetMs = 10, if UDP probe latency is 5 ms, // DoT probe latency must less than 25 ms. int latencyFactor = Experiments::getInstance()->getFlag("dot_validation_latency_factor", -1); int latencyOffsetMs = Experiments::getInstance()->getFlag("dot_validation_latency_offset_ms", -1); const bool shouldCompareUdpLatency = server.name.empty() && (latencyFactor >= 0 && latencyOffsetMs >= 0 && latencyFactor + latencyOffsetMs != 0); if (!shouldCompareUdpLatency) { return true; } const int ancount = (recvbuf[6] << 8) | recvbuf[7]; LOG(DEBUG) << "answer count: " << ancount; LOG(INFO) << fmt::format("Use flags: latencyFactor={}, latencyOffsetMs={}", latencyFactor, latencyOffsetMs); // TODO: Further validate the response contents (check for valid AAAA record, ...). // Note that currently, integration tests rely on this function accepting a // response with zero records. int64_t udpProbeTimeUs = 0; bool udpProbeGotAnswer = false; std::thread udpProbeThread([&] { // Can issue another probe if the first one fails or is lost. for (int i = 1; i < 3; i++) { netdutils::Stopwatch stopwatch; auto result = sendUdpQuery(server.addr().ip(), mark, query); udpProbeTimeUs = stopwatch.timeTakenUs(); udpProbeGotAnswer = result.ok(); LOG(INFO) << fmt::format("UDP probe for {} {}, took {:.3f}ms", server.toIpString(), (udpProbeGotAnswer ? "succeeded" : "failed"), udpProbeTimeUs / 1000.0); return true; if (udpProbeGotAnswer) { break; } LOG(WARNING) << "sendUdpQuery attempt " << i << " failed: " << result.error().message(); } }); int64_t dotProbeTimeUs = 0; bool dotProbeGotAnswer = false; std::thread dotProbeThread([&] { netdutils::Stopwatch stopwatch; auto r = transport.query(netdutils::makeSlice(query)).get(); dotProbeTimeUs = stopwatch.timeTakenUs(); if (r.code != Response::success) { LOG(WARNING) << "query failed"; } else { if (auto result = checkDnsResponse(r.response); !result.ok()) { LOG(WARNING) << "checkDnsResponse failed: " << result.error().message(); } else { dotProbeGotAnswer = true; } } LOG(INFO) << fmt::format("DoT probe for {} {}, took {:.3f}ms", server.toIpString(), (dotProbeGotAnswer ? "succeeded" : "failed"), dotProbeTimeUs / 1000.0); }); // TODO: If DoT probe thread finishes before UDP probe thread and dotProbeGotAnswer is false, // actively cancel UDP probe thread. dotProbeThread.join(); udpProbeThread.join(); if (!dotProbeGotAnswer) return false; if (!udpProbeGotAnswer) return true; return dotProbeTimeUs < (latencyFactor * udpProbeTimeUs + latencyOffsetMs * 1000); } } // end of namespace net Loading
Experiments.h +11 −3 Original line number Diff line number Diff line Loading @@ -49,10 +49,18 @@ class Experiments { // TODO: Migrate other experiment flags to here. // (retry_count, retransmission_time_interval) static constexpr const char* const kExperimentFlagKeyList[] = { "keep_listening_udp", "parallel_lookup_release", "parallel_lookup_sleep_time", "sort_nameservers", "dot_async_handshake", "dot_connect_timeout_ms", "dot_maxtries", "dot_revalidation_threshold", "dot_xport_unusable_threshold", "keep_listening_udp", "parallel_lookup_release", "parallel_lookup_sleep_time", "sort_nameservers", "dot_async_handshake", "dot_connect_timeout_ms", "dot_maxtries", "dot_revalidation_threshold", "dot_xport_unusable_threshold", "dot_query_timeout_ms", "dot_validation_latency_factor", "dot_validation_latency_offset_ms", }; // This value is used in updateInternal as the default value if any flags can't be found. static constexpr int kFlagIntDefault = INT_MIN; Loading
PrivateDnsConfigurationTest.cpp +4 −0 Original line number Diff line number Diff line Loading @@ -35,6 +35,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test { ASSERT_TRUE(tls1.startServer()); ASSERT_TRUE(tls2.startServer()); ASSERT_TRUE(backend.startServer()); ASSERT_TRUE(backend1ForUdpProbe.startServer()); ASSERT_TRUE(backend2ForUdpProbe.startServer()); } void SetUp() { Loading Loading @@ -132,6 +134,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test { inline static test::DnsTlsFrontend tls1{kServer1, "853", kBackend, "53"}; inline static test::DnsTlsFrontend tls2{kServer2, "853", kBackend, "53"}; inline static test::DNSResponder backend{kBackend, "53"}; inline static test::DNSResponder backend1ForUdpProbe{kServer1, "53"}; inline static test::DNSResponder backend2ForUdpProbe{kServer2, "53"}; }; TEST_F(PrivateDnsConfigurationTest, ValidationSuccess) { Loading
res_send.cpp +54 −42 Original line number Diff line number Diff line Loading @@ -146,11 +146,12 @@ using android::netdutils::Stopwatch; static int send_vc(res_state statp, res_params* params, const uint8_t* buf, int buflen, uint8_t* ans, int anssiz, int* terrno, size_t ns, time_t* at, int* rcode, int* delay); static int setupUdpSocket(ResState* statp, const sockaddr* sockap, size_t addrIndex, int* terrno); static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen, uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit, int* gotsomewhere, time_t* at, int* rcode, int* delay); static void dump_error(const char*, const struct sockaddr*, int); static void dump_error(const char*, const struct sockaddr*); static int sock_eq(struct sockaddr*, struct sockaddr*); static int connect_with_timeout(int sock, const struct sockaddr* nsap, socklen_t salen, Loading Loading @@ -726,14 +727,14 @@ same_ns: errno = 0; if (random_bind(statp->tcp_nssock, nsap->sa_family) < 0) { *terrno = errno; dump_error("bind/vc", nsap, nsaplen); dump_error("bind/vc", nsap); statp->closeSockets(); return (0); } if (connect_with_timeout(statp->tcp_nssock, nsap, (socklen_t)nsaplen, get_timeout(statp, params, ns)) < 0) { *terrno = errno; dump_error("connect/vc", nsap, nsaplen); dump_error("connect/vc", nsap); statp->closeSockets(); /* * The way connect_with_timeout() is implemented prevents us from reliably Loading Loading @@ -932,7 +933,7 @@ retry: static std::vector<pollfd> extractUdpFdset(res_state statp, const short events = POLLIN) { std::vector<pollfd> fdset(statp->nsaddrs.size()); for (size_t i = 0; i < statp->nsaddrs.size(); ++i) { fdset[i] = {.fd = statp->nssocks[i], .events = events}; fdset[i] = {.fd = statp->udpsocks[i], .events = events}; } return fdset; } Loading Loading @@ -969,10 +970,10 @@ static Result<std::vector<int>> udpRetryingPollWrapper(res_state statp, int ns, android::net::Experiments::getInstance()->getFlag("keep_listening_udp", 0); if (keepListeningUdp) return udpRetryingPoll(statp, finish); if (int n = retrying_poll(statp->nssocks[ns], POLLIN, finish); n <= 0) { if (int n = retrying_poll(statp->udpsocks[ns], POLLIN, finish); n <= 0) { return ErrnoError(); } return std::vector<int>{statp->nssocks[ns]}; return std::vector<int>{statp->udpsocks[ns]}; } bool ignoreInvalidAnswer(res_state statp, const sockaddr_storage& from, const uint8_t* buf, Loading @@ -997,66 +998,76 @@ bool ignoreInvalidAnswer(res_state statp, const sockaddr_storage& from, const ui return false; } static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen, uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit, int* gotsomewhere, time_t* at, int* rcode, int* delay) { // It should never happen, but just in case. if (*ns >= statp->nsaddrs.size()) { LOG(ERROR) << __func__ << ": Out-of-bound indexing: " << ns; *terrno = EINVAL; return -1; } // return 1 when setup udp socket success. // return 0 when timeout , bind error, network error(ex: Protocol not supported ...). // return -1 when create socket fail, set socket option fail. static int setupUdpSocket(ResState* statp, const sockaddr* sockap, size_t addrIndex, int* terrno) { statp->udpsocks[addrIndex].reset(socket(sockap->sa_family, SOCK_DGRAM | SOCK_CLOEXEC, 0)); *at = time(nullptr); *delay = 0; const sockaddr_storage ss = statp->nsaddrs[*ns]; const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss); const int nsaplen = sockaddrSize(nsap); if (statp->nssocks[*ns] == -1) { statp->nssocks[*ns].reset(socket(nsap->sa_family, SOCK_DGRAM | SOCK_CLOEXEC, 0)); if (statp->nssocks[*ns] < 0) { if (statp->udpsocks[addrIndex] < 0) { *terrno = errno; PLOG(DEBUG) << __func__ << ": socket(dg): "; PLOG(ERROR) << __func__ << ": socket: "; switch (errno) { case EPROTONOSUPPORT: case EPFNOSUPPORT: case EAFNOSUPPORT: return (0); return 0; default: return (-1); return -1; } } const uid_t uid = statp->enforce_dns_uid ? AID_DNS : statp->uid; resolv_tag_socket(statp->nssocks[*ns], uid, statp->pid); resolv_tag_socket(statp->udpsocks[addrIndex], uid, statp->pid); if (statp->_mark != MARK_UNSET) { if (setsockopt(statp->nssocks[*ns], SOL_SOCKET, SO_MARK, &(statp->_mark), if (setsockopt(statp->udpsocks[addrIndex], SOL_SOCKET, SO_MARK, &(statp->_mark), sizeof(statp->_mark)) < 0) { *terrno = errno; statp->closeSockets(); return -1; } } if (random_bind(statp->udpsocks[addrIndex], sockap->sa_family) < 0) { *terrno = errno; dump_error("bind", sockap); statp->closeSockets(); return 0; } return 1; } static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen, uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit, int* gotsomewhere, time_t* at, int* rcode, int* delay) { // It should never happen, but just in case. if (*ns >= statp->nsaddrs.size()) { LOG(ERROR) << __func__ << ": Out-of-bound indexing: " << ns; *terrno = EINVAL; return -1; } *at = time(nullptr); *delay = 0; const sockaddr_storage ss = statp->nsaddrs[*ns]; const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss); if (statp->udpsocks[*ns] == -1) { int result = setupUdpSocket(statp, nsap, *ns, terrno); if (result <= 0) return result; // Use a "connected" datagram socket to receive an ECONNREFUSED error // on the next socket operation when the server responds with an // ICMP port-unreachable error. This way we can detect the absence of // a nameserver without timing out. if (random_bind(statp->nssocks[*ns], nsap->sa_family) < 0) { *terrno = errno; dump_error("bind(dg)", nsap, nsaplen); statp->closeSockets(); return (0); } if (connect(statp->nssocks[*ns], nsap, (socklen_t)nsaplen) < 0) { if (connect(statp->udpsocks[*ns], nsap, sockaddrSize(nsap)) < 0) { *terrno = errno; dump_error("connect(dg)", nsap, nsaplen); dump_error("connect(dg)", nsap); statp->closeSockets(); return (0); } LOG(DEBUG) << __func__ << ": new DG socket"; } if (send(statp->nssocks[*ns], (const char*)buf, (size_t)buflen, 0) != buflen) { if (send(statp->udpsocks[*ns], (const char*)buf, (size_t)buflen, 0) != buflen) { *terrno = errno; PLOG(DEBUG) << __func__ << ": send: "; statp->closeSockets(); Loading Loading @@ -1150,7 +1161,7 @@ static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int } } static void dump_error(const char* str, const struct sockaddr* address, int alen) { static void dump_error(const char* str, const struct sockaddr* address) { char hbuf[NI_MAXHOST]; char sbuf[NI_MAXSERV]; constexpr int niflags = NI_NUMERICHOST | NI_NUMERICSERV; Loading @@ -1158,7 +1169,8 @@ static void dump_error(const char* str, const struct sockaddr* address, int alen if (!WOULD_LOG(DEBUG)) return; if (getnameinfo(address, (socklen_t)alen, hbuf, sizeof(hbuf), sbuf, sizeof(sbuf), niflags)) { if (getnameinfo(address, sockaddrSize(address), hbuf, sizeof(hbuf), sbuf, sizeof(sbuf), niflags)) { strncpy(hbuf, "?", sizeof(hbuf) - 1); hbuf[sizeof(hbuf) - 1] = '\0'; strncpy(sbuf, "?", sizeof(sbuf) - 1); Loading
resolv_private.h +2 −2 Original line number Diff line number Diff line Loading @@ -119,7 +119,7 @@ struct ResState { tcp_nssock.reset(); _flags &= ~RES_F_VC; for (auto& sock : nssocks) { for (auto& sock : udpsocks) { sock.reset(); } } Loading @@ -132,7 +132,7 @@ struct ResState { pid_t pid; // pid of the app that sent the DNS lookup std::vector<std::string> search_domains{}; // domains to search std::vector<android::netdutils::IPSockAddr> nsaddrs; android::base::unique_fd nssocks[MAXNS]; // UDP sockets to nameservers android::base::unique_fd udpsocks[MAXNS]; // UDP sockets to nameservers and mdns responsder unsigned ndots : 4 = 1; // threshold for initial abs. query unsigned _mark; // If non-0 SET_MARK to _mark on all request sockets android::base::unique_fd tcp_nssock; // TCP socket (but why not one per nameserver?) Loading