Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 6bcde889 authored by Mike Yu's avatar Mike Yu Committed by Gerrit Code Review
Browse files

Merge changes I9262f65e,I5088d3c6,I47ab70aa,I262f9782

* changes:
  Rewrite DnsTlsSocketTest by gMock
  Add some states to trace DnsTlsSocket life cycle
  Preserve original errno before calling close()
  Add startHandshake to IDnsTlsSocket
parents 12e48b04 1b9069cd
Loading
Loading
Loading
Loading
+37 −8
Original line number Diff line number Diff line
@@ -81,7 +81,7 @@ Status DnsTlsSocket::tcpConnect() {

    mSslFd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
    if (mSslFd.get() == -1) {
        LOG(ERROR) << "Failed to create socket";
        PLOG(ERROR) << "Failed to create socket";
        return Status(errno);
    }

@@ -89,9 +89,10 @@ Status DnsTlsSocket::tcpConnect() {

    const socklen_t len = sizeof(mMark);
    if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
        LOG(ERROR) << "Failed to set socket mark";
        const int err = errno;
        PLOG(ERROR) << "Failed to set socket mark";
        mSslFd.reset();
        return Status(errno);
        return Status(err);
    }

    const Status tfo = enableSockopt(mSslFd.get(), SOL_TCP, TCP_FASTOPEN_CONNECT);
@@ -105,9 +106,10 @@ Status DnsTlsSocket::tcpConnect() {
    if (connect(mSslFd.get(), reinterpret_cast<const struct sockaddr *>(&mServer.ss),
                sizeof(mServer.ss)) != 0 &&
            errno != EINPROGRESS) {
        LOG(DEBUG) << "Socket failed to connect";
        const int err = errno;
        PLOG(ERROR) << "Socket failed to connect";
        mSslFd.reset();
        return Status(errno);
        return Status(err);
    }

    return netdutils::status::ok;
@@ -169,18 +171,33 @@ bool DnsTlsSocket::initialize() {
    // Enable session cache
    mCache->prepareSslContext(mSslCtx.get());

    mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));

    transitionState(State::UNINITIALIZED, State::INITIALIZED);

    return true;
}

bool DnsTlsSocket::startHandshake() {
    std::lock_guard guard(mLock);
    if (mState != State::INITIALIZED) {
        LOG(ERROR) << "Calling startHandshake in unexpected state " << static_cast<int>(mState);
        return false;
    }
    transitionState(State::INITIALIZED, State::CONNECTING);

    // Connect
    Status status = tcpConnect();
    if (!status.ok()) {
        transitionState(State::CONNECTING, State::WAIT_FOR_DELETE);
        return false;
    }
    mSsl = sslConnect(mSslFd.get());
    if (!mSsl) {
        transitionState(State::CONNECTING, State::WAIT_FOR_DELETE);
        return false;
    }

    mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));

    // Start the I/O loop.
    mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this));

@@ -309,7 +326,9 @@ void DnsTlsSocket::loop() {
    std::deque<std::vector<uint8_t>> q;
    const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000;

    transitionState(State::CONNECTING, State::CONNECTED);
    setThreadName(StringPrintf("TlsListen_%u", mMark & 0xffff).c_str());

    while (true) {
        // poll() ignores negative fds
        struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } };
@@ -336,7 +355,7 @@ void DnsTlsSocket::loop() {
            break;
        }
        if (s < 0) {
            LOG(DEBUG) << "Poll failed: " << errno;
            PLOG(DEBUG) << "Poll failed";
            break;
        }
        if (fds[SSLFD].revents & (POLLIN | POLLERR | POLLHUP)) {
@@ -379,6 +398,7 @@ void DnsTlsSocket::loop() {
    sslDisconnect();
    LOG(DEBUG) << "Calling onClosed";
    mObserver->onClosed();
    transitionState(State::CONNECTED, State::WAIT_FOR_DELETE);
    LOG(DEBUG) << "Ending loop";
}

@@ -441,6 +461,15 @@ bool DnsTlsSocket::incrementEventFd(const int64_t count) {
    return true;
}

void DnsTlsSocket::transitionState(State from, State to) {
    if (mState != from) {
        LOG(WARNING) << "BUG: transitioning from an unexpected state " << static_cast<int>(mState)
                     << ", expect: from " << static_cast<int>(from) << " to "
                     << static_cast<int>(to);
    }
    mState = to;
}

// Read exactly len bytes into buffer or fail with an SSL error code
int DnsTlsSocket::sslRead(const Slice buffer, bool wait) {
    size_t remaining = buffer.size();
+40 −1
Original line number Diff line number Diff line
@@ -44,18 +44,53 @@ class DnsTlsSessionCache;
// or the destructor in a callback.  Doing so will result in deadlocks.
// This class may call the observer at any time after initialize(), until the destructor
// returns (but not after).
//
// Calls to IDnsTlsSocketObserver in a DnsTlsSocket life cycle:
//
//                                UNINITIALIZED
//                                      |
//                                      v
//                                 INITIALIZED
//                                      |
//                                      v
//                            +----CONNECTING------+
//            Handshake fails |                    | Handshake succeeds
//                            |                    |
//                            |                    v
//                            |        +---> CONNECTED --+
//                            |        |           |     |
//                            |        +-----------+     | Idle timeout
//                            |   Send/Recv queries      | onClose()
//                            |   onResponse()           |
//                            |                          |
//                            |                          |
//                            +--> WAIT_FOR_DELETE <-----+
//
//
// TODO: Add onHandshakeFinished() for handshake results.
class DnsTlsSocket : public IDnsTlsSocket {
  public:
    enum class State {
        UNINITIALIZED,
        INITIALIZED,
        CONNECTING,
        CONNECTED,
        WAIT_FOR_DELETE,
    };

    DnsTlsSocket(const DnsTlsServer& server, unsigned mark,
                 IDnsTlsSocketObserver* _Nonnull observer, DnsTlsSessionCache* _Nonnull cache)
        : mMark(mark), mServer(server), mObserver(observer), mCache(cache) {}
    ~DnsTlsSocket();

    // Creates the SSL context for this session and connect.  Returns false on failure.
    // Creates the SSL context for this session. Returns false on failure.
    // This method should be called after construction and before use of a DnsTlsSocket.
    // Only call this method once per DnsTlsSocket.
    bool initialize() EXCLUDES(mLock);

    // A blocking call to start handshaking until it finishes.
    bool startHandshake() EXCLUDES(mLock);

    // Send a query on the provided SSL socket.  |query| contains
    // the body of a query, not including the ID header. This function will typically return before
    // the query is actually sent.  If this function fails, DnsTlsSocketObserver will be
@@ -112,6 +147,9 @@ class DnsTlsSocket : public IDnsTlsSocket {
    // the loop thread by incrementing mEventFd.  loop() reads items off the queue.
    LockedQueue<std::vector<uint8_t>> mQueue;

    // Transition the state from expected state |from| to new state |to|.
    void transitionState(State from, State to) REQUIRES(mLock);

    // eventfd socket used for notifying the SSL thread when queries are ready to send.
    // This socket acts similarly to an atomic counter, incremented by query() and cleared
    // by loop().  We have to use a socket because the SSL thread needs to wait in poll()
@@ -131,6 +169,7 @@ class DnsTlsSocket : public IDnsTlsSocket {
    const DnsTlsServer mServer;
    IDnsTlsSocketObserver* _Nonnull const mObserver;
    DnsTlsSessionCache* _Nonnull const mCache;
    State mState GUARDED_BY(mLock) = State::UNINITIALIZED;
};

}  // end of namespace net
+6 −1
Original line number Diff line number Diff line
@@ -70,9 +70,14 @@ bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query& q) {
void DnsTlsTransport::doConnect() {
    LOG(DEBUG) << "Constructing new socket";
    mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache);

    bool success = true;
    if (mSocket.get() == nullptr || !mSocket->startHandshake()) {
        success = false;
    }
    mConnectCounter++;

    if (mSocket) {
    if (success) {
        auto queries = mQueries.getAll();
        LOG(DEBUG) << "Initialization succeeded.  Reissuing " << queries.size() << " queries.";
        for(auto& q : queries) {
+2 −0
Original line number Diff line number Diff line
@@ -40,6 +40,8 @@ class IDnsTlsSocket {
    // notified that the socket is closed.
    // Note that a true return value indicates successful sending, not receipt of a response.
    virtual bool query(uint16_t id, const netdutils::Slice query) = 0;

    virtual bool startHandshake() = 0;
};

}  // end of namespace net
+69 −19
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@

#include <android-base/logging.h>
#include <android-base/macros.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <netdutils/Slice.h>

@@ -39,8 +40,9 @@
namespace android {
namespace net {

using netdutils::Slice;
using netdutils::makeSlice;
using netdutils::Slice;
using ::testing::NiceMock;

typedef std::vector<uint8_t> bytevec;

@@ -134,6 +136,7 @@ class FakeSocketEcho : public IDnsTlsSocket {
        std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
        return true;
    }
    bool startHandshake() override { return true; }

  private:
    IDnsTlsSocketObserver* const mObserver;
@@ -169,6 +172,7 @@ class FakeSocketId : public IDnsTlsSocket {
        std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, response).detach();
        return true;
    }
    bool startHandshake() override { return true; }

  private:
    IDnsTlsSocketObserver* const mObserver;
@@ -216,9 +220,15 @@ TEST_F(TransportTest, RacingQueries_10000) {
class FakeSocketDelay : public IDnsTlsSocket {
  public:
    explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
    ~FakeSocketDelay() { std::lock_guard guard(mLock); }
    static size_t sDelay;
    static bool sReverse;
    ~FakeSocketDelay() {
        std::lock_guard guard(mLock);
        sDelay = 1;
        sReverse = false;
        sConnectable = true;
    }
    inline static size_t sDelay = 1;
    inline static bool sReverse = false;
    inline static bool sConnectable = true;

    bool query(uint16_t id, const Slice query) override {
        LOG(DEBUG) << "FakeSocketDelay got query with ID " << int(id);
@@ -236,6 +246,7 @@ class FakeSocketDelay : public IDnsTlsSocket {
        }
        return true;
    }
    bool startHandshake() override { return sConnectable; }

  private:
    void sendResponses() {
@@ -256,9 +267,6 @@ class FakeSocketDelay : public IDnsTlsSocket {
    std::vector<bytevec> mResponses GUARDED_BY(mLock);
};

size_t FakeSocketDelay::sDelay;
bool FakeSocketDelay::sReverse;

TEST_F(TransportTest, ParallelColliding) {
    FakeSocketDelay::sDelay = 10;
    FakeSocketDelay::sReverse = false;
@@ -424,13 +432,24 @@ class NullSocketFactory : public IDnsTlsSocketFactory {
};

TEST_F(TransportTest, ConnectFail) {
    NullSocketFactory factory;
    DnsTlsTransport transport(SERVER1, MARK, &factory);
    auto r = transport.query(makeSlice(QUERY)).get();
    // Failure on creating socket.
    NullSocketFactory factory1;
    DnsTlsTransport transport1(SERVER1, MARK, &factory1);
    auto r = transport1.query(makeSlice(QUERY)).get();

    EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
    EXPECT_TRUE(r.response.empty());
    EXPECT_EQ(transport.getConnectCounter(), 1);
    EXPECT_EQ(transport1.getConnectCounter(), 1);

    // Failure on handshaking.
    FakeSocketDelay::sConnectable = false;
    FakeSocketFactory<FakeSocketDelay> factory2;
    DnsTlsTransport transport2(SERVER1, MARK, &factory2);
    r = transport2.query(makeSlice(QUERY)).get();

    EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
    EXPECT_TRUE(r.response.empty());
    EXPECT_EQ(transport2.getConnectCounter(), 1);
}

// Simulate a socket that connects but then immediately receives a server
@@ -444,6 +463,7 @@ class FakeSocketClose : public IDnsTlsSocket {
               const Slice query ATTRIBUTE_UNUSED) override {
        return true;
    }
    bool startHandshake() override { return true; }

  private:
    std::thread mCloser;
@@ -506,6 +526,7 @@ class FakeSocketLimited : public IDnsTlsSocket {
        }
        return mQueries <= sLimit;
    }
    bool startHandshake() override { return true; }

  private:
    void sendClose() {
@@ -632,6 +653,7 @@ class FakeSocketGarbage : public IDnsTlsSocket {
        mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
        return true;
    }
    bool startHandshake() override { return true; }

  private:
    std::mutex mLock;
@@ -959,12 +981,10 @@ TEST(QueryMapTest, FillHole) {
    EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
}

class StubObserver : public IDnsTlsSocketObserver {
class MockDnsTlsSocketObserver : public IDnsTlsSocketObserver {
  public:
    bool closed = false;
    void onResponse(std::vector<uint8_t>) override {}

    void onClosed() override { closed = true; }
    MOCK_METHOD(void, onClosed, (), (override));
    MOCK_METHOD(void, onResponse, (std::vector<uint8_t>), (override));
};

TEST(DnsTlsSocketTest, SlowDestructor) {
@@ -980,23 +1000,53 @@ TEST(DnsTlsSocketTest, SlowDestructor) {
    DnsTlsServer server;
    parseServer(tls_addr, 8530, &server.ss);

    StubObserver observer;
    ASSERT_FALSE(observer.closed);
    MockDnsTlsSocketObserver observer;
    DnsTlsSessionCache cache;
    auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache);
    ASSERT_TRUE(socket->initialize());
    ASSERT_TRUE(socket->startHandshake());

    // Test: Time the socket destructor.  This should be fast.
    auto before = std::chrono::steady_clock::now();
    EXPECT_CALL(observer, onClosed);
    socket.reset();
    auto after = std::chrono::steady_clock::now();
    auto delay = after - before;
    LOG(DEBUG) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns";
    EXPECT_TRUE(observer.closed);
    // Shutdown should complete in milliseconds, but if the shutdown signal is lost
    // it will wait for the timeout, which is expected to take 20seconds.
    EXPECT_LT(delay, std::chrono::seconds{5});
}

TEST(DnsTlsSocketTest, StartHandshake) {
    constexpr char tls_addr[] = "127.0.0.3";
    constexpr char tls_port[] = "8530";
    constexpr char backend_addr[] = "192.0.2.1";
    constexpr char backend_port[] = "1";

    test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port);
    ASSERT_TRUE(tls.startServer());

    DnsTlsServer server;
    parseServer(tls_addr, 8530, &server.ss);

    // Use NiceMock to suppress the "uninteresting calls" warning.
    // (onClose will be called when running |socket|'s destructor)
    NiceMock<MockDnsTlsSocketObserver> observer;
    DnsTlsSessionCache cache;
    auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache);

    // Call the function before the call to initialize().
    EXPECT_FALSE(socket->startHandshake());

    // Call the function after the call to initialize().
    EXPECT_TRUE(socket->initialize());
    EXPECT_TRUE(socket->startHandshake());

    // Call both of them again.
    EXPECT_FALSE(socket->initialize());
    EXPECT_FALSE(socket->startHandshake());
}

} // end of namespace net
} // end of namespace android