Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 6ce587d2 authored by Mike Yu's avatar Mike Yu Committed by Luke Huang
Browse files

Support multinetwork tests

The integration test used to set up testing DNS servers on loopback
interface. To support testing functionality for multinetwork, make
the test able to send queries to a TUN interface, and the queries
will be forwarded to the testing DNS servers.

To forward packets, implement a forwarder which can translate packets
(v4-to-v4 or v6-to-v6) between the resolver and testing DNS servers
and can forward packets to each other.

Also add three tests:
  GetAddrInfo_AI_ADDRCONFIG
  NetworkDestroyedDuringQueryInFlight
  OneCachePerNetwork

And remove unused libraries from the test:
  libnetd_test_tun_interface
  libnetd_test_utils

Test: cd packages/modules/DnsResolver && atest
Change-Id: I52a52ce59373bc8b9462064c0409b657696c379f
parent 568ed6c9
Loading
Loading
Loading
Loading
+2 −2
Original line number Original line Diff line number Diff line
@@ -151,6 +151,7 @@ cc_test {
        "dns_responder/dns_responder.cpp",
        "dns_responder/dns_responder.cpp",
        "dnsresolver_binder_test.cpp",
        "dnsresolver_binder_test.cpp",
        "resolv_integration_test.cpp",
        "resolv_integration_test.cpp",
        "tun_forwarder.cpp",
    ],
    ],
    header_libs: [
    header_libs: [
        "dnsproxyd_protocol_headers",
        "dnsproxyd_protocol_headers",
@@ -170,13 +171,12 @@ cc_test {
        "libnetd_test_dnsresponder_ndk",
        "libnetd_test_dnsresponder_ndk",
        "libnetd_test_metrics_listener",
        "libnetd_test_metrics_listener",
        "libnetd_test_resolv_utils",
        "libnetd_test_resolv_utils",
        "libnetd_test_tun_interface",
        "libnetd_test_utils",
        "libnetdutils",
        "libnetdutils",
        "libssl",
        "libssl",
        "libutils",
        "libutils",
        "netd_aidl_interface-ndk_platform",
        "netd_aidl_interface-ndk_platform",
        "netd_event_listener_interface-ndk_platform",
        "netd_event_listener_interface-ndk_platform",
        "libipchecksum",
    ],
    ],
    // This test talks to the DnsResolver module over a binary protocol on a socket, so keep it as
    // This test talks to the DnsResolver module over a binary protocol on a socket, so keep it as
    // multilib setting is worth because we might be able to get some coverage for the case where
    // multilib setting is worth because we might be able to get some coverage for the case where
+17 −1
Original line number Original line Diff line number Diff line
@@ -175,6 +175,7 @@ class DNSResponder {
    void setResponseProbability(double response_probability);
    void setResponseProbability(double response_probability);
    void setResponseProbability(double response_probability, int protocol);
    void setResponseProbability(double response_probability, int protocol);
    void setResponseDelayMs(unsigned);
    void setResponseDelayMs(unsigned);
    void setErrorRcode(ns_rcode error_rcode) { error_rcode_ = error_rcode; }
    void setEdns(Edns edns);
    void setEdns(Edns edns);
    void setTtl(unsigned ttl);
    void setTtl(unsigned ttl);
    bool running() const;
    bool running() const;
@@ -190,6 +191,16 @@ class DNSResponder {
    void setDeferredResp(bool deferred_resp);
    void setDeferredResp(bool deferred_resp);
    static bool fillRdata(const std::string& rdatastr, DNSRecord& record);
    static bool fillRdata(const std::string& rdatastr, DNSRecord& record);


    // These functions are helpers for binding the listening sockets to a specific network, which
    // is necessary only for multinetwork tests. Since binding sockets to a network requires
    // the dependency of libnetd_client, and DNSResponder is also widely used in other tests like
    // resolv_unit_test which doesn't need that dependency, so expose the socket fds to let the
    // callers perform binding operations by themselves. Callers MUST not close the fds.
    void setNetwork(unsigned netId) { mNetId = netId; }
    std::optional<unsigned> getNetwork() const { return mNetId; }
    int getUdpSocket() const { return udp_socket_.get(); }
    int getTcpSocket() const { return tcp_socket_.get(); }

    // TODO: Make DNSResponder record unknown queries in a vector for improving the debugging.
    // TODO: Make DNSResponder record unknown queries in a vector for improving the debugging.
    // Unit test could dump the unexpected query for further debug if any unexpected failure.
    // Unit test could dump the unexpected query for further debug if any unexpected failure.


@@ -284,8 +295,10 @@ class DNSResponder {
    // Address and service to listen on TCP and UDP.
    // Address and service to listen on TCP and UDP.
    const std::string listen_address_;
    const std::string listen_address_;
    const std::string listen_service_;
    const std::string listen_service_;

    // TODO: Consider refactoring atomic members of this class to a single big mutex.
    // Error code to return for requests for an unknown name.
    // Error code to return for requests for an unknown name.
    const ns_rcode error_rcode_;
    ns_rcode error_rcode_;
    // Mapping type the DNS server used to build the response.
    // Mapping type the DNS server used to build the response.
    const MappingType mapping_type_;
    const MappingType mapping_type_;
    // Probability that a valid response on TCP is being sent instead of
    // Probability that a valid response on TCP is being sent instead of
@@ -340,6 +353,9 @@ class DNSResponder {
    std::condition_variable cv_for_deferred_resp_;
    std::condition_variable cv_for_deferred_resp_;
    std::mutex cv_mutex_for_deferred_resp_;
    std::mutex cv_mutex_for_deferred_resp_;
    bool deferred_resp_ GUARDED_BY(cv_mutex_for_deferred_resp_) = false;
    bool deferred_resp_ GUARDED_BY(cv_mutex_for_deferred_resp_) = false;

    // The network to which the listening sockets will be bound.
    std::optional<unsigned> mNetId;
};
};


}  // namespace test
}  // namespace test
+339 −1
Original line number Original line Diff line number Diff line
@@ -20,6 +20,7 @@
#include <android-base/logging.h>
#include <android-base/logging.h>
#include <android-base/parseint.h>
#include <android-base/parseint.h>
#include <android-base/properties.h>
#include <android-base/properties.h>
#include <android-base/result.h>
#include <android-base/stringprintf.h>
#include <android-base/stringprintf.h>
#include <android-base/unique_fd.h>
#include <android-base/unique_fd.h>
#include <android/multinetwork.h>  // ResNsendFlags
#include <android/multinetwork.h>  // ResNsendFlags
@@ -62,13 +63,13 @@
#include "netid_client.h"  // NETID_UNSET
#include "netid_client.h"  // NETID_UNSET
#include "params.h"        // MAXNS
#include "params.h"        // MAXNS
#include "stats.h"         // RCODE_TIMEOUT
#include "stats.h"         // RCODE_TIMEOUT
#include "test_utils.h"
#include "tests/dns_metrics_listener/dns_metrics_listener.h"
#include "tests/dns_metrics_listener/dns_metrics_listener.h"
#include "tests/dns_responder/dns_responder.h"
#include "tests/dns_responder/dns_responder.h"
#include "tests/dns_responder/dns_responder_client_ndk.h"
#include "tests/dns_responder/dns_responder_client_ndk.h"
#include "tests/dns_responder/dns_tls_certificate.h"
#include "tests/dns_responder/dns_tls_certificate.h"
#include "tests/dns_responder/dns_tls_frontend.h"
#include "tests/dns_responder/dns_tls_frontend.h"
#include "tests/resolv_test_utils.h"
#include "tests/resolv_test_utils.h"
#include "tests/tun_forwarder.h"


// Valid VPN netId range is 100 ~ 65535
// Valid VPN netId range is 100 ~ 65535
constexpr int TEST_VPN_NETID = 65502;
constexpr int TEST_VPN_NETID = 65502;
@@ -86,10 +87,13 @@ using aidl::android::net::IDnsResolver;
using aidl::android::net::INetd;
using aidl::android::net::INetd;
using aidl::android::net::ResolverParamsParcel;
using aidl::android::net::ResolverParamsParcel;
using aidl::android::net::metrics::INetdEventListener;
using aidl::android::net::metrics::INetdEventListener;
using android::base::Error;
using android::base::ParseInt;
using android::base::ParseInt;
using android::base::Result;
using android::base::StringPrintf;
using android::base::StringPrintf;
using android::base::unique_fd;
using android::base::unique_fd;
using android::net::ResolverStats;
using android::net::ResolverStats;
using android::net::TunForwarder;
using android::net::metrics::DnsMetricsListener;
using android::net::metrics::DnsMetricsListener;
using android::netdutils::enableSockopt;
using android::netdutils::enableSockopt;
using android::netdutils::makeSlice;
using android::netdutils::makeSlice;
@@ -5116,3 +5120,337 @@ TEST_F(ResolverTest, BlockDnsQueryUidDoesNotLeadToBadServer) {
    EXPECT_EQ(dns1.queries().size(), 0U);
    EXPECT_EQ(dns1.queries().size(), 0U);
    EXPECT_EQ(dns2.queries().size(), 0U);
    EXPECT_EQ(dns2.queries().size(), 0U);
}
}

// ResolverMultinetworkTest is used to verify multinetwork functionality. Here's how it works:
// The resolver sends queries to address A, and then there will be a TunForwarder helping forward
// the packets to address B, which is the address on which the testing server is listening. The
// answer packets responded from the testing server go through the reverse path back to the
// resolver.
//
// To achieve the that, it needs to set up a interface with routing rules. Tests are not
// supposed to initiate DNS servers on their own; instead, some utilities are added to the class to
// help the setup.
//
// An example of how to use it:
// TEST_F() {
//     ScopedNetwork network = CreateScopedNetwork(V4);
//     network.init();
//
//     auto dns = network.addIpv4Dns();
//     StartDns(dns.dnsServer, {});
//
//     setResolverConfiguration(...);
//     network.startTunForwarder();
//
//     // Send queries here
// }

class ResolverMultinetworkTest : public ResolverTest {
  protected:
    enum class ConnectivityType { V4, V6, V4V6 };

    struct DnsServerPair {
        test::DNSResponder& dnsServer;
        std::string dnsAddr;  // The DNS server address used for setResolverConfiguration().
        // TODO: Add test::DnsTlsFrontend* and std::string for DoT.
    };

    class ScopedNetwork {
      public:
        ScopedNetwork(unsigned netId, ConnectivityType type, INetd* netdSrv,
                      IDnsResolver* dnsResolvSrv)
            : mNetId(netId),
              mConnectivityType(type),
              mNetdSrv(netdSrv),
              mDnsResolvSrv(dnsResolvSrv) {
            mIfname = StringPrintf("testtun%d", netId);
        }
        ~ScopedNetwork() { destroy(); }

        Result<void> init();
        void destroy();
        Result<DnsServerPair> addIpv4Dns() { return addDns(ConnectivityType::V4); }
        Result<DnsServerPair> addIpv6Dns() { return addDns(ConnectivityType::V6); }
        bool startTunForwarder() { return mTunForwarder->startForwarding(); }
        unsigned netId() const { return mNetId; }

      private:
        Result<DnsServerPair> addDns(ConnectivityType connectivity);
        std::string makeIpv4AddrString(unsigned n) const {
            return StringPrintf("192.168.%u.%u", mNetId, n);
        }
        std::string makeIpv6AddrString(unsigned n) const {
            return StringPrintf("2001:db8:%u::%u", mNetId, n);
        }

        const unsigned mNetId;
        const ConnectivityType mConnectivityType;
        INetd* mNetdSrv;
        IDnsResolver* mDnsResolvSrv;

        std::string mIfname;
        std::unique_ptr<TunForwarder> mTunForwarder;
        std::vector<std::unique_ptr<test::DNSResponder>> mDnsServers;
        // TODO: Add std::vector<std::unique_ptr<test::DnsTlsFrontend>>
    };

    void SetUp() override {
        ResolverTest::SetUp();
        ASSERT_NE(mDnsClient.netdService(), nullptr);
        ASSERT_NE(mDnsClient.resolvService(), nullptr);
    }

    void TearDown() override { ResolverTest::TearDown(); }

    ScopedNetwork CreateScopedNetwork(ConnectivityType type);
    void StartDns(test::DNSResponder& dns, const std::vector<DnsRecord>& records);

    unsigned getFreeNetId() { return mNextNetId++; }

  private:
    // Use a different netId because this class inherits from the class ResolverTest which
    // always creates TEST_NETID in setup. It's incremented when CreateScopedNetwork() is called.
    // Note: Don't create more than 20 networks in the class since 51 is used for the dummy network.
    unsigned mNextNetId = 31;
};

ResolverMultinetworkTest::ScopedNetwork ResolverMultinetworkTest::CreateScopedNetwork(
        ConnectivityType type) {
    return {getFreeNetId(), type, mDnsClient.netdService(), mDnsClient.resolvService()};
}

Result<void> ResolverMultinetworkTest::ScopedNetwork::init() {
    unique_fd ufd = TunForwarder::createTun(mIfname);
    if (!ufd.ok()) {
        return Errorf("createTun for {} failed", mIfname);
    }
    mTunForwarder = std::make_unique<TunForwarder>(std::move(ufd));

    if (auto r = mNetdSrv->networkCreatePhysical(mNetId, INetd::PERMISSION_SYSTEM); !r.isOk()) {
        return Error() << r.getMessage();
    }
    if (auto r = mDnsResolvSrv->createNetworkCache(mNetId); !r.isOk()) {
        return Error() << r.getMessage();
    }
    if (auto r = mNetdSrv->networkAddInterface(mNetId, mIfname); !r.isOk()) {
        return Error() << r.getMessage();
    }

    if (mConnectivityType == ConnectivityType::V4 || mConnectivityType == ConnectivityType::V4V6) {
        const std::string v4Addr = makeIpv4AddrString(1);
        if (auto r = mNetdSrv->interfaceAddAddress(mIfname, v4Addr, 32); !r.isOk()) {
            return Error() << r.getMessage();
        }
        if (auto r = mNetdSrv->networkAddRoute(mNetId, mIfname, "0.0.0.0/0", ""); !r.isOk()) {
            return Error() << r.getMessage();
        }
    }
    if (mConnectivityType == ConnectivityType::V6 || mConnectivityType == ConnectivityType::V4V6) {
        const std::string v6Addr = makeIpv6AddrString(1);
        if (auto r = mNetdSrv->interfaceAddAddress(mIfname, v6Addr, 128); !r.isOk()) {
            return Error() << r.getMessage();
        }
        if (auto r = mNetdSrv->networkAddRoute(mNetId, mIfname, "::/0", ""); !r.isOk()) {
            return Error() << r.getMessage();
        }
    }

    return {};
}

void ResolverMultinetworkTest::ScopedNetwork::destroy() {
    mNetdSrv->networkDestroy(mNetId);
    mDnsResolvSrv->destroyNetworkCache(mNetId);
}

void ResolverMultinetworkTest::StartDns(test::DNSResponder& dns,
                                        const std::vector<DnsRecord>& records) {
    ResolverTest::StartDns(dns, records);

    // Bind the DNSResponder's sockets to the network if specified.
    if (std::optional<unsigned> netId = dns.getNetwork(); netId.has_value()) {
        setNetworkForSocket(netId.value(), dns.getUdpSocket());
        setNetworkForSocket(netId.value(), dns.getTcpSocket());
    }
}

Result<ResolverMultinetworkTest::DnsServerPair> ResolverMultinetworkTest::ScopedNetwork::addDns(
        ConnectivityType type) {
    const int index = mDnsServers.size();
    const int prefixLen = (type == ConnectivityType::V4) ? 32 : 128;

    const std::function<std::string(unsigned)> makeIpString =
            std::bind((type == ConnectivityType::V4) ? &ScopedNetwork::makeIpv4AddrString
                                                     : &ScopedNetwork::makeIpv6AddrString,
                      this, std::placeholders::_1);

    std::string src1 = makeIpString(1);            // The address from which the resolver will send.
    std::string dst1 = makeIpString(index + 100);  // The address to which the resolver will send.
    std::string src2 = dst1;                       // The address translated from src1.
    std::string dst2 = makeIpString(index + 200);  // The address translated from dst2.

    if (!mTunForwarder->addForwardingRule({src1, dst1}, {src2, dst2}) ||
        !mTunForwarder->addForwardingRule({dst2, src2}, {dst1, src1})) {
        return Errorf("Failed to add the rules ({}, {}, {}, {})", src1, dst1, src2, dst2);
    }

    if (!mNetdSrv->interfaceAddAddress(mIfname, dst2, prefixLen).isOk()) {
        return Errorf("interfaceAddAddress({}, {}, {}) failed", mIfname, dst2, prefixLen);
    }

    // Create a DNSResponder instance.
    auto& dnsPtr = mDnsServers.emplace_back(std::make_unique<test::DNSResponder>(dst2));
    dnsPtr->setNetwork(mNetId);
    return DnsServerPair{
            .dnsServer = *dnsPtr,
            .dnsAddr = dst1,
    };
}

TEST_F(ResolverMultinetworkTest, GetAddrInfo_AI_ADDRCONFIG) {
    constexpr char host_name[] = "ohayou.example.com.";

    const std::array<ConnectivityType, 3> allTypes = {
            ConnectivityType::V4,
            ConnectivityType::V6,
            ConnectivityType::V4V6,
    };
    for (const auto& type : allTypes) {
        SCOPED_TRACE(StringPrintf("ConnectivityType: %d", type));

        // Create a network.
        ScopedNetwork network = CreateScopedNetwork(type);
        ASSERT_RESULT_OK(network.init());

        // Add a testing DNS server.
        const Result<DnsServerPair> dnsPair =
                (type == ConnectivityType::V4) ? network.addIpv4Dns() : network.addIpv6Dns();
        ASSERT_RESULT_OK(dnsPair);
        StartDns(dnsPair->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.31"},
                                      {host_name, ns_type::ns_t_aaaa, "2001:db8:cafe:d00d::31"}});

        // Set up resolver and start forwarding.
        ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
        parcel.tlsServers.clear();
        parcel.netId = network.netId();
        parcel.servers = {dnsPair->dnsAddr};
        ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
        ASSERT_TRUE(network.startTunForwarder());

        const addrinfo hints = {
                .ai_flags = AI_ADDRCONFIG,
                .ai_family = AF_UNSPEC,
                .ai_socktype = SOCK_DGRAM,
        };
        addrinfo* raw_ai_result = nullptr;
        EXPECT_EQ(0, android_getaddrinfofornet(host_name, nullptr, &hints, network.netId(),
                                               MARK_UNSET, &raw_ai_result));
        ScopedAddrinfo ai_result(raw_ai_result);
        std::vector<std::string> result_strs = ToStrings(ai_result);
        std::vector<std::string> expectedResult;
        size_t expectedQueries = 0;

        if (type == ConnectivityType::V6 || type == ConnectivityType::V4V6) {
            expectedResult.emplace_back("2001:db8:cafe:d00d::31");
            expectedQueries++;
        }
        if (type == ConnectivityType::V4 || type == ConnectivityType::V4V6) {
            expectedResult.emplace_back("1.1.1.31");
            expectedQueries++;
        }
        EXPECT_THAT(result_strs, testing::UnorderedElementsAreArray(expectedResult));
        EXPECT_EQ(GetNumQueries(dnsPair->dnsServer, host_name), expectedQueries);
    }
}

TEST_F(ResolverMultinetworkTest, NetworkDestroyedDuringQueryInFlight) {
    constexpr char host_name[] = "ohayou.example.com.";

    // Create a network and add an ipv4 DNS server.
    auto network =
            std::make_unique<ScopedNetwork>(getFreeNetId(), ConnectivityType::V4V6,
                                            mDnsClient.netdService(), mDnsClient.resolvService());
    ASSERT_RESULT_OK(network->init());
    const Result<DnsServerPair> dnsPair = network->addIpv4Dns();
    ASSERT_RESULT_OK(dnsPair);

    // Set the DNS server unresponsive.
    dnsPair->dnsServer.setResponseProbability(0.0);
    dnsPair->dnsServer.setErrorRcode(static_cast<ns_rcode>(-1));
    StartDns(dnsPair->dnsServer, {});

    // Set up resolver and start forwarding.
    ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
    parcel.tlsServers.clear();
    parcel.netId = network->netId();
    parcel.servers = {dnsPair->dnsAddr};
    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
    ASSERT_TRUE(network->startTunForwarder());

    // Expect the things happening in order:
    // 1. The thread sends the query to the dns server which is unresponsive.
    // 2. The network is destroyed while the thread is waiting for the response from the dns server.
    // 3. After the dns server timeout, the thread retries but fails to connect.
    std::thread lookup([&]() {
        int fd = resNetworkQuery(network->netId(), host_name, ns_c_in, ns_t_a, 0);
        EXPECT_TRUE(fd != -1);
        expectAnswersNotValid(fd, -ETIMEDOUT);
    });

    // Tear down the network as soon as the dns server receives the query.
    const auto condition = [&]() { return GetNumQueries(dnsPair->dnsServer, host_name) == 1U; };
    EXPECT_TRUE(PollForCondition(condition));
    network.reset();

    lookup.join();
}

TEST_F(ResolverMultinetworkTest, OneCachePerNetwork) {
    SKIP_IF_REMOTE_VERSION_LESS_THAN(mDnsClient.resolvService(), 4);
    constexpr char host_name[] = "ohayou.example.com.";

    ScopedNetwork network1 = CreateScopedNetwork(ConnectivityType::V4V6);
    ScopedNetwork network2 = CreateScopedNetwork(ConnectivityType::V4V6);
    ASSERT_RESULT_OK(network1.init());
    ASSERT_RESULT_OK(network2.init());

    const Result<DnsServerPair> dnsPair1 = network1.addIpv4Dns();
    const Result<DnsServerPair> dnsPair2 = network2.addIpv4Dns();
    ASSERT_RESULT_OK(dnsPair1);
    ASSERT_RESULT_OK(dnsPair2);
    StartDns(dnsPair1->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.31"}});
    StartDns(dnsPair2->dnsServer, {{host_name, ns_type::ns_t_a, "1.1.1.32"}});

    // Set up resolver for network 1 and start forwarding.
    ResolverParamsParcel parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
    parcel.tlsServers.clear();
    parcel.netId = network1.netId();
    parcel.servers = {dnsPair1->dnsAddr};
    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
    ASSERT_TRUE(network1.startTunForwarder());

    // Set up resolver for network 2 and start forwarding.
    parcel.netId = network2.netId();
    parcel.servers = {dnsPair2->dnsAddr};
    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
    ASSERT_TRUE(network2.startTunForwarder());

    // Send the same queries to both networks.
    int fd1 = resNetworkQuery(network1.netId(), host_name, ns_c_in, ns_t_a, 0);
    int fd2 = resNetworkQuery(network2.netId(), host_name, ns_c_in, ns_t_a, 0);

    expectAnswersValid(fd1, AF_INET, "1.1.1.31");
    expectAnswersValid(fd2, AF_INET, "1.1.1.32");
    EXPECT_EQ(GetNumQueries(dnsPair1->dnsServer, host_name), 1U);
    EXPECT_EQ(GetNumQueries(dnsPair2->dnsServer, host_name), 1U);

    // Flush the cache of network 1, and send the queries again.
    EXPECT_TRUE(mDnsClient.resolvService()->flushNetworkCache(network1.netId()).isOk());
    fd1 = resNetworkQuery(network1.netId(), host_name, ns_c_in, ns_t_a, 0);
    fd2 = resNetworkQuery(network2.netId(), host_name, ns_c_in, ns_t_a, 0);

    expectAnswersValid(fd1, AF_INET, "1.1.1.31");
    expectAnswersValid(fd2, AF_INET, "1.1.1.32");
    EXPECT_EQ(GetNumQueries(dnsPair1->dnsServer, host_name), 2U);
    EXPECT_EQ(GetNumQueries(dnsPair2->dnsServer, host_name), 1U);
}
+419 −0

File added.

Preview size limit exceeded, changes collapsed.

tests/tun_forwarder.h

0 → 100644
+105 −0
Original line number Original line Diff line number Diff line
/*
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

#pragma once

#include <map>
#include <thread>

#include <netinet/ip.h>

#include <android-base/result.h>
#include <android-base/unique_fd.h>
#include <netdutils/Slice.h>

namespace android::net {

// Given a TUN interface fd, TunForwarder reads packets from the fd, changes their IP header
// according to a set of forwarding rules (which can be set by addForwardingRule), and sends
// new packets back to the fd. Only IPv4 and IPv6 packets with recognized source and destination
// addresses are accepted; other packets are silently ignored.
class TunForwarder {
  public:
    TunForwarder(base::unique_fd tunFd);
    ~TunForwarder();

    bool addForwardingRule(const std::array<std::string, 2>& from,
                           const std::array<std::string, 2>& to);
    bool startForwarding();
    bool stopForwarding();

    static base::unique_fd createTun(const std::string& ifname);

  private:
    // TODO: Considering using IPAddress for v4pair and v6pair. This might requires adding
    // addr4() and addr6() as IPPrefix does.
    struct v4pair {
        static base::Result<v4pair> makePair(const std::array<std::string, 2>& addrs);
        v4pair() = default;
        v4pair(int32_t srcAddr, int32_t dstAddr) {
            src.s_addr = static_cast<in_addr_t>(srcAddr);
            dst.s_addr = static_cast<in_addr_t>(dstAddr);
        }
        in_addr src;
        in_addr dst;
        bool operator==(const v4pair& o) const;
        bool operator<(const v4pair& o) const;
    };

    struct v6pair {
        static base::Result<v6pair> makePair(const std::array<std::string, 2>& addrs);
        v6pair() = default;
        v6pair(const in6_addr& srcAddr, const in6_addr& dstAddr) : src(srcAddr), dst(dstAddr) {}
        in6_addr src;
        in6_addr dst;
        bool operator==(const v6pair& o) const;
        bool operator<(const v6pair& o) const;
    };

    void loop();
    void handlePacket(int fd) const;

    // Send a signal to terminate the loop thread.
    bool signalEventFd();

    // A series of functions to check the packet. Return error if the packet is neither UDP nor TCP.
    base::Result<void> validatePacket(netdutils::Slice tunPacket) const;
    base::Result<void> validateIpv4Packet(netdutils::Slice ipv4Packet) const;
    base::Result<void> validateIpv6Packet(netdutils::Slice ipv6Packet) const;
    base::Result<void> validateUdpPacket(netdutils::Slice udpPacket) const;
    base::Result<void> validateTcpPacket(netdutils::Slice tcpPacket) const;

    // The function assumes |tunPacket| is either UDP or TCP packet, changes the source/destination
    // addresses, and updates the checksum.
    base::Result<void> translatePacket(netdutils::Slice tunPacket) const;
    base::Result<void> translateIpv4Packet(netdutils::Slice ipv4Packet) const;
    base::Result<void> translateIpv6Packet(netdutils::Slice ipv6Packet) const;
    void translateUdpPacket(netdutils::Slice udpPacket, uint32_t oldPseudoSum,
                            uint32_t newPseudoSum) const;
    void translateTcpPacket(netdutils::Slice tcpPacket, uint32_t oldPseudoSum,
                            uint32_t newPseudoSum) const;

    std::thread mForwarder;
    base::unique_fd mTunFd;
    base::unique_fd mEventFd;
    std::map<v4pair, v4pair> mRulesIpv4;
    std::map<v6pair, v6pair> mRulesIpv6;

    static constexpr int kPollTimeoutMs = 5000;
};

}  // namespace android::net