Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 82ae84b9 authored by Mike Yu's avatar Mike Yu
Browse files

Implement DoT revalidation

The revalidation starts from DnsTlsDispatcher which uses a counter
for counting the number of continuous network_error failures of a
DoT server. The mechanics works for private DNS opportunistic mode.

- Once the counter reaches dot_revalidation_threshold, DnsTlsDispatcher
  sends a revalidation request to PrivateDnsConfiguration to validate
  the DoT server.
- Once the counter reaches dot_xport_unusable_threshold, DnsTlsDispatcher
  marks the transport of the DoT server as unusable. The DoT server
  won't be used for at least 5 minutes.

DoT revalidation runs when the followings are met:
  [1] the private DNS setting is opportunistic mode
  [2] the requested DoT server is valid to be used on the network
  [3] the requested DoT server is currently marked as Validation::success

The above mechanics runs when the feature flag "dot_revalidation_threshold"
is a positive and zon-zero value, and is -1 when the mechanics is
disabled.

Bug: 79727473
Test: atest when all the flags off
        dot_revalidation_threshold: -1
        dot_async_handshake: 0
        dot_xport_unusable_threshold: -1
        dot_maxtries: 3
        parallel_lookup_sleep_time: 2
        dot_connect_timeout_ms: 127000
        parallel_lookup_release: 0
        sort_nameservers: 0
        keep_listening_udp: 0

Test: atest when all the flags on
        dot_revalidation_threshold: 10
        dot_async_handshake: 1
        dot_xport_unusable_threshold: 20
        dot_maxtries: 1
        parallel_lookup_sleep_time: 2
        dot_connect_timeout_ms: 10000
        parallel_lookup_release: 1
        sort_nameservers: 1
        keep_listening_udp: 1

Change-Id: Id442529468d63156a9aebf30ea5f142dfa689a97
parent 9310ca22
Loading
Loading
Loading
Loading
+100 −16
Original line number Diff line number Diff line
@@ -21,6 +21,8 @@
#include <netdutils/Stopwatch.h>

#include "DnsTlsSocketFactory.h"
#include "Experiments.h"
#include "PrivateDnsConfiguration.h"
#include "resolv_cache.h"
#include "resolv_private.h"
#include "stats.pb.h"
@@ -46,8 +48,8 @@ DnsTlsDispatcher& DnsTlsDispatcher::getInstance() {
    return instance;
}

std::list<DnsTlsServer> DnsTlsDispatcher::getOrderedServerList(
        const std::list<DnsTlsServer> &tlsServers, unsigned mark) const {
std::list<DnsTlsServer> DnsTlsDispatcher::getOrderedAndUsableServerList(
        const std::list<DnsTlsServer>& tlsServers, unsigned netId, unsigned mark) {
    // Our preferred DnsTlsServer order is:
    //     1) reuse existing IPv6 connections
    //     2) reuse existing IPv4 connections
@@ -65,7 +67,16 @@ std::list<DnsTlsServer> DnsTlsDispatcher::getOrderedServerList(

        for (const auto& tlsServer : tlsServers) {
            const Key key = std::make_pair(mark, tlsServer);
            if (mStore.find(key) != mStore.end()) {
            if (const Transport* xport = getTransport(key); xport != nullptr) {
                // DoT revalidation specific feature.
                if (!xport->usable()) {
                    // Don't use this xport. It will be removed after timeout
                    // (IDLE_TIMEOUT minutes).
                    LOG(DEBUG) << "Skip using DoT server " << tlsServer.toIpString() << " on "
                               << netId;
                    continue;
                }

                switch (tlsServer.ss.ss_family) {
                    case AF_INET:
                        existing4.push_back(tlsServer);
@@ -97,19 +108,21 @@ std::list<DnsTlsServer> DnsTlsDispatcher::getOrderedServerList(
DnsTlsTransport::Response DnsTlsDispatcher::query(const std::list<DnsTlsServer>& tlsServers,
                                                  res_state statp, const Slice query,
                                                  const Slice ans, int* resplen) {
    const std::list<DnsTlsServer> orderedServers(getOrderedServerList(tlsServers, statp->_mark));
    const std::list<DnsTlsServer> servers(
            getOrderedAndUsableServerList(tlsServers, statp->netid, statp->_mark));

    if (orderedServers.empty()) LOG(WARNING) << "Empty DnsTlsServer list";
    if (servers.empty()) LOG(WARNING) << "No usable DnsTlsServers";

    DnsTlsTransport::Response code = DnsTlsTransport::Response::internal_error;
    int serverCount = 0;
    for (const auto& server : orderedServers) {
    for (const auto& server : servers) {
        DnsQueryEvent* dnsQueryEvent =
                statp->event->mutable_dns_query_events()->add_dns_query_event();

        bool connectTriggered = false;
        Stopwatch queryStopwatch;
        code = this->query(server, statp->_mark, query, ans, resplen, &connectTriggered);
        code = this->query(server, statp->netid, statp->_mark, query, ans, resplen,
                           &connectTriggered);

        dnsQueryEvent->set_latency_micros(saturate_cast<int32_t>(queryStopwatch.timeTakenUs()));
        dnsQueryEvent->set_dns_server_index(serverCount++);
@@ -148,9 +161,9 @@ DnsTlsTransport::Response DnsTlsDispatcher::query(const std::list<DnsTlsServer>&
    return code;
}

DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, unsigned mark,
                                                  const Slice query, const Slice ans, int* resplen,
                                                  bool* connectTriggered) {
DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, unsigned netId,
                                                  unsigned mark, const Slice query, const Slice ans,
                                                  int* resplen, bool* connectTriggered) {
    // TODO: This can cause the resolver to create multiple connections to the same DoT server
    // merely due to different mark, such as the bit explicitlySelected unset.
    // See if we can save them and just create one connection for one DoT server.
@@ -158,12 +171,8 @@ DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, un
    Transport* xport;
    {
        std::lock_guard guard(sLock);
        auto it = mStore.find(key);
        if (it == mStore.end()) {
            xport = new Transport(server, mark, mFactory.get());
            mStore[key].reset(xport);
        } else {
            xport = it->second.get();
        if (xport = getTransport(key); xport == nullptr) {
            xport = addTransport(server, mark);
        }
        ++xport->useCount;
    }
@@ -198,6 +207,23 @@ DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, un
        std::lock_guard guard(sLock);
        --xport->useCount;
        xport->lastUsed = now;

        // DoT revalidation specific feature.
        if (xport->checkRevalidationNecessary(code)) {
            // Even if the revalidation passes, it doesn't guarantee that DoT queries
            // to the xport can stop failing because revalidation creates a new connection
            // to probe while the xport still uses an existing connection. So far, there isn't
            // a feasible way to force the xport to disconnect the connection. If the case
            // happens, the xport will be marked as unusable and DoT queries won't be sent to
            // it anymore. Eventually, after IDLE_TIMEOUT, the xport will be destroyed, and
            // a new xport will be created.
            const auto result =
                    PrivateDnsConfiguration::getInstance().requestValidation(netId, server, mark);
            LOG(WARNING) << "Requested validation for " << server.toIpString() << " with mark 0x"
                         << std::hex << mark << ", "
                         << (result.ok() ? "succeeded" : "failed: " + result.error().message());
        }

        cleanup(now);
    }
    return code;
@@ -222,5 +248,63 @@ void DnsTlsDispatcher::cleanup(std::chrono::time_point<std::chrono::steady_clock
    mLastCleanup = now;
}

DnsTlsDispatcher::Transport* DnsTlsDispatcher::addTransport(const DnsTlsServer& server,
                                                            unsigned mark) {
    const Key key = std::make_pair(mark, server);
    Transport* ret = getTransport(key);
    if (ret != nullptr) return ret;

    const Experiments* const instance = Experiments::getInstance();
    int triggerThr =
            instance->getFlag("dot_revalidation_threshold", Transport::kDotRevalidationThreshold);
    int unusableThr = instance->getFlag("dot_xport_unusable_threshold",
                                        Transport::kDotXportUnusableThreshold);

    // Check and adjust the parameters if they are improperly set.
    bool revalidationEnabled = false;
    const bool isForOpportunisticMode = server.name.empty();
    if (triggerThr > 0 && unusableThr > 0 && isForOpportunisticMode) {
        revalidationEnabled = true;
    } else {
        triggerThr = -1;
        unusableThr = -1;
    }

    ret = new Transport(server, mark, mFactory.get(), revalidationEnabled, triggerThr, unusableThr);
    LOG(DEBUG) << "Transport is initialized with { " << triggerThr << ", " << unusableThr << "}"
               << " for server { " << server.toIpString() << "/" << server.name << " }";

    mStore[key].reset(ret);

    return ret;
}

DnsTlsDispatcher::Transport* DnsTlsDispatcher::getTransport(const Key& key) {
    auto it = mStore.find(key);
    return (it == mStore.end() ? nullptr : it->second.get());
}

bool DnsTlsDispatcher::Transport::checkRevalidationNecessary(DnsTlsTransport::Response code) {
    if (!revalidationEnabled) return false;

    if (code == DnsTlsTransport::Response::network_error) {
        continuousfailureCount++;
    } else {
        continuousfailureCount = 0;
    }

    // triggerThreshold must be greater than 0 because the value of revalidationEnabled is true.
    if (usable() && continuousfailureCount == triggerThreshold) {
        return true;
    }
    return false;
}

bool DnsTlsDispatcher::Transport::usable() const {
    if (!revalidationEnabled) return true;

    return continuousfailureCount < unusableThreshold;
}

}  // end of namespace net
}  // end of namespace android
+45 −7
Original line number Diff line number Diff line
@@ -36,6 +36,7 @@ namespace net {

// This is a singleton class that manages the collection of active DnsTlsTransports.
// Queries made here are dispatched to an existing or newly constructed DnsTlsTransport.
// TODO: PrivateDnsValidationObserver is not implemented in this class. Remove it.
class DnsTlsDispatcher : public PrivateDnsValidationObserver {
  public:
    // Constructor with dependency injection for testing.
@@ -57,7 +58,7 @@ class DnsTlsDispatcher : public PrivateDnsValidationObserver {
    // and writes the response into |ans|, and indicates the number of bytes written in |resplen|.
    // If the whole procedure above triggers (or experiences) any new connection, |connectTriggered|
    // is set. Returns a success or error code.
    DnsTlsTransport::Response query(const DnsTlsServer& server, unsigned mark,
    DnsTlsTransport::Response query(const DnsTlsServer& server, unsigned netId, unsigned mark,
                                    const netdutils::Slice query, const netdutils::Slice ans,
                                    int* _Nonnull resplen, bool* _Nonnull connectTriggered);

@@ -78,8 +79,12 @@ class DnsTlsDispatcher : public PrivateDnsValidationObserver {
    // Transport is a thin wrapper around DnsTlsTransport, adding reference counting and
    // usage monitoring so we can expire idle sessions from the cache.
    struct Transport {
        Transport(const DnsTlsServer& server, unsigned mark, IDnsTlsSocketFactory* _Nonnull factory)
            : transport(server, mark, factory) {}
        Transport(const DnsTlsServer& server, unsigned mark, IDnsTlsSocketFactory* _Nonnull factory,
                  bool revalidationEnabled, int triggerThr, int unusableThr)
            : transport(server, mark, factory),
              revalidationEnabled(revalidationEnabled),
              triggerThreshold(triggerThr),
              unusableThreshold(unusableThr) {}
        // DnsTlsTransport is thread-safe, so it doesn't need to be guarded.
        DnsTlsTransport transport;
        // This use counter and timestamp are used to ensure that only idle sessions are
@@ -87,11 +92,44 @@ class DnsTlsDispatcher : public PrivateDnsValidationObserver {
        int useCount GUARDED_BY(sLock) = 0;
        // lastUsed is only guaranteed to be meaningful after useCount is decremented to zero.
        std::chrono::time_point<std::chrono::steady_clock> lastUsed GUARDED_BY(sLock);

        // If DoT revalidation is disabled, it returns true; otherwise, it returns
        // whether or not this Transport is usable.
        bool usable() const REQUIRES(sLock);

        bool checkRevalidationNecessary(DnsTlsTransport::Response code) REQUIRES(sLock);

        static constexpr int kDotRevalidationThreshold = -1;
        static constexpr int kDotXportUnusableThreshold = -1;

      private:
        // Used to track if this Transport is usable.
        int continuousfailureCount GUARDED_BY(sLock) = 0;

        // Used to indicate whether DoT revalidation is enabled for this Transport.
        // The value is set to true only if:
        //    1. both triggerThreshold and unusableThreshold are  positive values.
        //    2. private DNS mode is opportunistic.
        const bool revalidationEnabled;

        // The number of continuous failures to trigger a validation. It takes effect when DoT
        // revalidation is on. If the value is not a positive value, DoT revalidation is disabled.
        // Note that it must be at least 10, or it breaks ConnectTlsServerTimeout_ConcurrentQueries
        // test.
        const int triggerThreshold;

        // The threshold to determine if this Transport is considered unusable.
        // If continuousfailureCount reaches this value, this Transport is no longer used. It
        // takes effect when DoT revalidation is on. If the value is not a positive value, DoT
        // revalidation is disabled.
        const int unusableThreshold;
    };

    Transport* _Nullable addTransport(const DnsTlsServer& server, unsigned mark) REQUIRES(sLock);
    Transport* _Nullable getTransport(const Key& key) REQUIRES(sLock);

    // Cache of reusable DnsTlsTransports.  Transports stay in cache as long as
    // they are in use and for a few minutes after.
    // The key is a (netid, server) pair.  The netid is first for lexicographic comparison speed.
    std::map<Key, std::unique_ptr<Transport>> mStore GUARDED_BY(sLock);

    // The last time we did a cleanup.  For efficiency, we only perform a cleanup once every
@@ -102,9 +140,9 @@ class DnsTlsDispatcher : public PrivateDnsValidationObserver {
    // This function performs a linear scan of mStore.
    void cleanup(std::chrono::time_point<std::chrono::steady_clock> now) REQUIRES(sLock);

    // Return a sorted list of DnsTlsServers in preference order.
    std::list<DnsTlsServer> getOrderedServerList(const std::list<DnsTlsServer>& tlsServers,
                                                 unsigned mark) const;
    // Return a sorted list of usable DnsTlsServers in preference order.
    std::list<DnsTlsServer> getOrderedAndUsableServerList(const std::list<DnsTlsServer>& tlsServers,
                                                          unsigned netId, unsigned mark);

    // Trivial factory for DnsTlsSockets.  Dependency injection is only used for testing.
    std::unique_ptr<IDnsTlsSocketFactory> mFactory;
+3 −3
Original line number Diff line number Diff line
@@ -51,7 +51,7 @@ class Experiments {
    static constexpr const char* const kExperimentFlagKeyList[] = {
            "keep_listening_udp", "parallel_lookup_release",    "parallel_lookup_sleep_time",
            "sort_nameservers",   "dot_async_handshake",        "dot_connect_timeout_ms",
            "dot_maxtries",
            "dot_maxtries",       "dot_revalidation_threshold", "dot_xport_unusable_threshold",
    };
    // This value is used in updateInternal as the default value if any flags can't be found.
    static constexpr int kFlagIntDefault = INT_MIN;
+41 −20
Original line number Diff line number Diff line
@@ -110,14 +110,14 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,

        if (needsValidation(server)) {
            updateServerState(identity, Validation::in_process, netId);
            startValidation(server, netId);
            startValidation(server, netId, false);
        }
    }

    return 0;
}

PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) {
PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) const {
    PrivateDnsStatus status{PrivateDnsMode::OFF, {}};
    std::lock_guard guard(mPrivateDnsLock);

@@ -144,41 +144,55 @@ void PrivateDnsConfiguration::clear(unsigned netId) {
    mPrivateDnsTransports.erase(netId);
}

bool PrivateDnsConfiguration::requestValidation(unsigned netId, const DnsTlsServer& server,
base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
                                                              const DnsTlsServer& server,
                                                              uint32_t mark) {
    std::lock_guard guard(mPrivateDnsLock);

    // Running revalidation requires to mark the server as in_process, which means the server
    // won't be used until the validation passes. It's necessary and safe to run revalidation
    // when in private DNS opportunistic mode, because there's a fallback mechanics even if
    // all of the private DNS servers are in in_process state.
    if (auto it = mPrivateDnsModes.find(netId); it == mPrivateDnsModes.end()) {
        return Errorf("NetId not found in mPrivateDnsModes");
    } else if (it->second != PrivateDnsMode::OPPORTUNISTIC) {
        return Errorf("Private DNS setting is not opportunistic mode");
    }

    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        return false;
        return Errorf("NetId not found in mPrivateDnsTransports");
    }

    auto& tracker = netPair->second;
    const ServerIdentity identity = ServerIdentity(server);
    auto it = tracker.find(identity);
    if (it == tracker.end()) {
        return false;
        return Errorf("Server was removed");
    }

    const DnsTlsServer& target = it->second;

    if (!target.active()) return false;
    if (!target.active()) return Errorf("Server is not active");

    if (target.validationState() != Validation::success) return false;
    if (target.validationState() != Validation::success) {
        return Errorf("Server validation state mismatched");
    }

    // Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
    // This is to protect validation from running on unexpected marks.
    // Validation should be associated with a mark gotten by system permission.
    if (target.mark != mark) return false;
    if (target.mark != mark) return Errorf("Socket mark mismatched");

    updateServerState(identity, Validation::in_process, netId);
    startValidation(target, netId);
    return true;
    startValidation(target, netId, true);
    return {};
}

void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId)
        REQUIRES(mPrivateDnsLock) {
    // Note that capturing |server| and |netId| in this lambda create copies.
    std::thread validate_thread([this, server, netId] {
void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
                                              bool isRevalidation) REQUIRES(mPrivateDnsLock) {
    // Note that capturing |server|, |netId|, and |isRevalidation| in this lambda create copies.
    std::thread validate_thread([this, server, netId, isRevalidation] {
        setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());

        // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
@@ -208,7 +222,9 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign
            LOG(WARNING) << "validateDnsTlsServer returned " << success << " for "
                         << server.toIpString();

            const bool needs_reeval = this->recordPrivateDnsValidation(server, netId, success);
            const bool needs_reeval =
                    this->recordPrivateDnsValidation(server, netId, success, isRevalidation);

            if (!needs_reeval) {
                break;
            }
@@ -254,7 +270,7 @@ void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const DnsTlsServer&
}

bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId,
                                                         bool success) {
                                                         bool success, bool isRevalidation) {
    constexpr bool NEEDS_REEVALUATION = true;
    constexpr bool DONT_REEVALUATE = false;
    const ServerIdentity identity = ServerIdentity(server);
@@ -274,10 +290,15 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
        notifyValidationStateUpdate(identity.ip.toString(), Validation::fail, netId);
        return DONT_REEVALUATE;
    }
    const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT);

    bool reevaluationStatus =
            (success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION;
    bool reevaluationStatus = NEEDS_REEVALUATION;
    if (success) {
        reevaluationStatus = DONT_REEVALUATE;
    } else if (mode->second == PrivateDnsMode::OFF) {
        reevaluationStatus = DONT_REEVALUATE;
    } else if (mode->second == PrivateDnsMode::OPPORTUNISTIC && !isRevalidation) {
        reevaluationStatus = DONT_REEVALUATE;
    }

    auto& tracker = netPair->second;
    auto serverPair = tracker.find(identity);
+11 −7
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@
#include <mutex>
#include <vector>

#include <android-base/result.h>
#include <android-base/thread_annotations.h>
#include <netdutils/DumpWriter.h>
#include <netdutils/InternetAddresses.h>
@@ -61,13 +62,13 @@ class PrivateDnsConfiguration {
    int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
            const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);

    PrivateDnsStatus getStatus(unsigned netId) EXCLUDES(mPrivateDnsLock);
    PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);

    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);

    // Request |server| to be revalidated on a connection tagged with |mark|.
    // Return true if the request is accepted; otherwise, return false.
    bool requestValidation(unsigned netId, const DnsTlsServer& server, uint32_t mark)
    // Returns a Result to indicate if the request is accepted.
    base::Result<void> requestValidation(unsigned netId, const DnsTlsServer& server, uint32_t mark)
            EXCLUDES(mPrivateDnsLock);

    struct ServerIdentity {
@@ -98,10 +99,13 @@ class PrivateDnsConfiguration {

    PrivateDnsConfiguration() = default;

    void startValidation(const DnsTlsServer& server, unsigned netId) REQUIRES(mPrivateDnsLock);
    // Launchs a thread to run the validation for |server| on the network |netId|.
    // |isRevalidation| is true if this call is due to a revalidation request.
    void startValidation(const DnsTlsServer& server, unsigned netId, bool isRevalidation)
            REQUIRES(mPrivateDnsLock);

    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success)
            EXCLUDES(mPrivateDnsLock);
    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success,
                                    bool isRevalidation) EXCLUDES(mPrivateDnsLock);

    void sendPrivateDnsValidationEvent(const DnsTlsServer& server, unsigned netId, bool success)
            REQUIRES(mPrivateDnsLock);
@@ -114,7 +118,7 @@ class PrivateDnsConfiguration {
    void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
            REQUIRES(mPrivateDnsLock);

    std::mutex mPrivateDnsLock;
    mutable std::mutex mPrivateDnsLock;
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);

    // Contains all servers for a network, along with their current validation status.
Loading