Loading DnsTlsServer.cpp +6 −0 Original line number Diff line number Diff line Loading @@ -18,6 +18,8 @@ #include <algorithm> #include <netdutils/InternetAddresses.h> namespace { // Returns a tuple of references to the elements of a. Loading Loading @@ -124,5 +126,9 @@ bool DnsTlsServer::wasExplicitlyConfigured() const { return !name.empty(); } std::string DnsTlsServer::toIpString() const { return netdutils::IPSockAddr::toIPSockAddr(ss).ip().toString(); } } // namespace net } // namespace android DnsTlsServer.h +28 −0 Original line number Diff line number Diff line Loading @@ -27,6 +27,16 @@ namespace android { namespace net { // Validation status of a DNS over TLS server (on a specific netId). enum class Validation : uint8_t { in_process, success, success_but_expired, fail, unknown_server, unknown_netid, }; // DnsTlsServer represents a recursive resolver that supports, or may support, a // secure protocol. struct DnsTlsServer { Loading @@ -37,17 +47,21 @@ struct DnsTlsServer { DnsTlsServer(const sockaddr_storage& ss) : ss(ss) {} // The server location, including IP and port. // TODO: make it const. sockaddr_storage ss = {}; // The server's hostname. If this string is nonempty, the server must present a // certificate that indicates this name and has a valid chain to a trusted root CA. // TODO: make it const. std::string name; // The certificate of the CA that signed the server's certificate. // It is used to store temporary test CA certificate for internal tests. // TODO: make it const. std::string certificate; // Placeholder. More protocols might be defined in the future. // TODO: make it const. int protocol = IPPROTO_TCP; // Exact comparison of DnsTlsServer objects Loading @@ -55,6 +69,20 @@ struct DnsTlsServer { bool operator==(const DnsTlsServer& other) const; bool wasExplicitlyConfigured() const; std::string toIpString() const; Validation validationState() const { return mValidation; } void setValidationState(Validation val) { mValidation = val; } // Return whether or not the server can be used for a network. It depends on // the resolver configuration. bool active() const { return mActive; } void setActive(bool val) { mActive = val; } private: // State, unrelated to the comparison of DnsTlsServer objects. Validation mValidation = Validation::unknown_server; bool mActive = false; }; // This comparison only checks the IP address. It ignores ports, names, and fingerprints. Loading PrivateDnsConfiguration.cpp +69 −129 Original line number Diff line number Diff line Loading @@ -20,7 +20,6 @@ #include <android-base/logging.h> #include <android-base/stringprintf.h> #include <netdb.h> #include <netdutils/ThreadUtil.h> #include <sys/socket.h> Loading @@ -37,13 +36,6 @@ using std::chrono::milliseconds; namespace android { namespace net { std::string addrToString(const sockaddr_storage* addr) { char out[INET6_ADDRSTRLEN] = {0}; getnameinfo((const sockaddr*) addr, sizeof(sockaddr_storage), out, INET6_ADDRSTRLEN, nullptr, 0, NI_NUMERICHOST); return std::string(out); } bool parseServer(const char* server, sockaddr_storage* parsed) { addrinfo hints = { .ai_flags = AI_NUMERICHOST | AI_NUMERICSERV, Loading @@ -69,7 +61,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, << ", " << servers.size() << ", " << name << ")"; // Parse the list of servers that has been passed in std::set<DnsTlsServer> tlsServers; PrivateDnsTracker tmp; for (const auto& s : servers) { sockaddr_storage parsed; if (!parseServer(s.c_str(), &parsed)) { Loading @@ -78,70 +70,42 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, DnsTlsServer server(parsed); server.name = name; server.certificate = caCert; tlsServers.insert(server); tmp[ServerIdentity(server)] = server; } std::lock_guard guard(mPrivateDnsLock); if (!name.empty()) { mPrivateDnsModes[netId] = PrivateDnsMode::STRICT; } else if (!tlsServers.empty()) { } else if (!tmp.empty()) { mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC; } else { mPrivateDnsModes[netId] = PrivateDnsMode::OFF; mPrivateDnsTransports.erase(netId); mPrivateDnsValidateThreads.erase(netId); // TODO: As mPrivateDnsValidateThreads is reset, validation threads which haven't yet // finished are considered outdated. Consider signaling the outdated validation threads to // stop them from updating the state of PrivateDnsConfiguration (possibly disallow them to // report onPrivateDnsValidationEvent). // TODO: signal validation threads to stop. return 0; } // Create the tracker if it was not present auto netPair = mPrivateDnsTransports.find(netId); if (netPair == mPrivateDnsTransports.end()) { // No TLS tracker yet for this netId. bool added; std::tie(netPair, added) = mPrivateDnsTransports.emplace(netId, PrivateDnsTracker()); if (!added) { LOG(ERROR) << "Memory error while recording private DNS for netId " << netId; return -ENOMEM; } } auto& tracker = netPair->second; auto& tracker = mPrivateDnsTransports[netId]; // Remove any servers from the tracker that are not in |servers| exactly. for (auto it = tracker.begin(); it != tracker.end();) { if (tlsServers.count(it->first) == 0) { it = tracker.erase(it); } else { ++it; // Add the servers if not contained in tracker. for (const auto& [identity, server] : tmp) { if (tracker.find(identity) == tracker.end()) { tracker[identity] = server; } } // Add any new or changed servers to the tracker, and initiate async checks for them. for (const auto& server : tlsServers) { if (needsValidation(tracker, server)) { // This is temporarily required. Consider the following scenario, for example, // Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts. // tracker has s1 alone. // Step 2) The configuration changes. DotServer2 (s2) is set for the network. A // validation (v2) for s2 starts. tracker has s2 alone. // Step 3) Assume v1 and v2 somehow block. Now, if the configuration changes back to // set s1, there won't be a v1' starts because needValidateThread() will // return false. // // If we didn't add servers to tracker before needValidateThread(), tracker would // become empty. We would report s1 validation failed. tracker[server] = Validation::in_process; LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId " << netId << ". Tracker now has size " << tracker.size(); // This judge must be after "tracker[server] = Validation::in_process;" if (!needValidateThread(server, netId)) { continue; for (auto& [identity, server] : tracker) { const bool active = tmp.find(identity) != tmp.end(); server.setActive(active); // For simplicity, deem the validation result of inactive servers as unreliable. if (!server.active() && server.validationState() == Validation::success) { updateServerState(identity, Validation::success_but_expired, netId); } updateServerState(server, Validation::in_process, netId); if (needsValidation(server)) { updateServerState(identity, Validation::in_process, netId); startValidation(server, netId, mark); } } Loading @@ -159,8 +123,10 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) { const auto netPair = mPrivateDnsTransports.find(netId); if (netPair != mPrivateDnsTransports.end()) { for (const auto& serverPair : netPair->second) { status.serversMap.emplace(serverPair.first, serverPair.second); for (const auto& [_, server] : netPair->second) { if (server.active()) { status.serversMap.emplace(server, server.validationState()); } } } Loading @@ -172,7 +138,6 @@ void PrivateDnsConfiguration::clear(unsigned netId) { std::lock_guard guard(mPrivateDnsLock); mPrivateDnsModes.erase(netId); mPrivateDnsTransports.erase(netId); mPrivateDnsValidateThreads.erase(netId); } void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId, Loading Loading @@ -205,7 +170,7 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign LOG(WARNING) << "Validating DnsTlsServer on netId " << netId; const bool success = DnsTlsTransport::validate(server, netId, mark); LOG(DEBUG) << "validateDnsTlsServer returned " << success << " for " << addrToString(&server.ss); << server.toIpString(); const bool needs_reeval = this->recordPrivateDnsValidation(server, netId, success); if (!needs_reeval) { Loading @@ -213,12 +178,12 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign } if (backoff.hasNextTimeout()) { // TODO: make the thread able to receive signals to shutdown early. std::this_thread::sleep_for(backoff.getNextTimeout()); } else { break; } } this->cleanValidateThreadTracker(server, netId); }); validate_thread.detach(); } Loading @@ -227,20 +192,21 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser bool success) { constexpr bool NEEDS_REEVALUATION = true; constexpr bool DONT_REEVALUATE = false; const ServerIdentity identity = ServerIdentity(server); std::lock_guard guard(mPrivateDnsLock); auto netPair = mPrivateDnsTransports.find(netId); if (netPair == mPrivateDnsTransports.end()) { LOG(WARNING) << "netId " << netId << " was erased during private DNS validation"; maybeNotifyObserver(server, Validation::fail, netId); maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return DONT_REEVALUATE; } const auto mode = mPrivateDnsModes.find(netId); if (mode == mPrivateDnsModes.end()) { LOG(WARNING) << "netId " << netId << " has no private DNS validation mode"; maybeNotifyObserver(server, Validation::fail, netId); maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return DONT_REEVALUATE; } const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT); Loading @@ -249,112 +215,86 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser (success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION; auto& tracker = netPair->second; auto serverPair = tracker.find(server); auto serverPair = tracker.find(identity); if (serverPair == tracker.end()) { // TODO: Consider not adding this server to the tracker since this server is not expected // to be one of the private DNS servers for this network now. This could prevent this // server from being included when dumping status. LOG(WARNING) << "Server " << addrToString(&server.ss) LOG(WARNING) << "Server " << server.toIpString() << " was removed during private DNS validation"; success = false; reevaluationStatus = DONT_REEVALUATE; } else if (!(serverPair->first == server)) { } else if (!(serverPair->second == server)) { // TODO: It doesn't seem correct to overwrite the tracker entry for // |server| down below in this circumstance... Fix this. LOG(WARNING) << "Server " << addrToString(&server.ss) LOG(WARNING) << "Server " << server.toIpString() << " was changed during private DNS validation"; success = false; reevaluationStatus = DONT_REEVALUATE; } else if (!serverPair->second.active()) { LOG(WARNING) << "Server " << server.toIpString() << " was removed from the configuration"; success = false; reevaluationStatus = DONT_REEVALUATE; } // Send a validation event to NetdEventListenerService. const auto& listeners = ResolverEventReporter::getInstance().getListeners(); if (listeners.size() != 0) { for (const auto& it : listeners) { it->onPrivateDnsValidationEvent(netId, addrToString(&server.ss), server.name, success); it->onPrivateDnsValidationEvent(netId, server.toIpString(), server.name, success); } LOG(DEBUG) << "Sent validation " << (success ? "success" : "failure") << " event on netId " << netId << " for " << addrToString(&server.ss) << " with hostname {" << server.name << "}"; << netId << " for " << server.toIpString() << " with hostname {" << server.name << "}"; } else { LOG(ERROR) << "Validation event not sent since no INetdEventListener receiver is available."; } if (success) { updateServerState(server, Validation::success, netId); updateServerState(identity, Validation::success, netId); } else { // Validation failure is expected if a user is on a captive portal. // TODO: Trigger a second validation attempt after captive portal login // succeeds. const auto result = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process : Validation::fail; updateServerState(server, result, netId); updateServerState(identity, result, netId); } LOG(WARNING) << "Validation " << (success ? "success" : "failed"); return reevaluationStatus; } bool PrivateDnsConfiguration::needValidateThread(const DnsTlsServer& server, unsigned netId) REQUIRES(mPrivateDnsLock) { // Create the thread tracker if it was not present auto threadPair = mPrivateDnsValidateThreads.find(netId); if (threadPair == mPrivateDnsValidateThreads.end()) { // No thread tracker yet for this netId. bool added; std::tie(threadPair, added) = mPrivateDnsValidateThreads.emplace(netId, ThreadTracker()); if (!added) { LOG(ERROR) << "Memory error while needValidateThread for netId " << netId; return true; } } auto& threadTracker = threadPair->second; if (threadTracker.count(server)) { LOG(DEBUG) << "Server " << addrToString(&(server.ss)) << " validate thread is already running. Thread tracker now has size " << threadTracker.size(); return false; } else { threadTracker.insert(server); LOG(DEBUG) << "Server " << addrToString(&(server.ss)) << " validate thread is not running. Thread tracker now has size " << threadTracker.size(); return true; } } void PrivateDnsConfiguration::updateServerState(const DnsTlsServer& server, Validation state, void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId) { auto netPair = mPrivateDnsTransports.find(netId); if (netPair != mPrivateDnsTransports.end()) { if (netPair == mPrivateDnsTransports.end()) { maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return; } auto& tracker = netPair->second; tracker[server] = state; if (tracker.find(identity) == tracker.end()) { maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return; } maybeNotifyObserver(server, state, netId); tracker[identity].setValidationState(state); maybeNotifyObserver(identity.ip.toString(), state, netId); } void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server, unsigned netId) { std::lock_guard<std::mutex> guard(mPrivateDnsLock); LOG(DEBUG) << "cleanValidateThreadTracker Server " << addrToString(&(server.ss)) << " validate thread is stopped."; bool PrivateDnsConfiguration::needsValidation(const DnsTlsServer& server) { // The server is not expected to be used on the network. if (!server.active()) return false; auto threadPair = mPrivateDnsValidateThreads.find(netId); if (threadPair != mPrivateDnsValidateThreads.end()) { auto& threadTracker = threadPair->second; threadTracker.erase(server); LOG(DEBUG) << "Server " << addrToString(&(server.ss)) << " validate thread is stopped. Thread tracker now has size " << threadTracker.size(); } } // The server is newly added. if (server.validationState() == Validation::unknown_server) return true; bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server) { const auto& iter = tracker.find(server); return (iter == tracker.end()) || (iter->second == Validation::fail); // The server has failed at least one validation attempt. Give it another try. if (server.validationState() == Validation::fail) return true; // The previous validation result might be unreliable. if (server.validationState() == Validation::success_but_expired) return true; return false; } void PrivateDnsConfiguration::setObserver(Observer* observer) { Loading @@ -362,10 +302,10 @@ void PrivateDnsConfiguration::setObserver(Observer* observer) { mObserver = observer; } void PrivateDnsConfiguration::maybeNotifyObserver(const DnsTlsServer& server, Validation validation, uint32_t netId) const { void PrivateDnsConfiguration::maybeNotifyObserver(const std::string& serverIp, Validation validation, uint32_t netId) const { if (mObserver) { mObserver->onValidationStateUpdate(addrToString(&server.ss), validation, netId); mObserver->onValidationStateUpdate(serverIp, validation, netId); } } Loading PrivateDnsConfiguration.h +36 −21 Original line number Diff line number Diff line Loading @@ -22,6 +22,7 @@ #include <vector> #include <android-base/thread_annotations.h> #include <netdutils/InternetAddresses.h> #include "DnsTlsServer.h" Loading @@ -31,11 +32,10 @@ namespace net { // The DNS over TLS mode on a specific netId. enum class PrivateDnsMode : uint8_t { OFF, OPPORTUNISTIC, STRICT }; // Validation status of a DNS over TLS server (on a specific netId). enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid }; struct PrivateDnsStatus { PrivateDnsMode mode; // TODO: change the type to std::vector<DnsTlsServer>. std::map<DnsTlsServer, Validation, AddressComparator> serversMap; std::list<DnsTlsServer> validatedServers() const { Loading Loading @@ -65,8 +65,26 @@ class PrivateDnsConfiguration { void clear(unsigned netId) EXCLUDES(mPrivateDnsLock); struct ServerIdentity { const netdutils::IPAddress ip; const std::string name; const int protocol; explicit ServerIdentity(const DnsTlsServer& server) : ip(netdutils::IPSockAddr::toIPSockAddr(server.ss).ip()), name(server.name), protocol(server.protocol) {} bool operator<(const ServerIdentity& other) const { return std::tie(ip, name, protocol) < std::tie(other.ip, other.name, other.protocol); } bool operator==(const ServerIdentity& other) const { return std::tie(ip, name, protocol) == std::tie(other.ip, other.name, other.protocol); } }; private: typedef std::map<DnsTlsServer, Validation, AddressComparator> PrivateDnsTracker; typedef std::map<ServerIdentity, DnsTlsServer> PrivateDnsTracker; typedef std::set<DnsTlsServer, AddressComparator> ThreadTracker; PrivateDnsConfiguration() = default; Loading @@ -77,41 +95,38 @@ class PrivateDnsConfiguration { bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success) EXCLUDES(mPrivateDnsLock); bool needValidateThread(const DnsTlsServer& server, unsigned netId) REQUIRES(mPrivateDnsLock); void cleanValidateThreadTracker(const DnsTlsServer& server, unsigned netId) EXCLUDES(mPrivateDnsLock); // Start validation for newly added servers as well as any servers that have // landed in Validation::fail state. Note that servers that have failed // Decide if a validation for |server| is needed. Note that servers that have failed // multiple validation attempts but for which there is still a validating // thread running are marked as being Validation::in_process. bool needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server) REQUIRES(mPrivateDnsLock); bool needsValidation(const DnsTlsServer& server) REQUIRES(mPrivateDnsLock); void updateServerState(const DnsTlsServer& server, Validation state, uint32_t netId) void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId) REQUIRES(mPrivateDnsLock); std::mutex mPrivateDnsLock; std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock); // Structure for tracking the validation status of servers on a specific netId. // Using the AddressComparator ensures at most one entry per IP address. // Contains all servers for a network, along with their current validation status. // In case a server is removed due to a configuration change, it remains in this map, // but is marked inactive. // Any pending validation threads will continue running because we have no way to cancel them. std::map<unsigned, PrivateDnsTracker> mPrivateDnsTransports GUARDED_BY(mPrivateDnsLock); std::map<unsigned, ThreadTracker> mPrivateDnsValidateThreads GUARDED_BY(mPrivateDnsLock); // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is // about to begin or 2) when a validation finishes. // WARNING: The Observer is notified while the lock is being held. Be careful not to call any // method of PrivateDnsConfiguration from the observer. // about to begin or 2) when a validation finishes. If a validation finishes when in OFF mode // or when the network has been destroyed, |validation| will be Validation::fail. // WARNING: The Observer is notified while the lock is being held. Be careful not to call // any method of PrivateDnsConfiguration from the observer. // TODO: fix the reentrancy problem. class Observer { public: virtual ~Observer(){}; virtual void onValidationStateUpdate(const std::string& server, Validation validation, virtual void onValidationStateUpdate(const std::string& serverIp, Validation validation, uint32_t netId) = 0; }; void setObserver(Observer* observer); void maybeNotifyObserver(const DnsTlsServer& server, Validation validation, void maybeNotifyObserver(const std::string& serverIp, Validation validation, uint32_t netId) const REQUIRES(mPrivateDnsLock); Observer* mObserver GUARDED_BY(mPrivateDnsLock); Loading PrivateDnsConfigurationTest.cpp +37 −1 Original line number Diff line number Diff line Loading @@ -63,7 +63,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test { class MockObserver : public PrivateDnsConfiguration::Observer { public: MOCK_METHOD(void, onValidationStateUpdate, (const std::string& server, Validation validation, uint32_t netId), (override)); (const std::string& serverIp, Validation validation, uint32_t netId), (override)); std::map<std::string, Validation> getServerStateMap() const { std::lock_guard guard(lock); Loading Loading @@ -172,6 +173,11 @@ TEST_F(PrivateDnsConfigurationTest, ValidationBlock) { backend.setDeferredResp(false); ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); // kServer1 is not a present server and thus should not be available from // PrivateDnsConfiguration::getStatus(). mObserver.removeFromServerStateMap(kServer1); expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); } Loading Loading @@ -218,6 +224,36 @@ TEST_F(PrivateDnsConfigurationTest, NoValidation) { expectStatus(); } TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) { using ServerIdentity = PrivateDnsConfiguration::ServerIdentity; DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853)); server.name = "dns.example.com"; server.protocol = 1; // Different IP address (port is ignored). DnsTlsServer other = server; EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 5353); EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.2", 853); EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); // Different provider hostname. other = server; EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.name = "other.example.com"; EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); other.name = ""; EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); // Different protocol. other = server; EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.protocol++; EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); } // TODO: add ValidationFail_Strict test. } // namespace android::net Loading
DnsTlsServer.cpp +6 −0 Original line number Diff line number Diff line Loading @@ -18,6 +18,8 @@ #include <algorithm> #include <netdutils/InternetAddresses.h> namespace { // Returns a tuple of references to the elements of a. Loading Loading @@ -124,5 +126,9 @@ bool DnsTlsServer::wasExplicitlyConfigured() const { return !name.empty(); } std::string DnsTlsServer::toIpString() const { return netdutils::IPSockAddr::toIPSockAddr(ss).ip().toString(); } } // namespace net } // namespace android
DnsTlsServer.h +28 −0 Original line number Diff line number Diff line Loading @@ -27,6 +27,16 @@ namespace android { namespace net { // Validation status of a DNS over TLS server (on a specific netId). enum class Validation : uint8_t { in_process, success, success_but_expired, fail, unknown_server, unknown_netid, }; // DnsTlsServer represents a recursive resolver that supports, or may support, a // secure protocol. struct DnsTlsServer { Loading @@ -37,17 +47,21 @@ struct DnsTlsServer { DnsTlsServer(const sockaddr_storage& ss) : ss(ss) {} // The server location, including IP and port. // TODO: make it const. sockaddr_storage ss = {}; // The server's hostname. If this string is nonempty, the server must present a // certificate that indicates this name and has a valid chain to a trusted root CA. // TODO: make it const. std::string name; // The certificate of the CA that signed the server's certificate. // It is used to store temporary test CA certificate for internal tests. // TODO: make it const. std::string certificate; // Placeholder. More protocols might be defined in the future. // TODO: make it const. int protocol = IPPROTO_TCP; // Exact comparison of DnsTlsServer objects Loading @@ -55,6 +69,20 @@ struct DnsTlsServer { bool operator==(const DnsTlsServer& other) const; bool wasExplicitlyConfigured() const; std::string toIpString() const; Validation validationState() const { return mValidation; } void setValidationState(Validation val) { mValidation = val; } // Return whether or not the server can be used for a network. It depends on // the resolver configuration. bool active() const { return mActive; } void setActive(bool val) { mActive = val; } private: // State, unrelated to the comparison of DnsTlsServer objects. Validation mValidation = Validation::unknown_server; bool mActive = false; }; // This comparison only checks the IP address. It ignores ports, names, and fingerprints. Loading
PrivateDnsConfiguration.cpp +69 −129 Original line number Diff line number Diff line Loading @@ -20,7 +20,6 @@ #include <android-base/logging.h> #include <android-base/stringprintf.h> #include <netdb.h> #include <netdutils/ThreadUtil.h> #include <sys/socket.h> Loading @@ -37,13 +36,6 @@ using std::chrono::milliseconds; namespace android { namespace net { std::string addrToString(const sockaddr_storage* addr) { char out[INET6_ADDRSTRLEN] = {0}; getnameinfo((const sockaddr*) addr, sizeof(sockaddr_storage), out, INET6_ADDRSTRLEN, nullptr, 0, NI_NUMERICHOST); return std::string(out); } bool parseServer(const char* server, sockaddr_storage* parsed) { addrinfo hints = { .ai_flags = AI_NUMERICHOST | AI_NUMERICSERV, Loading @@ -69,7 +61,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, << ", " << servers.size() << ", " << name << ")"; // Parse the list of servers that has been passed in std::set<DnsTlsServer> tlsServers; PrivateDnsTracker tmp; for (const auto& s : servers) { sockaddr_storage parsed; if (!parseServer(s.c_str(), &parsed)) { Loading @@ -78,70 +70,42 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, DnsTlsServer server(parsed); server.name = name; server.certificate = caCert; tlsServers.insert(server); tmp[ServerIdentity(server)] = server; } std::lock_guard guard(mPrivateDnsLock); if (!name.empty()) { mPrivateDnsModes[netId] = PrivateDnsMode::STRICT; } else if (!tlsServers.empty()) { } else if (!tmp.empty()) { mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC; } else { mPrivateDnsModes[netId] = PrivateDnsMode::OFF; mPrivateDnsTransports.erase(netId); mPrivateDnsValidateThreads.erase(netId); // TODO: As mPrivateDnsValidateThreads is reset, validation threads which haven't yet // finished are considered outdated. Consider signaling the outdated validation threads to // stop them from updating the state of PrivateDnsConfiguration (possibly disallow them to // report onPrivateDnsValidationEvent). // TODO: signal validation threads to stop. return 0; } // Create the tracker if it was not present auto netPair = mPrivateDnsTransports.find(netId); if (netPair == mPrivateDnsTransports.end()) { // No TLS tracker yet for this netId. bool added; std::tie(netPair, added) = mPrivateDnsTransports.emplace(netId, PrivateDnsTracker()); if (!added) { LOG(ERROR) << "Memory error while recording private DNS for netId " << netId; return -ENOMEM; } } auto& tracker = netPair->second; auto& tracker = mPrivateDnsTransports[netId]; // Remove any servers from the tracker that are not in |servers| exactly. for (auto it = tracker.begin(); it != tracker.end();) { if (tlsServers.count(it->first) == 0) { it = tracker.erase(it); } else { ++it; // Add the servers if not contained in tracker. for (const auto& [identity, server] : tmp) { if (tracker.find(identity) == tracker.end()) { tracker[identity] = server; } } // Add any new or changed servers to the tracker, and initiate async checks for them. for (const auto& server : tlsServers) { if (needsValidation(tracker, server)) { // This is temporarily required. Consider the following scenario, for example, // Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts. // tracker has s1 alone. // Step 2) The configuration changes. DotServer2 (s2) is set for the network. A // validation (v2) for s2 starts. tracker has s2 alone. // Step 3) Assume v1 and v2 somehow block. Now, if the configuration changes back to // set s1, there won't be a v1' starts because needValidateThread() will // return false. // // If we didn't add servers to tracker before needValidateThread(), tracker would // become empty. We would report s1 validation failed. tracker[server] = Validation::in_process; LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId " << netId << ". Tracker now has size " << tracker.size(); // This judge must be after "tracker[server] = Validation::in_process;" if (!needValidateThread(server, netId)) { continue; for (auto& [identity, server] : tracker) { const bool active = tmp.find(identity) != tmp.end(); server.setActive(active); // For simplicity, deem the validation result of inactive servers as unreliable. if (!server.active() && server.validationState() == Validation::success) { updateServerState(identity, Validation::success_but_expired, netId); } updateServerState(server, Validation::in_process, netId); if (needsValidation(server)) { updateServerState(identity, Validation::in_process, netId); startValidation(server, netId, mark); } } Loading @@ -159,8 +123,10 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) { const auto netPair = mPrivateDnsTransports.find(netId); if (netPair != mPrivateDnsTransports.end()) { for (const auto& serverPair : netPair->second) { status.serversMap.emplace(serverPair.first, serverPair.second); for (const auto& [_, server] : netPair->second) { if (server.active()) { status.serversMap.emplace(server, server.validationState()); } } } Loading @@ -172,7 +138,6 @@ void PrivateDnsConfiguration::clear(unsigned netId) { std::lock_guard guard(mPrivateDnsLock); mPrivateDnsModes.erase(netId); mPrivateDnsTransports.erase(netId); mPrivateDnsValidateThreads.erase(netId); } void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId, Loading Loading @@ -205,7 +170,7 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign LOG(WARNING) << "Validating DnsTlsServer on netId " << netId; const bool success = DnsTlsTransport::validate(server, netId, mark); LOG(DEBUG) << "validateDnsTlsServer returned " << success << " for " << addrToString(&server.ss); << server.toIpString(); const bool needs_reeval = this->recordPrivateDnsValidation(server, netId, success); if (!needs_reeval) { Loading @@ -213,12 +178,12 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign } if (backoff.hasNextTimeout()) { // TODO: make the thread able to receive signals to shutdown early. std::this_thread::sleep_for(backoff.getNextTimeout()); } else { break; } } this->cleanValidateThreadTracker(server, netId); }); validate_thread.detach(); } Loading @@ -227,20 +192,21 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser bool success) { constexpr bool NEEDS_REEVALUATION = true; constexpr bool DONT_REEVALUATE = false; const ServerIdentity identity = ServerIdentity(server); std::lock_guard guard(mPrivateDnsLock); auto netPair = mPrivateDnsTransports.find(netId); if (netPair == mPrivateDnsTransports.end()) { LOG(WARNING) << "netId " << netId << " was erased during private DNS validation"; maybeNotifyObserver(server, Validation::fail, netId); maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return DONT_REEVALUATE; } const auto mode = mPrivateDnsModes.find(netId); if (mode == mPrivateDnsModes.end()) { LOG(WARNING) << "netId " << netId << " has no private DNS validation mode"; maybeNotifyObserver(server, Validation::fail, netId); maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return DONT_REEVALUATE; } const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT); Loading @@ -249,112 +215,86 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser (success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION; auto& tracker = netPair->second; auto serverPair = tracker.find(server); auto serverPair = tracker.find(identity); if (serverPair == tracker.end()) { // TODO: Consider not adding this server to the tracker since this server is not expected // to be one of the private DNS servers for this network now. This could prevent this // server from being included when dumping status. LOG(WARNING) << "Server " << addrToString(&server.ss) LOG(WARNING) << "Server " << server.toIpString() << " was removed during private DNS validation"; success = false; reevaluationStatus = DONT_REEVALUATE; } else if (!(serverPair->first == server)) { } else if (!(serverPair->second == server)) { // TODO: It doesn't seem correct to overwrite the tracker entry for // |server| down below in this circumstance... Fix this. LOG(WARNING) << "Server " << addrToString(&server.ss) LOG(WARNING) << "Server " << server.toIpString() << " was changed during private DNS validation"; success = false; reevaluationStatus = DONT_REEVALUATE; } else if (!serverPair->second.active()) { LOG(WARNING) << "Server " << server.toIpString() << " was removed from the configuration"; success = false; reevaluationStatus = DONT_REEVALUATE; } // Send a validation event to NetdEventListenerService. const auto& listeners = ResolverEventReporter::getInstance().getListeners(); if (listeners.size() != 0) { for (const auto& it : listeners) { it->onPrivateDnsValidationEvent(netId, addrToString(&server.ss), server.name, success); it->onPrivateDnsValidationEvent(netId, server.toIpString(), server.name, success); } LOG(DEBUG) << "Sent validation " << (success ? "success" : "failure") << " event on netId " << netId << " for " << addrToString(&server.ss) << " with hostname {" << server.name << "}"; << netId << " for " << server.toIpString() << " with hostname {" << server.name << "}"; } else { LOG(ERROR) << "Validation event not sent since no INetdEventListener receiver is available."; } if (success) { updateServerState(server, Validation::success, netId); updateServerState(identity, Validation::success, netId); } else { // Validation failure is expected if a user is on a captive portal. // TODO: Trigger a second validation attempt after captive portal login // succeeds. const auto result = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process : Validation::fail; updateServerState(server, result, netId); updateServerState(identity, result, netId); } LOG(WARNING) << "Validation " << (success ? "success" : "failed"); return reevaluationStatus; } bool PrivateDnsConfiguration::needValidateThread(const DnsTlsServer& server, unsigned netId) REQUIRES(mPrivateDnsLock) { // Create the thread tracker if it was not present auto threadPair = mPrivateDnsValidateThreads.find(netId); if (threadPair == mPrivateDnsValidateThreads.end()) { // No thread tracker yet for this netId. bool added; std::tie(threadPair, added) = mPrivateDnsValidateThreads.emplace(netId, ThreadTracker()); if (!added) { LOG(ERROR) << "Memory error while needValidateThread for netId " << netId; return true; } } auto& threadTracker = threadPair->second; if (threadTracker.count(server)) { LOG(DEBUG) << "Server " << addrToString(&(server.ss)) << " validate thread is already running. Thread tracker now has size " << threadTracker.size(); return false; } else { threadTracker.insert(server); LOG(DEBUG) << "Server " << addrToString(&(server.ss)) << " validate thread is not running. Thread tracker now has size " << threadTracker.size(); return true; } } void PrivateDnsConfiguration::updateServerState(const DnsTlsServer& server, Validation state, void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId) { auto netPair = mPrivateDnsTransports.find(netId); if (netPair != mPrivateDnsTransports.end()) { if (netPair == mPrivateDnsTransports.end()) { maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return; } auto& tracker = netPair->second; tracker[server] = state; if (tracker.find(identity) == tracker.end()) { maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId); return; } maybeNotifyObserver(server, state, netId); tracker[identity].setValidationState(state); maybeNotifyObserver(identity.ip.toString(), state, netId); } void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server, unsigned netId) { std::lock_guard<std::mutex> guard(mPrivateDnsLock); LOG(DEBUG) << "cleanValidateThreadTracker Server " << addrToString(&(server.ss)) << " validate thread is stopped."; bool PrivateDnsConfiguration::needsValidation(const DnsTlsServer& server) { // The server is not expected to be used on the network. if (!server.active()) return false; auto threadPair = mPrivateDnsValidateThreads.find(netId); if (threadPair != mPrivateDnsValidateThreads.end()) { auto& threadTracker = threadPair->second; threadTracker.erase(server); LOG(DEBUG) << "Server " << addrToString(&(server.ss)) << " validate thread is stopped. Thread tracker now has size " << threadTracker.size(); } } // The server is newly added. if (server.validationState() == Validation::unknown_server) return true; bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server) { const auto& iter = tracker.find(server); return (iter == tracker.end()) || (iter->second == Validation::fail); // The server has failed at least one validation attempt. Give it another try. if (server.validationState() == Validation::fail) return true; // The previous validation result might be unreliable. if (server.validationState() == Validation::success_but_expired) return true; return false; } void PrivateDnsConfiguration::setObserver(Observer* observer) { Loading @@ -362,10 +302,10 @@ void PrivateDnsConfiguration::setObserver(Observer* observer) { mObserver = observer; } void PrivateDnsConfiguration::maybeNotifyObserver(const DnsTlsServer& server, Validation validation, uint32_t netId) const { void PrivateDnsConfiguration::maybeNotifyObserver(const std::string& serverIp, Validation validation, uint32_t netId) const { if (mObserver) { mObserver->onValidationStateUpdate(addrToString(&server.ss), validation, netId); mObserver->onValidationStateUpdate(serverIp, validation, netId); } } Loading
PrivateDnsConfiguration.h +36 −21 Original line number Diff line number Diff line Loading @@ -22,6 +22,7 @@ #include <vector> #include <android-base/thread_annotations.h> #include <netdutils/InternetAddresses.h> #include "DnsTlsServer.h" Loading @@ -31,11 +32,10 @@ namespace net { // The DNS over TLS mode on a specific netId. enum class PrivateDnsMode : uint8_t { OFF, OPPORTUNISTIC, STRICT }; // Validation status of a DNS over TLS server (on a specific netId). enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid }; struct PrivateDnsStatus { PrivateDnsMode mode; // TODO: change the type to std::vector<DnsTlsServer>. std::map<DnsTlsServer, Validation, AddressComparator> serversMap; std::list<DnsTlsServer> validatedServers() const { Loading Loading @@ -65,8 +65,26 @@ class PrivateDnsConfiguration { void clear(unsigned netId) EXCLUDES(mPrivateDnsLock); struct ServerIdentity { const netdutils::IPAddress ip; const std::string name; const int protocol; explicit ServerIdentity(const DnsTlsServer& server) : ip(netdutils::IPSockAddr::toIPSockAddr(server.ss).ip()), name(server.name), protocol(server.protocol) {} bool operator<(const ServerIdentity& other) const { return std::tie(ip, name, protocol) < std::tie(other.ip, other.name, other.protocol); } bool operator==(const ServerIdentity& other) const { return std::tie(ip, name, protocol) == std::tie(other.ip, other.name, other.protocol); } }; private: typedef std::map<DnsTlsServer, Validation, AddressComparator> PrivateDnsTracker; typedef std::map<ServerIdentity, DnsTlsServer> PrivateDnsTracker; typedef std::set<DnsTlsServer, AddressComparator> ThreadTracker; PrivateDnsConfiguration() = default; Loading @@ -77,41 +95,38 @@ class PrivateDnsConfiguration { bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success) EXCLUDES(mPrivateDnsLock); bool needValidateThread(const DnsTlsServer& server, unsigned netId) REQUIRES(mPrivateDnsLock); void cleanValidateThreadTracker(const DnsTlsServer& server, unsigned netId) EXCLUDES(mPrivateDnsLock); // Start validation for newly added servers as well as any servers that have // landed in Validation::fail state. Note that servers that have failed // Decide if a validation for |server| is needed. Note that servers that have failed // multiple validation attempts but for which there is still a validating // thread running are marked as being Validation::in_process. bool needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server) REQUIRES(mPrivateDnsLock); bool needsValidation(const DnsTlsServer& server) REQUIRES(mPrivateDnsLock); void updateServerState(const DnsTlsServer& server, Validation state, uint32_t netId) void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId) REQUIRES(mPrivateDnsLock); std::mutex mPrivateDnsLock; std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock); // Structure for tracking the validation status of servers on a specific netId. // Using the AddressComparator ensures at most one entry per IP address. // Contains all servers for a network, along with their current validation status. // In case a server is removed due to a configuration change, it remains in this map, // but is marked inactive. // Any pending validation threads will continue running because we have no way to cancel them. std::map<unsigned, PrivateDnsTracker> mPrivateDnsTransports GUARDED_BY(mPrivateDnsLock); std::map<unsigned, ThreadTracker> mPrivateDnsValidateThreads GUARDED_BY(mPrivateDnsLock); // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is // about to begin or 2) when a validation finishes. // WARNING: The Observer is notified while the lock is being held. Be careful not to call any // method of PrivateDnsConfiguration from the observer. // about to begin or 2) when a validation finishes. If a validation finishes when in OFF mode // or when the network has been destroyed, |validation| will be Validation::fail. // WARNING: The Observer is notified while the lock is being held. Be careful not to call // any method of PrivateDnsConfiguration from the observer. // TODO: fix the reentrancy problem. class Observer { public: virtual ~Observer(){}; virtual void onValidationStateUpdate(const std::string& server, Validation validation, virtual void onValidationStateUpdate(const std::string& serverIp, Validation validation, uint32_t netId) = 0; }; void setObserver(Observer* observer); void maybeNotifyObserver(const DnsTlsServer& server, Validation validation, void maybeNotifyObserver(const std::string& serverIp, Validation validation, uint32_t netId) const REQUIRES(mPrivateDnsLock); Observer* mObserver GUARDED_BY(mPrivateDnsLock); Loading
PrivateDnsConfigurationTest.cpp +37 −1 Original line number Diff line number Diff line Loading @@ -63,7 +63,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test { class MockObserver : public PrivateDnsConfiguration::Observer { public: MOCK_METHOD(void, onValidationStateUpdate, (const std::string& server, Validation validation, uint32_t netId), (override)); (const std::string& serverIp, Validation validation, uint32_t netId), (override)); std::map<std::string, Validation> getServerStateMap() const { std::lock_guard guard(lock); Loading Loading @@ -172,6 +173,11 @@ TEST_F(PrivateDnsConfigurationTest, ValidationBlock) { backend.setDeferredResp(false); ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; })); // kServer1 is not a present server and thus should not be available from // PrivateDnsConfiguration::getStatus(). mObserver.removeFromServerStateMap(kServer1); expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC); } Loading Loading @@ -218,6 +224,36 @@ TEST_F(PrivateDnsConfigurationTest, NoValidation) { expectStatus(); } TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) { using ServerIdentity = PrivateDnsConfiguration::ServerIdentity; DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853)); server.name = "dns.example.com"; server.protocol = 1; // Different IP address (port is ignored). DnsTlsServer other = server; EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 5353); EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.2", 853); EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); // Different provider hostname. other = server; EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.name = "other.example.com"; EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); other.name = ""; EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); // Different protocol. other = server; EXPECT_EQ(ServerIdentity(server), ServerIdentity(other)); other.protocol++; EXPECT_NE(ServerIdentity(server), ServerIdentity(other)); } // TODO: add ValidationFail_Strict test. } // namespace android::net