Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ad96ef83 authored by Mike Yu's avatar Mike Yu
Browse files

Change most of PrivateDnsConfiguration methods to use ServerIdentity

Passing DnsTlsServer might be confusing because it's not straightforward
to know if a DnsTlsServer is a copy or onwed by PrivateDnsConfiguration.

This CL changes most of the methods to use ServerIdentity. The methods
can then get the corresponding DnsTlsServer by the new added method
getPrivateDns().

Bug: 186177613
Test: cd packages/modules/DnsResolver && atest
Change-Id: Ied4a4ee026862cd2c596586499cbfa7646eaaf2a
parent 05bf348e
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -214,8 +214,8 @@ DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, un
            // happens, the xport will be marked as unusable and DoT queries won't be sent to
            // it anymore. Eventually, after IDLE_TIMEOUT, the xport will be destroyed, and
            // a new xport will be created.
            const auto result =
                    PrivateDnsConfiguration::getInstance().requestValidation(netId, server, mark);
            const auto result = PrivateDnsConfiguration::getInstance().requestValidation(
                    netId, PrivateDnsConfiguration::ServerIdentity{server}, mark);
            LOG(WARNING) << "Requested validation for " << server.toIpString() << " with mark 0x"
                         << std::hex << mark << ", "
                         << (result.ok() ? "succeeded" : "failed: " + result.error().message());
+61 −47
Original line number Diff line number Diff line
@@ -110,7 +110,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,

        if (needsValidation(server)) {
            updateServerState(identity, Validation::in_process, netId);
            startValidation(server, netId, false);
            startValidation(identity, netId, false);
        }
    }

@@ -145,7 +145,7 @@ void PrivateDnsConfiguration::clear(unsigned netId) {
}

base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
                                                              const DnsTlsServer& server,
                                                              const ServerIdentity& identity,
                                                              uint32_t mark) {
    std::lock_guard guard(mPrivateDnsLock);

@@ -159,40 +159,39 @@ base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
        return Errorf("Private DNS setting is not opportunistic mode");
    }

    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        return Errorf("NetId not found in mPrivateDnsTransports");
    auto result = getPrivateDnsLocked(identity, netId);
    if (!result.ok()) {
        return result.error();
    }

    auto& tracker = netPair->second;
    const ServerIdentity identity = ServerIdentity(server);
    auto it = tracker.find(identity);
    if (it == tracker.end()) {
        return Errorf("Server was removed");
    }
    const DnsTlsServer* target = result.value();

    const DnsTlsServer& target = it->second;
    if (!target->active()) return Errorf("Server is not active");

    if (!target.active()) return Errorf("Server is not active");

    if (target.validationState() != Validation::success) {
    if (target->validationState() != Validation::success) {
        return Errorf("Server validation state mismatched");
    }

    // Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
    // This is to protect validation from running on unexpected marks.
    // Validation should be associated with a mark gotten by system permission.
    if (target.mark != mark) return Errorf("Socket mark mismatched");
    if (target->mark != mark) return Errorf("Socket mark mismatched");

    updateServerState(identity, Validation::in_process, netId);
    startValidation(target, netId, true);
    startValidation(identity, netId, true);
    return {};
}

void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
                                              bool isRevalidation) REQUIRES(mPrivateDnsLock) {
    // Note that capturing |server|, |netId|, and |isRevalidation| in this lambda create copies.
    std::thread validate_thread([this, server, netId, isRevalidation] {
void PrivateDnsConfiguration::startValidation(const ServerIdentity& identity, unsigned netId,
                                              bool isRevalidation) {
    // This ensures that the thread sends probe at least once in case
    // the server is removed before the thread starts running.
    // TODO: consider moving these code to the thread.
    const auto result = getPrivateDnsLocked(identity, netId);
    if (!result.ok()) return;
    DnsTlsServer server = *result.value();

    std::thread validate_thread([this, identity, server, netId, isRevalidation] {
        setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());

        // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
@@ -223,7 +222,7 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign
                         << server.toIpString();

            const bool needs_reeval =
                    this->recordPrivateDnsValidation(server, netId, success, isRevalidation);
                    this->recordPrivateDnsValidation(identity, netId, success, isRevalidation);

            if (!needs_reeval) {
                break;
@@ -240,11 +239,11 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign
    validate_thread.detach();
}

void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const DnsTlsServer& server,
void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const ServerIdentity& identity,
                                                            unsigned netId, bool success) {
    LOG(DEBUG) << "Sending validation " << (success ? "success" : "failure") << " event on netId "
               << netId << " for " << server.toIpString() << " with hostname {" << server.name
               << "}";
               << netId << " for " << identity.sockaddr.ip().toString() << " with hostname {"
               << identity.provider << "}";
    // Send a validation event to NetdEventListenerService.
    const auto& listeners = ResolverEventReporter::getInstance().getListeners();
    if (listeners.empty()) {
@@ -252,15 +251,16 @@ void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const DnsTlsServer&
                << "Validation event not sent since no INetdEventListener receiver is available.";
    }
    for (const auto& it : listeners) {
        it->onPrivateDnsValidationEvent(netId, server.toIpString(), server.name, success);
        it->onPrivateDnsValidationEvent(netId, identity.sockaddr.ip().toString(), identity.provider,
                                        success);
    }

    // Send a validation event to unsolicited event listeners.
    const auto& unsolEventListeners = ResolverEventReporter::getInstance().getUnsolEventListeners();
    const PrivateDnsValidationEventParcel validationEvent = {
            .netId = static_cast<int32_t>(netId),
            .ipAddress = server.toIpString(),
            .hostname = server.name,
            .ipAddress = identity.sockaddr.ip().toString(),
            .hostname = identity.provider,
            .validation = success ? IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_SUCCESS
                                  : IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_FAILURE,
    };
@@ -269,11 +269,11 @@ void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const DnsTlsServer&
    }
}

bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId,
                                                         bool success, bool isRevalidation) {
bool PrivateDnsConfiguration::recordPrivateDnsValidation(const ServerIdentity& identity,
                                                         unsigned netId, bool success,
                                                         bool isRevalidation) {
    constexpr bool NEEDS_REEVALUATION = true;
    constexpr bool DONT_REEVALUATE = false;
    const ServerIdentity identity = ServerIdentity(server);

    std::lock_guard guard(mPrivateDnsLock);

@@ -303,23 +303,19 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
    auto& tracker = netPair->second;
    auto serverPair = tracker.find(identity);
    if (serverPair == tracker.end()) {
        LOG(WARNING) << "Server " << server.toIpString()
        LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
                     << " was removed during private DNS validation";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    } else if (!(serverPair->second == server)) {
        LOG(WARNING) << "Server " << server.toIpString()
                     << " was changed during private DNS validation";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    } else if (!serverPair->second.active()) {
        LOG(WARNING) << "Server " << server.toIpString() << " was removed from the configuration";
        LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
                     << " was removed from the configuration";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    }

    // Send private dns validation result to listeners.
    sendPrivateDnsValidationEvent(server, netId, success);
    sendPrivateDnsValidationEvent(identity, netId, success);

    if (success) {
        updateServerState(identity, Validation::success, netId);
@@ -338,19 +334,15 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser

void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
                                                uint32_t netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
    const auto result = getPrivateDnsLocked(identity, netId);
    if (!result.ok()) {
        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
        return;
    }

    auto& tracker = netPair->second;
    if (tracker.find(identity) == tracker.end()) {
        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
        return;
    }
    auto* server = result.value();

    tracker[identity].setValidationState(state);
    server->setValidationState(state);
    notifyValidationStateUpdate(identity.sockaddr, state, netId);

    RecordEntry record(netId, identity, state);
@@ -373,6 +365,28 @@ bool PrivateDnsConfiguration::needsValidation(const DnsTlsServer& server) {
    return false;
}

base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDns(const ServerIdentity& identity,
                                                                   unsigned netId) {
    std::lock_guard guard(mPrivateDnsLock);
    return getPrivateDnsLocked(identity, netId);
}

base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
        const ServerIdentity& identity, unsigned netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        return Errorf("Failed to get private DNS: netId {} not found", netId);
    }

    auto iter = netPair->second.find(identity);
    if (iter == netPair->second.end()) {
        return Errorf("Failed to get private DNS: server {{{}/{}}} not found", identity.sockaddr,
                      identity.provider);
    }

    return &iter->second;
}

void PrivateDnsConfiguration::setObserver(PrivateDnsValidationObserver* observer) {
    std::lock_guard guard(mPrivateDnsLock);
    mObserver = observer;
+30 −21
Original line number Diff line number Diff line
@@ -33,6 +33,7 @@
namespace android {
namespace net {

// TODO: decouple the dependency of DnsTlsServer.
struct PrivateDnsStatus {
    PrivateDnsMode mode;

@@ -53,24 +54,6 @@ struct PrivateDnsStatus {

class PrivateDnsConfiguration {
  public:
    // The only instance of PrivateDnsConfiguration.
    static PrivateDnsConfiguration& getInstance() {
        static PrivateDnsConfiguration instance;
        return instance;
    }

    int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
            const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);

    PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);

    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);

    // Request |server| to be revalidated on a connection tagged with |mark|.
    // Returns a Result to indicate if the request is accepted.
    base::Result<void> requestValidation(unsigned netId, const DnsTlsServer& server, uint32_t mark)
            EXCLUDES(mPrivateDnsLock);

    struct ServerIdentity {
        const netdutils::IPSockAddr sockaddr;
        const std::string provider;
@@ -86,6 +69,24 @@ class PrivateDnsConfiguration {
        }
    };

    // The only instance of PrivateDnsConfiguration.
    static PrivateDnsConfiguration& getInstance() {
        static PrivateDnsConfiguration instance;
        return instance;
    }

    int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
            const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);

    PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);

    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);

    // Request the server to be revalidated on a connection tagged with |mark|.
    // Returns a Result to indicate if the request is accepted.
    base::Result<void> requestValidation(unsigned netId, const ServerIdentity& identity,
                                         uint32_t mark) EXCLUDES(mPrivateDnsLock);

    void setObserver(PrivateDnsValidationObserver* observer);

    void dump(netdutils::DumpWriter& dw) const;
@@ -97,23 +98,31 @@ class PrivateDnsConfiguration {

    // Launchs a thread to run the validation for |server| on the network |netId|.
    // |isRevalidation| is true if this call is due to a revalidation request.
    void startValidation(const DnsTlsServer& server, unsigned netId, bool isRevalidation)
    void startValidation(const ServerIdentity& identity, unsigned netId, bool isRevalidation)
            REQUIRES(mPrivateDnsLock);

    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success,
    bool recordPrivateDnsValidation(const ServerIdentity& identity, unsigned netId, bool success,
                                    bool isRevalidation) EXCLUDES(mPrivateDnsLock);

    void sendPrivateDnsValidationEvent(const DnsTlsServer& server, unsigned netId, bool success)
    void sendPrivateDnsValidationEvent(const ServerIdentity& identity, unsigned netId, bool success)
            REQUIRES(mPrivateDnsLock);

    // Decide if a validation for |server| is needed. Note that servers that have failed
    // multiple validation attempts but for which there is still a validating
    // thread running are marked as being Validation::in_process.
    // TODO: decouple the dependency of DnsTlsServer.
    bool needsValidation(const DnsTlsServer& server) REQUIRES(mPrivateDnsLock);

    void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
            REQUIRES(mPrivateDnsLock);

    // For testing.
    base::Result<DnsTlsServer*> getPrivateDns(const ServerIdentity& identity, unsigned netId)
            EXCLUDES(mPrivateDnsLock);

    base::Result<DnsTlsServer*> getPrivateDnsLocked(const ServerIdentity& identity, unsigned netId)
            REQUIRES(mPrivateDnsLock);

    mutable std::mutex mPrivateDnsLock;
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);

+33 −8
Original line number Diff line number Diff line
@@ -28,6 +28,8 @@ using namespace std::chrono_literals;

class PrivateDnsConfigurationTest : public ::testing::Test {
  public:
    using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;

    static void SetUpTestSuite() {
        // stopServer() will be called in their destructor.
        ASSERT_TRUE(tls1.startServer());
@@ -100,6 +102,10 @@ class PrivateDnsConfigurationTest : public ::testing::Test {
        return (serverStateMap == mObserver.getServerStateMap());
    }

    bool hasPrivateDnsServer(const ServerIdentity& identity, unsigned netId) {
        return mPdc.getPrivateDns(identity, netId).ok();
    }

    static constexpr uint32_t kNetId = 30;
    static constexpr uint32_t kMark = 30;
    static constexpr char kBackend[] = "127.0.2.1";
@@ -230,8 +236,6 @@ TEST_F(PrivateDnsConfigurationTest, NoValidation) {
}

TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
    using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;

    DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
    server.name = "dns.example.com";

@@ -254,6 +258,7 @@ TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {

TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
    const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
    const ServerIdentity identity(server);

    testing::InSequence seq;

@@ -281,18 +286,18 @@ TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
            EXPECT_CALL(mObserver,
                        onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
            EXPECT_TRUE(mPdc.requestValidation(kNetId, server, kMark).ok());
            EXPECT_TRUE(mPdc.requestValidation(kNetId, identity, kMark).ok());
        } else if (config == "IN_PROGRESS") {
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
            EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
            EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
        } else if (config == "FAIL") {
            EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
            EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
        }

        // Resending the same request or requesting nonexistent servers are denied.
        EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark + 1).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, server, kMark).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark + 1).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, identity, kMark).ok());

        // Reset the test state.
        backend.setDeferredResp(false);
@@ -306,6 +311,26 @@ TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
    }
}

TEST_F(PrivateDnsConfigurationTest, GetPrivateDns) {
    const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
    const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853));

    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));

    // Suppress the warning.
    EXPECT_CALL(mObserver, onValidationStateUpdate).Times(2);

    EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
    expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);

    EXPECT_TRUE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId + 1));

    ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
}

// TODO: add ValidationFail_Strict test.

}  // namespace android::net