Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 078a976a authored by Treehugger Robot's avatar Treehugger Robot Committed by Gerrit Code Review
Browse files

Merge "Refactor DnsResponderClient::GetResolverInfo"

parents e5d3eb9e 21975f31
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -100,6 +100,7 @@ int getDnsInfo(unsigned netId, std::vector<std::string>* servers, std::vector<st
    domains->clear();
    *params = res_params{};
    stats->clear();
    wait_for_pending_req_timeout_count->clear();
    int res_wait_for_pending_req_timeout_count;
    int revision_id = android_net_res_stats_get_info_for_net(
            netId, &nscount, res_servers, &dcount, res_domains, params, res_stats,
@@ -149,7 +150,7 @@ int getDnsInfo(unsigned netId, std::vector<std::string>* servers, std::vector<st
        domains->push_back(res_domains[i]);
    }

    (*wait_for_pending_req_timeout_count)[0] = res_wait_for_pending_req_timeout_count;
    wait_for_pending_req_timeout_count->push_back(res_wait_for_pending_req_timeout_count);
    return 0;
}

+45 −28
Original line number Diff line number Diff line
@@ -30,6 +30,8 @@ using aidl::android::net::IDnsResolver;
using aidl::android::net::INetd;
using aidl::android::net::ResolverOptionsParcel;
using aidl::android::net::ResolverParamsParcel;
using android::base::Error;
using android::base::Result;
using android::net::ResolverStats;

ResolverParams::Builder::Builder() {
@@ -66,36 +68,51 @@ void DnsResponderClient::SetupMappings(unsigned numHosts, const std::vector<std:
    }
}

bool DnsResponderClient::GetResolverInfo(aidl::android::net::IDnsResolver* dnsResolverService,
                                         unsigned netId, std::vector<std::string>* servers,
                                         std::vector<std::string>* domains,
                                         std::vector<std::string>* tlsServers, res_params* params,
                                         std::vector<ResolverStats>* stats,
                                         int* waitForPendingReqTimeoutCount) {
    using aidl::android::net::IDnsResolver;
    std::vector<int32_t> params32;
    std::vector<int32_t> stats32;
    std::vector<int32_t> waitForPendingReqTimeoutCount32{0};
    auto rv = dnsResolverService->getResolverInfo(netId, servers, domains, tlsServers, &params32,
                                                  &stats32, &waitForPendingReqTimeoutCount32);

    if (!rv.isOk() || params32.size() != static_cast<size_t>(IDnsResolver::RESOLVER_PARAMS_COUNT)) {
        return false;
    }
    *params = res_params{
            .sample_validity =
                    static_cast<uint16_t>(params32[IDnsResolver::RESOLVER_PARAMS_SAMPLE_VALIDITY]),
            .success_threshold =
                    static_cast<uint8_t>(params32[IDnsResolver::RESOLVER_PARAMS_SUCCESS_THRESHOLD]),
Result<ResolverInfo> DnsResponderClient::getResolverInfo() {
    std::vector<std::string> dnsServers;
    std::vector<std::string> domains;
    std::vector<std::string> dotServers;
    std::vector<int32_t> params;
    std::vector<int32_t> stats;
    std::vector<int32_t> waitForPendingReqTimeoutCount;
    auto rv = mDnsResolvSrv->getResolverInfo(TEST_NETID, &dnsServers, &domains, &dotServers,
                                             &params, &stats, &waitForPendingReqTimeoutCount);
    if (!rv.isOk()) {
        return Error() << "getResolverInfo failed: " << rv.getMessage();
    }
    if (stats.size() % IDnsResolver::RESOLVER_STATS_COUNT != 0) {
        return Error() << "Unexpected stats size: " << stats.size();
    }
    if (params.size() != IDnsResolver::RESOLVER_PARAMS_COUNT) {
        return Error() << "Unexpected params size: " << params.size();
    }
    if (waitForPendingReqTimeoutCount.size() != 1) {
        return Error() << "Unexpected waitForPendingReqTimeoutCount size: "
                       << waitForPendingReqTimeoutCount.size();
    }

    ResolverInfo out = {
            .dnsServers = std::move(dnsServers),
            .domains = std::move(domains),
            .dotServers = std::move(dotServers),
            .params{
                    .sample_validity = static_cast<uint16_t>(
                            params[IDnsResolver::RESOLVER_PARAMS_SAMPLE_VALIDITY]),
                    .success_threshold = static_cast<uint8_t>(
                            params[IDnsResolver::RESOLVER_PARAMS_SUCCESS_THRESHOLD]),
                    .min_samples =
                    static_cast<uint8_t>(params32[IDnsResolver::RESOLVER_PARAMS_MIN_SAMPLES]),
                            static_cast<uint8_t>(params[IDnsResolver::RESOLVER_PARAMS_MIN_SAMPLES]),
                    .max_samples =
                    static_cast<uint8_t>(params32[IDnsResolver::RESOLVER_PARAMS_MAX_SAMPLES]),
            .base_timeout_msec = params32[IDnsResolver::RESOLVER_PARAMS_BASE_TIMEOUT_MSEC],
            .retry_count = params32[IDnsResolver::RESOLVER_PARAMS_RETRY_COUNT],
                            static_cast<uint8_t>(params[IDnsResolver::RESOLVER_PARAMS_MAX_SAMPLES]),
                    .base_timeout_msec = params[IDnsResolver::RESOLVER_PARAMS_BASE_TIMEOUT_MSEC],
                    .retry_count = params[IDnsResolver::RESOLVER_PARAMS_RETRY_COUNT],
            },
            .stats = {},
            .waitForPendingReqTimeoutCount = waitForPendingReqTimeoutCount[0],
    };
    *waitForPendingReqTimeoutCount = waitForPendingReqTimeoutCount32[0];
    return ResolverStats::decodeAll(stats32, stats);
    ResolverStats::decodeAll(stats, &out.stats);

    return std::move(out);
}

bool DnsResponderClient::SetResolversForNetwork(const std::vector<std::string>& servers,
+12 −6
Original line number Diff line number Diff line
@@ -23,6 +23,7 @@

#include <android-base/format.h>
#include <android-base/logging.h>
#include <android-base/result.h>

#include <aidl/android/net/IDnsResolver.h>
#include <aidl/android/net/INetd.h>
@@ -47,6 +48,16 @@ inline constexpr char kDefaultSearchDomain[] = "example.com";
        }                                                                                          \
    } while (0)

// A thin wrapper to store the outputs of DnsResolver::getResolverInfo().
struct ResolverInfo {
    std::vector<std::string> dnsServers;
    std::vector<std::string> domains;
    std::vector<std::string> dotServers;
    res_params params;
    std::vector<android::net::ResolverStats> stats;
    int waitForPendingReqTimeoutCount;
};

class ResolverParams {
  public:
    class Builder {
@@ -132,12 +143,7 @@ class DnsResponderClient {
    static NativeNetworkConfig makeNativeNetworkConfig(int netId, NativeNetworkType networkType,
                                                       int permission, bool secure);

    static bool GetResolverInfo(aidl::android::net::IDnsResolver* dnsResolverService,
                                unsigned netId, std::vector<std::string>* servers,
                                std::vector<std::string>* domains,
                                std::vector<std::string>* tlsServers, res_params* params,
                                std::vector<android::net::ResolverStats>* stats,
                                int* waitForPendingReqTimeoutCount);
    android::base::Result<ResolverInfo> getResolverInfo();

    // Return a default resolver configuration for opportunistic mode.
    static aidl::android::net::ResolverParamsParcel GetDefaultResolverParamsParcel();
+56 −120
Original line number Diff line number Diff line
@@ -322,20 +322,13 @@ class ResolverTest : public NetNativeTestBase {

    bool expectStatsFromGetResolverInfo(const std::vector<NameserverStats>& nameserversStats,
                                        const StatsCmp cmp) {
        std::vector<std::string> res_servers;
        std::vector<std::string> res_domains;
        std::vector<std::string> res_tls_servers;
        res_params res_params;
        std::vector<ResolverStats> res_stats;
        int wait_for_pending_req_timeout_count;

        if (!DnsResponderClient::GetResolverInfo(mDnsClient.resolvService(), TEST_NETID,
                                                 &res_servers, &res_domains, &res_tls_servers,
                                                 &res_params, &res_stats,
                                                 &wait_for_pending_req_timeout_count)) {
            ADD_FAILURE() << "GetResolverInfo failed";
        const auto resolvInfo = mDnsClient.getResolverInfo();
        if (!resolvInfo.ok()) {
            ADD_FAILURE() << resolvInfo.error().message();
            return false;
        }
        const std::vector<std::string>& res_servers = resolvInfo.value().dnsServers;
        const std::vector<ResolverStats>& res_stats = resolvInfo.value().stats;

        if (res_servers.size() != res_stats.size()) {
            ADD_FAILURE() << fmt::format("res_servers.size() != res_stats.size(): {} != {}",
@@ -679,28 +672,21 @@ TEST_F(ResolverTest, GetHostByName_Binder) {
    EXPECT_EQ(mapping.ip4, ToString(result));
    EXPECT_TRUE(result->h_addr_list[1] == nullptr);

    std::vector<std::string> res_servers;
    std::vector<std::string> res_domains;
    std::vector<std::string> res_tls_servers;
    res_params res_params;
    std::vector<ResolverStats> res_stats;
    int wait_for_pending_req_timeout_count;
    ASSERT_TRUE(DnsResponderClient::GetResolverInfo(
            mDnsClient.resolvService(), TEST_NETID, &res_servers, &res_domains, &res_tls_servers,
            &res_params, &res_stats, &wait_for_pending_req_timeout_count));
    EXPECT_EQ(servers.size(), res_servers.size());
    EXPECT_EQ(domains.size(), res_domains.size());
    EXPECT_EQ(0U, res_tls_servers.size());
    EXPECT_EQ(resolverParams.sampleValiditySeconds, res_params.sample_validity);
    EXPECT_EQ(resolverParams.successThreshold, res_params.success_threshold);
    EXPECT_EQ(resolverParams.minSamples, res_params.min_samples);
    EXPECT_EQ(resolverParams.maxSamples, res_params.max_samples);
    EXPECT_EQ(resolverParams.baseTimeoutMsec, res_params.base_timeout_msec);
    EXPECT_EQ(resolverParams.retryCount, res_params.retry_count);
    EXPECT_EQ(servers.size(), res_stats.size());

    EXPECT_THAT(res_servers, testing::UnorderedElementsAreArray(servers));
    EXPECT_THAT(res_domains, testing::UnorderedElementsAreArray(domains));
    const auto resolvInfo = mDnsClient.getResolverInfo();
    ASSERT_RESULT_OK(resolvInfo);
    EXPECT_EQ(servers.size(), resolvInfo.value().dnsServers.size());
    EXPECT_EQ(domains.size(), resolvInfo.value().domains.size());
    EXPECT_TRUE(resolvInfo.value().dotServers.empty());
    EXPECT_EQ(resolverParams.sampleValiditySeconds, resolvInfo.value().params.sample_validity);
    EXPECT_EQ(resolverParams.successThreshold, resolvInfo.value().params.success_threshold);
    EXPECT_EQ(resolverParams.minSamples, resolvInfo.value().params.min_samples);
    EXPECT_EQ(resolverParams.maxSamples, resolvInfo.value().params.max_samples);
    EXPECT_EQ(resolverParams.baseTimeoutMsec, resolvInfo.value().params.base_timeout_msec);
    EXPECT_EQ(resolverParams.retryCount, resolvInfo.value().params.retry_count);
    EXPECT_EQ(servers.size(), resolvInfo.value().stats.size());

    EXPECT_THAT(resolvInfo.value().dnsServers, testing::UnorderedElementsAreArray(servers));
    EXPECT_THAT(resolvInfo.value().domains, testing::UnorderedElementsAreArray(domains));
}

TEST_F(ResolverTest, GetAddrInfo) {
@@ -889,17 +875,9 @@ TEST_F(ResolverTest, GetAddrInfoV4_deferred_resp) {
        EXPECT_EQ(0U, GetNumQueries(dns2, host_name_deferred));
        EXPECT_TRUE(result != nullptr);
        EXPECT_EQ("1.2.3.4", ToString(result));

        std::vector<std::string> res_servers;
        std::vector<std::string> res_domains;
        std::vector<std::string> res_tls_servers;
        res_params res_params;
        std::vector<ResolverStats> res_stats;
        int wait_for_pending_req_timeout_count;
        ASSERT_TRUE(DnsResponderClient::GetResolverInfo(
                mDnsClient.resolvService(), TEST_NETID, &res_servers, &res_domains,
                &res_tls_servers, &res_params, &res_stats, &wait_for_pending_req_timeout_count));
        EXPECT_EQ(0, wait_for_pending_req_timeout_count);
        const auto resolvInfo = mDnsClient.getResolverInfo();
        ASSERT_RESULT_OK(resolvInfo);
        EXPECT_EQ(0, resolvInfo.value().waitForPendingReqTimeoutCount);
    });

    // ensuring t2 and t3 handler functions are processed in order
@@ -1218,16 +1196,9 @@ TEST_F(ResolverTest, GetAddrInfoV6_concurrent) {
        thread.join();
    }

    std::vector<std::string> res_servers;
    std::vector<std::string> res_domains;
    std::vector<std::string> res_tls_servers;
    res_params res_params;
    std::vector<ResolverStats> res_stats;
    int wait_for_pending_req_timeout_count;
    ASSERT_TRUE(DnsResponderClient::GetResolverInfo(
            mDnsClient.resolvService(), TEST_NETID, &res_servers, &res_domains, &res_tls_servers,
            &res_params, &res_stats, &wait_for_pending_req_timeout_count));
    EXPECT_EQ(0, wait_for_pending_req_timeout_count);
    const auto resolvInfo = mDnsClient.getResolverInfo();
    ASSERT_RESULT_OK(resolvInfo);
    EXPECT_EQ(0, resolvInfo.value().waitForPendingReqTimeoutCount);
}

TEST_F(ResolverTest, SkipBadServersDueToInternalError) {
@@ -1514,22 +1485,15 @@ TEST_F(ResolverTest, GetAddrInfoFromCustTable_Modify) {

TEST_F(ResolverTest, EmptySetup) {
    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(ResolverParamsParcel{.netId = TEST_NETID}));
    std::vector<std::string> res_servers;
    std::vector<std::string> res_domains;
    std::vector<std::string> res_tls_servers;
    res_params res_params;
    std::vector<ResolverStats> res_stats;
    int wait_for_pending_req_timeout_count;
    ASSERT_TRUE(DnsResponderClient::GetResolverInfo(
            mDnsClient.resolvService(), TEST_NETID, &res_servers, &res_domains, &res_tls_servers,
            &res_params, &res_stats, &wait_for_pending_req_timeout_count));
    EXPECT_EQ(0U, res_servers.size());
    EXPECT_EQ(0U, res_domains.size());
    EXPECT_EQ(0U, res_tls_servers.size());
    EXPECT_EQ(0U, res_params.sample_validity);
    EXPECT_EQ(0U, res_params.success_threshold);
    EXPECT_EQ(0U, res_params.min_samples);
    EXPECT_EQ(0U, res_params.max_samples);
    const auto resolvInfo = mDnsClient.getResolverInfo();
    ASSERT_RESULT_OK(resolvInfo);
    EXPECT_TRUE(resolvInfo.value().dnsServers.empty());
    EXPECT_TRUE(resolvInfo.value().domains.empty());
    EXPECT_TRUE(resolvInfo.value().dotServers.empty());
    EXPECT_EQ(0U, resolvInfo.value().params.sample_validity);
    EXPECT_EQ(0U, resolvInfo.value().params.success_threshold);
    EXPECT_EQ(0U, resolvInfo.value().params.min_samples);
    EXPECT_EQ(0U, resolvInfo.value().params.max_samples);
    // We don't check baseTimeoutMsec and retryCount because their value differ depending on
    // the experiment flags.
}
@@ -1568,24 +1532,6 @@ TEST_F(ResolverTest, SearchPathChange) {
    EXPECT_EQ("2001:db8::1:13", ToString(result));
}

namespace {

std::vector<std::string> getResolverDomains(aidl::android::net::IDnsResolver* dnsResolverService,
                                            unsigned netId) {
    std::vector<std::string> res_servers;
    std::vector<std::string> res_domains;
    std::vector<std::string> res_tls_servers;
    res_params res_params;
    std::vector<ResolverStats> res_stats;
    int wait_for_pending_req_timeout_count;
    DnsResponderClient::GetResolverInfo(dnsResolverService, netId, &res_servers, &res_domains,
                                        &res_tls_servers, &res_params, &res_stats,
                                        &wait_for_pending_req_timeout_count);
    return res_domains;
}

}  // namespace

TEST_F(ResolverTest, SearchPathPrune) {
    constexpr size_t DUPLICATED_DOMAIN_NUM = 3;
    constexpr char listen_addr[] = "127.0.0.13";
@@ -1630,7 +1576,9 @@ TEST_F(ResolverTest, SearchPathPrune) {
    EXPECT_EQ(1U, GetNumQueries(dns, host_name1));
    EXPECT_EQ("2001:db8::13", ToString(result));

    const auto& res_domains1 = getResolverDomains(mDnsClient.resolvService(), TEST_NETID);
    auto resolvInfo = mDnsClient.getResolverInfo();
    ASSERT_RESULT_OK(resolvInfo);
    const auto& res_domains1 = resolvInfo.value().domains;
    // Expect 1 valid domain, invalid domains are removed.
    ASSERT_EQ(1U, res_domains1.size());
    EXPECT_EQ(domian_name1, res_domains1[0]);
@@ -1647,7 +1595,9 @@ TEST_F(ResolverTest, SearchPathPrune) {
    EXPECT_EQ(1U, GetNumQueries(dns, host_name2));
    EXPECT_EQ("2001:db8::1:13", ToString(result));

    const auto& res_domains2 = getResolverDomains(mDnsClient.resolvService(), TEST_NETID);
    resolvInfo = mDnsClient.getResolverInfo();
    ASSERT_RESULT_OK(resolvInfo);
    const auto& res_domains2 = resolvInfo.value().domains;
    // Expect 4 valid domain, duplicate domains are removed.
    EXPECT_EQ(DUPLICATED_DOMAIN_NUM + 1U, res_domains2.size());
    EXPECT_THAT(
@@ -1702,23 +1652,17 @@ TEST_F(ResolverTest, MaxServerPrune_Binder) {
        LOG(INFO) << "private DNS validation on " << tls[i]->listen_address() << " done.";
    }

    std::vector<std::string> res_servers;
    std::vector<std::string> res_domains;
    std::vector<std::string> res_tls_servers;
    res_params res_params;
    std::vector<ResolverStats> res_stats;
    int wait_for_pending_req_timeout_count;
    ASSERT_TRUE(DnsResponderClient::GetResolverInfo(
            mDnsClient.resolvService(), TEST_NETID, &res_servers, &res_domains, &res_tls_servers,
            &res_params, &res_stats, &wait_for_pending_req_timeout_count));

    // Check the size of the stats and its contents.
    EXPECT_EQ(static_cast<size_t>(MAXNS), res_servers.size());
    EXPECT_EQ(static_cast<size_t>(MAXNS), res_tls_servers.size());
    EXPECT_EQ(static_cast<size_t>(MAXDNSRCH), res_domains.size());
    EXPECT_TRUE(std::equal(servers.begin(), servers.begin() + MAXNS, res_servers.begin()));
    EXPECT_TRUE(std::equal(servers.begin(), servers.begin() + MAXNS, res_tls_servers.begin()));
    EXPECT_TRUE(std::equal(domains.begin(), domains.begin() + MAXDNSRCH, res_domains.begin()));
    const auto resolvInfo = mDnsClient.getResolverInfo();
    ASSERT_RESULT_OK(resolvInfo);
    EXPECT_EQ(static_cast<size_t>(MAXNS), resolvInfo.value().dnsServers.size());
    EXPECT_EQ(static_cast<size_t>(MAXNS), resolvInfo.value().dotServers.size());
    EXPECT_EQ(static_cast<size_t>(MAXDNSRCH), resolvInfo.value().domains.size());
    EXPECT_TRUE(std::equal(servers.begin(), servers.begin() + MAXNS,
                           resolvInfo.value().dnsServers.begin()));
    EXPECT_TRUE(std::equal(servers.begin(), servers.begin() + MAXNS,
                           resolvInfo.value().dotServers.begin()));
    EXPECT_TRUE(std::equal(domains.begin(), domains.begin() + MAXDNSRCH,
                           resolvInfo.value().domains.begin()));
}

TEST_F(ResolverTest, ResolverStats) {
@@ -7651,17 +7595,9 @@ TEST_F(ResolverTest, NegativeValueInExperimentFlag) {
        setupParams.baseTimeoutMsec = config.baseTimeoutMsec;
        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(setupParams));

        std::vector<std::string> res_servers;
        std::vector<std::string> res_domains;
        std::vector<std::string> res_tls_servers;
        res_params res_params;
        std::vector<ResolverStats> res_stats;
        int wait_for_pending_req_timeout_count;
        ASSERT_TRUE(DnsResponderClient::GetResolverInfo(
                mDnsClient.resolvService(), TEST_NETID, &res_servers, &res_domains,
                &res_tls_servers, &res_params, &res_stats, &wait_for_pending_req_timeout_count));

        EXPECT_EQ(config.expectedRetryCount, res_params.retry_count);
        EXPECT_EQ(config.expectedBaseTimeoutMsec, res_params.base_timeout_msec);
        const auto resolvInfo = mDnsClient.getResolverInfo();
        ASSERT_RESULT_OK(resolvInfo);
        EXPECT_EQ(config.expectedRetryCount, resolvInfo.value().params.retry_count);
        EXPECT_EQ(config.expectedBaseTimeoutMsec, resolvInfo.value().params.base_timeout_msec);
    }
}
+3 −14
Original line number Diff line number Diff line
@@ -25,13 +25,9 @@
#include <gtest/gtest.h>
#include <netdutils/NetNativeTestBase.h>

#include "ResolverStats.h"
#include "dns_responder/dns_responder_client_ndk.h"
#include "params.h"  // MAX_NS
#include "resolv_test_utils.h"

using android::net::ResolverStats;

class ResolverStressTest : public NetNativeTestBase {
  public:
    ResolverStressTest() { mDnsClient.SetUp(); }
@@ -79,16 +75,9 @@ class ResolverStressTest : public NetNativeTestBase {
        LOG(INFO) << fmt::format("{} hosts, {} threads, {} queries, {:E}s", num_hosts, num_threads,
                                 num_queries, std::chrono::duration<double>(t1 - t0).count());

        std::vector<std::string> res_servers;
        std::vector<std::string> res_domains;
        std::vector<std::string> res_tls_servers;
        res_params res_params;
        std::vector<ResolverStats> res_stats;
        int wait_for_pending_req_timeout_count;
        ASSERT_TRUE(DnsResponderClient::GetResolverInfo(
                mDnsClient.resolvService(), TEST_NETID, &res_servers, &res_domains,
                &res_tls_servers, &res_params, &res_stats, &wait_for_pending_req_timeout_count));
        EXPECT_EQ(0, wait_for_pending_req_timeout_count);
        const auto resolvInfo = mDnsClient.getResolverInfo();
        ASSERT_RESULT_OK(resolvInfo);
        EXPECT_EQ(0, resolvInfo.value().waitForPendingReqTimeoutCount);
    }

    DnsResponderClient mDnsClient;