Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ae02a1f4 authored by Siarhei Vishniakou's avatar Siarhei Vishniakou
Browse files

Store connections by token instead of by fd

The connections are currently stored by fd. If a connection is removed
via 'removeInputChannel', it is possible to re-create the same
connection and have it keyed by the same fd. When this happens, a race
condition may occur where a socket hangup on this fd would cause the
removal of a newly registered connection.

In this refactor, the connections are no longer stored by fd. The looper
interface for adding fds has two versions:
1) the old one that we are currently using, which is marked as 'do not
use'
2) the new one where a callback object is provided instead.

In this CL, we switch to the new version of the callback.

There is now also no need to store the inputchannels in a separate
structure, because we can use the connections collection that's now
keyed by token to find them.

In a future refactor, we should switch to using 'unique_ptr' for the
inputchannels. Most of the time when we are looking for an input
channel, we are actually interested in finding the corresponding
connection.

If we switch Connection to shared_ptr, we can also look into switching
LooperEventCallback to store a weak pointer to a connection instead of
storing the connection token. This should speed up the handling of
events, by avoiding a map lookup.

Test: ./reinitinput.sh. Observe that it doesnt finish after this patch
Test: atest inputflinger_tests
Bug: 182478748

Change-Id: I601f765eebfadcaeff3661a10a10c4a4f0477389
parent 96e84d98
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -229,7 +229,7 @@ public:
    InputChannel(const InputChannel& other)
          : mName(other.mName), mFd(::dup(other.mFd)), mToken(other.mToken){};
    InputChannel(const std::string name, android::base::unique_fd fd, sp<IBinder> token);
    virtual ~InputChannel();
    ~InputChannel() override;
    /**
     * Create a pair of input channels.
     * The two returned input channels are equivalent, and are labeled as "server" and "client"
+104 −116
Original line number Diff line number Diff line
@@ -283,27 +283,6 @@ static V getValueByKey(const std::unordered_map<K, V>& map, K key) {
    return it != map.end() ? it->second : V{};
}

/**
 * Find the entry in std::unordered_map by value, and remove it.
 * If more than one entry has the same value, then all matching
 * key-value pairs will be removed.
 *
 * Return true if at least one value has been removed.
 */
template <typename K, typename V>
static bool removeByValue(std::unordered_map<K, V>& map, const V& value) {
    bool removed = false;
    for (auto it = map.begin(); it != map.end();) {
        if (it->second == value) {
            it = map.erase(it);
            removed = true;
        } else {
            it++;
        }
    }
    return removed;
}

static bool haveSameToken(const sp<InputWindowHandle>& first, const sp<InputWindowHandle>& second) {
    if (first == second) {
        return true;
@@ -507,8 +486,8 @@ InputDispatcher::~InputDispatcher() {
        drainInboundQueueLocked();
    }

    while (!mConnectionsByFd.empty()) {
        sp<Connection> connection = mConnectionsByFd.begin()->second;
    while (!mConnectionsByToken.empty()) {
        sp<Connection> connection = mConnectionsByToken.begin()->second;
        removeInputChannel(connection->inputChannel->getConnectionToken());
    }
}
@@ -3297,21 +3276,16 @@ void InputDispatcher::releaseDispatchEntry(DispatchEntry* dispatchEntry) {
    delete dispatchEntry;
}

int InputDispatcher::handleReceiveCallback(int fd, int events, void* data) {
    InputDispatcher* d = static_cast<InputDispatcher*>(data);

    { // acquire lock
        std::scoped_lock _l(d->mLock);

        if (d->mConnectionsByFd.find(fd) == d->mConnectionsByFd.end()) {
            ALOGE("Received spurious receive callback for unknown input channel.  "
                  "fd=%d, events=0x%x",
                  fd, events);
int InputDispatcher::handleReceiveCallback(int events, sp<IBinder> connectionToken) {
    std::scoped_lock _l(mLock);
    sp<Connection> connection = getConnectionLocked(connectionToken);
    if (connection == nullptr) {
        ALOGW("Received looper callback for unknown input channel token %p.  events=0x%x",
              connectionToken.get(), events);
        return 0; // remove the callback
    }

    bool notify;
        sp<Connection> connection = d->mConnectionsByFd[fd];
    if (!(events & (ALOOPER_EVENT_ERROR | ALOOPER_EVENT_HANGUP))) {
        if (!(events & ALOOPER_EVENT_INPUT)) {
            ALOGW("channel '%s' ~ Received spurious callback for unhandled poll event.  "
@@ -3334,15 +3308,15 @@ int InputDispatcher::handleReceiveCallback(int fd, int events, void* data) {
            if (std::holds_alternative<InputPublisher::Finished>(*result)) {
                const InputPublisher::Finished& finish =
                        std::get<InputPublisher::Finished>(*result);
                    d->finishDispatchCycleLocked(currentTime, connection, finish.seq,
                                                 finish.handled, finish.consumeTime);
                finishDispatchCycleLocked(currentTime, connection, finish.seq, finish.handled,
                                          finish.consumeTime);
            } else if (std::holds_alternative<InputPublisher::Timeline>(*result)) {
                // TODO(b/167947340): Report this data to LatencyTracker
            }
            gotOne = true;
        }
        if (gotOne) {
                d->runCommandsLockedInterruptible();
            runCommandsLockedInterruptible();
            if (status == WOULD_BLOCK) {
                return 1;
            }
@@ -3358,25 +3332,22 @@ int InputDispatcher::handleReceiveCallback(int fd, int events, void* data) {
        // Monitor channels are never explicitly unregistered.
        // We do it automatically when the remote endpoint is closed so don't warn about them.
        const bool stillHaveWindowHandle =
                    d->getWindowHandleLocked(connection->inputChannel->getConnectionToken()) !=
                    nullptr;
                getWindowHandleLocked(connection->inputChannel->getConnectionToken()) != nullptr;
        notify = !connection->monitor && stillHaveWindowHandle;
        if (notify) {
                ALOGW("channel '%s' ~ Consumer closed input channel or an error occurred.  "
                      "events=0x%x",
            ALOGW("channel '%s' ~ Consumer closed input channel or an error occurred.  events=0x%x",
                  connection->getInputChannelName().c_str(), events);
        }
    }

    // Remove the channel.
        d->removeInputChannelLocked(connection->inputChannel->getConnectionToken(), notify);
    removeInputChannelLocked(connection->inputChannel->getConnectionToken(), notify);
    return 0; // remove the callback
    }             // release lock
}

void InputDispatcher::synthesizeCancelationEventsForAllConnectionsLocked(
        const CancelationOptions& options) {
    for (const auto& [fd, connection] : mConnectionsByFd) {
    for (const auto& [token, connection] : mConnectionsByToken) {
        synthesizeCancelationEventsForConnectionLocked(connection, options);
    }
}
@@ -4342,11 +4313,11 @@ bool InputDispatcher::hasResponsiveConnectionLocked(InputWindowHandle& windowHan

std::shared_ptr<InputChannel> InputDispatcher::getInputChannelLocked(
        const sp<IBinder>& token) const {
    size_t count = mInputChannelsByToken.count(token);
    if (count == 0) {
    auto connectionIt = mConnectionsByToken.find(token);
    if (connectionIt == mConnectionsByToken.end()) {
        return nullptr;
    }
    return mInputChannelsByToken.at(token);
    return connectionIt->second->inputChannel;
}

void InputDispatcher::updateWindowHandlesForDisplayLocked(
@@ -4996,13 +4967,13 @@ void InputDispatcher::dumpDispatchStateLocked(std::string& dump) {
        dump += INDENT "ReplacedKeys: <empty>\n";
    }

    if (!mConnectionsByFd.empty()) {
    if (!mConnectionsByToken.empty()) {
        dump += INDENT "Connections:\n";
        for (const auto& pair : mConnectionsByFd) {
            const sp<Connection>& connection = pair.second;
        for (const auto& [token, connection] : mConnectionsByToken) {
            dump += StringPrintf(INDENT2 "%i: channelName='%s', windowName='%s', "
                                         "status=%s, monitor=%s, responsive=%s\n",
                                 pair.first, connection->getInputChannelName().c_str(),
                                 connection->inputChannel->getFd().get(),
                                 connection->getInputChannelName().c_str(),
                                 connection->getWindowName().c_str(), connection->getStatusLabel(),
                                 toString(connection->monitor), toString(connection->responsive));

@@ -5050,14 +5021,23 @@ void InputDispatcher::dumpMonitors(std::string& dump, const std::vector<Monitor>
    }
}

class LooperEventCallback : public LooperCallback {
public:
    LooperEventCallback(std::function<int(int events)> callback) : mCallback(callback) {}
    int handleEvent(int /*fd*/, int events, void* /*data*/) override { return mCallback(events); }

private:
    std::function<int(int events)> mCallback;
};

Result<std::unique_ptr<InputChannel>> InputDispatcher::createInputChannel(const std::string& name) {
#if DEBUG_CHANNEL_CREATION
    ALOGD("channel '%s' ~ createInputChannel", name.c_str());
#endif

    std::shared_ptr<InputChannel> serverChannel;
    std::unique_ptr<InputChannel> serverChannel;
    std::unique_ptr<InputChannel> clientChannel;
    status_t result = openInputChannelPair(name, serverChannel, clientChannel);
    status_t result = InputChannel::openInputChannelPair(name, serverChannel, clientChannel);

    if (result) {
        return base::Error(result) << "Failed to open input channel pair with name " << name;
@@ -5065,13 +5045,20 @@ Result<std::unique_ptr<InputChannel>> InputDispatcher::createInputChannel(const

    { // acquire lock
        std::scoped_lock _l(mLock);
        sp<Connection> connection = new Connection(serverChannel, false /*monitor*/, mIdGenerator);

        const sp<IBinder>& token = serverChannel->getConnectionToken();
        int fd = serverChannel->getFd();
        mConnectionsByFd[fd] = connection;
        mInputChannelsByToken[serverChannel->getConnectionToken()] = serverChannel;
        sp<Connection> connection =
                new Connection(std::move(serverChannel), false /*monitor*/, mIdGenerator);

        if (mConnectionsByToken.find(token) != mConnectionsByToken.end()) {
            ALOGE("Created a new connection, but the token %p is already known", token.get());
        }
        mConnectionsByToken.emplace(token, connection);

        std::function<int(int events)> callback = std::bind(&InputDispatcher::handleReceiveCallback,
                                                            this, std::placeholders::_1, token);

        mLooper->addFd(fd, 0, ALOOPER_EVENT_INPUT, handleReceiveCallback, this);
        mLooper->addFd(fd, 0, ALOOPER_EVENT_INPUT, new LooperEventCallback(callback), nullptr);
    } // release lock

    // Wake the looper because some connections have changed.
@@ -5099,18 +5086,21 @@ Result<std::unique_ptr<InputChannel>> InputDispatcher::createInputMonitor(int32_
        }

        sp<Connection> connection = new Connection(serverChannel, true /*monitor*/, mIdGenerator);

        const sp<IBinder>& token = serverChannel->getConnectionToken();
        const int fd = serverChannel->getFd();
        mConnectionsByFd[fd] = connection;
        mInputChannelsByToken[serverChannel->getConnectionToken()] = serverChannel;

        if (mConnectionsByToken.find(token) != mConnectionsByToken.end()) {
            ALOGE("Created a new connection, but the token %p is already known", token.get());
        }
        mConnectionsByToken.emplace(token, connection);
        std::function<int(int events)> callback = std::bind(&InputDispatcher::handleReceiveCallback,
                                                            this, std::placeholders::_1, token);

        auto& monitorsByDisplay =
                isGestureMonitor ? mGestureMonitorsByDisplay : mGlobalMonitorsByDisplay;
        monitorsByDisplay[displayId].emplace_back(serverChannel, pid);

        mLooper->addFd(fd, 0, ALOOPER_EVENT_INPUT, handleReceiveCallback, this);
        ALOGI("Created monitor %s for display %" PRId32 ", gesture=%s, pid=%" PRId32, name.c_str(),
              displayId, toString(isGestureMonitor), pid);
        mLooper->addFd(fd, 0, ALOOPER_EVENT_INPUT, new LooperEventCallback(callback), nullptr);
    }

    // Wake the looper because some connections have changed.
@@ -5143,7 +5133,6 @@ status_t InputDispatcher::removeInputChannelLocked(const sp<IBinder>& connection
    }

    removeConnectionLocked(connection);
    mInputChannelsByToken.erase(connectionToken);

    if (connection->monitor) {
        removeMonitorChannelLocked(connectionToken);
@@ -5301,9 +5290,8 @@ sp<Connection> InputDispatcher::getConnectionLocked(const sp<IBinder>& inputConn
        return nullptr;
    }

    for (const auto& pair : mConnectionsByFd) {
        const sp<Connection>& connection = pair.second;
        if (connection->inputChannel->getConnectionToken() == inputConnectionToken) {
    for (const auto& [token, connection] : mConnectionsByToken) {
        if (token == inputConnectionToken) {
            return connection;
        }
    }
@@ -5321,7 +5309,7 @@ std::string InputDispatcher::getConnectionNameLocked(const sp<IBinder>& connecti

void InputDispatcher::removeConnectionLocked(const sp<Connection>& connection) {
    mAnrTracker.eraseToken(connection->inputChannel->getConnectionToken());
    removeByValue(mConnectionsByFd, connection);
    mConnectionsByToken.erase(connection->inputChannel->getConnectionToken());
}

void InputDispatcher::onDispatchCycleFinishedLocked(nsecs_t currentTime,
+5 −6
Original line number Diff line number Diff line
@@ -211,9 +211,6 @@ private:
                                                    bool addPortalWindows = false,
                                                    bool ignoreDragWindow = false) REQUIRES(mLock);

    // All registered connections mapped by channel file descriptor.
    std::unordered_map<int, sp<Connection>> mConnectionsByFd GUARDED_BY(mLock);

    sp<Connection> getConnectionLocked(const sp<IBinder>& inputConnectionToken) const
            REQUIRES(mLock);

@@ -225,8 +222,10 @@ private:
    struct StrongPointerHash {
        std::size_t operator()(const sp<T>& b) const { return std::hash<T*>{}(b.get()); }
    };
    std::unordered_map<sp<IBinder>, std::shared_ptr<InputChannel>, StrongPointerHash<IBinder>>
            mInputChannelsByToken GUARDED_BY(mLock);

    // All registered connections mapped by input channel token.
    std::unordered_map<sp<IBinder>, sp<Connection>, StrongPointerHash<IBinder>> mConnectionsByToken
            GUARDED_BY(mLock);

    // Finds the display ID of the gesture monitor identified by the provided token.
    std::optional<int32_t> findGestureMonitorDisplayByTokenLocked(const sp<IBinder>& token)
@@ -544,7 +543,7 @@ private:
                                        bool notify) REQUIRES(mLock);
    void drainDispatchQueue(std::deque<DispatchEntry*>& queue);
    void releaseDispatchEntry(DispatchEntry* dispatchEntry);
    static int handleReceiveCallback(int fd, int events, void* data);
    int handleReceiveCallback(int events, sp<IBinder> connectionToken);
    // The action sent should only be of type AMOTION_EVENT_*
    void dispatchPointerDownOutsideFocus(uint32_t source, int32_t action,
                                         const sp<IBinder>& newToken) REQUIRES(mLock);