Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 2eece3f5 authored by Mike Yu's avatar Mike Yu Committed by Gerrit Code Review
Browse files

Merge "Remove validation threads tracking in PrivateDnsConfiguration"

parents 58eff5d4 f7717f5a
Loading
Loading
Loading
Loading
+14 −1
Original line number Original line Diff line number Diff line
@@ -28,7 +28,14 @@ namespace android {
namespace net {
namespace net {


// Validation status of a DNS over TLS server (on a specific netId).
// Validation status of a DNS over TLS server (on a specific netId).
enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid };
enum class Validation : uint8_t {
    in_process,
    success,
    success_but_expired,
    fail,
    unknown_server,
    unknown_netid,
};


// DnsTlsServer represents a recursive resolver that supports, or may support, a
// DnsTlsServer represents a recursive resolver that supports, or may support, a
// secure protocol.
// secure protocol.
@@ -66,9 +73,15 @@ struct DnsTlsServer {
    Validation validationState() const { return mValidation; }
    Validation validationState() const { return mValidation; }
    void setValidationState(Validation val) { mValidation = val; }
    void setValidationState(Validation val) { mValidation = val; }


    // Return whether or not the server can be used for a network. It depends on
    // the resolver configuration.
    bool active() const { return mActive; }
    void setActive(bool val) { mActive = val; }

  private:
  private:
    // State, unrelated to the comparison of DnsTlsServer objects.
    // State, unrelated to the comparison of DnsTlsServer objects.
    Validation mValidation = Validation::unknown_server;
    Validation mValidation = Validation::unknown_server;
    bool mActive = false;
};
};


// This comparison only checks the IP address.  It ignores ports, names, and fingerprints.
// This comparison only checks the IP address.  It ignores ports, names, and fingerprints.
+37 −97
Original line number Original line Diff line number Diff line
@@ -89,61 +89,30 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
    } else {
    } else {
        mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
        mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
        mPrivateDnsTransports.erase(netId);
        mPrivateDnsTransports.erase(netId);
        mPrivateDnsValidateThreads.erase(netId);
        // TODO: signal validation threads to stop.
        // TODO: As mPrivateDnsValidateThreads is reset, validation threads which haven't yet
        // finished are considered outdated. Consider signaling the outdated validation threads to
        // stop them from updating the state of PrivateDnsConfiguration (possibly disallow them to
        // report onPrivateDnsValidationEvent).
        return 0;
        return 0;
    }
    }


    // Create the tracker if it was not present
    // Create the tracker if it was not present
    auto netPair = mPrivateDnsTransports.find(netId);
    auto& tracker = mPrivateDnsTransports[netId];
    if (netPair == mPrivateDnsTransports.end()) {
        // No TLS tracker yet for this netId.
        bool added;
        std::tie(netPair, added) = mPrivateDnsTransports.emplace(netId, PrivateDnsTracker());
        if (!added) {
            LOG(ERROR) << "Memory error while recording private DNS for netId " << netId;
            return -ENOMEM;
        }
    }
    auto& tracker = netPair->second;

    // Remove any servers from the tracker that are not in |servers| exactly.
    for (auto it = tracker.begin(); it != tracker.end();) {
        if (tmp.find(it->first) == tmp.end()) {
            it = tracker.erase(it);
        } else {
            ++it;
        }
    }


    // Add any new or changed servers to the tracker, and initiate async checks for them.
    // Add the servers if not contained in tracker.
    for (const auto& [identity, server] : tmp) {
    for (const auto& [identity, server] : tmp) {
        if (needsValidation(tracker, server)) {
            // This is temporarily required. Consider the following scenario, for example,
            //   Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts.
            //           tracker has s1 alone.
            //   Step 2) The configuration changes. DotServer2 (s2) is set for the network. A
            //           validation (v2) for s2 starts. tracker has s2 alone.
            //   Step 3) Assume v1 and v2 somehow block. Now, if the configuration changes back to
            //           set s1, there won't be a v1' starts because needValidateThread() will
            //           return false.
            //
            // If we didn't add servers to tracker before needValidateThread(), tracker would
            // become empty. We would report s1 validation failed.
        if (tracker.find(identity) == tracker.end()) {
        if (tracker.find(identity) == tracker.end()) {
            tracker[identity] = server;
            tracker[identity] = server;
        }
        }
            tracker[identity].setValidationState(Validation::in_process);
            LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId "
                       << netId << ". Tracker now has size " << tracker.size();
            // This judge must be after "tracker[server] = Validation::in_process;"
            if (!needValidateThread(server, netId)) {
                continue;
    }
    }


    for (auto& [identity, server] : tracker) {
        const bool active = tmp.find(identity) != tmp.end();
        server.setActive(active);

        // For simplicity, deem the validation result of inactive servers as unreliable.
        if (!server.active() && server.validationState() == Validation::success) {
            updateServerState(identity, Validation::success_but_expired, netId);
        }

        if (needsValidation(server)) {
            updateServerState(identity, Validation::in_process, netId);
            updateServerState(identity, Validation::in_process, netId);
            startValidation(server, netId, mark);
            startValidation(server, netId, mark);
        }
        }
@@ -163,9 +132,11 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) {
    const auto netPair = mPrivateDnsTransports.find(netId);
    const auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair != mPrivateDnsTransports.end()) {
    if (netPair != mPrivateDnsTransports.end()) {
        for (const auto& [_, server] : netPair->second) {
        for (const auto& [_, server] : netPair->second) {
            if (server.active()) {
                status.serversMap.emplace(server, server.validationState());
                status.serversMap.emplace(server, server.validationState());
            }
            }
        }
        }
    }


    return status;
    return status;
}
}
@@ -175,7 +146,6 @@ void PrivateDnsConfiguration::clear(unsigned netId) {
    std::lock_guard guard(mPrivateDnsLock);
    std::lock_guard guard(mPrivateDnsLock);
    mPrivateDnsModes.erase(netId);
    mPrivateDnsModes.erase(netId);
    mPrivateDnsTransports.erase(netId);
    mPrivateDnsTransports.erase(netId);
    mPrivateDnsValidateThreads.erase(netId);
}
}


void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
@@ -216,12 +186,12 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign
            }
            }


            if (backoff.hasNextTimeout()) {
            if (backoff.hasNextTimeout()) {
                // TODO: make the thread able to receive signals to shutdown early.
                std::this_thread::sleep_for(backoff.getNextTimeout());
                std::this_thread::sleep_for(backoff.getNextTimeout());
            } else {
            } else {
                break;
                break;
            }
            }
        }
        }
        this->cleanValidateThreadTracker(server, netId);
    });
    });
    validate_thread.detach();
    validate_thread.detach();
}
}
@@ -266,6 +236,11 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
                     << " was changed during private DNS validation";
                     << " was changed during private DNS validation";
        success = false;
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
        reevaluationStatus = DONT_REEVALUATE;
    } else if (!serverPair->second.active()) {
        LOG(WARNING) << "Server " << addrToString(&server.ss)
                     << " was removed from the configuration";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    }
    }


    // Send a validation event to NetdEventListenerService.
    // Send a validation event to NetdEventListenerService.
@@ -297,34 +272,6 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
    return reevaluationStatus;
    return reevaluationStatus;
}
}


bool PrivateDnsConfiguration::needValidateThread(const DnsTlsServer& server, unsigned netId)
        REQUIRES(mPrivateDnsLock) {
    // Create the thread tracker if it was not present
    auto threadPair = mPrivateDnsValidateThreads.find(netId);
    if (threadPair == mPrivateDnsValidateThreads.end()) {
        // No thread tracker yet for this netId.
        bool added;
        std::tie(threadPair, added) = mPrivateDnsValidateThreads.emplace(netId, ThreadTracker());
        if (!added) {
            LOG(ERROR) << "Memory error while needValidateThread for netId " << netId;
            return true;
        }
    }
    auto& threadTracker = threadPair->second;
    if (threadTracker.count(server)) {
        LOG(DEBUG) << "Server " << addrToString(&(server.ss))
                   << " validate thread is already running. Thread tracker now has size "
                   << threadTracker.size();
        return false;
    } else {
        threadTracker.insert(server);
        LOG(DEBUG) << "Server " << addrToString(&(server.ss))
                   << " validate thread is not running. Thread tracker now has size "
                   << threadTracker.size();
        return true;
    }
}

void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
                                                uint32_t netId) {
                                                uint32_t netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    auto netPair = mPrivateDnsTransports.find(netId);
@@ -343,27 +290,20 @@ void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity,
    maybeNotifyObserver(identity.ip.toString(), state, netId);
    maybeNotifyObserver(identity.ip.toString(), state, netId);
}
}


void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server,
bool PrivateDnsConfiguration::needsValidation(const DnsTlsServer& server) {
                                                         unsigned netId) {
    // The server is not expected to be used on the network.
    std::lock_guard<std::mutex> guard(mPrivateDnsLock);
    if (!server.active()) return false;
    LOG(DEBUG) << "cleanValidateThreadTracker Server " << addrToString(&(server.ss))
               << " validate thread is stopped.";


    auto threadPair = mPrivateDnsValidateThreads.find(netId);
    // The server is newly added.
    if (threadPair != mPrivateDnsValidateThreads.end()) {
    if (server.validationState() == Validation::unknown_server) return true;
        auto& threadTracker = threadPair->second;
        threadTracker.erase(server);
        LOG(DEBUG) << "Server " << addrToString(&(server.ss))
                   << " validate thread is stopped. Thread tracker now has size "
                   << threadTracker.size();
    }
}


bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker,
    // The server has failed at least one validation attempt. Give it another try.
                                              const DnsTlsServer& server) {
    if (server.validationState() == Validation::fail) return true;
    const ServerIdentity identity = ServerIdentity(server);

    const auto& iter = tracker.find(identity);
    // The previous validation result might be unreliable.
    return (iter == tracker.end()) || (iter->second.validationState() == Validation::fail);
    if (server.validationState() == Validation::success_but_expired) return true;

    return false;
}
}


void PrivateDnsConfiguration::setObserver(Observer* observer) {
void PrivateDnsConfiguration::setObserver(Observer* observer) {
+7 −11
Original line number Original line Diff line number Diff line
@@ -95,26 +95,22 @@ class PrivateDnsConfiguration {
    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success)
    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success)
            EXCLUDES(mPrivateDnsLock);
            EXCLUDES(mPrivateDnsLock);


    bool needValidateThread(const DnsTlsServer& server, unsigned netId) REQUIRES(mPrivateDnsLock);
    // Decide if a validation for |server| is needed. Note that servers that have failed
    void cleanValidateThreadTracker(const DnsTlsServer& server, unsigned netId)
            EXCLUDES(mPrivateDnsLock);

    // Start validation for newly added servers as well as any servers that have
    // landed in Validation::fail state. Note that servers that have failed
    // multiple validation attempts but for which there is still a validating
    // multiple validation attempts but for which there is still a validating
    // thread running are marked as being Validation::in_process.
    // thread running are marked as being Validation::in_process.
    bool needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server)
    bool needsValidation(const DnsTlsServer& server) REQUIRES(mPrivateDnsLock);
            REQUIRES(mPrivateDnsLock);


    void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
    void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
            REQUIRES(mPrivateDnsLock);
            REQUIRES(mPrivateDnsLock);


    std::mutex mPrivateDnsLock;
    std::mutex mPrivateDnsLock;
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
    // Structure for tracking the validation status of servers on a specific netId.

    // Using the AddressComparator ensures at most one entry per IP address.
    // Contains all servers for a network, along with their current validation status.
    // In case a server is removed due to a configuration change, it remains in this map,
    // but is marked inactive.
    // Any pending validation threads will continue running because we have no way to cancel them.
    std::map<unsigned, PrivateDnsTracker> mPrivateDnsTransports GUARDED_BY(mPrivateDnsLock);
    std::map<unsigned, PrivateDnsTracker> mPrivateDnsTransports GUARDED_BY(mPrivateDnsLock);
    std::map<unsigned, ThreadTracker> mPrivateDnsValidateThreads GUARDED_BY(mPrivateDnsLock);


    // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is
    // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is
    // about to begin or 2) when a validation finishes. If a validation finishes when in OFF mode
    // about to begin or 2) when a validation finishes. If a validation finishes when in OFF mode
+2 −0
Original line number Original line Diff line number Diff line
@@ -70,6 +70,8 @@ constexpr const char* validationStatusToString(Validation value) {
            return "in_process";
            return "in_process";
        case Validation::success:
        case Validation::success:
            return "success";
            return "success";
        case Validation::success_but_expired:
            return "success_but_expired";
        case Validation::fail:
        case Validation::fail:
            return "fail";
            return "fail";
        case Validation::unknown_server:
        case Validation::unknown_server:
+6 −0
Original line number Original line Diff line number Diff line
@@ -892,9 +892,15 @@ TEST_F(ServerTest, State) {
    checkEqual(s1, s2);
    checkEqual(s1, s2);
    s2.setValidationState(Validation::fail);
    s2.setValidationState(Validation::fail);
    checkEqual(s1, s2);
    checkEqual(s1, s2);
    s1.setActive(true);
    checkEqual(s1, s2);
    s2.setActive(false);
    checkEqual(s1, s2);


    EXPECT_EQ(s1.validationState(), Validation::success);
    EXPECT_EQ(s1.validationState(), Validation::success);
    EXPECT_EQ(s2.validationState(), Validation::fail);
    EXPECT_EQ(s2.validationState(), Validation::fail);
    EXPECT_TRUE(s1.active());
    EXPECT_FALSE(s2.active());
}
}


TEST(QueryMapTest, Basic) {
TEST(QueryMapTest, Basic) {
Loading