Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 3334a5e0 authored by Mike Yu's avatar Mike Yu
Browse files

Add updateServerState method to PrivateDnsConfiguration

This change is expected to not change any behavior.

Bug: 79727473
Test: cd packages/modules/DnsResolver && atest
Change-Id: I8409f45fbd3f3af252bc33e5568af70f0c4eac4d
parent a60f48b6
Loading
Loading
Loading
Loading
+38 −26
Original line number Original line Diff line number Diff line
@@ -122,7 +122,27 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
    // Add any new or changed servers to the tracker, and initiate async checks for them.
    // Add any new or changed servers to the tracker, and initiate async checks for them.
    for (const auto& server : tlsServers) {
    for (const auto& server : tlsServers) {
        if (needsValidation(tracker, server)) {
        if (needsValidation(tracker, server)) {
            validatePrivateDnsProvider(server, tracker, netId, mark);
            // This is temporarily required. Consider the following scenario, for example,
            //   Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts.
            //           tracker has s1 alone.
            //   Step 2) The configuration changes. DotServer2 (s2) is set for the network. A
            //           validation (v2) for s2 starts. tracker has s2 alone.
            //   Step 3) Assume v1 and v2 somehow block. Now, if the configuration changes back to
            //           set s1, there won't be a v1' starts because needValidateThread() will
            //           return false.
            //
            // If we didn't add servers to tracker before needValidateThread(), tracker would
            // become empty. We would report s1 validation failed.
            tracker[server] = Validation::in_process;
            LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId "
                       << netId << ". Tracker now has size " << tracker.size();
            // This judge must be after "tracker[server] = Validation::in_process;"
            if (!needValidateThread(server, netId)) {
                continue;
            }

            updateServerState(server, Validation::in_process, netId);
            startValidation(server, netId, mark);
        }
        }
    }
    }


@@ -155,19 +175,8 @@ void PrivateDnsConfiguration::clear(unsigned netId) {
    mPrivateDnsValidateThreads.erase(netId);
    mPrivateDnsValidateThreads.erase(netId);
}
}


void PrivateDnsConfiguration::validatePrivateDnsProvider(const DnsTlsServer& server,
void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
                                                         PrivateDnsTracker& tracker, unsigned netId,
                                              uint32_t mark) REQUIRES(mPrivateDnsLock) {
                                              uint32_t mark) REQUIRES(mPrivateDnsLock) {
    tracker[server] = Validation::in_process;
    LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId "
               << netId << ". Tracker now has size " << tracker.size();
    // This judge must be after "tracker[server] = Validation::in_process;"
    if (!needValidateThread(server, netId)) {
        return;
    }

    maybeNotifyObserver(server, Validation::in_process, netId);

    // Note that capturing |server| and |netId| in this lambda create copies.
    // Note that capturing |server| and |netId| in this lambda create copies.
    std::thread validate_thread([this, server, netId, mark] {
    std::thread validate_thread([this, server, netId, mark] {
        setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());
        setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());
@@ -273,18 +282,14 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
    }
    }


    if (success) {
    if (success) {
        tracker[server] = Validation::success;
        updateServerState(server, Validation::success, netId);
        maybeNotifyObserver(server, Validation::success, netId);
    } else {
    } else {
        // Validation failure is expected if a user is on a captive portal.
        // Validation failure is expected if a user is on a captive portal.
        // TODO: Trigger a second validation attempt after captive portal login
        // TODO: Trigger a second validation attempt after captive portal login
        // succeeds.
        // succeeds.
        tracker[server] = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process
        const auto result = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process
                                                                       : Validation::fail;
                                                                       : Validation::fail;
        maybeNotifyObserver(server,
        updateServerState(server, result, netId);
                            (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process
                                                                       : Validation::fail,
                            netId);
    }
    }
    LOG(WARNING) << "Validation " << (success ? "success" : "failed");
    LOG(WARNING) << "Validation " << (success ? "success" : "failed");


@@ -319,6 +324,17 @@ bool PrivateDnsConfiguration::needValidateThread(const DnsTlsServer& server, uns
    }
    }
}
}


void PrivateDnsConfiguration::updateServerState(const DnsTlsServer& server, Validation state,
                                                uint32_t netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair != mPrivateDnsTransports.end()) {
        auto& tracker = netPair->second;
        tracker[server] = state;
    }

    maybeNotifyObserver(server, state, netId);
}

void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server,
void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server,
                                                         unsigned netId) {
                                                         unsigned netId) {
    std::lock_guard<std::mutex> guard(mPrivateDnsLock);
    std::lock_guard<std::mutex> guard(mPrivateDnsLock);
@@ -335,10 +351,6 @@ void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& ser
    }
    }
}
}


// Start validation for newly added servers as well as any servers that have
// landed in Validation::fail state. Note that servers that have failed
// multiple validation attempts but for which there is still a validating
// thread running are marked as being Validation::in_process.
bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker,
bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker,
                                              const DnsTlsServer& server) {
                                              const DnsTlsServer& server) {
    const auto& iter = tracker.find(server);
    const auto& iter = tracker.find(server);
+14 −5
Original line number Original line Diff line number Diff line
@@ -71,19 +71,25 @@ class PrivateDnsConfiguration {


    PrivateDnsConfiguration() = default;
    PrivateDnsConfiguration() = default;


    void validatePrivateDnsProvider(const DnsTlsServer& server, PrivateDnsTracker& tracker,
    void startValidation(const DnsTlsServer& server, unsigned netId, uint32_t mark)
                                    unsigned netId, uint32_t mark) REQUIRES(mPrivateDnsLock);
            REQUIRES(mPrivateDnsLock);


    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success);
    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success)
            EXCLUDES(mPrivateDnsLock);


    bool needValidateThread(const DnsTlsServer& server, unsigned netId) REQUIRES(mPrivateDnsLock);
    bool needValidateThread(const DnsTlsServer& server, unsigned netId) REQUIRES(mPrivateDnsLock);
    void cleanValidateThreadTracker(const DnsTlsServer& server, unsigned netId);
    void cleanValidateThreadTracker(const DnsTlsServer& server, unsigned netId)
            EXCLUDES(mPrivateDnsLock);


    // Start validation for newly added servers as well as any servers that have
    // Start validation for newly added servers as well as any servers that have
    // landed in Validation::fail state. Note that servers that have failed
    // landed in Validation::fail state. Note that servers that have failed
    // multiple validation attempts but for which there is still a validating
    // multiple validation attempts but for which there is still a validating
    // thread running are marked as being Validation::in_process.
    // thread running are marked as being Validation::in_process.
    bool needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server);
    bool needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server)
            REQUIRES(mPrivateDnsLock);

    void updateServerState(const DnsTlsServer& server, Validation state, uint32_t netId)
            REQUIRES(mPrivateDnsLock);


    std::mutex mPrivateDnsLock;
    std::mutex mPrivateDnsLock;
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
@@ -94,6 +100,9 @@ class PrivateDnsConfiguration {


    // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is
    // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is
    // about to begin or 2) when a validation finishes.
    // about to begin or 2) when a validation finishes.
    // WARNING: The Observer is notified while the lock is being held. Be careful not to call any
    // method of PrivateDnsConfiguration from the observer.
    // TODO: fix the reentrancy problem.
    class Observer {
    class Observer {
      public:
      public:
        virtual ~Observer(){};
        virtual ~Observer(){};