Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 78d9ed12 authored by Mike Yu's avatar Mike Yu Committed by Automerger Merge Worker
Browse files

Apply IPrivateDnsServer to PrivateDnsConfiguration am: cf56d23d am: 4743607e am: f87a3001

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1690509

Change-Id: I75014cd118d4f74431d42dad2571455bb39733af
parents ba5b5f83 f87a3001
Loading
Loading
Loading
Loading
+25 −23
Original line number Diff line number Diff line
@@ -70,11 +70,11 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
        if (!parseServer(s.c_str(), &parsed)) {
            return -EINVAL;
        }
        DnsTlsServer server(parsed);
        server.name = name;
        server.certificate = caCert;
        server.mark = mark;
        tmp[ServerIdentity(server)] = server;
        auto server = std::make_unique<DnsTlsServer>(parsed);
        server->name = name;
        server->certificate = caCert;
        server->mark = mark;
        tmp[ServerIdentity(*server)] = std::move(server);
    }

    std::lock_guard guard(mPrivateDnsLock);
@@ -93,22 +93,22 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
    auto& tracker = mPrivateDnsTransports[netId];

    // Add the servers if not contained in tracker.
    for (const auto& [identity, server] : tmp) {
    for (auto& [identity, server] : tmp) {
        if (tracker.find(identity) == tracker.end()) {
            tracker[identity] = server;
            tracker[identity] = std::move(server);
        }
    }

    for (auto& [identity, server] : tracker) {
        const bool active = tmp.find(identity) != tmp.end();
        server.setActive(active);
        server->setActive(active);

        // For simplicity, deem the validation result of inactive servers as unreliable.
        if (!server.active() && server.validationState() == Validation::success) {
        if (!server->active() && server->validationState() == Validation::success) {
            updateServerState(identity, Validation::success_but_expired, netId);
        }

        if (needsValidation(server)) {
        if (needsValidation(*server)) {
            updateServerState(identity, Validation::in_process, netId);
            startValidation(identity, netId, false);
        }
@@ -128,9 +128,11 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) const {
    const auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair != mPrivateDnsTransports.end()) {
        for (const auto& [_, server] : netPair->second) {
            if (server.active()) {
                status.serversMap.emplace(server, server.validationState());
            if (server->isDot() && server->active()) {
                DnsTlsServer& dotServer = *static_cast<DnsTlsServer*>(server.get());
                status.serversMap.emplace(dotServer, server->validationState());
            }
            // TODO: also add DoH server to the map.
        }
    }

@@ -164,18 +166,18 @@ base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
        return result.error();
    }

    const DnsTlsServer* target = result.value();
    const IPrivateDnsServer* server = result.value();

    if (!target->active()) return Errorf("Server is not active");
    if (!server->active()) return Errorf("Server is not active");

    if (target->validationState() != Validation::success) {
    if (server->validationState() != Validation::success) {
        return Errorf("Server validation state mismatched");
    }

    // Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
    // This is to protect validation from running on unexpected marks.
    // Validation should be associated with a mark gotten by system permission.
    if (target->validationMark() != mark) return Errorf("Socket mark mismatched");
    if (server->validationMark() != mark) return Errorf("Socket mark mismatched");

    updateServerState(identity, Validation::in_process, netId);
    startValidation(identity, netId, true);
@@ -189,7 +191,7 @@ void PrivateDnsConfiguration::startValidation(const ServerIdentity& identity, un
    // TODO: consider moving these code to the thread.
    const auto result = getPrivateDnsLocked(identity, netId);
    if (!result.ok()) return;
    DnsTlsServer server = *result.value();
    DnsTlsServer server = *static_cast<const DnsTlsServer*>(result.value());

    std::thread validate_thread([this, identity, server, netId, isRevalidation] {
        setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());
@@ -307,7 +309,7 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const ServerIdentity& i
                     << " was removed during private DNS validation";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    } else if (!serverPair->second.active()) {
    } else if (!serverPair->second->active()) {
        LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
                     << " was removed from the configuration";
        success = false;
@@ -349,7 +351,7 @@ void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity,
    mPrivateDnsLog.push(std::move(record));
}

bool PrivateDnsConfiguration::needsValidation(const DnsTlsServer& server) {
bool PrivateDnsConfiguration::needsValidation(const IPrivateDnsServer& server) const {
    // The server is not expected to be used on the network.
    if (!server.active()) return false;

@@ -365,13 +367,13 @@ bool PrivateDnsConfiguration::needsValidation(const DnsTlsServer& server) {
    return false;
}

base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDns(const ServerIdentity& identity,
                                                                   unsigned netId) {
base::Result<IPrivateDnsServer*> PrivateDnsConfiguration::getPrivateDns(
        const ServerIdentity& identity, unsigned netId) {
    std::lock_guard guard(mPrivateDnsLock);
    return getPrivateDnsLocked(identity, netId);
}

base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
base::Result<IPrivateDnsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
        const ServerIdentity& identity, unsigned netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
@@ -384,7 +386,7 @@ base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
                      identity.provider);
    }

    return &iter->second;
    return iter->second.get();
}

void PrivateDnsConfiguration::setObserver(PrivateDnsValidationObserver* observer) {
+7 −8
Original line number Diff line number Diff line
@@ -58,8 +58,8 @@ class PrivateDnsConfiguration {
        const netdutils::IPSockAddr sockaddr;
        const std::string provider;

        explicit ServerIdentity(const DnsTlsServer& server)
            : sockaddr(netdutils::IPSockAddr::toIPSockAddr(server.ss)), provider(server.name) {}
        explicit ServerIdentity(const IPrivateDnsServer& server)
            : sockaddr(server.addr()), provider(server.provider()) {}

        bool operator<(const ServerIdentity& other) const {
            return std::tie(sockaddr, provider) < std::tie(other.sockaddr, other.provider);
@@ -92,7 +92,7 @@ class PrivateDnsConfiguration {
    void dump(netdutils::DumpWriter& dw) const;

  private:
    typedef std::map<ServerIdentity, DnsTlsServer> PrivateDnsTracker;
    typedef std::map<ServerIdentity, std::unique_ptr<IPrivateDnsServer>> PrivateDnsTracker;

    PrivateDnsConfiguration() = default;

@@ -110,18 +110,17 @@ class PrivateDnsConfiguration {
    // Decide if a validation for |server| is needed. Note that servers that have failed
    // multiple validation attempts but for which there is still a validating
    // thread running are marked as being Validation::in_process.
    // TODO: decouple the dependency of DnsTlsServer.
    bool needsValidation(const DnsTlsServer& server) REQUIRES(mPrivateDnsLock);
    bool needsValidation(const IPrivateDnsServer& server) const REQUIRES(mPrivateDnsLock);

    void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
            REQUIRES(mPrivateDnsLock);

    // For testing.
    base::Result<DnsTlsServer*> getPrivateDns(const ServerIdentity& identity, unsigned netId)
    base::Result<IPrivateDnsServer*> getPrivateDns(const ServerIdentity& identity, unsigned netId)
            EXCLUDES(mPrivateDnsLock);

    base::Result<DnsTlsServer*> getPrivateDnsLocked(const ServerIdentity& identity, unsigned netId)
            REQUIRES(mPrivateDnsLock);
    base::Result<IPrivateDnsServer*> getPrivateDnsLocked(const ServerIdentity& identity,
                                                         unsigned netId) REQUIRES(mPrivateDnsLock);

    mutable std::mutex mPrivateDnsLock;
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);