Loading DnsTlsServer.cpp +1 −5 Original line number Diff line number Diff line Loading @@ -109,11 +109,7 @@ bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y // Returns a tuple of references to the elements of s. auto make_tie(const DnsTlsServer& s) { return std::tie( s.ss, s.name, s.protocol ); return std::tie(s.ss, s.name, s.protocol, s.connectTimeout); } bool DnsTlsServer::operator <(const DnsTlsServer& other) const { Loading DnsTlsServer.h +6 −0 Original line number Diff line number Diff line Loading @@ -17,6 +17,7 @@ #ifndef _DNS_DNSTLSSERVER_H #define _DNS_DNSTLSSERVER_H #include <chrono> #include <set> #include <string> #include <vector> Loading Loading @@ -58,6 +59,11 @@ struct DnsTlsServer { // Placeholder. More protocols might be defined in the future. int protocol = IPPROTO_TCP; // The time to wait for the attempt on connecting to the server. // Set the default value 127 seconds to be consistent with TCP connect timeout. // (presume net.ipv4.tcp_syn_retries = 6) std::chrono::milliseconds connectTimeout = std::chrono::milliseconds(127 * 1000); // Exact comparison of DnsTlsServer objects bool operator<(const DnsTlsServer& other) const; bool operator==(const DnsTlsServer& other) const; Loading DnsTlsSocket.cpp +21 −16 Original line number Diff line number Diff line Loading @@ -59,16 +59,14 @@ namespace { constexpr const char kCaCertDir[] = "/system/etc/security/cacerts"; int waitForReading(int fd) { struct pollfd fds = { .fd = fd, .events = POLLIN }; const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1)); return ret; int waitForReading(int fd, int timeoutMs = -1) { pollfd fds = {.fd = fd, .events = POLLIN}; return TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs)); } int waitForWriting(int fd) { struct pollfd fds = { .fd = fd, .events = POLLOUT }; const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1)); return ret; int waitForWriting(int fd, int timeoutMs = -1) { pollfd fds = {.fd = fd, .events = POLLOUT}; return TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs)); } std::string markToFwmarkString(unsigned mMark) { Loading Loading @@ -250,14 +248,21 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { const int ssl_err = SSL_get_error(ssl.get(), ret); switch (ssl_err) { case SSL_ERROR_WANT_READ: if (waitForReading(fd) != 1) { PLOG(WARNING) << "SSL_connect read error, " << markToFwmarkString(mMark); // SSL_ERROR_WANT_READ is returned because the application data has been sent during // the TCP connection handshake, the device is waiting for the SSL handshake reply // from the server. if (int err = waitForReading(fd, mServer.connectTimeout.count()); err <= 0) { PLOG(WARNING) << "SSL_connect read error " << err << ", " << markToFwmarkString(mMark); return nullptr; } break; case SSL_ERROR_WANT_WRITE: if (waitForWriting(fd) != 1) { PLOG(WARNING) << "SSL_connect write error, " << markToFwmarkString(mMark); // If no application data is sent during the TCP connection handshake, the // device is waiting for the connection established to perform SSL handshake. if (int err = waitForWriting(fd, mServer.connectTimeout.count()); err <= 0) { PLOG(WARNING) << "SSL_connect write error " << err << ", " << markToFwmarkString(mMark); return nullptr; } break; Loading Loading @@ -291,8 +296,8 @@ bool DnsTlsSocket::sslWrite(const Slice buffer) { const int ssl_err = SSL_get_error(mSsl.get(), ret); switch (ssl_err) { case SSL_ERROR_WANT_WRITE: if (waitForWriting(mSslFd.get()) != 1) { LOG(DEBUG) << "SSL_write error"; if (int err = waitForWriting(mSslFd.get()); err <= 0) { PLOG(WARNING) << "Poll failed in sslWrite, error " << err; return false; } continue; Loading Loading @@ -462,8 +467,8 @@ int DnsTlsSocket::sslRead(const Slice buffer, bool wait) { if (ret < 0) { const int ssl_err = SSL_get_error(mSsl.get(), ret); if (wait && ssl_err == SSL_ERROR_WANT_READ) { if (waitForReading(mSslFd.get()) != 1) { LOG(DEBUG) << "Poll failed in sslRead: " << errno; if (int err = waitForReading(mSslFd.get()); err <= 0) { PLOG(WARNING) << "Poll failed in sslRead, error " << err; return SSL_ERROR_SYSCALL; } continue; Loading PrivateDnsConfiguration.cpp +13 −2 Original line number Diff line number Diff line Loading @@ -29,6 +29,8 @@ #include "netd_resolv/resolv.h" #include "netdutils/BackoffSequence.h" using std::chrono::milliseconds; namespace android { namespace net { Loading Loading @@ -59,9 +61,9 @@ bool parseServer(const char* server, sockaddr_storage* parsed) { int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers, const std::string& name, const std::string& caCert) { const std::string& caCert, int32_t connectTimeoutMs) { LOG(DEBUG) << "PrivateDnsConfiguration::set(" << netId << ", 0x" << std::hex << mark << std::dec << ", " << servers.size() << ", " << name << ")"; << ", " << servers.size() << ", " << name << ", " << connectTimeoutMs << "ms)"; // Parse the list of servers that has been passed in std::set<DnsTlsServer> tlsServers; Loading @@ -73,6 +75,15 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, DnsTlsServer server(parsed); server.name = name; server.certificate = caCert; // connectTimeoutMs = 0: use the default timeout value. // connectTimeoutMs < 0: invalid timeout value. if (connectTimeoutMs > 0) { // Set a specific timeout value but limit it to be at least 1 second. server.connectTimeout = (connectTimeoutMs < 1000) ? milliseconds(1000) : milliseconds(connectTimeoutMs); } tlsServers.insert(server); } Loading PrivateDnsConfiguration.h +2 −1 Original line number Diff line number Diff line Loading @@ -53,7 +53,8 @@ struct PrivateDnsStatus { class PrivateDnsConfiguration { public: int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers, const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock); const std::string& name, const std::string& caCert, int32_t connectTimeoutMs) EXCLUDES(mPrivateDnsLock); PrivateDnsStatus getStatus(unsigned netId) EXCLUDES(mPrivateDnsLock); Loading Loading
DnsTlsServer.cpp +1 −5 Original line number Diff line number Diff line Loading @@ -109,11 +109,7 @@ bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y // Returns a tuple of references to the elements of s. auto make_tie(const DnsTlsServer& s) { return std::tie( s.ss, s.name, s.protocol ); return std::tie(s.ss, s.name, s.protocol, s.connectTimeout); } bool DnsTlsServer::operator <(const DnsTlsServer& other) const { Loading
DnsTlsServer.h +6 −0 Original line number Diff line number Diff line Loading @@ -17,6 +17,7 @@ #ifndef _DNS_DNSTLSSERVER_H #define _DNS_DNSTLSSERVER_H #include <chrono> #include <set> #include <string> #include <vector> Loading Loading @@ -58,6 +59,11 @@ struct DnsTlsServer { // Placeholder. More protocols might be defined in the future. int protocol = IPPROTO_TCP; // The time to wait for the attempt on connecting to the server. // Set the default value 127 seconds to be consistent with TCP connect timeout. // (presume net.ipv4.tcp_syn_retries = 6) std::chrono::milliseconds connectTimeout = std::chrono::milliseconds(127 * 1000); // Exact comparison of DnsTlsServer objects bool operator<(const DnsTlsServer& other) const; bool operator==(const DnsTlsServer& other) const; Loading
DnsTlsSocket.cpp +21 −16 Original line number Diff line number Diff line Loading @@ -59,16 +59,14 @@ namespace { constexpr const char kCaCertDir[] = "/system/etc/security/cacerts"; int waitForReading(int fd) { struct pollfd fds = { .fd = fd, .events = POLLIN }; const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1)); return ret; int waitForReading(int fd, int timeoutMs = -1) { pollfd fds = {.fd = fd, .events = POLLIN}; return TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs)); } int waitForWriting(int fd) { struct pollfd fds = { .fd = fd, .events = POLLOUT }; const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1)); return ret; int waitForWriting(int fd, int timeoutMs = -1) { pollfd fds = {.fd = fd, .events = POLLOUT}; return TEMP_FAILURE_RETRY(poll(&fds, 1, timeoutMs)); } std::string markToFwmarkString(unsigned mMark) { Loading Loading @@ -250,14 +248,21 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { const int ssl_err = SSL_get_error(ssl.get(), ret); switch (ssl_err) { case SSL_ERROR_WANT_READ: if (waitForReading(fd) != 1) { PLOG(WARNING) << "SSL_connect read error, " << markToFwmarkString(mMark); // SSL_ERROR_WANT_READ is returned because the application data has been sent during // the TCP connection handshake, the device is waiting for the SSL handshake reply // from the server. if (int err = waitForReading(fd, mServer.connectTimeout.count()); err <= 0) { PLOG(WARNING) << "SSL_connect read error " << err << ", " << markToFwmarkString(mMark); return nullptr; } break; case SSL_ERROR_WANT_WRITE: if (waitForWriting(fd) != 1) { PLOG(WARNING) << "SSL_connect write error, " << markToFwmarkString(mMark); // If no application data is sent during the TCP connection handshake, the // device is waiting for the connection established to perform SSL handshake. if (int err = waitForWriting(fd, mServer.connectTimeout.count()); err <= 0) { PLOG(WARNING) << "SSL_connect write error " << err << ", " << markToFwmarkString(mMark); return nullptr; } break; Loading Loading @@ -291,8 +296,8 @@ bool DnsTlsSocket::sslWrite(const Slice buffer) { const int ssl_err = SSL_get_error(mSsl.get(), ret); switch (ssl_err) { case SSL_ERROR_WANT_WRITE: if (waitForWriting(mSslFd.get()) != 1) { LOG(DEBUG) << "SSL_write error"; if (int err = waitForWriting(mSslFd.get()); err <= 0) { PLOG(WARNING) << "Poll failed in sslWrite, error " << err; return false; } continue; Loading Loading @@ -462,8 +467,8 @@ int DnsTlsSocket::sslRead(const Slice buffer, bool wait) { if (ret < 0) { const int ssl_err = SSL_get_error(mSsl.get(), ret); if (wait && ssl_err == SSL_ERROR_WANT_READ) { if (waitForReading(mSslFd.get()) != 1) { LOG(DEBUG) << "Poll failed in sslRead: " << errno; if (int err = waitForReading(mSslFd.get()); err <= 0) { PLOG(WARNING) << "Poll failed in sslRead, error " << err; return SSL_ERROR_SYSCALL; } continue; Loading
PrivateDnsConfiguration.cpp +13 −2 Original line number Diff line number Diff line Loading @@ -29,6 +29,8 @@ #include "netd_resolv/resolv.h" #include "netdutils/BackoffSequence.h" using std::chrono::milliseconds; namespace android { namespace net { Loading Loading @@ -59,9 +61,9 @@ bool parseServer(const char* server, sockaddr_storage* parsed) { int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers, const std::string& name, const std::string& caCert) { const std::string& caCert, int32_t connectTimeoutMs) { LOG(DEBUG) << "PrivateDnsConfiguration::set(" << netId << ", 0x" << std::hex << mark << std::dec << ", " << servers.size() << ", " << name << ")"; << ", " << servers.size() << ", " << name << ", " << connectTimeoutMs << "ms)"; // Parse the list of servers that has been passed in std::set<DnsTlsServer> tlsServers; Loading @@ -73,6 +75,15 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark, DnsTlsServer server(parsed); server.name = name; server.certificate = caCert; // connectTimeoutMs = 0: use the default timeout value. // connectTimeoutMs < 0: invalid timeout value. if (connectTimeoutMs > 0) { // Set a specific timeout value but limit it to be at least 1 second. server.connectTimeout = (connectTimeoutMs < 1000) ? milliseconds(1000) : milliseconds(connectTimeoutMs); } tlsServers.insert(server); } Loading
PrivateDnsConfiguration.h +2 −1 Original line number Diff line number Diff line Loading @@ -53,7 +53,8 @@ struct PrivateDnsStatus { class PrivateDnsConfiguration { public: int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers, const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock); const std::string& name, const std::string& caCert, int32_t connectTimeoutMs) EXCLUDES(mPrivateDnsLock); PrivateDnsStatus getStatus(unsigned netId) EXCLUDES(mPrivateDnsLock); Loading