Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit cf56d23d authored by Mike Yu's avatar Mike Yu
Browse files

Apply IPrivateDnsServer to PrivateDnsConfiguration

Change to use IPrivateDnsServer instead of DnsTlsServer as much as
possible.

Bug: 186177613
Test: cd packages/modules/DnsResolver && atest
Change-Id: I61081a0db5f53311f4335748f980467fcd4bd3e3
parent 690b19fa
Loading
Loading
Loading
Loading
+25 −23
Original line number Diff line number Diff line
@@ -70,11 +70,11 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
        if (!parseServer(s.c_str(), &parsed)) {
            return -EINVAL;
        }
        DnsTlsServer server(parsed);
        server.name = name;
        server.certificate = caCert;
        server.mark = mark;
        tmp[ServerIdentity(server)] = server;
        auto server = std::make_unique<DnsTlsServer>(parsed);
        server->name = name;
        server->certificate = caCert;
        server->mark = mark;
        tmp[ServerIdentity(*server)] = std::move(server);
    }

    std::lock_guard guard(mPrivateDnsLock);
@@ -93,22 +93,22 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
    auto& tracker = mPrivateDnsTransports[netId];

    // Add the servers if not contained in tracker.
    for (const auto& [identity, server] : tmp) {
    for (auto& [identity, server] : tmp) {
        if (tracker.find(identity) == tracker.end()) {
            tracker[identity] = server;
            tracker[identity] = std::move(server);
        }
    }

    for (auto& [identity, server] : tracker) {
        const bool active = tmp.find(identity) != tmp.end();
        server.setActive(active);
        server->setActive(active);

        // For simplicity, deem the validation result of inactive servers as unreliable.
        if (!server.active() && server.validationState() == Validation::success) {
        if (!server->active() && server->validationState() == Validation::success) {
            updateServerState(identity, Validation::success_but_expired, netId);
        }

        if (needsValidation(server)) {
        if (needsValidation(*server)) {
            updateServerState(identity, Validation::in_process, netId);
            startValidation(identity, netId, false);
        }
@@ -128,9 +128,11 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) const {
    const auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair != mPrivateDnsTransports.end()) {
        for (const auto& [_, server] : netPair->second) {
            if (server.active()) {
                status.serversMap.emplace(server, server.validationState());
            if (server->isDot() && server->active()) {
                DnsTlsServer& dotServer = *static_cast<DnsTlsServer*>(server.get());
                status.serversMap.emplace(dotServer, server->validationState());
            }
            // TODO: also add DoH server to the map.
        }
    }

@@ -164,18 +166,18 @@ base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
        return result.error();
    }

    const DnsTlsServer* target = result.value();
    const IPrivateDnsServer* server = result.value();

    if (!target->active()) return Errorf("Server is not active");
    if (!server->active()) return Errorf("Server is not active");

    if (target->validationState() != Validation::success) {
    if (server->validationState() != Validation::success) {
        return Errorf("Server validation state mismatched");
    }

    // Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
    // This is to protect validation from running on unexpected marks.
    // Validation should be associated with a mark gotten by system permission.
    if (target->validationMark() != mark) return Errorf("Socket mark mismatched");
    if (server->validationMark() != mark) return Errorf("Socket mark mismatched");

    updateServerState(identity, Validation::in_process, netId);
    startValidation(identity, netId, true);
@@ -189,7 +191,7 @@ void PrivateDnsConfiguration::startValidation(const ServerIdentity& identity, un
    // TODO: consider moving these code to the thread.
    const auto result = getPrivateDnsLocked(identity, netId);
    if (!result.ok()) return;
    DnsTlsServer server = *result.value();
    DnsTlsServer server = *static_cast<const DnsTlsServer*>(result.value());

    std::thread validate_thread([this, identity, server, netId, isRevalidation] {
        setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());
@@ -307,7 +309,7 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const ServerIdentity& i
                     << " was removed during private DNS validation";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    } else if (!serverPair->second.active()) {
    } else if (!serverPair->second->active()) {
        LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
                     << " was removed from the configuration";
        success = false;
@@ -349,7 +351,7 @@ void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity,
    mPrivateDnsLog.push(std::move(record));
}

bool PrivateDnsConfiguration::needsValidation(const DnsTlsServer& server) {
bool PrivateDnsConfiguration::needsValidation(const IPrivateDnsServer& server) const {
    // The server is not expected to be used on the network.
    if (!server.active()) return false;

@@ -365,13 +367,13 @@ bool PrivateDnsConfiguration::needsValidation(const DnsTlsServer& server) {
    return false;
}

base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDns(const ServerIdentity& identity,
                                                                   unsigned netId) {
base::Result<IPrivateDnsServer*> PrivateDnsConfiguration::getPrivateDns(
        const ServerIdentity& identity, unsigned netId) {
    std::lock_guard guard(mPrivateDnsLock);
    return getPrivateDnsLocked(identity, netId);
}

base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
base::Result<IPrivateDnsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
        const ServerIdentity& identity, unsigned netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
@@ -384,7 +386,7 @@ base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
                      identity.provider);
    }

    return &iter->second;
    return iter->second.get();
}

void PrivateDnsConfiguration::setObserver(PrivateDnsValidationObserver* observer) {
+7 −8
Original line number Diff line number Diff line
@@ -58,8 +58,8 @@ class PrivateDnsConfiguration {
        const netdutils::IPSockAddr sockaddr;
        const std::string provider;

        explicit ServerIdentity(const DnsTlsServer& server)
            : sockaddr(netdutils::IPSockAddr::toIPSockAddr(server.ss)), provider(server.name) {}
        explicit ServerIdentity(const IPrivateDnsServer& server)
            : sockaddr(server.addr()), provider(server.provider()) {}

        bool operator<(const ServerIdentity& other) const {
            return std::tie(sockaddr, provider) < std::tie(other.sockaddr, other.provider);
@@ -92,7 +92,7 @@ class PrivateDnsConfiguration {
    void dump(netdutils::DumpWriter& dw) const;

  private:
    typedef std::map<ServerIdentity, DnsTlsServer> PrivateDnsTracker;
    typedef std::map<ServerIdentity, std::unique_ptr<IPrivateDnsServer>> PrivateDnsTracker;

    PrivateDnsConfiguration() = default;

@@ -110,18 +110,17 @@ class PrivateDnsConfiguration {
    // Decide if a validation for |server| is needed. Note that servers that have failed
    // multiple validation attempts but for which there is still a validating
    // thread running are marked as being Validation::in_process.
    // TODO: decouple the dependency of DnsTlsServer.
    bool needsValidation(const DnsTlsServer& server) REQUIRES(mPrivateDnsLock);
    bool needsValidation(const IPrivateDnsServer& server) const REQUIRES(mPrivateDnsLock);

    void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
            REQUIRES(mPrivateDnsLock);

    // For testing.
    base::Result<DnsTlsServer*> getPrivateDns(const ServerIdentity& identity, unsigned netId)
    base::Result<IPrivateDnsServer*> getPrivateDns(const ServerIdentity& identity, unsigned netId)
            EXCLUDES(mPrivateDnsLock);

    base::Result<DnsTlsServer*> getPrivateDnsLocked(const ServerIdentity& identity, unsigned netId)
            REQUIRES(mPrivateDnsLock);
    base::Result<IPrivateDnsServer*> getPrivateDnsLocked(const ServerIdentity& identity,
                                                         unsigned netId) REQUIRES(mPrivateDnsLock);

    mutable std::mutex mPrivateDnsLock;
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);