Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 4e4a2e0a authored by lifr's avatar lifr
Browse files

Add integration test for DNS answer RR with CNAMEs chain

Bug: 123376330
Test: resolv_integration_test
Change-Id: I74ba26f6a892f86e40b6b02611d7f9adee454fec
parent 36796f35
Loading
Loading
Loading
Loading
+55 −26
Original line number Diff line number Diff line
@@ -27,6 +27,7 @@
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <set>

#include <iostream>
#include <vector>
@@ -827,11 +828,12 @@ bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
            return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
                                     response_len);
        }

        if (!addAnswerRecords(question, &header.answers)) {
            return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
                                     response_len);
            return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response, response_len);
        }
    }

    header.qr = true;
    char* response_cur = header.write(response, response + *response_len);
    if (response_cur == nullptr) {
@@ -844,45 +846,73 @@ bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
                                    std::vector<DNSRecord>* answers) const {
    std::lock_guard guard(mappings_mutex_);
    auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
    if (it == mappings_.end()) {
    std::string rname = question.qname.name;
    std::vector<int> rtypes;

    if (question.qtype == ns_type::ns_t_a || question.qtype == ns_type::ns_t_aaaa)
        rtypes.push_back(ns_type::ns_t_cname);
    rtypes.push_back(question.qtype);
    for (int rtype : rtypes) {
        std::set<std::string> cnames_Loop;
        std::unordered_map<QueryKey, std::string, QueryKeyHash>::const_iterator it;
        while ((it = mappings_.find(QueryKey(rname, rtype))) != mappings_.end()) {
            if (rtype == ns_type::ns_t_cname) {
                // When detect CNAME infinite loops by cnames_Loop, it won't save the duplicate one.
                // As following, the query will stop on loop3 by detecting the same cname.
                // loop1.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(insert in answer record)
                // loop2.{"b.xxx.com", ns_type::ns_t_cname, "a.xxx.com"}(insert in answer record)
                // loop3.{"a.xxx.com", ns_type::ns_t_cname, "b.xxx.com"}(When the same cname record
                //   is found in cnames_Loop already, break the query loop.)
                if (cnames_Loop.find(it->first.name) != cnames_Loop.end()) break;
                cnames_Loop.insert(it->first.name);
            }
            DNSRecord record{
                    .name = {.name = it->first.name},
                    .rtype = it->first.type,
                    .rclass = ns_class::ns_c_in,
                    .ttl = 5,  // seconds
            };
            fillAnswerRdata(it->second, record);
            answers->push_back(std::move(record));
            if (rtype != ns_type::ns_t_cname) break;
            rname = it->second;
        }
    }

    if (answers->size() == 0) {
        // TODO(imaipi): handle correctly
        ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
              question.qname.name.c_str(), dnstype2str(question.qtype));
    }

    return true;
}
    DBGLOG("mapping found for %s %s: %s", question.qname.name.c_str(), dnstype2str(question.qtype),
           it->second.c_str());
    DNSRecord record;
    record.name = question.qname;
    record.rtype = question.qtype;
    record.rclass = ns_class::ns_c_in;
    record.ttl = 5;  // seconds
    if (question.qtype == ns_type::ns_t_a) {

bool DNSResponder::fillAnswerRdata(const std::string& rdatastr, DNSRecord& record) const {
    if (record.rtype == ns_type::ns_t_a) {
        record.rdata.resize(4);
        if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) {
            ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str());
        if (inet_pton(AF_INET, rdatastr.c_str(), record.rdata.data()) != 1) {
            ALOGI("inet_pton(AF_INET, %s) failed", rdatastr.c_str());
            return false;
        }
    } else if (question.qtype == ns_type::ns_t_aaaa) {
    } else if (record.rtype == ns_type::ns_t_aaaa) {
        record.rdata.resize(16);
        if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) {
            ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str());
        if (inet_pton(AF_INET6, rdatastr.c_str(), record.rdata.data()) != 1) {
            ALOGI("inet_pton(AF_INET6, %s) failed", rdatastr.c_str());
            return false;
        }
    } else if (question.qtype == ns_type::ns_t_ptr) {
    } else if ((record.rtype == ns_type::ns_t_ptr) || (record.rtype == ns_type::ns_t_cname)) {
        constexpr char delimiter = '.';
        std::string name = it->second;
        std::string name = rdatastr;
        std::vector<char> rdata;

        // PTRDNAME field
        // Generating PTRDNAME field(section 3.3.12) or CNAME field(section 3.3.1) in rfc1035.
        // The "name" should be an absolute domain name which ends in a dot.
        if (name.back() != delimiter) {
            ALOGI("invalid absolute domain name");
            return false;
        }
        name.pop_back();  // remove the dot in tail

        for (const std::string& label : android::base::Split(name, {delimiter})) {
            // The length of label is limited to 63 octets or less. See RFC 1035 section 3.1.
            if (label.length() == 0 || label.length() > 63) {
@@ -902,10 +932,9 @@ bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
        }
        record.rdata = move(rdata);
    } else {
        ALOGI("unhandled qtype %s", dnstype2str(question.qtype));
        ALOGI("unhandled qtype %s", dnstype2str(record.rtype));
        return false;
    }
    answers->push_back(std::move(record));
    return true;
}

+3 −2
Original line number Diff line number Diff line
@@ -113,8 +113,9 @@ class DNSResponder {
    bool handleDNSRequest(const char* buffer, ssize_t buffer_len,
                          char* response, size_t* response_len) const;

    bool addAnswerRecords(const DNSQuestion& question,
                          std::vector<DNSRecord>* answers) const;
    bool addAnswerRecords(const DNSQuestion& question, std::vector<DNSRecord>* answers) const;

    bool fillAnswerRdata(const std::string& rdatastr, DNSRecord& record) const;

    bool generateErrorResponse(DNSHeader* header, ns_rcode rcode,
                               char* response, size_t* response_len) const;
+138 −0
Original line number Diff line number Diff line
@@ -334,6 +334,80 @@ TEST_F(ResolverTest, GetHostByName) {
    EXPECT_TRUE(result->h_addr_list[1] == nullptr);
}

TEST_F(ResolverTest, GetHostByName_cnames) {
    constexpr char host_name[] = "host.example.com.";
    size_t cnamecount = 0;
    test::DNSResponder dns;

    const std::vector<DnsRecord> records = {
            {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
            {"a.example.com.", ns_type::ns_t_cname, "b.example.com."},
            {"b.example.com.", ns_type::ns_t_cname, "c.example.com."},
            {"c.example.com.", ns_type::ns_t_cname, "d.example.com."},
            {"d.example.com.", ns_type::ns_t_cname, "e.example.com."},
            {"e.example.com.", ns_type::ns_t_cname, host_name},
            {host_name, ns_type::ns_t_a, "1.2.3.3"},
            {host_name, ns_type::ns_t_aaaa, "2001:db8::42"},
    };
    StartDns(dns, records);
    ASSERT_TRUE(mDnsClient.SetResolversForNetwork());

    // using gethostbyname2() to resolve ipv4 hello.example.com. to 1.2.3.3
    // Ensure the v4 address and cnames are correct
    const hostent* result;
    result = gethostbyname2("hello", AF_INET);
    ASSERT_FALSE(result == nullptr);

    for (int i = 0; result != nullptr && result->h_aliases[i] != nullptr; i++) {
        std::string domain_name = records[i].host_name.substr(0, records[i].host_name.size() - 1);
        EXPECT_EQ(result->h_aliases[i], domain_name);
        cnamecount++;
    }
    // The size of "Non-cname type" record in DNS records is 2
    ASSERT_EQ(cnamecount, records.size() - 2);
    ASSERT_EQ(4, result->h_length);
    ASSERT_FALSE(result->h_addr_list[0] == nullptr);
    EXPECT_EQ("1.2.3.3", ToString(result));
    EXPECT_TRUE(result->h_addr_list[1] == nullptr);
    EXPECT_EQ(1U, dns.queries().size()) << dns.dumpQueries();

    // using gethostbyname2() to resolve ipv6 hello.example.com. to 2001:db8::42
    // Ensure the v6 address and cnames are correct
    cnamecount = 0;
    dns.clearQueries();
    result = gethostbyname2("hello", AF_INET6);
    for (unsigned i = 0; result != nullptr && result->h_aliases[i] != nullptr; i++) {
        std::string domain_name = records[i].host_name.substr(0, records[i].host_name.size() - 1);
        EXPECT_EQ(result->h_aliases[i], domain_name);
        cnamecount++;
    }
    // The size of "Non-cname type" DNS record in records is 2
    ASSERT_EQ(cnamecount, records.size() - 2);
    ASSERT_FALSE(result == nullptr);
    ASSERT_EQ(16, result->h_length);
    ASSERT_FALSE(result->h_addr_list[0] == nullptr);
    EXPECT_EQ("2001:db8::42", ToString(result));
    EXPECT_TRUE(result->h_addr_list[1] == nullptr);
}

TEST_F(ResolverTest, GetHostByName_cnamesInfiniteLoop) {
    test::DNSResponder dns;
    const std::vector<DnsRecord> records = {
            {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
            {"a.example.com.", ns_type::ns_t_cname, kHelloExampleCom},
    };
    StartDns(dns, records);
    ASSERT_TRUE(mDnsClient.SetResolversForNetwork());

    const hostent* result;
    result = gethostbyname2("hello", AF_INET);
    ASSERT_TRUE(result == nullptr);

    dns.clearQueries();
    result = gethostbyname2("hello", AF_INET6);
    ASSERT_TRUE(result == nullptr);
}

TEST_F(ResolverTest, GetHostByName_localhost) {
    constexpr char name_camelcase[] = "LocalHost";
    constexpr char name_ip6_dot[] = "ip6-localhost.";
@@ -677,6 +751,70 @@ TEST_F(ResolverTest, GetAddrInfoV4_deferred_resp) {
    t2.join();
}

TEST_F(ResolverTest, GetAddrInfo_cnames) {
    constexpr char host_name[] = "host.example.com.";
    test::DNSResponder dns;
    const std::vector<DnsRecord> records = {
            {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
            {"a.example.com.", ns_type::ns_t_cname, "b.example.com."},
            {"b.example.com.", ns_type::ns_t_cname, "c.example.com."},
            {"c.example.com.", ns_type::ns_t_cname, "d.example.com."},
            {"d.example.com.", ns_type::ns_t_cname, "e.example.com."},
            {"e.example.com.", ns_type::ns_t_cname, host_name},
            {host_name, ns_type::ns_t_a, "1.2.3.3"},
            {host_name, ns_type::ns_t_aaaa, "2001:db8::42"},
    };
    StartDns(dns, records);
    ASSERT_TRUE(mDnsClient.SetResolversForNetwork());

    addrinfo hints = {.ai_family = AF_INET};
    ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
    EXPECT_TRUE(result != nullptr);
    EXPECT_EQ("1.2.3.3", ToString(result));

    dns.clearQueries();
    hints = {.ai_family = AF_INET6};
    result = safe_getaddrinfo("hello", nullptr, &hints);
    EXPECT_TRUE(result != nullptr);
    EXPECT_EQ("2001:db8::42", ToString(result));
}

TEST_F(ResolverTest, GetAddrInfo_cnamesNoIpAddress) {
    test::DNSResponder dns;
    const std::vector<DnsRecord> records = {
            {kHelloExampleCom, ns_type::ns_t_cname, "a.example.com."},
    };
    StartDns(dns, records);
    ASSERT_TRUE(mDnsClient.SetResolversForNetwork());

    addrinfo hints = {.ai_family = AF_INET};
    ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
    EXPECT_TRUE(result == nullptr);

    dns.clearQueries();
    hints = {.ai_family = AF_INET6};
    result = safe_getaddrinfo("hello", nullptr, &hints);
    EXPECT_TRUE(result == nullptr);
}

TEST_F(ResolverTest, GetAddrInfo_cnamesIllegalRdata) {
    test::DNSResponder dns;
    const std::vector<DnsRecord> records = {
            {kHelloExampleCom, ns_type::ns_t_cname, ".!#?"},
    };
    StartDns(dns, records);
    ASSERT_TRUE(mDnsClient.SetResolversForNetwork());

    addrinfo hints = {.ai_family = AF_INET};
    ScopedAddrinfo result = safe_getaddrinfo("hello", nullptr, &hints);
    EXPECT_TRUE(result == nullptr);

    dns.clearQueries();
    hints = {.ai_family = AF_INET6};
    result = safe_getaddrinfo("hello", nullptr, &hints);
    EXPECT_TRUE(result == nullptr);
}

TEST_F(ResolverTest, MultidomainResolution) {
    constexpr char host_name[] = "nihao.example2.com.";
    std::vector<std::string> searchDomains = { "example1.com", "example2.com", "example3.com" };