Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 3e1790b2 authored by Mike Yu's avatar Mike Yu Committed by Automerger Merge Worker
Browse files

Some renaming code for DoT in PrivateDnsConfiguration am: 2bee9337

parents 75a89452 2bee9337
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -229,7 +229,7 @@ DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, un
            // happens, the xport will be marked as unusable and DoT queries won't be sent to
            // it anymore. Eventually, after IDLE_TIMEOUT, the xport will be destroyed, and
            // a new xport will be created.
            const auto result = PrivateDnsConfiguration::getInstance().requestValidation(
            const auto result = PrivateDnsConfiguration::getInstance().requestDotValidation(
                    netId, PrivateDnsConfiguration::ServerIdentity{server}, mark);
            LOG(WARNING) << "Requested validation for " << server.toIpString() << " with mark 0x"
                         << std::hex << mark << ", "
+28 −29
Original line number Diff line number Diff line
@@ -96,7 +96,7 @@ int PrivateDnsConfiguration::setDot(int32_t netId, uint32_t mark,
                                    const std::vector<std::string>& servers,
                                    const std::string& name, const std::string& caCert) {
    // Parse the list of servers that has been passed in
    PrivateDnsTracker tmp;
    std::map<ServerIdentity, DnsTlsServer> tmp;
    for (const auto& s : servers) {
        // The IP addresses are guaranteed to be valid.
        DnsTlsServer server(IPAddress::forString(s));
@@ -107,7 +107,7 @@ int PrivateDnsConfiguration::setDot(int32_t netId, uint32_t mark,
    }

    // Create the tracker if it was not present
    auto& tracker = mPrivateDnsTransports[netId];
    auto& tracker = mDotTracker[netId];

    // Add the servers if not contained in tracker.
    for (const auto& [identity, server] : tmp) {
@@ -127,7 +127,7 @@ int PrivateDnsConfiguration::setDot(int32_t netId, uint32_t mark,

        if (needsValidation(server)) {
            updateServerState(identity, Validation::in_process, netId);
            startValidation(identity, netId, false);
            startDotValidation(identity, netId, false);
        }
    }

@@ -135,7 +135,7 @@ int PrivateDnsConfiguration::setDot(int32_t netId, uint32_t mark,
}

void PrivateDnsConfiguration::clearDot(int32_t netId) {
    mPrivateDnsTransports.erase(netId);
    mDotTracker.erase(netId);
    resolv_stats_set_addrs(netId, PROTO_DOT, {}, kDotPort);
}

@@ -151,8 +151,8 @@ PrivateDnsStatus PrivateDnsConfiguration::getStatus(unsigned netId) const {
    if (mode == mPrivateDnsModes.end()) return status;
    status.mode = mode->second;

    const auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair != mPrivateDnsTransports.end()) {
    const auto netPair = mDotTracker.find(netId);
    if (netPair != mDotTracker.end()) {
        for (const auto& [_, server] : netPair->second) {
            if (server.active()) {
                status.dotServersMap.emplace(server, server.validationState());
@@ -181,7 +181,7 @@ void PrivateDnsConfiguration::clear(unsigned netId) {
    mCv.notify_all();
}

base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
base::Result<void> PrivateDnsConfiguration::requestDotValidation(unsigned netId,
                                                                 const ServerIdentity& identity,
                                                                 uint32_t mark) {
    std::lock_guard guard(mPrivateDnsLock);
@@ -196,7 +196,7 @@ base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
        return Errorf("Private DNS setting is not opportunistic mode");
    }

    auto result = getPrivateDnsLocked(identity, netId);
    auto result = getDotServerLocked(identity, netId);
    if (!result.ok()) {
        return result.error();
    }
@@ -215,16 +215,16 @@ base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
    if (target->validationMark() != mark) return Errorf("Socket mark mismatched");

    updateServerState(identity, Validation::in_process, netId);
    startValidation(identity, netId, true);
    startDotValidation(identity, netId, true);
    return {};
}

void PrivateDnsConfiguration::startValidation(const ServerIdentity& identity, unsigned netId,
void PrivateDnsConfiguration::startDotValidation(const ServerIdentity& identity, unsigned netId,
                                                 bool isRevalidation) {
    // This ensures that the thread sends probe at least once in case
    // the server is removed before the thread starts running.
    // TODO: consider moving these code to the thread.
    const auto result = getPrivateDnsLocked(identity, netId);
    const auto result = getDotServerLocked(identity, netId);
    if (!result.ok()) return;
    DnsTlsServer server = *result.value();

@@ -256,7 +256,7 @@ void PrivateDnsConfiguration::startValidation(const ServerIdentity& identity, un
                         << server.toIpString();

            const bool needs_reeval =
                    this->recordPrivateDnsValidation(identity, netId, success, isRevalidation);
                    this->recordDotValidation(identity, netId, success, isRevalidation);

            if (!needs_reeval || !backoff.hasNextTimeout()) {
                break;
@@ -309,16 +309,15 @@ void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const ServerIdentity
    }
}

bool PrivateDnsConfiguration::recordPrivateDnsValidation(const ServerIdentity& identity,
                                                         unsigned netId, bool success,
                                                         bool isRevalidation) {
bool PrivateDnsConfiguration::recordDotValidation(const ServerIdentity& identity, unsigned netId,
                                                  bool success, bool isRevalidation) {
    constexpr bool NEEDS_REEVALUATION = true;
    constexpr bool DONT_REEVALUATE = false;

    std::lock_guard guard(mPrivateDnsLock);

    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
    auto netPair = mDotTracker.find(netId);
    if (netPair == mDotTracker.end()) {
        LOG(WARNING) << "netId " << netId << " was erased during private DNS validation";
        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
        return DONT_REEVALUATE;
@@ -373,7 +372,7 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const ServerIdentity& i

void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
                                                uint32_t netId) {
    const auto result = getPrivateDnsLocked(identity, netId);
    const auto result = getDotServerLocked(identity, netId);
    if (!result.ok()) {
        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
        return;
@@ -404,16 +403,16 @@ bool PrivateDnsConfiguration::needsValidation(const DnsTlsServer& server) const
    return false;
}

base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDns(const ServerIdentity& identity,
base::Result<DnsTlsServer*> PrivateDnsConfiguration::getDotServer(const ServerIdentity& identity,
                                                                  unsigned netId) {
    std::lock_guard guard(mPrivateDnsLock);
    return getPrivateDnsLocked(identity, netId);
    return getDotServerLocked(identity, netId);
}

base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
base::Result<DnsTlsServer*> PrivateDnsConfiguration::getDotServerLocked(
        const ServerIdentity& identity, unsigned netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
    auto netPair = mDotTracker.find(netId);
    if (netPair == mDotTracker.end()) {
        return Errorf("Failed to get private DNS: netId {} not found", netId);
    }

@@ -605,8 +604,8 @@ bool PrivateDnsConfiguration::needReportEvent(uint32_t netId, ServerIdentity ide
    switch (identity.sockaddr.port()) {
        // DoH
        case kDohPort: {
            auto netPair = mPrivateDnsTransports.find(netId);
            if (netPair == mPrivateDnsTransports.end()) return true;
            auto netPair = mDotTracker.find(netId);
            if (netPair == mDotTracker.end()) return true;
            for (const auto& [id, server] : netPair->second) {
                if ((identity.sockaddr.ip() == id.sockaddr.ip()) &&
                    (identity.sockaddr.port() != id.sockaddr.port()) &&
+15 −16
Original line number Diff line number Diff line
@@ -113,7 +113,7 @@ class PrivateDnsConfiguration {

    // Request the server to be revalidated on a connection tagged with |mark|.
    // Returns a Result to indicate if the request is accepted.
    base::Result<void> requestValidation(unsigned netId, const ServerIdentity& identity,
    base::Result<void> requestDotValidation(unsigned netId, const ServerIdentity& identity,
                                            uint32_t mark) EXCLUDES(mPrivateDnsLock);

    void setObserver(PrivateDnsValidationObserver* observer);
@@ -127,8 +127,6 @@ class PrivateDnsConfiguration {
            EXCLUDES(mPrivateDnsLock);

  private:
    typedef std::map<ServerIdentity, DnsTlsServer> PrivateDnsTracker;

    PrivateDnsConfiguration() = default;

    int setDot(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
@@ -136,12 +134,19 @@ class PrivateDnsConfiguration {

    void clearDot(int32_t netId) REQUIRES(mPrivateDnsLock);

    // Launchs a thread to run the validation for |server| on the network |netId|.
    // For testing.
    base::Result<DnsTlsServer*> getDotServer(const ServerIdentity& identity, unsigned netId)
            EXCLUDES(mPrivateDnsLock);

    base::Result<DnsTlsServer*> getDotServerLocked(const ServerIdentity& identity, unsigned netId)
            REQUIRES(mPrivateDnsLock);

    // Launchs a thread to run the validation for the DoT server |server| on the network |netId|.
    // |isRevalidation| is true if this call is due to a revalidation request.
    void startValidation(const ServerIdentity& identity, unsigned netId, bool isRevalidation)
    void startDotValidation(const ServerIdentity& identity, unsigned netId, bool isRevalidation)
            REQUIRES(mPrivateDnsLock);

    bool recordPrivateDnsValidation(const ServerIdentity& identity, unsigned netId, bool success,
    bool recordDotValidation(const ServerIdentity& identity, unsigned netId, bool success,
                             bool isRevalidation) EXCLUDES(mPrivateDnsLock);

    void sendPrivateDnsValidationEvent(const ServerIdentity& identity, unsigned netId,
@@ -155,13 +160,6 @@ class PrivateDnsConfiguration {
    void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
            REQUIRES(mPrivateDnsLock);

    // For testing.
    base::Result<DnsTlsServer*> getPrivateDns(const ServerIdentity& identity, unsigned netId)
            EXCLUDES(mPrivateDnsLock);

    base::Result<DnsTlsServer*> getPrivateDnsLocked(const ServerIdentity& identity, unsigned netId)
            REQUIRES(mPrivateDnsLock);

    void initDohLocked() REQUIRES(mPrivateDnsLock);
    int setDoh(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
               const std::string& name, const std::string& caCert) REQUIRES(mPrivateDnsLock);
@@ -174,7 +172,8 @@ class PrivateDnsConfiguration {
    // In case a server is removed due to a configuration change, it remains in this map,
    // but is marked inactive.
    // Any pending validation threads will continue running because we have no way to cancel them.
    std::map<unsigned, PrivateDnsTracker> mPrivateDnsTransports GUARDED_BY(mPrivateDnsLock);
    std::map<unsigned, std::map<ServerIdentity, DnsTlsServer>> mDotTracker
            GUARDED_BY(mPrivateDnsLock);

    void notifyValidationStateUpdate(const netdutils::IPSockAddr& sockaddr, Validation validation,
                                     uint32_t netId) const REQUIRES(mPrivateDnsLock);
+8 −8
Original line number Diff line number Diff line
@@ -124,7 +124,7 @@ class PrivateDnsConfigurationTest : public NetNativeTestBase {
    }

    bool hasPrivateDnsServer(const ServerIdentity& identity, unsigned netId) {
        return mPdc.getPrivateDns(identity, netId).ok();
        return mPdc.getDotServer(identity, netId).ok();
    }

    static constexpr uint32_t kNetId = 30;
@@ -198,7 +198,7 @@ TEST_F(PrivateDnsConfigurationTest, Revalidation_Opportunistic) {
        backend.startServer();
    });
    backend.stopServer();
    EXPECT_TRUE(mPdc.requestValidation(kNetId, ServerIdentity(server), kMark).ok());
    EXPECT_TRUE(mPdc.requestDotValidation(kNetId, ServerIdentity(server), kMark).ok());

    t.join();
    expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);
@@ -343,18 +343,18 @@ TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
            EXPECT_CALL(mObserver,
                        onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
            EXPECT_TRUE(mPdc.requestValidation(kNetId, identity, kMark).ok());
            EXPECT_TRUE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
        } else if (config == "IN_PROGRESS") {
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
            EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
            EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
        } else if (config == "FAIL") {
            EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
            EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
        }

        // Resending the same request or requesting nonexistent servers are denied.
        EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark + 1).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, identity, kMark).ok());
        EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark).ok());
        EXPECT_FALSE(mPdc.requestDotValidation(kNetId, identity, kMark + 1).ok());
        EXPECT_FALSE(mPdc.requestDotValidation(kNetId + 1, identity, kMark).ok());

        // Reset the test state.
        backend.setDeferredResp(false);