Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d97af93c authored by android-build-team Robot's avatar android-build-team Robot
Browse files

Snap for 6981746 from 08d1d5e7 to sc-release

Change-Id: I7c6df914fdafe1ea3a8f11066f44e8816600c065
parents 7c6de472 08d1d5e7
Loading
Loading
Loading
Loading
+14 −2
Original line number Diff line number Diff line
@@ -450,8 +450,20 @@ void DnsTlsSocket::loop() {
            break;
        }
        if (fds[SSLFD].revents & (POLLIN | POLLERR | POLLHUP)) {
            bool readFailed = false;

            // readResponse() only reads one DNS (and consumes exact bytes) from ssl.
            // Keep doing so until ssl has no pending data.
            // TODO: readResponse() can block until it reads a complete DNS response. Consider
            // refactoring it to not get blocked in any case.
            do {
                if (!readResponse()) {
                    LOG(DEBUG) << "SSL remote close or read error.";
                    readFailed = true;
                }
            } while (SSL_pending(mSsl.get()) > 0 && !readFailed);

            if (readFailed) {
                break;
            }
        }
+3 −0
Original line number Diff line number Diff line
@@ -137,6 +137,9 @@ class DnsTlsSocket : public IDnsTlsSocket {
    int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock);

    bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);

    // Read one DNS response. It can potentially block until reading the exact bytes of
    // the response.
    bool readResponse() REQUIRES(mLock);

    // It is only used for DNS-OVER-TLS internal test.
+24 −9
Original line number Diff line number Diff line
@@ -223,6 +223,11 @@ void DnsTlsFrontend::requestHandler() {
                // client, including cleanup actions.
                queries_ += handleRequests(ssl.get(), client.get());
            }

            if (passiveClose_) {
                LOG(DEBUG) << "hold the current connection until next connection request";
                clientFd = std::move(client);
            }
        }
    }
    LOG(DEBUG) << "Ending loop";
@@ -230,6 +235,7 @@ void DnsTlsFrontend::requestHandler() {

int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
    int queryCounts = 0;
    std::vector<uint8_t> reply;
    pollfd fds = {.fd = clientFd, .events = POLLIN};
    do {
        uint8_t queryHeader[2];
@@ -263,16 +269,25 @@ int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
        uint8_t responseHeader[2];
        responseHeader[0] = rlen >> 8;
        responseHeader[1] = rlen;
        if (SSL_write(ssl, responseHeader, 2) != 2) {
            LOG(INFO) << "Failed to write response header";
            return queryCounts;
        reply.insert(reply.end(), responseHeader, responseHeader + 2);
        reply.insert(reply.end(), recv_buffer, recv_buffer + rlen);

        ++queryCounts;
        if (queryCounts >= delayQueries_) {
            break;
        }
        if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
            LOG(INFO) << "Failed to write response body";
            return queryCounts;
    } while (poll(&fds, 1, delayQueriesTimeout_) > 0);

    if (queryCounts < delayQueries_) {
        LOG(WARNING) << "Expect " << delayQueries_ << " queries, but actually received "
                     << queryCounts << " queries";
    }

    const int replyLen = reply.size();
    LOG(DEBUG) << "Sending " << queryCounts << "queries at once, byte = " << replyLen;
    if (SSL_write(ssl, reply.data(), replyLen) != replyLen) {
        LOG(WARNING) << "Failed to write response body";
    }
        ++queryCounts;
    } while (poll(&fds, 1, 1) > 0);

    LOG(DEBUG) << __func__ << " return: " << queryCounts;
    return queryCounts;
+9 −0
Original line number Diff line number Diff line
@@ -62,6 +62,12 @@ class DnsTlsFrontend {
    void set_chain_length(int length) { chain_length_ = length; }
    void setHangOnHandshakeForTesting(bool hangOnHandshake) { hangOnHandshake_ = hangOnHandshake; }

    // Set DnsTlsFrontend to not reply any response until there are |delay| responses or timeout.
    void setDelayQueries(int delay) { delayQueries_ = delay; }
    void setDelayQueriesTimeout(int timeout) { delayQueriesTimeout_ = timeout; }

    void setPassiveClose(bool passiveClose) { passiveClose_ = passiveClose; }

    static constexpr char kDefaultListenAddr[] = "127.0.0.3";
    static constexpr char kDefaultListenService[] = "853";
    static constexpr char kDefaultBackendAddr[] = "127.0.0.3";
@@ -94,6 +100,9 @@ class DnsTlsFrontend {
    std::mutex update_mutex_;
    int chain_length_ = 1;
    std::atomic<bool> hangOnHandshake_ = false;
    std::atomic<int> delayQueries_ = 1;
    std::atomic<int> delayQueriesTimeout_ = 1;
    std::atomic<bool> passiveClose_ = false;
};

}  // namespace test
+61 −0
Original line number Diff line number Diff line
@@ -5501,6 +5501,67 @@ TEST_F(ResolverTest, DnsServerSelection) {
    } while (std::next_permutation(serverList.begin(), serverList.end()));
}

TEST_F(ResolverTest, MultipleDotQueriesInOnePacket) {
    constexpr char hostname1[] = "query1.example.com.";
    constexpr char hostname2[] = "query2.example.com.";
    const std::vector<DnsRecord> records = {
            {hostname1, ns_type::ns_t_a, "1.2.3.4"},
            {hostname2, ns_type::ns_t_a, "1.2.3.5"},
    };

    const std::string addr = getUniqueIPv4Address();
    test::DNSResponder dns(addr);
    StartDns(dns, records);
    test::DnsTlsFrontend tls(addr, "853", addr, "53");
    ASSERT_TRUE(tls.startServer());

    // Set up resolver to strict mode.
    auto parcel = DnsResponderClient::GetDefaultResolverParamsParcel();
    parcel.servers = {addr};
    parcel.tlsServers = {addr};
    parcel.tlsName = kDefaultPrivateDnsHostName;
    parcel.caCertificate = kCaCert;
    ASSERT_TRUE(mDnsClient.SetResolversFromParcel(parcel));
    EXPECT_TRUE(WaitForPrivateDnsValidation(tls.listen_address(), true));
    EXPECT_TRUE(tls.waitForQueries(1));
    tls.clearQueries();
    dns.clearQueries();

    const auto queryAndCheck = [&](const std::string& hostname,
                                   const std::vector<DnsRecord>& records) {
        SCOPED_TRACE(hostname);

        const addrinfo hints = {.ai_family = AF_INET, .ai_socktype = SOCK_DGRAM};
        auto [result, timeTakenMs] = safe_getaddrinfo_time_taken(hostname.c_str(), nullptr, hints);

        std::vector<std::string> expectedAnswers;
        for (const auto& r : records) {
            if (r.host_name == hostname) expectedAnswers.push_back(r.addr);
        }

        EXPECT_LE(timeTakenMs, 200);
        ASSERT_NE(result, nullptr);
        EXPECT_THAT(ToStrings(result), testing::UnorderedElementsAreArray(expectedAnswers));
    };

    // Set tls to reply DNS responses in one TCP packet and not to close the connection from its
    // side.
    tls.setDelayQueries(2);
    tls.setDelayQueriesTimeout(500);
    tls.setPassiveClose(true);

    // Start sending DNS requests at the same time.
    std::array<std::thread, 2> threads;
    threads[0] = std::thread(queryAndCheck, hostname1, records);
    threads[1] = std::thread(queryAndCheck, hostname2, records);

    threads[0].join();
    threads[1].join();

    // Also check no additional queries due to DoT reconnection.
    EXPECT_EQ(tls.queries(), 2);
}

// ResolverMultinetworkTest is used to verify multinetwork functionality. Here's how it works:
// The resolver sends queries to address A, and then there will be a TunForwarder helping forward
// the packets to address B, which is the address on which the testing server is listening. The