Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 0c7b1ca2 authored by Mike Yu's avatar Mike Yu Committed by Gerrit Code Review
Browse files

Merge changes Iceb7bc7f,I37afed96,I6a0a4c96

* changes:
  Make DoT retries configurable
  Move connectTimeout to DnsTlsSocket
  Allow to do TLS handshake on DnsTlsSocket loop thread
parents b6867372 bb499099
Loading
Loading
Loading
Loading
+8 −1
Original line number Diff line number Diff line
@@ -20,9 +20,16 @@

#include <android-base/logging.h>

#include "Experiments.h"

namespace android {
namespace net {

DnsTlsQueryMap::DnsTlsQueryMap() {
    mMaxTries = Experiments::getInstance()->getFlag("dot_maxtries", kMaxTries);
    if (mMaxTries < 1) mMaxTries = 1;
}

std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(
        const netdutils::Slice query) {
    std::lock_guard guard(mLock);
@@ -67,7 +74,7 @@ void DnsTlsQueryMap::cleanup() {
    std::lock_guard guard(mLock);
    for (auto it = mQueries.begin(); it != mQueries.end();) {
        auto& p = it->second;
        if (p.tries >= kMaxTries) {
        if (p.tries >= mMaxTries) {
            expire(&p);
            it = mQueries.erase(it);
        } else {
+3 −0
Original line number Diff line number Diff line
@@ -36,6 +36,8 @@ class DnsTlsQueryMap {
  public:
    enum class Response : uint8_t { success, network_error, limit_error, internal_error };

    DnsTlsQueryMap();

    struct Query {
        // The new ID number assigned to this query.
        uint16_t newId;
@@ -80,6 +82,7 @@ class DnsTlsQueryMap {

    // The maximum number of times we will send a query before abandoning it.
    static constexpr int kMaxTries = 3;
    int mMaxTries;

  private:
    std::mutex mLock;
+1 −1
Original line number Diff line number Diff line
@@ -109,7 +109,7 @@ bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y

// Returns a tuple of references to the elements of s.
auto make_tie(const DnsTlsServer& s) {
    return std::tie(s.ss, s.name, s.protocol, s.connectTimeout);
    return std::tie(s.ss, s.name, s.protocol);
}

bool DnsTlsServer::operator <(const DnsTlsServer& other) const {
+0 −8
Original line number Diff line number Diff line
@@ -16,7 +16,6 @@

#pragma once

#include <chrono>
#include <set>
#include <string>
#include <vector>
@@ -51,13 +50,6 @@ struct DnsTlsServer {
    // Placeholder.  More protocols might be defined in the future.
    int protocol = IPPROTO_TCP;

    // The time to wait for the attempt on connecting to the server.
    // Set the default value 127 seconds to be consistent with TCP connect timeout.
    // (presume net.ipv4.tcp_syn_retries = 6)
    static constexpr std::chrono::milliseconds kDotConnectTimeoutMs =
            std::chrono::milliseconds(127 * 1000);
    std::chrono::milliseconds connectTimeout = kDotConnectTimeoutMs;

    // Exact comparison of DnsTlsServer objects
    bool operator<(const DnsTlsServer& other) const;
    bool operator==(const DnsTlsServer& other) const;
+110 −14
Original line number Diff line number Diff line
@@ -37,12 +37,14 @@
#include <netdutils/SocketOption.h>
#include <netdutils/ThreadUtil.h>

#include "Experiments.h"
#include "netd_resolv/resolv.h"
#include "private/android_filesystem_config.h"  // AID_DNS
#include "resolv_private.h"

namespace android {

using android::net::Experiments;
using base::StringPrintf;
using netdutils::enableSockopt;
using netdutils::enableTcpKeepAlives;
@@ -172,6 +174,15 @@ bool DnsTlsSocket::initialize() {
    mCache->prepareSslContext(mSslCtx.get());

    mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
    mShutdownEvent.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));

    const Experiments* const instance = Experiments::getInstance();
    mConnectTimeoutMs = instance->getFlag("dot_connect_timeout_ms", kDotConnectTimeoutMs);
    if (mConnectTimeoutMs < 1000) mConnectTimeoutMs = 1000;

    mAsyncHandshake = instance->getFlag("dot_async_handshake", 0);
    LOG(DEBUG) << "DnsTlsSocket is initialized with { mConnectTimeoutMs: " << mConnectTimeoutMs
               << ", mAsyncHandshake: " << mAsyncHandshake << " }";

    transitionState(State::UNINITIALIZED, State::INITIALIZED);

@@ -186,17 +197,18 @@ bool DnsTlsSocket::startHandshake() {
    }
    transitionState(State::INITIALIZED, State::CONNECTING);

    // Connect
    Status status = tcpConnect();
    if (!status.ok()) {
    if (!mAsyncHandshake) {
        if (Status status = tcpConnect(); !status.ok()) {
            transitionState(State::CONNECTING, State::WAIT_FOR_DELETE);
            LOG(WARNING) << "TCP Handshake failed: " << status.code();
            return false;
        }
    mSsl = sslConnect(mSslFd.get());
    if (!mSsl) {
        if (mSsl = sslConnect(mSslFd.get()); !mSsl) {
            transitionState(State::CONNECTING, State::WAIT_FOR_DELETE);
            LOG(WARNING) << "TLS Handshake failed";
            return false;
        }
    }

    // Start the I/O loop.
    mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this));
@@ -204,7 +216,7 @@ bool DnsTlsSocket::startHandshake() {
    return true;
}

bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
bssl::UniquePtr<SSL> DnsTlsSocket::prepareForSslConnect(int fd) {
    if (!mSslCtx) {
        LOG(ERROR) << "Internal error: context is null in sslConnect";
        return nullptr;
@@ -247,6 +259,15 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
        LOG(DEBUG) << "No session available";
    }

    return ssl;
}

bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
    bssl::UniquePtr<SSL> ssl;
    if (ssl = prepareForSslConnect(fd); !ssl) {
        return nullptr;
    }

    for (;;) {
        LOG(DEBUG) << " Calling SSL_connect with mark 0x" << std::hex << mMark;
        int ret = SSL_connect(ssl.get());
@@ -259,7 +280,7 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
                // SSL_ERROR_WANT_READ is returned because the application data has been sent during
                // the TCP connection handshake, the device is waiting for the SSL handshake reply
                // from the server.
                if (int err = waitForReading(fd, mServer.connectTimeout.count()); err <= 0) {
                if (int err = waitForReading(fd, mConnectTimeoutMs); err <= 0) {
                    PLOG(WARNING) << "SSL_connect read error " << err << ", mark 0x" << std::hex
                                  << mMark;
                    return nullptr;
@@ -268,7 +289,7 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
            case SSL_ERROR_WANT_WRITE:
                // If no application data is sent during the TCP connection handshake, the
                // device is waiting for the connection established to perform SSL handshake.
                if (int err = waitForWriting(fd, mServer.connectTimeout.count()); err <= 0) {
                if (int err = waitForWriting(fd, mConnectTimeoutMs); err <= 0) {
                    PLOG(WARNING) << "SSL_connect write error " << err << ", mark 0x" << std::hex
                                  << mMark;
                    return nullptr;
@@ -286,6 +307,59 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
    return ssl;
}

bssl::UniquePtr<SSL> DnsTlsSocket::sslConnectV2(int fd) {
    bssl::UniquePtr<SSL> ssl;
    if (ssl = prepareForSslConnect(fd); !ssl) {
        return nullptr;
    }

    for (;;) {
        LOG(DEBUG) << " Calling SSL_connect with mark 0x" << std::hex << mMark;
        int ret = SSL_connect(ssl.get());
        LOG(DEBUG) << " SSL_connect returned " << ret << " with mark 0x" << std::hex << mMark;
        if (ret == 1) break;  // SSL handshake complete;

        enum { SSLFD = 0, EVENTFD = 1 };
        pollfd fds[2] = {
                {.fd = mSslFd.get(), .events = 0},
                {.fd = mShutdownEvent.get(), .events = POLLIN},
        };

        const int ssl_err = SSL_get_error(ssl.get(), ret);
        switch (ssl_err) {
            case SSL_ERROR_WANT_READ:
                fds[SSLFD].events = POLLIN;
                break;
            case SSL_ERROR_WANT_WRITE:
                fds[SSLFD].events = POLLOUT;
                break;
            default:
                PLOG(WARNING) << "SSL_connect ssl error =" << ssl_err << ", mark 0x" << std::hex
                              << mMark;
                return nullptr;
        }

        int n = TEMP_FAILURE_RETRY(poll(fds, std::size(fds), mConnectTimeoutMs));
        if (n <= 0) {
            PLOG(WARNING) << ((n == 0) ? "handshake timeout" : "Poll failed");
            return nullptr;
        }

        if (fds[EVENTFD].revents & (POLLIN | POLLERR)) {
            LOG(WARNING) << "Got shutdown request during handshake";
            return nullptr;
        }
        if (fds[SSLFD].revents & POLLERR) {
            LOG(WARNING) << "Got POLLERR on SSLFD during handshake";
            return nullptr;
        }
    }

    LOG(DEBUG) << mMark << " handshake complete";

    return ssl;
}

void DnsTlsSocket::sslDisconnect() {
    if (mSsl) {
        SSL_shutdown(mSsl.get());
@@ -326,9 +400,26 @@ void DnsTlsSocket::loop() {
    std::deque<std::vector<uint8_t>> q;
    const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000;

    transitionState(State::CONNECTING, State::CONNECTED);
    setThreadName(StringPrintf("TlsListen_%u", mMark & 0xffff).c_str());

    if (mAsyncHandshake) {
        if (Status status = tcpConnect(); !status.ok()) {
            LOG(WARNING) << "TCP Handshake failed: " << status.code();
            mObserver->onClosed();
            transitionState(State::CONNECTING, State::WAIT_FOR_DELETE);
            return;
        }
        if (mSsl = sslConnectV2(mSslFd.get()); !mSsl) {
            LOG(WARNING) << "TLS Handshake failed";
            mObserver->onClosed();
            transitionState(State::CONNECTING, State::WAIT_FOR_DELETE);
            return;
        }
        LOG(DEBUG) << "Handshaking succeeded";
    }

    transitionState(State::CONNECTING, State::CONNECTED);

    while (true) {
        // poll() ignores negative fds
        struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } };
@@ -446,6 +537,11 @@ void DnsTlsSocket::requestLoopShutdown() {
        // Write a negative number to the eventfd.  This triggers an immediate shutdown.
        incrementEventFd(INT64_MIN);
    }
    if (mShutdownEvent != -1) {
        if (eventfd_write(mShutdownEvent.get(), INT64_MIN) == -1) {
            PLOG(ERROR) << "Failed to write to mShutdownEvent";
        }
    }
}

bool DnsTlsSocket::incrementEventFd(const int64_t count) {
Loading