Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 69686277 authored by Stefano Duo's avatar Stefano Duo
Browse files

getaddrinfo: propagate ResState via res_target

Code internal to getaddrinfo assumes that all query share the same
ResState. This is currently true, but will no longer be the case once we
properly support mDNS.

To keep things working:
1. Propagate ResState via res_target (as it's going to be possibly
   different for each query
2. Explicitly pass around the parameters we expect to be the same for
   each query

Bug: b:374756136
Change-Id: Ieba6ccb9d1089879576249b5d820c9115e5951dd
parent a1f5fa39
Loading
Loading
Loading
Loading
+33 −24
Original line number Diff line number Diff line
@@ -120,6 +120,8 @@ struct res_target {
    int qclass, qtype;                                                 // class and type of query
    std::vector<uint8_t> answer = std::vector<uint8_t>(MAXPACKET, 0);  // buffer to put answer
    int n = 0;                                                         // result length
    // ResState this query should be run within
    ResState* res_state;
};

static int explore_fqdn(const struct addrinfo*, const char*, const char*, struct addrinfo**,
@@ -151,9 +153,11 @@ static bool files_getaddrinfo(const size_t netid, const char* name, const addrin
static int _find_src_addr(const struct sockaddr*, struct sockaddr*, unsigned, uid_t,
                          bool allow_v6_linklocal);

static int res_searchN(const char* name, std::span<res_target> queries, ResState* res, int* herrno);
static int res_searchN(const char* name, std::span<res_target> queries,
                       std::span<std::string> search_domains, bool is_mdns,
                       android::net::NetworkDnsEventReported* event, int* herrno);
static int res_querydomainN(const char* name, const char* domain, std::span<res_target> queries,
                            ResState* res, int* herrno);
                            android::net::NetworkDnsEventReported* event, int* herrno);

const char* const ai_errlist[] = {
        "Success",
@@ -1403,7 +1407,9 @@ static int dns_getaddrinfo(const char* name, const addrinfo* pai,
                           addrinfo** rv, NetworkDnsEventReported* event) {
    std::vector<res_target> queries;
    ResState res(netcontext, app_socket, event);

    setMdnsFlag(name, res.netid, &(res.flags));
    bool is_mdns = isMdnsResolution(res.flags);

    bool query_ipv6 = false;
    bool query_ipv4 = false;
@@ -1417,9 +1423,8 @@ static int dns_getaddrinfo(const char* name, const addrinfo* pai,
            // system". However, bionic doesn't currently support getifaddrs, so checking for
            // connectivity is the next best thing.
            query_ipv6 = have_global_ipv6_connectivity(netcontext->app_mark, netcontext->uid) ||
                         (isMdnsResolution(res.flags) &&
                          have_local_ipv6_connectivity(netcontext->app_mark, netcontext->uid,
                                                       res.netid));
                         (is_mdns && have_local_ipv6_connectivity(netcontext->app_mark,
                                                                  netcontext->uid, res.netid));
            query_ipv4 = have_ipv4_connectivity(netcontext->app_mark, netcontext->uid);
        }
    } else if (pai->ai_family == AF_INET) {
@@ -1430,11 +1435,14 @@ static int dns_getaddrinfo(const char* name, const addrinfo* pai,
        return EAI_FAMILY;
    }

    resolv_populate_res_for_net(&res);

    if (query_ipv6) {
        res_target ipv6_query;
        ipv6_query.name = name;
        ipv6_query.qclass = C_IN;
        ipv6_query.qtype = T_AAAA;
        ipv6_query.res_state = &res;
        queries.push_back(ipv6_query);
    }
    if (query_ipv4) {
@@ -1442,16 +1450,16 @@ static int dns_getaddrinfo(const char* name, const addrinfo* pai,
        ipv4_query.name = name;
        ipv4_query.qclass = C_IN;
        ipv4_query.qtype = T_A;
        ipv4_query.res_state = &res;
        queries.push_back(ipv4_query);
    }
    if (queries.empty()) {
        return EAI_NODATA;
    }

    resolv_populate_res_for_net(&res);

    int he;
    if (res_searchN(name, queries, &res, &he) < 0) {
    // TODO: Refactor search_domains and event out of ResState (they really should not be there).
    if (res_searchN(name, queries, res.search_domains, is_mdns, res.event, &he) < 0) {
        // Return h_errno (he) to catch more detailed errors rather than EAI_NODATA.
        // Note that res_searchN() doesn't set the pair NETDB_INTERNAL and errno.
        // See also herrnoToAiErrno().
@@ -1687,15 +1695,15 @@ QueryResult doQuery(const char* name, res_target* t, ResState* res,

// This function runs doQuery() for each res_target in parallel.
// The `target`, which is set in dns_getaddrinfo(), contains at most two res_target.
static int res_queryN_parallel(const char* name, std::span<res_target> queries, ResState* res,
                               int* herrno) {
static int res_queryN_parallel(const char* name, std::span<res_target> queries,
                               android::net::NetworkDnsEventReported* event, int* herrno) {
    std::vector<std::future<QueryResult>> results;
    results.reserve(2);
    std::chrono::milliseconds sleepTimeMs{};
    bool is_first_iteration = true;
    for (auto& query : queries) {
        results.emplace_back(
                std::async(std::launch::async, doQuery, name, &query, res, sleepTimeMs));
        results.emplace_back(std::async(std::launch::async, doQuery, name, &query, query.res_state,
                                        sleepTimeMs));
        if (is_first_iteration) {
            // Avoiding gateways drop packets if queries are sent too close together
            // Only needed if we have multiple queries in a row.
@@ -1718,7 +1726,7 @@ static int res_queryN_parallel(const char* name, std::span<res_target> queries,
            *herrno = r.herrno;
            return -1;
        }
        res->event->MergeFrom(r.event);
        event->MergeFrom(r.event);
        ancount += r.ancount;
        rcode = r.rcode;
        errno = r.qerrno;
@@ -1738,8 +1746,9 @@ static int res_queryN_parallel(const char* name, std::span<res_target> queries,
 * If enabled, implement search rules until answer or unrecoverable failure
 * is detected.  Error code, if any, is left in *herrno.
 */
static int res_searchN(const char* name, std::span<res_target> queries, ResState* res,
                       int* herrno) {
static int res_searchN(const char* name, std::span<res_target> queries,
                       std::span<std::string> search_domains, bool is_mdns,
                       android::net::NetworkDnsEventReported* event, int* herrno) {
    const char* cp;
    HEADER* hp;
    uint32_t dots;
@@ -1760,7 +1769,7 @@ static int res_searchN(const char* name, std::span<res_target> queries, ResState
    // If there are dots in the name already, let's just give it a try 'as is'.
    saved_herrno = -1;
    if (dots >= NDOTS) {
        ret = res_querydomainN(name, NULL, queries, res, herrno);
        ret = res_querydomainN(name, NULL, queries, event, herrno);
        if (ret > 0) return (ret);
        saved_herrno = *herrno;
        tried_as_is++;
@@ -1772,9 +1781,9 @@ static int res_searchN(const char* name, std::span<res_target> queries, ResState
     * - there is at least one dot and there is no trailing dot.
     * - this is not a .local mDNS lookup.
     */
    if ((!dots || (dots && !trailing_dot)) && !isMdnsResolution(res->flags)) {
        for (const auto& domain : res->search_domains) {
            ret = res_querydomainN(name, domain.c_str(), queries, res, herrno);
    if ((!dots || (dots && !trailing_dot)) && !is_mdns) {
        for (const auto& domain : search_domains) {
            ret = res_querydomainN(name, domain.c_str(), queries, event, herrno);
            if (ret > 0) return ret;

            /*
@@ -1818,7 +1827,7 @@ static int res_searchN(const char* name, std::span<res_target> queries, ResState
     * name or whether it ends with a dot.
     */
    if (!tried_as_is) {
        ret = res_querydomainN(name, NULL, queries, res, herrno);
        ret = res_querydomainN(name, NULL, queries, event, herrno);
        if (ret > 0) return ret;
    }

@@ -1842,7 +1851,7 @@ static int res_searchN(const char* name, std::span<res_target> queries, ResState
// Perform a call on res_query on the concatenation of name and domain,
// removing a trailing dot from name if domain is NULL.
static int res_querydomainN(const char* name, const char* domain, std::span<res_target> queries,
                            ResState* res, int* herrno) {
                            android::net::NetworkDnsEventReported* event, int* herrno) {
    char nbuf[MAXDNAME];
    const char* longname = nbuf;
    size_t n, d;
@@ -1870,5 +1879,5 @@ static int res_querydomainN(const char* name, const char* domain, std::span<res_
        }
        snprintf(nbuf, sizeof(nbuf), "%s.%s", name, domain);
    }
    return res_queryN_parallel(longname, queries, res, herrno);
    return res_queryN_parallel(longname, queries, event, herrno);
}