Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit b121cbf0 authored by Bruce Chen's avatar Bruce Chen Committed by Automerger Merge Worker
Browse files

Merge "Replace manual buffer handling with std::span" am: 094d9ab7

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1790127

Change-Id: I35c2c70c5af1eb311f50b0afa530d09bb4d110bd
parents 5d725143 094d9ab7
Loading
Loading
Loading
Loading
+22 −23
Original line number Diff line number Diff line
@@ -39,7 +39,6 @@
#include <cutils/multiuser.h>
#include <netdutils/InternetAddresses.h>
#include <netdutils/ResponseCode.h>
#include <netdutils/Slice.h>
#include <netdutils/Stopwatch.h>
#include <netdutils/ThreadUtil.h>
#include <private/android_filesystem_config.h>  // AID_SYSTEM
@@ -65,6 +64,7 @@ using aidl::android::net::metrics::INetdEventListener;
using aidl::android::net::resolv::aidl::DnsHealthEventParcel;
using aidl::android::net::resolv::aidl::IDnsResolverUnsolicitedEventListener;
using android::net::NetworkDnsEventReported;
using std::span;

namespace android {

@@ -147,11 +147,11 @@ void maybeFixupNetContext(android_net_context* ctx, pid_t pid) {
void addIpAddrWithinLimit(std::vector<std::string>* ip_addrs, const sockaddr* addr,
                          socklen_t addrlen);

int extractResNsendAnswers(const uint8_t* answer, size_t anslen, int ipType,
int extractResNsendAnswers(std::span<const uint8_t> answer, int ipType,
                           std::vector<std::string>* ip_addrs) {
    int total_ip_addr_count = 0;
    ns_msg handle;
    if (ns_initparse((const uint8_t*)answer, anslen, &handle) < 0) {
    if (ns_initparse(answer.data(), answer.size(), &handle) < 0) {
        return 0;
    }
    int ancount = ns_msg_count(handle, ns_s_an);
@@ -250,21 +250,20 @@ bool simpleStrtoul(const char* input, IntegralType* output, int base = 10) {
    return true;
}

bool setQueryId(uint8_t* msg, size_t msgLen, uint16_t query_id) {
    if (msgLen < sizeof(HEADER)) {
bool setQueryId(span<uint8_t> msg, uint16_t query_id) {
    if ((size_t)msg.size() < sizeof(HEADER)) {
        errno = EINVAL;
        return false;
    }
    auto hp = reinterpret_cast<HEADER*>(msg);
    auto hp = reinterpret_cast<HEADER*>(msg.data());
    hp->id = htons(query_id);
    return true;
}

bool parseQuery(const uint8_t* msg, size_t msgLen, uint16_t* query_id, int* rr_type,
                std::string* rr_name) {
bool parseQuery(span<const uint8_t> msg, uint16_t* query_id, int* rr_type, std::string* rr_name) {
    ns_msg handle;
    ns_rr rr;
    if (ns_initparse((const uint8_t*)msg, msgLen, &handle) < 0 ||
    if (ns_initparse(msg.data(), msg.size(), &handle) < 0 ||
        ns_parserr(&handle, ns_s_qd, 0, &rr) < 0) {
        return false;
    }
@@ -927,8 +926,8 @@ void DnsProxyListener::ResNSendHandler::run() {
    uint16_t original_query_id = 0;

    // TODO: Handle the case which is msg contains more than one query
    if (!parseQuery(msg.data(), msgLen, &original_query_id, &rr_type, &rr_name) ||
        !setQueryId(msg.data(), msgLen, arc4random_uniform(65536))) {
    if (!parseQuery({msg.data(), msgLen}, &original_query_id, &rr_type, &rr_name) ||
        !setQueryId({msg.data(), msgLen}, arc4random_uniform(65536))) {
        // If the query couldn't be parsed, block the request.
        LOG(WARNING) << "ResNSendHandler::run: resnsend: from UID " << uid << ", invalid query";
        sendBE32(mClient, -EINVAL);
@@ -938,21 +937,21 @@ void DnsProxyListener::ResNSendHandler::run() {
    // Send DNS query
    std::vector<uint8_t> ansBuf(MAXPACKET, 0);
    int rcode = ns_r_noerror;
    int nsendAns = -1;
    int ansLen = -1;
    NetworkDnsEventReported event;
    initDnsEvent(&event, mNetContext);
    if (queryLimiter.start(uid)) {
        if (evaluate_domain_name(mNetContext, rr_name.c_str())) {
            nsendAns = resolv_res_nsend(&mNetContext, msg.data(), msgLen, ansBuf.data(), MAXPACKET,
                                        &rcode, static_cast<ResNsendFlags>(mFlags), &event);
            ansLen = resolv_res_nsend(&mNetContext, {msg.data(), msgLen}, ansBuf, &rcode,
                                      static_cast<ResNsendFlags>(mFlags), &event);
        } else {
            nsendAns = -EAI_SYSTEM;
            ansLen = -EAI_SYSTEM;
        }
        queryLimiter.finish(uid);
    } else {
        LOG(WARNING) << "ResNSendHandler::run: resnsend: from UID " << uid
                     << ", max concurrent queries reached";
        nsendAns = -EBUSY;
        ansLen = -EBUSY;
    }

    const int32_t latencyUs = saturate_cast<int32_t>(s.timeTakenUs());
@@ -961,14 +960,14 @@ void DnsProxyListener::ResNSendHandler::run() {
    event.set_res_nsend_flags(static_cast<ResNsendFlags>(mFlags));

    // Fail, send -errno
    if (nsendAns < 0) {
        if (!sendBE32(mClient, nsendAns)) {
    if (ansLen < 0) {
        if (!sendBE32(mClient, ansLen)) {
            PLOG(WARNING) << "ResNSendHandler::run: resnsend: failed to send errno to uid " << uid
                          << " pid " << mClient->getPid();
        }
        if (rr_type == ns_t_a || rr_type == ns_t_aaaa) {
            reportDnsEvent(INetdEventListener::EVENT_RES_NSEND, mNetContext, latencyUs,
                           resNSendToAiError(nsendAns, rcode), event, rr_name);
                           resNSendToAiError(ansLen, rcode), event, rr_name);
        }
        return;
    }
@@ -981,8 +980,8 @@ void DnsProxyListener::ResNSendHandler::run() {
    }

    // Restore query id and send answer
    if (!setQueryId(ansBuf.data(), nsendAns, original_query_id) ||
        !sendLenAndData(mClient, nsendAns, ansBuf.data())) {
    if (!setQueryId({ansBuf.data(), ansLen}, original_query_id) ||
        !sendLenAndData(mClient, ansLen, ansBuf.data())) {
        PLOG(WARNING) << "ResNSendHandler::run: resnsend: failed to send answer to uid " << uid
                      << " pid " << mClient->getPid();
        return;
@@ -991,9 +990,9 @@ void DnsProxyListener::ResNSendHandler::run() {
    if (rr_type == ns_t_a || rr_type == ns_t_aaaa) {
        std::vector<std::string> ip_addrs;
        const int total_ip_addr_count =
                extractResNsendAnswers((uint8_t*)ansBuf.data(), nsendAns, rr_type, &ip_addrs);
                extractResNsendAnswers({ansBuf.data(), ansLen}, rr_type, &ip_addrs);
        reportDnsEvent(INetdEventListener::EVENT_RES_NSEND, mNetContext, latencyUs,
                       resNSendToAiError(nsendAns, rcode), event, rr_name, ip_addrs,
                       resNSendToAiError(ansLen, rcode), event, rr_name, ip_addrs,
                       total_ip_addr_count);
    }
}
+2 −1
Original line number Diff line number Diff line
@@ -128,7 +128,8 @@ DnsTlsTransport::Response DnsTlsDispatcher::query(const std::list<DnsTlsServer>&
        dnsQueryEvent->set_dns_server_index(serverCount++);
        dnsQueryEvent->set_ip_version(ipFamilyToIPVersion(server.ss.ss_family));
        dnsQueryEvent->set_protocol(PROTO_DOT);
        dnsQueryEvent->set_type(getQueryType(query.base(), query.size()));
        std::span<const uint8_t> msg(query.base(), query.size());
        dnsQueryEvent->set_type(getQueryType(msg));
        dnsQueryEvent->set_connected(connectTriggered);

        switch (code) {
+8 −13
Original line number Diff line number Diff line
@@ -1627,13 +1627,11 @@ QueryResult doQuery(const char* name, res_target* t, ResState* res,
    LOG(DEBUG) << __func__ << ": (" << cl << ", " << type << ")";

    uint8_t buf[MAXPACKET];

    int n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf, sizeof(buf),
                         res->netcontext_flags);
    int n = res_nmkquery(QUERY, name, cl, type, {}, buf, res->netcontext_flags);

    if (n > 0 &&
        (res->netcontext_flags & (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS))) {
        n = res_nopt(res, n, buf, sizeof(buf), anslen);
        n = res_nopt(res, n, buf, anslen);
    }

    NetworkDnsEventReported event;
@@ -1651,7 +1649,7 @@ QueryResult doQuery(const char* name, res_target* t, ResState* res,
    ResState res_temp = res->clone(&event);

    int rcode = NOERROR;
    n = res_nsend(&res_temp, buf, n, t->answer.data(), anslen, &rcode, 0, sleepTimeMs);
    n = res_nsend(&res_temp, {buf, n}, {t->answer.data(), anslen}, &rcode, 0, sleepTimeMs);
    if (n < 0 || hp->rcode != NOERROR || ntohs(hp->ancount) == 0) {
        // To ensure that the rcode handling is identical to res_queryN().
        if (rcode != RCODE_TIMEOUT) rcode = hp->rcode;
@@ -1660,9 +1658,8 @@ QueryResult doQuery(const char* name, res_target* t, ResState* res,
             (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS)) &&
            (res_temp.flags & RES_F_EDNS0ERR)) {
            LOG(DEBUG) << __func__ << ": retry without EDNS0";
            n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf,
                             sizeof(buf), res_temp.netcontext_flags);
            n = res_nsend(&res_temp, buf, n, t->answer.data(), anslen, &rcode, 0);
            n = res_nmkquery(QUERY, name, cl, type, {}, buf, res_temp.netcontext_flags);
            n = res_nsend(&res_temp, {buf, n}, {t->answer.data(), anslen}, &rcode, 0);
        }
    }

@@ -1761,21 +1758,19 @@ static int res_queryN(const char* name, res_target* target, ResState* res, int*
        const int anslen = t->answer.size();

        LOG(DEBUG) << __func__ << ": (" << cl << ", " << type << ")";

        n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf, sizeof(buf),
                         res->netcontext_flags);
        n = res_nmkquery(QUERY, name, cl, type, {}, buf, res->netcontext_flags);
        if (n > 0 &&
            (res->netcontext_flags &
             (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS)) &&
            !retried)  // TODO:  remove the retry flag and provide a sufficient test coverage.
            n = res_nopt(res, n, buf, sizeof(buf), anslen);
            n = res_nopt(res, n, buf, anslen);
        if (n <= 0) {
            LOG(ERROR) << __func__ << ": res_nmkquery failed";
            *herrno = NO_RECOVERY;
            return n;
        }

        n = res_nsend(res, buf, n, t->answer.data(), anslen, &rcode, 0);
        n = res_nsend(res, {buf, n}, {t->answer.data(), anslen}, &rcode, 0);
        if (n < 0 || hp->rcode != NOERROR || ntohs(hp->ancount) == 0) {
            // Record rcode from DNS response header only if no timeout.
            // Keep rcode timeout for reporting later if any.
+2 −2
Original line number Diff line number Diff line
@@ -632,7 +632,7 @@ static int dns_gethtbyname(ResState* res, const char* name, int addr_type, getna

    int he;
    const unsigned qclass = isMdnsResolution(res->flags) ? C_IN | C_UNICAST : C_IN;
    n = res_nsearch(res, name, qclass, type, buf->buf, (int)sizeof(buf->buf), &he);
    n = res_nsearch(res, name, qclass, type, {buf->buf, (int)sizeof(buf->buf)}, &he);
    if (n < 0) {
        LOG(DEBUG) << __func__ << ": res_nsearch failed (" << n << ")";
        // Return h_errno (he) to catch more detailed errors rather than EAI_NODATA.
@@ -694,7 +694,7 @@ static int dns_gethtbyaddr(const unsigned char* uaddr, int len, int af,

    ResState res(netcontext, event);
    int he;
    n = res_nquery(&res, qbuf, C_IN, T_PTR, buf->buf, (int)sizeof(buf->buf), &he);
    n = res_nquery(&res, qbuf, C_IN, T_PTR, {buf->buf, (int)sizeof(buf->buf)}, &he);
    if (n < 0) {
        LOG(DEBUG) << __func__ << ": res_nquery failed (" << n << ")";
        // Note that res_nquery() doesn't set the pair NETDB_INTERNAL and errno.
+23 −24
Original line number Diff line number Diff line
@@ -78,6 +78,7 @@ using android::net::PROTO_UDP;
using android::net::Protocol;
using android::netdutils::DumpWriter;
using android::netdutils::IPSockAddr;
using std::span;

/* This code implements a small and *simple* DNS resolver cache.
 *
@@ -773,14 +774,14 @@ static uint32_t answer_getNegativeTTL(ns_msg handle) {
 * In case of parse error zero (0) is returned which
 * indicates that the answer shall not be cached.
 */
static uint32_t answer_getTTL(const void* answer, int answerlen) {
static uint32_t answer_getTTL(span<const uint8_t> answer) {
    ns_msg handle;
    int ancount, n;
    uint32_t result, ttl;
    ns_rr rr;

    result = 0;
    if (ns_initparse((const uint8_t*) answer, answerlen, &handle) >= 0) {
    if (ns_initparse(answer.data(), answer.size(), &handle) >= 0) {
        // get number of answer records
        ancount = ns_msg_count(handle, ns_s_an);

@@ -840,13 +841,13 @@ static unsigned entry_hash(const Entry* e) {

/* initialize an Entry as a search key, this also checks the input query packet
 * returns 1 on success, or 0 in case of unsupported/malformed data */
static int entry_init_key(Entry* e, const void* query, int querylen) {
static int entry_init_key(Entry* e, span<const uint8_t> query) {
    DnsPacket pack[1];

    memset(e, 0, sizeof(*e));

    e->query = (const uint8_t*) query;
    e->querylen = querylen;
    e->query = query.data();
    e->querylen = query.size();
    e->hash = entry_hash(e);

    _dnsPacket_init(pack, e->query, e->querylen);
@@ -855,11 +856,11 @@ static int entry_init_key(Entry* e, const void* query, int querylen) {
}

/* allocate a new entry as a cache node */
static Entry* entry_alloc(const Entry* init, const void* answer, int answerlen) {
static Entry* entry_alloc(const Entry* init, span<const uint8_t> answer) {
    Entry* e;
    int size;

    size = sizeof(*e) + init->querylen + answerlen;
    size = sizeof(*e) + init->querylen + answer.size();
    e = (Entry*) calloc(size, 1);
    if (e == NULL) return e;

@@ -870,9 +871,9 @@ static Entry* entry_alloc(const Entry* init, const void* answer, int answerlen)
    memcpy((char*) e->query, init->query, e->querylen);

    e->answer = e->query + e->querylen;
    e->answerlen = answerlen;
    e->answerlen = answer.size();

    memcpy((char*) e->answer, answer, e->answerlen);
    memcpy((char*)e->answer, answer.data(), e->answerlen);

    return e;
}
@@ -1101,14 +1102,14 @@ static void cache_notify_waiting_tid_locked(struct Cache* cache, const Entry* ke
    }
}

void _resolv_cache_query_failed(unsigned netid, const void* query, int querylen, uint32_t flags) {
void _resolv_cache_query_failed(unsigned netid, span<const uint8_t> query, uint32_t flags) {
    // We should not notify with these flags.
    if (flags & (ANDROID_RESOLV_NO_CACHE_STORE | ANDROID_RESOLV_NO_CACHE_LOOKUP)) {
        return;
    }
    Entry key[1];

    if (!entry_init_key(key, query, querylen)) return;
    if (!entry_init_key(key, query)) return;

    std::lock_guard guard(cache_mutex);

@@ -1228,8 +1229,8 @@ static void _cache_remove_expired(Cache* cache) {
// Get a NetConfig associated with a network, or nullptr if not found.
static NetConfig* find_netconfig_locked(unsigned netid) REQUIRES(cache_mutex);

ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int querylen, void* answer,
                                      int answersize, int* answerlen, uint32_t flags) {
ResolvCacheStatus resolv_cache_lookup(unsigned netid, span<const uint8_t> query,
                                      span<uint8_t> answer, int* answerlen, uint32_t flags) {
    // Skip cache lookup, return RESOLV_CACHE_NOTFOUND directly so that it is
    // possible to cache the answer of this query.
    // If ANDROID_RESOLV_NO_CACHE_STORE is set, return RESOLV_CACHE_SKIP to skip possible cache
@@ -1247,7 +1248,7 @@ ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int que
    LOG(INFO) << __func__ << ": lookup";

    /* we don't cache malformed queries */
    if (!entry_init_key(&key, query, querylen)) {
    if (!entry_init_key(&key, query)) {
        LOG(INFO) << __func__ << ": unsupported query";
        return RESOLV_CACHE_UNSUPPORTED;
    }
@@ -1310,13 +1311,13 @@ ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int que
    }

    *answerlen = e->answerlen;
    if (e->answerlen > answersize) {
    if (e->answerlen > answer.size()) {
        /* NOTE: we return UNSUPPORTED if the answer buffer is too short */
        LOG(INFO) << __func__ << ": ANSWER TOO LONG";
        return RESOLV_CACHE_UNSUPPORTED;
    }

    memcpy(answer, e->answer, e->answerlen);
    memcpy(answer.data(), e->answer, e->answerlen);

    /* bump up this entry to the top of the MRU list */
    if (e != cache->mru_list.mru_next) {
@@ -1328,8 +1329,7 @@ ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int que
    return RESOLV_CACHE_FOUND;
}

int resolv_cache_add(unsigned netid, const void* query, int querylen, const void* answer,
                     int answerlen) {
int resolv_cache_add(unsigned netid, span<const uint8_t> query, span<const uint8_t> answer) {
    Entry key[1];
    Entry* e;
    Entry** lookup;
@@ -1338,7 +1338,7 @@ int resolv_cache_add(unsigned netid, const void* query, int querylen, const void

    /* don't assume that the query has already been cached
     */
    if (!entry_init_key(key, query, querylen)) {
    if (!entry_init_key(key, query)) {
        LOG(INFO) << __func__ << ": passed invalid query?";
        return -EINVAL;
    }
@@ -1375,9 +1375,9 @@ int resolv_cache_add(unsigned netid, const void* query, int querylen, const void
        }
    }

    ttl = answer_getTTL(answer, answerlen);
    ttl = answer_getTTL(answer);
    if (ttl > 0) {
        e = entry_alloc(key, answer, answerlen);
        e = entry_alloc(key, answer);
        if (e != NULL) {
            e->expires = ttl + _time_now();
            _cache_add_p(cache, lookup, e);
@@ -1886,13 +1886,12 @@ bool has_named_cache(unsigned netid) {
    return find_named_cache_locked(netid) != nullptr;
}

int resolv_cache_get_expiration(unsigned netid, const std::vector<char>& query,
                                time_t* expiration) {
int resolv_cache_get_expiration(unsigned netid, span<const uint8_t> query, time_t* expiration) {
    Entry key;
    *expiration = -1;

    // A malformed query is not allowed.
    if (!entry_init_key(&key, query.data(), query.size())) {
    if (!entry_init_key(&key, query)) {
        LOG(WARNING) << __func__ << ": unsupported query";
        return -EINVAL;
    }
Loading