Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 0b67a11a authored by Mike Yu's avatar Mike Yu
Browse files

resolv: speed up the termination of loop thread in testing DNS and TLS servers

Use eventfd as a signal to notify the loop thread to terminate. The
loop thread can be terminated right after their stopServer() is called.

Also some minor changes:
[1] Use unique_fd for sockets to avoid fd leakage.
[2] Use ScopedAddrinfo in DnsTlsFrontend to prevent memory unfreed
    if getaddrinfo() fails.
[3] Remove the unnecessary call stopServer() in resolver_test since
    it is automatically called in the their destructor.
[4] Timeout value in the loop thread polling is set as -1 by default.

Before this change: resolv_integration_test takes 145283 ms
After this change: resolv_integration_test takes 125249 ms

Bug: 130686826
Test: runtest and netd_benchmark passed
Merged-In: I2810743d17f858a273157267495274195f04dcbc
Merged-In: Ieb729e899506caad50a8c613c4b9e75209229135
Change-Id: I97c44ff87c4adfa6242bbd3aa0970474004c16f4
(cherry picked from commit a2b8cb036ae4418dae1935dfe20833862786117c)
parent 91cf5292
Loading
Loading
Loading
Loading
+131 −92
Original line number Original line Diff line number Diff line
@@ -24,6 +24,7 @@
#include <stdlib.h>
#include <stdlib.h>
#include <string.h>
#include <string.h>
#include <sys/epoll.h>
#include <sys/epoll.h>
#include <sys/eventfd.h>
#include <sys/socket.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/types.h>
#include <unistd.h>
#include <unistd.h>
@@ -37,15 +38,17 @@
#include <log/log.h>
#include <log/log.h>
#include <netdutils/SocketOption.h>
#include <netdutils/SocketOption.h>


#include "NetdConstants.h"

using android::netdutils::enableSockopt;
using android::netdutils::enableSockopt;


namespace test {
namespace test {


std::string errno2str() {
std::string errno2str() {
    char error_msg[512] = { 0 };
    char error_msg[512] = { 0 };
    if (strerror_r(errno, error_msg, sizeof(error_msg)))
    // It actually calls __gnu_strerror_r() which returns the type |char*| rather than |int|.
        return std::string();
    // PLOG is an option though it requires lots of changes from ALOGx() to LOG(x).
    return std::string(error_msg);
    return strerror_r(errno, error_msg, sizeof(error_msg));
}
}


#define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
#define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
@@ -576,7 +579,7 @@ void DNSResponder::setEdns(Edns edns) {
}
}


bool DNSResponder::running() const {
bool DNSResponder::running() const {
    return socket_ != -1;
    return socket_.get() != -1;
}
}


bool DNSResponder::startServer() {
bool DNSResponder::startServer() {
@@ -584,70 +587,72 @@ bool DNSResponder::startServer() {
        ALOGI("server already running");
        ALOGI("server already running");
        return false;
        return false;
    }
    }

    // Set up UDP socket.
    addrinfo ai_hints{
    addrinfo ai_hints{
        .ai_family = AF_UNSPEC,
        .ai_family = AF_UNSPEC,
        .ai_socktype = SOCK_DGRAM,
        .ai_socktype = SOCK_DGRAM,
        .ai_flags = AI_PASSIVE
        .ai_flags = AI_PASSIVE
    };
    };
    addrinfo* ai_res;
    addrinfo* ai_res = nullptr;
    int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
    int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
                         &ai_hints, &ai_res);
                         &ai_hints, &ai_res);
    ScopedAddrinfo ai_res_cleanup(ai_res);
    if (rv) {
    if (rv) {
        ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
        ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
            listen_service_.c_str(), gai_strerror(rv));
            listen_service_.c_str(), gai_strerror(rv));
        return false;
        return false;
    }
    }
    int s = -1;
    for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) {
    for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) {
        s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
        android::base::unique_fd s(socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol));
        if (s < 0) continue;
        if (s.get() < 0) {
        enableSockopt(s, SOL_SOCKET, SO_REUSEPORT).ignoreError();
            APLOGI("ignore creating socket %d failed", s.get());
        enableSockopt(s, SOL_SOCKET, SO_REUSEADDR).ignoreError();
        if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
            APLOGI("bind failed for socket %d", s);
            close(s);
            s = -1;
            continue;
            continue;
        }
        }
        enableSockopt(s.get(), SOL_SOCKET, SO_REUSEPORT).ignoreError();
        enableSockopt(s.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
        std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
        std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
        if (bind(s.get(), ai->ai_addr, ai->ai_addrlen)) {
            APLOGI("failed to bind UDP %s:%s", host_str.c_str(), listen_service_.c_str());
            continue;
        }
        ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str());
        ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str());
        socket_ = std::move(s);
        break;
        break;
    }
    }
    freeaddrinfo(ai_res);

    if (s < 0) {
    int flags = fcntl(socket_.get(), F_GETFL, 0);
        ALOGI("bind() failed");
    if (flags < 0) flags = 0;
    if (fcntl(socket_.get(), F_SETFL, flags | O_NONBLOCK) < 0) {
        APLOGI("fcntl(F_SETFL) failed for socket %d", socket_.get());
        return false;
        return false;
    }
    }


    int flags = fcntl(s, F_GETFL, 0);
    // Set up eventfd socket.
    if (flags < 0) flags = 0;
    event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
    if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) {
    if (event_fd_.get() == -1) {
        APLOGI("fcntl(F_SETFL) failed for socket %d", s);
        APLOGI("failed to create eventfd %d", event_fd_.get());
        close(s);
        return false;
        return false;
    }
    }


    int ep_fd = epoll_create1(EPOLL_CLOEXEC);
    // Set up epoll socket.
    if (ep_fd < 0) {
    epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
        char error_msg[512] = { 0 };
    if (epoll_fd_.get() < 0) {
        if (strerror_r(errno, error_msg, sizeof(error_msg)))
        APLOGI("epoll_create1() failed on fd %d", epoll_fd_.get());
            strncpy(error_msg, "UNKNOWN", sizeof(error_msg));
        APLOGI("epoll_create1() failed: %s", error_msg);
        close(s);
        return false;
        return false;
    }
    }
    epoll_event ev;

    ev.events = EPOLLIN;
    ALOGI("adding socket %d to epoll", socket_.get());
    ev.data.fd = s;
    if (!addFd(socket_.get(), EPOLLIN)) {
    if (epoll_ctl(ep_fd, EPOLL_CTL_ADD, s, &ev) < 0) {
        ALOGE("failed to add the socket %d to epoll", socket_.get());
        APLOGI("epoll_ctl() failed for socket %d", s);
        return false;
        close(ep_fd);
    }
        close(s);
    ALOGI("adding eventfd %d to epoll", event_fd_.get());
    if (!addFd(event_fd_.get(), EPOLLIN)) {
        ALOGE("failed to add the eventfd %d to epoll", event_fd_.get());
        return false;
        return false;
    }
    }


    epoll_fd_ = ep_fd;
    socket_ = s;
    {
    {
        std::lock_guard lock(update_mutex_);
        std::lock_guard lock(update_mutex_);
        handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
        handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
@@ -662,17 +667,13 @@ bool DNSResponder::stopServer() {
        ALOGI("server not running");
        ALOGI("server not running");
        return false;
        return false;
    }
    }
    if (terminate_) {
    ALOGI("stopping server");
        ALOGI("LOGIC ERROR");
    if (!sendToEventFd()) {
        return false;
        return false;
    }
    }
    ALOGI("stopping server");
    terminate_ = true;
    handler_thread_.join();
    handler_thread_.join();
    close(epoll_fd_);
    epoll_fd_.reset();
    close(socket_);
    socket_.reset();
    terminate_ = false;
    socket_ = -1;
    ALOGI("server stopped successfully");
    ALOGI("server stopped successfully");
    return true;
    return true;
}
}
@@ -697,59 +698,27 @@ void DNSResponder::clearQueries() {
}
}


void DNSResponder::requestHandler() {
void DNSResponder::requestHandler() {
    epoll_event evs[1];
    epoll_event evs[EPOLL_MAX_EVENTS];
    while (!terminate_) {
    while (true) {
        int n = epoll_wait(epoll_fd_, evs, 1, poll_timeout_ms_);
        int n = epoll_wait(epoll_fd_.get(), evs, EPOLL_MAX_EVENTS, poll_timeout_ms_);
        if (n == 0) continue;
        if (n == 0) continue;
        if (n < 0) {
        if (n < 0) {
            ALOGI("epoll_wait() failed");
            APLOGI("epoll_wait() failed, n=%d", n);
            // TODO(imaipi): terminate on error.
            return;
            return;
        }
        }
        char buffer[4096];
        sockaddr_storage sa;
        socklen_t sa_len = sizeof(sa);
        ssize_t len;
        do {
            len = recvfrom(socket_, buffer, sizeof(buffer), 0,
                           (sockaddr*) &sa, &sa_len);
        } while (len < 0 && (errno == EAGAIN || errno == EINTR));
        if (len <= 0) {
            ALOGI("recvfrom() failed");
            continue;
        }
        DBGLOG("read %zd bytes", len);
        std::lock_guard lock(cv_mutex_);
        char response[4096];
        size_t response_len = sizeof(response);
        if (handleDNSRequest(buffer, len, response, &response_len) &&
            response_len > 0) {
            // place wait_for after handleDNSRequest() so we can check the number of queries in
            // test case before it got responded.
            std::unique_lock guard(cv_mutex_for_deferred_resp_);
            cv_for_deferred_resp_.wait(guard, [this]() REQUIRES(cv_mutex_for_deferred_resp_) {
                return !deferred_resp_;
            });


            len = sendto(socket_, response, response_len, 0,
        for (int i = 0; i < n; i++) {
                         reinterpret_cast<const sockaddr*>(&sa), sa_len);
            const int fd = evs[i].data.fd;
            std::string host_str =
            const uint32_t events = evs[i].events;
                addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
            if (fd == event_fd_.get() && (events & (EPOLLIN | EPOLLERR))) {
            if (len > 0) {
                handleEventFd();
                DBGLOG("sent %zu bytes to %s", len, host_str.c_str());
                return;
            } else if (fd == socket_.get() && (events & (EPOLLIN | EPOLLERR))) {
                handleQuery();
            } else {
            } else {
                APLOGI("sendto() failed for %s", host_str.c_str());
                ALOGW("unexpected epoll events 0x%x on fd %d", events, fd);
            }
            }
            // Test that the response is actually a correct DNS message.
            const char* response_end = response + len;
            DNSHeader header;
            const char* cur = header.read(response, response_end);
            if (cur == nullptr) ALOGI("response is flawed");

        } else {
            ALOGI("not responding");
        }
        }
        cv.notify_one();
    }
    }
}
}


@@ -960,4 +929,74 @@ void DNSResponder::setDeferredResp(bool deferred_resp) {
    }
    }
}
}


bool DNSResponder::addFd(int fd, uint32_t events) {
    epoll_event ev;
    ev.events = events;
    ev.data.fd = fd;
    if (epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, fd, &ev) < 0) {
        APLOGI("epoll_ctl() for socket %d failed", fd);
        return false;
    }
    return true;
}

void DNSResponder::handleQuery() {
    char buffer[4096];
    sockaddr_storage sa;
    socklen_t sa_len = sizeof(sa);
    ssize_t len;
    do {
        len = recvfrom(socket_.get(), buffer, sizeof(buffer), 0, (sockaddr*)&sa, &sa_len);
    } while (len < 0 && (errno == EAGAIN || errno == EINTR));
    if (len <= 0) {
        APLOGI("recvfrom() failed, len=%zu", len);
        return;
    }
    DBGLOG("read %zd bytes", len);
    std::lock_guard lock(cv_mutex_);
    char response[4096];
    size_t response_len = sizeof(response);
    if (handleDNSRequest(buffer, len, response, &response_len) && response_len > 0) {
        // place wait_for after handleDNSRequest() so we can check the number of queries in
        // test case before it got responded.
        std::unique_lock guard(cv_mutex_for_deferred_resp_);
        cv_for_deferred_resp_.wait(
                guard, [this]() REQUIRES(cv_mutex_for_deferred_resp_) { return !deferred_resp_; });

        len = sendto(socket_.get(), response, response_len, 0,
                     reinterpret_cast<const sockaddr*>(&sa), sa_len);
        std::string host_str = addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
        if (len > 0) {
            DBGLOG("sent %zu bytes to %s", len, host_str.c_str());
        } else {
            APLOGI("sendto() failed for %s", host_str.c_str());
        }
        // Test that the response is actually a correct DNS message.
        const char* response_end = response + len;
        DNSHeader header;
        const char* cur = header.read(response, response_end);
        if (cur == nullptr) ALOGW("response is flawed");
    } else {
        ALOGW("not responding");
    }
    cv.notify_one();
    return;
}

bool DNSResponder::sendToEventFd() {
    const uint64_t data = 1;
    if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
        APLOGI("failed to write eventfd, rt=%zd", rt);
        return false;
    }
    return true;
}

void DNSResponder::handleEventFd() {
    int64_t data;
    if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
        APLOGI("ignore reading eventfd failed, rt=%zd", rt);
    }
}

}  // namespace test
}  // namespace test
+20 −6
Original line number Original line Diff line number Diff line
@@ -29,6 +29,7 @@
#include <vector>
#include <vector>


#include <android-base/thread_annotations.h>
#include <android-base/thread_annotations.h>
#include "android-base/unique_fd.h"


namespace test {
namespace test {


@@ -38,8 +39,7 @@ struct DNSRecord;


inline const std::string kDefaultListenAddr = "127.0.0.3";
inline const std::string kDefaultListenAddr = "127.0.0.3";
inline const std::string kDefaultListenService = "53";
inline const std::string kDefaultListenService = "53";
inline const int kDefaultPollTimoutMillis = 250;
inline const int kDefaultPollTimoutMillis = -1;
inline const ns_rcode kDefaultErrorCode = ns_rcode::ns_r_servfail;


/*
/*
 * Simple DNS responder, which replies to queries with the registered response
 * Simple DNS responder, which replies to queries with the registered response
@@ -122,6 +122,18 @@ class DNSResponder {
    bool makeErrorResponse(DNSHeader* header, ns_rcode rcode, char* response,
    bool makeErrorResponse(DNSHeader* header, ns_rcode rcode, char* response,
                           size_t* response_len) const;
                           size_t* response_len) const;


    // Add a new file descriptor to be polled by the handler thread.
    bool addFd(int fd, uint32_t events);

    // Read the query sent from the client and send the answer back to the client. It
    // makes sure the I/O communicated with the client is correct.
    void handleQuery();

    // Trigger the handler thread to terminate.
    bool sendToEventFd();

    // Used in the handler thread for the termination signal.
    void handleEventFd();


    // Address and service to listen on, currently limited to UDP.
    // Address and service to listen on, currently limited to UDP.
    const std::string listen_address_;
    const std::string listen_address_;
@@ -133,6 +145,8 @@ class DNSResponder {
    // Probability that a valid response is being sent instead of being sent
    // Probability that a valid response is being sent instead of being sent
    // instead of returning error_rcode_.
    // instead of returning error_rcode_.
    std::atomic<double> response_probability_ = 1.0;
    std::atomic<double> response_probability_ = 1.0;
    // Maximum number of fds for epoll.
    const int EPOLL_MAX_EVENTS = 2;


    // Control how the DNS server behaves when it receives the requests containing OPT RR.
    // Control how the DNS server behaves when it receives the requests containing OPT RR.
    // If it's set Edns::ON, the server can recognize and reply the response; if it's set
    // If it's set Edns::ON, the server can recognize and reply the response; if it's set
@@ -151,11 +165,11 @@ class DNSResponder {
        GUARDED_BY(queries_mutex_);
        GUARDED_BY(queries_mutex_);
    mutable std::mutex queries_mutex_;
    mutable std::mutex queries_mutex_;
    // Socket on which the server is listening.
    // Socket on which the server is listening.
    int socket_ = -1;
    android::base::unique_fd socket_;
    // File descriptor for epoll.
    // File descriptor for epoll.
    int epoll_fd_ = -1;
    android::base::unique_fd epoll_fd_;
    // Signal for request handler termination.
    // Eventfd used to signal for the handler thread termination.
    std::atomic<bool> terminate_ = false;
    android::base::unique_fd event_fd_;
    // Thread for handling incoming threads.
    // Thread for handling incoming threads.
    std::thread handler_thread_ GUARDED_BY(update_mutex_);
    std::thread handler_thread_ GUARDED_BY(update_mutex_);
    std::mutex update_mutex_;
    std::mutex update_mutex_;
+0 −7
Original line number Original line Diff line number Diff line
@@ -119,13 +119,6 @@ void DnsResponderClient::SetupDNSServers(unsigned num_servers, const std::vector
    }
    }
}
}


void DnsResponderClient::ShutdownDNSServers(std::vector<std::unique_ptr<test::DNSResponder>>* dns) {
    for (const auto& d : *dns) {
        d->stopServer();
    }
    dns->clear();
}

int DnsResponderClient::SetupOemNetwork() {
int DnsResponderClient::SetupOemNetwork() {
    mNetdSrv->networkDestroy(TEST_NETID);
    mNetdSrv->networkDestroy(TEST_NETID);
    mDnsResolvSrv->destroyNetworkCache(TEST_NETID);
    mDnsResolvSrv->destroyNetworkCache(TEST_NETID);
+0 −2
Original line number Original line Diff line number Diff line
@@ -87,8 +87,6 @@ public:
            std::vector<std::unique_ptr<test::DNSResponder>>* dns,
            std::vector<std::unique_ptr<test::DNSResponder>>* dns,
            std::vector<std::string>* servers);
            std::vector<std::string>* servers);


    static void ShutdownDNSServers(std::vector<std::unique_ptr<test::DNSResponder>>* dns);

    int SetupOemNetwork();
    int SetupOemNetwork();


    void TearDownOemNetwork(int oemNetId);
    void TearDownOemNetwork(int oemNetId);
+104 −85
Original line number Original line Diff line number Diff line
@@ -16,16 +16,15 @@


#include "dns_tls_frontend.h"
#include "dns_tls_frontend.h"


#include <netdb.h>
#include <stdio.h>
#include <unistd.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <arpa/inet.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <openssl/err.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/evp.h>
#include <openssl/ssl.h>
#include <openssl/ssl.h>
#include <sys/eventfd.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <unistd.h>


#define LOG_TAG "DnsTlsFrontend"
#define LOG_TAG "DnsTlsFrontend"
@@ -63,9 +62,7 @@ bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {


std::string errno2str() {
std::string errno2str() {
    char error_msg[512] = { 0 };
    char error_msg[512] = { 0 };
    if (strerror_r(errno, error_msg, sizeof(error_msg)))
    return strerror_r(errno, error_msg, sizeof(error_msg));
        return std::string();
    return std::string(error_msg);
}
}


#define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
#define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
@@ -197,65 +194,70 @@ bool DnsTlsFrontend::startServer() {
        .ai_socktype = SOCK_STREAM,
        .ai_socktype = SOCK_STREAM,
        .ai_flags = AI_PASSIVE
        .ai_flags = AI_PASSIVE
    };
    };
    addrinfo* frontend_ai_res;
    addrinfo* frontend_ai_res = nullptr;
    int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
    int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
                         &frontend_ai_hints, &frontend_ai_res);
                         &frontend_ai_hints, &frontend_ai_res);
    ScopedAddrinfo frontend_ai_res_cleanup(frontend_ai_res);
    if (rv) {
    if (rv) {
        ALOGE("frontend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
        ALOGE("frontend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
            listen_service_.c_str(), gai_strerror(rv));
            listen_service_.c_str(), gai_strerror(rv));
        return false;
        return false;
    }
    }


    int s = -1;
    for (const addrinfo* ai = frontend_ai_res ; ai ; ai = ai->ai_next) {
    for (const addrinfo* ai = frontend_ai_res ; ai ; ai = ai->ai_next) {
        s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
        android::base::unique_fd s(socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol));
        if (s < 0) continue;
        if (s.get() < 0) {
        enableSockopt(s, SOL_SOCKET, SO_REUSEPORT).ignoreError();
            APLOGI("ignore creating socket failed %d", s.get());
        enableSockopt(s, SOL_SOCKET, SO_REUSEADDR).ignoreError();
        if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
            APLOGI("bind failed for socket %d", s);
            close(s);
            s = -1;
            continue;
            continue;
        }
        }
        enableSockopt(s.get(), SOL_SOCKET, SO_REUSEPORT).ignoreError();
        enableSockopt(s.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
        std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
        std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
        if (bind(s.get(), ai->ai_addr, ai->ai_addrlen)) {
            APLOGI("failed to bind TCP %s:%s", host_str.c_str(), listen_service_.c_str());
            continue;
        }
        ALOGI("bound to TCP %s:%s", host_str.c_str(), listen_service_.c_str());
        ALOGI("bound to TCP %s:%s", host_str.c_str(), listen_service_.c_str());
        socket_ = std::move(s);
        break;
        break;
    }
    }
    freeaddrinfo(frontend_ai_res);
    if (s < 0) {
        ALOGE("server socket creation failed");
        return false;
    }


    if (listen(s, 1) < 0) {
    if (listen(socket_.get(), 1) < 0) {
        ALOGE("listen failed");
        APLOGI("failed to listen socket %d", socket_.get());
        return false;
        return false;
    }
    }


    socket_ = s;

    // Set up UDP client socket to backend.
    // Set up UDP client socket to backend.
    addrinfo backend_ai_hints{
    addrinfo backend_ai_hints{
        .ai_family = AF_UNSPEC,
        .ai_family = AF_UNSPEC,
        .ai_socktype = SOCK_DGRAM
        .ai_socktype = SOCK_DGRAM
    };
    };
    addrinfo* backend_ai_res;
    addrinfo* backend_ai_res = nullptr;
    rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(),
    rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(),
                         &backend_ai_hints, &backend_ai_res);
                         &backend_ai_hints, &backend_ai_res);
    ScopedAddrinfo backend_ai_res_cleanup(backend_ai_res);
    if (rv) {
    if (rv) {
        ALOGE("backend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
        ALOGE("backend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
            listen_service_.c_str(), gai_strerror(rv));
            listen_service_.c_str(), gai_strerror(rv));
        return false;
        return false;
    }
    }
    backend_socket_ = socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype,
    backend_socket_.reset(socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype,
        backend_ai_res->ai_protocol);
                                 backend_ai_res->ai_protocol));
    if (backend_socket_ < 0) {
    if (backend_socket_.get() < 0) {
        ALOGE("backend socket creation failed");
        APLOGI("backend socket %d creation failed", backend_socket_.get());
        return false;
    }

    // connect() always fails in the test DnsTlsSocketTest.SlowDestructor because of
    // no backend server. Don't check it.
    connect(backend_socket_.get(), backend_ai_res->ai_addr, backend_ai_res->ai_addrlen);

    // Set up eventfd socket.
    event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
    if (event_fd_.get() == -1) {
        APLOGI("failed to create eventfd %d", event_fd_.get());
        return false;
        return false;
    }
    }
    connect(backend_socket_, backend_ai_res->ai_addr, backend_ai_res->ai_addrlen);
    freeaddrinfo(backend_ai_res);


    {
    {
        std::lock_guard lock(update_mutex_);
        std::lock_guard lock(update_mutex_);
@@ -267,31 +269,36 @@ bool DnsTlsFrontend::startServer() {


void DnsTlsFrontend::requestHandler() {
void DnsTlsFrontend::requestHandler() {
    ALOGD("Request handler started");
    ALOGD("Request handler started");
    struct pollfd fds[1] = {{ .fd = socket_, .events = POLLIN }};
    enum { EVENT_FD = 0, LISTEN_FD = 1 };
    pollfd fds[2] = {{.fd = event_fd_.get(), .events = POLLIN},
                     {.fd = socket_.get(), .events = POLLIN}};

    while (true) {
        int poll_code = poll(fds, std::size(fds), -1);
        if (poll_code <= 0) {
            APLOGI("Poll failed with error %d", poll_code);
            break;
        }


    while (!terminate_) {
        if (fds[EVENT_FD].revents & (POLLIN | POLLERR)) {
        int poll_code = poll(fds, 1, 10 /* ms */);
            handleEventFd();
        if (poll_code == 0) {
            // Timeout.  Poll again.
            continue;
        } else if (poll_code < 0) {
            ALOGW("Poll failed with error %d", poll_code);
            // Error.
            break;
            break;
        }
        }
        if (fds[LISTEN_FD].revents & (POLLIN | POLLERR)) {
            sockaddr_storage addr;
            sockaddr_storage addr;
            socklen_t len = sizeof(addr);
            socklen_t len = sizeof(addr);


            ALOGD("Trying to accept a client");
            ALOGD("Trying to accept a client");
        int client = accept4(socket_, reinterpret_cast<sockaddr*>(&addr), &len, SOCK_CLOEXEC);
            android::base::unique_fd client(
        ALOGD("Got client socket %d", client);
                    accept4(socket_.get(), reinterpret_cast<sockaddr*>(&addr), &len, SOCK_CLOEXEC));
        if (client < 0) {
            if (client.get() < 0) {
                // Stop
                // Stop
                APLOGI("failed to accept client socket %d", client.get());
                break;
                break;
            }
            }


            bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get()));
            bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get()));
        SSL_set_fd(ssl.get(), client);
            SSL_set_fd(ssl.get(), client.get());


            ALOGD("Doing SSL handshake");
            ALOGD("Doing SSL handshake");
            bool success = false;
            bool success = false;
@@ -302,8 +309,6 @@ void DnsTlsFrontend::requestHandler() {
                success = handleOneRequest(ssl.get());
                success = handleOneRequest(ssl.get());
            }
            }


        close(client);

            if (success) {
            if (success) {
                // Increment queries_ as late as possible, because it represents
                // Increment queries_ as late as possible, because it represents
                // a query that is fully processed, and the response returned to the
                // a query that is fully processed, and the response returned to the
@@ -311,7 +316,8 @@ void DnsTlsFrontend::requestHandler() {
                ++queries_;
                ++queries_;
            }
            }
        }
        }
    ALOGD("Request handler terminating");
    }
    ALOGD("Ending loop");
}
}


bool DnsTlsFrontend::handleOneRequest(SSL* ssl) {
bool DnsTlsFrontend::handleOneRequest(SSL* ssl) {
@@ -331,14 +337,14 @@ bool DnsTlsFrontend::handleOneRequest(SSL* ssl) {
        }
        }
        qbytes += ret;
        qbytes += ret;
    }
    }
    int sent = send(backend_socket_, query, qlen, 0);
    int sent = send(backend_socket_.get(), query, qlen, 0);
    if (sent != qlen) {
    if (sent != qlen) {
        ALOGI("Failed to send query");
        ALOGI("Failed to send query");
        return false;
        return false;
    }
    }
    const int max_size = 4096;
    const int max_size = 4096;
    uint8_t recv_buffer[max_size];
    uint8_t recv_buffer[max_size];
    int rlen = recv(backend_socket_, recv_buffer, max_size, 0);
    int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0);
    if (rlen <= 0) {
    if (rlen <= 0) {
        ALOGI("Failed to receive response");
        ALOGI("Failed to receive response");
        return false;
        return false;
@@ -363,18 +369,15 @@ bool DnsTlsFrontend::stopServer() {
        ALOGI("server not running");
        ALOGI("server not running");
        return false;
        return false;
    }
    }
    if (terminate_) {

        ALOGI("LOGIC ERROR");
    ALOGI("stopping frontend");
    if (!sendToEventFd()) {
        return false;
        return false;
    }
    }
    ALOGI("stopping frontend");
    terminate_ = true;
    handler_thread_.join();
    handler_thread_.join();
    close(socket_);
    socket_.reset();
    close(backend_socket_);
    backend_socket_.reset();
    terminate_ = false;
    event_fd_.reset();
    socket_ = -1;
    backend_socket_ = -1;
    ctx_.reset();
    ctx_.reset();
    fingerprint_.clear();
    fingerprint_.clear();
    ALOGI("frontend stopped successfully");
    ALOGI("frontend stopped successfully");
@@ -399,4 +402,20 @@ bool DnsTlsFrontend::waitForQueries(int number, int timeoutMs) const {
    return false;
    return false;
}
}


bool DnsTlsFrontend::sendToEventFd() {
    const uint64_t data = 1;
    if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
        APLOGI("failed to write eventfd, rt=%zd", rt);
        return false;
    }
    return true;
}

void DnsTlsFrontend::handleEventFd() {
    int64_t data;
    if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
        APLOGI("ignore reading eventfd failed, rt=%zd", rt);
    }
}

}  // namespace test
}  // namespace test
Loading