Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e2af226b authored by Bernie Innocenti's avatar Bernie Innocenti
Browse files

Fix refcounting on thread creation failure path for android_res_nsend()

SocketClient instances are manually reference-counted. DnsProxyListener
passes them down to handler threads, and expects them to decrement the
reference coint when done.

On thread creation failure, tryThreadOrError() deletes the handler and
decrements the reference:

void tryThreadOrError(SocketClient* cli, T* handler) {
    const int rval = netdutils::threadLaunch(handler);
    if (rval == 0) {
        // SocketClient decRef() happens in the handler's run() method.
        return;
    }
    ...
    delete handler;  // Calls decRef() in ~ResNSendHandler!!!
    cli->decRef();
}

However, the assumption that decRef() is done in the handler's run()
method was not true in all cases: ResNSendHandler, added in Android Q,
actually decrements the client's reference count in its destructor.

Since tryThreadOrError() decrements the counter twice on the error path,
the SocketClient self-deletes too early, while it is still referenced by
the SocketListener.

Triggering this condition requires an unresponsive network and multiple
apps sending queries, until netd runs out of threads (each app is
limited to 256 concurrent queries by queryLimiter). At least one app
should be using android_res_nsend().

The fix consists in moving ownership of SocketClient to a base class of
all handlers, with strong encapsulation of the manual reference
counting, RAII-style. This simplifies DnsProxyListener while making
further refcounting bugs unlikely to occur in the future.

DnsProxyListener remains dangerously open to memory leaks in error paths
due to manual memory management. We should take care of this in a
followup change.

Bug: 169105756
Change-Id: I9241a094653651e1bdda79eb753e7a53e5e51d8f
parent 0c7b1ca2
Loading
Loading
Loading
Loading
+21 −42
Original line number Original line Diff line number Diff line
@@ -84,25 +84,6 @@ void logArguments(int argc, char** argv) {
    }
    }
}
}


template<typename T>
void tryThreadOrError(SocketClient* cli, T* handler) {
    cli->incRef();

    const int rval = netdutils::threadLaunch(handler);
    if (rval == 0) {
        // SocketClient decRef() happens in the handler's run() method.
        return;
    }

    char* msg = nullptr;
    asprintf(&msg, "%s (%d)", strerror(-rval), -rval);
    cli->sendMsg(ResponseCode::OperationFailed, msg, false);
    free(msg);

    delete handler;
    cli->decRef();
}

bool checkAndClearUseLocalNameserversFlag(unsigned* netid) {
bool checkAndClearUseLocalNameserversFlag(unsigned* netid) {
    if (netid == nullptr || ((*netid) & NETID_USE_LOCAL_NAMESERVERS) == 0) {
    if (netid == nullptr || ((*netid) & NETID_USE_LOCAL_NAMESERVERS) == 0) {
        return false;
        return false;
@@ -563,10 +544,23 @@ DnsProxyListener::DnsProxyListener() : FrameworkListener(SOCKET_NAME) {
    registerCmd(new GetDnsNetIdCommand());
    registerCmd(new GetDnsNetIdCommand());
}
}


void DnsProxyListener::Handler::spawn() {
    const int rval = netdutils::threadLaunch(this);
    if (rval == 0) {
        return;
    }

    char* msg = nullptr;
    asprintf(&msg, "%s (%d)", strerror(-rval), -rval);
    mClient->sendMsg(ResponseCode::OperationFailed, msg, false);
    free(msg);
    delete this;
}

DnsProxyListener::GetAddrInfoHandler::GetAddrInfoHandler(SocketClient* c, char* host, char* service,
DnsProxyListener::GetAddrInfoHandler::GetAddrInfoHandler(SocketClient* c, char* host, char* service,
                                                         addrinfo* hints,
                                                         addrinfo* hints,
                                                         const android_net_context& netcontext)
                                                         const android_net_context& netcontext)
    : mClient(c), mHost(host), mService(service), mHints(hints), mNetContext(netcontext) {}
    : Handler(c), mHost(host), mService(service), mHints(hints), mNetContext(netcontext) {}


DnsProxyListener::GetAddrInfoHandler::~GetAddrInfoHandler() {
DnsProxyListener::GetAddrInfoHandler::~GetAddrInfoHandler() {
    free(mHost);
    free(mHost);
@@ -763,7 +757,6 @@ void DnsProxyListener::GetAddrInfoHandler::run() {
    reportDnsEvent(INetdEventListener::EVENT_GETADDRINFO, mNetContext, latencyUs, rv, event, mHost,
    reportDnsEvent(INetdEventListener::EVENT_GETADDRINFO, mNetContext, latencyUs, rv, event, mHost,
                   ip_addrs, total_ip_addr_count);
                   ip_addrs, total_ip_addr_count);
    freeaddrinfo(result);
    freeaddrinfo(result);
    mClient->decRef();
}
}


std::string DnsProxyListener::GetAddrInfoHandler::threadName() {
std::string DnsProxyListener::GetAddrInfoHandler::threadName() {
@@ -841,9 +834,7 @@ int DnsProxyListener::GetAddrInfoCmd::runCommand(SocketClient *cli,
        hints->ai_protocol = ai_protocol;
        hints->ai_protocol = ai_protocol;
    }
    }


    DnsProxyListener::GetAddrInfoHandler* handler =
    (new GetAddrInfoHandler(cli, name, service, hints, netcontext))->spawn();
            new DnsProxyListener::GetAddrInfoHandler(cli, name, service, hints, netcontext);
    tryThreadOrError(cli, handler);
    return 0;
    return 0;
}
}


@@ -888,19 +879,13 @@ int DnsProxyListener::ResNSendCommand::runCommand(SocketClient* cli, int argc, c
        netcontext.flags |= NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS;
        netcontext.flags |= NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS;
    }
    }


    DnsProxyListener::ResNSendHandler* handler =
    (new ResNSendHandler(cli, argv[3], flags, netcontext))->spawn();
            new DnsProxyListener::ResNSendHandler(cli, argv[3], flags, netcontext);
    tryThreadOrError(cli, handler);
    return 0;
    return 0;
}
}


DnsProxyListener::ResNSendHandler::ResNSendHandler(SocketClient* c, std::string msg, uint32_t flags,
DnsProxyListener::ResNSendHandler::ResNSendHandler(SocketClient* c, std::string msg, uint32_t flags,
                                                   const android_net_context& netcontext)
                                                   const android_net_context& netcontext)
    : mClient(c), mMsg(std::move(msg)), mFlags(flags), mNetContext(netcontext) {}
    : Handler(c), mMsg(std::move(msg)), mFlags(flags), mNetContext(netcontext) {}

DnsProxyListener::ResNSendHandler::~ResNSendHandler() {
    mClient->decRef();
}


void DnsProxyListener::ResNSendHandler::run() {
void DnsProxyListener::ResNSendHandler::run() {
    LOG(DEBUG) << "ResNSendHandler::run: " << mFlags << " / {" << mNetContext.app_netid << " "
    LOG(DEBUG) << "ResNSendHandler::run: " << mFlags << " / {" << mNetContext.app_netid << " "
@@ -1090,15 +1075,13 @@ int DnsProxyListener::GetHostByNameCmd::runCommand(SocketClient *cli,
        netcontext.flags |= NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS;
        netcontext.flags |= NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS;
    }
    }


    DnsProxyListener::GetHostByNameHandler* handler =
    (new GetHostByNameHandler(cli, name, af, netcontext))->spawn();
            new DnsProxyListener::GetHostByNameHandler(cli, name, af, netcontext);
    tryThreadOrError(cli, handler);
    return 0;
    return 0;
}
}


DnsProxyListener::GetHostByNameHandler::GetHostByNameHandler(SocketClient* c, char* name, int af,
DnsProxyListener::GetHostByNameHandler::GetHostByNameHandler(SocketClient* c, char* name, int af,
                                                             const android_net_context& netcontext)
                                                             const android_net_context& netcontext)
    : mClient(c), mName(name), mAf(af), mNetContext(netcontext) {}
    : Handler(c), mName(name), mAf(af), mNetContext(netcontext) {}


DnsProxyListener::GetHostByNameHandler::~GetHostByNameHandler() {
DnsProxyListener::GetHostByNameHandler::~GetHostByNameHandler() {
    free(mName);
    free(mName);
@@ -1190,7 +1173,6 @@ void DnsProxyListener::GetHostByNameHandler::run() {
    const int total_ip_addr_count = extractGetHostByNameAnswers(hp, &ip_addrs);
    const int total_ip_addr_count = extractGetHostByNameAnswers(hp, &ip_addrs);
    reportDnsEvent(INetdEventListener::EVENT_GETHOSTBYNAME, mNetContext, latencyUs, rv, event,
    reportDnsEvent(INetdEventListener::EVENT_GETHOSTBYNAME, mNetContext, latencyUs, rv, event,
                   mName, ip_addrs, total_ip_addr_count);
                   mName, ip_addrs, total_ip_addr_count);
    mClient->decRef();
}
}


std::string DnsProxyListener::GetHostByNameHandler::threadName() {
std::string DnsProxyListener::GetHostByNameHandler::threadName() {
@@ -1242,16 +1224,14 @@ int DnsProxyListener::GetHostByAddrCmd::runCommand(SocketClient *cli,
        netcontext.flags |= NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS;
        netcontext.flags |= NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS;
    }
    }


    DnsProxyListener::GetHostByAddrHandler* handler = new DnsProxyListener::GetHostByAddrHandler(
    (new GetHostByAddrHandler(cli, addr, addrLen, addrFamily, netcontext))->spawn();
            cli, addr, addrLen, addrFamily, netcontext);
    tryThreadOrError(cli, handler);
    return 0;
    return 0;
}
}


DnsProxyListener::GetHostByAddrHandler::GetHostByAddrHandler(SocketClient* c, void* address,
DnsProxyListener::GetHostByAddrHandler::GetHostByAddrHandler(SocketClient* c, void* address,
                                                             int addressLen, int addressFamily,
                                                             int addressLen, int addressFamily,
                                                             const android_net_context& netcontext)
                                                             const android_net_context& netcontext)
    : mClient(c),
    : Handler(c),
      mAddress(address),
      mAddress(address),
      mAddressLen(addressLen),
      mAddressLen(addressLen),
      mAddressFamily(addressFamily),
      mAddressFamily(addressFamily),
@@ -1351,7 +1331,6 @@ void DnsProxyListener::GetHostByAddrHandler::run() {


    reportDnsEvent(INetdEventListener::EVENT_GETHOSTBYADDR, mNetContext, latencyUs, rv, event,
    reportDnsEvent(INetdEventListener::EVENT_GETHOSTBYADDR, mNetContext, latencyUs, rv, event,
                   (hp && hp->h_name) ? hp->h_name : "null", {}, 0);
                   (hp && hp->h_name) ? hp->h_name : "null", {}, 0);
    mClient->decRef();
}
}


std::string DnsProxyListener::GetHostByAddrHandler::threadName() {
std::string DnsProxyListener::GetHostByAddrHandler::threadName() {
+33 −21
Original line number Original line Diff line number Diff line
@@ -38,6 +38,23 @@ class DnsProxyListener : public FrameworkListener {
    static constexpr const char* SOCKET_NAME = "dnsproxyd";
    static constexpr const char* SOCKET_NAME = "dnsproxyd";


  private:
  private:
    class Handler {
      public:
        Handler(SocketClient* c) : mClient(c) { mClient->incRef(); }
        virtual ~Handler() { mClient->decRef(); }
        void operator=(const Handler&) = delete;

        // Attept to spawn the worker thread, or return an error to the client.
        // The Handler instance will self-delete in either case.
        void spawn();

        virtual void run() = 0;
        virtual std::string threadName() = 0;

        SocketClient* mClient;  // ref-counted
    };

    /* ------ getaddrinfo ------*/
    class GetAddrInfoCmd : public FrameworkCommand {
    class GetAddrInfoCmd : public FrameworkCommand {
      public:
      public:
        GetAddrInfoCmd();
        GetAddrInfoCmd();
@@ -45,21 +62,19 @@ class DnsProxyListener : public FrameworkListener {
        int runCommand(SocketClient* c, int argc, char** argv) override;
        int runCommand(SocketClient* c, int argc, char** argv) override;
    };
    };


    /* ------ getaddrinfo ------*/
    class GetAddrInfoHandler : public Handler {
    class GetAddrInfoHandler {
      public:
      public:
        // Note: All of host, service, and hints may be NULL
        // Note: All of host, service, and hints may be NULL
        GetAddrInfoHandler(SocketClient* c, char* host, char* service, addrinfo* hints,
        GetAddrInfoHandler(SocketClient* c, char* host, char* service, addrinfo* hints,
                           const android_net_context& netcontext);
                           const android_net_context& netcontext);
        ~GetAddrInfoHandler();
        ~GetAddrInfoHandler() override;


        void run();
        void run() override;
        std::string threadName();
        std::string threadName() override;


      private:
      private:
        void doDns64Synthesis(int32_t* rv, addrinfo** res, NetworkDnsEventReported* event);
        void doDns64Synthesis(int32_t* rv, addrinfo** res, NetworkDnsEventReported* event);


        SocketClient* mClient;  // ref counted
        char* mHost;            // owned. TODO: convert to std::string.
        char* mHost;            // owned. TODO: convert to std::string.
        char* mService;         // owned. TODO: convert to std::string.
        char* mService;         // owned. TODO: convert to std::string.
        addrinfo* mHints;       // owned
        addrinfo* mHints;       // owned
@@ -74,20 +89,19 @@ class DnsProxyListener : public FrameworkListener {
        int runCommand(SocketClient* c, int argc, char** argv) override;
        int runCommand(SocketClient* c, int argc, char** argv) override;
    };
    };


    class GetHostByNameHandler {
    class GetHostByNameHandler : public Handler {
      public:
      public:
        GetHostByNameHandler(SocketClient* c, char* name, int af,
        GetHostByNameHandler(SocketClient* c, char* name, int af,
                             const android_net_context& netcontext);
                             const android_net_context& netcontext);
        ~GetHostByNameHandler();
        ~GetHostByNameHandler() override;


        void run();
        void run() override;
        std::string threadName();
        std::string threadName() override;


      private:
      private:
        void doDns64Synthesis(int32_t* rv, hostent* hbuf, char* buf, size_t buflen, hostent** hpp,
        void doDns64Synthesis(int32_t* rv, hostent* hbuf, char* buf, size_t buflen, hostent** hpp,
                              NetworkDnsEventReported* event);
                              NetworkDnsEventReported* event);


        SocketClient* mClient;  // ref counted
        char* mName;            // owned. TODO: convert to std::string.
        char* mName;            // owned. TODO: convert to std::string.
        int mAf;
        int mAf;
        android_net_context mNetContext;
        android_net_context mNetContext;
@@ -101,20 +115,19 @@ class DnsProxyListener : public FrameworkListener {
        int runCommand(SocketClient* c, int argc, char** argv) override;
        int runCommand(SocketClient* c, int argc, char** argv) override;
    };
    };


    class GetHostByAddrHandler {
    class GetHostByAddrHandler : public Handler {
      public:
      public:
        GetHostByAddrHandler(SocketClient* c, void* address, int addressLen, int addressFamily,
        GetHostByAddrHandler(SocketClient* c, void* address, int addressLen, int addressFamily,
                             const android_net_context& netcontext);
                             const android_net_context& netcontext);
        ~GetHostByAddrHandler();
        ~GetHostByAddrHandler() override;


        void run();
        void run() override;
        std::string threadName();
        std::string threadName() override;


      private:
      private:
        void doDns64ReverseLookup(hostent* hbuf, char* buf, size_t buflen, hostent** hpp,
        void doDns64ReverseLookup(hostent* hbuf, char* buf, size_t buflen, hostent** hpp,
                                  NetworkDnsEventReported* event);
                                  NetworkDnsEventReported* event);


        SocketClient* mClient;  // ref counted
        void* mAddress;         // address to lookup; owned
        void* mAddress;         // address to lookup; owned
        int mAddressLen;        // length of address to look up
        int mAddressLen;        // length of address to look up
        int mAddressFamily;     // address family
        int mAddressFamily;     // address family
@@ -129,17 +142,16 @@ class DnsProxyListener : public FrameworkListener {
        int runCommand(SocketClient* c, int argc, char** argv) override;
        int runCommand(SocketClient* c, int argc, char** argv) override;
    };
    };


    class ResNSendHandler {
    class ResNSendHandler : public Handler {
      public:
      public:
        ResNSendHandler(SocketClient* c, std::string msg, uint32_t flags,
        ResNSendHandler(SocketClient* c, std::string msg, uint32_t flags,
                        const android_net_context& netcontext);
                        const android_net_context& netcontext);
        ~ResNSendHandler();
        ~ResNSendHandler() override = default;


        void run();
        void run() override;
        std::string threadName();
        std::string threadName() override;


      private:
      private:
        SocketClient* mClient;  // ref counted
        std::string mMsg;
        std::string mMsg;
        uint32_t mFlags;
        uint32_t mFlags;
        android_net_context mNetContext;
        android_net_context mNetContext;