Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit f7717f5a authored by Mike Yu's avatar Mike Yu
Browse files

Remove validation threads tracking in PrivateDnsConfiguration

The validation threads tracking is not necessary since DnsTlsServer
itself can provide sufficient information to PrivateDnsConfiguration
to decide if a validation should start.

This change comprises:

[1] Remove mPrivateDnsValidateThreads and related functions
[2] Extend DnsTlsServer to tell if it is a present server for a
    network.
[3] PrivateDnsConfiguration reserves every DnsTlsServer object until
    being set to OFF mode or the network is destroyed. Make use of
    [2] to determine which servers should be active for the network.
[4] DnsTlsServers with identical IP address but different private
    DNS provider hostname are treated as different servers, and
    thus they have their own validation thread.
[5] Add a new state, Validation::success_but_expired, which is used
    to determine if a server which had passed the validation should
    be revalidated again.
[6] To fit in with [4], some related tests are modified.

Bug: 79727473
Test: cd packages/modules/DnsResolver && atest
Change-Id: I78afce543ea05be39c36d268576824e9ec798b12
parent fa985f71
Loading
Loading
Loading
Loading
+14 −1
Original line number Diff line number Diff line
@@ -28,7 +28,14 @@ namespace android {
namespace net {

// Validation status of a DNS over TLS server (on a specific netId).
enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid };
enum class Validation : uint8_t {
    in_process,
    success,
    success_but_expired,
    fail,
    unknown_server,
    unknown_netid,
};

// DnsTlsServer represents a recursive resolver that supports, or may support, a
// secure protocol.
@@ -66,9 +73,15 @@ struct DnsTlsServer {
    Validation validationState() const { return mValidation; }
    void setValidationState(Validation val) { mValidation = val; }

    // Return whether or not the server can be used for a network. It depends on
    // the resolver configuration.
    bool active() const { return mActive; }
    void setActive(bool val) { mActive = val; }

  private:
    // State, unrelated to the comparison of DnsTlsServer objects.
    Validation mValidation = Validation::unknown_server;
    bool mActive = false;
};

// This comparison only checks the IP address.  It ignores ports, names, and fingerprints.
+37 −97
Original line number Diff line number Diff line
@@ -89,61 +89,30 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
    } else {
        mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
        mPrivateDnsTransports.erase(netId);
        mPrivateDnsValidateThreads.erase(netId);
        // TODO: As mPrivateDnsValidateThreads is reset, validation threads which haven't yet
        // finished are considered outdated. Consider signaling the outdated validation threads to
        // stop them from updating the state of PrivateDnsConfiguration (possibly disallow them to
        // report onPrivateDnsValidationEvent).
        // TODO: signal validation threads to stop.
        return 0;
    }

    // Create the tracker if it was not present
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        // No TLS tracker yet for this netId.
        bool added;
        std::tie(netPair, added) = mPrivateDnsTransports.emplace(netId, PrivateDnsTracker());
        if (!added) {
            LOG(ERROR) << "Memory error while recording private DNS for netId " << netId;
            return -ENOMEM;
        }
    }
    auto& tracker = netPair->second;

    // Remove any servers from the tracker that are not in |servers| exactly.
    for (auto it = tracker.begin(); it != tracker.end();) {
        if (tmp.find(it->first) == tmp.end()) {
            it = tracker.erase(it);
        } else {
            ++it;
        }
    }
    auto& tracker = mPrivateDnsTransports[netId];

    // Add any new or changed servers to the tracker, and initiate async checks for them.
    // Add the servers if not contained in tracker.
    for (const auto& [identity, server] : tmp) {
        if (needsValidation(tracker, server)) {
            // This is temporarily required. Consider the following scenario, for example,
            //   Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts.
            //           tracker has s1 alone.
            //   Step 2) The configuration changes. DotServer2 (s2) is set for the network. A
            //           validation (v2) for s2 starts. tracker has s2 alone.
            //   Step 3) Assume v1 and v2 somehow block. Now, if the configuration changes back to
            //           set s1, there won't be a v1' starts because needValidateThread() will
            //           return false.
            //
            // If we didn't add servers to tracker before needValidateThread(), tracker would
            // become empty. We would report s1 validation failed.
        if (tracker.find(identity) == tracker.end()) {
            tracker[identity] = server;
        }
            tracker[identity].setValidationState(Validation::in_process);
            LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId "
                       << netId << ". Tracker now has size " << tracker.size();
            // This judge must be after "tracker[server] = Validation::in_process;"
            if (!needValidateThread(server, netId)) {
                continue;
    }

    for (auto& [identity, server] : tracker) {
        const bool active = tmp.find(identity) != tmp.end();
        server.setActive(active);

        // For simplicity, deem the validation result of inactive servers as unreliable.
        if (!server.active() && server.validationState() == Validation::success) {
            updateServerState(identity, Validation::success_but_expired, netId);
        }

        if (needsValidation(server)) {
            updateServerState(identity, Validation::in_process, netId);
            startValidation(server, netId, mark);
        }
@@ -163,9 +132,11 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) {
    const auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair != mPrivateDnsTransports.end()) {
        for (const auto& [_, server] : netPair->second) {
            if (server.active()) {
                status.serversMap.emplace(server, server.validationState());
            }
        }
    }

    return status;
}
@@ -175,7 +146,6 @@ void PrivateDnsConfiguration::clear(unsigned netId) {
    std::lock_guard guard(mPrivateDnsLock);
    mPrivateDnsModes.erase(netId);
    mPrivateDnsTransports.erase(netId);
    mPrivateDnsValidateThreads.erase(netId);
}

void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
@@ -216,12 +186,12 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign
            }

            if (backoff.hasNextTimeout()) {
                // TODO: make the thread able to receive signals to shutdown early.
                std::this_thread::sleep_for(backoff.getNextTimeout());
            } else {
                break;
            }
        }
        this->cleanValidateThreadTracker(server, netId);
    });
    validate_thread.detach();
}
@@ -266,6 +236,11 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
                     << " was changed during private DNS validation";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    } else if (!serverPair->second.active()) {
        LOG(WARNING) << "Server " << addrToString(&server.ss)
                     << " was removed from the configuration";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    }

    // Send a validation event to NetdEventListenerService.
@@ -297,34 +272,6 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
    return reevaluationStatus;
}

bool PrivateDnsConfiguration::needValidateThread(const DnsTlsServer& server, unsigned netId)
        REQUIRES(mPrivateDnsLock) {
    // Create the thread tracker if it was not present
    auto threadPair = mPrivateDnsValidateThreads.find(netId);
    if (threadPair == mPrivateDnsValidateThreads.end()) {
        // No thread tracker yet for this netId.
        bool added;
        std::tie(threadPair, added) = mPrivateDnsValidateThreads.emplace(netId, ThreadTracker());
        if (!added) {
            LOG(ERROR) << "Memory error while needValidateThread for netId " << netId;
            return true;
        }
    }
    auto& threadTracker = threadPair->second;
    if (threadTracker.count(server)) {
        LOG(DEBUG) << "Server " << addrToString(&(server.ss))
                   << " validate thread is already running. Thread tracker now has size "
                   << threadTracker.size();
        return false;
    } else {
        threadTracker.insert(server);
        LOG(DEBUG) << "Server " << addrToString(&(server.ss))
                   << " validate thread is not running. Thread tracker now has size "
                   << threadTracker.size();
        return true;
    }
}

void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
                                                uint32_t netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
@@ -343,27 +290,20 @@ void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity,
    maybeNotifyObserver(identity.ip.toString(), state, netId);
}

void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server,
                                                         unsigned netId) {
    std::lock_guard<std::mutex> guard(mPrivateDnsLock);
    LOG(DEBUG) << "cleanValidateThreadTracker Server " << addrToString(&(server.ss))
               << " validate thread is stopped.";
bool PrivateDnsConfiguration::needsValidation(const DnsTlsServer& server) {
    // The server is not expected to be used on the network.
    if (!server.active()) return false;

    auto threadPair = mPrivateDnsValidateThreads.find(netId);
    if (threadPair != mPrivateDnsValidateThreads.end()) {
        auto& threadTracker = threadPair->second;
        threadTracker.erase(server);
        LOG(DEBUG) << "Server " << addrToString(&(server.ss))
                   << " validate thread is stopped. Thread tracker now has size "
                   << threadTracker.size();
    }
}
    // The server is newly added.
    if (server.validationState() == Validation::unknown_server) return true;

bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker,
                                              const DnsTlsServer& server) {
    const ServerIdentity identity = ServerIdentity(server);
    const auto& iter = tracker.find(identity);
    return (iter == tracker.end()) || (iter->second.validationState() == Validation::fail);
    // The server has failed at least one validation attempt. Give it another try.
    if (server.validationState() == Validation::fail) return true;

    // The previous validation result might be unreliable.
    if (server.validationState() == Validation::success_but_expired) return true;

    return false;
}

void PrivateDnsConfiguration::setObserver(Observer* observer) {
+7 −11
Original line number Diff line number Diff line
@@ -95,26 +95,22 @@ class PrivateDnsConfiguration {
    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success)
            EXCLUDES(mPrivateDnsLock);

    bool needValidateThread(const DnsTlsServer& server, unsigned netId) REQUIRES(mPrivateDnsLock);
    void cleanValidateThreadTracker(const DnsTlsServer& server, unsigned netId)
            EXCLUDES(mPrivateDnsLock);

    // Start validation for newly added servers as well as any servers that have
    // landed in Validation::fail state. Note that servers that have failed
    // Decide if a validation for |server| is needed. Note that servers that have failed
    // multiple validation attempts but for which there is still a validating
    // thread running are marked as being Validation::in_process.
    bool needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server)
            REQUIRES(mPrivateDnsLock);
    bool needsValidation(const DnsTlsServer& server) REQUIRES(mPrivateDnsLock);

    void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
            REQUIRES(mPrivateDnsLock);

    std::mutex mPrivateDnsLock;
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
    // Structure for tracking the validation status of servers on a specific netId.
    // Using the AddressComparator ensures at most one entry per IP address.

    // Contains all servers for a network, along with their current validation status.
    // In case a server is removed due to a configuration change, it remains in this map,
    // but is marked inactive.
    // Any pending validation threads will continue running because we have no way to cancel them.
    std::map<unsigned, PrivateDnsTracker> mPrivateDnsTransports GUARDED_BY(mPrivateDnsLock);
    std::map<unsigned, ThreadTracker> mPrivateDnsValidateThreads GUARDED_BY(mPrivateDnsLock);

    // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is
    // about to begin or 2) when a validation finishes. If a validation finishes when in OFF mode
+2 −0
Original line number Diff line number Diff line
@@ -70,6 +70,8 @@ constexpr const char* validationStatusToString(Validation value) {
            return "in_process";
        case Validation::success:
            return "success";
        case Validation::success_but_expired:
            return "success_but_expired";
        case Validation::fail:
            return "fail";
        case Validation::unknown_server:
+6 −0
Original line number Diff line number Diff line
@@ -892,9 +892,15 @@ TEST_F(ServerTest, State) {
    checkEqual(s1, s2);
    s2.setValidationState(Validation::fail);
    checkEqual(s1, s2);
    s1.setActive(true);
    checkEqual(s1, s2);
    s2.setActive(false);
    checkEqual(s1, s2);

    EXPECT_EQ(s1.validationState(), Validation::success);
    EXPECT_EQ(s2.validationState(), Validation::fail);
    EXPECT_TRUE(s1.active());
    EXPECT_FALSE(s2.active());
}

TEST(QueryMapTest, Basic) {
Loading