Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit fa985f71 authored by Mike Yu's avatar Mike Yu
Browse files

Extend DnsTlsServer to store validation state

This change also fixes a bug in PrivateDnsConfiguration which DoT
servers not valid for the network might somehow be counted as
validated servers. For instance, if a validation for DoT server
finishes after the server is removed, the server is mistakenly
deemed as a validated server.

Bug: 79727473
Test: cd packages/modules/DnsResolver && atest
Change-Id: Idee1f34f59dce1451b7b3e87fd20e6795af883ba
parent 3334a5e0
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -27,6 +27,9 @@
namespace android {
namespace net {

// Validation status of a DNS over TLS server (on a specific netId).
enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid };

// DnsTlsServer represents a recursive resolver that supports, or may support, a
// secure protocol.
struct DnsTlsServer {
@@ -37,17 +40,21 @@ struct DnsTlsServer {
    DnsTlsServer(const sockaddr_storage& ss) : ss(ss) {}

    // The server location, including IP and port.
    // TODO: make it const.
    sockaddr_storage ss = {};

    // The server's hostname.  If this string is nonempty, the server must present a
    // certificate that indicates this name and has a valid chain to a trusted root CA.
    // TODO: make it const.
    std::string name;

    // The certificate of the CA that signed the server's certificate.
    // It is used to store temporary test CA certificate for internal tests.
    // TODO: make it const.
    std::string certificate;

    // Placeholder.  More protocols might be defined in the future.
    // TODO: make it const.
    int protocol = IPPROTO_TCP;

    // Exact comparison of DnsTlsServer objects
@@ -55,6 +62,13 @@ struct DnsTlsServer {
    bool operator==(const DnsTlsServer& other) const;

    bool wasExplicitlyConfigured() const;

    Validation validationState() const { return mValidation; }
    void setValidationState(Validation val) { mValidation = val; }

  private:
    // State, unrelated to the comparison of DnsTlsServer objects.
    Validation mValidation = Validation::unknown_server;
};

// This comparison only checks the IP address.  It ignores ports, names, and fingerprints.
+37 −28
Original line number Diff line number Diff line
@@ -69,7 +69,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
               << ", " << servers.size() << ", " << name << ")";

    // Parse the list of servers that has been passed in
    std::set<DnsTlsServer> tlsServers;
    PrivateDnsTracker tmp;
    for (const auto& s : servers) {
        sockaddr_storage parsed;
        if (!parseServer(s.c_str(), &parsed)) {
@@ -78,13 +78,13 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
        DnsTlsServer server(parsed);
        server.name = name;
        server.certificate = caCert;
        tlsServers.insert(server);
        tmp[ServerIdentity(server)] = server;
    }

    std::lock_guard guard(mPrivateDnsLock);
    if (!name.empty()) {
        mPrivateDnsModes[netId] = PrivateDnsMode::STRICT;
    } else if (!tlsServers.empty()) {
    } else if (!tmp.empty()) {
        mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
    } else {
        mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
@@ -112,7 +112,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,

    // Remove any servers from the tracker that are not in |servers| exactly.
    for (auto it = tracker.begin(); it != tracker.end();) {
        if (tlsServers.count(it->first) == 0) {
        if (tmp.find(it->first) == tmp.end()) {
            it = tracker.erase(it);
        } else {
            ++it;
@@ -120,7 +120,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
    }

    // Add any new or changed servers to the tracker, and initiate async checks for them.
    for (const auto& server : tlsServers) {
    for (const auto& [identity, server] : tmp) {
        if (needsValidation(tracker, server)) {
            // This is temporarily required. Consider the following scenario, for example,
            //   Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts.
@@ -133,7 +133,10 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
            //
            // If we didn't add servers to tracker before needValidateThread(), tracker would
            // become empty. We would report s1 validation failed.
            tracker[server] = Validation::in_process;
            if (tracker.find(identity) == tracker.end()) {
                tracker[identity] = server;
            }
            tracker[identity].setValidationState(Validation::in_process);
            LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId "
                       << netId << ". Tracker now has size " << tracker.size();
            // This judge must be after "tracker[server] = Validation::in_process;"
@@ -141,7 +144,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
                continue;
            }

            updateServerState(server, Validation::in_process, netId);
            updateServerState(identity, Validation::in_process, netId);
            startValidation(server, netId, mark);
        }
    }
@@ -159,8 +162,8 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) {

    const auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair != mPrivateDnsTransports.end()) {
        for (const auto& serverPair : netPair->second) {
            status.serversMap.emplace(serverPair.first, serverPair.second);
        for (const auto& [_, server] : netPair->second) {
            status.serversMap.emplace(server, server.validationState());
        }
    }

@@ -227,20 +230,21 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
                                                         bool success) {
    constexpr bool NEEDS_REEVALUATION = true;
    constexpr bool DONT_REEVALUATE = false;
    const ServerIdentity identity = ServerIdentity(server);

    std::lock_guard guard(mPrivateDnsLock);

    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        LOG(WARNING) << "netId " << netId << " was erased during private DNS validation";
        maybeNotifyObserver(server, Validation::fail, netId);
        maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
        return DONT_REEVALUATE;
    }

    const auto mode = mPrivateDnsModes.find(netId);
    if (mode == mPrivateDnsModes.end()) {
        LOG(WARNING) << "netId " << netId << " has no private DNS validation mode";
        maybeNotifyObserver(server, Validation::fail, netId);
        maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
        return DONT_REEVALUATE;
    }
    const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT);
@@ -249,16 +253,13 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
            (success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION;

    auto& tracker = netPair->second;
    auto serverPair = tracker.find(server);
    auto serverPair = tracker.find(identity);
    if (serverPair == tracker.end()) {
        // TODO: Consider not adding this server to the tracker since this server is not expected
        // to be one of the private DNS servers for this network now. This could prevent this
        // server from being included when dumping status.
        LOG(WARNING) << "Server " << addrToString(&server.ss)
                     << " was removed during private DNS validation";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    } else if (!(serverPair->first == server)) {
    } else if (!(serverPair->second == server)) {
        // TODO: It doesn't seem correct to overwrite the tracker entry for
        // |server| down below in this circumstance... Fix this.
        LOG(WARNING) << "Server " << addrToString(&server.ss)
@@ -282,14 +283,14 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
    }

    if (success) {
        updateServerState(server, Validation::success, netId);
        updateServerState(identity, Validation::success, netId);
    } else {
        // Validation failure is expected if a user is on a captive portal.
        // TODO: Trigger a second validation attempt after captive portal login
        // succeeds.
        const auto result = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process
                                                                       : Validation::fail;
        updateServerState(server, result, netId);
        updateServerState(identity, result, netId);
    }
    LOG(WARNING) << "Validation " << (success ? "success" : "failed");

@@ -324,15 +325,22 @@ bool PrivateDnsConfiguration::needValidateThread(const DnsTlsServer& server, uns
    }
}

void PrivateDnsConfiguration::updateServerState(const DnsTlsServer& server, Validation state,
void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
                                                uint32_t netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair != mPrivateDnsTransports.end()) {
    if (netPair == mPrivateDnsTransports.end()) {
        maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
        return;
    }

    auto& tracker = netPair->second;
        tracker[server] = state;
    if (tracker.find(identity) == tracker.end()) {
        maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
        return;
    }

    maybeNotifyObserver(server, state, netId);
    tracker[identity].setValidationState(state);
    maybeNotifyObserver(identity.ip.toString(), state, netId);
}

void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server,
@@ -353,8 +361,9 @@ void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& ser

bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker,
                                              const DnsTlsServer& server) {
    const auto& iter = tracker.find(server);
    return (iter == tracker.end()) || (iter->second == Validation::fail);
    const ServerIdentity identity = ServerIdentity(server);
    const auto& iter = tracker.find(identity);
    return (iter == tracker.end()) || (iter->second.validationState() == Validation::fail);
}

void PrivateDnsConfiguration::setObserver(Observer* observer) {
@@ -362,10 +371,10 @@ void PrivateDnsConfiguration::setObserver(Observer* observer) {
    mObserver = observer;
}

void PrivateDnsConfiguration::maybeNotifyObserver(const DnsTlsServer& server, Validation validation,
                                                  uint32_t netId) const {
void PrivateDnsConfiguration::maybeNotifyObserver(const std::string& serverIp,
                                                  Validation validation, uint32_t netId) const {
    if (mObserver) {
        mObserver->onValidationStateUpdate(addrToString(&server.ss), validation, netId);
        mObserver->onValidationStateUpdate(serverIp, validation, netId);
    }
}

+29 −10
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@
#include <vector>

#include <android-base/thread_annotations.h>
#include <netdutils/InternetAddresses.h>

#include "DnsTlsServer.h"

@@ -31,11 +32,10 @@ namespace net {
// The DNS over TLS mode on a specific netId.
enum class PrivateDnsMode : uint8_t { OFF, OPPORTUNISTIC, STRICT };

// Validation status of a DNS over TLS server (on a specific netId).
enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid };

struct PrivateDnsStatus {
    PrivateDnsMode mode;

    // TODO: change the type to std::vector<DnsTlsServer>.
    std::map<DnsTlsServer, Validation, AddressComparator> serversMap;

    std::list<DnsTlsServer> validatedServers() const {
@@ -65,8 +65,26 @@ class PrivateDnsConfiguration {

    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);

    struct ServerIdentity {
        const netdutils::IPAddress ip;
        const std::string name;
        const int protocol;

        explicit ServerIdentity(const DnsTlsServer& server)
            : ip(netdutils::IPSockAddr::toIPSockAddr(server.ss).ip()),
              name(server.name),
              protocol(server.protocol) {}

        bool operator<(const ServerIdentity& other) const {
            return std::tie(ip, name, protocol) < std::tie(other.ip, other.name, other.protocol);
        }
        bool operator==(const ServerIdentity& other) const {
            return std::tie(ip, name, protocol) == std::tie(other.ip, other.name, other.protocol);
        }
    };

  private:
    typedef std::map<DnsTlsServer, Validation, AddressComparator> PrivateDnsTracker;
    typedef std::map<ServerIdentity, DnsTlsServer> PrivateDnsTracker;
    typedef std::set<DnsTlsServer, AddressComparator> ThreadTracker;

    PrivateDnsConfiguration() = default;
@@ -88,7 +106,7 @@ class PrivateDnsConfiguration {
    bool needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server)
            REQUIRES(mPrivateDnsLock);

    void updateServerState(const DnsTlsServer& server, Validation state, uint32_t netId)
    void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
            REQUIRES(mPrivateDnsLock);

    std::mutex mPrivateDnsLock;
@@ -99,19 +117,20 @@ class PrivateDnsConfiguration {
    std::map<unsigned, ThreadTracker> mPrivateDnsValidateThreads GUARDED_BY(mPrivateDnsLock);

    // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is
    // about to begin or 2) when a validation finishes.
    // WARNING: The Observer is notified while the lock is being held. Be careful not to call any
    // method of PrivateDnsConfiguration from the observer.
    // about to begin or 2) when a validation finishes. If a validation finishes when in OFF mode
    // or when the network has been destroyed, |validation| will be Validation::fail.
    // WARNING: The Observer is notified while the lock is being held. Be careful not to call
    // any method of PrivateDnsConfiguration from the observer.
    // TODO: fix the reentrancy problem.
    class Observer {
      public:
        virtual ~Observer(){};
        virtual void onValidationStateUpdate(const std::string& server, Validation validation,
        virtual void onValidationStateUpdate(const std::string& serverIp, Validation validation,
                                             uint32_t netId) = 0;
    };

    void setObserver(Observer* observer);
    void maybeNotifyObserver(const DnsTlsServer& server, Validation validation,
    void maybeNotifyObserver(const std::string& serverIp, Validation validation,
                             uint32_t netId) const REQUIRES(mPrivateDnsLock);

    Observer* mObserver GUARDED_BY(mPrivateDnsLock);
+37 −1
Original line number Diff line number Diff line
@@ -63,7 +63,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test {
    class MockObserver : public PrivateDnsConfiguration::Observer {
      public:
        MOCK_METHOD(void, onValidationStateUpdate,
                    (const std::string& server, Validation validation, uint32_t netId), (override));
                    (const std::string& serverIp, Validation validation, uint32_t netId),
                    (override));

        std::map<std::string, Validation> getServerStateMap() const {
            std::lock_guard guard(lock);
@@ -172,6 +173,11 @@ TEST_F(PrivateDnsConfigurationTest, ValidationBlock) {
    backend.setDeferredResp(false);

    ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));

    // kServer1 is not a present server and thus should not be available from
    // PrivateDnsConfiguration::getStatus().
    mObserver.removeFromServerStateMap(kServer1);

    expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
}

@@ -218,6 +224,36 @@ TEST_F(PrivateDnsConfigurationTest, NoValidation) {
    expectStatus();
}

TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
    using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;

    DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
    server.name = "dns.example.com";
    server.protocol = 1;

    // Different IP address (port is ignored).
    DnsTlsServer other = server;
    EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
    other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 5353);
    EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
    other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.2", 853);
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));

    // Different provider hostname.
    other = server;
    EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
    other.name = "other.example.com";
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
    other.name = "";
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));

    // Different protocol.
    other = server;
    EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
    other.protocol++;
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
}

// TODO: add ValidationFail_Strict test.

}  // namespace android::net
+24 −0
Original line number Diff line number Diff line
@@ -800,6 +800,18 @@ void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
    EXPECT_FALSE(s2 == s1);
}

void checkEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
    EXPECT_TRUE(s1 == s1);
    EXPECT_TRUE(s2 == s2);
    EXPECT_TRUE(isAddressEqual(s1, s1));
    EXPECT_TRUE(isAddressEqual(s2, s2));

    EXPECT_FALSE(s1 < s2);
    EXPECT_FALSE(s2 < s1);
    EXPECT_TRUE(s1 == s2);
    EXPECT_TRUE(s2 == s1);
}

class ServerTest : public BaseTest {};

TEST_F(ServerTest, IPv4) {
@@ -873,6 +885,18 @@ TEST_F(ServerTest, Name) {
    EXPECT_TRUE(s2.wasExplicitlyConfigured());
}

TEST_F(ServerTest, State) {
    DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
    checkEqual(s1, s2);
    s1.setValidationState(Validation::success);
    checkEqual(s1, s2);
    s2.setValidationState(Validation::fail);
    checkEqual(s1, s2);

    EXPECT_EQ(s1.validationState(), Validation::success);
    EXPECT_EQ(s2.validationState(), Validation::fail);
}

TEST(QueryMapTest, Basic) {
    DnsTlsQueryMap map;