Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 61d17267 authored by Mike Yu's avatar Mike Yu
Browse files

Support prioritizing DNS servers

The change introduces a way to prioritize DNS servers on the basis of
DNS query response time, which aims to replace the current design that
is biased towards using the first DNS server assigned from networks.

The quality is evaluated based on the heuristics:
  - The more latency it is, the less likely it is used.
  - The longer time it is not used, the more likely it is used.

Compared to the current design, the proposed method detects bad DNS
servers more quickly. For instance, a server which is unreachable or
times out can be detected and deprioritized with few trials by backoff
penalty and abnormal latency.

Similar to the current design, a server which has been regarded as bad
quality can be used again, but it depends on how much worse it is. A
counter is used to count how many times a DNS server not being used,
which avoids from constantly using the same DNS server.

This change comprises:

[1] Allow the resolver to sort DNS servers on the basis of DNS query
    response time.
[2] Add an experiment flag to enable/disable the sorting.
[3] Show the result of the quantified quality of DNS servers in
    dumpsys dnsresolver.
[4] Add unit tests for DnsStats::getSortedServers().
[5] Revise the integration tests which are sensitive to the nameserver
    sorting, including two big changes in SkipBadServersDueToInternalError
    and SkipBadServersDueToTimeout and some minor changes.

Bug: 137169582
Test: ran resolv_unit_test
      ran resolv_integration_test with the sorting enabled
      ran resolv_integration_test with the sorting disabled
Change-Id: I24b6a317f135a942ce0ea310c81dfe658bada6a7
parent 6ce587d2
Loading
Loading
Loading
Loading
+77 −5
Original line number Diff line number Diff line
@@ -77,11 +77,14 @@ bool StatsData::operator==(const StatsData& o) const {
           std::tie(o.serverSockAddr, o.total, o.rcodeCounts, o.latencyUs);
}

int StatsData::averageLatencyMs() const {
    return (total == 0) ? 0 : duration_cast<milliseconds>(latencyUs).count() / total;
}

std::string StatsData::toString() const {
    if (total == 0) return StringPrintf("%s <no data>", serverSockAddr.ip().toString().c_str());

    const auto now = std::chrono::steady_clock::now();
    const int meanLatencyMs = duration_cast<milliseconds>(latencyUs).count() / total;
    const int lastUpdateSec = duration_cast<seconds>(now - lastUpdate).count();
    std::string buf;
    for (const auto& [rcode, counts] : rcodeCounts) {
@@ -90,7 +93,7 @@ std::string StatsData::toString() const {
        }
    }
    return StringPrintf("%s (%d, %dms, [%s], %ds)", serverSockAddr.ip().toString().c_str(), total,
                        meanLatencyMs, buf.c_str(), lastUpdateSec);
                        averageLatencyMs(), buf.c_str(), lastUpdateSec);
}

StatsRecords::StatsRecords(const IPSockAddr& ipSockAddr, size_t size)
@@ -104,6 +107,10 @@ void StatsRecords::push(const Record& record) {
        updateStatsData(mRecords.front(), false);
        mRecords.pop_front();
    }

    // Update the quality factors.
    mSkippedCount = 0;
    updatePenalty(record);
}

void StatsRecords::updateStatsData(const Record& record, const bool add) {
@@ -120,6 +127,41 @@ void StatsRecords::updateStatsData(const Record& record, const bool add) {
    mStatsData.lastUpdate = std::chrono::steady_clock::now();
}

void StatsRecords::updatePenalty(const Record& record) {
    switch (record.rcode) {
        case NS_R_NO_ERROR:
        case NS_R_NXDOMAIN:
        case NS_R_NOTAUTH:
            mPenalty = 0;
            return;
        default:
            // NS_R_TIMEOUT and NS_R_INTERNAL_ERROR are in this case.
            if (mPenalty == 0) {
                mPenalty = 100;
            } else {
                // The evaluated quality drops more quickly when continuous failures happen.
                mPenalty = std::min(mPenalty * 2, kMaxQuality);
            }
            return;
    }
}

double StatsRecords::score() const {
    const int avgRtt = mStatsData.averageLatencyMs();

    // Set the lower bound to -1 in case of "avgRtt + mPenalty < mSkippedCount"
    //   1) when the server doesn't have any stats yet.
    //   2) when the sorting has been disabled while it was enabled before.
    int quality = std::clamp(avgRtt + mPenalty - mSkippedCount, -1, kMaxQuality);

    // Normalization.
    return static_cast<double>(kMaxQuality - quality) * 100 / kMaxQuality;
}

void StatsRecords::incrementSkippedCount() {
    mSkippedCount = std::min(mSkippedCount + 1, kMaxQuality);
}

bool DnsStats::setServers(const std::vector<netdutils::IPSockAddr>& servers, Protocol protocol) {
    if (!ensureNoInvalidIp(servers)) return false;

@@ -147,6 +189,7 @@ bool DnsStats::setServers(const std::vector<netdutils::IPSockAddr>& servers, Pro
bool DnsStats::addStats(const IPSockAddr& ipSockAddr, const DnsQueryEvent& record) {
    if (ipSockAddr.ip() == INVALID_IPADDRESS) return false;

    bool added = false;
    for (auto& [serverSockAddr, statsRecords] : mStats[record.protocol()]) {
        if (serverSockAddr == ipSockAddr) {
            const StatsRecords::Record rec = {
@@ -154,10 +197,36 @@ bool DnsStats::addStats(const IPSockAddr& ipSockAddr, const DnsQueryEvent& recor
                    .latencyUs = microseconds(record.latency_micros()),
            };
            statsRecords.push(rec);
            return true;
            added = true;
        } else {
            statsRecords.incrementSkippedCount();
        }
    }
    return false;

    return added;
}

std::vector<IPSockAddr> DnsStats::getSortedServers(Protocol protocol) const {
    // DoT unsupported. The handshake overhead is expensive, and the connection will hang for a
    // while. Need to figure out if it is worth doing for DoT servers.
    if (protocol == PROTO_DOT) return {};

    auto it = mStats.find(protocol);
    if (it == mStats.end()) return {};

    // Sorting on insertion in decreasing order.
    std::multimap<double, IPSockAddr, std::greater<double>> sortedData;
    for (const auto& [ip, statsRecords] : it->second) {
        sortedData.insert({statsRecords.score(), ip});
    }

    std::vector<IPSockAddr> ret;
    ret.reserve(sortedData.size());
    for (auto& [_, v] : sortedData) {
        ret.push_back(v);  // IPSockAddr is trivially-copyable.
    }

    return ret;
}

std::vector<StatsData> DnsStats::getStats(Protocol protocol) const {
@@ -179,7 +248,10 @@ void DnsStats::dump(DumpWriter& dw) {
            return;
        }
        for (const auto& [_, statsRecords] : statsMap) {
            dw.println("%s", statsRecords.getStatsData().toString().c_str());
            const StatsData& data = statsRecords.getStatsData();
            std::string str = data.toString();
            str += StringPrintf(" score{%.1f}", statsRecords.score());
            dw.println("%s", str.c_str());
        }
    };

+22 −1
Original line number Diff line number Diff line
@@ -54,6 +54,7 @@ struct StatsData {
    // The last update timestamp.
    std::chrono::time_point<std::chrono::steady_clock> lastUpdate;

    int averageLatencyMs() const;
    std::string toString() const;

    // For testing.
@@ -77,12 +78,31 @@ class StatsRecords {

    const StatsData& getStatsData() const { return mStatsData; }

    // Quantifies the quality based on the current quality factors and the latency, and normalize
    // the value to a score between 0 to 100.
    double score() const;

    void incrementSkippedCount();

  private:
    void updateStatsData(const Record& record, const bool add);
    void updatePenalty(const Record& record);

    std::deque<Record> mRecords;
    size_t mCapacity;
    StatsData mStatsData;

    // A quality factor used to distinguish if the server can't be evaluated by latency alone, such
    // as instant failure on connect.
    int mPenalty = 0;

    // A quality factor used to prevent starvation.
    int mSkippedCount = 0;

    // The maximum of the quantified result. As the sorting is on the basis of server latency, limit
    // the maximal value of the quantity to 10000 in correspondence with the maximal cleartext
    // query timeout 10000 milliseconds. This helps normalize the value of the quality to a score.
    static constexpr int kMaxQuality = 10000;
};

// DnsStats class manages the statistics of DNS servers per netId.
@@ -98,13 +118,14 @@ class DnsStats {
    // Return true if |record| is successfully added into |server|'s stats; otherwise, return false.
    bool addStats(const netdutils::IPSockAddr& server, const DnsQueryEvent& record);

    std::vector<netdutils::IPSockAddr> getSortedServers(Protocol protocol) const;

    void dump(netdutils::DumpWriter& dw);

    // For testing.
    std::vector<StatsData> getStats(Protocol protocol) const;

    // TODO: Compatible support for getResolverInfo().
    // TODO: Support getSortedServers().

    static constexpr size_t kLogSize = 128;

+113 −3
Original line number Diff line number Diff line
@@ -52,6 +52,8 @@ StatsData makeStatsData(const IPSockAddr& server, const int total, const millise

}  // namespace

// TODO: add StatsDataTest to ensure its methods return correct outputs.

class StatsRecordsTest : public ::testing::Test {};

TEST_F(StatsRecordsTest, PushRecord) {
@@ -95,9 +97,9 @@ class DnsStatsTest : public ::testing::Test {
    void verifyDumpOutput(const std::vector<StatsData>& tcpData,
                          const std::vector<StatsData>& udpData,
                          const std::vector<StatsData>& dotData) {
        // A simple pattern to capture two matches:
        //     server address (empty allowed) and its statistics.
        const std::regex pattern(R"(\s{4,}([0-9a-fA-F:\.]*) ([<(].*[>)]))");
        // A pattern to capture three matches:
        //     server address (empty allowed), the statistics, and the score.
        const std::regex pattern(R"(\s{4,}([0-9a-fA-F:\.]*)[ ]?([<(].*[>)])[ ]?(\S*))");
        std::string dumpString = captureDumpOutput();

        const auto check = [&](const std::vector<StatsData>& statsData, const std::string& protocol,
@@ -111,6 +113,7 @@ class DnsStatsTest : public ::testing::Test {
                ASSERT_TRUE(std::regex_search(*dumpString, sm, pattern));
                EXPECT_TRUE(sm[1].str().empty());
                EXPECT_EQ(sm[2], "<no server>");
                EXPECT_TRUE(sm[3].str().empty());
                *dumpString = sm.suffix();
                return;
            }
@@ -119,6 +122,7 @@ class DnsStatsTest : public ::testing::Test {
                ASSERT_TRUE(std::regex_search(*dumpString, sm, pattern));
                EXPECT_EQ(sm[1], stats.serverSockAddr.ip().toString());
                EXPECT_FALSE(sm[2].str().empty());
                EXPECT_FALSE(sm[3].str().empty());
                *dumpString = sm.suffix();
            }
        };
@@ -379,4 +383,110 @@ TEST_F(DnsStatsTest, AddStatsRecords_100000) {
    verifyDumpOutput(expectedStats, expectedStats, expectedStats);
}

TEST_F(DnsStatsTest, GetServers_SortingByLatency) {
    const IPSockAddr server1 = IPSockAddr::toIPSockAddr("127.0.0.1", 53);
    const IPSockAddr server2 = IPSockAddr::toIPSockAddr("127.0.0.2", 53);
    const IPSockAddr server3 = IPSockAddr::toIPSockAddr("2001:db8:cafe:d00d::1", 53);
    const IPSockAddr server4 = IPSockAddr::toIPSockAddr("2001:db8:cafe:d00d::2", 53);

    // Return empty list before setup.
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP), IsEmpty());

    // Before there's any stats, the list of the sorted servers is the same as the setup's one.
    EXPECT_TRUE(mDnsStats.setServers({server1, server2, server3, server4}, PROTO_UDP));
    EXPECT_TRUE(mDnsStats.setServers({server1, server2, server3, server4}, PROTO_DOT));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server1, server2, server3, server4}));

    // Add a record to server1. The qualities of the other servers increase.
    EXPECT_TRUE(mDnsStats.addStats(server1, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 10ms)));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server2, server3, server4, server1}));

    // Add a record, with less repose time than server1, to server3.
    EXPECT_TRUE(mDnsStats.addStats(server3, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 5ms)));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server2, server4, server3, server1}));

    // Even though server2 has zero response time, select server4 as the first server because it
    // doesn't have stats yet.
    EXPECT_TRUE(mDnsStats.addStats(server2, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 0ms)));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server4, server2, server3, server1}));

    // Updating DoT record to server4 changes nothing.
    EXPECT_TRUE(mDnsStats.addStats(server4, makeDnsQueryEvent(PROTO_DOT, NS_R_NO_ERROR, 10ms)));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server4, server2, server3, server1}));

    // Add a record, with a very large value of respose time, to server4.
    EXPECT_TRUE(mDnsStats.addStats(server4, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 500000ms)));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server2, server3, server1, server4}));

    // The list of the DNS servers changed.
    EXPECT_TRUE(mDnsStats.setServers({server2, server4}, PROTO_UDP));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server2, server4}));

    // It fails to add records to an non-existing server, and nothing is changed in getting
    // the sorted servers.
    EXPECT_FALSE(mDnsStats.addStats(server1, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 10ms)));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server2, server4}));
}

TEST_F(DnsStatsTest, GetServers_DeprioritizingBadServers) {
    const IPSockAddr server1 = IPSockAddr::toIPSockAddr("127.0.0.1", 53);
    const IPSockAddr server2 = IPSockAddr::toIPSockAddr("127.0.0.2", 53);
    const IPSockAddr server3 = IPSockAddr::toIPSockAddr("127.0.0.3", 53);
    const IPSockAddr server4 = IPSockAddr::toIPSockAddr("127.0.0.4", 53);

    EXPECT_TRUE(mDnsStats.setServers({server1, server2, server3, server4}, PROTO_UDP));

    int server1Counts = 0;
    int server2Counts = 0;
    for (int i = 0; i < 5000; i++) {
        const auto servers = mDnsStats.getSortedServers(PROTO_UDP);
        EXPECT_EQ(servers.size(), 4U);
        if (servers[0] == server1) {
            // server1 is relatively slowly responsive.
            EXPECT_TRUE(mDnsStats.addStats(servers[0],
                                           makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 200ms)));
            server1Counts++;
        } else if (servers[0] == server2) {
            // server2 is relatively quickly responsive.
            EXPECT_TRUE(mDnsStats.addStats(servers[0],
                                           makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 100ms)));
            server2Counts++;
        } else if (servers[0] == server3) {
            // server3 always times out.
            EXPECT_TRUE(mDnsStats.addStats(servers[0],
                                           makeDnsQueryEvent(PROTO_UDP, NS_R_TIMEOUT, 1000ms)));
        } else if (servers[0] == server4) {
            // server4 is unusable.
            EXPECT_TRUE(mDnsStats.addStats(servers[0],
                                           makeDnsQueryEvent(PROTO_UDP, NS_R_INTERNAL_ERROR, 1ms)));
        }
    }

    const std::vector<StatsData> allStatsData = mDnsStats.getStats(PROTO_UDP);
    for (const auto& data : allStatsData) {
        EXPECT_EQ(data.rcodeCounts.size(), 1U);
        if (data.serverSockAddr == server1 || data.serverSockAddr == server2) {
            const auto it = data.rcodeCounts.find(NS_R_NO_ERROR);
            ASSERT_NE(it, data.rcodeCounts.end());
            EXPECT_GT(server2Counts, 2 * server1Counts);  // At least twice larger.
        } else if (data.serverSockAddr == server3) {
            const auto it = data.rcodeCounts.find(NS_R_TIMEOUT);
            ASSERT_NE(it, data.rcodeCounts.end());
            EXPECT_LT(it->second, 10);
        } else if (data.serverSockAddr == server4) {
            const auto it = data.rcodeCounts.find(NS_R_INTERNAL_ERROR);
            ASSERT_NE(it, data.rcodeCounts.end());
            EXPECT_LT(it->second, 10);
        }
    }
}

}  // namespace android::net
+2 −1
Original line number Diff line number Diff line
@@ -49,7 +49,8 @@ class Experiments {
    // TODO: Migrate other experiment flags to here.
    // (retry_count, retransmission_time_interval, dot_connect_timeout_ms)
    static constexpr const char* const kExperimentFlagKeyList[] = {
            "keep_listening_udp", "parallel_lookup", "parallel_lookup_sleep_time"};
            "keep_listening_udp", "parallel_lookup", "parallel_lookup_sleep_time",
            "sort_nameservers"};
    // This value is used in updateInternal as the default value if any flags can't be found.
    static constexpr int kFlagIntDefault = INT_MIN;
    // For testing.
+6 −1
Original line number Diff line number Diff line
@@ -60,6 +60,7 @@
#include <server_configurable_flags/get_flags.h>

#include "DnsStats.h"
#include "Experiments.h"
#include "res_comp.h"
#include "res_debug.h"
#include "resolv_private.h"
@@ -69,6 +70,7 @@ using aidl::android::net::IDnsResolver;
using android::base::StringAppendF;
using android::net::DnsQueryEvent;
using android::net::DnsStats;
using android::net::Experiments;
using android::net::PROTO_DOT;
using android::net::PROTO_TCP;
using android::net::PROTO_UDP;
@@ -1682,7 +1684,10 @@ void resolv_populate_res_for_net(ResState* statp) {
    NetConfig* info = find_netconfig_locked(statp->netid);
    if (info == nullptr) return;

    statp->nsaddrs = info->nameserverSockAddrs;
    const bool sortNameservers = Experiments::getInstance()->getFlag("sort_nameservers", 0);
    statp->sort_nameservers = sortNameservers;
    statp->nsaddrs = sortNameservers ? info->dnsStats.getSortedServers(PROTO_UDP)
                                     : info->nameserverSockAddrs;
    statp->search_domains = info->search_domains;
    statp->tc_mode = info->tc_mode;
    statp->enforce_dns_uid = info->enforceDnsUid;
Loading