Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ad96ef83 authored by Mike Yu's avatar Mike Yu
Browse files

Change most of PrivateDnsConfiguration methods to use ServerIdentity

Passing DnsTlsServer might be confusing because it's not straightforward
to know if a DnsTlsServer is a copy or onwed by PrivateDnsConfiguration.

This CL changes most of the methods to use ServerIdentity. The methods
can then get the corresponding DnsTlsServer by the new added method
getPrivateDns().

Bug: 186177613
Test: cd packages/modules/DnsResolver && atest
Change-Id: Ied4a4ee026862cd2c596586499cbfa7646eaaf2a
parent 05bf348e
Loading
Loading
Loading
Loading
+2 −2
Original line number Original line Diff line number Diff line
@@ -214,8 +214,8 @@ DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, un
            // happens, the xport will be marked as unusable and DoT queries won't be sent to
            // happens, the xport will be marked as unusable and DoT queries won't be sent to
            // it anymore. Eventually, after IDLE_TIMEOUT, the xport will be destroyed, and
            // it anymore. Eventually, after IDLE_TIMEOUT, the xport will be destroyed, and
            // a new xport will be created.
            // a new xport will be created.
            const auto result =
            const auto result = PrivateDnsConfiguration::getInstance().requestValidation(
                    PrivateDnsConfiguration::getInstance().requestValidation(netId, server, mark);
                    netId, PrivateDnsConfiguration::ServerIdentity{server}, mark);
            LOG(WARNING) << "Requested validation for " << server.toIpString() << " with mark 0x"
            LOG(WARNING) << "Requested validation for " << server.toIpString() << " with mark 0x"
                         << std::hex << mark << ", "
                         << std::hex << mark << ", "
                         << (result.ok() ? "succeeded" : "failed: " + result.error().message());
                         << (result.ok() ? "succeeded" : "failed: " + result.error().message());
+61 −47
Original line number Original line Diff line number Diff line
@@ -110,7 +110,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,


        if (needsValidation(server)) {
        if (needsValidation(server)) {
            updateServerState(identity, Validation::in_process, netId);
            updateServerState(identity, Validation::in_process, netId);
            startValidation(server, netId, false);
            startValidation(identity, netId, false);
        }
        }
    }
    }


@@ -145,7 +145,7 @@ void PrivateDnsConfiguration::clear(unsigned netId) {
}
}


base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
                                                              const DnsTlsServer& server,
                                                              const ServerIdentity& identity,
                                                              uint32_t mark) {
                                                              uint32_t mark) {
    std::lock_guard guard(mPrivateDnsLock);
    std::lock_guard guard(mPrivateDnsLock);


@@ -159,40 +159,39 @@ base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
        return Errorf("Private DNS setting is not opportunistic mode");
        return Errorf("Private DNS setting is not opportunistic mode");
    }
    }


    auto netPair = mPrivateDnsTransports.find(netId);
    auto result = getPrivateDnsLocked(identity, netId);
    if (netPair == mPrivateDnsTransports.end()) {
    if (!result.ok()) {
        return Errorf("NetId not found in mPrivateDnsTransports");
        return result.error();
    }
    }


    auto& tracker = netPair->second;
    const DnsTlsServer* target = result.value();
    const ServerIdentity identity = ServerIdentity(server);
    auto it = tracker.find(identity);
    if (it == tracker.end()) {
        return Errorf("Server was removed");
    }


    const DnsTlsServer& target = it->second;
    if (!target->active()) return Errorf("Server is not active");


    if (!target.active()) return Errorf("Server is not active");
    if (target->validationState() != Validation::success) {

    if (target.validationState() != Validation::success) {
        return Errorf("Server validation state mismatched");
        return Errorf("Server validation state mismatched");
    }
    }


    // Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
    // Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
    // This is to protect validation from running on unexpected marks.
    // This is to protect validation from running on unexpected marks.
    // Validation should be associated with a mark gotten by system permission.
    // Validation should be associated with a mark gotten by system permission.
    if (target.mark != mark) return Errorf("Socket mark mismatched");
    if (target->mark != mark) return Errorf("Socket mark mismatched");


    updateServerState(identity, Validation::in_process, netId);
    updateServerState(identity, Validation::in_process, netId);
    startValidation(target, netId, true);
    startValidation(identity, netId, true);
    return {};
    return {};
}
}


void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
void PrivateDnsConfiguration::startValidation(const ServerIdentity& identity, unsigned netId,
                                              bool isRevalidation) REQUIRES(mPrivateDnsLock) {
                                              bool isRevalidation) {
    // Note that capturing |server|, |netId|, and |isRevalidation| in this lambda create copies.
    // This ensures that the thread sends probe at least once in case
    std::thread validate_thread([this, server, netId, isRevalidation] {
    // the server is removed before the thread starts running.
    // TODO: consider moving these code to the thread.
    const auto result = getPrivateDnsLocked(identity, netId);
    if (!result.ok()) return;
    DnsTlsServer server = *result.value();

    std::thread validate_thread([this, identity, server, netId, isRevalidation] {
        setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());
        setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());


        // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
        // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
@@ -223,7 +222,7 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign
                         << server.toIpString();
                         << server.toIpString();


            const bool needs_reeval =
            const bool needs_reeval =
                    this->recordPrivateDnsValidation(server, netId, success, isRevalidation);
                    this->recordPrivateDnsValidation(identity, netId, success, isRevalidation);


            if (!needs_reeval) {
            if (!needs_reeval) {
                break;
                break;
@@ -240,11 +239,11 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign
    validate_thread.detach();
    validate_thread.detach();
}
}


void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const DnsTlsServer& server,
void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const ServerIdentity& identity,
                                                            unsigned netId, bool success) {
                                                            unsigned netId, bool success) {
    LOG(DEBUG) << "Sending validation " << (success ? "success" : "failure") << " event on netId "
    LOG(DEBUG) << "Sending validation " << (success ? "success" : "failure") << " event on netId "
               << netId << " for " << server.toIpString() << " with hostname {" << server.name
               << netId << " for " << identity.sockaddr.ip().toString() << " with hostname {"
               << "}";
               << identity.provider << "}";
    // Send a validation event to NetdEventListenerService.
    // Send a validation event to NetdEventListenerService.
    const auto& listeners = ResolverEventReporter::getInstance().getListeners();
    const auto& listeners = ResolverEventReporter::getInstance().getListeners();
    if (listeners.empty()) {
    if (listeners.empty()) {
@@ -252,15 +251,16 @@ void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const DnsTlsServer&
                << "Validation event not sent since no INetdEventListener receiver is available.";
                << "Validation event not sent since no INetdEventListener receiver is available.";
    }
    }
    for (const auto& it : listeners) {
    for (const auto& it : listeners) {
        it->onPrivateDnsValidationEvent(netId, server.toIpString(), server.name, success);
        it->onPrivateDnsValidationEvent(netId, identity.sockaddr.ip().toString(), identity.provider,
                                        success);
    }
    }


    // Send a validation event to unsolicited event listeners.
    // Send a validation event to unsolicited event listeners.
    const auto& unsolEventListeners = ResolverEventReporter::getInstance().getUnsolEventListeners();
    const auto& unsolEventListeners = ResolverEventReporter::getInstance().getUnsolEventListeners();
    const PrivateDnsValidationEventParcel validationEvent = {
    const PrivateDnsValidationEventParcel validationEvent = {
            .netId = static_cast<int32_t>(netId),
            .netId = static_cast<int32_t>(netId),
            .ipAddress = server.toIpString(),
            .ipAddress = identity.sockaddr.ip().toString(),
            .hostname = server.name,
            .hostname = identity.provider,
            .validation = success ? IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_SUCCESS
            .validation = success ? IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_SUCCESS
                                  : IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_FAILURE,
                                  : IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_FAILURE,
    };
    };
@@ -269,11 +269,11 @@ void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const DnsTlsServer&
    }
    }
}
}


bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId,
bool PrivateDnsConfiguration::recordPrivateDnsValidation(const ServerIdentity& identity,
                                                         bool success, bool isRevalidation) {
                                                         unsigned netId, bool success,
                                                         bool isRevalidation) {
    constexpr bool NEEDS_REEVALUATION = true;
    constexpr bool NEEDS_REEVALUATION = true;
    constexpr bool DONT_REEVALUATE = false;
    constexpr bool DONT_REEVALUATE = false;
    const ServerIdentity identity = ServerIdentity(server);


    std::lock_guard guard(mPrivateDnsLock);
    std::lock_guard guard(mPrivateDnsLock);


@@ -303,23 +303,19 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
    auto& tracker = netPair->second;
    auto& tracker = netPair->second;
    auto serverPair = tracker.find(identity);
    auto serverPair = tracker.find(identity);
    if (serverPair == tracker.end()) {
    if (serverPair == tracker.end()) {
        LOG(WARNING) << "Server " << server.toIpString()
        LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
                     << " was removed during private DNS validation";
                     << " was removed during private DNS validation";
        success = false;
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
        reevaluationStatus = DONT_REEVALUATE;
    } else if (!(serverPair->second == server)) {
        LOG(WARNING) << "Server " << server.toIpString()
                     << " was changed during private DNS validation";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    } else if (!serverPair->second.active()) {
    } else if (!serverPair->second.active()) {
        LOG(WARNING) << "Server " << server.toIpString() << " was removed from the configuration";
        LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
                     << " was removed from the configuration";
        success = false;
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
        reevaluationStatus = DONT_REEVALUATE;
    }
    }


    // Send private dns validation result to listeners.
    // Send private dns validation result to listeners.
    sendPrivateDnsValidationEvent(server, netId, success);
    sendPrivateDnsValidationEvent(identity, netId, success);


    if (success) {
    if (success) {
        updateServerState(identity, Validation::success, netId);
        updateServerState(identity, Validation::success, netId);
@@ -338,19 +334,15 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser


void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
                                                uint32_t netId) {
                                                uint32_t netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    const auto result = getPrivateDnsLocked(identity, netId);
    if (netPair == mPrivateDnsTransports.end()) {
    if (!result.ok()) {
        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
        return;
        return;
    }
    }


    auto& tracker = netPair->second;
    auto* server = result.value();
    if (tracker.find(identity) == tracker.end()) {
        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
        return;
    }


    tracker[identity].setValidationState(state);
    server->setValidationState(state);
    notifyValidationStateUpdate(identity.sockaddr, state, netId);
    notifyValidationStateUpdate(identity.sockaddr, state, netId);


    RecordEntry record(netId, identity, state);
    RecordEntry record(netId, identity, state);
@@ -373,6 +365,28 @@ bool PrivateDnsConfiguration::needsValidation(const DnsTlsServer& server) {
    return false;
    return false;
}
}


base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDns(const ServerIdentity& identity,
                                                                   unsigned netId) {
    std::lock_guard guard(mPrivateDnsLock);
    return getPrivateDnsLocked(identity, netId);
}

base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
        const ServerIdentity& identity, unsigned netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        return Errorf("Failed to get private DNS: netId {} not found", netId);
    }

    auto iter = netPair->second.find(identity);
    if (iter == netPair->second.end()) {
        return Errorf("Failed to get private DNS: server {{{}/{}}} not found", identity.sockaddr,
                      identity.provider);
    }

    return &iter->second;
}

void PrivateDnsConfiguration::setObserver(PrivateDnsValidationObserver* observer) {
void PrivateDnsConfiguration::setObserver(PrivateDnsValidationObserver* observer) {
    std::lock_guard guard(mPrivateDnsLock);
    std::lock_guard guard(mPrivateDnsLock);
    mObserver = observer;
    mObserver = observer;
+30 −21
Original line number Original line Diff line number Diff line
@@ -33,6 +33,7 @@
namespace android {
namespace android {
namespace net {
namespace net {


// TODO: decouple the dependency of DnsTlsServer.
struct PrivateDnsStatus {
struct PrivateDnsStatus {
    PrivateDnsMode mode;
    PrivateDnsMode mode;


@@ -53,24 +54,6 @@ struct PrivateDnsStatus {


class PrivateDnsConfiguration {
class PrivateDnsConfiguration {
  public:
  public:
    // The only instance of PrivateDnsConfiguration.
    static PrivateDnsConfiguration& getInstance() {
        static PrivateDnsConfiguration instance;
        return instance;
    }

    int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
            const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);

    PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);

    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);

    // Request |server| to be revalidated on a connection tagged with |mark|.
    // Returns a Result to indicate if the request is accepted.
    base::Result<void> requestValidation(unsigned netId, const DnsTlsServer& server, uint32_t mark)
            EXCLUDES(mPrivateDnsLock);

    struct ServerIdentity {
    struct ServerIdentity {
        const netdutils::IPSockAddr sockaddr;
        const netdutils::IPSockAddr sockaddr;
        const std::string provider;
        const std::string provider;
@@ -86,6 +69,24 @@ class PrivateDnsConfiguration {
        }
        }
    };
    };


    // The only instance of PrivateDnsConfiguration.
    static PrivateDnsConfiguration& getInstance() {
        static PrivateDnsConfiguration instance;
        return instance;
    }

    int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
            const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);

    PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);

    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);

    // Request the server to be revalidated on a connection tagged with |mark|.
    // Returns a Result to indicate if the request is accepted.
    base::Result<void> requestValidation(unsigned netId, const ServerIdentity& identity,
                                         uint32_t mark) EXCLUDES(mPrivateDnsLock);

    void setObserver(PrivateDnsValidationObserver* observer);
    void setObserver(PrivateDnsValidationObserver* observer);


    void dump(netdutils::DumpWriter& dw) const;
    void dump(netdutils::DumpWriter& dw) const;
@@ -97,23 +98,31 @@ class PrivateDnsConfiguration {


    // Launchs a thread to run the validation for |server| on the network |netId|.
    // Launchs a thread to run the validation for |server| on the network |netId|.
    // |isRevalidation| is true if this call is due to a revalidation request.
    // |isRevalidation| is true if this call is due to a revalidation request.
    void startValidation(const DnsTlsServer& server, unsigned netId, bool isRevalidation)
    void startValidation(const ServerIdentity& identity, unsigned netId, bool isRevalidation)
            REQUIRES(mPrivateDnsLock);
            REQUIRES(mPrivateDnsLock);


    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success,
    bool recordPrivateDnsValidation(const ServerIdentity& identity, unsigned netId, bool success,
                                    bool isRevalidation) EXCLUDES(mPrivateDnsLock);
                                    bool isRevalidation) EXCLUDES(mPrivateDnsLock);


    void sendPrivateDnsValidationEvent(const DnsTlsServer& server, unsigned netId, bool success)
    void sendPrivateDnsValidationEvent(const ServerIdentity& identity, unsigned netId, bool success)
            REQUIRES(mPrivateDnsLock);
            REQUIRES(mPrivateDnsLock);


    // Decide if a validation for |server| is needed. Note that servers that have failed
    // Decide if a validation for |server| is needed. Note that servers that have failed
    // multiple validation attempts but for which there is still a validating
    // multiple validation attempts but for which there is still a validating
    // thread running are marked as being Validation::in_process.
    // thread running are marked as being Validation::in_process.
    // TODO: decouple the dependency of DnsTlsServer.
    bool needsValidation(const DnsTlsServer& server) REQUIRES(mPrivateDnsLock);
    bool needsValidation(const DnsTlsServer& server) REQUIRES(mPrivateDnsLock);


    void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
    void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
            REQUIRES(mPrivateDnsLock);
            REQUIRES(mPrivateDnsLock);


    // For testing.
    base::Result<DnsTlsServer*> getPrivateDns(const ServerIdentity& identity, unsigned netId)
            EXCLUDES(mPrivateDnsLock);

    base::Result<DnsTlsServer*> getPrivateDnsLocked(const ServerIdentity& identity, unsigned netId)
            REQUIRES(mPrivateDnsLock);

    mutable std::mutex mPrivateDnsLock;
    mutable std::mutex mPrivateDnsLock;
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);


+33 −8
Original line number Original line Diff line number Diff line
@@ -28,6 +28,8 @@ using namespace std::chrono_literals;


class PrivateDnsConfigurationTest : public ::testing::Test {
class PrivateDnsConfigurationTest : public ::testing::Test {
  public:
  public:
    using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;

    static void SetUpTestSuite() {
    static void SetUpTestSuite() {
        // stopServer() will be called in their destructor.
        // stopServer() will be called in their destructor.
        ASSERT_TRUE(tls1.startServer());
        ASSERT_TRUE(tls1.startServer());
@@ -100,6 +102,10 @@ class PrivateDnsConfigurationTest : public ::testing::Test {
        return (serverStateMap == mObserver.getServerStateMap());
        return (serverStateMap == mObserver.getServerStateMap());
    }
    }


    bool hasPrivateDnsServer(const ServerIdentity& identity, unsigned netId) {
        return mPdc.getPrivateDns(identity, netId).ok();
    }

    static constexpr uint32_t kNetId = 30;
    static constexpr uint32_t kNetId = 30;
    static constexpr uint32_t kMark = 30;
    static constexpr uint32_t kMark = 30;
    static constexpr char kBackend[] = "127.0.2.1";
    static constexpr char kBackend[] = "127.0.2.1";
@@ -230,8 +236,6 @@ TEST_F(PrivateDnsConfigurationTest, NoValidation) {
}
}


TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
    using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;

    DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
    DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
    server.name = "dns.example.com";
    server.name = "dns.example.com";


@@ -254,6 +258,7 @@ TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {


TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
    const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
    const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
    const ServerIdentity identity(server);


    testing::InSequence seq;
    testing::InSequence seq;


@@ -281,18 +286,18 @@ TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
            EXPECT_CALL(mObserver,
            EXPECT_CALL(mObserver,
                        onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
                        onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
            EXPECT_TRUE(mPdc.requestValidation(kNetId, server, kMark).ok());
            EXPECT_TRUE(mPdc.requestValidation(kNetId, identity, kMark).ok());
        } else if (config == "IN_PROGRESS") {
        } else if (config == "IN_PROGRESS") {
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
            EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
            EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
        } else if (config == "FAIL") {
        } else if (config == "FAIL") {
            EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
            EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
        }
        }


        // Resending the same request or requesting nonexistent servers are denied.
        // Resending the same request or requesting nonexistent servers are denied.
        EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark + 1).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark + 1).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, server, kMark).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, identity, kMark).ok());


        // Reset the test state.
        // Reset the test state.
        backend.setDeferredResp(false);
        backend.setDeferredResp(false);
@@ -306,6 +311,26 @@ TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
    }
    }
}
}


TEST_F(PrivateDnsConfigurationTest, GetPrivateDns) {
    const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
    const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853));

    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));

    // Suppress the warning.
    EXPECT_CALL(mObserver, onValidationStateUpdate).Times(2);

    EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
    expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);

    EXPECT_TRUE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId + 1));

    ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
}

// TODO: add ValidationFail_Strict test.
// TODO: add ValidationFail_Strict test.


}  // namespace android::net
}  // namespace android::net