Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit a772c209 authored by Mike Yu's avatar Mike Yu
Browse files

Make private DNS connect timeout configurable

It could take time to connect to a private DNS server if the system
allows 6 syn-retransmissions (net.ipv4.tcp_syn_retries = 6), which
can take time more than 2 minutes.

This change allows us to configure the timeout value via dnsresolver
binder service, and keep the default timeout value the same as the
original design.

Bug: 120182528
Bug: 141218721
Test: atest --include-subdirs packages/modules/DnsResolver
Test: m com.android.resolv
      adb install com.android.resolv
      rebooted
Change-Id: I8711a31172cfc671bf348191db363e7863831470
parent 0a423e4b
Loading
Loading
Loading
Loading
+1 −5
Original line number Diff line number Diff line
@@ -109,11 +109,7 @@ bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y

// Returns a tuple of references to the elements of s.
auto make_tie(const DnsTlsServer& s) {
    return std::tie(
        s.ss,
        s.name,
        s.protocol
    );
    return std::tie(s.ss, s.name, s.protocol, s.connectTimeout);
}

bool DnsTlsServer::operator <(const DnsTlsServer& other) const {
+6 −0
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@
#ifndef _DNS_DNSTLSSERVER_H
#define _DNS_DNSTLSSERVER_H

#include <chrono>
#include <set>
#include <string>
#include <vector>
@@ -58,6 +59,11 @@ struct DnsTlsServer {
    // Placeholder.  More protocols might be defined in the future.
    int protocol = IPPROTO_TCP;

    // The time to wait for the attempt on connecting to the server.
    // Set the default value 127 seconds to be consistent with TCP connect timeout.
    // (presume net.ipv4.tcp_syn_retries = 6)
    std::chrono::milliseconds connectTimeout = std::chrono::milliseconds(127 * 1000);

    // Exact comparison of DnsTlsServer objects
    bool operator<(const DnsTlsServer& other) const;
    bool operator==(const DnsTlsServer& other) const;
+21 −16
Original line number Diff line number Diff line
@@ -59,16 +59,14 @@ namespace {

constexpr const char kCaCertDir[] = "/system/etc/security/cacerts";

int waitForReading(int fd) {
    struct pollfd fds = { .fd = fd, .events = POLLIN };
    const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1));
    return ret;
int waitForReading(int fd, int timeoutMs = -1) {
    pollfd fds = {.fd = fd, .events = POLLIN};
    return TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs));
}

int waitForWriting(int fd) {
    struct pollfd fds = { .fd = fd, .events = POLLOUT };
    const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1));
    return ret;
int waitForWriting(int fd, int timeoutMs = -1) {
    pollfd fds = {.fd = fd, .events = POLLOUT};
    return TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs));
}

std::string markToFwmarkString(unsigned mMark) {
@@ -250,14 +248,21 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
        const int ssl_err = SSL_get_error(ssl.get(), ret);
        switch (ssl_err) {
            case SSL_ERROR_WANT_READ:
                if (waitForReading(fd) != 1) {
                    PLOG(WARNING) << "SSL_connect read error, " << markToFwmarkString(mMark);
                // SSL_ERROR_WANT_READ is returned because the application data has been sent during
                // the TCP connection handshake, the device is waiting for the SSL handshake reply
                // from the server.
                if (int err = waitForReading(fd, mServer.connectTimeout.count()); err <= 0) {
                    PLOG(WARNING) << "SSL_connect read error " << err << ", "
                                  << markToFwmarkString(mMark);
                    return nullptr;
                }
                break;
            case SSL_ERROR_WANT_WRITE:
                if (waitForWriting(fd) != 1) {
                    PLOG(WARNING) << "SSL_connect write error, " << markToFwmarkString(mMark);
                // If no application data is sent during the TCP connection handshake, the
                // device is waiting for the connection established to perform SSL handshake.
                if (int err = waitForWriting(fd, mServer.connectTimeout.count()); err <= 0) {
                    PLOG(WARNING) << "SSL_connect write error " << err << ", "
                                  << markToFwmarkString(mMark);
                    return nullptr;
                }
                break;
@@ -291,8 +296,8 @@ bool DnsTlsSocket::sslWrite(const Slice buffer) {
            const int ssl_err = SSL_get_error(mSsl.get(), ret);
            switch (ssl_err) {
                case SSL_ERROR_WANT_WRITE:
                    if (waitForWriting(mSslFd.get()) != 1) {
                        LOG(DEBUG) << "SSL_write error";
                    if (int err = waitForWriting(mSslFd.get()); err <= 0) {
                        PLOG(WARNING) << "Poll failed in sslWrite, error " << err;
                        return false;
                    }
                    continue;
@@ -462,8 +467,8 @@ int DnsTlsSocket::sslRead(const Slice buffer, bool wait) {
        if (ret < 0) {
            const int ssl_err = SSL_get_error(mSsl.get(), ret);
            if (wait && ssl_err == SSL_ERROR_WANT_READ) {
                if (waitForReading(mSslFd.get()) != 1) {
                    LOG(DEBUG) << "Poll failed in sslRead: " << errno;
                if (int err = waitForReading(mSslFd.get()); err <= 0) {
                    PLOG(WARNING) << "Poll failed in sslRead, error " << err;
                    return SSL_ERROR_SYSCALL;
                }
                continue;
+13 −2
Original line number Diff line number Diff line
@@ -29,6 +29,8 @@
#include "netd_resolv/resolv.h"
#include "netdutils/BackoffSequence.h"

using std::chrono::milliseconds;

namespace android {
namespace net {

@@ -56,9 +58,9 @@ bool parseServer(const char* server, sockaddr_storage* parsed) {

int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
                                 const std::vector<std::string>& servers, const std::string& name,
                                 const std::string& caCert) {
                                 const std::string& caCert, int32_t connectTimeoutMs) {
    LOG(DEBUG) << "PrivateDnsConfiguration::set(" << netId << ", 0x" << std::hex << mark << std::dec
               << ", " << servers.size() << ", " << name << ")";
               << ", " << servers.size() << ", " << name << ", " << connectTimeoutMs << "ms)";

    // Parse the list of servers that has been passed in
    std::set<DnsTlsServer> tlsServers;
@@ -70,6 +72,15 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
        DnsTlsServer server(parsed);
        server.name = name;
        server.certificate = caCert;

        // connectTimeoutMs = 0: use the default timeout value.
        // connectTimeoutMs < 0: invalid timeout value.
        if (connectTimeoutMs > 0) {
            // Set a specific timeout value but limit it to be at least 1 second.
            server.connectTimeout =
                    (connectTimeoutMs < 1000) ? milliseconds(1000) : milliseconds(connectTimeoutMs);
        }

        tlsServers.insert(server);
    }

+2 −1
Original line number Diff line number Diff line
@@ -53,7 +53,8 @@ struct PrivateDnsStatus {
class PrivateDnsConfiguration {
  public:
    int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
            const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);
            const std::string& name, const std::string& caCert, int32_t connectTimeoutMs)
            EXCLUDES(mPrivateDnsLock);

    PrivateDnsStatus getStatus(unsigned netId) EXCLUDES(mPrivateDnsLock);

Loading