Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e48f7b54 authored by Mike Yu's avatar Mike Yu
Browse files

Refactor ResState to store nameserver addresses by IPSockAddr

IPSockAddr is more safer and convenient to store socket addresses,
to compare two socket addresses, and to make the code more readable.

The change also removes get_nsaddr(), a static function in res_send.cpp.

Bug: 137169582
Test: cd packages/modules/DnsResolver && atest
Change-Id: I694c293139b01a39c40cc50ba8c4f067a2ac4b07
parent 15791834
Loading
Loading
Loading
Loading
+6 −19
Original line number Original line Diff line number Diff line
@@ -1671,21 +1671,8 @@ void resolv_populate_res_for_net(ResState* statp) {
    NetConfig* info = find_netconfig_locked(statp->netid);
    NetConfig* info = find_netconfig_locked(statp->netid);
    if (info == nullptr) return;
    if (info == nullptr) return;


    // TODO: Convert nsaddrs[] to c++ container and remove the size-checking.
    statp->nscount = static_cast<int>(info->nameserverSockAddrs.size());
    const int serverNum = std::min(MAXNS, static_cast<int>(info->nameserverSockAddrs.size()));
    statp->nsaddrs = info->nameserverSockAddrs;

    for (int nserv = 0; nserv < serverNum; nserv++) {
        sockaddr_storage ss = info->nameserverSockAddrs.at(nserv);

        if (auto sockaddr_len = sockaddrSize(ss); sockaddr_len != 0) {
            memcpy(&statp->nsaddrs[nserv], &ss, sockaddr_len);
        } else {
            LOG(WARNING) << __func__ << ": can't get sa_len from "
                         << info->nameserverSockAddrs.at(nserv);
        }
    }

    statp->nscount = serverNum;
    statp->search_domains = info->search_domains;
    statp->search_domains = info->search_domains;
    statp->tc_mode = info->tc_mode;
    statp->tc_mode = info->tc_mode;
}
}
@@ -1811,18 +1798,18 @@ int resolv_cache_get_resolver_stats(unsigned netid, res_params* params, res_stat
    return -1;
    return -1;
}
}


void resolv_cache_add_resolver_stats_sample(unsigned netid, int revision_id, const sockaddr* sa,
void resolv_cache_add_resolver_stats_sample(unsigned netid, int revision_id,
                                            const IPSockAddr& serverSockAddr,
                                            const res_sample& sample, int max_samples) {
                                            const res_sample& sample, int max_samples) {
    if (max_samples <= 0 || sa == nullptr) return;
    if (max_samples <= 0) return;


    std::lock_guard guard(cache_mutex);
    std::lock_guard guard(cache_mutex);
    NetConfig* info = find_netconfig_locked(netid);
    NetConfig* info = find_netconfig_locked(netid);


    if (info && info->revision_id == revision_id) {
    if (info && info->revision_id == revision_id) {
        const int serverNum = std::min(MAXNS, static_cast<int>(info->nameserverSockAddrs.size()));
        const int serverNum = std::min(MAXNS, static_cast<int>(info->nameserverSockAddrs.size()));
        const IPSockAddr ipsa = IPSockAddr::toIPSockAddr(*sa);
        for (int ns = 0; ns < serverNum; ns++) {
        for (int ns = 0; ns < serverNum; ns++) {
            if (ipsa == info->nameserverSockAddrs.at(ns)) {
            if (serverSockAddr == info->nameserverSockAddrs[ns]) {
                res_cache_add_stats_sample_locked(&info->nsstats[ns], sample, max_samples);
                res_cache_add_stats_sample_locked(&info->nsstats[ns], sample, max_samples);
                return;
                return;
            }
            }
+1 −10
Original line number Original line Diff line number Diff line
@@ -97,17 +97,8 @@ void res_init(ResState* statp, const struct android_net_context* _Nonnull netcon
    statp->netid = netcontext->dns_netid;
    statp->netid = netcontext->dns_netid;
    statp->uid = netcontext->uid;
    statp->uid = netcontext->uid;
    statp->pid = netcontext->pid;
    statp->pid = netcontext->pid;
    statp->nscount = 1;
    statp->nscount = 0;
    statp->id = arc4random_uniform(65536);
    statp->id = arc4random_uniform(65536);
    // The following dummy initialization is probably useless because
    // it's overwritten later by resolv_populate_res_for_net().
    // TODO: check if it's safe to remove.
    const sockaddr_union u{
            .sin.sin_addr.s_addr = INADDR_ANY,
            .sin.sin_family = AF_INET,
            .sin.sin_port = htons(NAMESERVER_PORT),
    };
    memcpy(&statp->nsaddrs, &u, sizeof(u));


    for (auto& sock : statp->nssocks) {
    for (auto& sock : statp->nssocks) {
        sock.reset();
        sock.reset();
+40 −35
Original line number Original line Diff line number Diff line
@@ -138,11 +138,13 @@ using android::netdutils::Stopwatch;


static DnsTlsDispatcher sDnsTlsDispatcher;
static DnsTlsDispatcher sDnsTlsDispatcher;


static struct sockaddr* get_nsaddr(res_state, size_t);
static int send_vc(res_state statp, res_params* params, const uint8_t* buf, int buflen,
static int send_vc(res_state, res_params* params, const uint8_t*, int, uint8_t*, int, int*, int,
                   uint8_t* ans, int anssiz, int* terrno, size_t ns, time_t* at, int* rcode,
                   time_t*, int*, int*);
                   int* delay);
static int send_dg(res_state, res_params* params, const uint8_t*, int, uint8_t*, int, int*, int,
static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen,
                   int*, int*, time_t*, int*, int*);
                   uint8_t* ans, int anssiz, int* terrno, size_t ns, int* v_circuit,
                   int* gotsomewhere, time_t* at, int* rcode, int* delay);

static void dump_error(const char*, const struct sockaddr*, int);
static void dump_error(const char*, const struct sockaddr*, int);


static int sock_eq(struct sockaddr*, struct sockaddr*);
static int sock_eq(struct sockaddr*, struct sockaddr*);
@@ -288,13 +290,13 @@ static void res_set_usable_server(int selectedServer, int nscount, bool usable_s
static bool res_ourserver_p(res_state statp, const sockaddr* sa) {
static bool res_ourserver_p(res_state statp, const sockaddr* sa) {
    const sockaddr_in *inp, *srv;
    const sockaddr_in *inp, *srv;
    const sockaddr_in6 *in6p, *srv6;
    const sockaddr_in6 *in6p, *srv6;
    int ns;


    switch (sa->sa_family) {
    switch (sa->sa_family) {
        case AF_INET:
        case AF_INET:
            inp = (const struct sockaddr_in*) (const void*) sa;
            inp = (const struct sockaddr_in*) (const void*) sa;
            for (ns = 0; ns < statp->nscount; ns++) {
            for (const IPSockAddr& ipsa : statp->nsaddrs) {
                srv = (struct sockaddr_in*) (void*) get_nsaddr(statp, (size_t) ns);
                sockaddr_storage ss = ipsa;
                srv = reinterpret_cast<sockaddr_in*>(&ss);
                if (srv->sin_family == inp->sin_family && srv->sin_port == inp->sin_port &&
                if (srv->sin_family == inp->sin_family && srv->sin_port == inp->sin_port &&
                    (srv->sin_addr.s_addr == INADDR_ANY ||
                    (srv->sin_addr.s_addr == INADDR_ANY ||
                     srv->sin_addr.s_addr == inp->sin_addr.s_addr))
                     srv->sin_addr.s_addr == inp->sin_addr.s_addr))
@@ -303,8 +305,9 @@ static bool res_ourserver_p(res_state statp, const sockaddr* sa) {
            break;
            break;
        case AF_INET6:
        case AF_INET6:
            in6p = (const struct sockaddr_in6*) (const void*) sa;
            in6p = (const struct sockaddr_in6*) (const void*) sa;
            for (ns = 0; ns < statp->nscount; ns++) {
            for (const IPSockAddr& ipsa : statp->nsaddrs) {
                srv6 = (struct sockaddr_in6*) (void*) get_nsaddr(statp, (size_t) ns);
                sockaddr_storage ss = ipsa;
                srv6 = reinterpret_cast<sockaddr_in6*>(&ss);
                if (srv6->sin6_family == in6p->sin6_family && srv6->sin6_port == in6p->sin6_port &&
                if (srv6->sin6_family == in6p->sin6_family && srv6->sin6_port == in6p->sin6_port &&
#ifdef HAVE_SIN6_SCOPE_ID
#ifdef HAVE_SIN6_SCOPE_ID
                    (srv6->sin6_scope_id == 0 || srv6->sin6_scope_id == in6p->sin6_scope_id) &&
                    (srv6->sin6_scope_id == 0 || srv6->sin6_scope_id == in6p->sin6_scope_id) &&
@@ -484,20 +487,15 @@ int res_nsend(res_state statp, const uint8_t* buf, int buflen, uint8_t* ans, int
    int terrno = ETIMEDOUT;
    int terrno = ETIMEDOUT;


    for (int attempt = 0; attempt < retryTimes; ++attempt) {
    for (int attempt = 0; attempt < retryTimes; ++attempt) {
        for (int ns = 0; ns < statp->nscount; ++ns) {
        for (size_t ns = 0; ns < statp->nsaddrs.size(); ++ns) {
            if (!usable_servers[ns]) continue;
            if (!usable_servers[ns]) continue;


            *rcode = RCODE_INTERNAL_ERROR;
            *rcode = RCODE_INTERNAL_ERROR;


            // Get server addr
            // Get server addr
            const sockaddr* nsap = get_nsaddr(statp, ns);
            const IPSockAddr& serverSockAddr = statp->nsaddrs[ns];
            const int nsaplen = sockaddrSize(nsap);

            static const int niflags = NI_NUMERICHOST | NI_NUMERICSERV;
            char abuf[NI_MAXHOST];
            if (getnameinfo(nsap, (socklen_t)nsaplen, abuf, sizeof(abuf), NULL, 0, niflags) == 0)
            LOG(DEBUG) << __func__ << ": Querying server (# " << ns + 1
            LOG(DEBUG) << __func__ << ": Querying server (# " << ns + 1
                           << ") address = " << abuf;
                       << ") address = " << serverSockAddr.toString();


            ::android::net::Protocol query_proto = useTcp ? PROTO_TCP : PROTO_UDP;
            ::android::net::Protocol query_proto = useTcp ? PROTO_TCP : PROTO_UDP;
            time_t now = 0;
            time_t now = 0;
@@ -533,7 +531,7 @@ int res_nsend(res_state statp, const uint8_t* buf, int buflen, uint8_t* ans, int
            dnsQueryEvent->set_cache_hit(static_cast<CacheStatus>(cache_status));
            dnsQueryEvent->set_cache_hit(static_cast<CacheStatus>(cache_status));
            dnsQueryEvent->set_latency_micros(saturate_cast<int32_t>(queryStopwatch.timeTakenUs()));
            dnsQueryEvent->set_latency_micros(saturate_cast<int32_t>(queryStopwatch.timeTakenUs()));
            dnsQueryEvent->set_dns_server_index(ns);
            dnsQueryEvent->set_dns_server_index(ns);
            dnsQueryEvent->set_ip_version(ipFamilyToIPVersion(nsap->sa_family));
            dnsQueryEvent->set_ip_version(ipFamilyToIPVersion(serverSockAddr.family()));
            dnsQueryEvent->set_retry_times(retry_count_for_event);
            dnsQueryEvent->set_retry_times(retry_count_for_event);
            dnsQueryEvent->set_rcode(static_cast<NsRcode>(*rcode));
            dnsQueryEvent->set_rcode(static_cast<NsRcode>(*rcode));
            dnsQueryEvent->set_protocol(query_proto);
            dnsQueryEvent->set_protocol(query_proto);
@@ -545,9 +543,9 @@ int res_nsend(res_state statp, const uint8_t* buf, int buflen, uint8_t* ans, int
            if (shouldRecordStats) {
            if (shouldRecordStats) {
                res_sample sample;
                res_sample sample;
                _res_stats_set_sample(&sample, now, *rcode, delay);
                _res_stats_set_sample(&sample, now, *rcode, delay);
                resolv_cache_add_resolver_stats_sample(statp->netid, revision_id, nsap, sample,
                resolv_cache_add_resolver_stats_sample(statp->netid, revision_id, serverSockAddr,
                                                       params.max_samples);
                                                       sample, params.max_samples);
                resolv_stats_add(statp->netid, IPSockAddr::toIPSockAddr(*nsap), dnsQueryEvent);
                resolv_stats_add(statp->netid, serverSockAddr, dnsQueryEvent);
            }
            }


            if (resplen == 0) continue;
            if (resplen == 0) continue;
@@ -582,12 +580,6 @@ int res_nsend(res_state statp, const uint8_t* buf, int buflen, uint8_t* ans, int
    return -terrno;
    return -terrno;
}
}


/* Private */

static struct sockaddr* get_nsaddr(res_state statp, size_t n) {
    return (struct sockaddr*)(void*)&statp->nsaddrs[n];
}

static struct timespec get_timeout(res_state statp, const res_params* params, const int ns) {
static struct timespec get_timeout(res_state statp, const res_params* params, const int ns) {
    int msec;
    int msec;
    // Legacy algorithm which scales the timeout by nameserver number.
    // Legacy algorithm which scales the timeout by nameserver number.
@@ -610,7 +602,7 @@ static struct timespec get_timeout(res_state statp, const res_params* params, co
}
}


static int send_vc(res_state statp, res_params* params, const uint8_t* buf, int buflen,
static int send_vc(res_state statp, res_params* params, const uint8_t* buf, int buflen,
                   uint8_t* ans, int anssiz, int* terrno, int ns, time_t* at, int* rcode,
                   uint8_t* ans, int anssiz, int* terrno, size_t ns, time_t* at, int* rcode,
                   int* delay) {
                   int* delay) {
    *at = time(NULL);
    *at = time(NULL);
    *delay = 0;
    *delay = 0;
@@ -623,7 +615,14 @@ static int send_vc(res_state statp, res_params* params, const uint8_t* buf, int


    LOG(INFO) << __func__ << ": using send_vc";
    LOG(INFO) << __func__ << ": using send_vc";


    nsap = get_nsaddr(statp, (size_t) ns);
    // It should never happen, but just in case.
    if (ns >= statp->nsaddrs.size()) {
        LOG(ERROR) << __func__ << ": Out-of-bound indexing: " << ns;
        return -1;
    }

    sockaddr_storage ss = statp->nsaddrs[ns];
    nsap = reinterpret_cast<sockaddr*>(&ss);
    nsaplen = sockaddrSize(nsap);
    nsaplen = sockaddrSize(nsap);


    connreset = 0;
    connreset = 0;
@@ -897,12 +896,18 @@ bool ignoreInvalidAnswer(res_state statp, const sockaddr_storage& from, const ui
}
}


static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen,
static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen,
                   uint8_t* ans, int anssiz, int* terrno, int ns, int* v_circuit, int* gotsomewhere,
                   uint8_t* ans, int anssiz, int* terrno, size_t ns, int* v_circuit,
                   time_t* at, int* rcode, int* delay) {
                   int* gotsomewhere, time_t* at, int* rcode, int* delay) {
    // It should never happen, but just in case.
    if (ns >= statp->nsaddrs.size()) {
        LOG(ERROR) << __func__ << ": Out-of-bound indexing: " << ns;
        return -1;
    }

    *at = time(nullptr);
    *at = time(nullptr);
    *delay = 0;
    *delay = 0;

    const sockaddr_storage ss = statp->nsaddrs[ns];
    const sockaddr* nsap = get_nsaddr(statp, (size_t)ns);
    const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss);
    const int nsaplen = sockaddrSize(nsap);
    const int nsaplen = sockaddrSize(nsap);


    if (statp->nssocks[ns] == -1) {
    if (statp->nssocks[ns] == -1) {
+5 −4
Original line number Original line Diff line number Diff line
@@ -38,6 +38,8 @@


using namespace std::chrono_literals;
using namespace std::chrono_literals;


using android::netdutils::IPSockAddr;

constexpr int TEST_NETID = 30;
constexpr int TEST_NETID = 30;
constexpr int TEST_NETID_2 = 31;
constexpr int TEST_NETID_2 = 31;
constexpr int DNS_PORT = 53;
constexpr int DNS_PORT = 53;
@@ -190,9 +192,9 @@ class ResolvCacheTest : public ::testing::Test {
        return resolv_set_nameservers(netId, setup.servers, setup.domains, setup.params);
        return resolv_set_nameservers(netId, setup.servers, setup.domains, setup.params);
    }
    }


    void cacheAddStats(uint32_t netId, int revision_id, const sockaddr* sa,
    void cacheAddStats(uint32_t netId, int revision_id, const IPSockAddr& ipsa,
                       const res_sample& sample, int max_samples) {
                       const res_sample& sample, int max_samples) {
        resolv_cache_add_resolver_stats_sample(netId, revision_id, sa, sample, max_samples);
        resolv_cache_add_resolver_stats_sample(netId, revision_id, ipsa, sample, max_samples);
    }
    }


    int cacheFlush(uint32_t netId) { return resolv_flush_cache_for_net(netId); }
    int cacheFlush(uint32_t netId) { return resolv_flush_cache_for_net(netId); }
@@ -735,8 +737,7 @@ TEST_F(ResolvCacheTest, FlushCache) {
    res_sample sample = {.at = time(NULL), .rtt = 100, .rcode = ns_r_noerror};
    res_sample sample = {.at = time(NULL), .rtt = 100, .rcode = ns_r_noerror};
    sockaddr_in sin = {.sin_family = AF_INET, .sin_port = htons(DNS_PORT)};
    sockaddr_in sin = {.sin_family = AF_INET, .sin_port = htons(DNS_PORT)};
    ASSERT_TRUE(inet_pton(AF_INET, setup.servers[0].c_str(), &sin.sin_addr));
    ASSERT_TRUE(inet_pton(AF_INET, setup.servers[0].c_str(), &sin.sin_addr));
    cacheAddStats(TEST_NETID, 1 /*revision_id*/, (const sockaddr*)&sin, sample,
    cacheAddStats(TEST_NETID, 1 /*revision_id*/, IPSockAddr(sin), sample, setup.params.max_samples);
                  setup.params.max_samples);


    const CacheStats cacheStats = {
    const CacheStats cacheStats = {
            .setup = setup,
            .setup = setup,
+5 −2
Original line number Original line Diff line number Diff line
@@ -55,6 +55,8 @@
#include <string>
#include <string>
#include <vector>
#include <vector>


#include <netdutils/InternetAddresses.h>

#include "DnsResolver.h"
#include "DnsResolver.h"
#include "netd_resolv/resolv.h"
#include "netd_resolv/resolv.h"
#include "params.h"
#include "params.h"
@@ -102,7 +104,7 @@ struct ResState {
    int nscount;                                // number of name srvers
    int nscount;                                // number of name srvers
    uint16_t id;                                // current message id
    uint16_t id;                                // current message id
    std::vector<std::string> search_domains{};  // domains to search
    std::vector<std::string> search_domains{};  // domains to search
    sockaddr_union nsaddrs[MAXNS];
    std::vector<android::netdutils::IPSockAddr> nsaddrs;
    android::base::unique_fd nssocks[MAXNS];    // UDP sockets to nameservers
    android::base::unique_fd nssocks[MAXNS];    // UDP sockets to nameservers
    unsigned ndots : 4;                         // threshold for initial abs. query
    unsigned ndots : 4;                         // threshold for initial abs. query
    unsigned _mark;                             // If non-0 SET_MARK to _mark on all request sockets
    unsigned _mark;                             // If non-0 SET_MARK to _mark on all request sockets
@@ -125,7 +127,8 @@ int resolv_cache_get_resolver_stats(unsigned netid, res_params* params, res_stat
/* Add a sample to the shared struct for the given netid and server, provided that the
/* Add a sample to the shared struct for the given netid and server, provided that the
 * revision_id of the stored servers has not changed.
 * revision_id of the stored servers has not changed.
 */
 */
void resolv_cache_add_resolver_stats_sample(unsigned netid, int revision_id, const sockaddr* sa,
void resolv_cache_add_resolver_stats_sample(unsigned netid, int revision_id,
                                            const android::netdutils::IPSockAddr& serverSockAddr,
                                            const res_sample& sample, int max_samples);
                                            const res_sample& sample, int max_samples);


// Calculate the round-trip-time from start time t0 and end time t1.
// Calculate the round-trip-time from start time t0 and end time t1.