Loading DnsTlsServer.h +6 −0 Original line number Diff line number Diff line Loading @@ -74,6 +74,12 @@ struct DnsTlsServer { Validation validationState() const { return mValidation; } void setValidationState(Validation val) { mValidation = val; } // The socket mark used for validation. // Note that the mark of a connection to which the DnsResolver sends app's DNS requests can // be different. // TODO: make it const. uint32_t mark = 0; // Return whether or not the server can be used for a network. It depends on // the resolver configuration. bool active() const { return mActive; } Loading PrivateDnsConfiguration.cpp +32 −0 Original line number Diff line number Diff line Loading @@ -70,6 +70,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, DnsTlsServer server(parsed); server.name = name; server.certificate = caCert; server.mark = mark; tmp[ServerIdentity(server)] = server; } Loading Loading @@ -140,6 +141,37 @@ void PrivateDnsConfiguration::clear(unsigned netId) { mPrivateDnsTransports.erase(netId); } bool PrivateDnsConfiguration::requestValidation(unsigned netId, const DnsTlsServer& server, uint32_t mark) { std::lock_guard guard(mPrivateDnsLock); auto netPair = mPrivateDnsTransports.find(netId); if (netPair == mPrivateDnsTransports.end()) { return false; } auto& tracker = netPair->second; const ServerIdentity identity = ServerIdentity(server); auto it = tracker.find(identity); if (it == tracker.end()) { return false; } const DnsTlsServer& target = it->second; if (!target.active()) return false; if (target.validationState() != Validation::success) return false; // Don't run the validation if |mark| (from android_net_context.dns_mark) is different. // This is to protect validation from running on unexpected marks. // Validation should be associated with a mark gotten by system permission. if (target.mark != mark) return false; updateServerState(identity, Validation::in_process, netId); startValidation(target, netId, mark); return true; } void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId, uint32_t mark) REQUIRES(mPrivateDnsLock) { // Note that capturing |server| and |netId| in this lambda create copies. Loading PrivateDnsConfiguration.h +5 −0 Original line number Diff line number Diff line Loading @@ -65,6 +65,11 @@ class PrivateDnsConfiguration { void clear(unsigned netId) EXCLUDES(mPrivateDnsLock); // Request |server| to be revalidated on a connection tagged with |mark|. // Return true if the request is accepted; otherwise, return false. bool requestValidation(unsigned netId, const DnsTlsServer& server, uint32_t mark) EXCLUDES(mPrivateDnsLock); struct ServerIdentity { const netdutils::IPAddress ip; const std::string name; Loading PrivateDnsConfigurationTest.cpp +50 −0 Original line number Diff line number Diff line Loading @@ -254,6 +254,56 @@ TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) { EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); } TEST_F(PrivateDnsConfigurationTest, RequestValidation) { const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853)); testing::InSequence seq; for (const std::string_view config : {"SUCCESS", "IN_PROGRESS", "FAIL"}) { SCOPED_TRACE(config); EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId)); if (config == "SUCCESS") { EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId)); } else if (config == "IN_PROGRESS") { backend.setDeferredResp(true); } else { // config = "FAIL" ASSERT_TRUE(backend.stopServer()); EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId)); } EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0); expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); // Wait until the validation state is transitioned. const int runningThreads = (config == "IN_PROGRESS") ? 1 : 0; ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == runningThreads; })); bool requestAccepted = false; if (config == "SUCCESS") { EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId)); EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId)); requestAccepted = true; } else if (config == "IN_PROGRESS") { EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId)); } EXPECT_EQ(mPdc.requestValidation(kNetId, server, kMark), requestAccepted); // Resending the same request or requesting nonexistent servers are denied. EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark)); EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark + 1)); EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, server, kMark)); // Reset the test state. backend.setDeferredResp(false); backend.startServer(); ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); mPdc.clear(kNetId); } } // TODO: add ValidationFail_Strict test. } // namespace android::net Loading
DnsTlsServer.h +6 −0 Original line number Diff line number Diff line Loading @@ -74,6 +74,12 @@ struct DnsTlsServer { Validation validationState() const { return mValidation; } void setValidationState(Validation val) { mValidation = val; } // The socket mark used for validation. // Note that the mark of a connection to which the DnsResolver sends app's DNS requests can // be different. // TODO: make it const. uint32_t mark = 0; // Return whether or not the server can be used for a network. It depends on // the resolver configuration. bool active() const { return mActive; } Loading
PrivateDnsConfiguration.cpp +32 −0 Original line number Diff line number Diff line Loading @@ -70,6 +70,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, DnsTlsServer server(parsed); server.name = name; server.certificate = caCert; server.mark = mark; tmp[ServerIdentity(server)] = server; } Loading Loading @@ -140,6 +141,37 @@ void PrivateDnsConfiguration::clear(unsigned netId) { mPrivateDnsTransports.erase(netId); } bool PrivateDnsConfiguration::requestValidation(unsigned netId, const DnsTlsServer& server, uint32_t mark) { std::lock_guard guard(mPrivateDnsLock); auto netPair = mPrivateDnsTransports.find(netId); if (netPair == mPrivateDnsTransports.end()) { return false; } auto& tracker = netPair->second; const ServerIdentity identity = ServerIdentity(server); auto it = tracker.find(identity); if (it == tracker.end()) { return false; } const DnsTlsServer& target = it->second; if (!target.active()) return false; if (target.validationState() != Validation::success) return false; // Don't run the validation if |mark| (from android_net_context.dns_mark) is different. // This is to protect validation from running on unexpected marks. // Validation should be associated with a mark gotten by system permission. if (target.mark != mark) return false; updateServerState(identity, Validation::in_process, netId); startValidation(target, netId, mark); return true; } void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId, uint32_t mark) REQUIRES(mPrivateDnsLock) { // Note that capturing |server| and |netId| in this lambda create copies. Loading
PrivateDnsConfiguration.h +5 −0 Original line number Diff line number Diff line Loading @@ -65,6 +65,11 @@ class PrivateDnsConfiguration { void clear(unsigned netId) EXCLUDES(mPrivateDnsLock); // Request |server| to be revalidated on a connection tagged with |mark|. // Return true if the request is accepted; otherwise, return false. bool requestValidation(unsigned netId, const DnsTlsServer& server, uint32_t mark) EXCLUDES(mPrivateDnsLock); struct ServerIdentity { const netdutils::IPAddress ip; const std::string name; Loading
PrivateDnsConfigurationTest.cpp +50 −0 Original line number Diff line number Diff line Loading @@ -254,6 +254,56 @@ TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) { EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); } TEST_F(PrivateDnsConfigurationTest, RequestValidation) { const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853)); testing::InSequence seq; for (const std::string_view config : {"SUCCESS", "IN_PROGRESS", "FAIL"}) { SCOPED_TRACE(config); EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId)); if (config == "SUCCESS") { EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId)); } else if (config == "IN_PROGRESS") { backend.setDeferredResp(true); } else { // config = "FAIL" ASSERT_TRUE(backend.stopServer()); EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId)); } EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0); expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); // Wait until the validation state is transitioned. const int runningThreads = (config == "IN_PROGRESS") ? 1 : 0; ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == runningThreads; })); bool requestAccepted = false; if (config == "SUCCESS") { EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId)); EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId)); requestAccepted = true; } else if (config == "IN_PROGRESS") { EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId)); } EXPECT_EQ(mPdc.requestValidation(kNetId, server, kMark), requestAccepted); // Resending the same request or requesting nonexistent servers are denied. EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark)); EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark + 1)); EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, server, kMark)); // Reset the test state. backend.setDeferredResp(false); backend.startServer(); ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); mPdc.clear(kNetId); } } // TODO: add ValidationFail_Strict test. } // namespace android::net