Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 66eaa2db authored by Mike Yu's avatar Mike Yu Committed by Automerger Merge Worker
Browse files

Merge "Some cleaup in PrivateDnsConfiguration::startValidation" am: d347b3ba am: 3c28b5ed

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1529599

MUST ONLY BE SUBMITTED BY AUTOMERGER

Change-Id: I261dbb2b6709261cf21f82d56f260f3dc1d289e1
parents 96dac191 3c28b5ed
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -158,8 +158,8 @@ DnsTlsTransport::~DnsTlsTransport() {
// static
// TODO: Use this function to preheat the session cache.
// That may require moving it to DnsTlsDispatcher.
bool DnsTlsTransport::validate(const DnsTlsServer& server, unsigned netid, uint32_t mark) {
    LOG(DEBUG) << "Beginning validation on " << netid;
bool DnsTlsTransport::validate(const DnsTlsServer& server, uint32_t mark) {
    LOG(DEBUG) << "Beginning validation with mark " << std::hex << mark;
    // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
    // order to prove that it is actually a working DNS over TLS server.
    static const char kDnsSafeChars[] =
@@ -195,7 +195,7 @@ bool DnsTlsTransport::validate(const DnsTlsServer& server, unsigned netid, uint3
    DnsTlsTransport transport(server, mark, &factory);
    auto r = transport.query(netdutils::Slice(query, qlen)).get();
    if (r.code != Response::success) {
        LOG(DEBUG) << "query failed";
        LOG(WARNING) << "query failed";
        return false;
    }

@@ -212,7 +212,7 @@ bool DnsTlsTransport::validate(const DnsTlsServer& server, unsigned netid, uint3
    }

    const int ancount = (recvbuf[6] << 8) | recvbuf[7];
    LOG(DEBUG) << netid << " answer count: " << ancount;
    LOG(DEBUG) << "answer count: " << ancount;

    // TODO: Further validate the response contents (check for valid AAAA record, ...).
    // Note that currently, integration tests rely on this function accepting a
+2 −2
Original line number Diff line number Diff line
@@ -52,10 +52,10 @@ class DnsTlsTransport : public IDnsTlsSocketObserver {
    // Given a |query|, this method sends it to the server and returns the result asynchronously.
    std::future<Result> query(const netdutils::Slice query) EXCLUDES(mLock);

    // Check that a given TLS server is fully working on the specified netid.
    // Check that a given TLS server is fully working with a specified mark.
    // This function is used in ResolverController to ensure that we don't enable DNS over TLS
    // on networks where it doesn't actually work.
    static bool validate(const DnsTlsServer& server, unsigned netid, uint32_t mark);
    static bool validate(const DnsTlsServer& server, uint32_t mark);

    int getConnectCounter() const EXCLUDES(mLock);

+18 −16
Original line number Diff line number Diff line
@@ -107,7 +107,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,

        if (needsValidation(server)) {
            updateServerState(identity, Validation::in_process, netId);
            startValidation(server, netId, mark);
            startValidation(server, netId);
        }
    }

@@ -168,14 +168,14 @@ bool PrivateDnsConfiguration::requestValidation(unsigned netId, const DnsTlsServ
    if (target.mark != mark) return false;

    updateServerState(identity, Validation::in_process, netId);
    startValidation(target, netId, mark);
    startValidation(target, netId);
    return true;
}

void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
                                              uint32_t mark) REQUIRES(mPrivateDnsLock) {
void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId)
        REQUIRES(mPrivateDnsLock) {
    // Note that capturing |server| and |netId| in this lambda create copies.
    std::thread validate_thread([this, server, netId, mark] {
    std::thread validate_thread([this, server, netId] {
        setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());

        // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
@@ -199,9 +199,10 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign
        while (true) {
            // ::validate() is a blocking call that performs network operations.
            // It can take milliseconds to minutes, up to the SYN retry limit.
            LOG(WARNING) << "Validating DnsTlsServer on netId " << netId;
            const bool success = DnsTlsTransport::validate(server, netId, mark);
            LOG(DEBUG) << "validateDnsTlsServer returned " << success << " for "
            LOG(WARNING) << "Validating DnsTlsServer " << server.toIpString() << " with mark 0x"
                         << std::hex << server.mark;
            const bool success = DnsTlsTransport::validate(server, server.mark);
            LOG(WARNING) << "validateDnsTlsServer returned " << success << " for "
                         << server.toIpString();

            const bool needs_reeval = this->recordPrivateDnsValidation(server, netId, success);
@@ -231,14 +232,14 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        LOG(WARNING) << "netId " << netId << " was erased during private DNS validation";
        maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
        notifyValidationStateUpdate(identity.ip.toString(), Validation::fail, netId);
        return DONT_REEVALUATE;
    }

    const auto mode = mPrivateDnsModes.find(netId);
    if (mode == mPrivateDnsModes.end()) {
        LOG(WARNING) << "netId " << netId << " has no private DNS validation mode";
        maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
        notifyValidationStateUpdate(identity.ip.toString(), Validation::fail, netId);
        return DONT_REEVALUATE;
    }
    const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT);
@@ -299,18 +300,18 @@ void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity,
                                                uint32_t netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
        notifyValidationStateUpdate(identity.ip.toString(), Validation::fail, netId);
        return;
    }

    auto& tracker = netPair->second;
    if (tracker.find(identity) == tracker.end()) {
        maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
        notifyValidationStateUpdate(identity.ip.toString(), Validation::fail, netId);
        return;
    }

    tracker[identity].setValidationState(state);
    maybeNotifyObserver(identity.ip.toString(), state, netId);
    notifyValidationStateUpdate(identity.ip.toString(), state, netId);
}

bool PrivateDnsConfiguration::needsValidation(const DnsTlsServer& server) {
@@ -334,8 +335,9 @@ void PrivateDnsConfiguration::setObserver(Observer* observer) {
    mObserver = observer;
}

void PrivateDnsConfiguration::maybeNotifyObserver(const std::string& serverIp,
                                                  Validation validation, uint32_t netId) const {
void PrivateDnsConfiguration::notifyValidationStateUpdate(const std::string& serverIp,
                                                          Validation validation,
                                                          uint32_t netId) const {
    if (mObserver) {
        mObserver->onValidationStateUpdate(serverIp, validation, netId);
    }
+3 −4
Original line number Diff line number Diff line
@@ -94,8 +94,7 @@ class PrivateDnsConfiguration {

    PrivateDnsConfiguration() = default;

    void startValidation(const DnsTlsServer& server, unsigned netId, uint32_t mark)
            REQUIRES(mPrivateDnsLock);
    void startValidation(const DnsTlsServer& server, unsigned netId) REQUIRES(mPrivateDnsLock);

    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success)
            EXCLUDES(mPrivateDnsLock);
@@ -131,7 +130,7 @@ class PrivateDnsConfiguration {
    };

    void setObserver(Observer* observer);
    void maybeNotifyObserver(const std::string& serverIp, Validation validation,
    void notifyValidationStateUpdate(const std::string& serverIp, Validation validation,
                                     uint32_t netId) const REQUIRES(mPrivateDnsLock);

    Observer* mObserver GUARDED_BY(mPrivateDnsLock);