Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit feed5008 authored by android-build-team Robot's avatar android-build-team Robot
Browse files

Snap for 7464528 from 00c45f49 to tm-release

Change-Id: I72f4182b31476e0966f048ec27c3df316a1ebfec
parents f6ff2a22 00c45f49
Loading
Loading
Loading
Loading
+189 −44
Original line number Diff line number Diff line
@@ -18,14 +18,24 @@

#include "DnsTlsTransport.h"

#include <span>

#include <android-base/format.h>
#include <android-base/logging.h>
#include <android-base/result.h>
#include <android-base/stringprintf.h>
#include <arpa/inet.h>
#include <arpa/nameser.h>
#include <netdutils/Stopwatch.h>
#include <netdutils/ThreadUtil.h>
#include <private/android_filesystem_config.h>  // AID_DNS
#include <sys/poll.h>

#include "DnsTlsSocketFactory.h"
#include "Experiments.h"
#include "IDnsTlsSocketFactory.h"
#include "resolv_private.h"
#include "util.h"

using android::base::StringPrintf;
using android::netdutils::setThreadName;
@@ -33,6 +43,113 @@ using android::netdutils::setThreadName;
namespace android {
namespace net {

namespace {

// Make a DNS query for the hostname "<random>-dnsotls-ds.metric.gstatic.com".
std::vector<uint8_t> makeDnsQuery() {
    static const char kDnsSafeChars[] =
            "abcdefhijklmnopqrstuvwxyz"
            "ABCDEFHIJKLMNOPQRSTUVWXYZ"
            "0123456789";
    const auto c = [](uint8_t rnd) -> uint8_t {
        return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
    };
    uint8_t rnd[8];
    arc4random_buf(rnd, std::size(rnd));

    return std::vector<uint8_t>{
            rnd[6], rnd[7],  // [0-1]   query ID
            1,      0,       // [2-3]   flags; query[2] = 1 for recursion desired (RD).
            0,      1,       // [4-5]   QDCOUNT (number of queries)
            0,      0,       // [6-7]   ANCOUNT (number of answers)
            0,      0,       // [8-9]   NSCOUNT (number of name server records)
            0,      0,       // [10-11] ARCOUNT (number of additional records)
            17,     c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), '-', 'd', 'n',
            's',    'o',       't',       'l',       's',       '-',       'd',       's', 6,   'm',
            'e',    't',       'r',       'i',       'c',       7,         'g',       's', 't', 'a',
            't',    'i',       'c',       3,         'c',       'o',       'm',
            0,                  // null terminator of FQDN (root TLD)
            0,      ns_t_aaaa,  // QTYPE
            0,      ns_c_in     // QCLASS
    };
}

base::Result<void> checkDnsResponse(const std::span<const uint8_t> answer) {
    if (answer.size() < NS_HFIXEDSZ) {
        return Errorf("short response: {}", answer.size());
    }

    const int qdcount = (answer[4] << 8) | answer[5];
    if (qdcount != 1) {
        return Errorf("reply query count != 1: {}", qdcount);
    }

    const int ancount = (answer[6] << 8) | answer[7];
    LOG(DEBUG) << "answer count: " << ancount;

    // TODO: Further validate the response contents (check for valid AAAA record, ...).
    // Note that currently, integration tests rely on this function accepting a
    // response with zero records.

    return {};
}

// Sends |query| to the given server, and returns the DNS response.
base::Result<void> sendUdpQuery(netdutils::IPAddress ip, uint32_t mark,
                                std::span<const uint8_t> query) {
    const sockaddr_storage ss = netdutils::IPSockAddr(ip, 53);
    const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss);
    const int nsaplen = sockaddrSize(nsap);
    const int sockType = SOCK_DGRAM | SOCK_NONBLOCK | SOCK_CLOEXEC;
    android::base::unique_fd fd{socket(nsap->sa_family, sockType, 0)};
    if (fd < 0) {
        return ErrnoErrorf("socket failed");
    }

    resolv_tag_socket(fd.get(), AID_DNS, NET_CONTEXT_INVALID_PID);
    if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mark, sizeof(mark)) < 0) {
        return ErrnoErrorf("setsockopt failed");
    }

    if (connect(fd.get(), nsap, (socklen_t)nsaplen) < 0) {
        return ErrnoErrorf("connect failed");
    }

    if (send(fd, query.data(), query.size(), 0) != query.size()) {
        return ErrnoErrorf("send failed");
    }

    const int timeoutMs = 3000;
    while (true) {
        pollfd fds = {.fd = fd, .events = POLLIN};

        const int n = TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs));
        if (n == 0) {
            return Errorf("poll timed out");
        }
        if (n < 0) {
            return ErrnoErrorf("poll failed");
        }
        if (fds.revents & (POLLIN | POLLERR)) {
            std::vector<uint8_t> buf(MAXPACKET);
            const int resplen = recv(fd, buf.data(), buf.size(), 0);

            if (resplen < 0) {
                return ErrnoErrorf("recvfrom failed");
            }

            buf.resize(resplen);
            if (auto result = checkDnsResponse(buf); !result.ok()) {
                return Errorf("checkDnsResponse failed: {}", result.error().message());
            }

            return {};
        }
    }
}

}  // namespace

std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
    std::lock_guard guard(mLock);

@@ -160,65 +277,93 @@ DnsTlsTransport::~DnsTlsTransport() {
// That may require moving it to DnsTlsDispatcher.
bool DnsTlsTransport::validate(const DnsTlsServer& server, uint32_t mark) {
    LOG(DEBUG) << "Beginning validation with mark " << std::hex << mark;
    // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
    // order to prove that it is actually a working DNS over TLS server.
    static const char kDnsSafeChars[] =
            "abcdefhijklmnopqrstuvwxyz"
            "ABCDEFHIJKLMNOPQRSTUVWXYZ"
            "0123456789";
    const auto c = [](uint8_t rnd) -> uint8_t {
        return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
    };
    uint8_t rnd[8];
    arc4random_buf(rnd, std::size(rnd));
    // We could try to use res_mkquery() here, but it's basically the same.
    uint8_t query[] = {
        rnd[6], rnd[7],  // [0-1]   query ID
        1, 0,  // [2-3]   flags; query[2] = 1 for recursion desired (RD).
        0, 1,  // [4-5]   QDCOUNT (number of queries)
        0, 0,  // [6-7]   ANCOUNT (number of answers)
        0, 0,  // [8-9]   NSCOUNT (number of name server records)
        0, 0,  // [10-11] ARCOUNT (number of additional records)
        17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
            '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
        6, 'm', 'e', 't', 'r', 'i', 'c',
        7, 'g', 's', 't', 'a', 't', 'i', 'c',
        3, 'c', 'o', 'm',
        0,  // null terminator of FQDN (root TLD)
        0, ns_t_aaaa,  // QTYPE
        0, ns_c_in     // QCLASS
    };
    const int qlen = std::size(query);

    int replylen = 0;
    const std::vector<uint8_t> query = makeDnsQuery();
    DnsTlsSocketFactory factory;
    DnsTlsTransport transport(server, mark, &factory);
    auto r = transport.query(netdutils::Slice(query, qlen)).get();

    // Send the initial query to warm up the connection.
    auto r = transport.query(netdutils::makeSlice(query)).get();
    if (r.code != Response::success) {
        LOG(WARNING) << "query failed";
        return false;
    }

    const std::vector<uint8_t>& recvbuf = r.response;
    if (recvbuf.size() < NS_HFIXEDSZ) {
        LOG(WARNING) << "short response: " << replylen;
    if (auto result = checkDnsResponse(r.response); !result.ok()) {
        LOG(WARNING) << "checkDnsResponse failed: " << result.error().message();
        return false;
    }

    const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
    if (qdcount != 1) {
        LOG(WARNING) << "reply query count != 1: " << qdcount;
        return false;
    // If this validation is not for opportunistic mode, or the flags are not properly set,
    // the validation is done. If not, the validation will compare DoT probe latency and
    // UDP probe latency, and it will pass if:
    //   dot_probe_latency < latencyFactor * udp_probe_latency + latencyOffsetMs
    //
    // For instance, with latencyFactor = 3 and latencyOffsetMs = 10, if UDP probe latency is 5 ms,
    // DoT probe latency must less than 25 ms.
    int latencyFactor = Experiments::getInstance()->getFlag("dot_validation_latency_factor", -1);
    int latencyOffsetMs =
            Experiments::getInstance()->getFlag("dot_validation_latency_offset_ms", -1);
    const bool shouldCompareUdpLatency =
            server.name.empty() &&
            (latencyFactor >= 0 && latencyOffsetMs >= 0 && latencyFactor + latencyOffsetMs != 0);
    if (!shouldCompareUdpLatency) {
        return true;
    }

    const int ancount = (recvbuf[6] << 8) | recvbuf[7];
    LOG(DEBUG) << "answer count: " << ancount;
    LOG(INFO) << fmt::format("Use flags: latencyFactor={}, latencyOffsetMs={}", latencyFactor,
                             latencyOffsetMs);

    // TODO: Further validate the response contents (check for valid AAAA record, ...).
    // Note that currently, integration tests rely on this function accepting a
    // response with zero records.
    int64_t udpProbeTimeUs = 0;
    bool udpProbeGotAnswer = false;
    std::thread udpProbeThread([&] {
        // Can issue another probe if the first one fails or is lost.
        for (int i = 1; i < 3; i++) {
            netdutils::Stopwatch stopwatch;
            auto result = sendUdpQuery(server.addr().ip(), mark, query);
            udpProbeTimeUs = stopwatch.timeTakenUs();
            udpProbeGotAnswer = result.ok();
            LOG(INFO) << fmt::format("UDP probe for {} {}, took {:.3f}ms", server.toIpString(),
                                     (udpProbeGotAnswer ? "succeeded" : "failed"),
                                     udpProbeTimeUs / 1000.0);

    return true;
            if (udpProbeGotAnswer) {
                break;
            }
            LOG(WARNING) << "sendUdpQuery attempt " << i << " failed: " << result.error().message();
        }
    });

    int64_t dotProbeTimeUs = 0;
    bool dotProbeGotAnswer = false;
    std::thread dotProbeThread([&] {
        netdutils::Stopwatch stopwatch;
        auto r = transport.query(netdutils::makeSlice(query)).get();
        dotProbeTimeUs = stopwatch.timeTakenUs();

        if (r.code != Response::success) {
            LOG(WARNING) << "query failed";
        } else {
            if (auto result = checkDnsResponse(r.response); !result.ok()) {
                LOG(WARNING) << "checkDnsResponse failed: " << result.error().message();
            } else {
                dotProbeGotAnswer = true;
            }
        }

        LOG(INFO) << fmt::format("DoT probe for {} {}, took {:.3f}ms", server.toIpString(),
                                 (dotProbeGotAnswer ? "succeeded" : "failed"),
                                 dotProbeTimeUs / 1000.0);
    });

    // TODO: If DoT probe thread finishes before UDP probe thread and dotProbeGotAnswer is false,
    // actively cancel UDP probe thread.
    dotProbeThread.join();
    udpProbeThread.join();

    if (!dotProbeGotAnswer) return false;
    if (!udpProbeGotAnswer) return true;
    return dotProbeTimeUs < (latencyFactor * udpProbeTimeUs + latencyOffsetMs * 1000);
}

}  // end of namespace net
+11 −3
Original line number Diff line number Diff line
@@ -49,10 +49,18 @@ class Experiments {
    // TODO: Migrate other experiment flags to here.
    // (retry_count, retransmission_time_interval)
    static constexpr const char* const kExperimentFlagKeyList[] = {
            "keep_listening_udp",   "parallel_lookup_release",    "parallel_lookup_sleep_time",
            "sort_nameservers",     "dot_async_handshake",        "dot_connect_timeout_ms",
            "dot_maxtries",         "dot_revalidation_threshold", "dot_xport_unusable_threshold",
            "keep_listening_udp",
            "parallel_lookup_release",
            "parallel_lookup_sleep_time",
            "sort_nameservers",
            "dot_async_handshake",
            "dot_connect_timeout_ms",
            "dot_maxtries",
            "dot_revalidation_threshold",
            "dot_xport_unusable_threshold",
            "dot_query_timeout_ms",
            "dot_validation_latency_factor",
            "dot_validation_latency_offset_ms",
    };
    // This value is used in updateInternal as the default value if any flags can't be found.
    static constexpr int kFlagIntDefault = INT_MIN;
+4 −0
Original line number Diff line number Diff line
@@ -35,6 +35,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test {
        ASSERT_TRUE(tls1.startServer());
        ASSERT_TRUE(tls2.startServer());
        ASSERT_TRUE(backend.startServer());
        ASSERT_TRUE(backend1ForUdpProbe.startServer());
        ASSERT_TRUE(backend2ForUdpProbe.startServer());
    }

    void SetUp() {
@@ -132,6 +134,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test {
    inline static test::DnsTlsFrontend tls1{kServer1, "853", kBackend, "53"};
    inline static test::DnsTlsFrontend tls2{kServer2, "853", kBackend, "53"};
    inline static test::DNSResponder backend{kBackend, "53"};
    inline static test::DNSResponder backend1ForUdpProbe{kServer1, "53"};
    inline static test::DNSResponder backend2ForUdpProbe{kServer2, "53"};
};

TEST_F(PrivateDnsConfigurationTest, ValidationSuccess) {
+54 −42
Original line number Diff line number Diff line
@@ -146,11 +146,12 @@ using android::netdutils::Stopwatch;
static int send_vc(res_state statp, res_params* params, const uint8_t* buf, int buflen,
                   uint8_t* ans, int anssiz, int* terrno, size_t ns, time_t* at, int* rcode,
                   int* delay);
static int setupUdpSocket(ResState* statp, const sockaddr* sockap, size_t addrIndex, int* terrno);
static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen,
                   uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit,
                   int* gotsomewhere, time_t* at, int* rcode, int* delay);

static void dump_error(const char*, const struct sockaddr*, int);
static void dump_error(const char*, const struct sockaddr*);

static int sock_eq(struct sockaddr*, struct sockaddr*);
static int connect_with_timeout(int sock, const struct sockaddr* nsap, socklen_t salen,
@@ -726,14 +727,14 @@ same_ns:
        errno = 0;
        if (random_bind(statp->tcp_nssock, nsap->sa_family) < 0) {
            *terrno = errno;
            dump_error("bind/vc", nsap, nsaplen);
            dump_error("bind/vc", nsap);
            statp->closeSockets();
            return (0);
        }
        if (connect_with_timeout(statp->tcp_nssock, nsap, (socklen_t)nsaplen,
                                 get_timeout(statp, params, ns)) < 0) {
            *terrno = errno;
            dump_error("connect/vc", nsap, nsaplen);
            dump_error("connect/vc", nsap);
            statp->closeSockets();
            /*
             * The way connect_with_timeout() is implemented prevents us from reliably
@@ -932,7 +933,7 @@ retry:
static std::vector<pollfd> extractUdpFdset(res_state statp, const short events = POLLIN) {
    std::vector<pollfd> fdset(statp->nsaddrs.size());
    for (size_t i = 0; i < statp->nsaddrs.size(); ++i) {
        fdset[i] = {.fd = statp->nssocks[i], .events = events};
        fdset[i] = {.fd = statp->udpsocks[i], .events = events};
    }
    return fdset;
}
@@ -969,10 +970,10 @@ static Result<std::vector<int>> udpRetryingPollWrapper(res_state statp, int ns,
            android::net::Experiments::getInstance()->getFlag("keep_listening_udp", 0);
    if (keepListeningUdp) return udpRetryingPoll(statp, finish);

    if (int n = retrying_poll(statp->nssocks[ns], POLLIN, finish); n <= 0) {
    if (int n = retrying_poll(statp->udpsocks[ns], POLLIN, finish); n <= 0) {
        return ErrnoError();
    }
    return std::vector<int>{statp->nssocks[ns]};
    return std::vector<int>{statp->udpsocks[ns]};
}

bool ignoreInvalidAnswer(res_state statp, const sockaddr_storage& from, const uint8_t* buf,
@@ -997,66 +998,76 @@ bool ignoreInvalidAnswer(res_state statp, const sockaddr_storage& from, const ui
    return false;
}

static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen,
                   uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit,
                   int* gotsomewhere, time_t* at, int* rcode, int* delay) {
    // It should never happen, but just in case.
    if (*ns >= statp->nsaddrs.size()) {
        LOG(ERROR) << __func__ << ": Out-of-bound indexing: " << ns;
        *terrno = EINVAL;
        return -1;
    }
// return 1 when setup udp socket success.
// return 0 when timeout , bind error, network error(ex: Protocol not supported ...).
// return -1 when create socket fail, set socket option fail.
static int setupUdpSocket(ResState* statp, const sockaddr* sockap, size_t addrIndex, int* terrno) {
    statp->udpsocks[addrIndex].reset(socket(sockap->sa_family, SOCK_DGRAM | SOCK_CLOEXEC, 0));

    *at = time(nullptr);
    *delay = 0;
    const sockaddr_storage ss = statp->nsaddrs[*ns];
    const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss);
    const int nsaplen = sockaddrSize(nsap);

    if (statp->nssocks[*ns] == -1) {
        statp->nssocks[*ns].reset(socket(nsap->sa_family, SOCK_DGRAM | SOCK_CLOEXEC, 0));
        if (statp->nssocks[*ns] < 0) {
    if (statp->udpsocks[addrIndex] < 0) {
        *terrno = errno;
            PLOG(DEBUG) << __func__ << ": socket(dg): ";
        PLOG(ERROR) << __func__ << ": socket: ";
        switch (errno) {
            case EPROTONOSUPPORT:
            case EPFNOSUPPORT:
            case EAFNOSUPPORT:
                    return (0);
                return 0;
            default:
                    return (-1);
                return -1;
        }
    }

    const uid_t uid = statp->enforce_dns_uid ? AID_DNS : statp->uid;
        resolv_tag_socket(statp->nssocks[*ns], uid, statp->pid);
    resolv_tag_socket(statp->udpsocks[addrIndex], uid, statp->pid);
    if (statp->_mark != MARK_UNSET) {
            if (setsockopt(statp->nssocks[*ns], SOL_SOCKET, SO_MARK, &(statp->_mark),
        if (setsockopt(statp->udpsocks[addrIndex], SOL_SOCKET, SO_MARK, &(statp->_mark),
                       sizeof(statp->_mark)) < 0) {
            *terrno = errno;
            statp->closeSockets();
            return -1;
        }
    }

    if (random_bind(statp->udpsocks[addrIndex], sockap->sa_family) < 0) {
        *terrno = errno;
        dump_error("bind", sockap);
        statp->closeSockets();
        return 0;
    }
    return 1;
}

static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int buflen,
                   uint8_t* ans, int anssiz, int* terrno, size_t* ns, int* v_circuit,
                   int* gotsomewhere, time_t* at, int* rcode, int* delay) {
    // It should never happen, but just in case.
    if (*ns >= statp->nsaddrs.size()) {
        LOG(ERROR) << __func__ << ": Out-of-bound indexing: " << ns;
        *terrno = EINVAL;
        return -1;
    }

    *at = time(nullptr);
    *delay = 0;
    const sockaddr_storage ss = statp->nsaddrs[*ns];
    const sockaddr* nsap = reinterpret_cast<const sockaddr*>(&ss);

    if (statp->udpsocks[*ns] == -1) {
        int result = setupUdpSocket(statp, nsap, *ns, terrno);
        if (result <= 0) return result;

        // Use a "connected" datagram socket to receive an ECONNREFUSED error
        // on the next socket operation when the server responds with an
        // ICMP port-unreachable error. This way we can detect the absence of
        // a nameserver without timing out.
        if (random_bind(statp->nssocks[*ns], nsap->sa_family) < 0) {
            *terrno = errno;
            dump_error("bind(dg)", nsap, nsaplen);
            statp->closeSockets();
            return (0);
        }
        if (connect(statp->nssocks[*ns], nsap, (socklen_t)nsaplen) < 0) {
        if (connect(statp->udpsocks[*ns], nsap, sockaddrSize(nsap)) < 0) {
            *terrno = errno;
            dump_error("connect(dg)", nsap, nsaplen);
            dump_error("connect(dg)", nsap);
            statp->closeSockets();
            return (0);
        }
        LOG(DEBUG) << __func__ << ": new DG socket";
    }
    if (send(statp->nssocks[*ns], (const char*)buf, (size_t)buflen, 0) != buflen) {
    if (send(statp->udpsocks[*ns], (const char*)buf, (size_t)buflen, 0) != buflen) {
        *terrno = errno;
        PLOG(DEBUG) << __func__ << ": send: ";
        statp->closeSockets();
@@ -1150,7 +1161,7 @@ static int send_dg(res_state statp, res_params* params, const uint8_t* buf, int
    }
}

static void dump_error(const char* str, const struct sockaddr* address, int alen) {
static void dump_error(const char* str, const struct sockaddr* address) {
    char hbuf[NI_MAXHOST];
    char sbuf[NI_MAXSERV];
    constexpr int niflags = NI_NUMERICHOST | NI_NUMERICSERV;
@@ -1158,7 +1169,8 @@ static void dump_error(const char* str, const struct sockaddr* address, int alen

    if (!WOULD_LOG(DEBUG)) return;

    if (getnameinfo(address, (socklen_t)alen, hbuf, sizeof(hbuf), sbuf, sizeof(sbuf), niflags)) {
    if (getnameinfo(address, sockaddrSize(address), hbuf, sizeof(hbuf), sbuf, sizeof(sbuf),
                    niflags)) {
        strncpy(hbuf, "?", sizeof(hbuf) - 1);
        hbuf[sizeof(hbuf) - 1] = '\0';
        strncpy(sbuf, "?", sizeof(sbuf) - 1);
+2 −2
Original line number Diff line number Diff line
@@ -119,7 +119,7 @@ struct ResState {
        tcp_nssock.reset();
        _flags &= ~RES_F_VC;

        for (auto& sock : nssocks) {
        for (auto& sock : udpsocks) {
            sock.reset();
        }
    }
@@ -132,7 +132,7 @@ struct ResState {
    pid_t pid;                                  // pid of the app that sent the DNS lookup
    std::vector<std::string> search_domains{};  // domains to search
    std::vector<android::netdutils::IPSockAddr> nsaddrs;
    android::base::unique_fd nssocks[MAXNS];    // UDP sockets to nameservers
    android::base::unique_fd udpsocks[MAXNS];    // UDP sockets to nameservers and mdns responsder
    unsigned ndots : 4 = 1;                     // threshold for initial abs. query
    unsigned _mark;                             // If non-0 SET_MARK to _mark on all request sockets
    android::base::unique_fd tcp_nssock;        // TCP socket (but why not one per nameserver?)
Loading