Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 918acf86 authored by Mike Yu's avatar Mike Yu Committed by Gerrit Code Review
Browse files

Merge "Make private DNS connect timeout configurable"

parents 208b32cf a772c209
Loading
Loading
Loading
Loading
+1 −5
Original line number Diff line number Diff line
@@ -109,11 +109,7 @@ bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y

// Returns a tuple of references to the elements of s.
auto make_tie(const DnsTlsServer& s) {
    return std::tie(
        s.ss,
        s.name,
        s.protocol
    );
    return std::tie(s.ss, s.name, s.protocol, s.connectTimeout);
}

bool DnsTlsServer::operator <(const DnsTlsServer& other) const {
+6 −0
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@
#ifndef _DNS_DNSTLSSERVER_H
#define _DNS_DNSTLSSERVER_H

#include <chrono>
#include <set>
#include <string>
#include <vector>
@@ -58,6 +59,11 @@ struct DnsTlsServer {
    // Placeholder.  More protocols might be defined in the future.
    int protocol = IPPROTO_TCP;

    // The time to wait for the attempt on connecting to the server.
    // Set the default value 127 seconds to be consistent with TCP connect timeout.
    // (presume net.ipv4.tcp_syn_retries = 6)
    std::chrono::milliseconds connectTimeout = std::chrono::milliseconds(127 * 1000);

    // Exact comparison of DnsTlsServer objects
    bool operator<(const DnsTlsServer& other) const;
    bool operator==(const DnsTlsServer& other) const;
+21 −16
Original line number Diff line number Diff line
@@ -59,16 +59,14 @@ namespace {

constexpr const char kCaCertDir[] = "/system/etc/security/cacerts";

int waitForReading(int fd) {
    struct pollfd fds = { .fd = fd, .events = POLLIN };
    const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1));
    return ret;
int waitForReading(int fd, int timeoutMs = -1) {
    pollfd fds = {.fd = fd, .events = POLLIN};
    return TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs));
}

int waitForWriting(int fd) {
    struct pollfd fds = { .fd = fd, .events = POLLOUT };
    const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1));
    return ret;
int waitForWriting(int fd, int timeoutMs = -1) {
    pollfd fds = {.fd = fd, .events = POLLOUT};
    return TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs));
}

std::string markToFwmarkString(unsigned mMark) {
@@ -250,14 +248,21 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
        const int ssl_err = SSL_get_error(ssl.get(), ret);
        switch (ssl_err) {
            case SSL_ERROR_WANT_READ:
                if (waitForReading(fd) != 1) {
                    PLOG(WARNING) << "SSL_connect read error, " << markToFwmarkString(mMark);
                // SSL_ERROR_WANT_READ is returned because the application data has been sent during
                // the TCP connection handshake, the device is waiting for the SSL handshake reply
                // from the server.
                if (int err = waitForReading(fd, mServer.connectTimeout.count()); err <= 0) {
                    PLOG(WARNING) << "SSL_connect read error " << err << ", "
                                  << markToFwmarkString(mMark);
                    return nullptr;
                }
                break;
            case SSL_ERROR_WANT_WRITE:
                if (waitForWriting(fd) != 1) {
                    PLOG(WARNING) << "SSL_connect write error, " << markToFwmarkString(mMark);
                // If no application data is sent during the TCP connection handshake, the
                // device is waiting for the connection established to perform SSL handshake.
                if (int err = waitForWriting(fd, mServer.connectTimeout.count()); err <= 0) {
                    PLOG(WARNING) << "SSL_connect write error " << err << ", "
                                  << markToFwmarkString(mMark);
                    return nullptr;
                }
                break;
@@ -291,8 +296,8 @@ bool DnsTlsSocket::sslWrite(const Slice buffer) {
            const int ssl_err = SSL_get_error(mSsl.get(), ret);
            switch (ssl_err) {
                case SSL_ERROR_WANT_WRITE:
                    if (waitForWriting(mSslFd.get()) != 1) {
                        LOG(DEBUG) << "SSL_write error";
                    if (int err = waitForWriting(mSslFd.get()); err <= 0) {
                        PLOG(WARNING) << "Poll failed in sslWrite, error " << err;
                        return false;
                    }
                    continue;
@@ -462,8 +467,8 @@ int DnsTlsSocket::sslRead(const Slice buffer, bool wait) {
        if (ret < 0) {
            const int ssl_err = SSL_get_error(mSsl.get(), ret);
            if (wait && ssl_err == SSL_ERROR_WANT_READ) {
                if (waitForReading(mSslFd.get()) != 1) {
                    LOG(DEBUG) << "Poll failed in sslRead: " << errno;
                if (int err = waitForReading(mSslFd.get()); err <= 0) {
                    PLOG(WARNING) << "Poll failed in sslRead, error " << err;
                    return SSL_ERROR_SYSCALL;
                }
                continue;
+13 −2
Original line number Diff line number Diff line
@@ -29,6 +29,8 @@
#include "netd_resolv/resolv.h"
#include "netdutils/BackoffSequence.h"

using std::chrono::milliseconds;

namespace android {
namespace net {

@@ -59,9 +61,9 @@ bool parseServer(const char* server, sockaddr_storage* parsed) {

int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
                                 const std::vector<std::string>& servers, const std::string& name,
                                 const std::string& caCert) {
                                 const std::string& caCert, int32_t connectTimeoutMs) {
    LOG(DEBUG) << "PrivateDnsConfiguration::set(" << netId << ", 0x" << std::hex << mark << std::dec
               << ", " << servers.size() << ", " << name << ")";
               << ", " << servers.size() << ", " << name << ", " << connectTimeoutMs << "ms)";

    // Parse the list of servers that has been passed in
    std::set<DnsTlsServer> tlsServers;
@@ -73,6 +75,15 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
        DnsTlsServer server(parsed);
        server.name = name;
        server.certificate = caCert;

        // connectTimeoutMs = 0: use the default timeout value.
        // connectTimeoutMs < 0: invalid timeout value.
        if (connectTimeoutMs > 0) {
            // Set a specific timeout value but limit it to be at least 1 second.
            server.connectTimeout =
                    (connectTimeoutMs < 1000) ? milliseconds(1000) : milliseconds(connectTimeoutMs);
        }

        tlsServers.insert(server);
    }

+2 −1
Original line number Diff line number Diff line
@@ -53,7 +53,8 @@ struct PrivateDnsStatus {
class PrivateDnsConfiguration {
  public:
    int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
            const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);
            const std::string& name, const std::string& caCert, int32_t connectTimeoutMs)
            EXCLUDES(mPrivateDnsLock);

    PrivateDnsStatus getStatus(unsigned netId) EXCLUDES(mPrivateDnsLock);

Loading