Loading DnsTlsSocket.cpp +37 −8 Original line number Diff line number Diff line Loading @@ -81,7 +81,7 @@ Status DnsTlsSocket::tcpConnect() { mSslFd.reset(socket(mServer.ss.ss_family, type, mServer.protocol)); if (mSslFd.get() == -1) { LOG(ERROR) << "Failed to create socket"; PLOG(ERROR) << "Failed to create socket"; return Status(errno); } Loading @@ -89,9 +89,10 @@ Status DnsTlsSocket::tcpConnect() { const socklen_t len = sizeof(mMark); if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) { LOG(ERROR) << "Failed to set socket mark"; const int err = errno; PLOG(ERROR) << "Failed to set socket mark"; mSslFd.reset(); return Status(errno); return Status(err); } const Status tfo = enableSockopt(mSslFd.get(), SOL_TCP, TCP_FASTOPEN_CONNECT); Loading @@ -105,9 +106,10 @@ Status DnsTlsSocket::tcpConnect() { if (connect(mSslFd.get(), reinterpret_cast<const struct sockaddr *>(&mServer.ss), sizeof(mServer.ss)) != 0 && errno != EINPROGRESS) { LOG(DEBUG) << "Socket failed to connect"; const int err = errno; PLOG(ERROR) << "Socket failed to connect"; mSslFd.reset(); return Status(errno); return Status(err); } return netdutils::status::ok; Loading Loading @@ -169,18 +171,33 @@ bool DnsTlsSocket::initialize() { // Enable session cache mCache->prepareSslContext(mSslCtx.get()); mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC)); transitionState(State::UNINITIALIZED, State::INITIALIZED); return true; } bool DnsTlsSocket::startHandshake() { std::lock_guard guard(mLock); if (mState != State::INITIALIZED) { LOG(ERROR) << "Calling startHandshake in unexpected state " << static_cast<int>(mState); return false; } transitionState(State::INITIALIZED, State::CONNECTING); // Connect Status status = tcpConnect(); if (!status.ok()) { transitionState(State::CONNECTING, State::WAIT_FOR_DELETE); return false; } mSsl = sslConnect(mSslFd.get()); if (!mSsl) { transitionState(State::CONNECTING, State::WAIT_FOR_DELETE); return false; } mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC)); // Start the I/O loop. mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this)); Loading Loading @@ -309,7 +326,9 @@ void DnsTlsSocket::loop() { std::deque<std::vector<uint8_t>> q; const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000; transitionState(State::CONNECTING, State::CONNECTED); setThreadName(StringPrintf("TlsListen_%u", mMark & 0xffff).c_str()); while (true) { // poll() ignores negative fds struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } }; Loading @@ -336,7 +355,7 @@ void DnsTlsSocket::loop() { break; } if (s < 0) { LOG(DEBUG) << "Poll failed: " << errno; PLOG(DEBUG) << "Poll failed"; break; } if (fds[SSLFD].revents & (POLLIN | POLLERR | POLLHUP)) { Loading Loading @@ -379,6 +398,7 @@ void DnsTlsSocket::loop() { sslDisconnect(); LOG(DEBUG) << "Calling onClosed"; mObserver->onClosed(); transitionState(State::CONNECTED, State::WAIT_FOR_DELETE); LOG(DEBUG) << "Ending loop"; } Loading Loading @@ -441,6 +461,15 @@ bool DnsTlsSocket::incrementEventFd(const int64_t count) { return true; } void DnsTlsSocket::transitionState(State from, State to) { if (mState != from) { LOG(WARNING) << "BUG: transitioning from an unexpected state " << static_cast<int>(mState) << ", expect: from " << static_cast<int>(from) << " to " << static_cast<int>(to); } mState = to; } // Read exactly len bytes into buffer or fail with an SSL error code int DnsTlsSocket::sslRead(const Slice buffer, bool wait) { size_t remaining = buffer.size(); Loading DnsTlsSocket.h +40 −1 Original line number Diff line number Diff line Loading @@ -44,18 +44,53 @@ class DnsTlsSessionCache; // or the destructor in a callback. Doing so will result in deadlocks. // This class may call the observer at any time after initialize(), until the destructor // returns (but not after). // // Calls to IDnsTlsSocketObserver in a DnsTlsSocket life cycle: // // UNINITIALIZED // | // v // INITIALIZED // | // v // +----CONNECTING------+ // Handshake fails | | Handshake succeeds // | | // | v // | +---> CONNECTED --+ // | | | | // | +-----------+ | Idle timeout // | Send/Recv queries | onClose() // | onResponse() | // | | // | | // +--> WAIT_FOR_DELETE <-----+ // // // TODO: Add onHandshakeFinished() for handshake results. class DnsTlsSocket : public IDnsTlsSocket { public: enum class State { UNINITIALIZED, INITIALIZED, CONNECTING, CONNECTED, WAIT_FOR_DELETE, }; DnsTlsSocket(const DnsTlsServer& server, unsigned mark, IDnsTlsSocketObserver* _Nonnull observer, DnsTlsSessionCache* _Nonnull cache) : mMark(mark), mServer(server), mObserver(observer), mCache(cache) {} ~DnsTlsSocket(); // Creates the SSL context for this session and connect. Returns false on failure. // Creates the SSL context for this session. Returns false on failure. // This method should be called after construction and before use of a DnsTlsSocket. // Only call this method once per DnsTlsSocket. bool initialize() EXCLUDES(mLock); // A blocking call to start handshaking until it finishes. bool startHandshake() EXCLUDES(mLock); // Send a query on the provided SSL socket. |query| contains // the body of a query, not including the ID header. This function will typically return before // the query is actually sent. If this function fails, DnsTlsSocketObserver will be Loading Loading @@ -112,6 +147,9 @@ class DnsTlsSocket : public IDnsTlsSocket { // the loop thread by incrementing mEventFd. loop() reads items off the queue. LockedQueue<std::vector<uint8_t>> mQueue; // Transition the state from expected state |from| to new state |to|. void transitionState(State from, State to) REQUIRES(mLock); // eventfd socket used for notifying the SSL thread when queries are ready to send. // This socket acts similarly to an atomic counter, incremented by query() and cleared // by loop(). We have to use a socket because the SSL thread needs to wait in poll() Loading @@ -131,6 +169,7 @@ class DnsTlsSocket : public IDnsTlsSocket { const DnsTlsServer mServer; IDnsTlsSocketObserver* _Nonnull const mObserver; DnsTlsSessionCache* _Nonnull const mCache; State mState GUARDED_BY(mLock) = State::UNINITIALIZED; }; } // end of namespace net Loading DnsTlsTransport.cpp +6 −1 Original line number Diff line number Diff line Loading @@ -70,9 +70,14 @@ bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query& q) { void DnsTlsTransport::doConnect() { LOG(DEBUG) << "Constructing new socket"; mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache); bool success = true; if (mSocket.get() == nullptr || !mSocket->startHandshake()) { success = false; } mConnectCounter++; if (mSocket) { if (success) { auto queries = mQueries.getAll(); LOG(DEBUG) << "Initialization succeeded. Reissuing " << queries.size() << " queries."; for(auto& q : queries) { Loading IDnsTlsSocket.h +2 −0 Original line number Diff line number Diff line Loading @@ -40,6 +40,8 @@ class IDnsTlsSocket { // notified that the socket is closed. // Note that a true return value indicates successful sending, not receipt of a response. virtual bool query(uint16_t id, const netdutils::Slice query) = 0; virtual bool startHandshake() = 0; }; } // end of namespace net Loading resolv_tls_unit_test.cpp +69 −19 Original line number Diff line number Diff line Loading @@ -22,6 +22,7 @@ #include <android-base/logging.h> #include <android-base/macros.h> #include <gmock/gmock.h> #include <gtest/gtest.h> #include <netdutils/Slice.h> Loading @@ -39,8 +40,9 @@ namespace android { namespace net { using netdutils::Slice; using netdutils::makeSlice; using netdutils::Slice; using ::testing::NiceMock; typedef std::vector<uint8_t> bytevec; Loading Loading @@ -134,6 +136,7 @@ class FakeSocketEcho : public IDnsTlsSocket { std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach(); return true; } bool startHandshake() override { return true; } private: IDnsTlsSocketObserver* const mObserver; Loading Loading @@ -169,6 +172,7 @@ class FakeSocketId : public IDnsTlsSocket { std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, response).detach(); return true; } bool startHandshake() override { return true; } private: IDnsTlsSocketObserver* const mObserver; Loading Loading @@ -216,9 +220,15 @@ TEST_F(TransportTest, RacingQueries_10000) { class FakeSocketDelay : public IDnsTlsSocket { public: explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {} ~FakeSocketDelay() { std::lock_guard guard(mLock); } static size_t sDelay; static bool sReverse; ~FakeSocketDelay() { std::lock_guard guard(mLock); sDelay = 1; sReverse = false; sConnectable = true; } inline static size_t sDelay = 1; inline static bool sReverse = false; inline static bool sConnectable = true; bool query(uint16_t id, const Slice query) override { LOG(DEBUG) << "FakeSocketDelay got query with ID " << int(id); Loading @@ -236,6 +246,7 @@ class FakeSocketDelay : public IDnsTlsSocket { } return true; } bool startHandshake() override { return sConnectable; } private: void sendResponses() { Loading @@ -256,9 +267,6 @@ class FakeSocketDelay : public IDnsTlsSocket { std::vector<bytevec> mResponses GUARDED_BY(mLock); }; size_t FakeSocketDelay::sDelay; bool FakeSocketDelay::sReverse; TEST_F(TransportTest, ParallelColliding) { FakeSocketDelay::sDelay = 10; FakeSocketDelay::sReverse = false; Loading Loading @@ -424,13 +432,24 @@ class NullSocketFactory : public IDnsTlsSocketFactory { }; TEST_F(TransportTest, ConnectFail) { NullSocketFactory factory; DnsTlsTransport transport(SERVER1, MARK, &factory); auto r = transport.query(makeSlice(QUERY)).get(); // Failure on creating socket. NullSocketFactory factory1; DnsTlsTransport transport1(SERVER1, MARK, &factory1); auto r = transport1.query(makeSlice(QUERY)).get(); EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code); EXPECT_TRUE(r.response.empty()); EXPECT_EQ(transport.getConnectCounter(), 1); EXPECT_EQ(transport1.getConnectCounter(), 1); // Failure on handshaking. FakeSocketDelay::sConnectable = false; FakeSocketFactory<FakeSocketDelay> factory2; DnsTlsTransport transport2(SERVER1, MARK, &factory2); r = transport2.query(makeSlice(QUERY)).get(); EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code); EXPECT_TRUE(r.response.empty()); EXPECT_EQ(transport2.getConnectCounter(), 1); } // Simulate a socket that connects but then immediately receives a server Loading @@ -444,6 +463,7 @@ class FakeSocketClose : public IDnsTlsSocket { const Slice query ATTRIBUTE_UNUSED) override { return true; } bool startHandshake() override { return true; } private: std::thread mCloser; Loading Loading @@ -506,6 +526,7 @@ class FakeSocketLimited : public IDnsTlsSocket { } return mQueries <= sLimit; } bool startHandshake() override { return true; } private: void sendClose() { Loading Loading @@ -632,6 +653,7 @@ class FakeSocketGarbage : public IDnsTlsSocket { mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2)); return true; } bool startHandshake() override { return true; } private: std::mutex mLock; Loading Loading @@ -959,12 +981,10 @@ TEST(QueryMapTest, FillHole) { EXPECT_FALSE(map.recordQuery(makeSlice(QUERY))); } class StubObserver : public IDnsTlsSocketObserver { class MockDnsTlsSocketObserver : public IDnsTlsSocketObserver { public: bool closed = false; void onResponse(std::vector<uint8_t>) override {} void onClosed() override { closed = true; } MOCK_METHOD(void, onClosed, (), (override)); MOCK_METHOD(void, onResponse, (std::vector<uint8_t>), (override)); }; TEST(DnsTlsSocketTest, SlowDestructor) { Loading @@ -980,23 +1000,53 @@ TEST(DnsTlsSocketTest, SlowDestructor) { DnsTlsServer server; parseServer(tls_addr, 8530, &server.ss); StubObserver observer; ASSERT_FALSE(observer.closed); MockDnsTlsSocketObserver observer; DnsTlsSessionCache cache; auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache); ASSERT_TRUE(socket->initialize()); ASSERT_TRUE(socket->startHandshake()); // Test: Time the socket destructor. This should be fast. auto before = std::chrono::steady_clock::now(); EXPECT_CALL(observer, onClosed); socket.reset(); auto after = std::chrono::steady_clock::now(); auto delay = after - before; LOG(DEBUG) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns"; EXPECT_TRUE(observer.closed); // Shutdown should complete in milliseconds, but if the shutdown signal is lost // it will wait for the timeout, which is expected to take 20seconds. EXPECT_LT(delay, std::chrono::seconds{5}); } TEST(DnsTlsSocketTest, StartHandshake) { constexpr char tls_addr[] = "127.0.0.3"; constexpr char tls_port[] = "8530"; constexpr char backend_addr[] = "192.0.2.1"; constexpr char backend_port[] = "1"; test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port); ASSERT_TRUE(tls.startServer()); DnsTlsServer server; parseServer(tls_addr, 8530, &server.ss); // Use NiceMock to suppress the "uninteresting calls" warning. // (onClose will be called when running |socket|'s destructor) NiceMock<MockDnsTlsSocketObserver> observer; DnsTlsSessionCache cache; auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache); // Call the function before the call to initialize(). EXPECT_FALSE(socket->startHandshake()); // Call the function after the call to initialize(). EXPECT_TRUE(socket->initialize()); EXPECT_TRUE(socket->startHandshake()); // Call both of them again. EXPECT_FALSE(socket->initialize()); EXPECT_FALSE(socket->startHandshake()); } } // end of namespace net } // end of namespace android Loading
DnsTlsSocket.cpp +37 −8 Original line number Diff line number Diff line Loading @@ -81,7 +81,7 @@ Status DnsTlsSocket::tcpConnect() { mSslFd.reset(socket(mServer.ss.ss_family, type, mServer.protocol)); if (mSslFd.get() == -1) { LOG(ERROR) << "Failed to create socket"; PLOG(ERROR) << "Failed to create socket"; return Status(errno); } Loading @@ -89,9 +89,10 @@ Status DnsTlsSocket::tcpConnect() { const socklen_t len = sizeof(mMark); if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) { LOG(ERROR) << "Failed to set socket mark"; const int err = errno; PLOG(ERROR) << "Failed to set socket mark"; mSslFd.reset(); return Status(errno); return Status(err); } const Status tfo = enableSockopt(mSslFd.get(), SOL_TCP, TCP_FASTOPEN_CONNECT); Loading @@ -105,9 +106,10 @@ Status DnsTlsSocket::tcpConnect() { if (connect(mSslFd.get(), reinterpret_cast<const struct sockaddr *>(&mServer.ss), sizeof(mServer.ss)) != 0 && errno != EINPROGRESS) { LOG(DEBUG) << "Socket failed to connect"; const int err = errno; PLOG(ERROR) << "Socket failed to connect"; mSslFd.reset(); return Status(errno); return Status(err); } return netdutils::status::ok; Loading Loading @@ -169,18 +171,33 @@ bool DnsTlsSocket::initialize() { // Enable session cache mCache->prepareSslContext(mSslCtx.get()); mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC)); transitionState(State::UNINITIALIZED, State::INITIALIZED); return true; } bool DnsTlsSocket::startHandshake() { std::lock_guard guard(mLock); if (mState != State::INITIALIZED) { LOG(ERROR) << "Calling startHandshake in unexpected state " << static_cast<int>(mState); return false; } transitionState(State::INITIALIZED, State::CONNECTING); // Connect Status status = tcpConnect(); if (!status.ok()) { transitionState(State::CONNECTING, State::WAIT_FOR_DELETE); return false; } mSsl = sslConnect(mSslFd.get()); if (!mSsl) { transitionState(State::CONNECTING, State::WAIT_FOR_DELETE); return false; } mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC)); // Start the I/O loop. mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this)); Loading Loading @@ -309,7 +326,9 @@ void DnsTlsSocket::loop() { std::deque<std::vector<uint8_t>> q; const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000; transitionState(State::CONNECTING, State::CONNECTED); setThreadName(StringPrintf("TlsListen_%u", mMark & 0xffff).c_str()); while (true) { // poll() ignores negative fds struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } }; Loading @@ -336,7 +355,7 @@ void DnsTlsSocket::loop() { break; } if (s < 0) { LOG(DEBUG) << "Poll failed: " << errno; PLOG(DEBUG) << "Poll failed"; break; } if (fds[SSLFD].revents & (POLLIN | POLLERR | POLLHUP)) { Loading Loading @@ -379,6 +398,7 @@ void DnsTlsSocket::loop() { sslDisconnect(); LOG(DEBUG) << "Calling onClosed"; mObserver->onClosed(); transitionState(State::CONNECTED, State::WAIT_FOR_DELETE); LOG(DEBUG) << "Ending loop"; } Loading Loading @@ -441,6 +461,15 @@ bool DnsTlsSocket::incrementEventFd(const int64_t count) { return true; } void DnsTlsSocket::transitionState(State from, State to) { if (mState != from) { LOG(WARNING) << "BUG: transitioning from an unexpected state " << static_cast<int>(mState) << ", expect: from " << static_cast<int>(from) << " to " << static_cast<int>(to); } mState = to; } // Read exactly len bytes into buffer or fail with an SSL error code int DnsTlsSocket::sslRead(const Slice buffer, bool wait) { size_t remaining = buffer.size(); Loading
DnsTlsSocket.h +40 −1 Original line number Diff line number Diff line Loading @@ -44,18 +44,53 @@ class DnsTlsSessionCache; // or the destructor in a callback. Doing so will result in deadlocks. // This class may call the observer at any time after initialize(), until the destructor // returns (but not after). // // Calls to IDnsTlsSocketObserver in a DnsTlsSocket life cycle: // // UNINITIALIZED // | // v // INITIALIZED // | // v // +----CONNECTING------+ // Handshake fails | | Handshake succeeds // | | // | v // | +---> CONNECTED --+ // | | | | // | +-----------+ | Idle timeout // | Send/Recv queries | onClose() // | onResponse() | // | | // | | // +--> WAIT_FOR_DELETE <-----+ // // // TODO: Add onHandshakeFinished() for handshake results. class DnsTlsSocket : public IDnsTlsSocket { public: enum class State { UNINITIALIZED, INITIALIZED, CONNECTING, CONNECTED, WAIT_FOR_DELETE, }; DnsTlsSocket(const DnsTlsServer& server, unsigned mark, IDnsTlsSocketObserver* _Nonnull observer, DnsTlsSessionCache* _Nonnull cache) : mMark(mark), mServer(server), mObserver(observer), mCache(cache) {} ~DnsTlsSocket(); // Creates the SSL context for this session and connect. Returns false on failure. // Creates the SSL context for this session. Returns false on failure. // This method should be called after construction and before use of a DnsTlsSocket. // Only call this method once per DnsTlsSocket. bool initialize() EXCLUDES(mLock); // A blocking call to start handshaking until it finishes. bool startHandshake() EXCLUDES(mLock); // Send a query on the provided SSL socket. |query| contains // the body of a query, not including the ID header. This function will typically return before // the query is actually sent. If this function fails, DnsTlsSocketObserver will be Loading Loading @@ -112,6 +147,9 @@ class DnsTlsSocket : public IDnsTlsSocket { // the loop thread by incrementing mEventFd. loop() reads items off the queue. LockedQueue<std::vector<uint8_t>> mQueue; // Transition the state from expected state |from| to new state |to|. void transitionState(State from, State to) REQUIRES(mLock); // eventfd socket used for notifying the SSL thread when queries are ready to send. // This socket acts similarly to an atomic counter, incremented by query() and cleared // by loop(). We have to use a socket because the SSL thread needs to wait in poll() Loading @@ -131,6 +169,7 @@ class DnsTlsSocket : public IDnsTlsSocket { const DnsTlsServer mServer; IDnsTlsSocketObserver* _Nonnull const mObserver; DnsTlsSessionCache* _Nonnull const mCache; State mState GUARDED_BY(mLock) = State::UNINITIALIZED; }; } // end of namespace net Loading
DnsTlsTransport.cpp +6 −1 Original line number Diff line number Diff line Loading @@ -70,9 +70,14 @@ bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query& q) { void DnsTlsTransport::doConnect() { LOG(DEBUG) << "Constructing new socket"; mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache); bool success = true; if (mSocket.get() == nullptr || !mSocket->startHandshake()) { success = false; } mConnectCounter++; if (mSocket) { if (success) { auto queries = mQueries.getAll(); LOG(DEBUG) << "Initialization succeeded. Reissuing " << queries.size() << " queries."; for(auto& q : queries) { Loading
IDnsTlsSocket.h +2 −0 Original line number Diff line number Diff line Loading @@ -40,6 +40,8 @@ class IDnsTlsSocket { // notified that the socket is closed. // Note that a true return value indicates successful sending, not receipt of a response. virtual bool query(uint16_t id, const netdutils::Slice query) = 0; virtual bool startHandshake() = 0; }; } // end of namespace net Loading
resolv_tls_unit_test.cpp +69 −19 Original line number Diff line number Diff line Loading @@ -22,6 +22,7 @@ #include <android-base/logging.h> #include <android-base/macros.h> #include <gmock/gmock.h> #include <gtest/gtest.h> #include <netdutils/Slice.h> Loading @@ -39,8 +40,9 @@ namespace android { namespace net { using netdutils::Slice; using netdutils::makeSlice; using netdutils::Slice; using ::testing::NiceMock; typedef std::vector<uint8_t> bytevec; Loading Loading @@ -134,6 +136,7 @@ class FakeSocketEcho : public IDnsTlsSocket { std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach(); return true; } bool startHandshake() override { return true; } private: IDnsTlsSocketObserver* const mObserver; Loading Loading @@ -169,6 +172,7 @@ class FakeSocketId : public IDnsTlsSocket { std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, response).detach(); return true; } bool startHandshake() override { return true; } private: IDnsTlsSocketObserver* const mObserver; Loading Loading @@ -216,9 +220,15 @@ TEST_F(TransportTest, RacingQueries_10000) { class FakeSocketDelay : public IDnsTlsSocket { public: explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {} ~FakeSocketDelay() { std::lock_guard guard(mLock); } static size_t sDelay; static bool sReverse; ~FakeSocketDelay() { std::lock_guard guard(mLock); sDelay = 1; sReverse = false; sConnectable = true; } inline static size_t sDelay = 1; inline static bool sReverse = false; inline static bool sConnectable = true; bool query(uint16_t id, const Slice query) override { LOG(DEBUG) << "FakeSocketDelay got query with ID " << int(id); Loading @@ -236,6 +246,7 @@ class FakeSocketDelay : public IDnsTlsSocket { } return true; } bool startHandshake() override { return sConnectable; } private: void sendResponses() { Loading @@ -256,9 +267,6 @@ class FakeSocketDelay : public IDnsTlsSocket { std::vector<bytevec> mResponses GUARDED_BY(mLock); }; size_t FakeSocketDelay::sDelay; bool FakeSocketDelay::sReverse; TEST_F(TransportTest, ParallelColliding) { FakeSocketDelay::sDelay = 10; FakeSocketDelay::sReverse = false; Loading Loading @@ -424,13 +432,24 @@ class NullSocketFactory : public IDnsTlsSocketFactory { }; TEST_F(TransportTest, ConnectFail) { NullSocketFactory factory; DnsTlsTransport transport(SERVER1, MARK, &factory); auto r = transport.query(makeSlice(QUERY)).get(); // Failure on creating socket. NullSocketFactory factory1; DnsTlsTransport transport1(SERVER1, MARK, &factory1); auto r = transport1.query(makeSlice(QUERY)).get(); EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code); EXPECT_TRUE(r.response.empty()); EXPECT_EQ(transport.getConnectCounter(), 1); EXPECT_EQ(transport1.getConnectCounter(), 1); // Failure on handshaking. FakeSocketDelay::sConnectable = false; FakeSocketFactory<FakeSocketDelay> factory2; DnsTlsTransport transport2(SERVER1, MARK, &factory2); r = transport2.query(makeSlice(QUERY)).get(); EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code); EXPECT_TRUE(r.response.empty()); EXPECT_EQ(transport2.getConnectCounter(), 1); } // Simulate a socket that connects but then immediately receives a server Loading @@ -444,6 +463,7 @@ class FakeSocketClose : public IDnsTlsSocket { const Slice query ATTRIBUTE_UNUSED) override { return true; } bool startHandshake() override { return true; } private: std::thread mCloser; Loading Loading @@ -506,6 +526,7 @@ class FakeSocketLimited : public IDnsTlsSocket { } return mQueries <= sLimit; } bool startHandshake() override { return true; } private: void sendClose() { Loading Loading @@ -632,6 +653,7 @@ class FakeSocketGarbage : public IDnsTlsSocket { mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2)); return true; } bool startHandshake() override { return true; } private: std::mutex mLock; Loading Loading @@ -959,12 +981,10 @@ TEST(QueryMapTest, FillHole) { EXPECT_FALSE(map.recordQuery(makeSlice(QUERY))); } class StubObserver : public IDnsTlsSocketObserver { class MockDnsTlsSocketObserver : public IDnsTlsSocketObserver { public: bool closed = false; void onResponse(std::vector<uint8_t>) override {} void onClosed() override { closed = true; } MOCK_METHOD(void, onClosed, (), (override)); MOCK_METHOD(void, onResponse, (std::vector<uint8_t>), (override)); }; TEST(DnsTlsSocketTest, SlowDestructor) { Loading @@ -980,23 +1000,53 @@ TEST(DnsTlsSocketTest, SlowDestructor) { DnsTlsServer server; parseServer(tls_addr, 8530, &server.ss); StubObserver observer; ASSERT_FALSE(observer.closed); MockDnsTlsSocketObserver observer; DnsTlsSessionCache cache; auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache); ASSERT_TRUE(socket->initialize()); ASSERT_TRUE(socket->startHandshake()); // Test: Time the socket destructor. This should be fast. auto before = std::chrono::steady_clock::now(); EXPECT_CALL(observer, onClosed); socket.reset(); auto after = std::chrono::steady_clock::now(); auto delay = after - before; LOG(DEBUG) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns"; EXPECT_TRUE(observer.closed); // Shutdown should complete in milliseconds, but if the shutdown signal is lost // it will wait for the timeout, which is expected to take 20seconds. EXPECT_LT(delay, std::chrono::seconds{5}); } TEST(DnsTlsSocketTest, StartHandshake) { constexpr char tls_addr[] = "127.0.0.3"; constexpr char tls_port[] = "8530"; constexpr char backend_addr[] = "192.0.2.1"; constexpr char backend_port[] = "1"; test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port); ASSERT_TRUE(tls.startServer()); DnsTlsServer server; parseServer(tls_addr, 8530, &server.ss); // Use NiceMock to suppress the "uninteresting calls" warning. // (onClose will be called when running |socket|'s destructor) NiceMock<MockDnsTlsSocketObserver> observer; DnsTlsSessionCache cache; auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache); // Call the function before the call to initialize(). EXPECT_FALSE(socket->startHandshake()); // Call the function after the call to initialize(). EXPECT_TRUE(socket->initialize()); EXPECT_TRUE(socket->startHandshake()); // Call both of them again. EXPECT_FALSE(socket->initialize()); EXPECT_FALSE(socket->startHandshake()); } } // end of namespace net } // end of namespace android