Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 420ee62b authored by Luke Huang's avatar Luke Huang
Browse files

Keep listening for UDP responses after sending the next retry

Bug: 135717624
Test: atest
Change-Id: I3c970b638fb763ea209a62624f23c5ea85370fe9
parent 44e38183
Loading
Loading
Loading
Loading
+163 −101
Original line number Diff line number Diff line
@@ -95,6 +95,7 @@
#include <unistd.h>

#include <android-base/logging.h>
#include <android-base/result.h>
#include <android/multinetwork.h>  // ResNsendFlags

#include <netdutils/Slice.h>
@@ -114,6 +115,8 @@
#include "util.h"

// TODO: use the namespace something like android::netd_resolv for libnetd_resolv
using android::base::ErrnoError;
using android::base::Result;
using android::net::CacheStatus;
using android::net::DnsQueryEvent;
using android::net::DnsTlsDispatcher;
@@ -142,7 +145,7 @@ static int send_vc(res_state statp, res_params* params, const uint8_t* buf, int
                   uint8_t* ans, int anssiz, int* terrno, size_t ns, time_t* at, int* rcode,
                   int* delay);
static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen,
                   uint8_t* ans, int anssiz, int* terrno, size_t ns, int* v_circuit,
                   uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit,
                   int* gotsomewhere, time_t* at, int* rcode, int* delay);

static void dump_error(const char*, const struct sockaddr*, int);
@@ -286,21 +289,23 @@ static void res_set_usable_server(int selectedServer, int nscount, bool usable_s
    }
}

// Looks up the nameserver address in res.nsaddrs[], returns true if found, otherwise false.
static bool res_ourserver_p(res_state statp, const sockaddr* sa) {
// Looks up the nameserver address in res.nsaddrs[], returns the ns number if found, otherwise -1.
static int res_ourserver_p(res_state statp, const sockaddr* sa) {
    const sockaddr_in *inp, *srv;
    const sockaddr_in6 *in6p, *srv6;

    int ns = 0;
    switch (sa->sa_family) {
        case AF_INET:
            inp = (const struct sockaddr_in*) (const void*) sa;

            for (const IPSockAddr& ipsa : statp->nsaddrs) {
                sockaddr_storage ss = ipsa;
                srv = reinterpret_cast<sockaddr_in*>(&ss);
                if (srv->sin_family == inp->sin_family && srv->sin_port == inp->sin_port &&
                    (srv->sin_addr.s_addr == INADDR_ANY ||
                     srv->sin_addr.s_addr == inp->sin_addr.s_addr))
                    return true;
                    return ns;
                ++ns;
            }
            break;
        case AF_INET6:
@@ -314,13 +319,14 @@ static bool res_ourserver_p(res_state statp, const sockaddr* sa) {
#endif
                    (IN6_IS_ADDR_UNSPECIFIED(&srv6->sin6_addr) ||
                     IN6_ARE_ADDR_EQUAL(&srv6->sin6_addr, &in6p->sin6_addr)))
                    return true;
                    return ns;
                ++ns;
            }
            break;
        default:
            break;
    }
    return false;
    return -1;
}

/* int
@@ -498,18 +504,19 @@ int res_nsend(res_state statp, const uint8_t* buf, int buflen, uint8_t* ans, int
                       << ") address = " << serverSockAddr.toString();

            ::android::net::Protocol query_proto = useTcp ? PROTO_TCP : PROTO_UDP;
            time_t now = 0;
            time_t query_time = 0;
            int delay = 0;
            bool fallbackTCP = false;
            const bool shouldRecordStats = (attempt == 0);
            int resplen;
            Stopwatch queryStopwatch;
            int retry_count_for_event = 0;
            size_t actualNs = ns;
            if (useTcp) {
                // TCP; at most one attempt per server.
                attempt = retryTimes;
                resplen = send_vc(statp, &params, buf, buflen, ans, anssiz, &terrno, ns, &now,
                                  rcode, &delay);
                resplen = send_vc(statp, &params, buf, buflen, ans, anssiz, &terrno, ns,
                                  &query_time, rcode, &delay);

                if (buflen <= PACKETSZ && resplen <= 0 &&
                    statp->tc_mode == aidl::android::net::IDnsResolver::TC_MODE_UDP_TCP) {
@@ -520,18 +527,24 @@ int res_nsend(res_state statp, const uint8_t* buf, int buflen, uint8_t* ans, int
                LOG(INFO) << __func__ << ": used send_vc " << resplen;
            } else {
                // UDP
                resplen = send_dg(statp, &params, buf, buflen, ans, anssiz, &terrno, ns, &useTcp,
                                  &gotsomewhere, &now, rcode, &delay);
                resplen = send_dg(statp, &params, buf, buflen, ans, anssiz, &terrno, &actualNs,
                                  &useTcp, &gotsomewhere, &query_time, rcode, &delay);
                fallbackTCP = useTcp ? true : false;
                retry_count_for_event = attempt;
                LOG(INFO) << __func__ << ": used send_dg " << resplen;
            }

            const IPSockAddr& receivedServerAddr = statp->nsaddrs[actualNs];
            DnsQueryEvent* dnsQueryEvent = addDnsQueryEvent(statp->event);
            dnsQueryEvent->set_cache_hit(static_cast<CacheStatus>(cache_status));
            dnsQueryEvent->set_latency_micros(saturate_cast<int32_t>(queryStopwatch.timeTakenUs()));
            dnsQueryEvent->set_dns_server_index(ns);
            dnsQueryEvent->set_ip_version(ipFamilyToIPVersion(serverSockAddr.family()));
            // When |retryTimes| > 1, we cannot actually know the correct latency value if we
            // received the answer from the previous server. So temporarily set the latency as -1 if
            // that condition happened.
            // TODO: make the latency value accurate.
            dnsQueryEvent->set_latency_micros(
                    (actualNs == ns) ? saturate_cast<int32_t>(queryStopwatch.timeTakenUs()) : -1);
            dnsQueryEvent->set_dns_server_index(actualNs);
            dnsQueryEvent->set_ip_version(ipFamilyToIPVersion(receivedServerAddr.family()));
            dnsQueryEvent->set_retry_times(retry_count_for_event);
            dnsQueryEvent->set_rcode(static_cast<NsRcode>(*rcode));
            dnsQueryEvent->set_protocol(query_proto);
@@ -542,10 +555,13 @@ int res_nsend(res_state statp, const uint8_t* buf, int buflen, uint8_t* ans, int
            // SERVFAIL or times out) do not unduly affect the stats.
            if (shouldRecordStats) {
                res_sample sample;
                res_stats_set_sample(&sample, now, *rcode, delay);
                res_stats_set_sample(&sample, query_time, *rcode, delay);
                // KeepListening UDP mechanism is incompatible with usable_servers of legacy stats,
                // so keep the old logic for now.
                // TODO: Replace usable_servers of legacy stats with new one.
                resolv_cache_add_resolver_stats_sample(statp->netid, revision_id, serverSockAddr,
                                                       sample, params.max_samples);
                resolv_stats_add(statp->netid, serverSockAddr, dnsQueryEvent);
                resolv_stats_add(statp->netid, receivedServerAddr, dnsQueryEvent);
            }

            if (resplen == 0) continue;
@@ -629,7 +645,7 @@ static int send_vc(res_state statp, res_params* params, const uint8_t* buf, int
same_ns:
    truncating = 0;

    struct timespec now = evNowTime();
    struct timespec start_time = evNowTime();

    /* Are we still talking to whom we want to talk to? */
    if (statp->tcp_nssock >= 0 && (statp->_flags & RES_F_VC) != 0) {
@@ -803,7 +819,7 @@ read_len:
     */
    if (resplen > 0) {
        struct timespec done = evNowTime();
        *delay = res_stats_calculate_rtt(&done, &now);
        *delay = res_stats_calculate_rtt(&done, &start_time);
        *rcode = anhp->rcode;
    }
    return (resplen);
@@ -873,8 +889,53 @@ retry:
    return n;
}

static std::vector<pollfd> extractUdpFdset(res_state statp, const short events = POLLIN) {
    std::vector<pollfd> fdset(statp->nsaddrs.size());
    for (size_t i = 0; i < statp->nsaddrs.size(); ++i) {
        fdset[i] = {.fd = statp->nssocks[i], .events = events};
    }
    return fdset;
}

static Result<std::vector<int>> udpRetryingPoll(res_state statp, const timespec* finish) {
    for (;;) {
        LOG(DEBUG) << __func__ << ": poll";
        timespec start_time = evNowTime();
        timespec timeout = (evCmpTime(*finish, start_time) > 0) ? evSubTime(*finish, start_time)
                                                                : evConsTime(0L, 0L);
        std::vector<pollfd> fdset = extractUdpFdset(statp);
        const int n = ppoll(fdset.data(), fdset.size(), &timeout, /*sigmask=*/nullptr);
        if (n <= 0) {
            if (errno == EINTR && n < 0) continue;
            if (n == 0) errno = ETIMEDOUT;
            PLOG(INFO) << __func__ << ": failed";
            return ErrnoError();
        }
        std::vector<int> fdsToRead;
        for (const auto& pollfd : fdset) {
            if (pollfd.revents & (POLLIN | POLLERR)) {
                fdsToRead.push_back(pollfd.fd);
            }
        }
        LOG(DEBUG) << __func__ << ": "
                   << " returning fd size: " << fdsToRead.size();
        return fdsToRead;
    }
}

static Result<std::vector<int>> udpRetryingPollWrapper(res_state statp, int ns,
                                                       const timespec* finish) {
    const bool keepListeningUdp = getExperimentFlagInt("keep_listening_udp", 0);
    if (keepListeningUdp) return udpRetryingPoll(statp, finish);

    if (int n = retrying_poll(statp->nssocks[ns], POLLIN, finish); n <= 0) {
        return ErrnoError();
    }
    return std::vector<int>{statp->nssocks[ns]};
}

bool ignoreInvalidAnswer(res_state statp, const sockaddr_storage& from, const uint8_t* buf,
                         int buflen, uint8_t* ans, int anssiz) {
                         int buflen, uint8_t* ans, int anssiz, int* receivedFromNs) {
    const HEADER* hp = (const HEADER*)(const void*)buf;
    HEADER* anhp = (HEADER*)(void*)ans;
    if (hp->id != anhp->id) {
@@ -882,7 +943,7 @@ bool ignoreInvalidAnswer(res_state statp, const sockaddr_storage& from, const ui
        LOG(DEBUG) << __func__ << ": old answer:";
        return true;
    }
    if (!res_ourserver_p(statp, (sockaddr*)(void*)&from)) {
    if (*receivedFromNs = res_ourserver_p(statp, (sockaddr*)(void*)&from); *receivedFromNs < 0) {
        // response from wrong server? ignore it.
        LOG(DEBUG) << __func__ << ": not our server:";
        return true;
@@ -896,23 +957,23 @@ bool ignoreInvalidAnswer(res_state statp, const sockaddr_storage& from, const ui
}

static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen,
                   uint8_t* ans, int anssiz, int* terrno, size_t ns, int* v_circuit,
                   uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit,
                   int* gotsomewhere, time_t* at, int* rcode, int* delay) {
    // It should never happen, but just in case.
    if (ns >= statp->nsaddrs.size()) {
    if (*ns >= statp->nsaddrs.size()) {
        LOG(ERROR) << __func__ << ": Out-of-bound indexing: " << ns;
        return -1;
    }

    *at = time(nullptr);
    *delay = 0;
    const sockaddr_storage ss = statp->nsaddrs[ns];
    const sockaddr_storage ss = statp->nsaddrs[*ns];
    const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss);
    const int nsaplen = sockaddrSize(nsap);

    if (statp->nssocks[ns] == -1) {
        statp->nssocks[ns].reset(socket(nsap->sa_family, SOCK_DGRAM | SOCK_CLOEXEC, 0));
        if (statp->nssocks[ns] < 0) {
    if (statp->nssocks[*ns] == -1) {
        statp->nssocks[*ns].reset(socket(nsap->sa_family, SOCK_DGRAM | SOCK_CLOEXEC, 0));
        if (statp->nssocks[*ns] < 0) {
            switch (errno) {
                case EPROTONOSUPPORT:
                case EPFNOSUPPORT:
@@ -926,9 +987,9 @@ static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int
            }
        }

        resolv_tag_socket(statp->nssocks[ns], statp->uid, statp->pid);
        resolv_tag_socket(statp->nssocks[*ns], statp->uid, statp->pid);
        if (statp->_mark != MARK_UNSET) {
            if (setsockopt(statp->nssocks[ns], SOL_SOCKET, SO_MARK, &(statp->_mark),
            if (setsockopt(statp->nssocks[*ns], SOL_SOCKET, SO_MARK, &(statp->_mark),
                           sizeof(statp->_mark)) < 0) {
                statp->closeSockets();
                return -1;
@@ -938,62 +999,64 @@ static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int
        // on the next socket operation when the server responds with an
        // ICMP port-unreachable error. This way we can detect the absence of
        // a nameserver without timing out.
        if (random_bind(statp->nssocks[ns], nsap->sa_family) < 0) {
        if (random_bind(statp->nssocks[*ns], nsap->sa_family) < 0) {
            dump_error("bind(dg)", nsap, nsaplen);
            statp->closeSockets();
            return (0);
        }
        if (connect(statp->nssocks[ns], nsap, (socklen_t)nsaplen) < 0) {
        if (connect(statp->nssocks[*ns], nsap, (socklen_t)nsaplen) < 0) {
            dump_error("connect(dg)", nsap, nsaplen);
            statp->closeSockets();
            return (0);
        }
        LOG(DEBUG) << __func__ << ": new DG socket";
    }
    if (send(statp->nssocks[ns], (const char*)buf, (size_t)buflen, 0) != buflen) {
    if (send(statp->nssocks[*ns], (const char*)buf, (size_t)buflen, 0) != buflen) {
        PLOG(DEBUG) << __func__ << ": send: ";
        statp->closeSockets();
        return 0;
    }

    timespec timeout = get_timeout(statp, params, ns);
    timespec now = evNowTime();
    timespec finish = evAddTime(now, timeout);
    timespec timeout = get_timeout(statp, params, *ns);
    timespec start_time = evNowTime();
    timespec finish = evAddTime(start_time, timeout);
    for (;;) {
        // Wait for reply.
        int n = retrying_poll(statp->nssocks[ns], POLLIN, &finish);
        if (n == 0) {
            *rcode = RCODE_TIMEOUT;
            LOG(DEBUG) << __func__ << ": timeout";
            *gotsomewhere = 1;
        auto result = udpRetryingPollWrapper(statp, *ns, &finish);

        if (!result.has_value()) {
            const bool isTimeout = (result.error().code() == ETIMEDOUT);
            *rcode = (isTimeout) ? RCODE_TIMEOUT : *rcode;
            *gotsomewhere = (isTimeout) ? 1 : *gotsomewhere;
            // Leave the UDP sockets open on timeout so we can keep listening for
            // a late response from this server while retrying on the next server.
            if (!isTimeout) statp->closeSockets();
            LOG(DEBUG) << __func__ << ": " << (isTimeout) ? "timeout" : "poll";
            return 0;
        }
        if (n < 0) {
            PLOG(DEBUG) << __func__ << ": poll: ";
            statp->closeSockets();
            return 0;
        }

        errno = 0;
        bool needRetry = false;
        for (int fd : result.value()) {
            needRetry = false;
            sockaddr_storage from;
            socklen_t fromlen = sizeof(from);
        int resplen = recvfrom(statp->nssocks[ns], (char*)ans, (size_t)anssiz, 0,
                               (sockaddr*)(void*)&from, &fromlen);
            int resplen =
                    recvfrom(fd, (char*)ans, (size_t)anssiz, 0, (sockaddr*)(void*)&from, &fromlen);
            if (resplen <= 0) {
                PLOG(DEBUG) << __func__ << ": recvfrom: ";
            statp->closeSockets();
            return 0;
                continue;
            }
            *gotsomewhere = 1;
            if (resplen < HFIXEDSZ) {
                // Undersized message.
                LOG(DEBUG) << __func__ << ": undersized: " << resplen;
                *terrno = EMSGSIZE;
            statp->closeSockets();
            return 0;
                continue;
            }

        if (ignoreInvalidAnswer(statp, from, buf, buflen, ans, anssiz)) {
            int receivedFromNs = *ns;
            if (needRetry =
                        ignoreInvalidAnswer(statp, from, buf, buflen, ans, anssiz, &receivedFromNs);
                needRetry) {
                res_pquery(ans, (resplen > anssiz) ? anssiz : resplen);
                continue;
            }
@@ -1007,34 +1070,33 @@ static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int
                res_pquery(ans, (resplen > anssiz) ? anssiz : resplen);
                // record the error
                statp->_flags |= RES_F_EDNS0ERR;
            statp->closeSockets();
            return 0;
                continue;
            }

            timespec done = evNowTime();
        *delay = res_stats_calculate_rtt(&done, &now);
            *delay = res_stats_calculate_rtt(&done, &start_time);
            if (anhp->rcode == SERVFAIL || anhp->rcode == NOTIMP || anhp->rcode == REFUSED) {
                LOG(DEBUG) << __func__ << ": server rejected query:";
                res_pquery(ans, (resplen > anssiz) ? anssiz : resplen);
            statp->closeSockets();
                *rcode = anhp->rcode;
            return 0;
                continue;
            }
            if (anhp->tc) {
                // To get the rest of answer,
                // use TCP with same server.
                LOG(DEBUG) << __func__ << ": truncated answer";
                *v_circuit = 1;
            statp->closeSockets();
                return 1;
            }
            // All is well, or the error is fatal. Signal that the
            // next nameserver ought not be tried.
        if (resplen > 0) {

            *rcode = anhp->rcode;
        }
            *ns = receivedFromNs;
            return resplen;
        }
        if (!needRetry) return 0;
    }
}

static void dump_error(const char* str, const struct sockaddr* address, int alen) {
+7 −1
Original line number Diff line number Diff line
@@ -27,9 +27,10 @@
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <set>

#include <chrono>
#include <iostream>
#include <set>
#include <vector>

#define LOG_TAG "DNSResponder"
@@ -515,6 +516,10 @@ void DNSResponder::setResponseProbability(double response_probability) {
    setResponseProbability(response_probability, IPPROTO_UDP);
}

void DNSResponder::setResponseDelayMs(unsigned timeMs) {
    response_delayed_ms_ = timeMs;
}

// Set response probability on specific protocol. It's caller's duty to ensure that the |protocol|
// can be supported by DNSResponder.
void DNSResponder::setResponseProbability(double response_probability, int protocol) {
@@ -1102,6 +1107,7 @@ void DNSResponder::handleQuery(int protocol) {
    size_t response_len = sizeof(response);
    // TODO: check whether sending malformed packets to DnsResponder
    if (handleDNSRequest(buffer, len, protocol, response, &response_len) && response_len > 0) {
        std::this_thread::sleep_for(std::chrono::milliseconds(response_delayed_ms_));
        // place wait_for after handleDNSRequest() so we can check the number of queries in
        // test case before it got responded.
        std::unique_lock guard(cv_mutex_for_deferred_resp_);
+3 −0
Original line number Diff line number Diff line
@@ -174,6 +174,7 @@ class DNSResponder {

    void setResponseProbability(double response_probability);
    void setResponseProbability(double response_probability, int protocol);
    void setResponseDelayMs(unsigned);
    void setEdns(Edns edns);
    void setTtl(unsigned ttl);
    bool running() const;
@@ -296,6 +297,8 @@ class DNSResponder {

    std::atomic<unsigned> answer_record_ttl_sec_ = kAnswerRecordTtlSec;

    std::atomic<unsigned> response_delayed_ms_ = 0;

    // Maximum number of fds for epoll.
    const int EPOLL_MAX_EVENTS = 2;

+40 −0
Original line number Diff line number Diff line
@@ -4715,3 +4715,43 @@ TEST_P(ResolverParameterizedTest, TruncatedResponse) {
    EXPECT_EQ(1U, GetNumQueriesForProtocol(dns, IPPROTO_UDP, kHelloExampleCom));
    EXPECT_EQ(1U, GetNumQueriesForProtocol(dns, IPPROTO_TCP, kHelloExampleCom));
}

TEST_F(ResolverTest, KeepListeningUDP) {
    constexpr char listen_addr1[] = "127.0.0.4";
    constexpr char listen_addr2[] = "127.0.0.5";
    constexpr char host_name[] = "howdy.example.com.";
    const std::vector<DnsRecord> records = {
            {host_name, ns_type::ns_t_aaaa, "::1.2.3.4"},
    };
    const std::vector<int> params = {300, 25, 8, 8, 1000 /* BASE_TIMEOUT_MSEC */,
                                     1 /* retry count */};
    const int delayTimeMs = 1500;

    test::DNSResponder neverRespondDns(listen_addr2, "53", static_cast<ns_rcode>(-1));
    neverRespondDns.setResponseProbability(0.0);
    StartDns(neverRespondDns, records);

    ASSERT_TRUE(mDnsClient.SetResolversForNetwork({listen_addr1, listen_addr2},
                                                  kDefaultSearchDomains, params));
    // There are 2 DNS servers for this test.
    // |delayedDns| will be blocked for |delayTimeMs|, then start to respond to requests.
    // |neverRespondDns| will never respond.
    // In the first try, resolver will send query to |delayedDns| but get timeout error
    // because |delayTimeMs| > DNS timeout.
    // Then it's the second try, resolver will send query to |neverRespondDns| and
    // listen on both servers. Resolver will receive the answer coming from |delayedDns|.
    const std::string udpKeepListeningFlag("persist.device_config.netd_native.keep_listening_udp");

    ScopedSystemProperties scopedSystemProperties(udpKeepListeningFlag, "1");
    test::DNSResponder delayedDns(listen_addr1);
    delayedDns.setResponseDelayMs(delayTimeMs);
    StartDns(delayedDns, records);

    // Specify hints to ensure resolver doing query only 1 round.
    const addrinfo hints = {.ai_family = AF_INET6, .ai_socktype = SOCK_DGRAM};
    ScopedAddrinfo result = safe_getaddrinfo(host_name, nullptr, &hints);
    EXPECT_TRUE(result != nullptr);

    std::string result_str = ToString(result);
    EXPECT_TRUE(result_str == "::1.2.3.4") << ", result_str='" << result_str << "'";
}