Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 92cb6437 authored by Mike Yu's avatar Mike Yu Committed by Automerger Merge Worker
Browse files

Fix DnsTlsSocket to consume all pending data from ssl am: 5e1b9918

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1492679

Change-Id: I5b722fe879446da24b05bdb07d0abc36418682b9
parents 7f35c63d 5e1b9918
Loading
Loading
Loading
Loading
+14 −2
Original line number Diff line number Diff line
@@ -450,8 +450,20 @@ void DnsTlsSocket::loop() {
            break;
        }
        if (fds[SSLFD].revents & (POLLIN | POLLERR | POLLHUP)) {
            bool readFailed = false;

            // readResponse() only reads one DNS (and consumes exact bytes) from ssl.
            // Keep doing so until ssl has no pending data.
            // TODO: readResponse() can block until it reads a complete DNS response. Consider
            // refactoring it to not get blocked in any case.
            do {
                if (!readResponse()) {
                    LOG(DEBUG) << "SSL remote close or read error.";
                    readFailed = true;
                }
            } while (SSL_pending(mSsl.get()) > 0 && !readFailed);

            if (readFailed) {
                break;
            }
        }
+3 −0
Original line number Diff line number Diff line
@@ -137,6 +137,9 @@ class DnsTlsSocket : public IDnsTlsSocket {
    int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock);

    bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);

    // Read one DNS response. It can potentially block until reading the exact bytes of
    // the response.
    bool readResponse() REQUIRES(mLock);

    // It is only used for DNS-OVER-TLS internal test.
+24 −9
Original line number Diff line number Diff line
@@ -223,6 +223,11 @@ void DnsTlsFrontend::requestHandler() {
                // client, including cleanup actions.
                queries_ += handleRequests(ssl.get(), client.get());
            }

            if (passiveClose_) {
                LOG(DEBUG) << "hold the current connection until next connection request";
                clientFd = std::move(client);
            }
        }
    }
    LOG(DEBUG) << "Ending loop";
@@ -230,6 +235,7 @@ void DnsTlsFrontend::requestHandler() {

int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
    int queryCounts = 0;
    std::vector<uint8_t> reply;
    pollfd fds = {.fd = clientFd, .events = POLLIN};
    do {
        uint8_t queryHeader[2];
@@ -263,16 +269,25 @@ int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
        uint8_t responseHeader[2];
        responseHeader[0] = rlen >> 8;
        responseHeader[1] = rlen;
        if (SSL_write(ssl, responseHeader, 2) != 2) {
            LOG(INFO) << "Failed to write response header";
            return queryCounts;
        reply.insert(reply.end(), responseHeader, responseHeader + 2);
        reply.insert(reply.end(), recv_buffer, recv_buffer + rlen);

        ++queryCounts;
        if (queryCounts >= delayQueries_) {
            break;
        }
        if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
            LOG(INFO) << "Failed to write response body";
            return queryCounts;
    } while (poll(&fds, 1, delayQueriesTimeout_) > 0);

    if (queryCounts < delayQueries_) {
        LOG(WARNING) << "Expect " << delayQueries_ << " queries, but actually received "
                     << queryCounts << " queries";
    }

    const int replyLen = reply.size();
    LOG(DEBUG) << "Sending " << queryCounts << "queries at once, byte = " << replyLen;
    if (SSL_write(ssl, reply.data(), replyLen) != replyLen) {
        LOG(WARNING) << "Failed to write response body";
    }
        ++queryCounts;
    } while (poll(&fds, 1, 1) > 0);

    LOG(DEBUG) << __func__ << " return: " << queryCounts;
    return queryCounts;
+9 −0
Original line number Diff line number Diff line
@@ -62,6 +62,12 @@ class DnsTlsFrontend {
    void set_chain_length(int length) { chain_length_ = length; }
    void setHangOnHandshakeForTesting(bool hangOnHandshake) { hangOnHandshake_ = hangOnHandshake; }

    // Set DnsTlsFrontend to not reply any response until there are |delay| responses or timeout.
    void setDelayQueries(int delay) { delayQueries_ = delay; }
    void setDelayQueriesTimeout(int timeout) { delayQueriesTimeout_ = timeout; }

    void setPassiveClose(bool passiveClose) { passiveClose_ = passiveClose; }

    static constexpr char kDefaultListenAddr[] = "127.0.0.3";
    static constexpr char kDefaultListenService[] = "853";
    static constexpr char kDefaultBackendAddr[] = "127.0.0.3";
@@ -94,6 +100,9 @@ class DnsTlsFrontend {
    std::mutex update_mutex_;
    int chain_length_ = 1;
    std::atomic<bool> hangOnHandshake_ = false;
    std::atomic<int> delayQueries_ = 1;
    std::atomic<int> delayQueriesTimeout_ = 1;
    std::atomic<bool> passiveClose_ = false;
};

}  // namespace test
+61 −0
Original line number Diff line number Diff line
@@ -5501,6 +5501,67 @@ TEST_F(ResolverTest, DnsServerSelection) {
    } while (std::next_permutation(serverList.begin(), serverList.end()));
}

TEST_F(ResolverTest, MultipleDotQueriesInOnePacket) {
    constexpr char hostname1[] = "query1.example.com.";
    constexpr char hostname2[] = "query2.example.com.";
    const std::vector<DnsRecord> records = {
            {hostname1, ns_type::ns_t_a, "1.2.3.4"},
            {hostname2, ns_type::ns_t_a, "1.2.3.5"},
    };

    const std::string addr = getUniqueIPv4Address();
    test::DNSResponder dns(addr);
    StartDns(dns, records);
    test::DnsTlsFrontend tls(addr, "853", addr, "53");
    ASSERT_TRUE(tls.startServer());

    // Set up resolver to strict mode.
    auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
    parcel.servers = {addr};
    parcel.tlsServers = {addr};
    parcel.tlsName = kDefaultPrivateDnsHostName;
    parcel.caCertificate = kCaCert;
    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
    EXPECT_TRUE(tls.waitForQueries(1));
    tls.clearQueries();
    dns.clearQueries();

    const auto queryAndCheck = [&](const std::string& hostname,
                                   const std::vector<DnsRecord>& records) {
        SCOPED_TRACE(hostname);

        const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
        auto [result, timeTakenMs] = safe_getaddrinfo_time_taken(hostname.c_str(), nullptr, hints);

        std::vector<std::string> expectedAnswers;
        for (const auto& r : records) {
            if (r.host_name == hostname) expectedAnswers.push_back(r.addr);
        }

        EXPECT_LE(timeTakenMs, 200);
        ASSERT_NE(result, nullptr);
        EXPECT_THAT(ToStrings(result), testing::UnorderedElementsAreArray(expectedAnswers));
    };

    // Set tls to reply DNS responses in one TCP packet and not to close the connection from its
    // side.
    tls.setDelayQueries(2);
    tls.setDelayQueriesTimeout(500);
    tls.setPassiveClose(true);

    // Start sending DNS requests at the same time.
    std::array<std::thread, 2> threads;
    threads[0] = std::thread(queryAndCheck, hostname1, records);
    threads[1] = std::thread(queryAndCheck, hostname2, records);

    threads[0].join();
    threads[1].join();

    // Also check no additional queries due to DoT reconnection.
    EXPECT_EQ(tls.queries(), 2);
}

// ResolverMultinetworkTest is used to verify multinetwork functionality. Here's how it works:
// The resolver sends queries to address A, and then there will be a TunForwarder helping forward
// the packets to address B, which is the address on which the testing server is listening. The