Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c07ae5c2 authored by Ken Chen's avatar Ken Chen Committed by Gerrit Code Review
Browse files

Merge "Refactor the setupDns to include addIpv{4|6}Dns"

parents b929ae1d ccfd499e
Loading
Loading
Loading
Loading
+46 −55
Original line number Diff line number Diff line
@@ -6578,8 +6578,10 @@ class ResolverMultinetworkTest : public ResolverTest {
        if (mNextNetId == TEST_NETID_BASE + 256) mNextNetId = TEST_NETID_BASE;
        return mNextNetId++;
    }
    void setupDns(std::shared_ptr<test::DNSResponder> dnsServer, const char* host_name,
                  const char* ipv4_addr, const char* ipv6_addr, ScopedNetwork* nw);
    Result<std::shared_ptr<test::DNSResponder>> setupDns(ConnectivityType type, ScopedNetwork* nw,
                                                         const char* host_name,
                                                         const char* ipv4_addr,
                                                         const char* ipv6_addr);

  private:
    // Use a different netId because this class inherits from the class ResolverTest which
@@ -6645,13 +6647,19 @@ void ResolverMultinetworkTest::StartDns(test::DNSResponder& dns,
    }
}

void ResolverMultinetworkTest::setupDns(std::shared_ptr<test::DNSResponder> dnsServer,
                                        const char* host_name, const char* ipv4_addr,
                                        const char* ipv6_addr, ScopedNetwork* nw) {
    StartDns(*dnsServer,
Result<std::shared_ptr<test::DNSResponder>> ResolverMultinetworkTest::setupDns(
        ConnectivityType type, ScopedNetwork* nw, const char* host_name, const char* ipv4_addr,
        const char* ipv6_addr) {
    // Add a testing DNS server to networks.
    const Result<DnsServerPair> dnsSvPair =
            (type == ConnectivityType::V4) ? nw->addIpv4Dns() : nw->addIpv6Dns();
    if (!dnsSvPair.ok()) return Error() << dnsSvPair.error();

    StartDns(*dnsSvPair->dnsServer,
             {{host_name, ns_type::ns_t_a, ipv4_addr}, {host_name, ns_type::ns_t_aaaa, ipv6_addr}});
    ASSERT_TRUE(nw->setDnsConfiguration());
    ASSERT_TRUE(nw->startTunForwarder());
    if (!nw->setDnsConfiguration()) return Error() << "setDnsConfiguration() failed";
    if (!nw->startTunForwarder()) return Error() << "startTunForwarder() failed";
    return dnsSvPair->dnsServer;
}

Result<ResolverMultinetworkTest::DnsServerPair> ResolverMultinetworkTest::ScopedNetwork::addDns(
@@ -6886,24 +6894,15 @@ TEST_F(ResolverMultinetworkTest, DnsWithVpn) {
        ASSERT_RESULT_OK(bypassableVpnNetwork.addUser(TEST_UID));
        ASSERT_RESULT_OK(secureVpnNetwork.addUser(TEST_UID2));

        // Add a testing DNS server to networks.
        const Result<DnsServerPair> underlyingPair = (type == ConnectivityType::V4)
                                                             ? underlyingNetwork.addIpv4Dns()
                                                             : underlyingNetwork.addIpv6Dns();
        ASSERT_RESULT_OK(underlyingPair);
        const Result<DnsServerPair> bypassableVpnPair = (type == ConnectivityType::V4)
                                                                ? bypassableVpnNetwork.addIpv4Dns()
                                                                : bypassableVpnNetwork.addIpv6Dns();
        ASSERT_RESULT_OK(bypassableVpnPair);
        const Result<DnsServerPair> secureVpnPair = (type == ConnectivityType::V4)
                                                            ? secureVpnNetwork.addIpv4Dns()
                                                            : secureVpnNetwork.addIpv6Dns();
        ASSERT_RESULT_OK(secureVpnPair);
        // Set up resolver and start forwarding for networks.
        setupDns(underlyingPair->dnsServer, host_name, ipv4_addr, ipv6_addr, &underlyingNetwork);
        setupDns(bypassableVpnPair->dnsServer, host_name, ipv4_addr, ipv6_addr,
                 &bypassableVpnNetwork);
        setupDns(secureVpnPair->dnsServer, host_name, ipv4_addr, ipv6_addr, &secureVpnNetwork);
        auto underlyingNwDnsSv =
                setupDns(type, &underlyingNetwork, host_name, ipv4_addr, ipv6_addr);
        ASSERT_RESULT_OK(underlyingNwDnsSv);
        auto bypassableVpnDnsSv =
                setupDns(type, &bypassableVpnNetwork, host_name, ipv4_addr, ipv6_addr);
        ASSERT_RESULT_OK(bypassableVpnDnsSv);
        auto secureVpnDnsSv = setupDns(type, &secureVpnNetwork, host_name, ipv4_addr, ipv6_addr);
        ASSERT_RESULT_OK(secureVpnDnsSv);

        setDefaultNetwork(underlyingNetwork.netId());
        const unsigned underlyingNetId = underlyingNetwork.netId();
@@ -6933,13 +6932,13 @@ TEST_F(ResolverMultinetworkTest, DnsWithVpn) {
        } vpnWithDnsServerConfigs[]{
                // clang-format off
                // Queries use the bypassable VPN by default.
                {&defaultNetwork,       bypassableVpnNetId, bypassableVpnPair->dnsServer},
                {&defaultNetwork,       bypassableVpnNetId, *bypassableVpnDnsSv},
                // Choosing the underlying network works because the VPN is bypassable.
                {&underlyingNetwork,    underlyingNetId,    underlyingPair->dnsServer},
                {&underlyingNetwork,    underlyingNetId,    *underlyingNwDnsSv},
                // Selecting the VPN sends the query on the VPN.
                {&bypassableVpnNetwork, bypassableVpnNetId, bypassableVpnPair->dnsServer},
                {&bypassableVpnNetwork, bypassableVpnNetId, *bypassableVpnDnsSv},
                // TEST_UID does not have access to the secure VPN.
                {&secureVpnNetwork,     bypassableVpnNetId, bypassableVpnPair->dnsServer},
                {&secureVpnNetwork,     bypassableVpnNetId, *bypassableVpnDnsSv},
                // clang-format on
        };
        for (const auto& config : vpnWithDnsServerConfigs) {
@@ -6959,7 +6958,7 @@ TEST_F(ResolverMultinetworkTest, DnsWithVpn) {
            SCOPED_TRACE(fmt::format("Bypassble VPN without DnsServer, selectedNetwork = {}",
                                     selectedNetwork->name()));
            expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID, result);
            expectDnsQueryCountsFn(result.size(), underlyingPair->dnsServer, underlyingNetId);
            expectDnsQueryCountsFn(result.size(), *underlyingNwDnsSv, underlyingNetId);
        }

        // The same test scenario as before plus enableVpnIsolation for secure VPN, TEST_UID2.
@@ -6975,7 +6974,7 @@ TEST_F(ResolverMultinetworkTest, DnsWithVpn) {
                SCOPED_TRACE(fmt::format("Secure VPN without DnsServer, selectedNetwork = {}",
                                         selectedNetwork->name()));
                expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID2, result);
                expectDnsQueryCountsFn(result.size(), underlyingPair->dnsServer, underlyingNetId);
                expectDnsQueryCountsFn(result.size(), *underlyingNwDnsSv, underlyingNetId);
            }

            // Test secure VPN with DNS server.
@@ -6984,7 +6983,7 @@ TEST_F(ResolverMultinetworkTest, DnsWithVpn) {
                SCOPED_TRACE(fmt::format("Secure VPN with DnsServer, selectedNetwork = {}",
                                         selectedNetwork->name()));
                expectDnsWorksForUid(host_name, selectedNetwork->netId(), TEST_UID2, result);
                expectDnsQueryCountsFn(result.size(), secureVpnPair->dnsServer, secureVpnNetId);
                expectDnsQueryCountsFn(result.size(), *secureVpnDnsSv, secureVpnNetId);
            }

            if (enableVpnIsolation) {
@@ -7022,22 +7021,15 @@ TEST_F(ResolverMultinetworkTest, PerAppDefaultNetwork) {
        ASSERT_RESULT_OK(appDefaultNetwork.init());
        ASSERT_RESULT_OK(vpn.init());

        // Create testing DNS servers for each network.
        const Result<DnsServerPair> sysDefaultPair = (ipVersion == ConnectivityType::V4)
                                                             ? sysDefaultNetwork.addIpv4Dns()
                                                             : sysDefaultNetwork.addIpv6Dns();
        ASSERT_RESULT_OK(sysDefaultPair);
        const Result<DnsServerPair> appDefaultPair = (ipVersion == ConnectivityType::V4)
                                                             ? appDefaultNetwork.addIpv4Dns()
                                                             : appDefaultNetwork.addIpv6Dns();
        ASSERT_RESULT_OK(appDefaultPair);
        const Result<DnsServerPair> vpnPair =
                (ipVersion == ConnectivityType::V4) ? vpn.addIpv4Dns() : vpn.addIpv6Dns();
        ASSERT_RESULT_OK(vpnPair);
        // Set up resolver and start forwarding for networks.
        setupDns(sysDefaultPair->dnsServer, host_name, ipv4_addr, ipv6_addr, &sysDefaultNetwork);
        setupDns(appDefaultPair->dnsServer, host_name, ipv4_addr, ipv6_addr, &appDefaultNetwork);
        setupDns(vpnPair->dnsServer, host_name, ipv4_addr, ipv6_addr, &vpn);
        auto sysDefaultNwDnsSv =
                setupDns(ipVersion, &sysDefaultNetwork, host_name, ipv4_addr, ipv6_addr);
        ASSERT_RESULT_OK(sysDefaultNwDnsSv);
        auto appDefaultNwDnsSv =
                setupDns(ipVersion, &appDefaultNetwork, host_name, ipv4_addr, ipv6_addr);
        ASSERT_RESULT_OK(appDefaultNwDnsSv);
        auto vpnDnsSv = setupDns(ipVersion, &vpn, host_name, ipv4_addr, ipv6_addr);
        ASSERT_RESULT_OK(vpnDnsSv);

        const unsigned systemDefaultNetId = sysDefaultNetwork.netId();
        const unsigned appDefaultNetId = appDefaultNetwork.netId();
@@ -7061,12 +7053,11 @@ TEST_F(ResolverMultinetworkTest, PerAppDefaultNetwork) {

        // Test DNS query without selecting a network. --> use system default network.
        expectDnsWorksForUid(host_name, NETID_UNSET, TEST_UID, expectedDnsReply);
        expectDnsQueryCountsFn(expectedDnsReply.size(), sysDefaultPair->dnsServer,
                               systemDefaultNetId);
        expectDnsQueryCountsFn(expectedDnsReply.size(), *sysDefaultNwDnsSv, systemDefaultNetId);
        // Add user to app default network. --> use app default network.
        ASSERT_RESULT_OK(appDefaultNetwork.addUser(TEST_UID));
        expectDnsWorksForUid(host_name, NETID_UNSET, TEST_UID, expectedDnsReply);
        expectDnsQueryCountsFn(expectedDnsReply.size(), appDefaultPair->dnsServer, appDefaultNetId);
        expectDnsQueryCountsFn(expectedDnsReply.size(), *appDefaultNwDnsSv, appDefaultNetId);

        // Test DNS query with a selected network.
        // App default network applies to uid, vpn does not applies to uid.
@@ -7077,11 +7068,11 @@ TEST_F(ResolverMultinetworkTest, PerAppDefaultNetwork) {
        } vpnWithDnsServerConfigs[]{
                // clang-format off
                // App can select the system default network without any permission.
                {&sysDefaultNetwork, systemDefaultNetId, sysDefaultPair->dnsServer},
                {&sysDefaultNetwork, systemDefaultNetId, *sysDefaultNwDnsSv},
                // App can select the restricted network, since its uid was assigned to the network.
                {&appDefaultNetwork, appDefaultNetId, appDefaultPair->dnsServer},
                {&appDefaultNetwork, appDefaultNetId, *appDefaultNwDnsSv},
                // App does not have access to the VPN. --> fallback to app default network.
                {&vpn, appDefaultNetId, appDefaultPair->dnsServer},
                {&vpn, appDefaultNetId, *appDefaultNwDnsSv},
                // clang-format on
        };
        for (const auto& config : vpnWithDnsServerConfigs) {
@@ -7096,11 +7087,11 @@ TEST_F(ResolverMultinetworkTest, PerAppDefaultNetwork) {
        // App default network applies to uid, vpn applies to uid. --> use vpn.
        ASSERT_RESULT_OK(vpn.addUser(TEST_UID));
        expectDnsWorksForUid(host_name, vpn.netId(), TEST_UID, expectedDnsReply);
        expectDnsQueryCountsFn(expectedDnsReply.size(), vpnPair->dnsServer, vpnNetId);
        expectDnsQueryCountsFn(expectedDnsReply.size(), *vpnDnsSv, vpnNetId);

        // vpn without server. --> fallback to app default network.
        ASSERT_TRUE(vpn.clearDnsConfiguration());
        expectDnsWorksForUid(host_name, vpn.netId(), TEST_UID, expectedDnsReply);
        expectDnsQueryCountsFn(expectedDnsReply.size(), appDefaultPair->dnsServer, appDefaultNetId);
        expectDnsQueryCountsFn(expectedDnsReply.size(), *appDefaultNwDnsSv, appDefaultNetId);
    }
}