Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 595f3b09 authored by Android Build Coastguard Worker's avatar Android Build Coastguard Worker
Browse files

Snap for 8701348 from cb670aff to mainline-uwb-release

Change-Id: I5d6472d837c554d2edb4086218b24b419a18089f
parents 86b76c13 cb670aff
Loading
Loading
Loading
Loading
+123 −25
Original line number Original line Diff line number Diff line
@@ -469,7 +469,7 @@ void logDnsQueryResult(const addrinfo* res) {
                              NI_NUMERICHOST);
                              NI_NUMERICHOST);
        if (!ret) {
        if (!ret) {
            LOG(DEBUG) << __func__ << ": [" << i << "] " << ai->ai_flags << " " << ai->ai_family
            LOG(DEBUG) << __func__ << ": [" << i << "] " << ai->ai_flags << " " << ai->ai_family
                       << " " << ai->ai_socktype << " " << ai->ai_protocol;
                       << " " << ai->ai_socktype << " " << ai->ai_protocol << " " << ip_addr;
        } else {
        } else {
            LOG(DEBUG) << __func__ << ": [" << i << "] numeric hostname translation fail " << ret;
            LOG(DEBUG) << __func__ << ": [" << i << "] numeric hostname translation fail " << ret;
        }
        }
@@ -523,30 +523,96 @@ bool synthesizeNat64PrefixWithARecord(const netdutils::IPPrefix& prefix, struct
    return true;
    return true;
}
}


bool synthesizeNat64PrefixWithARecord(const netdutils::IPPrefix& prefix, addrinfo* result) {
bool synthesizeNat64PrefixWithARecord(const netdutils::IPPrefix& prefix, addrinfo** res,
    if (result == nullptr) return false;
                                      bool unspecWantedButNoIPv6,
    if (!onlyNonSpecialUseIPv4Addresses(result)) return false;
                                      const android_net_context* netcontext) {
    if (*res == nullptr) return false;
    if (!onlyNonSpecialUseIPv4Addresses(*res)) return false;
    if (!isValidNat64Prefix(prefix)) return false;
    if (!isValidNat64Prefix(prefix)) return false;


    struct sockaddr_storage ss = netdutils::IPSockAddr(prefix.ip());
    const sockaddr_storage ss = netdutils::IPSockAddr(prefix.ip());
    struct sockaddr_in6* v6prefix = (struct sockaddr_in6*)&ss;
    const sockaddr_in6* v6prefix = (sockaddr_in6*)&ss;
    for (addrinfo* ai = result; ai; ai = ai->ai_next) {
    addrinfo* const head4 = *res;
        struct sockaddr_in sinOriginal = *(struct sockaddr_in*)ai->ai_addr;
    addrinfo* head6 = nullptr;
        struct sockaddr_in6* sin6 = (struct sockaddr_in6*)ai->ai_addr;
    addrinfo* cur6 = nullptr;
        memset(sin6, 0, sizeof(sockaddr_in6));

    // Build a synthesized AAAA addrinfo list from the queried A addrinfo list. Here is the diagram
    // for the relationship of pointers.
    //
    // head4: point to the first queried A addrinfo
    // |
    // v
    // +-------------+   +-------------+
    // | addrinfo4#1 |-->| addrinfo4#2 |--> .. queried A addrinfo(s) for DNS64 synthesis
    // +-------------+   +-------------+
    //                   ^
    //                   |
    //                   cur4: current worked-on queried A addrinfo
    //
    // head6: point to the first synthesized AAAA addrinfo
    // |
    // v
    // +-------------+   +-------------+
    // | addrinfo6#1 |-->| addrinfo6#2 |--> .. synthesized DNS64 AAAA addrinfo(s)
    // +-------------+   +-------------+
    //                   ^
    //                   |
    //                   cur6: current worked-on synthesized addrinfo
    //
    for (const addrinfo* cur4 = head4; cur4; cur4 = cur4->ai_next) {
        // Allocate a space for a synthesized AAAA addrinfo. Note that the addrinfo and sockaddr
        // occupy one contiguous block of memory and are allocated and freed as a single block.
        // See get_ai and freeaddrinfo in packages/modules/DnsResolver/getaddrinfo.cpp.
        addrinfo* sa = (addrinfo*)calloc(1, sizeof(addrinfo) + sizeof(sockaddr_in6));
        if (sa == nullptr) {
            LOG(ERROR) << "allocate memory failed for synthesized result";
            freeaddrinfo(head6);
            return false;
        }


        // Synthesize /96 NAT64 prefix in place. The space has reserved by get_ai() in
        // Initialize the synthesized AAAA addrinfo by the queried A addrinfo. The ai_addr will be
        // system/netd/resolv/getaddrinfo.cpp.
        // set lately.
        sa->ai_flags = cur4->ai_flags;
        sa->ai_family = AF_INET6;
        sa->ai_socktype = cur4->ai_socktype;
        sa->ai_protocol = cur4->ai_protocol;
        sa->ai_addrlen = sizeof(sockaddr_in6);
        sa->ai_addr = (sockaddr*)(sa + 1);
        sa->ai_canonname = nullptr;
        sa->ai_next = nullptr;

        if (cur4->ai_canonname != nullptr) {
            sa->ai_canonname = strdup(cur4->ai_canonname);
            if (sa->ai_canonname == nullptr) {
                LOG(ERROR) << "allocate memory failed for canonname";
                freeaddrinfo(sa);
                freeaddrinfo(head6);
                return false;
            }
        }

        // Synthesize /96 NAT64 prefix with the queried IPv4 address.
        const sockaddr_in* sin4 = (sockaddr_in*)cur4->ai_addr;
        sockaddr_in6* sin6 = (sockaddr_in6*)sa->ai_addr;
        sin6->sin6_addr = v6prefix->sin6_addr;
        sin6->sin6_addr = v6prefix->sin6_addr;
        sin6->sin6_addr.s6_addr32[3] = sinOriginal.sin_addr.s_addr;
        sin6->sin6_addr.s6_addr32[3] = sin4->sin_addr.s_addr;
        sin6->sin6_family = AF_INET6;
        sin6->sin6_family = AF_INET6;
        sin6->sin6_port = sinOriginal.sin_port;
        sin6->sin6_port = sin4->sin_port;
        ai->ai_addrlen = sizeof(struct sockaddr_in6);

        ai->ai_family = AF_INET6;
        // If the synthesized list is empty, this becomes the first element.
        if (head6 == nullptr) {
            head6 = sa;
        }

        // Add this element to the end of the synthesized list.
        if (cur6 != nullptr) {
            cur6->ai_next = sa;
        }
        cur6 = sa;


        if (WOULD_LOG(VERBOSE)) {
        if (WOULD_LOG(VERBOSE)) {
            char buf[INET6_ADDRSTRLEN];  // big enough for either IPv4 or IPv6
            char buf[INET6_ADDRSTRLEN];  // big enough for either IPv4 or IPv6
            inet_ntop(AF_INET, &sinOriginal.sin_addr.s_addr, buf, sizeof(buf));
            inet_ntop(AF_INET, &sin4->sin_addr.s_addr, buf, sizeof(buf));
            LOG(VERBOSE) << __func__ << ": DNS A record: " << buf;
            LOG(VERBOSE) << __func__ << ": DNS A record: " << buf;
            inet_ntop(AF_INET6, &v6prefix->sin6_addr, buf, sizeof(buf));
            inet_ntop(AF_INET6, &v6prefix->sin6_addr, buf, sizeof(buf));
            LOG(VERBOSE) << __func__ << ": NAT64 prefix: " << buf;
            LOG(VERBOSE) << __func__ << ": NAT64 prefix: " << buf;
@@ -554,7 +620,39 @@ bool synthesizeNat64PrefixWithARecord(const netdutils::IPPrefix& prefix, addrinf
            LOG(VERBOSE) << __func__ << ": DNS64 Synthesized AAAA record: " << buf;
            LOG(VERBOSE) << __func__ << ": DNS64 Synthesized AAAA record: " << buf;
        }
        }
    }
    }
    logDnsQueryResult(result);

    // Simply concatenate the synthesized AAAA addrinfo list and the queried A addrinfo list when
    // AF_UNSPEC is specified. In the other words, the IPv6 addresses are listed first and then
    // IPv4 addresses. For example:
    //     64:ff9b::102:304 (socktype=2, protocol=17) ->
    //     64:ff9b::102:304 (socktype=1, protocol=6) ->
    //     1.2.3.4 (socktype=2, protocol=17) ->
    //     1.2.3.4 (socktype=1, protocol=6)
    // Note that head6 and cur6 should be non-null because there was at least one IPv4 address
    // synthesized. From the above example, the synthesized addrinfo list puts IPv6 and IPv4 in
    // groups and sort by RFC 6724 later. This ordering is different from no synthesized case
    // because resolv_getaddrinfo() sorts results in explore_options. resolv_getaddrinfo() calls
    // explore_fqdn() many times by the different items of explore_options. It means that
    // resolv_rfc6724_sort() only sorts the results in each explore_options and concatenates each
    // results into one. For example, getaddrinfo() is called with null hints for a domain name
    // which has both IPv4 and IPv6 addresses. The address order of the result addrinfo may be:
    //     2001:db8::102:304 (socktype=2, protocol=17) -> 1.2.3.4 (socktype=2, protocol=17) ->
    //     2001:db8::102:304 (socktype=1, protocol=6) -> 1.2.3.4 (socktype=1, protocol=6)
    // In above example, the first two results come from one explore option and the last two come
    // from another one. They are sorted first, and then concatenate together to be the result.
    // See also resolv_getaddrinfo in packages/modules/DnsResolver/getaddrinfo.cpp.
    if (unspecWantedButNoIPv6) {
        cur6->ai_next = head4;
    } else {
        freeaddrinfo(head4);
    }

    // Sort the concatenated addresses by RFC 6724 section 2.1.
    struct addrinfo sorting_head = {.ai_next = head6};
    resolv_rfc6724_sort(&sorting_head, netcontext->app_mark, netcontext->uid);

    *res = sorting_head.ai_next;
    logDnsQueryResult(*res);
    return true;
    return true;
}
}


@@ -713,7 +811,7 @@ void DnsProxyListener::GetAddrInfoHandler::doDns64Synthesis(int32_t* rv, addrinf
        }
        }
    }
    }


    if (!synthesizeNat64PrefixWithARecord(prefix, *res)) {
    if (!synthesizeNat64PrefixWithARecord(prefix, res, unspecWantedButNoIPv6, &mNetContext)) {
        if (ipv6WantedButNoData) {
        if (ipv6WantedButNoData) {
            // If caller wants IPv6 answers but no data and failed to synthesize IPv6 answers,
            // If caller wants IPv6 answers but no data and failed to synthesize IPv6 answers,
            // don't return the IPv4 answers.
            // don't return the IPv4 answers.
@@ -727,9 +825,9 @@ void DnsProxyListener::GetAddrInfoHandler::doDns64Synthesis(int32_t* rv, addrinf
}
}


void DnsProxyListener::GetAddrInfoHandler::run() {
void DnsProxyListener::GetAddrInfoHandler::run() {
    LOG(DEBUG) << "GetAddrInfoHandler::run: {" << mNetContext.app_netid << " "
    LOG(INFO) << "GetAddrInfoHandler::run: {" << mNetContext.app_netid << " "
               << mNetContext.app_mark << " " << mNetContext.dns_netid << " "
              << mNetContext.app_mark << " " << mNetContext.dns_netid << " " << mNetContext.dns_mark
               << mNetContext.dns_mark << " " << mNetContext.uid << " " << mNetContext.flags << "}";
              << " " << mNetContext.uid << " " << mNetContext.flags << "}";


    addrinfo* result = nullptr;
    addrinfo* result = nullptr;
    Stopwatch s;
    Stopwatch s;
@@ -902,9 +1000,9 @@ DnsProxyListener::ResNSendHandler::ResNSendHandler(SocketClient* c, std::string
    : Handler(c), mMsg(std::move(msg)), mFlags(flags), mNetContext(netcontext) {}
    : Handler(c), mMsg(std::move(msg)), mFlags(flags), mNetContext(netcontext) {}


void DnsProxyListener::ResNSendHandler::run() {
void DnsProxyListener::ResNSendHandler::run() {
    LOG(DEBUG) << "ResNSendHandler::run: " << mFlags << " / {" << mNetContext.app_netid << " "
    LOG(INFO) << "ResNSendHandler::run: " << mFlags << " / {" << mNetContext.app_netid << " "
               << mNetContext.app_mark << " " << mNetContext.dns_netid << " "
              << mNetContext.app_mark << " " << mNetContext.dns_netid << " " << mNetContext.dns_mark
               << mNetContext.dns_mark << " " << mNetContext.uid << " " << mNetContext.flags << "}";
              << " " << mNetContext.uid << " " << mNetContext.flags << "}";


    Stopwatch s;
    Stopwatch s;
    maybeFixupNetContext(&mNetContext, mClient->getPid());
    maybeFixupNetContext(&mNetContext, mClient->getPid());
+3 −2
Original line number Original line Diff line number Diff line
@@ -31,8 +31,9 @@
bool resolv_init(const ResolverNetdCallbacks* callbacks) {
bool resolv_init(const ResolverNetdCallbacks* callbacks) {
    android::base::InitLogging(/*argv=*/nullptr);
    android::base::InitLogging(/*argv=*/nullptr);
    LOG(INFO) << __func__ << ": Initializing resolver";
    LOG(INFO) << __func__ << ": Initializing resolver";
    resolv_set_log_severity(android::base::WARNING);
    const bool isDebug = isUserDebugBuild();
    doh_init_logger(DOH_LOG_LEVEL_WARN);
    resolv_set_log_severity(isDebug ? android::base::INFO : android::base::WARNING);
    doh_init_logger(isDebug ? DOH_LOG_LEVEL_INFO : DOH_LOG_LEVEL_WARN);
    using android::net::gApiLevel;
    using android::net::gApiLevel;
    gApiLevel = getApiLevel();
    gApiLevel = getApiLevel();
    using android::net::gResNetdCallbacks;
    using android::net::gResNetdCallbacks;
+47 −49
Original line number Original line Diff line number Diff line
@@ -67,8 +67,7 @@ std::list<DnsTlsServer> DnsTlsDispatcher::getOrderedAndUsableServerList(


        for (const auto& tlsServer : tlsServers) {
        for (const auto& tlsServer : tlsServers) {
            const Key key = std::make_pair(mark, tlsServer);
            const Key key = std::make_pair(mark, tlsServer);
            if (const Transport* xport = getTransport(key); xport != nullptr) {
            if (Transport* xport = getTransport(key); xport != nullptr) {
                // DoT revalidation specific feature.
                if (!xport->usable()) {
                if (!xport->usable()) {
                    // Don't use this xport. It will be removed after timeout
                    // Don't use this xport. It will be removed after timeout
                    // (IDLE_TIMEOUT minutes).
                    // (IDLE_TIMEOUT minutes).
@@ -112,7 +111,13 @@ DnsTlsTransport::Response DnsTlsDispatcher::query(const std::list<DnsTlsServer>&
    const std::list<DnsTlsServer> servers(
    const std::list<DnsTlsServer> servers(
            getOrderedAndUsableServerList(tlsServers, statp->netid, statp->mark));
            getOrderedAndUsableServerList(tlsServers, statp->netid, statp->mark));


    if (servers.empty()) LOG(WARNING) << "No usable DnsTlsServers";
    if (servers.empty()) {
        LOG(WARNING) << "No usable DnsTlsServers";

        // Call maybeCleanup so the expired Transports can be removed as expected.
        std::lock_guard guard(sLock);
        maybeCleanup(std::chrono::steady_clock::now());
    }


    DnsTlsTransport::Response code = DnsTlsTransport::Response::internal_error;
    DnsTlsTransport::Response code = DnsTlsTransport::Response::internal_error;
    int serverCount = 0;
    int serverCount = 0;
@@ -209,9 +214,14 @@ DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, un
        std::lock_guard guard(sLock);
        std::lock_guard guard(sLock);
        --xport->useCount;
        --xport->useCount;
        xport->lastUsed = now;
        xport->lastUsed = now;
        if (code == DnsTlsTransport::Response::network_error) {
            xport->continuousfailureCount++;
        } else {
            xport->continuousfailureCount = 0;
        }


        // DoT revalidation specific feature.
        // DoT revalidation specific feature.
        if (xport->checkRevalidationNecessary(code)) {
        if (xport->checkRevalidationNecessary()) {
            // Even if the revalidation passes, it doesn't guarantee that DoT queries
            // Even if the revalidation passes, it doesn't guarantee that DoT queries
            // to the xport can stop failing because revalidation creates a new connection
            // to the xport can stop failing because revalidation creates a new connection
            // to probe while the xport still uses an existing connection. So far, there isn't
            // to probe while the xport still uses an existing connection. So far, there isn't
@@ -226,14 +236,14 @@ DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, un
                         << (result.ok() ? "succeeded" : "failed: " + result.error().message());
                         << (result.ok() ? "succeeded" : "failed: " + result.error().message());
        }
        }


        cleanup(now);
        maybeCleanup(now);
    }
    }
    return code;
    return code;
}
}


void DnsTlsDispatcher::forceCleanup(unsigned netId) {
void DnsTlsDispatcher::forceCleanup(unsigned netId) {
    std::lock_guard guard(sLock);
    std::lock_guard guard(sLock);
    forceCleanupLocked(netId);
    cleanup(std::chrono::steady_clock::now(), netId);
}
}


DnsTlsTransport::Result DnsTlsDispatcher::queryInternal(Transport& xport,
DnsTlsTransport::Result DnsTlsDispatcher::queryInternal(Transport& xport,
@@ -265,33 +275,26 @@ DnsTlsTransport::Result DnsTlsDispatcher::queryInternal(Transport& xport,


// This timeout effectively controls how long to keep SSL session tickets.
// This timeout effectively controls how long to keep SSL session tickets.
static constexpr std::chrono::minutes IDLE_TIMEOUT(5);
static constexpr std::chrono::minutes IDLE_TIMEOUT(5);
void DnsTlsDispatcher::cleanup(std::chrono::time_point<std::chrono::steady_clock> now) {
void DnsTlsDispatcher::maybeCleanup(std::chrono::time_point<std::chrono::steady_clock> now) {
    // To avoid scanning mStore after every query, return early if a cleanup has been
    // To avoid scanning mStore after every query, return early if a cleanup has been
    // performed recently.
    // performed recently.
    if (now - mLastCleanup < IDLE_TIMEOUT) {
    if (now - mLastCleanup < IDLE_TIMEOUT) {
        return;
        return;
    }
    }
    for (auto it = mStore.begin(); it != mStore.end();) {
    cleanup(now, std::nullopt);
        auto& s = it->second;
        if (s->useCount == 0 && now - s->lastUsed > IDLE_TIMEOUT) {
            it = mStore.erase(it);
        } else {
            ++it;
        }
    }
    mLastCleanup = now;
    mLastCleanup = now;
}
}


// TODO: unify forceCleanupLocked() and cleanup().
void DnsTlsDispatcher::cleanup(std::chrono::time_point<std::chrono::steady_clock> now,
void DnsTlsDispatcher::forceCleanupLocked(unsigned netId) {
                               std::optional<unsigned> netId) {
    for (auto it = mStore.begin(); it != mStore.end();) {
    std::erase_if(mStore, [&](const auto& item) REQUIRES(sLock) {
        auto& s = it->second;
        auto const& [_, xport] = item;
        if (s->useCount == 0 && s->mNetId == netId) {
        if (xport->useCount == 0) {
            it = mStore.erase(it);
            if (netId.has_value() && xport->mNetId == netId.value()) return true;
        } else {
            if (now - xport->lastUsed > IDLE_TIMEOUT) return true;
            ++it;
        }
        }
        }
        return false;
    });
}
}


DnsTlsDispatcher::Transport* DnsTlsDispatcher::addTransport(const DnsTlsServer& server,
DnsTlsDispatcher::Transport* DnsTlsDispatcher::addTransport(const DnsTlsServer& server,
@@ -308,12 +311,11 @@ DnsTlsDispatcher::Transport* DnsTlsDispatcher::addTransport(const DnsTlsServer&
    int queryTimeout = instance->getFlag("dot_query_timeout_ms", Transport::kDotQueryTimeoutMs);
    int queryTimeout = instance->getFlag("dot_query_timeout_ms", Transport::kDotQueryTimeoutMs);


    // Check and adjust the parameters if they are improperly set.
    // Check and adjust the parameters if they are improperly set.
    bool revalidationEnabled = false;
    const bool isForOpportunisticMode = server.name.empty();
    const bool isForOpportunisticMode = server.name.empty();
    if (triggerThr > 0 && unusableThr > 0 && isForOpportunisticMode) {
    if (triggerThr <= 0 || !isForOpportunisticMode) {
        revalidationEnabled = true;
    } else {
        triggerThr = -1;
        triggerThr = -1;
    }
    if (unusableThr <= 0 || !isForOpportunisticMode) {
        unusableThr = -1;
        unusableThr = -1;
    }
    }
    if (queryTimeout < 0) {
    if (queryTimeout < 0) {
@@ -322,9 +324,8 @@ DnsTlsDispatcher::Transport* DnsTlsDispatcher::addTransport(const DnsTlsServer&
        queryTimeout = 1000;
        queryTimeout = 1000;
    }
    }


    ret = new Transport(server, mark, netId, mFactory.get(), revalidationEnabled, triggerThr,
    ret = new Transport(server, mark, netId, mFactory.get(), triggerThr, unusableThr, queryTimeout);
                        unusableThr, queryTimeout);
    LOG(INFO) << "Transport is initialized with { " << triggerThr << ", " << unusableThr << ", "
    LOG(DEBUG) << "Transport is initialized with { " << triggerThr << ", " << unusableThr << ", "
              << queryTimeout << "ms }"
              << queryTimeout << "ms }"
              << " for server { " << server.toIpString() << "/" << server.name << " }";
              << " for server { " << server.toIpString() << "/" << server.name << " }";


@@ -338,26 +339,23 @@ DnsTlsDispatcher::Transport* DnsTlsDispatcher::getTransport(const Key& key) {
    return (it == mStore.end() ? nullptr : it->second.get());
    return (it == mStore.end() ? nullptr : it->second.get());
}
}


bool DnsTlsDispatcher::Transport::checkRevalidationNecessary(DnsTlsTransport::Response code) {
bool DnsTlsDispatcher::Transport::checkRevalidationNecessary() {
    if (!revalidationEnabled) return false;
    if (triggerThreshold <= 0) return false;
    if (continuousfailureCount < triggerThreshold) return false;
    if (isRevalidationThresholdReached) return false;


    if (code == DnsTlsTransport::Response::network_error) {
    isRevalidationThresholdReached = true;
        continuousfailureCount++;
    } else {
        continuousfailureCount = 0;
    }

    // triggerThreshold must be greater than 0 because the value of revalidationEnabled is true.
    if (usable() && continuousfailureCount == triggerThreshold) {
    return true;
    return true;
}
}
    return false;
}


bool DnsTlsDispatcher::Transport::usable() const {
bool DnsTlsDispatcher::Transport::usable() {
    if (!revalidationEnabled) return true;
    if (unusableThreshold <= 0) return true;


    return continuousfailureCount < unusableThreshold;
    if (continuousfailureCount >= unusableThreshold) {
        // Once reach the threshold, mark this Transport as unusable.
        isXportUnusableThresholdReached = true;
    }
    return !isXportUnusableThresholdReached;
}
}


}  // end of namespace net
}  // end of namespace net
+24 −23
Original line number Original line Diff line number Diff line
@@ -83,11 +83,10 @@ class DnsTlsDispatcher : public PrivateDnsValidationObserver {
    // usage monitoring so we can expire idle sessions from the cache.
    // usage monitoring so we can expire idle sessions from the cache.
    struct Transport {
    struct Transport {
        Transport(const DnsTlsServer& server, unsigned mark, unsigned netId,
        Transport(const DnsTlsServer& server, unsigned mark, unsigned netId,
                  IDnsTlsSocketFactory* _Nonnull factory, bool revalidationEnabled, int triggerThr,
                  IDnsTlsSocketFactory* _Nonnull factory, int triggerThr, int unusableThr,
                  int unusableThr, int timeout)
                  int timeout)
            : transport(server, mark, factory),
            : transport(server, mark, factory),
              mNetId(netId),
              mNetId(netId),
              revalidationEnabled(revalidationEnabled),
              triggerThreshold(triggerThr),
              triggerThreshold(triggerThr),
              unusableThreshold(unusableThr),
              unusableThreshold(unusableThr),
              mTimeout(timeout) {}
              mTimeout(timeout) {}
@@ -106,9 +105,12 @@ class DnsTlsDispatcher : public PrivateDnsValidationObserver {


        // If DoT revalidation is disabled, it returns true; otherwise, it returns
        // If DoT revalidation is disabled, it returns true; otherwise, it returns
        // whether or not this Transport is usable.
        // whether or not this Transport is usable.
        bool usable() const REQUIRES(sLock);
        bool usable() REQUIRES(sLock);


        bool checkRevalidationNecessary(DnsTlsTransport::Response code) REQUIRES(sLock);
        // Used to track if this Transport is usable.
        int continuousfailureCount GUARDED_BY(sLock) = 0;

        bool checkRevalidationNecessary() REQUIRES(sLock);


        std::chrono::milliseconds timeout() const { return mTimeout; }
        std::chrono::milliseconds timeout() const { return mTimeout; }


@@ -117,25 +119,24 @@ class DnsTlsDispatcher : public PrivateDnsValidationObserver {
        static constexpr int kDotQueryTimeoutMs = -1;
        static constexpr int kDotQueryTimeoutMs = -1;


      private:
      private:
        // Used to track if this Transport is usable.
        // The flag to record whether or not dot_revalidation_threshold is ever reached.
        int continuousfailureCount GUARDED_BY(sLock) = 0;
        bool isRevalidationThresholdReached GUARDED_BY(sLock) = false;


        // Used to indicate whether DoT revalidation is enabled for this Transport.
        // The flag to record whether or not dot_xport_unusable_threshold is ever reached.
        // The value is set to true only if:
        bool isXportUnusableThresholdReached GUARDED_BY(sLock) = false;
        //    1. both triggerThreshold and unusableThreshold are  positive values.
        //    2. private DNS mode is opportunistic.
        const bool revalidationEnabled;


        // The number of continuous failures to trigger a validation. It takes effect when DoT
        // If the number of continuous query timeouts reaches the threshold, mark the
        // revalidation is on. If the value is not a positive value, DoT revalidation is disabled.
        // server as unvalidated and trigger a validation.
        // Note that it must be at least 10, or it breaks ConnectTlsServerTimeout_ConcurrentQueries
        // If the value is not a positive value or private DNS mode is strict mode, no threshold is
        // test.
        // set. Note that it must be at least 10, or it breaks
        // ConnectTlsServerTimeout_ConcurrentQueries test.
        const int triggerThreshold;
        const int triggerThreshold;


        // The threshold to determine if this Transport is considered unusable.
        // The threshold to determine if this Transport is considered unusable.
        // If continuousfailureCount reaches this value, this Transport is no longer used. It
        // If the number of continuous query timeouts reaches the threshold, mark this
        // takes effect when DoT revalidation is on. If the value is not a positive value, DoT
        // Transport as unusable. An unusable Transport won't be used anymore.
        // revalidation is disabled.
        // If the value is not a positive value or private DNS mode is strict mode, no threshold is
        // set.
        const int unusableThreshold;
        const int unusableThreshold;


        // The time to await a future (the result of a DNS request) from the DnsTlsTransport
        // The time to await a future (the result of a DNS request) from the DnsTlsTransport
@@ -159,12 +160,12 @@ class DnsTlsDispatcher : public PrivateDnsValidationObserver {
    DnsTlsTransport::Result queryInternal(Transport& transport, const netdutils::Slice query)
    DnsTlsTransport::Result queryInternal(Transport& transport, const netdutils::Slice query)
            EXCLUDES(sLock);
            EXCLUDES(sLock);


    void maybeCleanup(std::chrono::time_point<std::chrono::steady_clock> now) REQUIRES(sLock);

    // Drop any cache entries whose useCount is zero and which have not been used recently.
    // Drop any cache entries whose useCount is zero and which have not been used recently.
    // This function performs a linear scan of mStore.
    // This function performs a linear scan of mStore.
    void cleanup(std::chrono::time_point<std::chrono::steady_clock> now) REQUIRES(sLock);
    void cleanup(std::chrono::time_point<std::chrono::steady_clock> now,

                 std::optional<unsigned> netId) REQUIRES(sLock);
    // Force dropping any Transports whose useCount is zero.
    void forceCleanupLocked(unsigned netId) REQUIRES(sLock);


    // Return a sorted list of usable DnsTlsServers in preference order.
    // Return a sorted list of usable DnsTlsServers in preference order.
    std::list<DnsTlsServer> getOrderedAndUsableServerList(const std::list<DnsTlsServer>& tlsServers,
    std::list<DnsTlsServer> getOrderedAndUsableServerList(const std::list<DnsTlsServer>& tlsServers,
+11 −5
Original line number Original line Diff line number Diff line
@@ -145,6 +145,9 @@ void PrivateDnsConfiguration::clear(unsigned netId) {
    std::lock_guard guard(mPrivateDnsLock);
    std::lock_guard guard(mPrivateDnsLock);
    mPrivateDnsModes.erase(netId);
    mPrivateDnsModes.erase(netId);
    mPrivateDnsTransports.erase(netId);
    mPrivateDnsTransports.erase(netId);

    // Notify the relevant private DNS validations, if they are waiting, to finish.
    mCv.notify_all();
}
}


base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
@@ -224,14 +227,17 @@ void PrivateDnsConfiguration::startValidation(const ServerIdentity& identity, un
            const bool needs_reeval =
            const bool needs_reeval =
                    this->recordPrivateDnsValidation(identity, netId, success, isRevalidation);
                    this->recordPrivateDnsValidation(identity, netId, success, isRevalidation);


            if (!needs_reeval) {
            if (!needs_reeval || !backoff.hasNextTimeout()) {
                break;
                break;
            }
            }


            if (backoff.hasNextTimeout()) {
            std::unique_lock<std::mutex> cvGuard(mPrivateDnsLock);
                // TODO: make the thread able to receive signals to shutdown early.
            // If the timeout expired and the predicate still evaluates to false, wait_for returns
                std::this_thread::sleep_for(backoff.getNextTimeout());
            // false.
            } else {
            if (mCv.wait_for(cvGuard, backoff.getNextTimeout(),
                             [this, netId]() REQUIRES(mPrivateDnsLock) {
                                 return mPrivateDnsModes.find(netId) == mPrivateDnsModes.end();
                             })) {
                break;
                break;
            }
            }
        }
        }
Loading