Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 58eff5d4 authored by Mike Yu's avatar Mike Yu Committed by Gerrit Code Review
Browse files

Merge "Extend DnsTlsServer to store validation state"

parents 5bf80bee fa985f71
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -27,6 +27,9 @@
namespace android {
namespace net {

// Validation status of a DNS over TLS server (on a specific netId).
enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid };

// DnsTlsServer represents a recursive resolver that supports, or may support, a
// secure protocol.
struct DnsTlsServer {
@@ -37,17 +40,21 @@ struct DnsTlsServer {
    DnsTlsServer(const sockaddr_storage& ss) : ss(ss) {}

    // The server location, including IP and port.
    // TODO: make it const.
    sockaddr_storage ss = {};

    // The server's hostname.  If this string is nonempty, the server must present a
    // certificate that indicates this name and has a valid chain to a trusted root CA.
    // TODO: make it const.
    std::string name;

    // The certificate of the CA that signed the server's certificate.
    // It is used to store temporary test CA certificate for internal tests.
    // TODO: make it const.
    std::string certificate;

    // Placeholder.  More protocols might be defined in the future.
    // TODO: make it const.
    int protocol = IPPROTO_TCP;

    // Exact comparison of DnsTlsServer objects
@@ -55,6 +62,13 @@ struct DnsTlsServer {
    bool operator==(const DnsTlsServer& other) const;

    bool wasExplicitlyConfigured() const;

    Validation validationState() const { return mValidation; }
    void setValidationState(Validation val) { mValidation = val; }

  private:
    // State, unrelated to the comparison of DnsTlsServer objects.
    Validation mValidation = Validation::unknown_server;
};

// This comparison only checks the IP address.  It ignores ports, names, and fingerprints.
+37 −28
Original line number Diff line number Diff line
@@ -69,7 +69,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
               << ", " << servers.size() << ", " << name << ")";

    // Parse the list of servers that has been passed in
    std::set<DnsTlsServer> tlsServers;
    PrivateDnsTracker tmp;
    for (const auto& s : servers) {
        sockaddr_storage parsed;
        if (!parseServer(s.c_str(), &parsed)) {
@@ -78,13 +78,13 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
        DnsTlsServer server(parsed);
        server.name = name;
        server.certificate = caCert;
        tlsServers.insert(server);
        tmp[ServerIdentity(server)] = server;
    }

    std::lock_guard guard(mPrivateDnsLock);
    if (!name.empty()) {
        mPrivateDnsModes[netId] = PrivateDnsMode::STRICT;
    } else if (!tlsServers.empty()) {
    } else if (!tmp.empty()) {
        mPrivateDnsModes[netId] = PrivateDnsMode::OPPORTUNISTIC;
    } else {
        mPrivateDnsModes[netId] = PrivateDnsMode::OFF;
@@ -112,7 +112,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,

    // Remove any servers from the tracker that are not in |servers| exactly.
    for (auto it = tracker.begin(); it != tracker.end();) {
        if (tlsServers.count(it->first) == 0) {
        if (tmp.find(it->first) == tmp.end()) {
            it = tracker.erase(it);
        } else {
            ++it;
@@ -120,7 +120,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
    }

    // Add any new or changed servers to the tracker, and initiate async checks for them.
    for (const auto& server : tlsServers) {
    for (const auto& [identity, server] : tmp) {
        if (needsValidation(tracker, server)) {
            // This is temporarily required. Consider the following scenario, for example,
            //   Step 1) A DoTServer (s1) is set for the network. A validation (v1) for s1 starts.
@@ -133,7 +133,10 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
            //
            // If we didn't add servers to tracker before needValidateThread(), tracker would
            // become empty. We would report s1 validation failed.
            tracker[server] = Validation::in_process;
            if (tracker.find(identity) == tracker.end()) {
                tracker[identity] = server;
            }
            tracker[identity].setValidationState(Validation::in_process);
            LOG(DEBUG) << "Server " << addrToString(&server.ss) << " marked as in_process on netId "
                       << netId << ". Tracker now has size " << tracker.size();
            // This judge must be after "tracker[server] = Validation::in_process;"
@@ -141,7 +144,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,
                continue;
            }

            updateServerState(server, Validation::in_process, netId);
            updateServerState(identity, Validation::in_process, netId);
            startValidation(server, netId, mark);
        }
    }
@@ -159,8 +162,8 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) {

    const auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair != mPrivateDnsTransports.end()) {
        for (const auto& serverPair : netPair->second) {
            status.serversMap.emplace(serverPair.first, serverPair.second);
        for (const auto& [_, server] : netPair->second) {
            status.serversMap.emplace(server, server.validationState());
        }
    }

@@ -227,20 +230,21 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
                                                         bool success) {
    constexpr bool NEEDS_REEVALUATION = true;
    constexpr bool DONT_REEVALUATE = false;
    const ServerIdentity identity = ServerIdentity(server);

    std::lock_guard guard(mPrivateDnsLock);

    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        LOG(WARNING) << "netId " << netId << " was erased during private DNS validation";
        maybeNotifyObserver(server, Validation::fail, netId);
        maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
        return DONT_REEVALUATE;
    }

    const auto mode = mPrivateDnsModes.find(netId);
    if (mode == mPrivateDnsModes.end()) {
        LOG(WARNING) << "netId " << netId << " has no private DNS validation mode";
        maybeNotifyObserver(server, Validation::fail, netId);
        maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
        return DONT_REEVALUATE;
    }
    const bool modeDoesReevaluation = (mode->second == PrivateDnsMode::STRICT);
@@ -249,16 +253,13 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
            (success || !modeDoesReevaluation) ? DONT_REEVALUATE : NEEDS_REEVALUATION;

    auto& tracker = netPair->second;
    auto serverPair = tracker.find(server);
    auto serverPair = tracker.find(identity);
    if (serverPair == tracker.end()) {
        // TODO: Consider not adding this server to the tracker since this server is not expected
        // to be one of the private DNS servers for this network now. This could prevent this
        // server from being included when dumping status.
        LOG(WARNING) << "Server " << addrToString(&server.ss)
                     << " was removed during private DNS validation";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    } else if (!(serverPair->first == server)) {
    } else if (!(serverPair->second == server)) {
        // TODO: It doesn't seem correct to overwrite the tracker entry for
        // |server| down below in this circumstance... Fix this.
        LOG(WARNING) << "Server " << addrToString(&server.ss)
@@ -282,14 +283,14 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
    }

    if (success) {
        updateServerState(server, Validation::success, netId);
        updateServerState(identity, Validation::success, netId);
    } else {
        // Validation failure is expected if a user is on a captive portal.
        // TODO: Trigger a second validation attempt after captive portal login
        // succeeds.
        const auto result = (reevaluationStatus == NEEDS_REEVALUATION) ? Validation::in_process
                                                                       : Validation::fail;
        updateServerState(server, result, netId);
        updateServerState(identity, result, netId);
    }
    LOG(WARNING) << "Validation " << (success ? "success" : "failed");

@@ -324,15 +325,22 @@ bool PrivateDnsConfiguration::needValidateThread(const DnsTlsServer& server, uns
    }
}

void PrivateDnsConfiguration::updateServerState(const DnsTlsServer& server, Validation state,
void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
                                                uint32_t netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair != mPrivateDnsTransports.end()) {
    if (netPair == mPrivateDnsTransports.end()) {
        maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
        return;
    }

    auto& tracker = netPair->second;
        tracker[server] = state;
    if (tracker.find(identity) == tracker.end()) {
        maybeNotifyObserver(identity.ip.toString(), Validation::fail, netId);
        return;
    }

    maybeNotifyObserver(server, state, netId);
    tracker[identity].setValidationState(state);
    maybeNotifyObserver(identity.ip.toString(), state, netId);
}

void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& server,
@@ -353,8 +361,9 @@ void PrivateDnsConfiguration::cleanValidateThreadTracker(const DnsTlsServer& ser

bool PrivateDnsConfiguration::needsValidation(const PrivateDnsTracker& tracker,
                                              const DnsTlsServer& server) {
    const auto& iter = tracker.find(server);
    return (iter == tracker.end()) || (iter->second == Validation::fail);
    const ServerIdentity identity = ServerIdentity(server);
    const auto& iter = tracker.find(identity);
    return (iter == tracker.end()) || (iter->second.validationState() == Validation::fail);
}

void PrivateDnsConfiguration::setObserver(Observer* observer) {
@@ -362,10 +371,10 @@ void PrivateDnsConfiguration::setObserver(Observer* observer) {
    mObserver = observer;
}

void PrivateDnsConfiguration::maybeNotifyObserver(const DnsTlsServer& server, Validation validation,
                                                  uint32_t netId) const {
void PrivateDnsConfiguration::maybeNotifyObserver(const std::string& serverIp,
                                                  Validation validation, uint32_t netId) const {
    if (mObserver) {
        mObserver->onValidationStateUpdate(addrToString(&server.ss), validation, netId);
        mObserver->onValidationStateUpdate(serverIp, validation, netId);
    }
}

+29 −10
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@
#include <vector>

#include <android-base/thread_annotations.h>
#include <netdutils/InternetAddresses.h>

#include "DnsTlsServer.h"

@@ -31,11 +32,10 @@ namespace net {
// The DNS over TLS mode on a specific netId.
enum class PrivateDnsMode : uint8_t { OFF, OPPORTUNISTIC, STRICT };

// Validation status of a DNS over TLS server (on a specific netId).
enum class Validation : uint8_t { in_process, success, fail, unknown_server, unknown_netid };

struct PrivateDnsStatus {
    PrivateDnsMode mode;

    // TODO: change the type to std::vector<DnsTlsServer>.
    std::map<DnsTlsServer, Validation, AddressComparator> serversMap;

    std::list<DnsTlsServer> validatedServers() const {
@@ -65,8 +65,26 @@ class PrivateDnsConfiguration {

    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);

    struct ServerIdentity {
        const netdutils::IPAddress ip;
        const std::string name;
        const int protocol;

        explicit ServerIdentity(const DnsTlsServer& server)
            : ip(netdutils::IPSockAddr::toIPSockAddr(server.ss).ip()),
              name(server.name),
              protocol(server.protocol) {}

        bool operator<(const ServerIdentity& other) const {
            return std::tie(ip, name, protocol) < std::tie(other.ip, other.name, other.protocol);
        }
        bool operator==(const ServerIdentity& other) const {
            return std::tie(ip, name, protocol) == std::tie(other.ip, other.name, other.protocol);
        }
    };

  private:
    typedef std::map<DnsTlsServer, Validation, AddressComparator> PrivateDnsTracker;
    typedef std::map<ServerIdentity, DnsTlsServer> PrivateDnsTracker;
    typedef std::set<DnsTlsServer, AddressComparator> ThreadTracker;

    PrivateDnsConfiguration() = default;
@@ -88,7 +106,7 @@ class PrivateDnsConfiguration {
    bool needsValidation(const PrivateDnsTracker& tracker, const DnsTlsServer& server)
            REQUIRES(mPrivateDnsLock);

    void updateServerState(const DnsTlsServer& server, Validation state, uint32_t netId)
    void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
            REQUIRES(mPrivateDnsLock);

    std::mutex mPrivateDnsLock;
@@ -99,19 +117,20 @@ class PrivateDnsConfiguration {
    std::map<unsigned, ThreadTracker> mPrivateDnsValidateThreads GUARDED_BY(mPrivateDnsLock);

    // For testing. The observer is notified of onValidationStateUpdate 1) when a validation is
    // about to begin or 2) when a validation finishes.
    // WARNING: The Observer is notified while the lock is being held. Be careful not to call any
    // method of PrivateDnsConfiguration from the observer.
    // about to begin or 2) when a validation finishes. If a validation finishes when in OFF mode
    // or when the network has been destroyed, |validation| will be Validation::fail.
    // WARNING: The Observer is notified while the lock is being held. Be careful not to call
    // any method of PrivateDnsConfiguration from the observer.
    // TODO: fix the reentrancy problem.
    class Observer {
      public:
        virtual ~Observer(){};
        virtual void onValidationStateUpdate(const std::string& server, Validation validation,
        virtual void onValidationStateUpdate(const std::string& serverIp, Validation validation,
                                             uint32_t netId) = 0;
    };

    void setObserver(Observer* observer);
    void maybeNotifyObserver(const DnsTlsServer& server, Validation validation,
    void maybeNotifyObserver(const std::string& serverIp, Validation validation,
                             uint32_t netId) const REQUIRES(mPrivateDnsLock);

    Observer* mObserver GUARDED_BY(mPrivateDnsLock);
+37 −1
Original line number Diff line number Diff line
@@ -63,7 +63,8 @@ class PrivateDnsConfigurationTest : public ::testing::Test {
    class MockObserver : public PrivateDnsConfiguration::Observer {
      public:
        MOCK_METHOD(void, onValidationStateUpdate,
                    (const std::string& server, Validation validation, uint32_t netId), (override));
                    (const std::string& serverIp, Validation validation, uint32_t netId),
                    (override));

        std::map<std::string, Validation> getServerStateMap() const {
            std::lock_guard guard(lock);
@@ -172,6 +173,11 @@ TEST_F(PrivateDnsConfigurationTest, ValidationBlock) {
    backend.setDeferredResp(false);

    ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));

    // kServer1 is not a present server and thus should not be available from
    // PrivateDnsConfiguration::getStatus().
    mObserver.removeFromServerStateMap(kServer1);

    expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
}

@@ -218,6 +224,36 @@ TEST_F(PrivateDnsConfigurationTest, NoValidation) {
    expectStatus();
}

TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
    using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;

    DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
    server.name = "dns.example.com";
    server.protocol = 1;

    // Different IP address (port is ignored).
    DnsTlsServer other = server;
    EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
    other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 5353);
    EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
    other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.2", 853);
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));

    // Different provider hostname.
    other = server;
    EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
    other.name = "other.example.com";
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
    other.name = "";
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));

    // Different protocol.
    other = server;
    EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
    other.protocol++;
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
}

// TODO: add ValidationFail_Strict test.

}  // namespace android::net
+24 −0
Original line number Diff line number Diff line
@@ -800,6 +800,18 @@ void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
    EXPECT_FALSE(s2 == s1);
}

void checkEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
    EXPECT_TRUE(s1 == s1);
    EXPECT_TRUE(s2 == s2);
    EXPECT_TRUE(isAddressEqual(s1, s1));
    EXPECT_TRUE(isAddressEqual(s2, s2));

    EXPECT_FALSE(s1 < s2);
    EXPECT_FALSE(s2 < s1);
    EXPECT_TRUE(s1 == s2);
    EXPECT_TRUE(s2 == s1);
}

class ServerTest : public BaseTest {};

TEST_F(ServerTest, IPv4) {
@@ -873,6 +885,18 @@ TEST_F(ServerTest, Name) {
    EXPECT_TRUE(s2.wasExplicitlyConfigured());
}

TEST_F(ServerTest, State) {
    DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
    checkEqual(s1, s2);
    s1.setValidationState(Validation::success);
    checkEqual(s1, s2);
    s2.setValidationState(Validation::fail);
    checkEqual(s1, s2);

    EXPECT_EQ(s1.validationState(), Validation::success);
    EXPECT_EQ(s2.validationState(), Validation::fail);
}

TEST(QueryMapTest, Basic) {
    DnsTlsQueryMap map;