Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit cd86cf9b authored by Mike Yu's avatar Mike Yu Committed by Automerger Merge Worker
Browse files

Support prioritizing DNS servers am: 61d17267

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1243491

Change-Id: I96359e44b18315f1235260026bb94bea92ef1310
parents 73722617 61d17267
Loading
Loading
Loading
Loading
+77 −5
Original line number Original line Diff line number Diff line
@@ -77,11 +77,14 @@ bool StatsData::operator==(const StatsData& o) const {
           std::tie(o.serverSockAddr, o.total, o.rcodeCounts, o.latencyUs);
           std::tie(o.serverSockAddr, o.total, o.rcodeCounts, o.latencyUs);
}
}


int StatsData::averageLatencyMs() const {
    return (total == 0) ? 0 : duration_cast<milliseconds>(latencyUs).count() / total;
}

std::string StatsData::toString() const {
std::string StatsData::toString() const {
    if (total == 0) return StringPrintf("%s <no data>", serverSockAddr.ip().toString().c_str());
    if (total == 0) return StringPrintf("%s <no data>", serverSockAddr.ip().toString().c_str());


    const auto now = std::chrono::steady_clock::now();
    const auto now = std::chrono::steady_clock::now();
    const int meanLatencyMs = duration_cast<milliseconds>(latencyUs).count() / total;
    const int lastUpdateSec = duration_cast<seconds>(now - lastUpdate).count();
    const int lastUpdateSec = duration_cast<seconds>(now - lastUpdate).count();
    std::string buf;
    std::string buf;
    for (const auto& [rcode, counts] : rcodeCounts) {
    for (const auto& [rcode, counts] : rcodeCounts) {
@@ -90,7 +93,7 @@ std::string StatsData::toString() const {
        }
        }
    }
    }
    return StringPrintf("%s (%d, %dms, [%s], %ds)", serverSockAddr.ip().toString().c_str(), total,
    return StringPrintf("%s (%d, %dms, [%s], %ds)", serverSockAddr.ip().toString().c_str(), total,
                        meanLatencyMs, buf.c_str(), lastUpdateSec);
                        averageLatencyMs(), buf.c_str(), lastUpdateSec);
}
}


StatsRecords::StatsRecords(const IPSockAddr& ipSockAddr, size_t size)
StatsRecords::StatsRecords(const IPSockAddr& ipSockAddr, size_t size)
@@ -104,6 +107,10 @@ void StatsRecords::push(const Record& record) {
        updateStatsData(mRecords.front(), false);
        updateStatsData(mRecords.front(), false);
        mRecords.pop_front();
        mRecords.pop_front();
    }
    }

    // Update the quality factors.
    mSkippedCount = 0;
    updatePenalty(record);
}
}


void StatsRecords::updateStatsData(const Record& record, const bool add) {
void StatsRecords::updateStatsData(const Record& record, const bool add) {
@@ -120,6 +127,41 @@ void StatsRecords::updateStatsData(const Record& record, const bool add) {
    mStatsData.lastUpdate = std::chrono::steady_clock::now();
    mStatsData.lastUpdate = std::chrono::steady_clock::now();
}
}


void StatsRecords::updatePenalty(const Record& record) {
    switch (record.rcode) {
        case NS_R_NO_ERROR:
        case NS_R_NXDOMAIN:
        case NS_R_NOTAUTH:
            mPenalty = 0;
            return;
        default:
            // NS_R_TIMEOUT and NS_R_INTERNAL_ERROR are in this case.
            if (mPenalty == 0) {
                mPenalty = 100;
            } else {
                // The evaluated quality drops more quickly when continuous failures happen.
                mPenalty = std::min(mPenalty * 2, kMaxQuality);
            }
            return;
    }
}

double StatsRecords::score() const {
    const int avgRtt = mStatsData.averageLatencyMs();

    // Set the lower bound to -1 in case of "avgRtt + mPenalty < mSkippedCount"
    //   1) when the server doesn't have any stats yet.
    //   2) when the sorting has been disabled while it was enabled before.
    int quality = std::clamp(avgRtt + mPenalty - mSkippedCount, -1, kMaxQuality);

    // Normalization.
    return static_cast<double>(kMaxQuality - quality) * 100 / kMaxQuality;
}

void StatsRecords::incrementSkippedCount() {
    mSkippedCount = std::min(mSkippedCount + 1, kMaxQuality);
}

bool DnsStats::setServers(const std::vector<netdutils::IPSockAddr>& servers, Protocol protocol) {
bool DnsStats::setServers(const std::vector<netdutils::IPSockAddr>& servers, Protocol protocol) {
    if (!ensureNoInvalidIp(servers)) return false;
    if (!ensureNoInvalidIp(servers)) return false;


@@ -147,6 +189,7 @@ bool DnsStats::setServers(const std::vector<netdutils::IPSockAddr>& servers, Pro
bool DnsStats::addStats(const IPSockAddr& ipSockAddr, const DnsQueryEvent& record) {
bool DnsStats::addStats(const IPSockAddr& ipSockAddr, const DnsQueryEvent& record) {
    if (ipSockAddr.ip() == INVALID_IPADDRESS) return false;
    if (ipSockAddr.ip() == INVALID_IPADDRESS) return false;


    bool added = false;
    for (auto& [serverSockAddr, statsRecords] : mStats[record.protocol()]) {
    for (auto& [serverSockAddr, statsRecords] : mStats[record.protocol()]) {
        if (serverSockAddr == ipSockAddr) {
        if (serverSockAddr == ipSockAddr) {
            const StatsRecords::Record rec = {
            const StatsRecords::Record rec = {
@@ -154,10 +197,36 @@ bool DnsStats::addStats(const IPSockAddr& ipSockAddr, const DnsQueryEvent& recor
                    .latencyUs = microseconds(record.latency_micros()),
                    .latencyUs = microseconds(record.latency_micros()),
            };
            };
            statsRecords.push(rec);
            statsRecords.push(rec);
            return true;
            added = true;
        } else {
            statsRecords.incrementSkippedCount();
        }
        }
    }
    }
    return false;

    return added;
}

std::vector<IPSockAddr> DnsStats::getSortedServers(Protocol protocol) const {
    // DoT unsupported. The handshake overhead is expensive, and the connection will hang for a
    // while. Need to figure out if it is worth doing for DoT servers.
    if (protocol == PROTO_DOT) return {};

    auto it = mStats.find(protocol);
    if (it == mStats.end()) return {};

    // Sorting on insertion in decreasing order.
    std::multimap<double, IPSockAddr, std::greater<double>> sortedData;
    for (const auto& [ip, statsRecords] : it->second) {
        sortedData.insert({statsRecords.score(), ip});
    }

    std::vector<IPSockAddr> ret;
    ret.reserve(sortedData.size());
    for (auto& [_, v] : sortedData) {
        ret.push_back(v);  // IPSockAddr is trivially-copyable.
    }

    return ret;
}
}


std::vector<StatsData> DnsStats::getStats(Protocol protocol) const {
std::vector<StatsData> DnsStats::getStats(Protocol protocol) const {
@@ -179,7 +248,10 @@ void DnsStats::dump(DumpWriter& dw) {
            return;
            return;
        }
        }
        for (const auto& [_, statsRecords] : statsMap) {
        for (const auto& [_, statsRecords] : statsMap) {
            dw.println("%s", statsRecords.getStatsData().toString().c_str());
            const StatsData& data = statsRecords.getStatsData();
            std::string str = data.toString();
            str += StringPrintf(" score{%.1f}", statsRecords.score());
            dw.println("%s", str.c_str());
        }
        }
    };
    };


+22 −1
Original line number Original line Diff line number Diff line
@@ -54,6 +54,7 @@ struct StatsData {
    // The last update timestamp.
    // The last update timestamp.
    std::chrono::time_point<std::chrono::steady_clock> lastUpdate;
    std::chrono::time_point<std::chrono::steady_clock> lastUpdate;


    int averageLatencyMs() const;
    std::string toString() const;
    std::string toString() const;


    // For testing.
    // For testing.
@@ -77,12 +78,31 @@ class StatsRecords {


    const StatsData& getStatsData() const { return mStatsData; }
    const StatsData& getStatsData() const { return mStatsData; }


    // Quantifies the quality based on the current quality factors and the latency, and normalize
    // the value to a score between 0 to 100.
    double score() const;

    void incrementSkippedCount();

  private:
  private:
    void updateStatsData(const Record& record, const bool add);
    void updateStatsData(const Record& record, const bool add);
    void updatePenalty(const Record& record);


    std::deque<Record> mRecords;
    std::deque<Record> mRecords;
    size_t mCapacity;
    size_t mCapacity;
    StatsData mStatsData;
    StatsData mStatsData;

    // A quality factor used to distinguish if the server can't be evaluated by latency alone, such
    // as instant failure on connect.
    int mPenalty = 0;

    // A quality factor used to prevent starvation.
    int mSkippedCount = 0;

    // The maximum of the quantified result. As the sorting is on the basis of server latency, limit
    // the maximal value of the quantity to 10000 in correspondence with the maximal cleartext
    // query timeout 10000 milliseconds. This helps normalize the value of the quality to a score.
    static constexpr int kMaxQuality = 10000;
};
};


// DnsStats class manages the statistics of DNS servers per netId.
// DnsStats class manages the statistics of DNS servers per netId.
@@ -98,13 +118,14 @@ class DnsStats {
    // Return true if |record| is successfully added into |server|'s stats; otherwise, return false.
    // Return true if |record| is successfully added into |server|'s stats; otherwise, return false.
    bool addStats(const netdutils::IPSockAddr& server, const DnsQueryEvent& record);
    bool addStats(const netdutils::IPSockAddr& server, const DnsQueryEvent& record);


    std::vector<netdutils::IPSockAddr> getSortedServers(Protocol protocol) const;

    void dump(netdutils::DumpWriter& dw);
    void dump(netdutils::DumpWriter& dw);


    // For testing.
    // For testing.
    std::vector<StatsData> getStats(Protocol protocol) const;
    std::vector<StatsData> getStats(Protocol protocol) const;


    // TODO: Compatible support for getResolverInfo().
    // TODO: Compatible support for getResolverInfo().
    // TODO: Support getSortedServers().


    static constexpr size_t kLogSize = 128;
    static constexpr size_t kLogSize = 128;


+113 −3
Original line number Original line Diff line number Diff line
@@ -52,6 +52,8 @@ StatsData makeStatsData(const IPSockAddr& server, const int total, const millise


}  // namespace
}  // namespace


// TODO: add StatsDataTest to ensure its methods return correct outputs.

class StatsRecordsTest : public ::testing::Test {};
class StatsRecordsTest : public ::testing::Test {};


TEST_F(StatsRecordsTest, PushRecord) {
TEST_F(StatsRecordsTest, PushRecord) {
@@ -95,9 +97,9 @@ class DnsStatsTest : public ::testing::Test {
    void verifyDumpOutput(const std::vector<StatsData>& tcpData,
    void verifyDumpOutput(const std::vector<StatsData>& tcpData,
                          const std::vector<StatsData>& udpData,
                          const std::vector<StatsData>& udpData,
                          const std::vector<StatsData>& dotData) {
                          const std::vector<StatsData>& dotData) {
        // A simple pattern to capture two matches:
        // A pattern to capture three matches:
        //     server address (empty allowed) and its statistics.
        //     server address (empty allowed), the statistics, and the score.
        const std::regex pattern(R"(\s{4,}([0-9a-fA-F:\.]*) ([<(].*[>)]))");
        const std::regex pattern(R"(\s{4,}([0-9a-fA-F:\.]*)[ ]?([<(].*[>)])[ ]?(\S*))");
        std::string dumpString = captureDumpOutput();
        std::string dumpString = captureDumpOutput();


        const auto check = [&](const std::vector<StatsData>& statsData, const std::string& protocol,
        const auto check = [&](const std::vector<StatsData>& statsData, const std::string& protocol,
@@ -111,6 +113,7 @@ class DnsStatsTest : public ::testing::Test {
                ASSERT_TRUE(std::regex_search(*dumpString, sm, pattern));
                ASSERT_TRUE(std::regex_search(*dumpString, sm, pattern));
                EXPECT_TRUE(sm[1].str().empty());
                EXPECT_TRUE(sm[1].str().empty());
                EXPECT_EQ(sm[2], "<no server>");
                EXPECT_EQ(sm[2], "<no server>");
                EXPECT_TRUE(sm[3].str().empty());
                *dumpString = sm.suffix();
                *dumpString = sm.suffix();
                return;
                return;
            }
            }
@@ -119,6 +122,7 @@ class DnsStatsTest : public ::testing::Test {
                ASSERT_TRUE(std::regex_search(*dumpString, sm, pattern));
                ASSERT_TRUE(std::regex_search(*dumpString, sm, pattern));
                EXPECT_EQ(sm[1], stats.serverSockAddr.ip().toString());
                EXPECT_EQ(sm[1], stats.serverSockAddr.ip().toString());
                EXPECT_FALSE(sm[2].str().empty());
                EXPECT_FALSE(sm[2].str().empty());
                EXPECT_FALSE(sm[3].str().empty());
                *dumpString = sm.suffix();
                *dumpString = sm.suffix();
            }
            }
        };
        };
@@ -379,4 +383,110 @@ TEST_F(DnsStatsTest, AddStatsRecords_100000) {
    verifyDumpOutput(expectedStats, expectedStats, expectedStats);
    verifyDumpOutput(expectedStats, expectedStats, expectedStats);
}
}


TEST_F(DnsStatsTest, GetServers_SortingByLatency) {
    const IPSockAddr server1 = IPSockAddr::toIPSockAddr("127.0.0.1", 53);
    const IPSockAddr server2 = IPSockAddr::toIPSockAddr("127.0.0.2", 53);
    const IPSockAddr server3 = IPSockAddr::toIPSockAddr("2001:db8:cafe:d00d::1", 53);
    const IPSockAddr server4 = IPSockAddr::toIPSockAddr("2001:db8:cafe:d00d::2", 53);

    // Return empty list before setup.
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP), IsEmpty());

    // Before there's any stats, the list of the sorted servers is the same as the setup's one.
    EXPECT_TRUE(mDnsStats.setServers({server1, server2, server3, server4}, PROTO_UDP));
    EXPECT_TRUE(mDnsStats.setServers({server1, server2, server3, server4}, PROTO_DOT));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server1, server2, server3, server4}));

    // Add a record to server1. The qualities of the other servers increase.
    EXPECT_TRUE(mDnsStats.addStats(server1, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 10ms)));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server2, server3, server4, server1}));

    // Add a record, with less repose time than server1, to server3.
    EXPECT_TRUE(mDnsStats.addStats(server3, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 5ms)));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server2, server4, server3, server1}));

    // Even though server2 has zero response time, select server4 as the first server because it
    // doesn't have stats yet.
    EXPECT_TRUE(mDnsStats.addStats(server2, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 0ms)));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server4, server2, server3, server1}));

    // Updating DoT record to server4 changes nothing.
    EXPECT_TRUE(mDnsStats.addStats(server4, makeDnsQueryEvent(PROTO_DOT, NS_R_NO_ERROR, 10ms)));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server4, server2, server3, server1}));

    // Add a record, with a very large value of respose time, to server4.
    EXPECT_TRUE(mDnsStats.addStats(server4, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 500000ms)));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server2, server3, server1, server4}));

    // The list of the DNS servers changed.
    EXPECT_TRUE(mDnsStats.setServers({server2, server4}, PROTO_UDP));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server2, server4}));

    // It fails to add records to an non-existing server, and nothing is changed in getting
    // the sorted servers.
    EXPECT_FALSE(mDnsStats.addStats(server1, makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 10ms)));
    EXPECT_THAT(mDnsStats.getSortedServers(PROTO_UDP),
                testing::ElementsAreArray({server2, server4}));
}

TEST_F(DnsStatsTest, GetServers_DeprioritizingBadServers) {
    const IPSockAddr server1 = IPSockAddr::toIPSockAddr("127.0.0.1", 53);
    const IPSockAddr server2 = IPSockAddr::toIPSockAddr("127.0.0.2", 53);
    const IPSockAddr server3 = IPSockAddr::toIPSockAddr("127.0.0.3", 53);
    const IPSockAddr server4 = IPSockAddr::toIPSockAddr("127.0.0.4", 53);

    EXPECT_TRUE(mDnsStats.setServers({server1, server2, server3, server4}, PROTO_UDP));

    int server1Counts = 0;
    int server2Counts = 0;
    for (int i = 0; i < 5000; i++) {
        const auto servers = mDnsStats.getSortedServers(PROTO_UDP);
        EXPECT_EQ(servers.size(), 4U);
        if (servers[0] == server1) {
            // server1 is relatively slowly responsive.
            EXPECT_TRUE(mDnsStats.addStats(servers[0],
                                           makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 200ms)));
            server1Counts++;
        } else if (servers[0] == server2) {
            // server2 is relatively quickly responsive.
            EXPECT_TRUE(mDnsStats.addStats(servers[0],
                                           makeDnsQueryEvent(PROTO_UDP, NS_R_NO_ERROR, 100ms)));
            server2Counts++;
        } else if (servers[0] == server3) {
            // server3 always times out.
            EXPECT_TRUE(mDnsStats.addStats(servers[0],
                                           makeDnsQueryEvent(PROTO_UDP, NS_R_TIMEOUT, 1000ms)));
        } else if (servers[0] == server4) {
            // server4 is unusable.
            EXPECT_TRUE(mDnsStats.addStats(servers[0],
                                           makeDnsQueryEvent(PROTO_UDP, NS_R_INTERNAL_ERROR, 1ms)));
        }
    }

    const std::vector<StatsData> allStatsData = mDnsStats.getStats(PROTO_UDP);
    for (const auto& data : allStatsData) {
        EXPECT_EQ(data.rcodeCounts.size(), 1U);
        if (data.serverSockAddr == server1 || data.serverSockAddr == server2) {
            const auto it = data.rcodeCounts.find(NS_R_NO_ERROR);
            ASSERT_NE(it, data.rcodeCounts.end());
            EXPECT_GT(server2Counts, 2 * server1Counts);  // At least twice larger.
        } else if (data.serverSockAddr == server3) {
            const auto it = data.rcodeCounts.find(NS_R_TIMEOUT);
            ASSERT_NE(it, data.rcodeCounts.end());
            EXPECT_LT(it->second, 10);
        } else if (data.serverSockAddr == server4) {
            const auto it = data.rcodeCounts.find(NS_R_INTERNAL_ERROR);
            ASSERT_NE(it, data.rcodeCounts.end());
            EXPECT_LT(it->second, 10);
        }
    }
}

}  // namespace android::net
}  // namespace android::net
+2 −1
Original line number Original line Diff line number Diff line
@@ -49,7 +49,8 @@ class Experiments {
    // TODO: Migrate other experiment flags to here.
    // TODO: Migrate other experiment flags to here.
    // (retry_count, retransmission_time_interval, dot_connect_timeout_ms)
    // (retry_count, retransmission_time_interval, dot_connect_timeout_ms)
    static constexpr const char* const kExperimentFlagKeyList[] = {
    static constexpr const char* const kExperimentFlagKeyList[] = {
            "keep_listening_udp", "parallel_lookup", "parallel_lookup_sleep_time"};
            "keep_listening_udp", "parallel_lookup", "parallel_lookup_sleep_time",
            "sort_nameservers"};
    // This value is used in updateInternal as the default value if any flags can't be found.
    // This value is used in updateInternal as the default value if any flags can't be found.
    static constexpr int kFlagIntDefault = INT_MIN;
    static constexpr int kFlagIntDefault = INT_MIN;
    // For testing.
    // For testing.
+6 −1
Original line number Original line Diff line number Diff line
@@ -60,6 +60,7 @@
#include <server_configurable_flags/get_flags.h>
#include <server_configurable_flags/get_flags.h>


#include "DnsStats.h"
#include "DnsStats.h"
#include "Experiments.h"
#include "res_comp.h"
#include "res_comp.h"
#include "res_debug.h"
#include "res_debug.h"
#include "resolv_private.h"
#include "resolv_private.h"
@@ -69,6 +70,7 @@ using aidl::android::net::IDnsResolver;
using android::base::StringAppendF;
using android::base::StringAppendF;
using android::net::DnsQueryEvent;
using android::net::DnsQueryEvent;
using android::net::DnsStats;
using android::net::DnsStats;
using android::net::Experiments;
using android::net::PROTO_DOT;
using android::net::PROTO_DOT;
using android::net::PROTO_TCP;
using android::net::PROTO_TCP;
using android::net::PROTO_UDP;
using android::net::PROTO_UDP;
@@ -1682,7 +1684,10 @@ void resolv_populate_res_for_net(ResState* statp) {
    NetConfig* info = find_netconfig_locked(statp->netid);
    NetConfig* info = find_netconfig_locked(statp->netid);
    if (info == nullptr) return;
    if (info == nullptr) return;


    statp->nsaddrs = info->nameserverSockAddrs;
    const bool sortNameservers = Experiments::getInstance()->getFlag("sort_nameservers", 0);
    statp->sort_nameservers = sortNameservers;
    statp->nsaddrs = sortNameservers ? info->dnsStats.getSortedServers(PROTO_UDP)
                                     : info->nameserverSockAddrs;
    statp->search_domains = info->search_domains;
    statp->search_domains = info->search_domains;
    statp->tc_mode = info->tc_mode;
    statp->tc_mode = info->tc_mode;
    statp->enforce_dns_uid = info->enforceDnsUid;
    statp->enforce_dns_uid = info->enforceDnsUid;
Loading