Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e60ab415 authored by Mike Yu's avatar Mike Yu
Browse files

Support sending validation request to PrivateDnsConfiguration

Extend PrivateDnsConfiguration to support validation request.

The request is deniable. If the request is denied, no validation
starts. Callers can know if requests are accepted by the return
value of the call.

This change also extends DnsTlsServer to store the mark used by
validation, which helps on preventing running validation with
an unexpected socket mark and resulting in updating wrong validation
state.

Bug: 79727473
Test: cd packages/modules/DnsResolver && atest
Change-Id: Ib92f6b4dd94ed426bf28cb9756d1514e34f16140
parent c9f623a1
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -74,6 +74,12 @@ struct DnsTlsServer {
    Validation validationState() const { return mValidation; }
    void setValidationState(Validation val) { mValidation = val; }

    // The socket mark used for validation.
    // Note that the mark of a connection to which the DnsResolver sends app's DNS requests can
    // be different.
    // TODO: make it const.
    uint32_t mark = 0;

    // Return whether or not the server can be used for a network. It depends on
    // the resolver configuration.
    bool active() const { return mActive; }
+32 −0
Original line number Diff line number Diff line
@@ -70,6 +70,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
        DnsTlsServer server(parsed);
        server.name = name;
        server.certificate = caCert;
        server.mark = mark;
        tmp[ServerIdentity(server)] = server;
    }

@@ -140,6 +141,37 @@ void PrivateDnsConfiguration::clear(unsigned netId) {
    mPrivateDnsTransports.erase(netId);
}

bool PrivateDnsConfiguration::requestValidation(unsigned netId, const DnsTlsServer& server,
                                                uint32_t mark) {
    std::lock_guard guard(mPrivateDnsLock);
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        return false;
    }

    auto& tracker = netPair->second;
    const ServerIdentity identity = ServerIdentity(server);
    auto it = tracker.find(identity);
    if (it == tracker.end()) {
        return false;
    }

    const DnsTlsServer& target = it->second;

    if (!target.active()) return false;

    if (target.validationState() != Validation::success) return false;

    // Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
    // This is to protect validation from running on unexpected marks.
    // Validation should be associated with a mark gotten by system permission.
    if (target.mark != mark) return false;

    updateServerState(identity, Validation::in_process, netId);
    startValidation(target, netId, mark);
    return true;
}

void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
                                              uint32_t mark) REQUIRES(mPrivateDnsLock) {
    // Note that capturing |server| and |netId| in this lambda create copies.
+5 −0
Original line number Diff line number Diff line
@@ -65,6 +65,11 @@ class PrivateDnsConfiguration {

    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);

    // Request |server| to be revalidated on a connection tagged with |mark|.
    // Return true if the request is accepted; otherwise, return false.
    bool requestValidation(unsigned netId, const DnsTlsServer& server, uint32_t mark)
            EXCLUDES(mPrivateDnsLock);

    struct ServerIdentity {
        const netdutils::IPAddress ip;
        const std::string name;
+50 −0
Original line number Diff line number Diff line
@@ -254,6 +254,56 @@ TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
}

TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
    const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));

    testing::InSequence seq;

    for (const std::string_view config : {"SUCCESS", "IN_PROGRESS", "FAIL"}) {
        SCOPED_TRACE(config);

        EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
        if (config == "SUCCESS") {
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
        } else if (config == "IN_PROGRESS") {
            backend.setDeferredResp(true);
        } else {
            // config = "FAIL"
            ASSERT_TRUE(backend.stopServer());
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::fail, kNetId));
        }
        EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
        expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);

        // Wait until the validation state is transitioned.
        const int runningThreads = (config == "IN_PROGRESS") ? 1 : 0;
        ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == runningThreads; }));

        bool requestAccepted = false;
        if (config == "SUCCESS") {
            EXPECT_CALL(mObserver,
                        onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
            requestAccepted = true;
        } else if (config == "IN_PROGRESS") {
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
        }

        EXPECT_EQ(mPdc.requestValidation(kNetId, server, kMark), requestAccepted);

        // Resending the same request or requesting nonexistent servers are denied.
        EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark));
        EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark + 1));
        EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, server, kMark));

        // Reset the test state.
        backend.setDeferredResp(false);
        backend.startServer();
        ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
        mPdc.clear(kNetId);
    }
}

// TODO: add ValidationFail_Strict test.

}  // namespace android::net