Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d961c43a authored by Mike Yu's avatar Mike Yu Committed by Automerger Merge Worker
Browse files

Merge "Remove validation threads tracking in PrivateDnsConfiguration" am: 2eece3f5 am: 67941fd5

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1509770

Change-Id: I76068aefb1afdb24780aac7939e9e8cc1f4fbb74
parents c4a0ee1d 67941fd5
Loading
Loading
Loading
Loading
+14 −1
Original line number Diff line number Diff line
@@ -28,7 +28,14 @@ namespace android {
namespace net {

// Validation status of a DNS over TLS server (on a specific netId).
enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid };
enum class Validation : uint8_t {
    in_process,
    success,
    success_but_expired,
    fail,
    unknown_server,
    unknown_netid,
};

// DnsTlsServer represents a recursive resolver that supports, or may support, a
// secure protocol.
@@ -66,9 +73,15 @@ struct DnsTlsServer {
    Validation validationState() const { return mValidation; }
    void setValidationState(Validation val) { mValidation = val; }

    // Return whether or not the server can be used for a network. It depends on
    // the resolver configuration.
    bool active() const { return mActive; }
    void setActive(bool val) { mActive = val; }

  private:
    // State, unrelated to the comparison of DnsTlsServer objects.
    Validation mValidation = Validation::unknown_server;
    bool mActive = false;
};

// This comparison only checks the IP address.  It ignores ports, names, and fingerprints.
+37 −97
Original line number Diff line number Diff line
@@ -89,61 +89,30 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
    } else {
        mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
        mPrivateDnsTransports.erase(netId);
        mPrivateDnsValidateThreads.erase(netId);
        // TODO: As mPrivateDnsValidateThreads is reset, validation threads which haven't yet
        // finished are considered outdated. Consider signaling the outdated validation threads to
        // stop them from updating the state of PrivateDnsConfiguration (possibly disallow them to
        // report onPrivateDnsValidationEvent).
        // TODO: signal validation threads to stop.
        return 0;
    }

    // Create the tracker if it was not present
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        // No TLS tracker yet for this netId.
        bool added;
        std::tie(netPair, added) = mPrivateDnsTransports.emplace(netId, PrivateDnsTracker());
        if (!added) {
            LOG(ERROR) << "Memory error while recording private DNS for netId " << netId;
            return -ENOMEM;
        }
    }
    auto& tracker = netPair->second;

    // Remove any servers from the tracker that are not in |servers| exactly.
    for (auto it = tracker.begin(); it != tracker.end();) {
        if (tmp.find(it->first) == tmp.end()) {
            it = tracker.erase(it);
        } else {
            ++it;
        }
    }
    auto& tracker = mPrivateDnsTransports[netId];

    // Add any new or changed servers to the tracker, and initiate async checks for them.
    // Add the servers if not contained in tracker.
    for (const auto& [identity, server] : tmp) {
        if (needsValidation(tracker, server)) {
            // This is temporarily required. Consider the following scenario, for example,
            //   Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts.
            //           tracker has s1 alone.
            //   Step 2) The configuration changes. DotServer2 (s2) is set for the network. A
            //           validation (v2) for s2 starts. tracker has s2 alone.
            //   Step 3) Assume v1 and v2 somehow block. Now, if the configuration changes back to
            //           set s1, there won't be a v1' starts because needValidateThread() will
            //           return false.
            //
            // If we didn't add servers to tracker before needValidateThread(), tracker would
            // become empty. We would report s1 validation failed.
        if (tracker.find(identity) == tracker.end()) {
            tracker[identity] = server;
        }
            tracker[identity].setValidationState(Validation::in_process);
            LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId "
                       << netId << ". Tracker now has size " << tracker.size();
            // This judge must be after "tracker[server] = Validation::in_process;"
            if (!needValidateThread(server, netId)) {
                continue;
    }

    for (auto& [identity, server] : tracker) {
        const bool active = tmp.find(identity) != tmp.end();
        server.setActive(active);

        // For simplicity, deem the validation result of inactive servers as unreliable.
        if (!server.active() && server.validationState() == Validation::success) {
            updateServerState(identity, Validation::success_but_expired, netId);
        }

        if (needsValidation(server)) {
            updateServerState(identity, Validation::in_process, netId);
            startValidation(server, netId, mark);
        }
@@ -163,9 +132,11 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) {
    const auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair != mPrivateDnsTransports.end()) {
        for (const auto& [_, server] : netPair->second) {
            if (server.active()) {
                status.serversMap.emplace(server, server.validationState());
            }
        }
    }

    return status;
}
@@ -175,7 +146,6 @@ void PrivateDnsConfiguration::clear(unsigned netId) {
    std::lock_guard guard(mPrivateDnsLock);
    mPrivateDnsModes.erase(netId);
    mPrivateDnsTransports.erase(netId);
    mPrivateDnsValidateThreads.erase(netId);
}

void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
@@ -216,12 +186,12 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign
            }

            if (backoff.hasNextTimeout()) {
                // TODO: make the thread able to receive signals to shutdown early.
                std::this_thread::sleep_for(backoff.getNextTimeout());
            } else {
                break;
            }
        }
        this->cleanValidateThreadTracker(server, netId);
    });
    validate_thread.detach();
}
@@ -266,6 +236,11 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
                     << " was changed during private DNS validation";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    } else if (!serverPair->second.active()) {
        LOG(WARNING) << "Server " << addrToString(&server.ss)
                     << " was removed from the configuration";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    }

    // Send a validation event to NetdEventListenerService.
@@ -297,34 +272,6 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
    return reevaluationStatus;
}

bool PrivateDnsConfiguration::needValidateThread(const DnsTlsServer& server, unsigned netId)
        REQUIRES(mPrivateDnsLock) {
    // Create the thread tracker if it was not present
    auto threadPair = mPrivateDnsValidateThreads.find(netId);
    if (threadPair == mPrivateDnsValidateThreads.end()) {
        // No thread tracker yet for this netId.
        bool added;
        std::tie(threadPair, added) = mPrivateDnsValidateThreads.emplace(netId, ThreadTracker());
        if (!added) {
            LOG(ERROR) << "Memory error while needValidateThread for netId " << netId;
            return true;
        }
    }
    auto& threadTracker = threadPair->second;
    if (threadTracker.count(server)) {
        LOG(DEBUG) << "Server " << addrToString(&(server.ss))
                   << " validate thread is already running. Thread tracker now has size "
                   << threadTracker.size();
        return false;
    } else {
        threadTracker.insert(server);
        LOG(DEBUG) << "Server " << addrToString(&(server.ss))
                   << " validate thread is not running. Thread tracker now has size "
                   << threadTracker.size();
        return true;
    }
}

void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
                                                uint32_t netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
@@ -343,27 +290,20 @@ void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity,
    maybeNotifyObserver(identity.ip.toString(), state, netId);
}

void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server,
                                                         unsigned netId) {
    std::lock_guard<std::mutex> guard(mPrivateDnsLock);
    LOG(DEBUG) << "cleanValidateThreadTracker Server " << addrToString(&(server.ss))
               << " validate thread is stopped.";
bool PrivateDnsConfiguration::needsValidation(const DnsTlsServer& server) {
    // The server is not expected to be used on the network.
    if (!server.active()) return false;

    auto threadPair = mPrivateDnsValidateThreads.find(netId);
    if (threadPair != mPrivateDnsValidateThreads.end()) {
        auto& threadTracker = threadPair->second;
        threadTracker.erase(server);
        LOG(DEBUG) << "Server " << addrToString(&(server.ss))
                   << " validate thread is stopped. Thread tracker now has size "
                   << threadTracker.size();
    }
}
    // The server is newly added.
    if (server.validationState() == Validation::unknown_server) return true;

bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker,
                                              const DnsTlsServer& server) {
    const ServerIdentity identity = ServerIdentity(server);
    const auto& iter = tracker.find(identity);
    return (iter == tracker.end()) || (iter->second.validationState() == Validation::fail);
    // The server has failed at least one validation attempt. Give it another try.
    if (server.validationState() == Validation::fail) return true;

    // The previous validation result might be unreliable.
    if (server.validationState() == Validation::success_but_expired) return true;

    return false;
}

void PrivateDnsConfiguration::setObserver(Observer* observer) {
+7 −11
Original line number Diff line number Diff line
@@ -95,26 +95,22 @@ class PrivateDnsConfiguration {
    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success)
            EXCLUDES(mPrivateDnsLock);

    bool needValidateThread(const DnsTlsServer& server, unsigned netId) REQUIRES(mPrivateDnsLock);
    void cleanValidateThreadTracker(const DnsTlsServer& server, unsigned netId)
            EXCLUDES(mPrivateDnsLock);

    // Start validation for newly added servers as well as any servers that have
    // landed in Validation::fail state. Note that servers that have failed
    // Decide if a validation for |server| is needed. Note that servers that have failed
    // multiple validation attempts but for which there is still a validating
    // thread running are marked as being Validation::in_process.
    bool needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server)
            REQUIRES(mPrivateDnsLock);
    bool needsValidation(const DnsTlsServer& server) REQUIRES(mPrivateDnsLock);

    void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
            REQUIRES(mPrivateDnsLock);

    std::mutex mPrivateDnsLock;
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
    // Structure for tracking the validation status of servers on a specific netId.
    // Using the AddressComparator ensures at most one entry per IP address.

    // Contains all servers for a network, along with their current validation status.
    // In case a server is removed due to a configuration change, it remains in this map,
    // but is marked inactive.
    // Any pending validation threads will continue running because we have no way to cancel them.
    std::map<unsigned, PrivateDnsTracker> mPrivateDnsTransports GUARDED_BY(mPrivateDnsLock);
    std::map<unsigned, ThreadTracker> mPrivateDnsValidateThreads GUARDED_BY(mPrivateDnsLock);

    // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is
    // about to begin or 2) when a validation finishes. If a validation finishes when in OFF mode
+2 −0
Original line number Diff line number Diff line
@@ -70,6 +70,8 @@ constexpr const char* validationStatusToString(Validation value) {
            return "in_process";
        case Validation::success:
            return "success";
        case Validation::success_but_expired:
            return "success_but_expired";
        case Validation::fail:
            return "fail";
        case Validation::unknown_server:
+6 −0
Original line number Diff line number Diff line
@@ -892,9 +892,15 @@ TEST_F(ServerTest, State) {
    checkEqual(s1, s2);
    s2.setValidationState(Validation::fail);
    checkEqual(s1, s2);
    s1.setActive(true);
    checkEqual(s1, s2);
    s2.setActive(false);
    checkEqual(s1, s2);

    EXPECT_EQ(s1.validationState(), Validation::success);
    EXPECT_EQ(s2.validationState(), Validation::fail);
    EXPECT_TRUE(s1.active());
    EXPECT_FALSE(s2.active());
}

TEST(QueryMapTest, Basic) {
Loading