Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 453b5e4e authored by Mike Yu's avatar Mike Yu
Browse files

Change ServerIdentity to store IPSockAddr instead of IPAddress

The parameter protocol was added to ServerIdentity in order to
identify an unique private DNS server. This is unneeded after
this change since we can check the socket port to achieve the
same purpose.

Bug: 186177613
Test: cd packages/modules/DnsResolver && atest
Change-Id: Ifdad55ea817ff568a84271a5d7b77dd4fbe4772c
parent 08b2f2b7
Loading
Loading
Loading
Loading
+11 −11
Original line number Diff line number Diff line
@@ -280,14 +280,14 @@ bool PrivateDnsConfiguration::recordPrivateDnsValidation(const DnsTlsServer& ser
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        LOG(WARNING) << "netId " << netId << " was erased during private DNS validation";
        notifyValidationStateUpdate(identity.ip.toString(), Validation::fail, netId);
        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
        return DONT_REEVALUATE;
    }

    const auto mode = mPrivateDnsModes.find(netId);
    if (mode == mPrivateDnsModes.end()) {
        LOG(WARNING) << "netId " << netId << " has no private DNS validation mode";
        notifyValidationStateUpdate(identity.ip.toString(), Validation::fail, netId);
        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
        return DONT_REEVALUATE;
    }

@@ -342,18 +342,18 @@ void PrivateDnsConfiguration::updateServerState(const ServerIdentity& identity,
                                                uint32_t netId) {
    auto netPair = mPrivateDnsTransports.find(netId);
    if (netPair == mPrivateDnsTransports.end()) {
        notifyValidationStateUpdate(identity.ip.toString(), Validation::fail, netId);
        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
        return;
    }

    auto& tracker = netPair->second;
    if (tracker.find(identity) == tracker.end()) {
        notifyValidationStateUpdate(identity.ip.toString(), Validation::fail, netId);
        notifyValidationStateUpdate(identity.sockaddr, Validation::fail, netId);
        return;
    }

    tracker[identity].setValidationState(state);
    notifyValidationStateUpdate(identity.ip.toString(), state, netId);
    notifyValidationStateUpdate(identity.sockaddr, state, netId);

    RecordEntry record(netId, identity, state);
    mPrivateDnsLog.push(std::move(record));
@@ -380,11 +380,11 @@ void PrivateDnsConfiguration::setObserver(PrivateDnsValidationObserver* observer
    mObserver = observer;
}

void PrivateDnsConfiguration::notifyValidationStateUpdate(const std::string& serverIp,
void PrivateDnsConfiguration::notifyValidationStateUpdate(const netdutils::IPSockAddr& sockaddr,
                                                          Validation validation,
                                                          uint32_t netId) const {
    if (mObserver) {
        mObserver->onValidationStateUpdate(serverIp, validation, netId);
        mObserver->onValidationStateUpdate(sockaddr.ip().toString(), validation, netId);
    }
}

@@ -393,10 +393,10 @@ void PrivateDnsConfiguration::dump(netdutils::DumpWriter& dw) const {
    netdutils::ScopedIndent indentStats(dw);

    for (const auto& record : mPrivateDnsLog.copy()) {
        dw.println(fmt::format("{} - netId={} PrivateDns={{{}/{}}} state={}",
                               timestampToString(record.timestamp), record.netId,
                               record.serverIdentity.ip.toString(), record.serverIdentity.name,
                               validationStatusToString(record.state)));
        dw.println(fmt::format(
                "{} - netId={} PrivateDns={{{}/{}}} state={}", timestampToString(record.timestamp),
                record.netId, record.serverIdentity.sockaddr.toString(),
                record.serverIdentity.provider, validationStatusToString(record.state)));
    }
    dw.blankline();
}
+6 −9
Original line number Diff line number Diff line
@@ -72,20 +72,17 @@ class PrivateDnsConfiguration {
            EXCLUDES(mPrivateDnsLock);

    struct ServerIdentity {
        const netdutils::IPAddress ip;
        const std::string name;
        const int protocol;
        const netdutils::IPSockAddr sockaddr;
        const std::string provider;

        explicit ServerIdentity(const DnsTlsServer& server)
            : ip(netdutils::IPSockAddr::toIPSockAddr(server.ss).ip()),
              name(server.name),
              protocol(server.protocol) {}
            : sockaddr(netdutils::IPSockAddr::toIPSockAddr(server.ss)), provider(server.name) {}

        bool operator<(const ServerIdentity& other) const {
            return std::tie(ip, name, protocol) < std::tie(other.ip, other.name, other.protocol);
            return std::tie(sockaddr, provider) < std::tie(other.sockaddr, other.provider);
        }
        bool operator==(const ServerIdentity& other) const {
            return std::tie(ip, name, protocol) == std::tie(other.ip, other.name, other.protocol);
            return std::tie(sockaddr, provider) == std::tie(other.sockaddr, other.provider);
        }
    };

@@ -127,7 +124,7 @@ class PrivateDnsConfiguration {
    // Any pending validation threads will continue running because we have no way to cancel them.
    std::map<unsigned, PrivateDnsTracker> mPrivateDnsTransports GUARDED_BY(mPrivateDnsLock);

    void notifyValidationStateUpdate(const std::string& serverIp, Validation validation,
    void notifyValidationStateUpdate(const netdutils::IPSockAddr& sockaddr, Validation validation,
                                     uint32_t netId) const REQUIRES(mPrivateDnsLock);

    // TODO: fix the reentrancy problem.
+2 −9
Original line number Diff line number Diff line
@@ -234,13 +234,12 @@ TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {

    DnsTlsServer server(netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 853));
    server.name = "dns.example.com";
    server.protocol = 1;

    // Different IP address (port is ignored).
    // Different socket address.
    DnsTlsServer other = server;
    EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
    other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.1", 5353);
    EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
    other.ss = netdutils::IPSockAddr::toIPSockAddr("127.0.0.2", 853);
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));

@@ -251,12 +250,6 @@ TEST_F(PrivateDnsConfigurationTest, ServerIdentity_Comparison) {
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
    other.name = "";
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));

    // Different protocol.
    other = server;
    EXPECT_EQ(ServerIdentity(server), ServerIdentity(other));
    other.protocol++;
    EXPECT_NE(ServerIdentity(server), ServerIdentity(other));
}

TEST_F(PrivateDnsConfigurationTest, RequestValidation) {