Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e37ac228 authored by Luke Huang's avatar Luke Huang Committed by Automerger Merge Worker
Browse files

Link Rust DoH into DnsResolver with default off am: 2fe9c73f

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1756590

Change-Id: I0acaa411a337f1005ca5685a913362514e073708
parents 7e7ca876 2fe9c73f
Loading
Loading
Loading
Loading
+1 −0
Original line number Original line Diff line number Diff line
@@ -192,6 +192,7 @@ cc_library {
        "libbase",
        "libbase",
        "libcutils",
        "libcutils",
        "libnetdutils",
        "libnetdutils",
        "libdoh_ffi",
        "libprotobuf-cpp-lite",
        "libprotobuf-cpp-lite",
        "libstatslog_resolv",
        "libstatslog_resolv",
        "libstatspush_compat",
        "libstatspush_compat",
+1 −0
Original line number Original line Diff line number Diff line
@@ -81,6 +81,7 @@ DnsResolver::DnsResolver() {
    auto& dnsTlsDispatcher = DnsTlsDispatcher::getInstance();
    auto& dnsTlsDispatcher = DnsTlsDispatcher::getInstance();
    auto& privateDnsConfiguration = PrivateDnsConfiguration::getInstance();
    auto& privateDnsConfiguration = PrivateDnsConfiguration::getInstance();
    privateDnsConfiguration.setObserver(&dnsTlsDispatcher);
    privateDnsConfiguration.setObserver(&dnsTlsDispatcher);
    if (isDoHEnabled()) privateDnsConfiguration.initDoh();
}
}


bool DnsResolver::start() {
bool DnsResolver::start() {
+138 −3
Original line number Original line Diff line number Diff line
@@ -21,18 +21,22 @@
#include <android-base/format.h>
#include <android-base/format.h>
#include <android-base/logging.h>
#include <android-base/logging.h>
#include <android-base/stringprintf.h>
#include <android-base/stringprintf.h>
#include <netdutils/Slice.h>
#include <netdutils/ThreadUtil.h>
#include <netdutils/ThreadUtil.h>
#include <sys/socket.h>
#include <sys/socket.h>


#include "DnsTlsTransport.h"
#include "DnsTlsTransport.h"
#include "ResolverEventReporter.h"
#include "ResolverEventReporter.h"
#include "doh.h"
#include "netd_resolv/resolv.h"
#include "netd_resolv/resolv.h"
#include "resolv_private.h"
#include "util.h"
#include "util.h"


using aidl::android::net::resolv::aidl::IDnsResolverUnsolicitedEventListener;
using aidl::android::net::resolv::aidl::IDnsResolverUnsolicitedEventListener;
using aidl::android::net::resolv::aidl::PrivateDnsValidationEventParcel;
using aidl::android::net::resolv::aidl::PrivateDnsValidationEventParcel;
using android::base::StringPrintf;
using android::base::StringPrintf;
using android::netdutils::setThreadName;
using android::netdutils::setThreadName;
using android::netdutils::Slice;
using std::chrono::milliseconds;
using std::chrono::milliseconds;


namespace android {
namespace android {
@@ -238,9 +242,9 @@ void PrivateDnsConfiguration::startValidation(const ServerIdentity& identity, un
}
}


void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const ServerIdentity& identity,
void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const ServerIdentity& identity,
                                                            unsigned netId, bool success) {
                                                            unsigned netId, bool success) const {
    LOG(DEBUG) << "Sending validation " << (success ? "success" : "failure") << " event on netId "
    LOG(DEBUG) << "Sending validation " << (success ? "success" : "failure") << " event on netId "
               << netId << " for " << identity.sockaddr.ip().toString() << " with hostname {"
               << netId << " for " << identity.sockaddr.toString() << " with hostname {"
               << identity.provider << "}";
               << identity.provider << "}";
    // Send a validation event to NetdEventListenerService.
    // Send a validation event to NetdEventListenerService.
    const auto& listeners = ResolverEventReporter::getInstance().getListeners();
    const auto& listeners = ResolverEventReporter::getInstance().getListeners();
@@ -313,7 +317,9 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const ServerIdentity& i
    }
    }


    // Send private dns validation result to listeners.
    // Send private dns validation result to listeners.
    if (needReportEvent(netId, identity, success)) {
        sendPrivateDnsValidationEvent(identity, netId, success);
        sendPrivateDnsValidationEvent(identity, netId, success);
    }


    if (success) {
    if (success) {
        updateServerState(identity, Validation::success, netId);
        updateServerState(identity, Validation::success, netId);
@@ -411,5 +417,134 @@ void PrivateDnsConfiguration::dump(netdutils::DumpWriter& dw) const {
    dw.blankline();
    dw.blankline();
}
}


void PrivateDnsConfiguration::initDoh() {
    std::lock_guard guard(mPrivateDnsLock);
    initDohLocked();
}

void PrivateDnsConfiguration::initDohLocked() {
    if (mDohDispatcher != nullptr) return;
    mDohDispatcher = doh_dispatcher_new(
            [](uint32_t net_id, bool success, const char* ip_addr, const char* host) {
                android::net::PrivateDnsConfiguration::getInstance().onDohStatusUpdate(
                        net_id, success, ip_addr, host);
            });
}

int PrivateDnsConfiguration::setDoh(int32_t netId, uint32_t mark,
                                    const std::vector<std::string>& servers,
                                    const std::string& name, const std::string& caCert) {
    if (servers.empty()) return 0;
    LOG(DEBUG) << "PrivateDnsConfiguration::setDoh(" << netId << ", 0x" << std::hex << mark
               << std::dec << ", " << servers.size() << ", " << name << ")";
    std::lock_guard guard(mPrivateDnsLock);

    initDohLocked();

    // TODO: 1. Improve how to choose the server
    // TODO: 2. Support multiple servers
    for (const auto& entry : mAvailableDoHProviders) {
        const auto& doh = entry.getDohIdentity(servers, name);
        if (!doh.ok()) continue;

        auto it = mDohTracker.find(netId);
        // Skip if the same server already exists and its status == success.
        if (it != mDohTracker.end() && it->second == doh.value() &&
            it->second.status == Validation::success) {
            return 0;
        }
        const auto& [dohIt, _] = mDohTracker.insert_or_assign(netId, doh.value());
        const auto& dohId = dohIt->second;

        RecordEntry record(netId, {netdutils::IPSockAddr::toIPSockAddr(dohId.ipAddr, 443), name},
                           dohId.status);
        mPrivateDnsLog.push(std::move(record));
        return doh_net_new(mDohDispatcher, netId, dohId.httpsTemplate.c_str(), dohId.host.c_str(),
                           dohId.ipAddr.c_str(), mark, caCert.c_str(), 3000);
    }

    LOG(INFO) << __func__ << "No suitable DoH server found";
    return 0;
}

void PrivateDnsConfiguration::clearDoh(unsigned netId) {
    LOG(DEBUG) << "PrivateDnsConfiguration::clearDoh (" << netId << ")";
    std::lock_guard guard(mPrivateDnsLock);
    if (mDohDispatcher != nullptr) doh_net_delete(mDohDispatcher, netId);
    mDohTracker.erase(netId);
}

ssize_t PrivateDnsConfiguration::dohQuery(unsigned netId, const Slice query, const Slice answer,
                                          uint64_t timeoutMs) {
    {
        std::lock_guard guard(mPrivateDnsLock);
        // It's safe because mDohDispatcher won't be deleted after initializing.
        if (mDohDispatcher == nullptr) return RESULT_CAN_NOT_SEND;
    }
    return doh_query(mDohDispatcher, netId, query.base(), query.size(), answer.base(),
                     answer.size(), timeoutMs);
}

void PrivateDnsConfiguration::onDohStatusUpdate(uint32_t netId, bool success, const char* ipAddr,
                                                const char* host) {
    LOG(INFO) << __func__ << netId << ", " << success << ", " << ipAddr << ", " << host;
    std::lock_guard guard(mPrivateDnsLock);
    // Update the server status.
    auto it = mDohTracker.find(netId);
    if (it == mDohTracker.end() || (it->second.ipAddr != ipAddr && it->second.host != host)) {
        LOG(WARNING) << __func__ << "obsolete event";
        return;
    }
    Validation status = success ? Validation::success : Validation::fail;
    it->second.status = status;
    // Send the events to registered listeners.
    ServerIdentity identity = {netdutils::IPSockAddr::toIPSockAddr(ipAddr, 443), host};
    if (needReportEvent(netId, identity, success)) {
        sendPrivateDnsValidationEvent(identity, netId, success);
    }
    // Add log.
    RecordEntry record(netId, identity, status);
    mPrivateDnsLog.push(std::move(record));
}

bool PrivateDnsConfiguration::needReportEvent(uint32_t netId, ServerIdentity identity,
                                              bool success) const {
    // If the result is success or DoH is not enable, no concern to report the events.
    if (success || !isDoHEnabled()) return true;
    // If the result is failure, check another transport's status to determine if we should report
    // the event.
    switch (identity.sockaddr.port()) {
        // DoH
        case 443: {
            auto netPair = mPrivateDnsTransports.find(netId);
            if (netPair == mPrivateDnsTransports.end()) return true;
            for (const auto& [id, server] : netPair->second) {
                if ((identity.sockaddr.ip() == id.sockaddr.ip()) &&
                    (identity.sockaddr.port() != id.sockaddr.port()) &&
                    (server->validationState() == Validation::success)) {
                    LOG(DEBUG) << __func__
                               << "skip reporting DoH validation failure event, server addr: " +
                                          identity.sockaddr.ip().toString();
                    return false;
                }
            }
            break;
        }
        // DoT
        case 853: {
            auto it = mDohTracker.find(netId);
            if (it == mDohTracker.end()) return true;
            if (it->second == identity && it->second.status == Validation::success) {
                LOG(DEBUG) << __func__
                           << "skip reporting DoT validation failure event, server addr: " +
                                      identity.sockaddr.ip().toString();
                return false;
            }
            break;
        }
    }
    return true;
}

}  // namespace net
}  // namespace net
}  // namespace android
}  // namespace android
+81 −2
Original line number Original line Diff line number Diff line
@@ -16,20 +16,25 @@


#pragma once
#pragma once


#include <array>
#include <list>
#include <list>
#include <map>
#include <map>
#include <mutex>
#include <mutex>
#include <vector>
#include <vector>


#include <android-base/format.h>
#include <android-base/logging.h>
#include <android-base/result.h>
#include <android-base/result.h>
#include <android-base/thread_annotations.h>
#include <android-base/thread_annotations.h>
#include <netdutils/BackoffSequence.h>
#include <netdutils/BackoffSequence.h>
#include <netdutils/DumpWriter.h>
#include <netdutils/DumpWriter.h>
#include <netdutils/InternetAddresses.h>
#include <netdutils/InternetAddresses.h>
#include <netdutils/Slice.h>


#include "DnsTlsServer.h"
#include "DnsTlsServer.h"
#include "LockedQueue.h"
#include "LockedQueue.h"
#include "PrivateDnsValidationObserver.h"
#include "PrivateDnsValidationObserver.h"
#include "doh.h"


namespace android {
namespace android {
namespace net {
namespace net {
@@ -61,6 +66,8 @@ class PrivateDnsConfiguration {


        explicit ServerIdentity(const IPrivateDnsServer& server)
        explicit ServerIdentity(const IPrivateDnsServer& server)
            : sockaddr(server.addr()), provider(server.provider()) {}
            : sockaddr(server.addr()), provider(server.provider()) {}
        ServerIdentity(const netdutils::IPSockAddr& addr, const std::string& host)
            : sockaddr(addr), provider(host) {}


        bool operator<(const ServerIdentity& other) const {
        bool operator<(const ServerIdentity& other) const {
            return std::tie(sockaddr, provider) < std::tie(other.sockaddr, other.provider);
            return std::tie(sockaddr, provider) < std::tie(other.sockaddr, other.provider);
@@ -79,10 +86,20 @@ class PrivateDnsConfiguration {
    int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
    int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
            const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);
            const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);


    void initDoh() EXCLUDES(mPrivateDnsLock);

    int setDoh(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
               const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);

    PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);
    PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);


    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);
    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);


    void clearDoh(unsigned netId) EXCLUDES(mPrivateDnsLock);

    ssize_t dohQuery(unsigned netId, const netdutils::Slice query, const netdutils::Slice answer,
                     uint64_t timeoutMs) EXCLUDES(mPrivateDnsLock);

    // Request the server to be revalidated on a connection tagged with |mark|.
    // Request the server to be revalidated on a connection tagged with |mark|.
    // Returns a Result to indicate if the request is accepted.
    // Returns a Result to indicate if the request is accepted.
    base::Result<void> requestValidation(unsigned netId, const ServerIdentity& identity,
    base::Result<void> requestValidation(unsigned netId, const ServerIdentity& identity,
@@ -92,6 +109,9 @@ class PrivateDnsConfiguration {


    void dump(netdutils::DumpWriter& dw) const;
    void dump(netdutils::DumpWriter& dw) const;


    void onDohStatusUpdate(uint32_t netId, bool success, const char* ipAddr, const char* host)
            EXCLUDES(mPrivateDnsLock);

  private:
  private:
    typedef std::map<ServerIdentity, std::unique_ptr<IPrivateDnsServer>> PrivateDnsTracker;
    typedef std::map<ServerIdentity, std::unique_ptr<IPrivateDnsServer>> PrivateDnsTracker;


@@ -105,8 +125,8 @@ class PrivateDnsConfiguration {
    bool recordPrivateDnsValidation(const ServerIdentity& identity, unsigned netId, bool success,
    bool recordPrivateDnsValidation(const ServerIdentity& identity, unsigned netId, bool success,
                                    bool isRevalidation) EXCLUDES(mPrivateDnsLock);
                                    bool isRevalidation) EXCLUDES(mPrivateDnsLock);


    void sendPrivateDnsValidationEvent(const ServerIdentity& identity, unsigned netId, bool success)
    void sendPrivateDnsValidationEvent(const ServerIdentity& identity, unsigned netId,
            REQUIRES(mPrivateDnsLock);
                                       bool success) const REQUIRES(mPrivateDnsLock);


    // Decide if a validation for |server| is needed. Note that servers that have failed
    // Decide if a validation for |server| is needed. Note that servers that have failed
    // multiple validation attempts but for which there is still a validating
    // multiple validation attempts but for which there is still a validating
@@ -123,6 +143,8 @@ class PrivateDnsConfiguration {
    base::Result<IPrivateDnsServer*> getPrivateDnsLocked(const ServerIdentity& identity,
    base::Result<IPrivateDnsServer*> getPrivateDnsLocked(const ServerIdentity& identity,
                                                         unsigned netId) REQUIRES(mPrivateDnsLock);
                                                         unsigned netId) REQUIRES(mPrivateDnsLock);


    void initDohLocked() REQUIRES(mPrivateDnsLock);

    mutable std::mutex mPrivateDnsLock;
    mutable std::mutex mPrivateDnsLock;
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);


@@ -135,9 +157,14 @@ class PrivateDnsConfiguration {
    void notifyValidationStateUpdate(const netdutils::IPSockAddr& sockaddr, Validation validation,
    void notifyValidationStateUpdate(const netdutils::IPSockAddr& sockaddr, Validation validation,
                                     uint32_t netId) const REQUIRES(mPrivateDnsLock);
                                     uint32_t netId) const REQUIRES(mPrivateDnsLock);


    bool needReportEvent(uint32_t netId, ServerIdentity identity, bool success) const
            REQUIRES(mPrivateDnsLock);

    // TODO: fix the reentrancy problem.
    // TODO: fix the reentrancy problem.
    PrivateDnsValidationObserver* mObserver GUARDED_BY(mPrivateDnsLock);
    PrivateDnsValidationObserver* mObserver GUARDED_BY(mPrivateDnsLock);


    DohDispatcher* mDohDispatcher;

    friend class PrivateDnsConfigurationTest;
    friend class PrivateDnsConfigurationTest;


    // It's not const because PrivateDnsConfigurationTest needs to override it.
    // It's not const because PrivateDnsConfigurationTest needs to override it.
@@ -147,6 +174,58 @@ class PrivateDnsConfiguration {
                    .withInitialRetransmissionTime(std::chrono::seconds(60))
                    .withInitialRetransmissionTime(std::chrono::seconds(60))
                    .withMaximumRetransmissionTime(std::chrono::seconds(3600));
                    .withMaximumRetransmissionTime(std::chrono::seconds(3600));


    struct DohIdentity {
        std::string httpsTemplate;
        std::string ipAddr;
        std::string host;
        Validation status;
        bool operator<(const DohIdentity& other) const {
            return std::tie(ipAddr, host) < std::tie(other.ipAddr, other.host);
        }
        bool operator==(const DohIdentity& other) const {
            return std::tie(ipAddr, host) == std::tie(other.ipAddr, other.host);
        }
        bool operator<(const ServerIdentity& other) const {
            std::string otherIp = other.sockaddr.ip().toString();
            return std::tie(ipAddr, host) < std::tie(otherIp, other.provider);
        }
        bool operator==(const ServerIdentity& other) const {
            std::string otherIp = other.sockaddr.ip().toString();
            return std::tie(ipAddr, host) == std::tie(otherIp, other.provider);
        }
    };

    struct DohProviderEntry {
        std::string provider;
        std::set<std::string> ips;
        std::string host;
        std::string httpsTemplate;
        base::Result<DohIdentity> getDohIdentity(const std::vector<std::string>& ips,
                                                 const std::string& host) const {
            if (!host.empty() && this->host != host) return Errorf("host {} not matched", host);
            for (const auto& ip : ips) {
                if (this->ips.find(ip) == this->ips.end()) continue;
                LOG(INFO) << fmt::format("getDohIdentity: {} {}", ip, host);
                // Only pick the first one for now.
                return DohIdentity{httpsTemplate, ip, host, Validation::in_process};
            }
            return Errorf("server not matched");
        };
    };

    // TODO: Move below DoH relevant stuff into Rust implementation.
    std::map<unsigned, DohIdentity> mDohTracker GUARDED_BY(mPrivateDnsLock);
    std::array<DohProviderEntry, 2> mAvailableDoHProviders = {{
            {"Google",
             {"2001:4860:4860::8888", "2001:4860:4860::8844", "8.8.8.8", "8.8.4.4"},
             "dns.google",
             "https://dns.google/dns-query"},
            {"Cloudflare",
             {"2606:4700::6810:f8f9", "2606:4700::6810:f9f9", "104.16.248.249", "104.16.249.249"},
             "cloudflare-dns.com",
             "https://cloudflare-dns.com/dns-query"},
    }};

    struct RecordEntry {
    struct RecordEntry {
        RecordEntry(uint32_t netId, const ServerIdentity& identity, Validation state)
        RecordEntry(uint32_t netId, const ServerIdentity& identity, Validation state)
            : netId(netId), serverIdentity(identity), state(state) {}
            : netId(netId), serverIdentity(identity), state(state) {}
+14 −3
Original line number Original line Diff line number Diff line
@@ -34,6 +34,7 @@
#include "ResolverStats.h"
#include "ResolverStats.h"
#include "resolv_cache.h"
#include "resolv_cache.h"
#include "stats.h"
#include "stats.h"
#include "util.h"


using aidl::android::net::ResolverParamsParcel;
using aidl::android::net::ResolverParamsParcel;
using aidl::android::net::resolv::aidl::IDnsResolverUnsolicitedEventListener;
using aidl::android::net::resolv::aidl::IDnsResolverUnsolicitedEventListener;
@@ -169,6 +170,7 @@ void ResolverController::destroyNetworkCache(unsigned netId) {
    resolv_delete_cache_for_net(netId);
    resolv_delete_cache_for_net(netId);
    mDns64Configuration.stopPrefixDiscovery(netId);
    mDns64Configuration.stopPrefixDiscovery(netId);
    PrivateDnsConfiguration::getInstance().clear(netId);
    PrivateDnsConfiguration::getInstance().clear(netId);
    if (isDoHEnabled()) PrivateDnsConfiguration::getInstance().clearDoh(netId);


    // Don't get this instance in PrivateDnsConfiguration. It's probe to deadlock.
    // Don't get this instance in PrivateDnsConfiguration. It's probe to deadlock.
    DnsTlsDispatcher::getInstance().forceCleanup(netId);
    DnsTlsDispatcher::getInstance().forceCleanup(netId);
@@ -206,8 +208,8 @@ int ResolverController::setResolverConfiguration(const ResolverParamsParcel& res
    // through a different network. For example, on a VPN with no DNS servers (Do53), if the VPN
    // through a different network. For example, on a VPN with no DNS servers (Do53), if the VPN
    // applies to UID 0, dns_mark is assigned for default network rathan the VPN. (note that it's
    // applies to UID 0, dns_mark is assigned for default network rathan the VPN. (note that it's
    // possible that a VPN doesn't have any DNS servers but DoT servers in DNS strict mode)
    // possible that a VPN doesn't have any DNS servers but DoT servers in DNS strict mode)
    const int err = PrivateDnsConfiguration::getInstance().set(
    int err = PrivateDnsConfiguration::getInstance().set(resolverParams.netId, netcontext.app_mark,
            resolverParams.netId, netcontext.app_mark, tlsServers, resolverParams.tlsName,
                                                         tlsServers, resolverParams.tlsName,
                                                         resolverParams.caCertificate);
                                                         resolverParams.caCertificate);


    if (err != 0) {
    if (err != 0) {
@@ -225,6 +227,15 @@ int ResolverController::setResolverConfiguration(const ResolverParamsParcel& res
        return err;
        return err;
    }
    }


    if (isDoHEnabled())
        err = PrivateDnsConfiguration::getInstance().setDoh(
                resolverParams.netId, netcontext.app_mark, tlsServers, resolverParams.tlsName,
                resolverParams.caCertificate);

    if (err != 0) {
        return err;
    }

    res_params res_params = {};
    res_params res_params = {};
    res_params.sample_validity = resolverParams.sampleValiditySeconds;
    res_params.sample_validity = resolverParams.sampleValiditySeconds;
    res_params.success_threshold = resolverParams.successThreshold;
    res_params.success_threshold = resolverParams.successThreshold;
Loading