Loading DnsTlsServer.h +14 −0 Original line number Original line Diff line number Diff line Loading @@ -27,6 +27,9 @@ namespace android { namespace android { namespace net { namespace net { // Validation status of a DNS over TLS server (on a specific netId). enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid }; // DnsTlsServer represents a recursive resolver that supports, or may support, a // DnsTlsServer represents a recursive resolver that supports, or may support, a // secure protocol. // secure protocol. struct DnsTlsServer { struct DnsTlsServer { Loading @@ -37,17 +40,21 @@ struct DnsTlsServer { DnsTlsServer(const sockaddr_storage& ss) : ss(ss) {} DnsTlsServer(const sockaddr_storage& ss) : ss(ss) {} // The server location, including IP and port. // The server location, including IP and port. // TODO: make it const. sockaddr_storage ss = {}; sockaddr_storage ss = {}; // The server's hostname. If this string is nonempty, the server must present a // The server's hostname. If this string is nonempty, the server must present a // certificate that indicates this name and has a valid chain to a trusted root CA. // certificate that indicates this name and has a valid chain to a trusted root CA. // TODO: make it const. std::string name; std::string name; // The certificate of the CA that signed the server's certificate. // The certificate of the CA that signed the server's certificate. // It is used to store temporary test CA certificate for internal tests. // It is used to store temporary test CA certificate for internal tests. // TODO: make it const. std::string certificate; std::string certificate; // Placeholder. More protocols might be defined in the future. // Placeholder. More protocols might be defined in the future. // TODO: make it const. int protocol = IPPROTO_TCP; int protocol = IPPROTO_TCP; // Exact comparison of DnsTlsServer objects // Exact comparison of DnsTlsServer objects Loading @@ -55,6 +62,13 @@ struct DnsTlsServer { bool operator==(const DnsTlsServer& other) const; bool operator==(const DnsTlsServer& other) const; bool wasExplicitlyConfigured() const; bool wasExplicitlyConfigured() const; Validation validationState() const { return mValidation; } void setValidationState(Validation val) { mValidation = val; } private: // State, unrelated to the comparison of DnsTlsServer objects. Validation mValidation = Validation::unknown_server; }; }; // This comparison only checks the IP address. It ignores ports, names, and fingerprints. // This comparison only checks the IP address. It ignores ports, names, and fingerprints. Loading PrivateDnsConfiguration.cpp +37 −28 Original line number Original line Diff line number Diff line Loading @@ -69,7 +69,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, << ", " << servers.size() << ", " << name << ")"; << ", " << servers.size() << ", " << name << ")"; // Parse the list of servers that has been passed in // Parse the list of servers that has been passed in std::set<DnsTlsServer> tlsServers; PrivateDnsTracker tmp; for (const auto& s : servers) { for (const auto& s : servers) { sockaddr_storage parsed; sockaddr_storage parsed; if (!parseServer(s.c_str(), &parsed)) { if (!parseServer(s.c_str(), &parsed)) { Loading @@ -78,13 +78,13 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, DnsTlsServer server(parsed); DnsTlsServer server(parsed); server.name = name; server.name = name; server.certificate = caCert; server.certificate = caCert; tlsServers.insert(server); tmp[ServerIdentity(server)] = server; } } std::lock_guard guard(mPrivateDnsLock); std::lock_guard guard(mPrivateDnsLock); if (!name.empty()) { if (!name.empty()) { mPrivateDnsModes[netId] = PrivateDnsMode::STRICT; mPrivateDnsModes[netId] = PrivateDnsMode::STRICT; } else if (!tlsServers.empty()) { } else if (!tmp.empty()) { mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC; mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC; } else { } else { mPrivateDnsModes[netId] = PrivateDnsMode::OFF; mPrivateDnsModes[netId] = PrivateDnsMode::OFF; Loading Loading @@ -112,7 +112,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, // Remove any servers from the tracker that are not in |servers| exactly. // Remove any servers from the tracker that are not in |servers| exactly. for (auto it = tracker.begin(); it != tracker.end();) { for (auto it = tracker.begin(); it != tracker.end();) { if (tlsServers.count(it->first) == 0) { if (tmp.find(it->first) == tmp.end()) { it = tracker.erase(it); it = tracker.erase(it); } else { } else { ++it; ++it; Loading @@ -120,7 +120,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, } } // Add any new or changed servers to the tracker, and initiate async checks for them. // Add any new or changed servers to the tracker, and initiate async checks for them. for (const auto& server : tlsServers) { for (const auto& [identity, server] : tmp) { if (needsValidation(tracker, server)) { if (needsValidation(tracker, server)) { // This is temporarily required. Consider the following scenario, for example, // This is temporarily required. Consider the following scenario, for example, // Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts. // Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts. Loading @@ -133,7 +133,10 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, // // // If we didn't add servers to tracker before needValidateThread(), tracker would // If we didn't add servers to tracker before needValidateThread(), tracker would // become empty. We would report s1 validation failed. // become empty. We would report s1 validation failed. tracker[server] = Validation::in_process; if (tracker.find(identity) == tracker.end()) { tracker[identity] = server; } tracker[identity].setValidationState(Validation::in_process); LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId " LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId " << netId << ". Tracker now has size " << tracker.size(); << netId << ". Tracker now has size " << tracker.size(); // This judge must be after "tracker[server] = Validation::in_process;" // This judge must be after "tracker[server] = Validation::in_process;" Loading @@ -141,7 +144,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, continue; continue; } } updateServerState(server, Validation::in_process, netId); updateServerState(identity, Validation::in_process, netId); startValidation(server, netId, mark); startValidation(server, netId, mark); } } } } Loading @@ -159,8 +162,8 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) { const auto netPair = mPrivateDnsTransports.find(netId); const auto netPair = mPrivateDnsTransports.find(netId); if (netPair != mPrivateDnsTransports.end()) { if (netPair != mPrivateDnsTransports.end()) { for (const auto& serverPair : netPair->second) { for (const auto& [_, server] : netPair->second) { status.serversMap.emplace(serverPair.first, serverPair.second); status.serversMap.emplace(server, server.validationState()); } } } } Loading Loading @@ -227,20 +230,21 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser bool success) { bool success) { constexpr bool NEEDS_REEVALUATION = true; constexpr bool NEEDS_REEVALUATION = true; constexpr bool DONT_REEVALUATE = false; constexpr bool DONT_REEVALUATE = false; const ServerIdentity identity = ServerIdentity(server); std::lock_guard guard(mPrivateDnsLock); std::lock_guard guard(mPrivateDnsLock); auto netPair = mPrivateDnsTransports.find(netId); auto netPair = mPrivateDnsTransports.find(netId); if (netPair == mPrivateDnsTransports.end()) { if (netPair == mPrivateDnsTransports.end()) { LOG(WARNING) << "netId " << netId << " was erased during private DNS validation"; LOG(WARNING) << "netId " << netId << " was erased during private DNS validation"; maybeNotifyObserver(server, Validation::fail, netId); maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return DONT_REEVALUATE; return DONT_REEVALUATE; } } const auto mode = mPrivateDnsModes.find(netId); const auto mode = mPrivateDnsModes.find(netId); if (mode == mPrivateDnsModes.end()) { if (mode == mPrivateDnsModes.end()) { LOG(WARNING) << "netId " << netId << " has no private DNS validation mode"; LOG(WARNING) << "netId " << netId << " has no private DNS validation mode"; maybeNotifyObserver(server, Validation::fail, netId); maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return DONT_REEVALUATE; return DONT_REEVALUATE; } } const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT); const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT); Loading @@ -249,16 +253,13 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser (success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION; (success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION; auto& tracker = netPair->second; auto& tracker = netPair->second; auto serverPair = tracker.find(server); auto serverPair = tracker.find(identity); if (serverPair == tracker.end()) { if (serverPair == tracker.end()) { // TODO: Consider not adding this server to the tracker since this server is not expected // to be one of the private DNS servers for this network now. This could prevent this // server from being included when dumping status. LOG(WARNING) << "Server " << addrToString(&server.ss) LOG(WARNING) << "Server " << addrToString(&server.ss) << " was removed during private DNS validation"; << " was removed during private DNS validation"; success = false; success = false; reevaluationStatus = DONT_REEVALUATE; reevaluationStatus = DONT_REEVALUATE; } else if (!(serverPair->first == server)) { } else if (!(serverPair->second == server)) { // TODO: It doesn't seem correct to overwrite the tracker entry for // TODO: It doesn't seem correct to overwrite the tracker entry for // |server| down below in this circumstance... Fix this. // |server| down below in this circumstance... Fix this. LOG(WARNING) << "Server " << addrToString(&server.ss) LOG(WARNING) << "Server " << addrToString(&server.ss) Loading @@ -282,14 +283,14 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser } } if (success) { if (success) { updateServerState(server, Validation::success, netId); updateServerState(identity, Validation::success, netId); } else { } else { // Validation failure is expected if a user is on a captive portal. // Validation failure is expected if a user is on a captive portal. // TODO: Trigger a second validation attempt after captive portal login // TODO: Trigger a second validation attempt after captive portal login // succeeds. // succeeds. const auto result = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process const auto result = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process : Validation::fail; : Validation::fail; updateServerState(server, result, netId); updateServerState(identity, result, netId); } } LOG(WARNING) << "Validation " << (success ? "success" : "failed"); LOG(WARNING) << "Validation " << (success ? "success" : "failed"); Loading Loading @@ -324,15 +325,22 @@ bool PrivateDnsConfiguration::needValidateThread(const DnsTlsServer& server, uns } } } } void PrivateDnsConfiguration::updateServerState(const DnsTlsServer& server, Validation state, void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId) { uint32_t netId) { auto netPair = mPrivateDnsTransports.find(netId); auto netPair = mPrivateDnsTransports.find(netId); if (netPair != mPrivateDnsTransports.end()) { if (netPair == mPrivateDnsTransports.end()) { maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return; } auto& tracker = netPair->second; auto& tracker = netPair->second; tracker[server] = state; if (tracker.find(identity) == tracker.end()) { maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return; } } maybeNotifyObserver(server, state, netId); tracker[identity].setValidationState(state); maybeNotifyObserver(identity.ip.toString(), state, netId); } } void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server, void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server, Loading @@ -353,8 +361,9 @@ void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& ser bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker, bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server) { const DnsTlsServer& server) { const auto& iter = tracker.find(server); const ServerIdentity identity = ServerIdentity(server); return (iter == tracker.end()) || (iter->second == Validation::fail); const auto& iter = tracker.find(identity); return (iter == tracker.end()) || (iter->second.validationState() == Validation::fail); } } void PrivateDnsConfiguration::setObserver(Observer* observer) { void PrivateDnsConfiguration::setObserver(Observer* observer) { Loading @@ -362,10 +371,10 @@ void PrivateDnsConfiguration::setObserver(Observer* observer) { mObserver = observer; mObserver = observer; } } void PrivateDnsConfiguration::maybeNotifyObserver(const DnsTlsServer& server, Validation validation, void PrivateDnsConfiguration::maybeNotifyObserver(const std::string& serverIp, uint32_t netId) const { Validation validation, uint32_t netId) const { if (mObserver) { if (mObserver) { mObserver->onValidationStateUpdate(addrToString(&server.ss), validation, netId); mObserver->onValidationStateUpdate(serverIp, validation, netId); } } } } Loading PrivateDnsConfiguration.h +29 −10 Original line number Original line Diff line number Diff line Loading @@ -22,6 +22,7 @@ #include <vector> #include <vector> #include <android-base/thread_annotations.h> #include <android-base/thread_annotations.h> #include <netdutils/InternetAddresses.h> #include "DnsTlsServer.h" #include "DnsTlsServer.h" Loading @@ -31,11 +32,10 @@ namespace net { // The DNS over TLS mode on a specific netId. // The DNS over TLS mode on a specific netId. enum class PrivateDnsMode : uint8_t { OFF, OPPORTUNISTIC, STRICT }; enum class PrivateDnsMode : uint8_t { OFF, OPPORTUNISTIC, STRICT }; // Validation status of a DNS over TLS server (on a specific netId). enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid }; struct PrivateDnsStatus { struct PrivateDnsStatus { PrivateDnsMode mode; PrivateDnsMode mode; // TODO: change the type to std::vector<DnsTlsServer>. std::map<DnsTlsServer, Validation, AddressComparator> serversMap; std::map<DnsTlsServer, Validation, AddressComparator> serversMap; std::list<DnsTlsServer> validatedServers() const { std::list<DnsTlsServer> validatedServers() const { Loading Loading @@ -65,8 +65,26 @@ class PrivateDnsConfiguration { void clear(unsigned netId) EXCLUDES(mPrivateDnsLock); void clear(unsigned netId) EXCLUDES(mPrivateDnsLock); struct ServerIdentity { const netdutils::IPAddress ip; const std::string name; const int protocol; explicit ServerIdentity(const DnsTlsServer& server) : ip(netdutils::IPSockAddr::toIPSockAddr(server.ss).ip()), name(server.name), protocol(server.protocol) {} bool operator<(const ServerIdentity& other) const { return std::tie(ip, name, protocol) < std::tie(other.ip, other.name, other.protocol); } bool operator==(const ServerIdentity& other) const { return std::tie(ip, name, protocol) == std::tie(other.ip, other.name, other.protocol); } }; private: private: typedef std::map<DnsTlsServer, Validation, AddressComparator> PrivateDnsTracker; typedef std::map<ServerIdentity, DnsTlsServer> PrivateDnsTracker; typedef std::set<DnsTlsServer, AddressComparator> ThreadTracker; typedef std::set<DnsTlsServer, AddressComparator> ThreadTracker; PrivateDnsConfiguration() = default; PrivateDnsConfiguration() = default; Loading @@ -88,7 +106,7 @@ class PrivateDnsConfiguration { bool needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server) bool needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server) REQUIRES(mPrivateDnsLock); REQUIRES(mPrivateDnsLock); void updateServerState(const DnsTlsServer& server, Validation state, uint32_t netId) void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId) REQUIRES(mPrivateDnsLock); REQUIRES(mPrivateDnsLock); std::mutex mPrivateDnsLock; std::mutex mPrivateDnsLock; Loading @@ -99,19 +117,20 @@ class PrivateDnsConfiguration { std::map<unsigned, ThreadTracker> mPrivateDnsValidateThreads GUARDED_BY(mPrivateDnsLock); std::map<unsigned, ThreadTracker> mPrivateDnsValidateThreads GUARDED_BY(mPrivateDnsLock); // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is // about to begin or 2) when a validation finishes. // about to begin or 2) when a validation finishes. If a validation finishes when in OFF mode // WARNING: The Observer is notified while the lock is being held. Be careful not to call any // or when the network has been destroyed, |validation| will be Validation::fail. // method of PrivateDnsConfiguration from the observer. // WARNING: The Observer is notified while the lock is being held. Be careful not to call // any method of PrivateDnsConfiguration from the observer. // TODO: fix the reentrancy problem. // TODO: fix the reentrancy problem. class Observer { class Observer { public: public: virtual ~Observer(){}; virtual ~Observer(){}; virtual void onValidationStateUpdate(const std::string& server, Validation validation, virtual void onValidationStateUpdate(const std::string& serverIp, Validation validation, uint32_t netId) = 0; uint32_t netId) = 0; }; }; void setObserver(Observer* observer); void setObserver(Observer* observer); void maybeNotifyObserver(const DnsTlsServer& server, Validation validation, void maybeNotifyObserver(const std::string& serverIp, Validation validation, uint32_t netId) const REQUIRES(mPrivateDnsLock); uint32_t netId) const REQUIRES(mPrivateDnsLock); Observer* mObserver GUARDED_BY(mPrivateDnsLock); Observer* mObserver GUARDED_BY(mPrivateDnsLock); Loading PrivateDnsConfigurationTest.cpp +37 −1 Original line number Original line Diff line number Diff line Loading @@ -63,7 +63,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test { class MockObserver : public PrivateDnsConfiguration::Observer { class MockObserver : public PrivateDnsConfiguration::Observer { public: public: MOCK_METHOD(void, onValidationStateUpdate, MOCK_METHOD(void, onValidationStateUpdate, (const std::string& server, Validation validation, uint32_t netId), (override)); (const std::string& serverIp, Validation validation, uint32_t netId), (override)); std::map<std::string, Validation> getServerStateMap() const { std::map<std::string, Validation> getServerStateMap() const { std::lock_guard guard(lock); std::lock_guard guard(lock); Loading Loading @@ -172,6 +173,11 @@ TEST_F(PrivateDnsConfigurationTest, ValidationBlock) { backend.setDeferredResp(false); backend.setDeferredResp(false); ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); // kServer1 is not a present server and thus should not be available from // PrivateDnsConfiguration::getStatus(). mObserver.removeFromServerStateMap(kServer1); expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); } } Loading Loading @@ -218,6 +224,36 @@ TEST_F(PrivateDnsConfigurationTest, NoValidation) { expectStatus(); expectStatus(); } } TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) { using ServerIdentity = PrivateDnsConfiguration::ServerIdentity; DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853)); server.name = "dns.example.com"; server.protocol = 1; // Different IP address (port is ignored). DnsTlsServer other = server; EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 5353); EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.2", 853); EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); // Different provider hostname. other = server; EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.name = "other.example.com"; EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); other.name = ""; EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); // Different protocol. other = server; EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.protocol++; EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); } // TODO: add ValidationFail_Strict test. // TODO: add ValidationFail_Strict test. } // namespace android::net } // namespace android::net resolv_tls_unit_test.cpp +24 −0 Original line number Original line Diff line number Diff line Loading @@ -800,6 +800,18 @@ void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) { EXPECT_FALSE(s2 == s1); EXPECT_FALSE(s2 == s1); } } void checkEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) { EXPECT_TRUE(s1 == s1); EXPECT_TRUE(s2 == s2); EXPECT_TRUE(isAddressEqual(s1, s1)); EXPECT_TRUE(isAddressEqual(s2, s2)); EXPECT_FALSE(s1 < s2); EXPECT_FALSE(s2 < s1); EXPECT_TRUE(s1 == s2); EXPECT_TRUE(s2 == s1); } class ServerTest : public BaseTest {}; class ServerTest : public BaseTest {}; TEST_F(ServerTest, IPv4) { TEST_F(ServerTest, IPv4) { Loading Loading @@ -873,6 +885,18 @@ TEST_F(ServerTest, Name) { EXPECT_TRUE(s2.wasExplicitlyConfigured()); EXPECT_TRUE(s2.wasExplicitlyConfigured()); } } TEST_F(ServerTest, State) { DnsTlsServer s1(V4ADDR1), s2(V4ADDR1); checkEqual(s1, s2); s1.setValidationState(Validation::success); checkEqual(s1, s2); s2.setValidationState(Validation::fail); checkEqual(s1, s2); EXPECT_EQ(s1.validationState(), Validation::success); EXPECT_EQ(s2.validationState(), Validation::fail); } TEST(QueryMapTest, Basic) { TEST(QueryMapTest, Basic) { DnsTlsQueryMap map; DnsTlsQueryMap map; Loading Loading
DnsTlsServer.h +14 −0 Original line number Original line Diff line number Diff line Loading @@ -27,6 +27,9 @@ namespace android { namespace android { namespace net { namespace net { // Validation status of a DNS over TLS server (on a specific netId). enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid }; // DnsTlsServer represents a recursive resolver that supports, or may support, a // DnsTlsServer represents a recursive resolver that supports, or may support, a // secure protocol. // secure protocol. struct DnsTlsServer { struct DnsTlsServer { Loading @@ -37,17 +40,21 @@ struct DnsTlsServer { DnsTlsServer(const sockaddr_storage& ss) : ss(ss) {} DnsTlsServer(const sockaddr_storage& ss) : ss(ss) {} // The server location, including IP and port. // The server location, including IP and port. // TODO: make it const. sockaddr_storage ss = {}; sockaddr_storage ss = {}; // The server's hostname. If this string is nonempty, the server must present a // The server's hostname. If this string is nonempty, the server must present a // certificate that indicates this name and has a valid chain to a trusted root CA. // certificate that indicates this name and has a valid chain to a trusted root CA. // TODO: make it const. std::string name; std::string name; // The certificate of the CA that signed the server's certificate. // The certificate of the CA that signed the server's certificate. // It is used to store temporary test CA certificate for internal tests. // It is used to store temporary test CA certificate for internal tests. // TODO: make it const. std::string certificate; std::string certificate; // Placeholder. More protocols might be defined in the future. // Placeholder. More protocols might be defined in the future. // TODO: make it const. int protocol = IPPROTO_TCP; int protocol = IPPROTO_TCP; // Exact comparison of DnsTlsServer objects // Exact comparison of DnsTlsServer objects Loading @@ -55,6 +62,13 @@ struct DnsTlsServer { bool operator==(const DnsTlsServer& other) const; bool operator==(const DnsTlsServer& other) const; bool wasExplicitlyConfigured() const; bool wasExplicitlyConfigured() const; Validation validationState() const { return mValidation; } void setValidationState(Validation val) { mValidation = val; } private: // State, unrelated to the comparison of DnsTlsServer objects. Validation mValidation = Validation::unknown_server; }; }; // This comparison only checks the IP address. It ignores ports, names, and fingerprints. // This comparison only checks the IP address. It ignores ports, names, and fingerprints. Loading
PrivateDnsConfiguration.cpp +37 −28 Original line number Original line Diff line number Diff line Loading @@ -69,7 +69,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, << ", " << servers.size() << ", " << name << ")"; << ", " << servers.size() << ", " << name << ")"; // Parse the list of servers that has been passed in // Parse the list of servers that has been passed in std::set<DnsTlsServer> tlsServers; PrivateDnsTracker tmp; for (const auto& s : servers) { for (const auto& s : servers) { sockaddr_storage parsed; sockaddr_storage parsed; if (!parseServer(s.c_str(), &parsed)) { if (!parseServer(s.c_str(), &parsed)) { Loading @@ -78,13 +78,13 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, DnsTlsServer server(parsed); DnsTlsServer server(parsed); server.name = name; server.name = name; server.certificate = caCert; server.certificate = caCert; tlsServers.insert(server); tmp[ServerIdentity(server)] = server; } } std::lock_guard guard(mPrivateDnsLock); std::lock_guard guard(mPrivateDnsLock); if (!name.empty()) { if (!name.empty()) { mPrivateDnsModes[netId] = PrivateDnsMode::STRICT; mPrivateDnsModes[netId] = PrivateDnsMode::STRICT; } else if (!tlsServers.empty()) { } else if (!tmp.empty()) { mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC; mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC; } else { } else { mPrivateDnsModes[netId] = PrivateDnsMode::OFF; mPrivateDnsModes[netId] = PrivateDnsMode::OFF; Loading Loading @@ -112,7 +112,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, // Remove any servers from the tracker that are not in |servers| exactly. // Remove any servers from the tracker that are not in |servers| exactly. for (auto it = tracker.begin(); it != tracker.end();) { for (auto it = tracker.begin(); it != tracker.end();) { if (tlsServers.count(it->first) == 0) { if (tmp.find(it->first) == tmp.end()) { it = tracker.erase(it); it = tracker.erase(it); } else { } else { ++it; ++it; Loading @@ -120,7 +120,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, } } // Add any new or changed servers to the tracker, and initiate async checks for them. // Add any new or changed servers to the tracker, and initiate async checks for them. for (const auto& server : tlsServers) { for (const auto& [identity, server] : tmp) { if (needsValidation(tracker, server)) { if (needsValidation(tracker, server)) { // This is temporarily required. Consider the following scenario, for example, // This is temporarily required. Consider the following scenario, for example, // Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts. // Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts. Loading @@ -133,7 +133,10 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, // // // If we didn't add servers to tracker before needValidateThread(), tracker would // If we didn't add servers to tracker before needValidateThread(), tracker would // become empty. We would report s1 validation failed. // become empty. We would report s1 validation failed. tracker[server] = Validation::in_process; if (tracker.find(identity) == tracker.end()) { tracker[identity] = server; } tracker[identity].setValidationState(Validation::in_process); LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId " LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId " << netId << ". Tracker now has size " << tracker.size(); << netId << ". Tracker now has size " << tracker.size(); // This judge must be after "tracker[server] = Validation::in_process;" // This judge must be after "tracker[server] = Validation::in_process;" Loading @@ -141,7 +144,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, continue; continue; } } updateServerState(server, Validation::in_process, netId); updateServerState(identity, Validation::in_process, netId); startValidation(server, netId, mark); startValidation(server, netId, mark); } } } } Loading @@ -159,8 +162,8 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) { const auto netPair = mPrivateDnsTransports.find(netId); const auto netPair = mPrivateDnsTransports.find(netId); if (netPair != mPrivateDnsTransports.end()) { if (netPair != mPrivateDnsTransports.end()) { for (const auto& serverPair : netPair->second) { for (const auto& [_, server] : netPair->second) { status.serversMap.emplace(serverPair.first, serverPair.second); status.serversMap.emplace(server, server.validationState()); } } } } Loading Loading @@ -227,20 +230,21 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser bool success) { bool success) { constexpr bool NEEDS_REEVALUATION = true; constexpr bool NEEDS_REEVALUATION = true; constexpr bool DONT_REEVALUATE = false; constexpr bool DONT_REEVALUATE = false; const ServerIdentity identity = ServerIdentity(server); std::lock_guard guard(mPrivateDnsLock); std::lock_guard guard(mPrivateDnsLock); auto netPair = mPrivateDnsTransports.find(netId); auto netPair = mPrivateDnsTransports.find(netId); if (netPair == mPrivateDnsTransports.end()) { if (netPair == mPrivateDnsTransports.end()) { LOG(WARNING) << "netId " << netId << " was erased during private DNS validation"; LOG(WARNING) << "netId " << netId << " was erased during private DNS validation"; maybeNotifyObserver(server, Validation::fail, netId); maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return DONT_REEVALUATE; return DONT_REEVALUATE; } } const auto mode = mPrivateDnsModes.find(netId); const auto mode = mPrivateDnsModes.find(netId); if (mode == mPrivateDnsModes.end()) { if (mode == mPrivateDnsModes.end()) { LOG(WARNING) << "netId " << netId << " has no private DNS validation mode"; LOG(WARNING) << "netId " << netId << " has no private DNS validation mode"; maybeNotifyObserver(server, Validation::fail, netId); maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return DONT_REEVALUATE; return DONT_REEVALUATE; } } const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT); const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT); Loading @@ -249,16 +253,13 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser (success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION; (success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION; auto& tracker = netPair->second; auto& tracker = netPair->second; auto serverPair = tracker.find(server); auto serverPair = tracker.find(identity); if (serverPair == tracker.end()) { if (serverPair == tracker.end()) { // TODO: Consider not adding this server to the tracker since this server is not expected // to be one of the private DNS servers for this network now. This could prevent this // server from being included when dumping status. LOG(WARNING) << "Server " << addrToString(&server.ss) LOG(WARNING) << "Server " << addrToString(&server.ss) << " was removed during private DNS validation"; << " was removed during private DNS validation"; success = false; success = false; reevaluationStatus = DONT_REEVALUATE; reevaluationStatus = DONT_REEVALUATE; } else if (!(serverPair->first == server)) { } else if (!(serverPair->second == server)) { // TODO: It doesn't seem correct to overwrite the tracker entry for // TODO: It doesn't seem correct to overwrite the tracker entry for // |server| down below in this circumstance... Fix this. // |server| down below in this circumstance... Fix this. LOG(WARNING) << "Server " << addrToString(&server.ss) LOG(WARNING) << "Server " << addrToString(&server.ss) Loading @@ -282,14 +283,14 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser } } if (success) { if (success) { updateServerState(server, Validation::success, netId); updateServerState(identity, Validation::success, netId); } else { } else { // Validation failure is expected if a user is on a captive portal. // Validation failure is expected if a user is on a captive portal. // TODO: Trigger a second validation attempt after captive portal login // TODO: Trigger a second validation attempt after captive portal login // succeeds. // succeeds. const auto result = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process const auto result = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process : Validation::fail; : Validation::fail; updateServerState(server, result, netId); updateServerState(identity, result, netId); } } LOG(WARNING) << "Validation " << (success ? "success" : "failed"); LOG(WARNING) << "Validation " << (success ? "success" : "failed"); Loading Loading @@ -324,15 +325,22 @@ bool PrivateDnsConfiguration::needValidateThread(const DnsTlsServer& server, uns } } } } void PrivateDnsConfiguration::updateServerState(const DnsTlsServer& server, Validation state, void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId) { uint32_t netId) { auto netPair = mPrivateDnsTransports.find(netId); auto netPair = mPrivateDnsTransports.find(netId); if (netPair != mPrivateDnsTransports.end()) { if (netPair == mPrivateDnsTransports.end()) { maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return; } auto& tracker = netPair->second; auto& tracker = netPair->second; tracker[server] = state; if (tracker.find(identity) == tracker.end()) { maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return; } } maybeNotifyObserver(server, state, netId); tracker[identity].setValidationState(state); maybeNotifyObserver(identity.ip.toString(), state, netId); } } void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server, void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server, Loading @@ -353,8 +361,9 @@ void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& ser bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker, bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server) { const DnsTlsServer& server) { const auto& iter = tracker.find(server); const ServerIdentity identity = ServerIdentity(server); return (iter == tracker.end()) || (iter->second == Validation::fail); const auto& iter = tracker.find(identity); return (iter == tracker.end()) || (iter->second.validationState() == Validation::fail); } } void PrivateDnsConfiguration::setObserver(Observer* observer) { void PrivateDnsConfiguration::setObserver(Observer* observer) { Loading @@ -362,10 +371,10 @@ void PrivateDnsConfiguration::setObserver(Observer* observer) { mObserver = observer; mObserver = observer; } } void PrivateDnsConfiguration::maybeNotifyObserver(const DnsTlsServer& server, Validation validation, void PrivateDnsConfiguration::maybeNotifyObserver(const std::string& serverIp, uint32_t netId) const { Validation validation, uint32_t netId) const { if (mObserver) { if (mObserver) { mObserver->onValidationStateUpdate(addrToString(&server.ss), validation, netId); mObserver->onValidationStateUpdate(serverIp, validation, netId); } } } } Loading
PrivateDnsConfiguration.h +29 −10 Original line number Original line Diff line number Diff line Loading @@ -22,6 +22,7 @@ #include <vector> #include <vector> #include <android-base/thread_annotations.h> #include <android-base/thread_annotations.h> #include <netdutils/InternetAddresses.h> #include "DnsTlsServer.h" #include "DnsTlsServer.h" Loading @@ -31,11 +32,10 @@ namespace net { // The DNS over TLS mode on a specific netId. // The DNS over TLS mode on a specific netId. enum class PrivateDnsMode : uint8_t { OFF, OPPORTUNISTIC, STRICT }; enum class PrivateDnsMode : uint8_t { OFF, OPPORTUNISTIC, STRICT }; // Validation status of a DNS over TLS server (on a specific netId). enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid }; struct PrivateDnsStatus { struct PrivateDnsStatus { PrivateDnsMode mode; PrivateDnsMode mode; // TODO: change the type to std::vector<DnsTlsServer>. std::map<DnsTlsServer, Validation, AddressComparator> serversMap; std::map<DnsTlsServer, Validation, AddressComparator> serversMap; std::list<DnsTlsServer> validatedServers() const { std::list<DnsTlsServer> validatedServers() const { Loading Loading @@ -65,8 +65,26 @@ class PrivateDnsConfiguration { void clear(unsigned netId) EXCLUDES(mPrivateDnsLock); void clear(unsigned netId) EXCLUDES(mPrivateDnsLock); struct ServerIdentity { const netdutils::IPAddress ip; const std::string name; const int protocol; explicit ServerIdentity(const DnsTlsServer& server) : ip(netdutils::IPSockAddr::toIPSockAddr(server.ss).ip()), name(server.name), protocol(server.protocol) {} bool operator<(const ServerIdentity& other) const { return std::tie(ip, name, protocol) < std::tie(other.ip, other.name, other.protocol); } bool operator==(const ServerIdentity& other) const { return std::tie(ip, name, protocol) == std::tie(other.ip, other.name, other.protocol); } }; private: private: typedef std::map<DnsTlsServer, Validation, AddressComparator> PrivateDnsTracker; typedef std::map<ServerIdentity, DnsTlsServer> PrivateDnsTracker; typedef std::set<DnsTlsServer, AddressComparator> ThreadTracker; typedef std::set<DnsTlsServer, AddressComparator> ThreadTracker; PrivateDnsConfiguration() = default; PrivateDnsConfiguration() = default; Loading @@ -88,7 +106,7 @@ class PrivateDnsConfiguration { bool needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server) bool needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server) REQUIRES(mPrivateDnsLock); REQUIRES(mPrivateDnsLock); void updateServerState(const DnsTlsServer& server, Validation state, uint32_t netId) void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId) REQUIRES(mPrivateDnsLock); REQUIRES(mPrivateDnsLock); std::mutex mPrivateDnsLock; std::mutex mPrivateDnsLock; Loading @@ -99,19 +117,20 @@ class PrivateDnsConfiguration { std::map<unsigned, ThreadTracker> mPrivateDnsValidateThreads GUARDED_BY(mPrivateDnsLock); std::map<unsigned, ThreadTracker> mPrivateDnsValidateThreads GUARDED_BY(mPrivateDnsLock); // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is // about to begin or 2) when a validation finishes. // about to begin or 2) when a validation finishes. If a validation finishes when in OFF mode // WARNING: The Observer is notified while the lock is being held. Be careful not to call any // or when the network has been destroyed, |validation| will be Validation::fail. // method of PrivateDnsConfiguration from the observer. // WARNING: The Observer is notified while the lock is being held. Be careful not to call // any method of PrivateDnsConfiguration from the observer. // TODO: fix the reentrancy problem. // TODO: fix the reentrancy problem. class Observer { class Observer { public: public: virtual ~Observer(){}; virtual ~Observer(){}; virtual void onValidationStateUpdate(const std::string& server, Validation validation, virtual void onValidationStateUpdate(const std::string& serverIp, Validation validation, uint32_t netId) = 0; uint32_t netId) = 0; }; }; void setObserver(Observer* observer); void setObserver(Observer* observer); void maybeNotifyObserver(const DnsTlsServer& server, Validation validation, void maybeNotifyObserver(const std::string& serverIp, Validation validation, uint32_t netId) const REQUIRES(mPrivateDnsLock); uint32_t netId) const REQUIRES(mPrivateDnsLock); Observer* mObserver GUARDED_BY(mPrivateDnsLock); Observer* mObserver GUARDED_BY(mPrivateDnsLock); Loading
PrivateDnsConfigurationTest.cpp +37 −1 Original line number Original line Diff line number Diff line Loading @@ -63,7 +63,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test { class MockObserver : public PrivateDnsConfiguration::Observer { class MockObserver : public PrivateDnsConfiguration::Observer { public: public: MOCK_METHOD(void, onValidationStateUpdate, MOCK_METHOD(void, onValidationStateUpdate, (const std::string& server, Validation validation, uint32_t netId), (override)); (const std::string& serverIp, Validation validation, uint32_t netId), (override)); std::map<std::string, Validation> getServerStateMap() const { std::map<std::string, Validation> getServerStateMap() const { std::lock_guard guard(lock); std::lock_guard guard(lock); Loading Loading @@ -172,6 +173,11 @@ TEST_F(PrivateDnsConfigurationTest, ValidationBlock) { backend.setDeferredResp(false); backend.setDeferredResp(false); ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); // kServer1 is not a present server and thus should not be available from // PrivateDnsConfiguration::getStatus(). mObserver.removeFromServerStateMap(kServer1); expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); } } Loading Loading @@ -218,6 +224,36 @@ TEST_F(PrivateDnsConfigurationTest, NoValidation) { expectStatus(); expectStatus(); } } TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) { using ServerIdentity = PrivateDnsConfiguration::ServerIdentity; DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853)); server.name = "dns.example.com"; server.protocol = 1; // Different IP address (port is ignored). DnsTlsServer other = server; EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 5353); EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.2", 853); EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); // Different provider hostname. other = server; EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.name = "other.example.com"; EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); other.name = ""; EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); // Different protocol. other = server; EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.protocol++; EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); } // TODO: add ValidationFail_Strict test. // TODO: add ValidationFail_Strict test. } // namespace android::net } // namespace android::net
resolv_tls_unit_test.cpp +24 −0 Original line number Original line Diff line number Diff line Loading @@ -800,6 +800,18 @@ void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) { EXPECT_FALSE(s2 == s1); EXPECT_FALSE(s2 == s1); } } void checkEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) { EXPECT_TRUE(s1 == s1); EXPECT_TRUE(s2 == s2); EXPECT_TRUE(isAddressEqual(s1, s1)); EXPECT_TRUE(isAddressEqual(s2, s2)); EXPECT_FALSE(s1 < s2); EXPECT_FALSE(s2 < s1); EXPECT_TRUE(s1 == s2); EXPECT_TRUE(s2 == s1); } class ServerTest : public BaseTest {}; class ServerTest : public BaseTest {}; TEST_F(ServerTest, IPv4) { TEST_F(ServerTest, IPv4) { Loading Loading @@ -873,6 +885,18 @@ TEST_F(ServerTest, Name) { EXPECT_TRUE(s2.wasExplicitlyConfigured()); EXPECT_TRUE(s2.wasExplicitlyConfigured()); } } TEST_F(ServerTest, State) { DnsTlsServer s1(V4ADDR1), s2(V4ADDR1); checkEqual(s1, s2); s1.setValidationState(Validation::success); checkEqual(s1, s2); s2.setValidationState(Validation::fail); checkEqual(s1, s2); EXPECT_EQ(s1.validationState(), Validation::success); EXPECT_EQ(s2.validationState(), Validation::fail); } TEST(QueryMapTest, Basic) { TEST(QueryMapTest, Basic) { DnsTlsQueryMap map; DnsTlsQueryMap map; Loading