Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 8187f12d authored by Treehugger Robot's avatar Treehugger Robot Committed by Gerrit Code Review
Browse files

Merge "Copy queries synchronously in DnsTlsSocket"

parents ba7bef92 2187abe0
Loading
Loading
Loading
Loading
+65 −47
Original line number Diff line number Diff line
@@ -19,14 +19,15 @@

#include "netd_resolv/DnsTlsSocket.h"

#include <algorithm>
#include <arpa/inet.h>
#include <arpa/nameser.h>
#include <errno.h>
#include <linux/tcp.h>
#include <openssl/err.h>
#include <openssl/sha.h>
#include <sys/eventfd.h>
#include <sys/poll.h>
#include <algorithm>

#include "netd_resolv/DnsTlsSessionCache.h"
#include "netd_resolv/IDnsTlsSocketObserver.h"
@@ -163,14 +164,8 @@ bool DnsTlsSocket::initialize() {
    if (!mSsl) {
        return false;
    }
    int sv[2];
    if (socketpair(AF_LOCAL, SOCK_SEQPACKET, 0, sv)) {
        return false;
    }
    // The two sockets are perfectly symmetrical, so the choice of which one is
    // "in" and which one is "out" is arbitrary.
    mIpcInFd.reset(sv[0]);
    mIpcOutFd.reset(sv[1]);

    mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));

    // Start the I/O loop.
    mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this));
@@ -338,26 +333,25 @@ bool DnsTlsSocket::sslWrite(const Slice buffer) {

void DnsTlsSocket::loop() {
    std::lock_guard guard(mLock);
    // Buffer at most one query.
    Query q;
    std::deque<std::vector<uint8_t>> q;

    const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000;
    while (true) {
        // poll() ignores negative fds
        struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } };
        enum { SSLFD = 0, IPCFD = 1 };
        enum { SSLFD = 0, EVENTFD = 1 };

        // Always listen for a response from server.
        fds[SSLFD].fd = mSslFd.get();
        fds[SSLFD].events = POLLIN;

        // If we have a pending query, also wait for space
        // to write it, otherwise listen for a new query.
        if (!q.query.empty()) {
        // If we have pending queries, wait for space to write one.
        // Otherwise, listen for new queries.
        if (!q.empty()) {
            fds[SSLFD].events |= POLLOUT;
        } else {
            fds[IPCFD].fd = mIpcOutFd.get();
            fds[IPCFD].events = POLLIN;
            fds[EVENTFD].fd = mEventFd.get();
            fds[EVENTFD].events = POLLIN;
        }

        const int s = TEMP_FAILURE_RETRY(poll(fds, std::size(fds), timeout_msecs));
@@ -375,28 +369,44 @@ void DnsTlsSocket::loop() {
                break;
            }
        }
        if (fds[IPCFD].revents & (POLLIN | POLLERR)) {
            int res = read(mIpcOutFd.get(), &q, sizeof(q));
        if (fds[EVENTFD].revents & (POLLIN | POLLERR)) {
            int64_t num_queries;
            ssize_t res = read(mEventFd.get(), &num_queries, sizeof(num_queries));
            if (res < 0) {
                ALOGW("Error during IPC read");
                ALOGW("Error during eventfd read");
                break;
            } else if (res == 0) {
                ALOGV("IPC channel closed; disconnecting");
                ALOGV("eventfd closed; disconnecting");
                break;
            } else if (res != sizeof(num_queries)) {
                ALOGE("Int size mismatch: %zd != %zu", res, sizeof(num_queries));
                break;
            } else if (res != sizeof(q)) {
                ALOGE("Struct size mismatch: %d != %zu", res, sizeof(q));
            } else if (num_queries <= 0) {
                ALOGE("eventfd reads should always be positive");
                break;
            }
            // Take ownership of all pending queries.  (q is always empty here.)
            mQueue.swap(q);
            // The writing thread writes to mQueue and then increments mEventFd, so
            // there should be at least num_queries entries in mQueue.
            if (q.size() < (uint64_t) num_queries) {
                ALOGE("Synchronization error");
                break;
            }
        } else if (fds[SSLFD].revents & POLLOUT) {
            // query cannot be null here.
            if (!sendQuery(q)) {
            // q cannot be empty here.
            // Sending the entire queue here would risk a TCP flow control deadlock, so
            // we only send a single query on each cycle of this loop.
            // TODO: Coalesce multiple pending queries if there is enough space in the
            // write buffer.
            if (!sendQuery(q.front())) {
                break;
            }
            q = Query();  // Reset q to empty
            q.pop_front();
        }
    }
    ALOGV("Closing IPC read FD");
    mIpcOutFd.reset();
    ALOGV("Closing event FD");
    mEventFd.reset();
    ALOGV("Disconnecting");
    sslDisconnect();
    ALOGV("Calling onClosed");
@@ -407,7 +417,12 @@ void DnsTlsSocket::loop() {
DnsTlsSocket::~DnsTlsSocket() {
    ALOGV("Destructor");
    // This will trigger an orderly shutdown in loop().
    mIpcInFd.reset();
    // In principle there is a data race here: If there is an I/O error in the network thread
    // simultaneous with a call to the destructor in a different thread, both threads could
    // attempt to call mEventFd.reset() at the same time.  However, the implementation of
    // UniqueFd::reset appears to be thread-safe, and neither thread reads or writes mEventFd
    // after this point, so we don't expect an issue in practice.
    mEventFd.reset();
    {
        // Wait for the orderly shutdown to complete.
        std::lock_guard guard(mLock);
@@ -425,12 +440,28 @@ DnsTlsSocket::~DnsTlsSocket() {
}

bool DnsTlsSocket::query(uint16_t id, const Slice query) {
    const Query q = { .id = id, .query = query };
    if (!mIpcInFd) {
    if (!mEventFd) {
        return false;
    }
    int written = write(mIpcInFd.get(), &q, sizeof(q));
    return written == sizeof(q);

    // Compose the entire message in a single buffer, so that it can be
    // sent as a single TLS record.
    std::vector<uint8_t> buf(query.size() + 4);
    // Write 2-byte length
    uint16_t len = query.size() + 2;  // + 2 for the ID.
    buf[0] = len >> 8;
    buf[1] = len;
    // Write 2-byte ID
    buf[2] = id >> 8;
    buf[3] = id;
    // Copy body
    std::memcpy(buf.data() + 4, query.base(), query.size());

    mQueue.push(std::move(buf));
    // Increment the mEventFd counter by 1.
    constexpr int64_t num_queries = 1;
    int written = write(mEventFd.get(), &num_queries, sizeof(num_queries));
    return written == sizeof(num_queries);
}

// Read exactly len bytes into buffer or fail with an SSL error code
@@ -464,20 +495,7 @@ int DnsTlsSocket::sslRead(const Slice buffer, bool wait) {
    return SSL_ERROR_NONE;
}

bool DnsTlsSocket::sendQuery(const Query& q) {
    ALOGV("sending query");
    // Compose the entire message in a single buffer, so that it can be
    // sent as a single TLS record.
    std::vector<uint8_t> buf(q.query.size() + 4);
    // Write 2-byte length
    uint16_t len = q.query.size() + 2; // + 2 for the ID.
    buf[0] = len >> 8;
    buf[1] = len;
    // Write 2-byte ID
    buf[2] = q.id >> 8;
    buf[3] = q.id;
    // Copy body
    std::memcpy(buf.data() + 4, q.query.base(), q.query.size());
bool DnsTlsSocket::sendQuery(const std::vector<uint8_t>& buf) {
    if (!sslWrite(netdutils::makeSlice(buf))) {
        return false;
    }

LockedQueue.h

0 → 100644
+52 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef _DNS_LOCKED_QUEUE_H
#define _DNS_LOCKED_QUEUE_H

#include <algorithm>
#include <deque>
#include <mutex>

#include <android-base/thread_annotations.h>

namespace android {
namespace net {

template <typename T>
class LockedQueue {
  public:
    // Push an item onto the queue.
    void push(T item) {
        std::lock_guard guard(mLock);
        mQueue.push_front(std::move(item));
    }

    // Swap out the contents of the queue
    void swap(std::deque<T>& other) {
        std::lock_guard guard(mLock);
        mQueue.swap(other);
    }

  private:
    std::mutex mLock;
    std::deque<T> mQueue GUARDED_BY(mLock);
};

}  // end of namespace net
}  // end of namespace android

#endif  // _DNS_LOCKEDQUEUE_H
+12 −12
Original line number Diff line number Diff line
@@ -28,6 +28,7 @@

#include "DnsTlsServer.h"
#include "IDnsTlsSocket.h"
#include "LockedQueue.h"
#include "params.h"

namespace android {
@@ -96,20 +97,19 @@ private:
    // will return SSL_ERROR_WANT_READ if there is no data from the server to read.
    int sslRead(const Slice buffer, bool wait) REQUIRES(mLock);

    struct Query {
        uint16_t id;
        Slice query;
    };

    bool sendQuery(const Query& q) REQUIRES(mLock);
    bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);
    bool readResponse() REQUIRES(mLock);

    // SOCK_SEQPACKET socket pair used for sending queries from myriad query
    // threads to the SSL thread.  EOF indicates a close request.
    // We have to use a socket pair (i.e. a pipe) because the SSL thread needs
    // to wait in poll() for input from either a remote server or a query thread.
    base::unique_fd mIpcInFd;
    base::unique_fd mIpcOutFd GUARDED_BY(mLock);
    // Queue of pending queries.  query() pushes items onto the queue and notifies
    // the loop thread by incrementing mEventFd.  loop() reads items off the queue.
    LockedQueue<std::vector<uint8_t>> mQueue;

    // eventfd socket used for notifying the SSL thread when queries are ready to send.
    // This socket acts similarly to an atomic counter, incremented by query() and cleared
    // by loop().  We have to use a socket because the SSL thread needs to wait in poll()
    // for input from either a remote server or a query thread.
    // EOF indicates a close request.
    base::unique_fd mEventFd;

    // SSL Socket fields.
    bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock);