Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 0a0870d8 authored by Luke Huang's avatar Luke Huang
Browse files

Do the A/AAAA lookup in parallel for getaddrinfo

Create threads for each A/AAAA lookup.
The functionality is disabled for now.

Bug: 2609013
Bug: 135717624
Bug: 151698212
Test: atest

Change-Id: I23cdcdf800d2f9ee42b5f19a5dd2045cdf61d41f
parent 420ee62b
Loading
Loading
Loading
Loading
+138 −28
Original line number Diff line number Diff line
@@ -53,6 +53,8 @@
#include <sys/un.h>
#include <unistd.h>

#include <future>

#include <android-base/logging.h>

#include "netd_resolv/resolv.h"
@@ -61,6 +63,7 @@
#include "res_init.h"
#include "resolv_cache.h"
#include "resolv_private.h"
#include "util.h"

#define ANY 0

@@ -1573,6 +1576,137 @@ static bool files_getaddrinfo(const size_t netid, const char* name, const addrin

/* resolver logic */

namespace {

constexpr int SLEEP_TIME_MS = 2;

int getHerrnoFromRcode(int rcode) {
    switch (rcode) {
        // Not defined in RFC.
        case RCODE_TIMEOUT:
            // DNS metrics monitors DNS query timeout.
            return NETD_RESOLV_H_ERRNO_EXT_TIMEOUT;  // extended h_errno.
        // Defined in RFC 1035 section 4.1.1.
        case NXDOMAIN:
            return HOST_NOT_FOUND;
        case SERVFAIL:
            return TRY_AGAIN;
        case NOERROR:
            return NO_DATA;
        case FORMERR:
        case NOTIMP:
        case REFUSED:
        default:
            return NO_RECOVERY;
    }
}

struct QueryResult {
    int ancount;
    int rcode;
    int herrno;
    NetworkDnsEventReported event;
};

QueryResult doQuery(const char* name, res_target* t, res_state res) {
    HEADER* hp = (HEADER*)(void*)t->answer.data();

    hp->rcode = NOERROR;  // default

    const int cl = t->qclass;
    const int type = t->qtype;
    const int anslen = t->answer.size();

    LOG(DEBUG) << __func__ << ": (" << cl << ", " << type << ")";

    uint8_t buf[MAXPACKET];

    int n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf, sizeof(buf),
                         res->netcontext_flags);

    if (n > 0 &&
        (res->netcontext_flags & (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS))) {
        n = res_nopt(res, n, buf, sizeof(buf), anslen);
    }

    NetworkDnsEventReported event;
    if (n <= 0) {
        LOG(ERROR) << __func__ << ": res_nmkquery failed";
        return {0, -1, NO_RECOVERY, event};
        return {
                .ancount = 0,
                .rcode = -1,
                .herrno = NO_RECOVERY,
                .event = event,
        };
    }

    ResState res_temp = fromResState(*res, &event);

    int rcode = NOERROR;
    n = res_nsend(&res_temp, buf, n, t->answer.data(), anslen, &rcode, 0);
    if (n < 0 || hp->rcode != NOERROR || ntohs(hp->ancount) == 0) {
        // if the query choked with EDNS0, retry without EDNS0
        if ((res_temp.netcontext_flags &
             (NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS)) &&
            (res_temp._flags & RES_F_EDNS0ERR)) {
            LOG(DEBUG) << __func__ << ": retry without EDNS0";
            n = res_nmkquery(QUERY, name, cl, type, /*data=*/nullptr, /*datalen=*/0, buf,
                             sizeof(buf), res->netcontext_flags);
            n = res_nsend(&res_temp, buf, n, t->answer.data(), anslen, &rcode, 0);
        }
    }

    LOG(DEBUG) << __func__ << ": rcode=" << hp->rcode << ", ancount=" << ntohs(hp->ancount);

    t->n = n;
    return {
            .ancount = ntohs(hp->ancount),
            .rcode = rcode,
            .event = event,
    };
}

}  // namespace

static int res_queryN_parallel(const char* name, res_target* target, res_state res, int* herrno) {
    std::vector<std::future<QueryResult>> results;
    results.reserve(2);
    for (res_target* t = target; t; t = t->next) {
        results.emplace_back(std::async(std::launch::async, doQuery, name, t, res));
        // Avoiding gateways drop packets if queries are sent too close together
        if (t->next) usleep(SLEEP_TIME_MS * 1000);
    }

    int ancount = 0;
    int rcode = 0;

    for (auto& f : results) {
        const QueryResult& r = f.get();
        if (r.herrno == NO_RECOVERY) {
            *herrno = r.herrno;
            return -1;
        }
        res->event->MergeFrom(r.event);
        ancount += r.ancount;
        rcode = r.rcode;
    }

    if (ancount == 0) {
        *herrno = getHerrnoFromRcode(rcode);
        return -1;
    }

    return ancount;
}

static int res_queryN_wrapper(const char* name, res_target* target, res_state res, int* herrno) {
    const bool parallel_lookup = getExperimentFlagInt("parallel_lookup", 0);
    if (parallel_lookup) return res_queryN_parallel(name, target, res, herrno);

    return res_queryN(name, target, res, herrno);
}

/*
 * Formulate a normal query, send, and await answer.
 * Returned answer is placed in supplied buffer "answer".
@@ -1647,29 +1781,7 @@ static int res_queryN(const char* name, res_target* target, res_state res, int*
    }

    if (ancount == 0) {
        switch (rcode) {
            // Not defined in RFC.
            case RCODE_TIMEOUT:
                // DNS metrics monitors DNS query timeout.
                *herrno = NETD_RESOLV_H_ERRNO_EXT_TIMEOUT;  // extended h_errno.
                break;
            // Defined in RFC 1035 section 4.1.1.
            case NXDOMAIN:
                *herrno = HOST_NOT_FOUND;
                break;
            case SERVFAIL:
                *herrno = TRY_AGAIN;
                break;
            case NOERROR:
                *herrno = NO_DATA;
                break;
            case FORMERR:
            case NOTIMP:
            case REFUSED:
            default:
                *herrno = NO_RECOVERY;
                break;
        }
        *herrno = getHerrnoFromRcode(rcode);
        return -1;
    }
    return ancount;
@@ -1795,10 +1907,8 @@ static int res_searchN(const char* name, res_target* target, res_state res, int*
    return -1;
}

/*
 * Perform a call on res_query on the concatenation of name and domain,
 * removing a trailing dot from name if domain is NULL.
 */
// Perform a call on res_query on the concatenation of name and domain,
// removing a trailing dot from name if domain is NULL.
static int res_querydomainN(const char* name, const char* domain, res_target* target, res_state res,
                            int* herrno) {
    char nbuf[MAXDNAME];
@@ -1828,5 +1938,5 @@ static int res_querydomainN(const char* name, const char* domain, res_target* ta
        }
        snprintf(nbuf, sizeof(nbuf), "%s.%s", name, domain);
    }
    return res_queryN(longname, target, res, herrno);
    return res_queryN_wrapper(longname, target, res, herrno);
}
+23 −0
Original line number Diff line number Diff line
@@ -91,6 +91,7 @@

#include "netd_resolv/resolv.h"
#include "resolv_private.h"
#include "stats.pb.h"

void res_init(ResState* statp, const struct android_net_context* _Nonnull netcontext,
              android::net::NetworkDnsEventReported* _Nonnull event) {
@@ -108,3 +109,25 @@ void res_init(ResState* statp, const struct android_net_context* _Nonnull netcon
    statp->event = event;
    statp->netcontext_flags = netcontext->flags;
}

// TODO: Have some proper constructors for ResState instead of this method and res_init().
ResState fromResState(const ResState& other, android::net::NetworkDnsEventReported* event) {
    ResState resOutput;
    resOutput.netid = other.netid;
    resOutput.uid = other.uid;
    resOutput.pid = other.pid;
    resOutput.id = other.id;

    resOutput.nsaddrs = other.nsaddrs;

    for (auto& sock : resOutput.nssocks) {
        sock.reset();
    }

    resOutput.ndots = other.ndots;
    resOutput._mark = other._mark;
    resOutput.tcp_nssock.reset();
    resOutput.event = event;
    resOutput.netcontext_flags = other.netcontext_flags;
    return resOutput;
}
+2 −0
Original line number Diff line number Diff line
@@ -16,7 +16,9 @@
#pragma once

#include "resolv_private.h"
#include "stats.pb.h"

// TODO: make this a constructor for ResState
void res_init(ResState* res, const struct android_net_context* netcontext,
              android::net::NetworkDnsEventReported* event);
ResState fromResState(const ResState& other, android::net::NetworkDnsEventReported* event);
+48 −46
Original line number Diff line number Diff line
@@ -213,30 +213,28 @@ void DnsTlsFrontend::requestHandler() {
            SSL_set_fd(ssl.get(), client.get());

            LOG(DEBUG) << "Doing SSL handshake";
            bool success = false;
            if (SSL_accept(ssl.get()) <= 0) {
                LOG(INFO) << "SSL negotiation failure";
            } else {
                LOG(DEBUG) << "SSL handshake complete";
                success = handleOneRequest(ssl.get());
            }

            if (success) {
                // Increment queries_ as late as possible, because it represents
                // a query that is fully processed, and the response returned to the
                // client, including cleanup actions.
                ++queries_;
                queries_ += handleRequests(ssl.get(), client.get());
            }
        }
    }
    LOG(DEBUG) << "Ending loop";
}

bool DnsTlsFrontend::handleOneRequest(SSL* ssl) {
int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
    int queryCounts = 0;
    pollfd fds = {.fd = clientFd, .events = POLLIN};
    do {
        uint8_t queryHeader[2];
        if (SSL_read(ssl, &queryHeader, 2) != 2) {
            LOG(INFO) << "Not enough header bytes";
        return false;
            return queryCounts;
        }
        const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
        uint8_t query[qlen];
@@ -245,34 +243,38 @@ bool DnsTlsFrontend::handleOneRequest(SSL* ssl) {
            int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
            if (ret <= 0) {
                LOG(INFO) << "Error while reading query";
            return false;
                return queryCounts;
            }
            qbytes += ret;
        }
        int sent = send(backend_socket_.get(), query, qlen, 0);
        if (sent != qlen) {
            LOG(INFO) << "Failed to send query";
        return false;
            return queryCounts;
        }
        const int max_size = 4096;
        uint8_t recv_buffer[max_size];
        int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0);
        if (rlen <= 0) {
            LOG(INFO) << "Failed to receive response";
        return false;
            return queryCounts;
        }
        uint8_t responseHeader[2];
        responseHeader[0] = rlen >> 8;
        responseHeader[1] = rlen;
        if (SSL_write(ssl, responseHeader, 2) != 2) {
            LOG(INFO) << "Failed to write response header";
        return false;
            return queryCounts;
        }
        if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
            LOG(INFO) << "Failed to write response body";
        return false;
            return queryCounts;
        }
    return true;
        ++queryCounts;
    } while (poll(&fds, 1, 1) > 0);

    LOG(DEBUG) << __func__ << " return: " << queryCounts;
    return queryCounts;
}

bool DnsTlsFrontend::stopServer() {
+1 −1
Original line number Diff line number Diff line
@@ -69,7 +69,7 @@ class DnsTlsFrontend {

  private:
    void requestHandler();
    bool handleOneRequest(SSL* ssl);
    int handleRequests(SSL* ssl, int clientFd);

    // Trigger the handler thread to terminate.
    bool sendToEventFd();
Loading