Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ccfd499e authored by Ken Chen's avatar Ken Chen
Browse files

Refactor the setupDns to include addIpv{4|6}Dns

Move DNS server configurations and network forwardings to setupDns
together. Moving those implementation details to the same function not
only reduces code duplication, but also help us to focus on test logic.

Bug: 176507580
Test: atest
Change-Id: Ib9a9ffd303fb5a531586178d0f3de8709889d5dd
parent 34ed9eeb
Loading
Loading
Loading
Loading
+46 −55
Original line number Diff line number Diff line
@@ -6576,8 +6576,10 @@ class ResolverMultinetworkTest : public ResolverTest {
        if (mNextNetId == TEST_NETID_BASE + 256) mNextNetId = TEST_NETID_BASE;
        return mNextNetId++;
    }
    void setupDns(std::shared_ptr<test::DNSResponder> dnsServer, const char* host_name,
                  const char* ipv4_addr, const char* ipv6_addr, ScopedNetwork* nw);
    Result<std::shared_ptr<test::DNSResponder>> setupDns(ConnectivityType type, ScopedNetwork* nw,
                                                         const char* host_name,
                                                         const char* ipv4_addr,
                                                         const char* ipv6_addr);

  private:
    // Use a different netId because this class inherits from the class ResolverTest which
@@ -6643,13 +6645,19 @@ void ResolverMultinetworkTest::StartDns(test::DNSResponder& dns,
    }
}

void ResolverMultinetworkTest::setupDns(std::shared_ptr<test::DNSResponder> dnsServer,
                                        const char* host_name, const char* ipv4_addr,
                                        const char* ipv6_addr, ScopedNetwork* nw) {
    StartDns(*dnsServer,
Result<std::shared_ptr<test::DNSResponder>> ResolverMultinetworkTest::setupDns(
        ConnectivityType type, ScopedNetwork* nw, const char* host_name, const char* ipv4_addr,
        const char* ipv6_addr) {
    // Add a testing DNS server to networks.
    const Result<DnsServerPair> dnsSvPair =
            (type == ConnectivityType::V4) ? nw->addIpv4Dns() : nw->addIpv6Dns();
    if (!dnsSvPair.ok()) return Error() << dnsSvPair.error();

    StartDns(*dnsSvPair->dnsServer,
             {{host_name, ns_type::ns_t_a, ipv4_addr}, {host_name, ns_type::ns_t_aaaa, ipv6_addr}});
    ASSERT_TRUE(nw->setDnsConfiguration());
    ASSERT_TRUE(nw->startTunForwarder());
    if (!nw->setDnsConfiguration()) return Error() << "setDnsConfiguration() failed";
    if (!nw->startTunForwarder()) return Error() << "startTunForwarder() failed";
    return dnsSvPair->dnsServer;
}

Result<ResolverMultinetworkTest::DnsServerPair> ResolverMultinetworkTest::ScopedNetwork::addDns(
@@ -6884,24 +6892,15 @@ TEST_F(ResolverMultinetworkTest, DnsWithVpn) {
        ASSERT_RESULT_OK(bypassableVpnNetwork.addUser(TEST_UID));
        ASSERT_RESULT_OK(secureVpnNetwork.addUser(TEST_UID2));

        // Add a testing DNS server to networks.
        const Result<DnsServerPair> underlyingPair = (type == ConnectivityType::V4)
                                                             ? underlyingNetwork.addIpv4Dns()
                                                             : underlyingNetwork.addIpv6Dns();
        ASSERT_RESULT_OK(underlyingPair);
        const Result<DnsServerPair> bypassableVpnPair = (type == ConnectivityType::V4)
                                                                ? bypassableVpnNetwork.addIpv4Dns()
                                                                : bypassableVpnNetwork.addIpv6Dns();
        ASSERT_RESULT_OK(bypassableVpnPair);
        const Result<DnsServerPair> secureVpnPair = (type == ConnectivityType::V4)
                                                            ? secureVpnNetwork.addIpv4Dns()
                                                            : secureVpnNetwork.addIpv6Dns();
        ASSERT_RESULT_OK(secureVpnPair);
        // Set up resolver and start forwarding for networks.
        setupDns(underlyingPair->dnsServer, host_name, ipv4_addr, ipv6_addr, &underlyingNetwork);
        setupDns(bypassableVpnPair->dnsServer, host_name, ipv4_addr, ipv6_addr,
                 &bypassableVpnNetwork);
        setupDns(secureVpnPair->dnsServer, host_name, ipv4_addr, ipv6_addr, &secureVpnNetwork);
        auto underlyingNwDnsSv =
                setupDns(type, &underlyingNetwork, host_name, ipv4_addr, ipv6_addr);
        ASSERT_RESULT_OK(underlyingNwDnsSv);
        auto bypassableVpnDnsSv =
                setupDns(type, &bypassableVpnNetwork, host_name, ipv4_addr, ipv6_addr);
        ASSERT_RESULT_OK(bypassableVpnDnsSv);
        auto secureVpnDnsSv = setupDns(type, &secureVpnNetwork, host_name, ipv4_addr, ipv6_addr);
        ASSERT_RESULT_OK(secureVpnDnsSv);

        setDefaultNetwork(underlyingNetwork.netId());
        const unsigned underlyingNetId = underlyingNetwork.netId();
@@ -6931,13 +6930,13 @@ TEST_F(ResolverMultinetworkTest, DnsWithVpn) {
        } vpnWithDnsServerConfigs[]{
                // clang-format off
                // Queries use the bypassable VPN by default.
                {&defaultNetwork,       bypassableVpnNetId, bypassableVpnPair->dnsServer},
                {&defaultNetwork,       bypassableVpnNetId, *bypassableVpnDnsSv},
                // Choosing the underlying network works because the VPN is bypassable.
                {&underlyingNetwork,    underlyingNetId,    underlyingPair->dnsServer},
                {&underlyingNetwork,    underlyingNetId,    *underlyingNwDnsSv},
                // Selecting the VPN sends the query on the VPN.
                {&bypassableVpnNetwork, bypassableVpnNetId, bypassableVpnPair->dnsServer},
                {&bypassableVpnNetwork, bypassableVpnNetId, *bypassableVpnDnsSv},
                // TEST_UID does not have access to the secure VPN.
                {&secureVpnNetwork,     bypassableVpnNetId, bypassableVpnPair->dnsServer},
                {&secureVpnNetwork,     bypassableVpnNetId, *bypassableVpnDnsSv},
                // clang-format on
        };
        for (const auto& config : vpnWithDnsServerConfigs) {
@@ -6957,7 +6956,7 @@ TEST_F(ResolverMultinetworkTest, DnsWithVpn) {
            SCOPED_TRACE(fmt::format("Bypassble VPN without DnsServer, selectedNetwork = {}",
                                     selectedNetwork->name()));
            expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID, result);
            expectDnsQueryCountsFn(result.size(), underlyingPair->dnsServer, underlyingNetId);
            expectDnsQueryCountsFn(result.size(), *underlyingNwDnsSv, underlyingNetId);
        }

        // The same test scenario as before plus enableVpnIsolation for secure VPN, TEST_UID2.
@@ -6973,7 +6972,7 @@ TEST_F(ResolverMultinetworkTest, DnsWithVpn) {
                SCOPED_TRACE(fmt::format("Secure VPN without DnsServer, selectedNetwork = {}",
                                         selectedNetwork->name()));
                expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID2, result);
                expectDnsQueryCountsFn(result.size(), underlyingPair->dnsServer, underlyingNetId);
                expectDnsQueryCountsFn(result.size(), *underlyingNwDnsSv, underlyingNetId);
            }

            // Test secure VPN with DNS server.
@@ -6982,7 +6981,7 @@ TEST_F(ResolverMultinetworkTest, DnsWithVpn) {
                SCOPED_TRACE(fmt::format("Secure VPN with DnsServer, selectedNetwork = {}",
                                         selectedNetwork->name()));
                expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID2, result);
                expectDnsQueryCountsFn(result.size(), secureVpnPair->dnsServer, secureVpnNetId);
                expectDnsQueryCountsFn(result.size(), *secureVpnDnsSv, secureVpnNetId);
            }

            if (enableVpnIsolation) {
@@ -7020,22 +7019,15 @@ TEST_F(ResolverMultinetworkTest, PerAppDefaultNetwork) {
        ASSERT_RESULT_OK(appDefaultNetwork.init());
        ASSERT_RESULT_OK(vpn.init());

        // Create testing DNS servers for each network.
        const Result<DnsServerPair> sysDefaultPair = (ipVersion == ConnectivityType::V4)
                                                             ? sysDefaultNetwork.addIpv4Dns()
                                                             : sysDefaultNetwork.addIpv6Dns();
        ASSERT_RESULT_OK(sysDefaultPair);
        const Result<DnsServerPair> appDefaultPair = (ipVersion == ConnectivityType::V4)
                                                             ? appDefaultNetwork.addIpv4Dns()
                                                             : appDefaultNetwork.addIpv6Dns();
        ASSERT_RESULT_OK(appDefaultPair);
        const Result<DnsServerPair> vpnPair =
                (ipVersion == ConnectivityType::V4) ? vpn.addIpv4Dns() : vpn.addIpv6Dns();
        ASSERT_RESULT_OK(vpnPair);
        // Set up resolver and start forwarding for networks.
        setupDns(sysDefaultPair->dnsServer, host_name, ipv4_addr, ipv6_addr, &sysDefaultNetwork);
        setupDns(appDefaultPair->dnsServer, host_name, ipv4_addr, ipv6_addr, &appDefaultNetwork);
        setupDns(vpnPair->dnsServer, host_name, ipv4_addr, ipv6_addr, &vpn);
        auto sysDefaultNwDnsSv =
                setupDns(ipVersion, &sysDefaultNetwork, host_name, ipv4_addr, ipv6_addr);
        ASSERT_RESULT_OK(sysDefaultNwDnsSv);
        auto appDefaultNwDnsSv =
                setupDns(ipVersion, &appDefaultNetwork, host_name, ipv4_addr, ipv6_addr);
        ASSERT_RESULT_OK(appDefaultNwDnsSv);
        auto vpnDnsSv = setupDns(ipVersion, &vpn, host_name, ipv4_addr, ipv6_addr);
        ASSERT_RESULT_OK(vpnDnsSv);

        const unsigned systemDefaultNetId = sysDefaultNetwork.netId();
        const unsigned appDefaultNetId = appDefaultNetwork.netId();
@@ -7059,12 +7051,11 @@ TEST_F(ResolverMultinetworkTest, PerAppDefaultNetwork) {

        // Test DNS query without selecting a network. --> use system default network.
        expectDnsWorksForUid(host_name, NETID_UNSET, TEST_UID, expectedDnsReply);
        expectDnsQueryCountsFn(expectedDnsReply.size(), sysDefaultPair->dnsServer,
                               systemDefaultNetId);
        expectDnsQueryCountsFn(expectedDnsReply.size(), *sysDefaultNwDnsSv, systemDefaultNetId);
        // Add user to app default network. --> use app default network.
        ASSERT_RESULT_OK(appDefaultNetwork.addUser(TEST_UID));
        expectDnsWorksForUid(host_name, NETID_UNSET, TEST_UID, expectedDnsReply);
        expectDnsQueryCountsFn(expectedDnsReply.size(), appDefaultPair->dnsServer, appDefaultNetId);
        expectDnsQueryCountsFn(expectedDnsReply.size(), *appDefaultNwDnsSv, appDefaultNetId);

        // Test DNS query with a selected network.
        // App default network applies to uid, vpn does not applies to uid.
@@ -7075,11 +7066,11 @@ TEST_F(ResolverMultinetworkTest, PerAppDefaultNetwork) {
        } vpnWithDnsServerConfigs[]{
                // clang-format off
                // App can select the system default network without any permission.
                {&sysDefaultNetwork, systemDefaultNetId, sysDefaultPair->dnsServer},
                {&sysDefaultNetwork, systemDefaultNetId, *sysDefaultNwDnsSv},
                // App can select the restricted network, since its uid was assigned to the network.
                {&appDefaultNetwork, appDefaultNetId, appDefaultPair->dnsServer},
                {&appDefaultNetwork, appDefaultNetId, *appDefaultNwDnsSv},
                // App does not have access to the VPN. --> fallback to app default network.
                {&vpn, appDefaultNetId, appDefaultPair->dnsServer},
                {&vpn, appDefaultNetId, *appDefaultNwDnsSv},
                // clang-format on
        };
        for (const auto& config : vpnWithDnsServerConfigs) {
@@ -7094,11 +7085,11 @@ TEST_F(ResolverMultinetworkTest, PerAppDefaultNetwork) {
        // App default network applies to uid, vpn applies to uid. --> use vpn.
        ASSERT_RESULT_OK(vpn.addUser(TEST_UID));
        expectDnsWorksForUid(host_name, vpn.netId(), TEST_UID, expectedDnsReply);
        expectDnsQueryCountsFn(expectedDnsReply.size(), vpnPair->dnsServer, vpnNetId);
        expectDnsQueryCountsFn(expectedDnsReply.size(), *vpnDnsSv, vpnNetId);

        // vpn without server. --> fallback to app default network.
        ASSERT_TRUE(vpn.clearDnsConfiguration());
        expectDnsWorksForUid(host_name, vpn.netId(), TEST_UID, expectedDnsReply);
        expectDnsQueryCountsFn(expectedDnsReply.size(), appDefaultPair->dnsServer, appDefaultNetId);
        expectDnsQueryCountsFn(expectedDnsReply.size(), *appDefaultNwDnsSv, appDefaultNetId);
    }
}