Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 441d9378 authored by Mike Yu's avatar Mike Yu
Browse files

Add startHandshake to IDnsTlsSocket

This is a refactor change which separates the handshake code from
DnsTlsSocket::initialize(). The plan is that initialize() will
continue running on query threads but the code for connection
handshake will run on either query threads or loop threads depending
on a flag.

Bug: 149445907
Test: cd packages/modules/DnsResolver && atest
Change-Id: I262f978230fb1a01ca7963de03b64cb439a37eec
parent 62978a84
Loading
Loading
Loading
Loading
+17 −2
Original line number Diff line number Diff line
@@ -169,6 +169,23 @@ bool DnsTlsSocket::initialize() {
    // Enable session cache
    mCache->prepareSslContext(mSslCtx.get());

    mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));

    return true;
}

bool DnsTlsSocket::startHandshake() {
    std::lock_guard guard(mLock);
    if (!mSslCtx) {
        LOG(ERROR) << "Calling startHandshake before initializing";
        return false;
    }

    if (mLoopThread) {
        LOG(WARNING) << "The loop thread has been created. Ignore the handshake request";
        return false;
    }

    // Connect
    Status status = tcpConnect();
    if (!status.ok()) {
@@ -179,8 +196,6 @@ bool DnsTlsSocket::initialize() {
        return false;
    }

    mEventFd.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));

    // Start the I/O loop.
    mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this));

+25 −1
Original line number Diff line number Diff line
@@ -44,6 +44,27 @@ class DnsTlsSessionCache;
// or the destructor in a callback.  Doing so will result in deadlocks.
// This class may call the observer at any time after initialize(), until the destructor
// returns (but not after).
//
// Calls to IDnsTlsSocketObserver in a DnsTlsSocket life cycle:
//
//                                    START
//                                      |
//                                      v
//                            +--startHandshake()--+
//            Handshake fails |                    | Handshake succeeds
//                            |                    |
//                            |                    v
//                            |        +--------> loop --+
//                            |        |           |     |
//                            |        +-----------+     | Idle timeout
//                            |   Send/Recv queries      | onClose()
//                            |   onResponse()           |
//                            |                          |
//                            |                          |
//                            +------> END <-------------+
//
//
// TODO: Add onHandshakeFinished() for handshake results.
class DnsTlsSocket : public IDnsTlsSocket {
  public:
    DnsTlsSocket(const DnsTlsServer& server, unsigned mark,
@@ -51,11 +72,14 @@ class DnsTlsSocket : public IDnsTlsSocket {
        : mMark(mark), mServer(server), mObserver(observer), mCache(cache) {}
    ~DnsTlsSocket();

    // Creates the SSL context for this session and connect.  Returns false on failure.
    // Creates the SSL context for this session. Returns false on failure.
    // This method should be called after construction and before use of a DnsTlsSocket.
    // Only call this method once per DnsTlsSocket.
    bool initialize() EXCLUDES(mLock);

    // A blocking call to start handshaking until it finishes.
    bool startHandshake() EXCLUDES(mLock);

    // Send a query on the provided SSL socket.  |query| contains
    // the body of a query, not including the ID header. This function will typically return before
    // the query is actually sent.  If this function fails, DnsTlsSocketObserver will be
+6 −1
Original line number Diff line number Diff line
@@ -70,9 +70,14 @@ bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query& q) {
void DnsTlsTransport::doConnect() {
    LOG(DEBUG) << "Constructing new socket";
    mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache);

    bool success = true;
    if (mSocket.get() == nullptr || !mSocket->startHandshake()) {
        success = false;
    }
    mConnectCounter++;

    if (mSocket) {
    if (success) {
        auto queries = mQueries.getAll();
        LOG(DEBUG) << "Initialization succeeded.  Reissuing " << queries.size() << " queries.";
        for(auto& q : queries) {
+2 −0
Original line number Diff line number Diff line
@@ -40,6 +40,8 @@ class IDnsTlsSocket {
    // notified that the socket is closed.
    // Note that a true return value indicates successful sending, not receipt of a response.
    virtual bool query(uint16_t id, const netdutils::Slice query) = 0;

    virtual bool startHandshake() = 0;
};

}  // end of namespace net
+60 −10
Original line number Diff line number Diff line
@@ -134,6 +134,7 @@ class FakeSocketEcho : public IDnsTlsSocket {
        std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
        return true;
    }
    bool startHandshake() override { return true; }

  private:
    IDnsTlsSocketObserver* const mObserver;
@@ -169,6 +170,7 @@ class FakeSocketId : public IDnsTlsSocket {
        std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, response).detach();
        return true;
    }
    bool startHandshake() override { return true; }

  private:
    IDnsTlsSocketObserver* const mObserver;
@@ -216,9 +218,15 @@ TEST_F(TransportTest, RacingQueries_10000) {
class FakeSocketDelay : public IDnsTlsSocket {
  public:
    explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
    ~FakeSocketDelay() { std::lock_guard guard(mLock); }
    static size_t sDelay;
    static bool sReverse;
    ~FakeSocketDelay() {
        std::lock_guard guard(mLock);
        sDelay = 1;
        sReverse = false;
        sConnectable = true;
    }
    inline static size_t sDelay = 1;
    inline static bool sReverse = false;
    inline static bool sConnectable = true;

    bool query(uint16_t id, const Slice query) override {
        LOG(DEBUG) << "FakeSocketDelay got query with ID " << int(id);
@@ -236,6 +244,7 @@ class FakeSocketDelay : public IDnsTlsSocket {
        }
        return true;
    }
    bool startHandshake() override { return sConnectable; }

  private:
    void sendResponses() {
@@ -256,9 +265,6 @@ class FakeSocketDelay : public IDnsTlsSocket {
    std::vector<bytevec> mResponses GUARDED_BY(mLock);
};

size_t FakeSocketDelay::sDelay;
bool FakeSocketDelay::sReverse;

TEST_F(TransportTest, ParallelColliding) {
    FakeSocketDelay::sDelay = 10;
    FakeSocketDelay::sReverse = false;
@@ -424,13 +430,24 @@ class NullSocketFactory : public IDnsTlsSocketFactory {
};

TEST_F(TransportTest, ConnectFail) {
    NullSocketFactory factory;
    DnsTlsTransport transport(SERVER1, MARK, &factory);
    auto r = transport.query(makeSlice(QUERY)).get();
    // Failure on creating socket.
    NullSocketFactory factory1;
    DnsTlsTransport transport1(SERVER1, MARK, &factory1);
    auto r = transport1.query(makeSlice(QUERY)).get();

    EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
    EXPECT_TRUE(r.response.empty());
    EXPECT_EQ(transport.getConnectCounter(), 1);
    EXPECT_EQ(transport1.getConnectCounter(), 1);

    // Failure on handshaking.
    FakeSocketDelay::sConnectable = false;
    FakeSocketFactory<FakeSocketDelay> factory2;
    DnsTlsTransport transport2(SERVER1, MARK, &factory2);
    r = transport2.query(makeSlice(QUERY)).get();

    EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
    EXPECT_TRUE(r.response.empty());
    EXPECT_EQ(transport2.getConnectCounter(), 1);
}

// Simulate a socket that connects but then immediately receives a server
@@ -444,6 +461,7 @@ class FakeSocketClose : public IDnsTlsSocket {
               const Slice query ATTRIBUTE_UNUSED) override {
        return true;
    }
    bool startHandshake() override { return true; }

  private:
    std::thread mCloser;
@@ -506,6 +524,7 @@ class FakeSocketLimited : public IDnsTlsSocket {
        }
        return mQueries <= sLimit;
    }
    bool startHandshake() override { return true; }

  private:
    void sendClose() {
@@ -632,6 +651,7 @@ class FakeSocketGarbage : public IDnsTlsSocket {
        mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
        return true;
    }
    bool startHandshake() override { return true; }

  private:
    std::mutex mLock;
@@ -985,6 +1005,7 @@ TEST(DnsTlsSocketTest, SlowDestructor) {
    DnsTlsSessionCache cache;
    auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache);
    ASSERT_TRUE(socket->initialize());
    ASSERT_TRUE(socket->startHandshake());

    // Test: Time the socket destructor.  This should be fast.
    auto before = std::chrono::steady_clock::now();
@@ -998,5 +1019,34 @@ TEST(DnsTlsSocketTest, SlowDestructor) {
    EXPECT_LT(delay, std::chrono::seconds{5});
}

TEST(DnsTlsSocketTest, StartHandshake) {
    constexpr char tls_addr[] = "127.0.0.3";
    constexpr char tls_port[] = "8530";
    constexpr char backend_addr[] = "192.0.2.1";
    constexpr char backend_port[] = "1";

    test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port);
    ASSERT_TRUE(tls.startServer());

    DnsTlsServer server;
    parseServer(tls_addr, 8530, &server.ss);

    StubObserver observer;
    ASSERT_FALSE(observer.closed);
    DnsTlsSessionCache cache;
    auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache);

    // Call the function before the call to initialize().
    EXPECT_FALSE(socket->startHandshake());

    // Call the function after the call to initialize().
    EXPECT_TRUE(socket->initialize());
    EXPECT_TRUE(socket->startHandshake());

    // Call both of them again.
    EXPECT_FALSE(socket->initialize());
    EXPECT_FALSE(socket->startHandshake());
}

} // end of namespace net
} // end of namespace android