Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 5e1b9918 authored by Mike Yu's avatar Mike Yu
Browse files

Fix DnsTlsSocket to consume all pending data from ssl

When servers reply multiple DNS responses in one packet, DnsTlsSocket
handles only the first DNS response. The remaining DNS responses
are still in ssl buffer. This causes a bug that there will be at
least one DNS reponse in ssl buffer, which results in at least one
DNS request timeout.

This change fixes it by always consuming the data from ssl before
next poll.

Bug: 172778187
Bug: 168027339
Bug: 171413368
Test: cd packages/modules/DnsResolver && atest
Change-Id: I72807e43636a46d30df6a694bb906313a8de63f2
parent ac949907
Loading
Loading
Loading
Loading
+14 −2
Original line number Diff line number Diff line
@@ -450,8 +450,20 @@ void DnsTlsSocket::loop() {
            break;
        }
        if (fds[SSLFD].revents & (POLLIN | POLLERR | POLLHUP)) {
            bool readFailed = false;

            // readResponse() only reads one DNS (and consumes exact bytes) from ssl.
            // Keep doing so until ssl has no pending data.
            // TODO: readResponse() can block until it reads a complete DNS response. Consider
            // refactoring it to not get blocked in any case.
            do {
                if (!readResponse()) {
                    LOG(DEBUG) << "SSL remote close or read error.";
                    readFailed = true;
                }
            } while (SSL_pending(mSsl.get()) > 0 && !readFailed);

            if (readFailed) {
                break;
            }
        }
+3 −0
Original line number Diff line number Diff line
@@ -137,6 +137,9 @@ class DnsTlsSocket : public IDnsTlsSocket {
    int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock);

    bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);

    // Read one DNS response. It can potentially block until reading the exact bytes of
    // the response.
    bool readResponse() REQUIRES(mLock);

    // It is only used for DNS-OVER-TLS internal test.
+24 −9
Original line number Diff line number Diff line
@@ -223,6 +223,11 @@ void DnsTlsFrontend::requestHandler() {
                // client, including cleanup actions.
                queries_ += handleRequests(ssl.get(), client.get());
            }

            if (passiveClose_) {
                LOG(DEBUG) << "hold the current connection until next connection request";
                clientFd = std::move(client);
            }
        }
    }
    LOG(DEBUG) << "Ending loop";
@@ -230,6 +235,7 @@ void DnsTlsFrontend::requestHandler() {

int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
    int queryCounts = 0;
    std::vector<uint8_t> reply;
    pollfd fds = {.fd = clientFd, .events = POLLIN};
    do {
        uint8_t queryHeader[2];
@@ -263,16 +269,25 @@ int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
        uint8_t responseHeader[2];
        responseHeader[0] = rlen >> 8;
        responseHeader[1] = rlen;
        if (SSL_write(ssl, responseHeader, 2) != 2) {
            LOG(INFO) << "Failed to write response header";
            return queryCounts;
        reply.insert(reply.end(), responseHeader, responseHeader + 2);
        reply.insert(reply.end(), recv_buffer, recv_buffer + rlen);

        ++queryCounts;
        if (queryCounts >= delayQueries_) {
            break;
        }
        if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
            LOG(INFO) << "Failed to write response body";
            return queryCounts;
    } while (poll(&fds, 1, delayQueriesTimeout_) > 0);

    if (queryCounts < delayQueries_) {
        LOG(WARNING) << "Expect " << delayQueries_ << " queries, but actually received "
                     << queryCounts << " queries";
    }

    const int replyLen = reply.size();
    LOG(DEBUG) << "Sending " << queryCounts << "queries at once, byte = " << replyLen;
    if (SSL_write(ssl, reply.data(), replyLen) != replyLen) {
        LOG(WARNING) << "Failed to write response body";
    }
        ++queryCounts;
    } while (poll(&fds, 1, 1) > 0);

    LOG(DEBUG) << __func__ << " return: " << queryCounts;
    return queryCounts;
+9 −0
Original line number Diff line number Diff line
@@ -62,6 +62,12 @@ class DnsTlsFrontend {
    void set_chain_length(int length) { chain_length_ = length; }
    void setHangOnHandshakeForTesting(bool hangOnHandshake) { hangOnHandshake_ = hangOnHandshake; }

    // Set DnsTlsFrontend to not reply any response until there are |delay| responses or timeout.
    void setDelayQueries(int delay) { delayQueries_ = delay; }
    void setDelayQueriesTimeout(int timeout) { delayQueriesTimeout_ = timeout; }

    void setPassiveClose(bool passiveClose) { passiveClose_ = passiveClose; }

    static constexpr char kDefaultListenAddr[] = "127.0.0.3";
    static constexpr char kDefaultListenService[] = "853";
    static constexpr char kDefaultBackendAddr[] = "127.0.0.3";
@@ -94,6 +100,9 @@ class DnsTlsFrontend {
    std::mutex update_mutex_;
    int chain_length_ = 1;
    std::atomic<bool> hangOnHandshake_ = false;
    std::atomic<int> delayQueries_ = 1;
    std::atomic<int> delayQueriesTimeout_ = 1;
    std::atomic<bool> passiveClose_ = false;
};

}  // namespace test
+61 −0
Original line number Diff line number Diff line
@@ -5501,6 +5501,67 @@ TEST_F(ResolverTest, DnsServerSelection) {
    } while (std::next_permutation(serverList.begin(), serverList.end()));
}

TEST_F(ResolverTest, MultipleDotQueriesInOnePacket) {
    constexpr char hostname1[] = "query1.example.com.";
    constexpr char hostname2[] = "query2.example.com.";
    const std::vector<DnsRecord> records = {
            {hostname1, ns_type::ns_t_a, "1.2.3.4"},
            {hostname2, ns_type::ns_t_a, "1.2.3.5"},
    };

    const std::string addr = getUniqueIPv4Address();
    test::DNSResponder dns(addr);
    StartDns(dns, records);
    test::DnsTlsFrontend tls(addr, "853", addr, "53");
    ASSERT_TRUE(tls.startServer());

    // Set up resolver to strict mode.
    auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
    parcel.servers = {addr};
    parcel.tlsServers = {addr};
    parcel.tlsName = kDefaultPrivateDnsHostName;
    parcel.caCertificate = kCaCert;
    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
    EXPECT_TRUE(tls.waitForQueries(1));
    tls.clearQueries();
    dns.clearQueries();

    const auto queryAndCheck = [&](const std::string& hostname,
                                   const std::vector<DnsRecord>& records) {
        SCOPED_TRACE(hostname);

        const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
        auto [result, timeTakenMs] = safe_getaddrinfo_time_taken(hostname.c_str(), nullptr, hints);

        std::vector<std::string> expectedAnswers;
        for (const auto& r : records) {
            if (r.host_name == hostname) expectedAnswers.push_back(r.addr);
        }

        EXPECT_LE(timeTakenMs, 200);
        ASSERT_NE(result, nullptr);
        EXPECT_THAT(ToStrings(result), testing::UnorderedElementsAreArray(expectedAnswers));
    };

    // Set tls to reply DNS responses in one TCP packet and not to close the connection from its
    // side.
    tls.setDelayQueries(2);
    tls.setDelayQueriesTimeout(500);
    tls.setPassiveClose(true);

    // Start sending DNS requests at the same time.
    std::array<std::thread, 2> threads;
    threads[0] = std::thread(queryAndCheck, hostname1, records);
    threads[1] = std::thread(queryAndCheck, hostname2, records);

    threads[0].join();
    threads[1].join();

    // Also check no additional queries due to DoT reconnection.
    EXPECT_EQ(tls.queries(), 2);
}

// ResolverMultinetworkTest is used to verify multinetwork functionality. Here's how it works:
// The resolver sends queries to address A, and then there will be a TunForwarder helping forward
// the packets to address B, which is the address on which the testing server is listening. The