Loading PrivateDnsConfiguration.cpp +13 −3 Original line number Original line Diff line number Diff line Loading @@ -109,7 +109,11 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, } } PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) const { PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) const { PrivateDnsStatus status{PrivateDnsMode::OFF, {}}; PrivateDnsStatus status{ .mode = PrivateDnsMode::OFF, .dotServersMap = {}, .dohServersMap = {}, }; std::lock_guard guard(mPrivateDnsLock); std::lock_guard guard(mPrivateDnsLock); const auto mode = mPrivateDnsModes.find(netId); const auto mode = mPrivateDnsModes.find(netId); Loading @@ -121,10 +125,16 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) const { for (const auto& [_, server] : netPair->second) { for (const auto& [_, server] : netPair->second) { if (server->isDot() && server->active()) { if (server->isDot() && server->active()) { DnsTlsServer& dotServer = *static_cast<DnsTlsServer*>(server.get()); DnsTlsServer& dotServer = *static_cast<DnsTlsServer*>(server.get()); status.serversMap.emplace(dotServer, server->validationState()); status.dotServersMap.emplace(dotServer, server->validationState()); } } } // TODO: also add DoH server to the map. } } auto it = mDohTracker.find(netId); if (it != mDohTracker.end()) { status.dohServersMap.emplace( netdutils::IPSockAddr::toIPSockAddr(it->second.ipAddr, kDohPort), it->second.status); } } return status; return status; Loading PrivateDnsConfiguration.h +13 −2 Original line number Original line Diff line number Diff line Loading @@ -44,18 +44,29 @@ struct PrivateDnsStatus { PrivateDnsMode mode; PrivateDnsMode mode; // TODO: change the type to std::vector<DnsTlsServer>. // TODO: change the type to std::vector<DnsTlsServer>. std::map<DnsTlsServer, Validation, AddressComparator> serversMap; std::map<DnsTlsServer, Validation, AddressComparator> dotServersMap; std::map<netdutils::IPSockAddr, Validation> dohServersMap; std::list<DnsTlsServer> validatedServers() const { std::list<DnsTlsServer> validatedServers() const { std::list<DnsTlsServer> servers; std::list<DnsTlsServer> servers; for (const auto& pair : serversMap) { for (const auto& pair : dotServersMap) { if (pair.second == Validation::success) { if (pair.second == Validation::success) { servers.push_back(pair.first); servers.push_back(pair.first); } } } } return servers; return servers; } } bool hasValidatedDohServers() const { for (const auto& [_, status] : dohServersMap) { if (status == Validation::success) { return true; } } return false; } }; }; class PrivateDnsConfiguration { class PrivateDnsConfiguration { Loading PrivateDnsConfigurationTest.cpp +2 −2 Original line number Original line Diff line number Diff line Loading @@ -110,7 +110,7 @@ class PrivateDnsConfigurationTest : public ::testing::Test { if (status.mode != mode) return false; if (status.mode != mode) return false; std::map<std::string, Validation> serverStateMap; std::map<std::string, Validation> serverStateMap; for (const auto& [server, validation] : status.serversMap) { for (const auto& [server, validation] : status.dotServersMap) { serverStateMap[ToString(&server.ss)] = validation; serverStateMap[ToString(&server.ss)] = validation; } } return (serverStateMap == mObserver.getServerStateMap()); return (serverStateMap == mObserver.getServerStateMap()); Loading Loading @@ -275,7 +275,7 @@ TEST_F(PrivateDnsConfigurationTest, NoValidation) { const auto expectStatus = [&]() { const auto expectStatus = [&]() { const PrivateDnsStatus status = mPdc.getStatus(kNetId); const PrivateDnsStatus status = mPdc.getStatus(kNetId); EXPECT_EQ(status.mode, PrivateDnsMode::OFF); EXPECT_EQ(status.mode, PrivateDnsMode::OFF); EXPECT_THAT(status.serversMap, testing::IsEmpty()); EXPECT_THAT(status.dotServersMap, testing::IsEmpty()); }; }; EXPECT_EQ(mPdc.set(kNetId, kMark, {"invalid_addr"}, {}, {}), -EINVAL); EXPECT_EQ(mPdc.set(kNetId, kMark, {"invalid_addr"}, {}, {}), -EINVAL); Loading ResolverController.cpp +4 −4 Original line number Original line Diff line number Diff line Loading @@ -268,7 +268,7 @@ int ResolverController::getResolverInfo(int32_t netId, std::vector<std::string>* ResolverStats::encodeAll(res_stats, stats); ResolverStats::encodeAll(res_stats, stats); const auto privateDnsStatus = PrivateDnsConfiguration::getInstance().getStatus(netId); const auto privateDnsStatus = PrivateDnsConfiguration::getInstance().getStatus(netId); for (const auto& [server, _] : privateDnsStatus.serversMap) { for (const auto& [server, _] : privateDnsStatus.dotServersMap) { tlsServers->push_back(server.toIpString()); tlsServers->push_back(server.toIpString()); } } Loading Loading @@ -362,13 +362,13 @@ void ResolverController::dump(DumpWriter& dw, unsigned netId) { mDns64Configuration.dump(dw, netId); mDns64Configuration.dump(dw, netId); const auto privateDnsStatus = PrivateDnsConfiguration::getInstance().getStatus(netId); const auto privateDnsStatus = PrivateDnsConfiguration::getInstance().getStatus(netId); dw.println("Private DNS mode: %s", getPrivateDnsModeString(privateDnsStatus.mode)); dw.println("Private DNS mode: %s", getPrivateDnsModeString(privateDnsStatus.mode)); if (privateDnsStatus.serversMap.size() == 0) { if (privateDnsStatus.dotServersMap.size() == 0) { dw.println("No Private DNS servers configured"); dw.println("No Private DNS servers configured"); } else { } else { dw.println("Private DNS configuration (%u entries)", dw.println("Private DNS configuration (%u entries)", static_cast<uint32_t>(privateDnsStatus.serversMap.size())); static_cast<uint32_t>(privateDnsStatus.dotServersMap.size())); dw.incIndent(); dw.incIndent(); for (const auto& [server, validation] : privateDnsStatus.serversMap) { for (const auto& [server, validation] : privateDnsStatus.dotServersMap) { dw.println("%s name{%s} status{%s}", server.toIpString().c_str(), dw.println("%s name{%s} status{%s}", server.toIpString().c_str(), server.name.c_str(), validationStatusToString(validation)); server.name.c_str(), validationStatusToString(validation)); } } Loading res_send.cpp +12 −8 Original line number Original line Diff line number Diff line Loading @@ -1337,7 +1337,7 @@ static int res_private_dns_send(ResState* statp, const Slice query, const Slice } } case PrivateDnsMode::OPPORTUNISTIC: { case PrivateDnsMode::OPPORTUNISTIC: { *fallback = true; *fallback = true; if (enableDoH) { if (enableDoH && privateDnsStatus.hasValidatedDohServers()) { result = res_doh_send(statp, query, answer, rcode); result = res_doh_send(statp, query, answer, rcode); if (result != DOH_RESULT_CAN_NOT_SEND) return result; if (result != DOH_RESULT_CAN_NOT_SEND) return result; } } Loading @@ -1346,7 +1346,7 @@ static int res_private_dns_send(ResState* statp, const Slice query, const Slice } } case PrivateDnsMode::STRICT: { case PrivateDnsMode::STRICT: { *fallback = false; *fallback = false; if (enableDoH) { if (enableDoH && privateDnsStatus.hasValidatedDohServers()) { result = res_doh_send(statp, query, answer, rcode); result = res_doh_send(statp, query, answer, rcode); if (result != DOH_RESULT_CAN_NOT_SEND) return result; if (result != DOH_RESULT_CAN_NOT_SEND) return result; } } Loading @@ -1366,15 +1366,19 @@ static int res_private_dns_send(ResState* statp, const Slice query, const Slice // default network change. // default network change. for (int i = 0; i < 42; i++) { for (int i = 0; i < 42; i++) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100)); if (enableDoH) { result = res_doh_send(statp, query, answer, rcode); if (result != DOH_RESULT_CAN_NOT_SEND) return result; } // Calling getStatus() to merely check if there's any validated server seems // Calling getStatus() to merely check if there's any validated server seems // wasteful. Consider adding a new method in PrivateDnsConfiguration for speed // wasteful. Consider adding a new method in PrivateDnsConfiguration for speed // ups. // ups. if (!privateDnsConfiguration.getStatus(netId).validatedServers().empty()) { privateDnsStatus = privateDnsConfiguration.getStatus(netId); privateDnsStatus = privateDnsConfiguration.getStatus(netId); if (enableDoH && privateDnsStatus.hasValidatedDohServers()) { result = res_doh_send(statp, query, answer, rcode); if (result != DOH_RESULT_CAN_NOT_SEND) return result; } // Switch to use the DoT servers if they are validated. if (!privateDnsStatus.validatedServers().empty()) { break; break; } } } } Loading Loading
PrivateDnsConfiguration.cpp +13 −3 Original line number Original line Diff line number Diff line Loading @@ -109,7 +109,11 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, } } PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) const { PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) const { PrivateDnsStatus status{PrivateDnsMode::OFF, {}}; PrivateDnsStatus status{ .mode = PrivateDnsMode::OFF, .dotServersMap = {}, .dohServersMap = {}, }; std::lock_guard guard(mPrivateDnsLock); std::lock_guard guard(mPrivateDnsLock); const auto mode = mPrivateDnsModes.find(netId); const auto mode = mPrivateDnsModes.find(netId); Loading @@ -121,10 +125,16 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) const { for (const auto& [_, server] : netPair->second) { for (const auto& [_, server] : netPair->second) { if (server->isDot() && server->active()) { if (server->isDot() && server->active()) { DnsTlsServer& dotServer = *static_cast<DnsTlsServer*>(server.get()); DnsTlsServer& dotServer = *static_cast<DnsTlsServer*>(server.get()); status.serversMap.emplace(dotServer, server->validationState()); status.dotServersMap.emplace(dotServer, server->validationState()); } } } // TODO: also add DoH server to the map. } } auto it = mDohTracker.find(netId); if (it != mDohTracker.end()) { status.dohServersMap.emplace( netdutils::IPSockAddr::toIPSockAddr(it->second.ipAddr, kDohPort), it->second.status); } } return status; return status; Loading
PrivateDnsConfiguration.h +13 −2 Original line number Original line Diff line number Diff line Loading @@ -44,18 +44,29 @@ struct PrivateDnsStatus { PrivateDnsMode mode; PrivateDnsMode mode; // TODO: change the type to std::vector<DnsTlsServer>. // TODO: change the type to std::vector<DnsTlsServer>. std::map<DnsTlsServer, Validation, AddressComparator> serversMap; std::map<DnsTlsServer, Validation, AddressComparator> dotServersMap; std::map<netdutils::IPSockAddr, Validation> dohServersMap; std::list<DnsTlsServer> validatedServers() const { std::list<DnsTlsServer> validatedServers() const { std::list<DnsTlsServer> servers; std::list<DnsTlsServer> servers; for (const auto& pair : serversMap) { for (const auto& pair : dotServersMap) { if (pair.second == Validation::success) { if (pair.second == Validation::success) { servers.push_back(pair.first); servers.push_back(pair.first); } } } } return servers; return servers; } } bool hasValidatedDohServers() const { for (const auto& [_, status] : dohServersMap) { if (status == Validation::success) { return true; } } return false; } }; }; class PrivateDnsConfiguration { class PrivateDnsConfiguration { Loading
PrivateDnsConfigurationTest.cpp +2 −2 Original line number Original line Diff line number Diff line Loading @@ -110,7 +110,7 @@ class PrivateDnsConfigurationTest : public ::testing::Test { if (status.mode != mode) return false; if (status.mode != mode) return false; std::map<std::string, Validation> serverStateMap; std::map<std::string, Validation> serverStateMap; for (const auto& [server, validation] : status.serversMap) { for (const auto& [server, validation] : status.dotServersMap) { serverStateMap[ToString(&server.ss)] = validation; serverStateMap[ToString(&server.ss)] = validation; } } return (serverStateMap == mObserver.getServerStateMap()); return (serverStateMap == mObserver.getServerStateMap()); Loading Loading @@ -275,7 +275,7 @@ TEST_F(PrivateDnsConfigurationTest, NoValidation) { const auto expectStatus = [&]() { const auto expectStatus = [&]() { const PrivateDnsStatus status = mPdc.getStatus(kNetId); const PrivateDnsStatus status = mPdc.getStatus(kNetId); EXPECT_EQ(status.mode, PrivateDnsMode::OFF); EXPECT_EQ(status.mode, PrivateDnsMode::OFF); EXPECT_THAT(status.serversMap, testing::IsEmpty()); EXPECT_THAT(status.dotServersMap, testing::IsEmpty()); }; }; EXPECT_EQ(mPdc.set(kNetId, kMark, {"invalid_addr"}, {}, {}), -EINVAL); EXPECT_EQ(mPdc.set(kNetId, kMark, {"invalid_addr"}, {}, {}), -EINVAL); Loading
ResolverController.cpp +4 −4 Original line number Original line Diff line number Diff line Loading @@ -268,7 +268,7 @@ int ResolverController::getResolverInfo(int32_t netId, std::vector<std::string>* ResolverStats::encodeAll(res_stats, stats); ResolverStats::encodeAll(res_stats, stats); const auto privateDnsStatus = PrivateDnsConfiguration::getInstance().getStatus(netId); const auto privateDnsStatus = PrivateDnsConfiguration::getInstance().getStatus(netId); for (const auto& [server, _] : privateDnsStatus.serversMap) { for (const auto& [server, _] : privateDnsStatus.dotServersMap) { tlsServers->push_back(server.toIpString()); tlsServers->push_back(server.toIpString()); } } Loading Loading @@ -362,13 +362,13 @@ void ResolverController::dump(DumpWriter& dw, unsigned netId) { mDns64Configuration.dump(dw, netId); mDns64Configuration.dump(dw, netId); const auto privateDnsStatus = PrivateDnsConfiguration::getInstance().getStatus(netId); const auto privateDnsStatus = PrivateDnsConfiguration::getInstance().getStatus(netId); dw.println("Private DNS mode: %s", getPrivateDnsModeString(privateDnsStatus.mode)); dw.println("Private DNS mode: %s", getPrivateDnsModeString(privateDnsStatus.mode)); if (privateDnsStatus.serversMap.size() == 0) { if (privateDnsStatus.dotServersMap.size() == 0) { dw.println("No Private DNS servers configured"); dw.println("No Private DNS servers configured"); } else { } else { dw.println("Private DNS configuration (%u entries)", dw.println("Private DNS configuration (%u entries)", static_cast<uint32_t>(privateDnsStatus.serversMap.size())); static_cast<uint32_t>(privateDnsStatus.dotServersMap.size())); dw.incIndent(); dw.incIndent(); for (const auto& [server, validation] : privateDnsStatus.serversMap) { for (const auto& [server, validation] : privateDnsStatus.dotServersMap) { dw.println("%s name{%s} status{%s}", server.toIpString().c_str(), dw.println("%s name{%s} status{%s}", server.toIpString().c_str(), server.name.c_str(), validationStatusToString(validation)); server.name.c_str(), validationStatusToString(validation)); } } Loading
res_send.cpp +12 −8 Original line number Original line Diff line number Diff line Loading @@ -1337,7 +1337,7 @@ static int res_private_dns_send(ResState* statp, const Slice query, const Slice } } case PrivateDnsMode::OPPORTUNISTIC: { case PrivateDnsMode::OPPORTUNISTIC: { *fallback = true; *fallback = true; if (enableDoH) { if (enableDoH && privateDnsStatus.hasValidatedDohServers()) { result = res_doh_send(statp, query, answer, rcode); result = res_doh_send(statp, query, answer, rcode); if (result != DOH_RESULT_CAN_NOT_SEND) return result; if (result != DOH_RESULT_CAN_NOT_SEND) return result; } } Loading @@ -1346,7 +1346,7 @@ static int res_private_dns_send(ResState* statp, const Slice query, const Slice } } case PrivateDnsMode::STRICT: { case PrivateDnsMode::STRICT: { *fallback = false; *fallback = false; if (enableDoH) { if (enableDoH && privateDnsStatus.hasValidatedDohServers()) { result = res_doh_send(statp, query, answer, rcode); result = res_doh_send(statp, query, answer, rcode); if (result != DOH_RESULT_CAN_NOT_SEND) return result; if (result != DOH_RESULT_CAN_NOT_SEND) return result; } } Loading @@ -1366,15 +1366,19 @@ static int res_private_dns_send(ResState* statp, const Slice query, const Slice // default network change. // default network change. for (int i = 0; i < 42; i++) { for (int i = 0; i < 42; i++) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100)); if (enableDoH) { result = res_doh_send(statp, query, answer, rcode); if (result != DOH_RESULT_CAN_NOT_SEND) return result; } // Calling getStatus() to merely check if there's any validated server seems // Calling getStatus() to merely check if there's any validated server seems // wasteful. Consider adding a new method in PrivateDnsConfiguration for speed // wasteful. Consider adding a new method in PrivateDnsConfiguration for speed // ups. // ups. if (!privateDnsConfiguration.getStatus(netId).validatedServers().empty()) { privateDnsStatus = privateDnsConfiguration.getStatus(netId); privateDnsStatus = privateDnsConfiguration.getStatus(netId); if (enableDoH && privateDnsStatus.hasValidatedDohServers()) { result = res_doh_send(statp, query, answer, rcode); if (result != DOH_RESULT_CAN_NOT_SEND) return result; } // Switch to use the DoT servers if they are validated. if (!privateDnsStatus.validatedServers().empty()) { break; break; } } } } Loading