Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 2b5e9544 authored by Mike Yu's avatar Mike Yu Committed by Gerrit Code Review
Browse files

Merge "Support sending validation request to PrivateDnsConfiguration"

parents c91ba94a e60ab415
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -74,6 +74,12 @@ struct DnsTlsServer {
    Validation validationState() const { return mValidation; }
    void setValidationState(Validation val) { mValidation = val; }

    // The socket mark used for validation.
    // Note that the mark of a connection to which the DnsResolver sends app's DNS requests can
    // be different.
    // TODO: make it const.
    uint32_t mark = 0;

    // Return whether or not the server can be used for a network. It depends on
    // the resolver configuration.
    bool active() const { return mActive; }
+32 −0
Original line number Diff line number Diff line
@@ -70,6 +70,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
        DnsTlsServer server(parsed);
        server.name = name;
        server.certificate = caCert;
        server.mark = mark;
        tmp[ServerIdentity(server)] = server;
    }

@@ -140,6 +141,37 @@ void PrivateDnsConfiguration::clear(unsigned netId) {
    mPrivateDnsTransports.erase(netId);
}

bool PrivateDnsConfiguration::requestValidation(unsigned netId, const DnsTlsServer& server,
                                                uint32_t mark) {
    std::lock_guard guard(mPrivateDnsLock);
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        return false;
    }

    auto& tracker = netPair->second;
    const ServerIdentity identity = ServerIdentity(server);
    auto it = tracker.find(identity);
    if (it == tracker.end()) {
        return false;
    }

    const DnsTlsServer& target = it->second;

    if (!target.active()) return false;

    if (target.validationState() != Validation::success) return false;

    // Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
    // This is to protect validation from running on unexpected marks.
    // Validation should be associated with a mark gotten by system permission.
    if (target.mark != mark) return false;

    updateServerState(identity, Validation::in_process, netId);
    startValidation(target, netId, mark);
    return true;
}

void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
                                              uint32_t mark) REQUIRES(mPrivateDnsLock) {
    // Note that capturing |server| and |netId| in this lambda create copies.
+5 −0
Original line number Diff line number Diff line
@@ -65,6 +65,11 @@ class PrivateDnsConfiguration {

    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);

    // Request |server| to be revalidated on a connection tagged with |mark|.
    // Return true if the request is accepted; otherwise, return false.
    bool requestValidation(unsigned netId, const DnsTlsServer& server, uint32_t mark)
            EXCLUDES(mPrivateDnsLock);

    struct ServerIdentity {
        const netdutils::IPAddress ip;
        const std::string name;
+50 −0
Original line number Diff line number Diff line
@@ -254,6 +254,56 @@ TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
}

TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
    const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));

    testing::InSequence seq;

    for (const std::string_view config : {"SUCCESS", "IN_PROGRESS", "FAIL"}) {
        SCOPED_TRACE(config);

        EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
        if (config == "SUCCESS") {
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
        } else if (config == "IN_PROGRESS") {
            backend.setDeferredResp(true);
        } else {
            // config = "FAIL"
            ASSERT_TRUE(backend.stopServer());
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
        }
        EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
        expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);

        // Wait until the validation state is transitioned.
        const int runningThreads = (config == "IN_PROGRESS") ? 1 : 0;
        ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == runningThreads; }));

        bool requestAccepted = false;
        if (config == "SUCCESS") {
            EXPECT_CALL(mObserver,
                        onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
            requestAccepted = true;
        } else if (config == "IN_PROGRESS") {
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
        }

        EXPECT_EQ(mPdc.requestValidation(kNetId, server, kMark), requestAccepted);

        // Resending the same request or requesting nonexistent servers are denied.
        EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark));
        EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark + 1));
        EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, server, kMark));

        // Reset the test state.
        backend.setDeferredResp(false);
        backend.startServer();
        ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
        mPdc.clear(kNetId);
    }
}

// TODO: add ValidationFail_Strict test.

}  // namespace android::net