Loading DnsTlsSocket.cpp +65 −47 Original line number Diff line number Diff line Loading @@ -19,14 +19,15 @@ #include "netd_resolv/DnsTlsSocket.h" #include <algorithm> #include <arpa/inet.h> #include <arpa/nameser.h> #include <errno.h> #include <linux/tcp.h> #include <openssl/err.h> #include <openssl/sha.h> #include <sys/eventfd.h> #include <sys/poll.h> #include <algorithm> #include "netd_resolv/DnsTlsSessionCache.h" #include "netd_resolv/IDnsTlsSocketObserver.h" Loading Loading @@ -163,14 +164,8 @@ bool DnsTlsSocket::initialize() { if (!mSsl) { return false; } int sv[2]; if (socketpair(AF_LOCAL, SOCK_SEQPACKET, 0, sv)) { return false; } // The two sockets are perfectly symmetrical, so the choice of which one is // "in" and which one is "out" is arbitrary. mIpcInFd.reset(sv[0]); mIpcOutFd.reset(sv[1]); mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC)); // Start the I/O loop. mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this)); Loading Loading @@ -338,26 +333,25 @@ bool DnsTlsSocket::sslWrite(const Slice buffer) { void DnsTlsSocket::loop() { std::lock_guard guard(mLock); // Buffer at most one query. Query q; std::deque<std::vector<uint8_t>> q; const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000; while (true) { // poll() ignores negative fds struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } }; enum { SSLFD = 0, IPCFD = 1 }; enum { SSLFD = 0, EVENTFD = 1 }; // Always listen for a response from server. fds[SSLFD].fd = mSslFd.get(); fds[SSLFD].events = POLLIN; // If we have a pending query, also wait for space // to write it, otherwise listen for a new query. if (!q.query.empty()) { // If we have pending queries, wait for space to write one. // Otherwise, listen for new queries. if (!q.empty()) { fds[SSLFD].events |= POLLOUT; } else { fds[IPCFD].fd = mIpcOutFd.get(); fds[IPCFD].events = POLLIN; fds[EVENTFD].fd = mEventFd.get(); fds[EVENTFD].events = POLLIN; } const int s = TEMP_FAILURE_RETRY(poll(fds, std::size(fds), timeout_msecs)); Loading @@ -375,28 +369,44 @@ void DnsTlsSocket::loop() { break; } } if (fds[IPCFD].revents & (POLLIN | POLLERR)) { int res = read(mIpcOutFd.get(), &q, sizeof(q)); if (fds[EVENTFD].revents & (POLLIN | POLLERR)) { int64_t num_queries; ssize_t res = read(mEventFd.get(), &num_queries, sizeof(num_queries)); if (res < 0) { ALOGW("Error during IPC read"); ALOGW("Error during eventfd read"); break; } else if (res == 0) { ALOGV("IPC channel closed; disconnecting"); ALOGV("eventfd closed; disconnecting"); break; } else if (res != sizeof(num_queries)) { ALOGE("Int size mismatch: %zd != %zu", res, sizeof(num_queries)); break; } else if (res != sizeof(q)) { ALOGE("Struct size mismatch: %d != %zu", res, sizeof(q)); } else if (num_queries <= 0) { ALOGE("eventfd reads should always be positive"); break; } // Take ownership of all pending queries. (q is always empty here.) mQueue.swap(q); // The writing thread writes to mQueue and then increments mEventFd, so // there should be at least num_queries entries in mQueue. if (q.size() < (uint64_t) num_queries) { ALOGE("Synchronization error"); break; } } else if (fds[SSLFD].revents & POLLOUT) { // query cannot be null here. if (!sendQuery(q)) { // q cannot be empty here. // Sending the entire queue here would risk a TCP flow control deadlock, so // we only send a single query on each cycle of this loop. // TODO: Coalesce multiple pending queries if there is enough space in the // write buffer. if (!sendQuery(q.front())) { break; } q = Query(); // Reset q to empty q.pop_front(); } } ALOGV("Closing IPC read FD"); mIpcOutFd.reset(); ALOGV("Closing event FD"); mEventFd.reset(); ALOGV("Disconnecting"); sslDisconnect(); ALOGV("Calling onClosed"); Loading @@ -407,7 +417,12 @@ void DnsTlsSocket::loop() { DnsTlsSocket::~DnsTlsSocket() { ALOGV("Destructor"); // This will trigger an orderly shutdown in loop(). mIpcInFd.reset(); // In principle there is a data race here: If there is an I/O error in the network thread // simultaneous with a call to the destructor in a different thread, both threads could // attempt to call mEventFd.reset() at the same time. However, the implementation of // UniqueFd::reset appears to be thread-safe, and neither thread reads or writes mEventFd // after this point, so we don't expect an issue in practice. mEventFd.reset(); { // Wait for the orderly shutdown to complete. std::lock_guard guard(mLock); Loading @@ -425,12 +440,28 @@ DnsTlsSocket::~DnsTlsSocket() { } bool DnsTlsSocket::query(uint16_t id, const Slice query) { const Query q = { .id = id, .query = query }; if (!mIpcInFd) { if (!mEventFd) { return false; } int written = write(mIpcInFd.get(), &q, sizeof(q)); return written == sizeof(q); // Compose the entire message in a single buffer, so that it can be // sent as a single TLS record. std::vector<uint8_t> buf(query.size() + 4); // Write 2-byte length uint16_t len = query.size() + 2; // + 2 for the ID. buf[0] = len >> 8; buf[1] = len; // Write 2-byte ID buf[2] = id >> 8; buf[3] = id; // Copy body std::memcpy(buf.data() + 4, query.base(), query.size()); mQueue.push(std::move(buf)); // Increment the mEventFd counter by 1. constexpr int64_t num_queries = 1; int written = write(mEventFd.get(), &num_queries, sizeof(num_queries)); return written == sizeof(num_queries); } // Read exactly len bytes into buffer or fail with an SSL error code Loading Loading @@ -464,20 +495,7 @@ int DnsTlsSocket::sslRead(const Slice buffer, bool wait) { return SSL_ERROR_NONE; } bool DnsTlsSocket::sendQuery(const Query& q) { ALOGV("sending query"); // Compose the entire message in a single buffer, so that it can be // sent as a single TLS record. std::vector<uint8_t> buf(q.query.size() + 4); // Write 2-byte length uint16_t len = q.query.size() + 2; // + 2 for the ID. buf[0] = len >> 8; buf[1] = len; // Write 2-byte ID buf[2] = q.id >> 8; buf[3] = q.id; // Copy body std::memcpy(buf.data() + 4, q.query.base(), q.query.size()); bool DnsTlsSocket::sendQuery(const std::vector<uint8_t>& buf) { if (!sslWrite(netdutils::makeSlice(buf))) { return false; } Loading LockedQueue.h 0 → 100644 +52 −0 Original line number Diff line number Diff line /* * Copyright (C) 2019 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef _DNS_LOCKED_QUEUE_H #define _DNS_LOCKED_QUEUE_H #include <algorithm> #include <deque> #include <mutex> #include <android-base/thread_annotations.h> namespace android { namespace net { template <typename T> class LockedQueue { public: // Push an item onto the queue. void push(T item) { std::lock_guard guard(mLock); mQueue.push_front(std::move(item)); } // Swap out the contents of the queue void swap(std::deque<T>& other) { std::lock_guard guard(mLock); mQueue.swap(other); } private: std::mutex mLock; std::deque<T> mQueue GUARDED_BY(mLock); }; } // end of namespace net } // end of namespace android #endif // _DNS_LOCKEDQUEUE_H include/netd_resolv/DnsTlsSocket.h +12 −12 Original line number Diff line number Diff line Loading @@ -28,6 +28,7 @@ #include "DnsTlsServer.h" #include "IDnsTlsSocket.h" #include "LockedQueue.h" #include "params.h" namespace android { Loading Loading @@ -96,20 +97,19 @@ private: // will return SSL_ERROR_WANT_READ if there is no data from the server to read. int sslRead(const Slice buffer, bool wait) REQUIRES(mLock); struct Query { uint16_t id; Slice query; }; bool sendQuery(const Query& q) REQUIRES(mLock); bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock); bool readResponse() REQUIRES(mLock); // SOCK_SEQPACKET socket pair used for sending queries from myriad query // threads to the SSL thread. EOF indicates a close request. // We have to use a socket pair (i.e. a pipe) because the SSL thread needs // to wait in poll() for input from either a remote server or a query thread. base::unique_fd mIpcInFd; base::unique_fd mIpcOutFd GUARDED_BY(mLock); // Queue of pending queries. query() pushes items onto the queue and notifies // the loop thread by incrementing mEventFd. loop() reads items off the queue. LockedQueue<std::vector<uint8_t>> mQueue; // eventfd socket used for notifying the SSL thread when queries are ready to send. // This socket acts similarly to an atomic counter, incremented by query() and cleared // by loop(). We have to use a socket because the SSL thread needs to wait in poll() // for input from either a remote server or a query thread. // EOF indicates a close request. base::unique_fd mEventFd; // SSL Socket fields. bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock); Loading Loading
DnsTlsSocket.cpp +65 −47 Original line number Diff line number Diff line Loading @@ -19,14 +19,15 @@ #include "netd_resolv/DnsTlsSocket.h" #include <algorithm> #include <arpa/inet.h> #include <arpa/nameser.h> #include <errno.h> #include <linux/tcp.h> #include <openssl/err.h> #include <openssl/sha.h> #include <sys/eventfd.h> #include <sys/poll.h> #include <algorithm> #include "netd_resolv/DnsTlsSessionCache.h" #include "netd_resolv/IDnsTlsSocketObserver.h" Loading Loading @@ -163,14 +164,8 @@ bool DnsTlsSocket::initialize() { if (!mSsl) { return false; } int sv[2]; if (socketpair(AF_LOCAL, SOCK_SEQPACKET, 0, sv)) { return false; } // The two sockets are perfectly symmetrical, so the choice of which one is // "in" and which one is "out" is arbitrary. mIpcInFd.reset(sv[0]); mIpcOutFd.reset(sv[1]); mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC)); // Start the I/O loop. mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this)); Loading Loading @@ -338,26 +333,25 @@ bool DnsTlsSocket::sslWrite(const Slice buffer) { void DnsTlsSocket::loop() { std::lock_guard guard(mLock); // Buffer at most one query. Query q; std::deque<std::vector<uint8_t>> q; const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000; while (true) { // poll() ignores negative fds struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } }; enum { SSLFD = 0, IPCFD = 1 }; enum { SSLFD = 0, EVENTFD = 1 }; // Always listen for a response from server. fds[SSLFD].fd = mSslFd.get(); fds[SSLFD].events = POLLIN; // If we have a pending query, also wait for space // to write it, otherwise listen for a new query. if (!q.query.empty()) { // If we have pending queries, wait for space to write one. // Otherwise, listen for new queries. if (!q.empty()) { fds[SSLFD].events |= POLLOUT; } else { fds[IPCFD].fd = mIpcOutFd.get(); fds[IPCFD].events = POLLIN; fds[EVENTFD].fd = mEventFd.get(); fds[EVENTFD].events = POLLIN; } const int s = TEMP_FAILURE_RETRY(poll(fds, std::size(fds), timeout_msecs)); Loading @@ -375,28 +369,44 @@ void DnsTlsSocket::loop() { break; } } if (fds[IPCFD].revents & (POLLIN | POLLERR)) { int res = read(mIpcOutFd.get(), &q, sizeof(q)); if (fds[EVENTFD].revents & (POLLIN | POLLERR)) { int64_t num_queries; ssize_t res = read(mEventFd.get(), &num_queries, sizeof(num_queries)); if (res < 0) { ALOGW("Error during IPC read"); ALOGW("Error during eventfd read"); break; } else if (res == 0) { ALOGV("IPC channel closed; disconnecting"); ALOGV("eventfd closed; disconnecting"); break; } else if (res != sizeof(num_queries)) { ALOGE("Int size mismatch: %zd != %zu", res, sizeof(num_queries)); break; } else if (res != sizeof(q)) { ALOGE("Struct size mismatch: %d != %zu", res, sizeof(q)); } else if (num_queries <= 0) { ALOGE("eventfd reads should always be positive"); break; } // Take ownership of all pending queries. (q is always empty here.) mQueue.swap(q); // The writing thread writes to mQueue and then increments mEventFd, so // there should be at least num_queries entries in mQueue. if (q.size() < (uint64_t) num_queries) { ALOGE("Synchronization error"); break; } } else if (fds[SSLFD].revents & POLLOUT) { // query cannot be null here. if (!sendQuery(q)) { // q cannot be empty here. // Sending the entire queue here would risk a TCP flow control deadlock, so // we only send a single query on each cycle of this loop. // TODO: Coalesce multiple pending queries if there is enough space in the // write buffer. if (!sendQuery(q.front())) { break; } q = Query(); // Reset q to empty q.pop_front(); } } ALOGV("Closing IPC read FD"); mIpcOutFd.reset(); ALOGV("Closing event FD"); mEventFd.reset(); ALOGV("Disconnecting"); sslDisconnect(); ALOGV("Calling onClosed"); Loading @@ -407,7 +417,12 @@ void DnsTlsSocket::loop() { DnsTlsSocket::~DnsTlsSocket() { ALOGV("Destructor"); // This will trigger an orderly shutdown in loop(). mIpcInFd.reset(); // In principle there is a data race here: If there is an I/O error in the network thread // simultaneous with a call to the destructor in a different thread, both threads could // attempt to call mEventFd.reset() at the same time. However, the implementation of // UniqueFd::reset appears to be thread-safe, and neither thread reads or writes mEventFd // after this point, so we don't expect an issue in practice. mEventFd.reset(); { // Wait for the orderly shutdown to complete. std::lock_guard guard(mLock); Loading @@ -425,12 +440,28 @@ DnsTlsSocket::~DnsTlsSocket() { } bool DnsTlsSocket::query(uint16_t id, const Slice query) { const Query q = { .id = id, .query = query }; if (!mIpcInFd) { if (!mEventFd) { return false; } int written = write(mIpcInFd.get(), &q, sizeof(q)); return written == sizeof(q); // Compose the entire message in a single buffer, so that it can be // sent as a single TLS record. std::vector<uint8_t> buf(query.size() + 4); // Write 2-byte length uint16_t len = query.size() + 2; // + 2 for the ID. buf[0] = len >> 8; buf[1] = len; // Write 2-byte ID buf[2] = id >> 8; buf[3] = id; // Copy body std::memcpy(buf.data() + 4, query.base(), query.size()); mQueue.push(std::move(buf)); // Increment the mEventFd counter by 1. constexpr int64_t num_queries = 1; int written = write(mEventFd.get(), &num_queries, sizeof(num_queries)); return written == sizeof(num_queries); } // Read exactly len bytes into buffer or fail with an SSL error code Loading Loading @@ -464,20 +495,7 @@ int DnsTlsSocket::sslRead(const Slice buffer, bool wait) { return SSL_ERROR_NONE; } bool DnsTlsSocket::sendQuery(const Query& q) { ALOGV("sending query"); // Compose the entire message in a single buffer, so that it can be // sent as a single TLS record. std::vector<uint8_t> buf(q.query.size() + 4); // Write 2-byte length uint16_t len = q.query.size() + 2; // + 2 for the ID. buf[0] = len >> 8; buf[1] = len; // Write 2-byte ID buf[2] = q.id >> 8; buf[3] = q.id; // Copy body std::memcpy(buf.data() + 4, q.query.base(), q.query.size()); bool DnsTlsSocket::sendQuery(const std::vector<uint8_t>& buf) { if (!sslWrite(netdutils::makeSlice(buf))) { return false; } Loading
LockedQueue.h 0 → 100644 +52 −0 Original line number Diff line number Diff line /* * Copyright (C) 2019 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef _DNS_LOCKED_QUEUE_H #define _DNS_LOCKED_QUEUE_H #include <algorithm> #include <deque> #include <mutex> #include <android-base/thread_annotations.h> namespace android { namespace net { template <typename T> class LockedQueue { public: // Push an item onto the queue. void push(T item) { std::lock_guard guard(mLock); mQueue.push_front(std::move(item)); } // Swap out the contents of the queue void swap(std::deque<T>& other) { std::lock_guard guard(mLock); mQueue.swap(other); } private: std::mutex mLock; std::deque<T> mQueue GUARDED_BY(mLock); }; } // end of namespace net } // end of namespace android #endif // _DNS_LOCKEDQUEUE_H
include/netd_resolv/DnsTlsSocket.h +12 −12 Original line number Diff line number Diff line Loading @@ -28,6 +28,7 @@ #include "DnsTlsServer.h" #include "IDnsTlsSocket.h" #include "LockedQueue.h" #include "params.h" namespace android { Loading Loading @@ -96,20 +97,19 @@ private: // will return SSL_ERROR_WANT_READ if there is no data from the server to read. int sslRead(const Slice buffer, bool wait) REQUIRES(mLock); struct Query { uint16_t id; Slice query; }; bool sendQuery(const Query& q) REQUIRES(mLock); bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock); bool readResponse() REQUIRES(mLock); // SOCK_SEQPACKET socket pair used for sending queries from myriad query // threads to the SSL thread. EOF indicates a close request. // We have to use a socket pair (i.e. a pipe) because the SSL thread needs // to wait in poll() for input from either a remote server or a query thread. base::unique_fd mIpcInFd; base::unique_fd mIpcOutFd GUARDED_BY(mLock); // Queue of pending queries. query() pushes items onto the queue and notifies // the loop thread by incrementing mEventFd. loop() reads items off the queue. LockedQueue<std::vector<uint8_t>> mQueue; // eventfd socket used for notifying the SSL thread when queries are ready to send. // This socket acts similarly to an atomic counter, incremented by query() and cleared // by loop(). We have to use a socket because the SSL thread needs to wait in poll() // for input from either a remote server or a query thread. // EOF indicates a close request. base::unique_fd mEventFd; // SSL Socket fields. bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock); Loading