Loading DnsTlsSocket.cpp +14 −2 Original line number Diff line number Diff line Loading @@ -450,8 +450,20 @@ void DnsTlsSocket::loop() { break; } if (fds[SSLFD].revents & (POLLIN | POLLERR | POLLHUP)) { bool readFailed = false; // readResponse() only reads one DNS (and consumes exact bytes) from ssl. // Keep doing so until ssl has no pending data. // TODO: readResponse() can block until it reads a complete DNS response. Consider // refactoring it to not get blocked in any case. do { if (!readResponse()) { LOG(DEBUG) << "SSL remote close or read error."; readFailed = true; } } while (SSL_pending(mSsl.get()) > 0 && !readFailed); if (readFailed) { break; } } Loading DnsTlsSocket.h +3 −0 Original line number Diff line number Diff line Loading @@ -137,6 +137,9 @@ class DnsTlsSocket : public IDnsTlsSocket { int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock); bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock); // Read one DNS response. It can potentially block until reading the exact bytes of // the response. bool readResponse() REQUIRES(mLock); // It is only used for DNS-OVER-TLS internal test. Loading tests/dns_responder/dns_tls_frontend.cpp +24 −9 Original line number Diff line number Diff line Loading @@ -223,6 +223,11 @@ void DnsTlsFrontend::requestHandler() { // client, including cleanup actions. queries_ += handleRequests(ssl.get(), client.get()); } if (passiveClose_) { LOG(DEBUG) << "hold the current connection until next connection request"; clientFd = std::move(client); } } } LOG(DEBUG) << "Ending loop"; Loading @@ -230,6 +235,7 @@ void DnsTlsFrontend::requestHandler() { int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) { int queryCounts = 0; std::vector<uint8_t> reply; pollfd fds = {.fd = clientFd, .events = POLLIN}; do { uint8_t queryHeader[2]; Loading Loading @@ -263,16 +269,25 @@ int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) { uint8_t responseHeader[2]; responseHeader[0] = rlen >> 8; responseHeader[1] = rlen; if (SSL_write(ssl, responseHeader, 2) != 2) { LOG(INFO) << "Failed to write response header"; return queryCounts; reply.insert(reply.end(), responseHeader, responseHeader + 2); reply.insert(reply.end(), recv_buffer, recv_buffer + rlen); ++queryCounts; if (queryCounts >= delayQueries_) { break; } if (SSL_write(ssl, recv_buffer, rlen) != rlen) { LOG(INFO) << "Failed to write response body"; return queryCounts; } while (poll(&fds, 1, delayQueriesTimeout_) > 0); if (queryCounts < delayQueries_) { LOG(WARNING) << "Expect " << delayQueries_ << " queries, but actually received " << queryCounts << " queries"; } const int replyLen = reply.size(); LOG(DEBUG) << "Sending " << queryCounts << "queries at once, byte = " << replyLen; if (SSL_write(ssl, reply.data(), replyLen) != replyLen) { LOG(WARNING) << "Failed to write response body"; } ++queryCounts; } while (poll(&fds, 1, 1) > 0); LOG(DEBUG) << __func__ << " return: " << queryCounts; return queryCounts; Loading tests/dns_responder/dns_tls_frontend.h +9 −0 Original line number Diff line number Diff line Loading @@ -62,6 +62,12 @@ class DnsTlsFrontend { void set_chain_length(int length) { chain_length_ = length; } void setHangOnHandshakeForTesting(bool hangOnHandshake) { hangOnHandshake_ = hangOnHandshake; } // Set DnsTlsFrontend to not reply any response until there are |delay| responses or timeout. void setDelayQueries(int delay) { delayQueries_ = delay; } void setDelayQueriesTimeout(int timeout) { delayQueriesTimeout_ = timeout; } void setPassiveClose(bool passiveClose) { passiveClose_ = passiveClose; } static constexpr char kDefaultListenAddr[] = "127.0.0.3"; static constexpr char kDefaultListenService[] = "853"; static constexpr char kDefaultBackendAddr[] = "127.0.0.3"; Loading Loading @@ -94,6 +100,9 @@ class DnsTlsFrontend { std::mutex update_mutex_; int chain_length_ = 1; std::atomic<bool> hangOnHandshake_ = false; std::atomic<int> delayQueries_ = 1; std::atomic<int> delayQueriesTimeout_ = 1; std::atomic<bool> passiveClose_ = false; }; } // namespace test Loading tests/resolv_integration_test.cpp +61 −0 Original line number Diff line number Diff line Loading @@ -5501,6 +5501,67 @@ TEST_F(ResolverTest, DnsServerSelection) { } while (std::next_permutation(serverList.begin(), serverList.end())); } TEST_F(ResolverTest, MultipleDotQueriesInOnePacket) { constexpr char hostname1[] = "query1.example.com."; constexpr char hostname2[] = "query2.example.com."; const std::vector<DnsRecord> records = { {hostname1, ns_type::ns_t_a, "1.2.3.4"}, {hostname2, ns_type::ns_t_a, "1.2.3.5"}, }; const std::string addr = getUniqueIPv4Address(); test::DNSResponder dns(addr); StartDns(dns, records); test::DnsTlsFrontend tls(addr, "853", addr, "53"); ASSERT_TRUE(tls.startServer()); // Set up resolver to strict mode. auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel(); parcel.servers = {addr}; parcel.tlsServers = {addr}; parcel.tlsName = kDefaultPrivateDnsHostName; parcel.caCertificate = kCaCert; ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel)); EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true)); EXPECT_TRUE(tls.waitForQueries(1)); tls.clearQueries(); dns.clearQueries(); const auto queryAndCheck = [&](const std::string& hostname, const std::vector<DnsRecord>& records) { SCOPED_TRACE(hostname); const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM}; auto [result, timeTakenMs] = safe_getaddrinfo_time_taken(hostname.c_str(), nullptr, hints); std::vector<std::string> expectedAnswers; for (const auto& r : records) { if (r.host_name == hostname) expectedAnswers.push_back(r.addr); } EXPECT_LE(timeTakenMs, 200); ASSERT_NE(result, nullptr); EXPECT_THAT(ToStrings(result), testing::UnorderedElementsAreArray(expectedAnswers)); }; // Set tls to reply DNS responses in one TCP packet and not to close the connection from its // side. tls.setDelayQueries(2); tls.setDelayQueriesTimeout(500); tls.setPassiveClose(true); // Start sending DNS requests at the same time. std::array<std::thread, 2> threads; threads[0] = std::thread(queryAndCheck, hostname1, records); threads[1] = std::thread(queryAndCheck, hostname2, records); threads[0].join(); threads[1].join(); // Also check no additional queries due to DoT reconnection. EXPECT_EQ(tls.queries(), 2); } // ResolverMultinetworkTest is used to verify multinetwork functionality. Here's how it works: // The resolver sends queries to address A, and then there will be a TunForwarder helping forward // the packets to address B, which is the address on which the testing server is listening. The Loading Loading
DnsTlsSocket.cpp +14 −2 Original line number Diff line number Diff line Loading @@ -450,8 +450,20 @@ void DnsTlsSocket::loop() { break; } if (fds[SSLFD].revents & (POLLIN | POLLERR | POLLHUP)) { bool readFailed = false; // readResponse() only reads one DNS (and consumes exact bytes) from ssl. // Keep doing so until ssl has no pending data. // TODO: readResponse() can block until it reads a complete DNS response. Consider // refactoring it to not get blocked in any case. do { if (!readResponse()) { LOG(DEBUG) << "SSL remote close or read error."; readFailed = true; } } while (SSL_pending(mSsl.get()) > 0 && !readFailed); if (readFailed) { break; } } Loading
DnsTlsSocket.h +3 −0 Original line number Diff line number Diff line Loading @@ -137,6 +137,9 @@ class DnsTlsSocket : public IDnsTlsSocket { int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock); bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock); // Read one DNS response. It can potentially block until reading the exact bytes of // the response. bool readResponse() REQUIRES(mLock); // It is only used for DNS-OVER-TLS internal test. Loading
tests/dns_responder/dns_tls_frontend.cpp +24 −9 Original line number Diff line number Diff line Loading @@ -223,6 +223,11 @@ void DnsTlsFrontend::requestHandler() { // client, including cleanup actions. queries_ += handleRequests(ssl.get(), client.get()); } if (passiveClose_) { LOG(DEBUG) << "hold the current connection until next connection request"; clientFd = std::move(client); } } } LOG(DEBUG) << "Ending loop"; Loading @@ -230,6 +235,7 @@ void DnsTlsFrontend::requestHandler() { int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) { int queryCounts = 0; std::vector<uint8_t> reply; pollfd fds = {.fd = clientFd, .events = POLLIN}; do { uint8_t queryHeader[2]; Loading Loading @@ -263,16 +269,25 @@ int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) { uint8_t responseHeader[2]; responseHeader[0] = rlen >> 8; responseHeader[1] = rlen; if (SSL_write(ssl, responseHeader, 2) != 2) { LOG(INFO) << "Failed to write response header"; return queryCounts; reply.insert(reply.end(), responseHeader, responseHeader + 2); reply.insert(reply.end(), recv_buffer, recv_buffer + rlen); ++queryCounts; if (queryCounts >= delayQueries_) { break; } if (SSL_write(ssl, recv_buffer, rlen) != rlen) { LOG(INFO) << "Failed to write response body"; return queryCounts; } while (poll(&fds, 1, delayQueriesTimeout_) > 0); if (queryCounts < delayQueries_) { LOG(WARNING) << "Expect " << delayQueries_ << " queries, but actually received " << queryCounts << " queries"; } const int replyLen = reply.size(); LOG(DEBUG) << "Sending " << queryCounts << "queries at once, byte = " << replyLen; if (SSL_write(ssl, reply.data(), replyLen) != replyLen) { LOG(WARNING) << "Failed to write response body"; } ++queryCounts; } while (poll(&fds, 1, 1) > 0); LOG(DEBUG) << __func__ << " return: " << queryCounts; return queryCounts; Loading
tests/dns_responder/dns_tls_frontend.h +9 −0 Original line number Diff line number Diff line Loading @@ -62,6 +62,12 @@ class DnsTlsFrontend { void set_chain_length(int length) { chain_length_ = length; } void setHangOnHandshakeForTesting(bool hangOnHandshake) { hangOnHandshake_ = hangOnHandshake; } // Set DnsTlsFrontend to not reply any response until there are |delay| responses or timeout. void setDelayQueries(int delay) { delayQueries_ = delay; } void setDelayQueriesTimeout(int timeout) { delayQueriesTimeout_ = timeout; } void setPassiveClose(bool passiveClose) { passiveClose_ = passiveClose; } static constexpr char kDefaultListenAddr[] = "127.0.0.3"; static constexpr char kDefaultListenService[] = "853"; static constexpr char kDefaultBackendAddr[] = "127.0.0.3"; Loading Loading @@ -94,6 +100,9 @@ class DnsTlsFrontend { std::mutex update_mutex_; int chain_length_ = 1; std::atomic<bool> hangOnHandshake_ = false; std::atomic<int> delayQueries_ = 1; std::atomic<int> delayQueriesTimeout_ = 1; std::atomic<bool> passiveClose_ = false; }; } // namespace test Loading
tests/resolv_integration_test.cpp +61 −0 Original line number Diff line number Diff line Loading @@ -5501,6 +5501,67 @@ TEST_F(ResolverTest, DnsServerSelection) { } while (std::next_permutation(serverList.begin(), serverList.end())); } TEST_F(ResolverTest, MultipleDotQueriesInOnePacket) { constexpr char hostname1[] = "query1.example.com."; constexpr char hostname2[] = "query2.example.com."; const std::vector<DnsRecord> records = { {hostname1, ns_type::ns_t_a, "1.2.3.4"}, {hostname2, ns_type::ns_t_a, "1.2.3.5"}, }; const std::string addr = getUniqueIPv4Address(); test::DNSResponder dns(addr); StartDns(dns, records); test::DnsTlsFrontend tls(addr, "853", addr, "53"); ASSERT_TRUE(tls.startServer()); // Set up resolver to strict mode. auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel(); parcel.servers = {addr}; parcel.tlsServers = {addr}; parcel.tlsName = kDefaultPrivateDnsHostName; parcel.caCertificate = kCaCert; ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel)); EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true)); EXPECT_TRUE(tls.waitForQueries(1)); tls.clearQueries(); dns.clearQueries(); const auto queryAndCheck = [&](const std::string& hostname, const std::vector<DnsRecord>& records) { SCOPED_TRACE(hostname); const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM}; auto [result, timeTakenMs] = safe_getaddrinfo_time_taken(hostname.c_str(), nullptr, hints); std::vector<std::string> expectedAnswers; for (const auto& r : records) { if (r.host_name == hostname) expectedAnswers.push_back(r.addr); } EXPECT_LE(timeTakenMs, 200); ASSERT_NE(result, nullptr); EXPECT_THAT(ToStrings(result), testing::UnorderedElementsAreArray(expectedAnswers)); }; // Set tls to reply DNS responses in one TCP packet and not to close the connection from its // side. tls.setDelayQueries(2); tls.setDelayQueriesTimeout(500); tls.setPassiveClose(true); // Start sending DNS requests at the same time. std::array<std::thread, 2> threads; threads[0] = std::thread(queryAndCheck, hostname1, records); threads[1] = std::thread(queryAndCheck, hostname2, records); threads[0].join(); threads[1].join(); // Also check no additional queries due to DoT reconnection. EXPECT_EQ(tls.queries(), 2); } // ResolverMultinetworkTest is used to verify multinetwork functionality. Here's how it works: // The resolver sends queries to address A, and then there will be a TunForwarder helping forward // the packets to address B, which is the address on which the testing server is listening. The Loading