Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit faf6f3b0 authored by Mike Yu's avatar Mike Yu Committed by Automerger Merge Worker
Browse files

Change most of PrivateDnsConfiguration methods to use ServerIdentity am:...

Change most of PrivateDnsConfiguration methods to use ServerIdentity am: ad96ef83 am: 3e6ad9c3 am: 29cb3923

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1600840

Change-Id: I28b485882a859caeba5b7a49f8c3bba53950d62b
parents 16dbcd4c 29cb3923
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -214,8 +214,8 @@ DnsTlsTransport::Response DnsTlsDispatcher::query(const DnsTlsServer& server, un
            // happens, the xport will be marked as unusable and DoT queries won't be sent to
            // it anymore. Eventually, after IDLE_TIMEOUT, the xport will be destroyed, and
            // a new xport will be created.
            const auto result =
                    PrivateDnsConfiguration::getInstance().requestValidation(netId, server, mark);
            const auto result = PrivateDnsConfiguration::getInstance().requestValidation(
                    netId, PrivateDnsConfiguration::ServerIdentity{server}, mark);
            LOG(WARNING) << "Requested validation for " << server.toIpString() << " with mark 0x"
                         << std::hex << mark << ", "
                         << (result.ok() ? "succeeded" : "failed: " + result.error().message());
+61 −47
Original line number Diff line number Diff line
@@ -110,7 +110,7 @@ int PrivateDnsConfiguration::set(int32_t netId, uint32_t mark,

        if (needsValidation(server)) {
            updateServerState(identity, Validation::in_process, netId);
            startValidation(server, netId, false);
            startValidation(identity, netId, false);
        }
    }

@@ -145,7 +145,7 @@ void PrivateDnsConfiguration::clear(unsigned netId) {
}

base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
                                                              const DnsTlsServer& server,
                                                              const ServerIdentity& identity,
                                                              uint32_t mark) {
    std::lock_guard guard(mPrivateDnsLock);

@@ -159,40 +159,39 @@ base::Result<void> PrivateDnsConfiguration::requestValidation(unsigned netId,
        return Errorf("Private DNS setting is not opportunistic mode");
    }

    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        return Errorf("NetId not found in mPrivateDnsTransports");
    auto result = getPrivateDnsLocked(identity, netId);
    if (!result.ok()) {
        return result.error();
    }

    auto& tracker = netPair->second;
    const ServerIdentity identity = ServerIdentity(server);
    auto it = tracker.find(identity);
    if (it == tracker.end()) {
        return Errorf("Server was removed");
    }
    const DnsTlsServer* target = result.value();

    const DnsTlsServer& target = it->second;
    if (!target->active()) return Errorf("Server is not active");

    if (!target.active()) return Errorf("Server is not active");

    if (target.validationState() != Validation::success) {
    if (target->validationState() != Validation::success) {
        return Errorf("Server validation state mismatched");
    }

    // Don't run the validation if |mark| (from android_net_context.dns_mark) is different.
    // This is to protect validation from running on unexpected marks.
    // Validation should be associated with a mark gotten by system permission.
    if (target.mark != mark) return Errorf("Socket mark mismatched");
    if (target->mark != mark) return Errorf("Socket mark mismatched");

    updateServerState(identity, Validation::in_process, netId);
    startValidation(target, netId, true);
    startValidation(identity, netId, true);
    return {};
}

void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsigned netId,
                                              bool isRevalidation) REQUIRES(mPrivateDnsLock) {
    // Note that capturing |server|, |netId|, and |isRevalidation| in this lambda create copies.
    std::thread validate_thread([this, server, netId, isRevalidation] {
void PrivateDnsConfiguration::startValidation(const ServerIdentity& identity, unsigned netId,
                                              bool isRevalidation) {
    // This ensures that the thread sends probe at least once in case
    // the server is removed before the thread starts running.
    // TODO: consider moving these code to the thread.
    const auto result = getPrivateDnsLocked(identity, netId);
    if (!result.ok()) return;
    DnsTlsServer server = *result.value();

    std::thread validate_thread([this, identity, server, netId, isRevalidation] {
        setThreadName(StringPrintf("TlsVerify_%u", netId).c_str());

        // cat /proc/sys/net/ipv4/tcp_syn_retries yields "6".
@@ -223,7 +222,7 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign
                         << server.toIpString();

            const bool needs_reeval =
                    this->recordPrivateDnsValidation(server, netId, success, isRevalidation);
                    this->recordPrivateDnsValidation(identity, netId, success, isRevalidation);

            if (!needs_reeval) {
                break;
@@ -240,11 +239,11 @@ void PrivateDnsConfiguration::startValidation(const DnsTlsServer& server, unsign
    validate_thread.detach();
}

void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const DnsTlsServer& server,
void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const ServerIdentity& identity,
                                                            unsigned netId, bool success) {
    LOG(DEBUG) << "Sending validation " << (success ? "success" : "failure") << " event on netId "
               << netId << " for " << server.toIpString() << " with hostname {" << server.name
               << "}";
               << netId << " for " << identity.sockaddr.ip().toString() << " with hostname {"
               << identity.provider << "}";
    // Send a validation event to NetdEventListenerService.
    const auto& listeners = ResolverEventReporter::getInstance().getListeners();
    if (listeners.empty()) {
@@ -252,15 +251,16 @@ void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const DnsTlsServer&
                << "Validation event not sent since no INetdEventListener receiver is available.";
    }
    for (const auto& it : listeners) {
        it->onPrivateDnsValidationEvent(netId, server.toIpString(), server.name, success);
        it->onPrivateDnsValidationEvent(netId, identity.sockaddr.ip().toString(), identity.provider,
                                        success);
    }

    // Send a validation event to unsolicited event listeners.
    const auto& unsolEventListeners = ResolverEventReporter::getInstance().getUnsolEventListeners();
    const PrivateDnsValidationEventParcel validationEvent = {
            .netId = static_cast<int32_t>(netId),
            .ipAddress = server.toIpString(),
            .hostname = server.name,
            .ipAddress = identity.sockaddr.ip().toString(),
            .hostname = identity.provider,
            .validation = success ? IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_SUCCESS
                                  : IDnsResolverUnsolicitedEventListener::VALIDATION_RESULT_FAILURE,
    };
@@ -269,11 +269,11 @@ void PrivateDnsConfiguration::sendPrivateDnsValidationEvent(const DnsTlsServer&
    }
}

bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId,
                                                         bool success, bool isRevalidation) {
bool PrivateDnsConfiguration::recordPrivateDnsValidation(const ServerIdentity& identity,
                                                         unsigned netId, bool success,
                                                         bool isRevalidation) {
    constexpr bool NEEDS_REEVALUATION = true;
    constexpr bool DONT_REEVALUATE = false;
    const ServerIdentity identity = ServerIdentity(server);

    std::lock_guard guard(mPrivateDnsLock);

@@ -303,23 +303,19 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
    auto& tracker = netPair->second;
    auto serverPair = tracker.find(identity);
    if (serverPair == tracker.end()) {
        LOG(WARNING) << "Server " << server.toIpString()
        LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
                     << " was removed during private DNS validation";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    } else if (!(serverPair->second == server)) {
        LOG(WARNING) << "Server " << server.toIpString()
                     << " was changed during private DNS validation";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    } else if (!serverPair->second.active()) {
        LOG(WARNING) << "Server " << server.toIpString() << " was removed from the configuration";
        LOG(WARNING) << "Server " << identity.sockaddr.ip().toString()
                     << " was removed from the configuration";
        success = false;
        reevaluationStatus = DONT_REEVALUATE;
    }

    // Send private dns validation result to listeners.
    sendPrivateDnsValidationEvent(server, netId, success);
    sendPrivateDnsValidationEvent(identity, netId, success);

    if (success) {
        updateServerState(identity, Validation::success, netId);
@@ -338,19 +334,15 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser

void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity, Validation state,
                                                uint32_t netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
    const auto result = getPrivateDnsLocked(identity, netId);
    if (!result.ok()) {
        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
        return;
    }

    auto& tracker = netPair->second;
    if (tracker.find(identity) == tracker.end()) {
        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
        return;
    }
    auto* server = result.value();

    tracker[identity].setValidationState(state);
    server->setValidationState(state);
    notifyValidationStateUpdate(identity.sockaddr, state, netId);

    RecordEntry record(netId, identity, state);
@@ -373,6 +365,28 @@ bool PrivateDnsConfiguration::needsValidation(const DnsTlsServer& server) {
    return false;
}

base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDns(const ServerIdentity& identity,
                                                                   unsigned netId) {
    std::lock_guard guard(mPrivateDnsLock);
    return getPrivateDnsLocked(identity, netId);
}

base::Result<DnsTlsServer*> PrivateDnsConfiguration::getPrivateDnsLocked(
        const ServerIdentity& identity, unsigned netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        return Errorf("Failed to get private DNS: netId {} not found", netId);
    }

    auto iter = netPair->second.find(identity);
    if (iter == netPair->second.end()) {
        return Errorf("Failed to get private DNS: server {{{}/{}}} not found", identity.sockaddr,
                      identity.provider);
    }

    return &iter->second;
}

void PrivateDnsConfiguration::setObserver(PrivateDnsValidationObserver* observer) {
    std::lock_guard guard(mPrivateDnsLock);
    mObserver = observer;
+30 −21
Original line number Diff line number Diff line
@@ -33,6 +33,7 @@
namespace android {
namespace net {

// TODO: decouple the dependency of DnsTlsServer.
struct PrivateDnsStatus {
    PrivateDnsMode mode;

@@ -53,24 +54,6 @@ struct PrivateDnsStatus {

class PrivateDnsConfiguration {
  public:
    // The only instance of PrivateDnsConfiguration.
    static PrivateDnsConfiguration& getInstance() {
        static PrivateDnsConfiguration instance;
        return instance;
    }

    int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
            const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);

    PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);

    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);

    // Request |server| to be revalidated on a connection tagged with |mark|.
    // Returns a Result to indicate if the request is accepted.
    base::Result<void> requestValidation(unsigned netId, const DnsTlsServer& server, uint32_t mark)
            EXCLUDES(mPrivateDnsLock);

    struct ServerIdentity {
        const netdutils::IPSockAddr sockaddr;
        const std::string provider;
@@ -86,6 +69,24 @@ class PrivateDnsConfiguration {
        }
    };

    // The only instance of PrivateDnsConfiguration.
    static PrivateDnsConfiguration& getInstance() {
        static PrivateDnsConfiguration instance;
        return instance;
    }

    int set(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
            const std::string& name, const std::string& caCert) EXCLUDES(mPrivateDnsLock);

    PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);

    void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);

    // Request the server to be revalidated on a connection tagged with |mark|.
    // Returns a Result to indicate if the request is accepted.
    base::Result<void> requestValidation(unsigned netId, const ServerIdentity& identity,
                                         uint32_t mark) EXCLUDES(mPrivateDnsLock);

    void setObserver(PrivateDnsValidationObserver* observer);

    void dump(netdutils::DumpWriter& dw) const;
@@ -97,23 +98,31 @@ class PrivateDnsConfiguration {

    // Launchs a thread to run the validation for |server| on the network |netId|.
    // |isRevalidation| is true if this call is due to a revalidation request.
    void startValidation(const DnsTlsServer& server, unsigned netId, bool isRevalidation)
    void startValidation(const ServerIdentity& identity, unsigned netId, bool isRevalidation)
            REQUIRES(mPrivateDnsLock);

    bool recordPrivateDnsValidation(const DnsTlsServer& server, unsigned netId, bool success,
    bool recordPrivateDnsValidation(const ServerIdentity& identity, unsigned netId, bool success,
                                    bool isRevalidation) EXCLUDES(mPrivateDnsLock);

    void sendPrivateDnsValidationEvent(const DnsTlsServer& server, unsigned netId, bool success)
    void sendPrivateDnsValidationEvent(const ServerIdentity& identity, unsigned netId, bool success)
            REQUIRES(mPrivateDnsLock);

    // Decide if a validation for |server| is needed. Note that servers that have failed
    // multiple validation attempts but for which there is still a validating
    // thread running are marked as being Validation::in_process.
    // TODO: decouple the dependency of DnsTlsServer.
    bool needsValidation(const DnsTlsServer& server) REQUIRES(mPrivateDnsLock);

    void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
            REQUIRES(mPrivateDnsLock);

    // For testing.
    base::Result<DnsTlsServer*> getPrivateDns(const ServerIdentity& identity, unsigned netId)
            EXCLUDES(mPrivateDnsLock);

    base::Result<DnsTlsServer*> getPrivateDnsLocked(const ServerIdentity& identity, unsigned netId)
            REQUIRES(mPrivateDnsLock);

    mutable std::mutex mPrivateDnsLock;
    std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);

+33 −8
Original line number Diff line number Diff line
@@ -28,6 +28,8 @@ using namespace std::chrono_literals;

class PrivateDnsConfigurationTest : public ::testing::Test {
  public:
    using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;

    static void SetUpTestSuite() {
        // stopServer() will be called in their destructor.
        ASSERT_TRUE(tls1.startServer());
@@ -100,6 +102,10 @@ class PrivateDnsConfigurationTest : public ::testing::Test {
        return (serverStateMap == mObserver.getServerStateMap());
    }

    bool hasPrivateDnsServer(const ServerIdentity& identity, unsigned netId) {
        return mPdc.getPrivateDns(identity, netId).ok();
    }

    static constexpr uint32_t kNetId = 30;
    static constexpr uint32_t kMark = 30;
    static constexpr char kBackend[] = "127.0.2.1";
@@ -230,8 +236,6 @@ TEST_F(PrivateDnsConfigurationTest, NoValidation) {
}

TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
    using ServerIdentity = PrivateDnsConfiguration::ServerIdentity;

    DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
    server.name = "dns.example.com";

@@ -254,6 +258,7 @@ TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {

TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
    const DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
    const ServerIdentity identity(server);

    testing::InSequence seq;

@@ -281,18 +286,18 @@ TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
            EXPECT_CALL(mObserver,
                        onValidationStateUpdate(kServer1, Validation::in_process, kNetId));
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
            EXPECT_TRUE(mPdc.requestValidation(kNetId, server, kMark).ok());
            EXPECT_TRUE(mPdc.requestValidation(kNetId, identity, kMark).ok());
        } else if (config == "IN_PROGRESS") {
            EXPECT_CALL(mObserver, onValidationStateUpdate(kServer1, Validation::success, kNetId));
            EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
            EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
        } else if (config == "FAIL") {
            EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
            EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
        }

        // Resending the same request or requesting nonexistent servers are denied.
        EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId, server, kMark + 1).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, server, kMark).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId, identity, kMark + 1).ok());
        EXPECT_FALSE(mPdc.requestValidation(kNetId + 1, identity, kMark).ok());

        // Reset the test state.
        backend.setDeferredResp(false);
@@ -306,6 +311,26 @@ TEST_F(PrivateDnsConfigurationTest, RequestValidation) {
    }
}

TEST_F(PrivateDnsConfigurationTest, GetPrivateDns) {
    const DnsTlsServer server1(netdutils::IPSockAddr::toIPSockAddr(kServer1, 853));
    const DnsTlsServer server2(netdutils::IPSockAddr::toIPSockAddr(kServer2, 853));

    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));

    // Suppress the warning.
    EXPECT_CALL(mObserver, onValidationStateUpdate).Times(2);

    EXPECT_EQ(mPdc.set(kNetId, kMark, {kServer1}, {}, {}), 0);
    expectPrivateDnsStatus(PrivateDnsMode::OPPORTUNISTIC);

    EXPECT_TRUE(hasPrivateDnsServer(ServerIdentity(server1), kNetId));
    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server2), kNetId));
    EXPECT_FALSE(hasPrivateDnsServer(ServerIdentity(server1), kNetId + 1));

    ASSERT_TRUE(PollForCondition([&]() { return mObserver.runningThreads == 0; }));
}

// TODO: add ValidationFail_Strict test.

}  // namespace android::net