Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 394d1722 authored by Mike Yu's avatar Mike Yu
Browse files

Change setDoh and clearDoh to private method

Now setDoh() is called in PrivateDnsConfiguration::set() and
clearDoh() is called in PrivateDnsConfiguration::clear().

Also add setDot() and clearDot().

Bug: 239659682
Test: atest
Change-Id: I10bbba46474e1e45fffbb96460c14c9814924711
parent c1d47171
Loading
Loading
Loading
Loading
+60 −29
Original line number Diff line number Diff line
@@ -46,40 +46,66 @@ using std::chrono::milliseconds;
namespace android {
namespace net {

namespace {

bool ensureNoInvalidIp(const std::vector<std::string>& servers) {
    IPAddress ip;
    for (const auto& s : servers) {
        if (!IPAddress::forString(s, &ip)) {
            LOG(WARNING) << "Invalid IP address: " << s;
            return false;
        }
    }
    return true;
}

}  // namespace

int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
                                 const std::vector<std::string>& servers, const std::string& name,
                                 const std::string& caCert) {
    LOG(DEBUG) << "PrivateDnsConfiguration::set(" << netId << ", 0x" << std::hex << mark << std::dec
               << ", " << servers.size() << ", " << name << ")";

    // Parse the list of servers that has been passed in
    PrivateDnsTracker tmp;
    for (const auto& s : servers) {
        IPAddress ip;
        if (!IPAddress::forString(s, &ip)) {
            LOG(WARNING) << "Failed to parse server address (" << s << ")";
            return -EINVAL;
        }

        auto server = std::make_unique<DnsTlsServer>(ip);
        server->name = name;
        server->certificate = caCert;
        server->mark = mark;
        tmp[ServerIdentity(*server)] = std::move(server);
    }
    if (!ensureNoInvalidIp(servers)) return -EINVAL;

    std::lock_guard guard(mPrivateDnsLock);
    if (!name.empty()) {
        mPrivateDnsModes[netId] = PrivateDnsMode::STRICT;
    } else if (!tmp.empty()) {
    } else if (!servers.empty()) {
        mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
    } else {
        mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
        mPrivateDnsTransports.erase(netId);
        clearDot(netId);
        clearDoh(netId);
        return 0;
        // TODO: signal validation threads to stop.
    }

    if (int n = setDot(netId, mark, servers, name, caCert); n != 0) {
        return n;
    }
    if (isDoHEnabled()) {
        return setDoh(netId, mark, servers, name, caCert);
    }

    return 0;
}

int PrivateDnsConfiguration::setDot(int32_t netId, uint32_t mark,
                                    const std::vector<std::string>& servers,
                                    const std::string& name, const std::string& caCert) {
    // Parse the list of servers that has been passed in
    PrivateDnsTracker tmp;
    for (const auto& s : servers) {
        // The IP addresses are guaranteed to be valid.
        auto server = std::make_unique<DnsTlsServer>(IPAddress::forString(s));
        server->name = name;
        server->certificate = caCert;
        server->mark = mark;
        tmp[ServerIdentity(*server)] = std::move(server);
    }

    // Create the tracker if it was not present
    auto& tracker = mPrivateDnsTransports[netId];

@@ -105,9 +131,19 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
        }
    }

    if (int n = resolv_stats_set_addrs(netId, PROTO_DOT, servers, kDotPort); n != 0) {
        LOG(WARNING) << "Failed to set DoT stats";
        return n;
    }

    return 0;
}

void PrivateDnsConfiguration::clearDot(int32_t netId) {
    mPrivateDnsTransports.erase(netId);
    resolv_stats_set_addrs(netId, PROTO_DOT, {}, kDotPort);
}

PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) const {
    PrivateDnsStatus status{
            .mode = PrivateDnsMode::OFF,
@@ -144,7 +180,8 @@ void PrivateDnsConfiguration::clear(unsigned netId) {
    LOG(DEBUG) << "PrivateDnsConfiguration::clear(" << netId << ")";
    std::lock_guard guard(mPrivateDnsLock);
    mPrivateDnsModes.erase(netId);
    mPrivateDnsTransports.erase(netId);
    clearDot(netId);
    clearDoh(netId);

    // Notify the relevant private DNS validations, if they are waiting, to finish.
    mCv.notify_all();
@@ -451,9 +488,8 @@ int PrivateDnsConfiguration::setDoh(int32_t netId, uint32_t mark,
                                    const std::string& name, const std::string& caCert) {
    LOG(DEBUG) << "PrivateDnsConfiguration::setDoh(" << netId << ", 0x" << std::hex << mark
               << std::dec << ", " << servers.size() << ", " << name << ")";
    std::lock_guard guard(mPrivateDnsLock);
    if (servers.empty()) {
        clearDohLocked(netId);
        clearDoh(netId);
        return 0;
    }

@@ -522,22 +558,17 @@ int PrivateDnsConfiguration::setDoh(int32_t netId, uint32_t mark,
    }

    LOG(INFO) << __func__ << ": No suitable DoH server found";
    clearDohLocked(netId);
    clearDoh(netId);
    return 0;
}

void PrivateDnsConfiguration::clearDohLocked(unsigned netId) {
    LOG(DEBUG) << "PrivateDnsConfiguration::clearDohLocked (" << netId << ")";
void PrivateDnsConfiguration::clearDoh(unsigned netId) {
    LOG(DEBUG) << "PrivateDnsConfiguration::clearDoh (" << netId << ")";
    if (mDohDispatcher != nullptr) doh_net_delete(mDohDispatcher, netId);
    mDohTracker.erase(netId);
    resolv_stats_set_addrs(netId, PROTO_DOH, {}, kDohPort);
}

void PrivateDnsConfiguration::clearDoh(unsigned netId) {
    std::lock_guard guard(mPrivateDnsLock);
    clearDohLocked(netId);
}

ssize_t PrivateDnsConfiguration::dohQuery(unsigned netId, const Slice query, const Slice answer,
                                          uint64_t timeoutMs) {
    {
+8 −6
Original line number Diff line number Diff line
@@ -105,15 +105,10 @@ class PrivateDnsConfiguration {

    void initDoh() EXCLUDES(mPrivateDnsLock);

    int setDoh(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
               const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);

    PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);

    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);

    void clearDoh(unsigned netId) EXCLUDES(mPrivateDnsLock);

    ssize_t dohQuery(unsigned netId, const netdutils::Slice query, const netdutils::Slice answer,
                     uint64_t timeoutMs) EXCLUDES(mPrivateDnsLock);

@@ -137,6 +132,11 @@ class PrivateDnsConfiguration {

    PrivateDnsConfiguration() = default;

    int setDot(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
               const std::string& name, const std::string& caCert) REQUIRES(mPrivateDnsLock);

    void clearDot(int32_t netId) REQUIRES(mPrivateDnsLock);

    // Launchs a thread to run the validation for |server| on the network |netId|.
    // |isRevalidation| is true if this call is due to a revalidation request.
    void startValidation(const ServerIdentity& identity, unsigned netId, bool isRevalidation)
@@ -164,7 +164,9 @@ class PrivateDnsConfiguration {
                                                         unsigned netId) REQUIRES(mPrivateDnsLock);

    void initDohLocked() REQUIRES(mPrivateDnsLock);
    void clearDohLocked(unsigned netId) REQUIRES(mPrivateDnsLock);
    int setDoh(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
               const std::string& name, const std::string& caCert) REQUIRES(mPrivateDnsLock);
    void clearDoh(unsigned netId) REQUIRES(mPrivateDnsLock);

    mutable std::mutex mPrivateDnsLock;
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
+6 −0
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@
#include <netdutils/NetNativeTestBase.h>

#include "PrivateDnsConfiguration.h"
#include "resolv_cache.h"
#include "tests/dns_responder/dns_responder.h"
#include "tests/dns_responder/dns_tls_frontend.h"
#include "tests/resolv_test_utils.h"
@@ -74,8 +75,13 @@ class PrivateDnsConfigurationTest : public NetNativeTestBase {
                    std::lock_guard guard(mObserver.lock);
                    mObserver.serverStateMap[server] = validation;
                });

        // Create a NetConfig for stats.
        EXPECT_EQ(0, resolv_create_cache_for_net(kNetId));
    }

    void TearDown() { resolv_delete_cache_for_net(kNetId); }

  protected:
    class MockObserver : public PrivateDnsValidationObserver {
      public:
+0 −15
Original line number Diff line number Diff line
@@ -169,7 +169,6 @@ void ResolverController::destroyNetworkCache(unsigned netId) {
    resolv_delete_cache_for_net(netId);
    mDns64Configuration.stopPrefixDiscovery(netId);
    PrivateDnsConfiguration::getInstance().clear(netId);
    if (isDoHEnabled()) PrivateDnsConfiguration::getInstance().clearDoh(netId);

    // Don't get this instance in PrivateDnsConfiguration. It's probe to deadlock.
    DnsTlsDispatcher::getInstance().forceCleanup(netId);
@@ -215,11 +214,6 @@ int ResolverController::setResolverConfiguration(const ResolverParamsParcel& res
        return err;
    }

    if (err = resolv_stats_set_addrs(resolverParams.netId, PROTO_DOT, tlsServers, 853);
        err != 0) {
        return err;
    }

    if (is_mdns_supported_transport_types(resolverParams.transportTypes)) {
        if (err = resolv_stats_set_addrs(resolverParams.netId, PROTO_MDNS,
                                         {"ff02::fb", "224.0.0.251"}, 5353);
@@ -228,15 +222,6 @@ int ResolverController::setResolverConfiguration(const ResolverParamsParcel& res
        }
    }

    if (isDoHEnabled()) {
        err = privateDnsConfiguration.setDoh(resolverParams.netId, netcontext.app_mark, tlsServers,
                                             resolverParams.tlsName, resolverParams.caCertificate);

        if (err != 0) {
            return err;
        }
    }

    res_params res_params = {};
    res_params.sample_validity = resolverParams.sampleValiditySeconds;
    res_params.success_threshold = resolverParams.successThreshold;