Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 4ff97d7d authored by Mike Yu's avatar Mike Yu Committed by Gerrit Code Review
Browse files

Merge changes from topic "resolver_cache_test"

* changes:
  resolv: add some tests for resolver cache
  resolv: export more structs from dns_responder
parents 33c60d64 d1ec2548
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -186,6 +186,7 @@ cc_test {
    srcs: [
        "dns_tls_test.cpp",
        "libnetd_resolv_test.cpp",
        "res_cache_test.cpp",
    ],
    shared_libs: [
        "libbase",
+8 −91
Original line number Diff line number Diff line
@@ -144,20 +144,6 @@ const char* dnsclass2str(unsigned dnsclass) {
    return it->second;
}

struct DNSName {
    std::string name;
    const char* read(const char* buffer, const char* buffer_end);
    char* write(char* buffer, const char* buffer_end) const;
    const char* toString() const;
private:
    const char* parseField(const char* buffer, const char* buffer_end,
                           bool* last);
};

const char* DNSName::toString() const {
    return name.c_str();
}

const char* DNSName::read(const char* buffer, const char* buffer_end) {
    const char* cur = buffer;
    bool last = false;
@@ -201,8 +187,7 @@ char* DNSName::write(char* buffer, const char* buffer_end) const {
    return buffer_cur;
}

const char* DNSName::parseField(const char* buffer, const char* buffer_end,
                                bool* last) {
const char* DNSName::parseField(const char* buffer, const char* buffer_end, bool* last) {
    if (buffer + sizeof(uint8_t) > buffer_end) {
        LOG(ERROR) << "parsing failed at line " << __LINE__;
        return nullptr;
@@ -231,15 +216,6 @@ const char* DNSName::parseField(const char* buffer, const char* buffer_end,
    return nullptr;
}

struct DNSQuestion {
    DNSName qname;
    unsigned qtype;
    unsigned qclass;
    const char* read(const char* buffer, const char* buffer_end);
    char* write(char* buffer, const char* buffer_end) const;
    std::string toString() const;
};

const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
    const char* cur = qname.read(buffer, buffer_end);
    if (cur == nullptr) {
@@ -263,41 +239,17 @@ char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
        return nullptr;
    }
    *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
    *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) =
            htons(qclass);
    *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) = htons(qclass);
    return buffer_cur + 2 * sizeof(uint16_t);
}

std::string DNSQuestion::toString() const {
    char buffer[4096];
    int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(),
    int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.name.c_str(),
                       dnstype2str(qtype), dnsclass2str(qclass));
    return std::string(buffer, len);
}

struct DNSRecord {
    DNSName name;
    unsigned rtype;
    unsigned rclass;
    unsigned ttl;
    std::vector<char> rdata;
    const char* read(const char* buffer, const char* buffer_end);
    char* write(char* buffer, const char* buffer_end) const;
    std::string toString() const;
private:
    struct IntFields {
        uint16_t rtype;
        uint16_t rclass;
        uint32_t ttl;
        uint16_t rdlen;
    } __attribute__((__packed__));

    const char* readIntFields(const char* buffer, const char* buffer_end,
            unsigned* rdlen);
    char* writeIntFields(unsigned rdlen, char* buffer,
                         const char* buffer_end) const;
};

const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
    const char* cur = name.read(buffer, buffer_end);
    if (cur == nullptr) {
@@ -332,8 +284,8 @@ char* DNSRecord::write(char* buffer, const char* buffer_end) const {

std::string DNSRecord::toString() const {
    char buffer[4096];
    int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(),
                       dnstype2str(rtype), dnsclass2str(rclass));
    int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.name.c_str(), dnstype2str(rtype),
                       dnsclass2str(rclass));
    return std::string(buffer, len);
}

@@ -365,47 +317,12 @@ char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer,
    return buffer + sizeof(IntFields);
}

struct DNSHeader {
    unsigned id;
    bool ra;
    uint8_t rcode;
    bool qr;
    uint8_t opcode;
    bool aa;
    bool tr;
    bool rd;
    bool ad;
    std::vector<DNSQuestion> questions;
    std::vector<DNSRecord> answers;
    std::vector<DNSRecord> authorities;
    std::vector<DNSRecord> additionals;
    const char* read(const char* buffer, const char* buffer_end);
    char* write(char* buffer, const char* buffer_end) const;
    std::string toString() const;

private:
    struct Header {
        uint16_t id;
        uint8_t flags0;
        uint8_t flags1;
        uint16_t qdcount;
        uint16_t ancount;
        uint16_t nscount;
        uint16_t arcount;
    } __attribute__((__packed__));

    const char* readHeader(const char* buffer, const char* buffer_end,
                           unsigned* qdcount, unsigned* ancount,
                           unsigned* nscount, unsigned* arcount);
};

const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
    unsigned qdcount;
    unsigned ancount;
    unsigned nscount;
    unsigned arcount;
    const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount,
                                 &nscount, &arcount);
    const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount, &nscount, &arcount);
    if (cur == nullptr) {
        LOG(ERROR) << "parsing failed at line " << __LINE__;
        return nullptr;
@@ -838,7 +755,7 @@ bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
    return true;
}

bool DNSResponder::fillAnswerRdata(const std::string& rdatastr, DNSRecord& record) const {
bool DNSResponder::fillAnswerRdata(const std::string& rdatastr, DNSRecord& record) {
    if (record.rtype == ns_type::ns_t_a) {
        record.rdata.resize(4);
        if (inet_pton(AF_INET, rdatastr.c_str(), record.rdata.data()) != 1) {
+73 −5
Original line number Diff line number Diff line
@@ -33,9 +33,78 @@

namespace test {

struct DNSHeader;
struct DNSQuestion;
struct DNSRecord;
struct DNSName {
    std::string name;
    const char* read(const char* buffer, const char* buffer_end);
    char* write(char* buffer, const char* buffer_end) const;

  private:
    const char* parseField(const char* buffer, const char* buffer_end, bool* last);
};

struct DNSQuestion {
    DNSName qname;
    unsigned qtype;
    unsigned qclass;
    const char* read(const char* buffer, const char* buffer_end);
    char* write(char* buffer, const char* buffer_end) const;
    std::string toString() const;
};

struct DNSRecord {
    DNSName name;
    unsigned rtype;
    unsigned rclass;
    unsigned ttl;
    std::vector<char> rdata;
    const char* read(const char* buffer, const char* buffer_end);
    char* write(char* buffer, const char* buffer_end) const;
    std::string toString() const;

  private:
    struct IntFields {
        uint16_t rtype;
        uint16_t rclass;
        uint32_t ttl;
        uint16_t rdlen;
    } __attribute__((__packed__));

    const char* readIntFields(const char* buffer, const char* buffer_end, unsigned* rdlen);
    char* writeIntFields(unsigned rdlen, char* buffer, const char* buffer_end) const;
};

struct DNSHeader {
    unsigned id;
    bool ra;
    uint8_t rcode;
    bool qr;
    uint8_t opcode;
    bool aa;
    bool tr;
    bool rd;
    bool ad;
    std::vector<DNSQuestion> questions;
    std::vector<DNSRecord> answers;
    std::vector<DNSRecord> authorities;
    std::vector<DNSRecord> additionals;
    const char* read(const char* buffer, const char* buffer_end);
    char* write(char* buffer, const char* buffer_end) const;
    std::string toString() const;

  private:
    struct Header {
        uint16_t id;
        uint8_t flags0;
        uint8_t flags1;
        uint16_t qdcount;
        uint16_t ancount;
        uint16_t nscount;
        uint16_t arcount;
    } __attribute__((__packed__));

    const char* readHeader(const char* buffer, const char* buffer_end, unsigned* qdcount,
                           unsigned* ancount, unsigned* nscount, unsigned* arcount);
};

inline const std::string kDefaultListenAddr = "127.0.0.3";
inline const std::string kDefaultListenService = "53";
@@ -78,6 +147,7 @@ class DNSResponder {
    std::condition_variable& getCv() { return cv; }
    std::mutex& getCvMutex() { return cv_mutex_; }
    void setDeferredResp(bool deferred_resp);
    static bool fillAnswerRdata(const std::string& rdatastr, DNSRecord& record);

  private:
    // Key used for accessing mappings.
@@ -113,8 +183,6 @@ class DNSResponder {

    bool addAnswerRecords(const DNSQuestion& question, std::vector<DNSRecord>* answers) const;

    bool fillAnswerRdata(const std::string& rdatastr, DNSRecord& record) const;

    bool generateErrorResponse(DNSHeader* header, ns_rcode rcode,
                               char* response, size_t* response_len) const;
    bool makeErrorResponse(DNSHeader* header, ns_rcode rcode, char* response,
+58 −17
Original line number Diff line number Diff line
@@ -91,17 +91,17 @@ constexpr bool kDumpData = false;
 *     this will initialize the cache on first usage. the result can be NULL
 *     if the cache is disabled.
 *
 *   - the client calls _resolv_cache_lookup() before performing a query
 *   - the client calls resolv_cache_lookup() before performing a query
 *
 *     if the function returns RESOLV_CACHE_FOUND, a copy of the answer data
 *     has been copied into the client-provided answer buffer.
 *
 *     if the function returns RESOLV_CACHE_NOTFOUND, the client should perform
 *     a request normally, *then* call _resolv_cache_add() to add the received
 *     a request normally, *then* call resolv_cache_add() to add the received
 *     answer to the cache.
 *
 *     if the function returns RESOLV_CACHE_UNSUPPORTED, the client should
 *     perform a request normally, and *not* call _resolv_cache_add()
 *     perform a request normally, and *not* call resolv_cache_add()
 *
 *     note that RESOLV_CACHE_UNSUPPORTED is also returned if the answer buffer
 *     is too short to accomodate the cached result.
@@ -538,7 +538,7 @@ static int _dnsPacket_checkQuery(DnsPacket* packet) {
     * - there is no point for a query packet sent to a server
     *   to have the TC bit set, but the implementation might
     *   set the bit in the query buffer for its own needs
     *   between a _resolv_cache_lookup and a
     *   between a resolv_cache_lookup and a
     *   _resolv_cache_add. We should not freak out if this
     *   is the case.
     *
@@ -1407,9 +1407,8 @@ static void _cache_remove_expired(Cache* cache) {
// gets a resolv_cache_info associated with a network, or NULL if not found
static resolv_cache_info* find_cache_info_locked(unsigned netid) REQUIRES(cache_mutex);

ResolvCacheStatus _resolv_cache_lookup(unsigned netid, const void* query, int querylen,
                                       void* answer, int answersize, int* answerlen,
                                       uint32_t flags) {
ResolvCacheStatus resolv_cache_lookup(unsigned netid, const void* query, int querylen, void* answer,
                                      int answersize, int* answerlen, uint32_t flags) {
    // Skip cache lookup, return RESOLV_CACHE_NOTFOUND directly so that it is
    // possible to cache the answer of this query.
    // If ANDROID_RESOLV_NO_CACHE_STORE is set, return RESOLV_CACHE_SKIP to skip possible cache
@@ -1435,7 +1434,7 @@ ResolvCacheStatus _resolv_cache_lookup(unsigned netid, const void* query, int qu
    std::unique_lock lock(cache_mutex);
    ScopedAssumeLocked assume_lock(cache_mutex);
    cache = find_named_cache_locked(netid);
    if (cache == NULL) {
    if (cache == nullptr) {
        return RESOLV_CACHE_UNSUPPORTED;
    }

@@ -1512,7 +1511,7 @@ ResolvCacheStatus _resolv_cache_lookup(unsigned netid, const void* query, int qu
    return RESOLV_CACHE_FOUND;
}

void _resolv_cache_add(unsigned netid, const void* query, int querylen, const void* answer,
int resolv_cache_add(unsigned netid, const void* query, int querylen, const void* answer,
                     int answerlen) {
    Entry key[1];
    Entry* e;
@@ -1524,14 +1523,14 @@ void _resolv_cache_add(unsigned netid, const void* query, int querylen, const vo
     */
    if (!entry_init_key(key, query, querylen)) {
        LOG(INFO) << __func__ << ": passed invalid query?";
        return;
        return -EINVAL;
    }

    std::lock_guard guard(cache_mutex);

    cache = find_named_cache_locked(netid);
    if (cache == NULL) {
        return;
    if (cache == nullptr) {
        return -ENONET;
    }

    LOG(INFO) << __func__ << ": query:";
@@ -1549,7 +1548,7 @@ void _resolv_cache_add(unsigned netid, const void* query, int querylen, const vo
    if (e != NULL) {
        LOG(INFO) << __func__ << ": ALREADY IN CACHE (" << e << ") ? IGNORING ADD";
        _cache_notify_waiting_tid_locked(cache, key);
        return;
        return -EEXIST;
    }

    if (cache->num_entries >= cache->max_entries) {
@@ -1563,7 +1562,7 @@ void _resolv_cache_add(unsigned netid, const void* query, int querylen, const vo
        if (e != NULL) {
            LOG(INFO) << __func__ << ": ALREADY IN CACHE (" << e << ") ? IGNORING ADD";
            _cache_notify_waiting_tid_locked(cache, key);
            return;
            return -EEXIST;
        }
    }

@@ -1578,6 +1577,8 @@ void _resolv_cache_add(unsigned netid, const void* query, int querylen, const vo

    cache_dump_mru(cache);
    _cache_notify_waiting_tid_locked(cache, key);

    return 0;
}

// Head of the list of caches.
@@ -1676,8 +1677,8 @@ static void insert_cache_info_locked(struct resolv_cache_info* cache_info) {

static resolv_cache* find_named_cache_locked(unsigned netid) {
    resolv_cache_info* info = find_cache_info_locked(netid);
    if (info != NULL) return info->cache;
    return NULL;
    if (info != nullptr) return info->cache;
    return nullptr;
}

static resolv_cache_info* find_cache_info_locked(unsigned netid) {
@@ -2022,3 +2023,43 @@ void _resolv_cache_add_resolver_stats_sample(unsigned netid, int revision_id, in
        _res_cache_add_stats_sample_locked(&info->nsstats[ns], sample, max_samples);
    }
}

bool has_named_cache(unsigned netid) {
    std::lock_guard guard(cache_mutex);
    return find_named_cache_locked(netid) != nullptr;
}

int resolv_cache_get_expiration(unsigned netid, const std::vector<char> query, time_t* expiration) {
    Entry key;
    Entry** lookup;
    Entry* e;
    Cache* cache;
    *expiration = -1;

    // A malfored query is not allowed.
    if (!entry_init_key(&key, query.data(), query.size())) {
        LOG(WARNING) << __func__ << ": unsupported query";
        return -EINVAL;
    }

    // lookup cache.
    std::lock_guard guard(cache_mutex);
    if (cache = find_named_cache_locked(netid); cache == nullptr) {
        LOG(WARNING) << __func__ << ": cache not created in the network " << netid;
        return -ENONET;
    }
    lookup = _cache_lookup_p(cache, &key);
    e = *lookup;
    if (e == NULL) {
        LOG(WARNING) << __func__ << ": not in cache";
        return -ENODATA;
    }

    if (_time_now() >= e->expires) {
        LOG(WARNING) << __func__ << ": entry expired";
        return -ENODATA;
    }

    *expiration = e->expires;
    return 0;
}

res_cache_test.cpp

0 → 100644
+521 −0

File added.

Preview size limit exceeded, changes collapsed.

Loading