Loading DnsTlsQueryMap.cpp +8 −1 Original line number Diff line number Diff line Loading @@ -20,9 +20,16 @@ #include <android-base/logging.h> #include "Experiments.h" namespace android { namespace net { DnsTlsQueryMap::DnsTlsQueryMap() { mMaxTries = Experiments::getInstance()->getFlag("dot_maxtries", kMaxTries); if (mMaxTries < 1) mMaxTries = 1; } std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery( const netdutils::Slice query) { std::lock_guard guard(mLock); Loading Loading @@ -67,7 +74,7 @@ void DnsTlsQueryMap::cleanup() { std::lock_guard guard(mLock); for (auto it = mQueries.begin(); it != mQueries.end();) { auto& p = it->second; if (p.tries >= kMaxTries) { if (p.tries >= mMaxTries) { expire(&p); it = mQueries.erase(it); } else { Loading DnsTlsQueryMap.h +3 −0 Original line number Diff line number Diff line Loading @@ -36,6 +36,8 @@ class DnsTlsQueryMap { public: enum class Response : uint8_t { success, network_error, limit_error, internal_error }; DnsTlsQueryMap(); struct Query { // The new ID number assigned to this query. uint16_t newId; Loading Loading @@ -80,6 +82,7 @@ class DnsTlsQueryMap { // The maximum number of times we will send a query before abandoning it. static constexpr int kMaxTries = 3; int mMaxTries; private: std::mutex mLock; Loading DnsTlsServer.cpp +1 −1 Original line number Diff line number Diff line Loading @@ -109,7 +109,7 @@ bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y // Returns a tuple of references to the elements of s. auto make_tie(const DnsTlsServer& s) { return std::tie(s.ss, s.name, s.protocol, s.connectTimeout); return std::tie(s.ss, s.name, s.protocol); } bool DnsTlsServer::operator <(const DnsTlsServer& other) const { Loading DnsTlsServer.h +0 −8 Original line number Diff line number Diff line Loading @@ -16,7 +16,6 @@ #pragma once #include <chrono> #include <set> #include <string> #include <vector> Loading Loading @@ -51,13 +50,6 @@ struct DnsTlsServer { // Placeholder. More protocols might be defined in the future. int protocol = IPPROTO_TCP; // The time to wait for the attempt on connecting to the server. // Set the default value 127 seconds to be consistent with TCP connect timeout. // (presume net.ipv4.tcp_syn_retries = 6) static constexpr std::chrono::milliseconds kDotConnectTimeoutMs = std::chrono::milliseconds(127 * 1000); std::chrono::milliseconds connectTimeout = kDotConnectTimeoutMs; // Exact comparison of DnsTlsServer objects bool operator<(const DnsTlsServer& other) const; bool operator==(const DnsTlsServer& other) const; Loading DnsTlsSocket.cpp +110 −14 Original line number Diff line number Diff line Loading @@ -37,12 +37,14 @@ #include <netdutils/SocketOption.h> #include <netdutils/ThreadUtil.h> #include "Experiments.h" #include "netd_resolv/resolv.h" #include "private/android_filesystem_config.h" // AID_DNS #include "resolv_private.h" namespace android { using android::net::Experiments; using base::StringPrintf; using netdutils::enableSockopt; using netdutils::enableTcpKeepAlives; Loading Loading @@ -172,6 +174,15 @@ bool DnsTlsSocket::initialize() { mCache->prepareSslContext(mSslCtx.get()); mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC)); mShutdownEvent.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC)); const Experiments* const instance = Experiments::getInstance(); mConnectTimeoutMs = instance->getFlag("dot_connect_timeout_ms", kDotConnectTimeoutMs); if (mConnectTimeoutMs < 1000) mConnectTimeoutMs = 1000; mAsyncHandshake = instance->getFlag("dot_async_handshake", 0); LOG(DEBUG) << "DnsTlsSocket is initialized with { mConnectTimeoutMs: " << mConnectTimeoutMs << ", mAsyncHandshake: " << mAsyncHandshake << " }"; transitionState(State::UNINITIALIZED, State::INITIALIZED); Loading @@ -186,17 +197,18 @@ bool DnsTlsSocket::startHandshake() { } transitionState(State::INITIALIZED, State::CONNECTING); // Connect Status status = tcpConnect(); if (!status.ok()) { if (!mAsyncHandshake) { if (Status status = tcpConnect(); !status.ok()) { transitionState(State::CONNECTING, State::WAIT_FOR_DELETE); LOG(WARNING) << "TCP Handshake failed: " << status.code(); return false; } mSsl = sslConnect(mSslFd.get()); if (!mSsl) { if (mSsl = sslConnect(mSslFd.get()); !mSsl) { transitionState(State::CONNECTING, State::WAIT_FOR_DELETE); LOG(WARNING) << "TLS Handshake failed"; return false; } } // Start the I/O loop. mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this)); Loading @@ -204,7 +216,7 @@ bool DnsTlsSocket::startHandshake() { return true; } bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { bssl::UniquePtr<SSL> DnsTlsSocket::prepareForSslConnect(int fd) { if (!mSslCtx) { LOG(ERROR) << "Internal error: context is null in sslConnect"; return nullptr; Loading Loading @@ -247,6 +259,15 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { LOG(DEBUG) << "No session available"; } return ssl; } bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { bssl::UniquePtr<SSL> ssl; if (ssl = prepareForSslConnect(fd); !ssl) { return nullptr; } for (;;) { LOG(DEBUG) << " Calling SSL_connect with mark 0x" << std::hex << mMark; int ret = SSL_connect(ssl.get()); Loading @@ -259,7 +280,7 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { // SSL_ERROR_WANT_READ is returned because the application data has been sent during // the TCP connection handshake, the device is waiting for the SSL handshake reply // from the server. if (int err = waitForReading(fd, mServer.connectTimeout.count()); err <= 0) { if (int err = waitForReading(fd, mConnectTimeoutMs); err <= 0) { PLOG(WARNING) << "SSL_connect read error " << err << ", mark 0x" << std::hex << mMark; return nullptr; Loading @@ -268,7 +289,7 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { case SSL_ERROR_WANT_WRITE: // If no application data is sent during the TCP connection handshake, the // device is waiting for the connection established to perform SSL handshake. if (int err = waitForWriting(fd, mServer.connectTimeout.count()); err <= 0) { if (int err = waitForWriting(fd, mConnectTimeoutMs); err <= 0) { PLOG(WARNING) << "SSL_connect write error " << err << ", mark 0x" << std::hex << mMark; return nullptr; Loading @@ -286,6 +307,59 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { return ssl; } bssl::UniquePtr<SSL> DnsTlsSocket::sslConnectV2(int fd) { bssl::UniquePtr<SSL> ssl; if (ssl = prepareForSslConnect(fd); !ssl) { return nullptr; } for (;;) { LOG(DEBUG) << " Calling SSL_connect with mark 0x" << std::hex << mMark; int ret = SSL_connect(ssl.get()); LOG(DEBUG) << " SSL_connect returned " << ret << " with mark 0x" << std::hex << mMark; if (ret == 1) break; // SSL handshake complete; enum { SSLFD = 0, EVENTFD = 1 }; pollfd fds[2] = { {.fd = mSslFd.get(), .events = 0}, {.fd = mShutdownEvent.get(), .events = POLLIN}, }; const int ssl_err = SSL_get_error(ssl.get(), ret); switch (ssl_err) { case SSL_ERROR_WANT_READ: fds[SSLFD].events = POLLIN; break; case SSL_ERROR_WANT_WRITE: fds[SSLFD].events = POLLOUT; break; default: PLOG(WARNING) << "SSL_connect ssl error =" << ssl_err << ", mark 0x" << std::hex << mMark; return nullptr; } int n = TEMP_FAILURE_RETRY(poll(fds, std::size(fds), mConnectTimeoutMs)); if (n <= 0) { PLOG(WARNING) << ((n == 0) ? "handshake timeout" : "Poll failed"); return nullptr; } if (fds[EVENTFD].revents & (POLLIN | POLLERR)) { LOG(WARNING) << "Got shutdown request during handshake"; return nullptr; } if (fds[SSLFD].revents & POLLERR) { LOG(WARNING) << "Got POLLERR on SSLFD during handshake"; return nullptr; } } LOG(DEBUG) << mMark << " handshake complete"; return ssl; } void DnsTlsSocket::sslDisconnect() { if (mSsl) { SSL_shutdown(mSsl.get()); Loading Loading @@ -326,9 +400,26 @@ void DnsTlsSocket::loop() { std::deque<std::vector<uint8_t>> q; const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000; transitionState(State::CONNECTING, State::CONNECTED); setThreadName(StringPrintf("TlsListen_%u", mMark & 0xffff).c_str()); if (mAsyncHandshake) { if (Status status = tcpConnect(); !status.ok()) { LOG(WARNING) << "TCP Handshake failed: " << status.code(); mObserver->onClosed(); transitionState(State::CONNECTING, State::WAIT_FOR_DELETE); return; } if (mSsl = sslConnectV2(mSslFd.get()); !mSsl) { LOG(WARNING) << "TLS Handshake failed"; mObserver->onClosed(); transitionState(State::CONNECTING, State::WAIT_FOR_DELETE); return; } LOG(DEBUG) << "Handshaking succeeded"; } transitionState(State::CONNECTING, State::CONNECTED); while (true) { // poll() ignores negative fds struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } }; Loading Loading @@ -446,6 +537,11 @@ void DnsTlsSocket::requestLoopShutdown() { // Write a negative number to the eventfd. This triggers an immediate shutdown. incrementEventFd(INT64_MIN); } if (mShutdownEvent != -1) { if (eventfd_write(mShutdownEvent.get(), INT64_MIN) == -1) { PLOG(ERROR) << "Failed to write to mShutdownEvent"; } } } bool DnsTlsSocket::incrementEventFd(const int64_t count) { Loading Loading
DnsTlsQueryMap.cpp +8 −1 Original line number Diff line number Diff line Loading @@ -20,9 +20,16 @@ #include <android-base/logging.h> #include "Experiments.h" namespace android { namespace net { DnsTlsQueryMap::DnsTlsQueryMap() { mMaxTries = Experiments::getInstance()->getFlag("dot_maxtries", kMaxTries); if (mMaxTries < 1) mMaxTries = 1; } std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery( const netdutils::Slice query) { std::lock_guard guard(mLock); Loading Loading @@ -67,7 +74,7 @@ void DnsTlsQueryMap::cleanup() { std::lock_guard guard(mLock); for (auto it = mQueries.begin(); it != mQueries.end();) { auto& p = it->second; if (p.tries >= kMaxTries) { if (p.tries >= mMaxTries) { expire(&p); it = mQueries.erase(it); } else { Loading
DnsTlsQueryMap.h +3 −0 Original line number Diff line number Diff line Loading @@ -36,6 +36,8 @@ class DnsTlsQueryMap { public: enum class Response : uint8_t { success, network_error, limit_error, internal_error }; DnsTlsQueryMap(); struct Query { // The new ID number assigned to this query. uint16_t newId; Loading Loading @@ -80,6 +82,7 @@ class DnsTlsQueryMap { // The maximum number of times we will send a query before abandoning it. static constexpr int kMaxTries = 3; int mMaxTries; private: std::mutex mLock; Loading
DnsTlsServer.cpp +1 −1 Original line number Diff line number Diff line Loading @@ -109,7 +109,7 @@ bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y // Returns a tuple of references to the elements of s. auto make_tie(const DnsTlsServer& s) { return std::tie(s.ss, s.name, s.protocol, s.connectTimeout); return std::tie(s.ss, s.name, s.protocol); } bool DnsTlsServer::operator <(const DnsTlsServer& other) const { Loading
DnsTlsServer.h +0 −8 Original line number Diff line number Diff line Loading @@ -16,7 +16,6 @@ #pragma once #include <chrono> #include <set> #include <string> #include <vector> Loading Loading @@ -51,13 +50,6 @@ struct DnsTlsServer { // Placeholder. More protocols might be defined in the future. int protocol = IPPROTO_TCP; // The time to wait for the attempt on connecting to the server. // Set the default value 127 seconds to be consistent with TCP connect timeout. // (presume net.ipv4.tcp_syn_retries = 6) static constexpr std::chrono::milliseconds kDotConnectTimeoutMs = std::chrono::milliseconds(127 * 1000); std::chrono::milliseconds connectTimeout = kDotConnectTimeoutMs; // Exact comparison of DnsTlsServer objects bool operator<(const DnsTlsServer& other) const; bool operator==(const DnsTlsServer& other) const; Loading
DnsTlsSocket.cpp +110 −14 Original line number Diff line number Diff line Loading @@ -37,12 +37,14 @@ #include <netdutils/SocketOption.h> #include <netdutils/ThreadUtil.h> #include "Experiments.h" #include "netd_resolv/resolv.h" #include "private/android_filesystem_config.h" // AID_DNS #include "resolv_private.h" namespace android { using android::net::Experiments; using base::StringPrintf; using netdutils::enableSockopt; using netdutils::enableTcpKeepAlives; Loading Loading @@ -172,6 +174,15 @@ bool DnsTlsSocket::initialize() { mCache->prepareSslContext(mSslCtx.get()); mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC)); mShutdownEvent.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC)); const Experiments* const instance = Experiments::getInstance(); mConnectTimeoutMs = instance->getFlag("dot_connect_timeout_ms", kDotConnectTimeoutMs); if (mConnectTimeoutMs < 1000) mConnectTimeoutMs = 1000; mAsyncHandshake = instance->getFlag("dot_async_handshake", 0); LOG(DEBUG) << "DnsTlsSocket is initialized with { mConnectTimeoutMs: " << mConnectTimeoutMs << ", mAsyncHandshake: " << mAsyncHandshake << " }"; transitionState(State::UNINITIALIZED, State::INITIALIZED); Loading @@ -186,17 +197,18 @@ bool DnsTlsSocket::startHandshake() { } transitionState(State::INITIALIZED, State::CONNECTING); // Connect Status status = tcpConnect(); if (!status.ok()) { if (!mAsyncHandshake) { if (Status status = tcpConnect(); !status.ok()) { transitionState(State::CONNECTING, State::WAIT_FOR_DELETE); LOG(WARNING) << "TCP Handshake failed: " << status.code(); return false; } mSsl = sslConnect(mSslFd.get()); if (!mSsl) { if (mSsl = sslConnect(mSslFd.get()); !mSsl) { transitionState(State::CONNECTING, State::WAIT_FOR_DELETE); LOG(WARNING) << "TLS Handshake failed"; return false; } } // Start the I/O loop. mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this)); Loading @@ -204,7 +216,7 @@ bool DnsTlsSocket::startHandshake() { return true; } bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { bssl::UniquePtr<SSL> DnsTlsSocket::prepareForSslConnect(int fd) { if (!mSslCtx) { LOG(ERROR) << "Internal error: context is null in sslConnect"; return nullptr; Loading Loading @@ -247,6 +259,15 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { LOG(DEBUG) << "No session available"; } return ssl; } bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { bssl::UniquePtr<SSL> ssl; if (ssl = prepareForSslConnect(fd); !ssl) { return nullptr; } for (;;) { LOG(DEBUG) << " Calling SSL_connect with mark 0x" << std::hex << mMark; int ret = SSL_connect(ssl.get()); Loading @@ -259,7 +280,7 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { // SSL_ERROR_WANT_READ is returned because the application data has been sent during // the TCP connection handshake, the device is waiting for the SSL handshake reply // from the server. if (int err = waitForReading(fd, mServer.connectTimeout.count()); err <= 0) { if (int err = waitForReading(fd, mConnectTimeoutMs); err <= 0) { PLOG(WARNING) << "SSL_connect read error " << err << ", mark 0x" << std::hex << mMark; return nullptr; Loading @@ -268,7 +289,7 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { case SSL_ERROR_WANT_WRITE: // If no application data is sent during the TCP connection handshake, the // device is waiting for the connection established to perform SSL handshake. if (int err = waitForWriting(fd, mServer.connectTimeout.count()); err <= 0) { if (int err = waitForWriting(fd, mConnectTimeoutMs); err <= 0) { PLOG(WARNING) << "SSL_connect write error " << err << ", mark 0x" << std::hex << mMark; return nullptr; Loading @@ -286,6 +307,59 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { return ssl; } bssl::UniquePtr<SSL> DnsTlsSocket::sslConnectV2(int fd) { bssl::UniquePtr<SSL> ssl; if (ssl = prepareForSslConnect(fd); !ssl) { return nullptr; } for (;;) { LOG(DEBUG) << " Calling SSL_connect with mark 0x" << std::hex << mMark; int ret = SSL_connect(ssl.get()); LOG(DEBUG) << " SSL_connect returned " << ret << " with mark 0x" << std::hex << mMark; if (ret == 1) break; // SSL handshake complete; enum { SSLFD = 0, EVENTFD = 1 }; pollfd fds[2] = { {.fd = mSslFd.get(), .events = 0}, {.fd = mShutdownEvent.get(), .events = POLLIN}, }; const int ssl_err = SSL_get_error(ssl.get(), ret); switch (ssl_err) { case SSL_ERROR_WANT_READ: fds[SSLFD].events = POLLIN; break; case SSL_ERROR_WANT_WRITE: fds[SSLFD].events = POLLOUT; break; default: PLOG(WARNING) << "SSL_connect ssl error =" << ssl_err << ", mark 0x" << std::hex << mMark; return nullptr; } int n = TEMP_FAILURE_RETRY(poll(fds, std::size(fds), mConnectTimeoutMs)); if (n <= 0) { PLOG(WARNING) << ((n == 0) ? "handshake timeout" : "Poll failed"); return nullptr; } if (fds[EVENTFD].revents & (POLLIN | POLLERR)) { LOG(WARNING) << "Got shutdown request during handshake"; return nullptr; } if (fds[SSLFD].revents & POLLERR) { LOG(WARNING) << "Got POLLERR on SSLFD during handshake"; return nullptr; } } LOG(DEBUG) << mMark << " handshake complete"; return ssl; } void DnsTlsSocket::sslDisconnect() { if (mSsl) { SSL_shutdown(mSsl.get()); Loading Loading @@ -326,9 +400,26 @@ void DnsTlsSocket::loop() { std::deque<std::vector<uint8_t>> q; const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000; transitionState(State::CONNECTING, State::CONNECTED); setThreadName(StringPrintf("TlsListen_%u", mMark & 0xffff).c_str()); if (mAsyncHandshake) { if (Status status = tcpConnect(); !status.ok()) { LOG(WARNING) << "TCP Handshake failed: " << status.code(); mObserver->onClosed(); transitionState(State::CONNECTING, State::WAIT_FOR_DELETE); return; } if (mSsl = sslConnectV2(mSslFd.get()); !mSsl) { LOG(WARNING) << "TLS Handshake failed"; mObserver->onClosed(); transitionState(State::CONNECTING, State::WAIT_FOR_DELETE); return; } LOG(DEBUG) << "Handshaking succeeded"; } transitionState(State::CONNECTING, State::CONNECTED); while (true) { // poll() ignores negative fds struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } }; Loading Loading @@ -446,6 +537,11 @@ void DnsTlsSocket::requestLoopShutdown() { // Write a negative number to the eventfd. This triggers an immediate shutdown. incrementEventFd(INT64_MIN); } if (mShutdownEvent != -1) { if (eventfd_write(mShutdownEvent.get(), INT64_MIN) == -1) { PLOG(ERROR) << "Failed to write to mShutdownEvent"; } } } bool DnsTlsSocket::incrementEventFd(const int64_t count) { Loading