Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 2078b2bd authored by Ken Chen's avatar Ken Chen Committed by Gerrit Code Review
Browse files

Merge "Add private DNS validation thread tracker"

parents 90d4e68d 1d68adbb
Loading
Loading
Loading
Loading
+51 −0
Original line number Original line Diff line number Diff line
@@ -97,6 +97,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
        mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
        mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
        mPrivateDnsTransports.erase(netId);
        mPrivateDnsTransports.erase(netId);
        resolv_stats_set_servers_for_dot(netId, {});
        resolv_stats_set_servers_for_dot(netId, {});
        mPrivateDnsValidateThreads.erase(netId);
        return 0;
        return 0;
    }
    }


@@ -155,6 +156,7 @@ void PrivateDnsConfiguration::clear(unsigned netId) {
    std::lock_guard guard(mPrivateDnsLock);
    std::lock_guard guard(mPrivateDnsLock);
    mPrivateDnsModes.erase(netId);
    mPrivateDnsModes.erase(netId);
    mPrivateDnsTransports.erase(netId);
    mPrivateDnsTransports.erase(netId);
    mPrivateDnsValidateThreads.erase(netId);
}
}


void PrivateDnsConfiguration::validatePrivateDnsProvider(const DnsTlsServer& server,
void PrivateDnsConfiguration::validatePrivateDnsProvider(const DnsTlsServer& server,
@@ -163,6 +165,10 @@ void PrivateDnsConfiguration::validatePrivateDnsProvider(const DnsTlsServer& ser
    tracker[server] = Validation::in_process;
    tracker[server] = Validation::in_process;
    LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId "
    LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId "
               << netId << ". Tracker now has size " << tracker.size();
               << netId << ". Tracker now has size " << tracker.size();
    // This judge must be after "tracker[server] = Validation::in_process;"
    if (!needValidateThread(server, netId)) {
        return;
    }


    // Note that capturing |server| and |netId| in this lambda create copies.
    // Note that capturing |server| and |netId| in this lambda create copies.
    std::thread validate_thread([this, server, netId, mark] {
    std::thread validate_thread([this, server, netId, mark] {
@@ -205,6 +211,7 @@ void PrivateDnsConfiguration::validatePrivateDnsProvider(const DnsTlsServer& ser
                break;
                break;
            }
            }
        }
        }
        this->cleanValidateThreadTracker(server, netId);
    });
    });
    validate_thread.detach();
    validate_thread.detach();
}
}
@@ -279,6 +286,50 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
    return reevaluationStatus;
    return reevaluationStatus;
}
}


bool PrivateDnsConfiguration::needValidateThread(const DnsTlsServer& server, unsigned netId)
        REQUIRES(mPrivateDnsLock) {
    // Create the thread tracker if it was not present
    auto threadPair = mPrivateDnsValidateThreads.find(netId);
    if (threadPair == mPrivateDnsValidateThreads.end()) {
        // No thread tracker yet for this netId.
        bool added;
        std::tie(threadPair, added) = mPrivateDnsValidateThreads.emplace(netId, ThreadTracker());
        if (!added) {
            LOG(ERROR) << "Memory error while needValidateThread for netId " << netId;
            return true;
        }
    }
    auto& threadTracker = threadPair->second;
    if (threadTracker.count(server)) {
        LOG(DEBUG) << "Server " << addrToString(&(server.ss))
                   << " validate thread is already running. Thread tracker now has size "
                   << threadTracker.size();
        return false;
    } else {
        threadTracker.insert(server);
        LOG(DEBUG) << "Server " << addrToString(&(server.ss))
                   << " validate thread is not running. Thread tracker now has size "
                   << threadTracker.size();
        return true;
    }
}

void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server,
                                                         unsigned netId) {
    std::lock_guard<std::mutex> guard(mPrivateDnsLock);
    LOG(DEBUG) << "cleanValidateThreadTracker Server " << addrToString(&(server.ss))
               << " validate thread is stopped.";

    auto threadPair = mPrivateDnsValidateThreads.find(netId);
    if (threadPair != mPrivateDnsValidateThreads.end()) {
        auto& threadTracker = threadPair->second;
        threadTracker.erase(server);
        LOG(DEBUG) << "Server " << addrToString(&(server.ss))
                   << " validate thread is stopped. Thread tracker now has size "
                   << threadTracker.size();
    }
}

// Start validation for newly added servers as well as any servers that have
// Start validation for newly added servers as well as any servers that have
// landed in Validation::fail state. Note that servers that have failed
// landed in Validation::fail state. Note that servers that have failed
// multiple validation attempts but for which there is still a validating
// multiple validation attempts but for which there is still a validating
+5 −0
Original line number Original line Diff line number Diff line
@@ -62,12 +62,16 @@ class PrivateDnsConfiguration {


  private:
  private:
    typedef std::map<DnsTlsServer, Validation, AddressComparator> PrivateDnsTracker;
    typedef std::map<DnsTlsServer, Validation, AddressComparator> PrivateDnsTracker;
    typedef std::set<DnsTlsServer, AddressComparator> ThreadTracker;


    void validatePrivateDnsProvider(const DnsTlsServer& server, PrivateDnsTracker& tracker,
    void validatePrivateDnsProvider(const DnsTlsServer& server, PrivateDnsTracker& tracker,
                                    unsigned netId, uint32_t mark) REQUIRES(mPrivateDnsLock);
                                    unsigned netId, uint32_t mark) REQUIRES(mPrivateDnsLock);


    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success);
    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success);


    bool needValidateThread(const DnsTlsServer& server, unsigned netId) REQUIRES(mPrivateDnsLock);
    void cleanValidateThreadTracker(const DnsTlsServer& server, unsigned netId);

    // Start validation for newly added servers as well as any servers that have
    // Start validation for newly added servers as well as any servers that have
    // landed in Validation::fail state. Note that servers that have failed
    // landed in Validation::fail state. Note that servers that have failed
    // multiple validation attempts but for which there is still a validating
    // multiple validation attempts but for which there is still a validating
@@ -79,6 +83,7 @@ class PrivateDnsConfiguration {
    // Structure for tracking the validation status of servers on a specific netId.
    // Structure for tracking the validation status of servers on a specific netId.
    // Using the AddressComparator ensures at most one entry per IP address.
    // Using the AddressComparator ensures at most one entry per IP address.
    std::map<unsigned, PrivateDnsTracker> mPrivateDnsTransports GUARDED_BY(mPrivateDnsLock);
    std::map<unsigned, PrivateDnsTracker> mPrivateDnsTransports GUARDED_BY(mPrivateDnsLock);
    std::map<unsigned, ThreadTracker> mPrivateDnsValidateThreads GUARDED_BY(mPrivateDnsLock);
};
};


extern PrivateDnsConfiguration gPrivateDnsConfiguration;
extern PrivateDnsConfiguration gPrivateDnsConfiguration;