Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 8fd80e92 authored by Jeff Brown's avatar Jeff Brown Committed by Android Git Automerger
Browse files

am 72ce4235: Merge "Fix possible race conditions during channel unregistration." into gingerbread

Merge commit '72ce4235' into gingerbread-plus-aosp

* commit '72ce4235':
  Fix possible race conditions during channel unregistration.
parents aea40e3c 72ce4235
Loading
Loading
Loading
Loading
+69 −30
Original line number Diff line number Diff line
@@ -76,10 +76,14 @@ private:
            STATUS_ZOMBIE
        };

        Connection(const sp<InputChannel>& inputChannel, const sp<PollLoop>& pollLoop);
        Connection(uint16_t id,
                const sp<InputChannel>& inputChannel, const sp<PollLoop>& pollLoop);

        inline const char* getInputChannelName() const { return inputChannel->getName().string(); }

        // A unique id for this connection.
        uint16_t id;

        Status status;

        sp<InputChannel> inputChannel;
@@ -91,29 +95,34 @@ private:
        // The sequence number of the current event being dispatched.
        // This is used as part of the finished token as a way to determine whether the finished
        // token is still valid before sending a finished signal back to the publisher.
        uint32_t messageSeqNum;
        uint16_t messageSeqNum;

        // True if a message has been received from the publisher but not yet finished.
        bool messageInProgress;
    };

    Mutex mLock;
    uint16_t mNextConnectionId;
    KeyedVector<int32_t, sp<Connection> > mConnectionsByReceiveFd;

    ssize_t getConnectionIndex(const sp<InputChannel>& inputChannel);

    static void handleInputChannelDisposed(JNIEnv* env,
            jobject inputChannelObj, const sp<InputChannel>& inputChannel, void* data);

    static bool handleReceiveCallback(int receiveFd, int events, void* data);

    static jlong generateFinishedToken(int32_t receiveFd, int32_t messageSeqNum);
    static jlong generateFinishedToken(int32_t receiveFd,
            uint16_t connectionId, uint16_t messageSeqNum);

    static void parseFinishedToken(jlong finishedToken,
            int32_t* outReceiveFd, uint32_t* outMessageIndex);
            int32_t* outReceiveFd, uint16_t* outConnectionId, uint16_t* outMessageIndex);
};

// ----------------------------------------------------------------------------

NativeInputQueue::NativeInputQueue() {
NativeInputQueue::NativeInputQueue() :
        mNextConnectionId(0) {
}

NativeInputQueue::~NativeInputQueue() {
@@ -134,18 +143,17 @@ status_t NativeInputQueue::registerInputChannel(JNIEnv* env, jobject inputChanne

    sp<PollLoop> pollLoop = android_os_MessageQueue_getPollLoop(env, messageQueueObj);

    int receiveFd;
    { // acquire lock
        AutoMutex _l(mLock);

        receiveFd = inputChannel->getReceivePipeFd();
        if (mConnectionsByReceiveFd.indexOfKey(receiveFd) >= 0) {
        if (getConnectionIndex(inputChannel) >= 0) {
            LOGW("Attempted to register already registered input channel '%s'",
                    inputChannel->getName().string());
            return BAD_VALUE;
        }

        sp<Connection> connection = new Connection(inputChannel, pollLoop);
        uint16_t connectionId = mNextConnectionId++;
        sp<Connection> connection = new Connection(connectionId, inputChannel, pollLoop);
        status_t result = connection->inputConsumer.initialize();
        if (result) {
            LOGW("Failed to initialize input consumer for input channel '%s', status=%d",
@@ -155,13 +163,14 @@ status_t NativeInputQueue::registerInputChannel(JNIEnv* env, jobject inputChanne

        connection->inputHandlerObjGlobal = env->NewGlobalRef(inputHandlerObj);

        int32_t receiveFd = inputChannel->getReceivePipeFd();
        mConnectionsByReceiveFd.add(receiveFd, connection);

        pollLoop->setCallback(receiveFd, POLLIN, handleReceiveCallback, this);
    } // release lock

    android_view_InputChannel_setDisposeCallback(env, inputChannelObj,
            handleInputChannelDisposed, this);

    pollLoop->setCallback(receiveFd, POLLIN, handleReceiveCallback, this);
    return OK;
}

@@ -177,38 +186,56 @@ status_t NativeInputQueue::unregisterInputChannel(JNIEnv* env, jobject inputChan
    LOGD("channel '%s' - Unregistered", inputChannel->getName().string());
#endif

    int32_t receiveFd;
    sp<Connection> connection;
    { // acquire lock
        AutoMutex _l(mLock);

        receiveFd = inputChannel->getReceivePipeFd();
        ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(receiveFd);
        ssize_t connectionIndex = getConnectionIndex(inputChannel);
        if (connectionIndex < 0) {
            LOGW("Attempted to unregister already unregistered input channel '%s'",
                    inputChannel->getName().string());
            return BAD_VALUE;
        }

        connection = mConnectionsByReceiveFd.valueAt(connectionIndex);
        sp<Connection> connection = mConnectionsByReceiveFd.valueAt(connectionIndex);
        mConnectionsByReceiveFd.removeItemsAt(connectionIndex);

        connection->status = Connection::STATUS_ZOMBIE;

        connection->pollLoop->removeCallback(inputChannel->getReceivePipeFd());

        env->DeleteGlobalRef(connection->inputHandlerObjGlobal);
        connection->inputHandlerObjGlobal = NULL;

        if (connection->messageInProgress) {
            LOGI("Sending finished signal for input channel '%s' since it is being unregistered "
                    "while an input message is still in progress.",
                    connection->getInputChannelName());
            connection->messageInProgress = false;
            connection->inputConsumer.sendFinishedSignal(); // ignoring result
        }
    } // release lock

    android_view_InputChannel_setDisposeCallback(env, inputChannelObj, NULL, NULL);

    connection->pollLoop->removeCallback(receiveFd);
    return OK;
}

ssize_t NativeInputQueue::getConnectionIndex(const sp<InputChannel>& inputChannel) {
    ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(inputChannel->getReceivePipeFd());
    if (connectionIndex >= 0) {
        sp<Connection> connection = mConnectionsByReceiveFd.valueAt(connectionIndex);
        if (connection->inputChannel.get() == inputChannel.get()) {
            return connectionIndex;
        }
    }

    return -1;
}

status_t NativeInputQueue::finished(JNIEnv* env, jlong finishedToken, bool ignoreSpuriousFinish) {
    int32_t receiveFd;
    uint32_t messageSeqNum;
    parseFinishedToken(finishedToken, &receiveFd, &messageSeqNum);
    uint16_t connectionId;
    uint16_t messageSeqNum;
    parseFinishedToken(finishedToken, &receiveFd, &connectionId, &messageSeqNum);

    { // acquire lock
        AutoMutex _l(mLock);
@@ -216,16 +243,25 @@ status_t NativeInputQueue::finished(JNIEnv* env, jlong finishedToken, bool ignor
        ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(receiveFd);
        if (connectionIndex < 0) {
            if (! ignoreSpuriousFinish) {
                LOGW("Attempted to finish input on channel that is no longer registered.");
                LOGI("Ignoring finish signal on channel that is no longer registered.");
            }
            return DEAD_OBJECT;
        }

        sp<Connection> connection = mConnectionsByReceiveFd.valueAt(connectionIndex);
        if (connectionId != connection->id) {
            if (! ignoreSpuriousFinish) {
                LOGI("Ignoring finish signal on channel that is no longer registered.");
            }
            return DEAD_OBJECT;
        }

        if (messageSeqNum != connection->messageSeqNum || ! connection->messageInProgress) {
            if (! ignoreSpuriousFinish) {
                LOGW("Attempted to finish input twice on channel '%s'.",
                        connection->getInputChannelName());
                LOGW("Attempted to finish input twice on channel '%s'.  "
                        "finished messageSeqNum=%d, current messageSeqNum=%d, messageInProgress=%d",
                        connection->getInputChannelName(),
                        messageSeqNum, connection->messageSeqNum, connection->messageInProgress);
            }
            return INVALID_OPERATION;
        }
@@ -312,7 +348,7 @@ bool NativeInputQueue::handleReceiveCallback(int receiveFd, int events, void* da
        connection->messageInProgress = true;
        connection->messageSeqNum += 1;

        finishedToken = generateFinishedToken(receiveFd, connection->messageSeqNum);
        finishedToken = generateFinishedToken(receiveFd, connection->id, connection->messageSeqNum);

        inputHandlerObjLocal = env->NewLocalRef(connection->inputHandlerObjGlobal);
    } // release lock
@@ -384,20 +420,23 @@ bool NativeInputQueue::handleReceiveCallback(int receiveFd, int events, void* da
    return true;
}

jlong NativeInputQueue::generateFinishedToken(int32_t receiveFd, int32_t messageSeqNum) {
    return (jlong(receiveFd) << 32) | jlong(messageSeqNum);
jlong NativeInputQueue::generateFinishedToken(int32_t receiveFd, uint16_t connectionId,
        uint16_t messageSeqNum) {
    return (jlong(receiveFd) << 32) | (jlong(connectionId) << 16) | jlong(messageSeqNum);
}

void NativeInputQueue::parseFinishedToken(jlong finishedToken,
        int32_t* outReceiveFd, uint32_t* outMessageIndex) {
        int32_t* outReceiveFd, uint16_t* outConnectionId, uint16_t* outMessageIndex) {
    *outReceiveFd = int32_t(finishedToken >> 32);
    *outMessageIndex = uint32_t(finishedToken & 0xffffffff);
    *outConnectionId = uint16_t(finishedToken >> 16);
    *outMessageIndex = uint16_t(finishedToken);
}

// ----------------------------------------------------------------------------

NativeInputQueue::Connection::Connection(const sp<InputChannel>& inputChannel, const sp<PollLoop>& pollLoop) :
    status(STATUS_NORMAL), inputChannel(inputChannel), inputConsumer(inputChannel),
NativeInputQueue::Connection::Connection(uint16_t id,
        const sp<InputChannel>& inputChannel, const sp<PollLoop>& pollLoop) :
    id(id), status(STATUS_NORMAL), inputChannel(inputChannel), inputConsumer(inputChannel),
    pollLoop(pollLoop), inputHandlerObjGlobal(NULL),
    messageSeqNum(0), messageInProgress(false) {
}
+2 −0
Original line number Diff line number Diff line
@@ -554,6 +554,8 @@ private:
    // All registered connections mapped by receive pipe file descriptor.
    KeyedVector<int, sp<Connection> > mConnectionsByReceiveFd;

    ssize_t getConnectionIndex(const sp<InputChannel>& inputChannel);

    // Active connections are connections that have a non-empty outbound queue.
    // We don't use a ref-counted pointer here because we explicitly abort connections
    // during unregistration which causes the connection's outbound queue to be cleared
+20 −12
Original line number Diff line number Diff line
@@ -433,8 +433,7 @@ void InputDispatcher::dispatchEventToCurrentInputTargetsLocked(nsecs_t currentTi
    for (size_t i = 0; i < mCurrentInputTargets.size(); i++) {
        const InputTarget& inputTarget = mCurrentInputTargets.itemAt(i);

        ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(
                inputTarget.inputChannel->getReceivePipeFd());
        ssize_t connectionIndex = getConnectionIndex(inputTarget.inputChannel);
        if (connectionIndex >= 0) {
            sp<Connection> connection = mConnectionsByReceiveFd.valueAt(connectionIndex);
            prepareDispatchCycleLocked(currentTime, connection, eventEntry, & inputTarget,
@@ -1367,12 +1366,10 @@ status_t InputDispatcher::registerInputChannel(const sp<InputChannel>& inputChan
    LOGD("channel '%s' ~ registerInputChannel", inputChannel->getName().string());
#endif

    int receiveFd;
    { // acquire lock
        AutoMutex _l(mLock);

        receiveFd = inputChannel->getReceivePipeFd();
        if (mConnectionsByReceiveFd.indexOfKey(receiveFd) >= 0) {
        if (getConnectionIndex(inputChannel) >= 0) {
            LOGW("Attempted to register already registered input channel '%s'",
                    inputChannel->getName().string());
            return BAD_VALUE;
@@ -1386,12 +1383,13 @@ status_t InputDispatcher::registerInputChannel(const sp<InputChannel>& inputChan
            return status;
        }

        int32_t receiveFd = inputChannel->getReceivePipeFd();
        mConnectionsByReceiveFd.add(receiveFd, connection);

        mPollLoop->setCallback(receiveFd, POLLIN, handleReceiveCallback, this);

        runCommandsLockedInterruptible();
    } // release lock

    mPollLoop->setCallback(receiveFd, POLLIN, handleReceiveCallback, this);
    return OK;
}

@@ -1400,12 +1398,10 @@ status_t InputDispatcher::unregisterInputChannel(const sp<InputChannel>& inputCh
    LOGD("channel '%s' ~ unregisterInputChannel", inputChannel->getName().string());
#endif

    int32_t receiveFd;
    { // acquire lock
        AutoMutex _l(mLock);

        receiveFd = inputChannel->getReceivePipeFd();
        ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(receiveFd);
        ssize_t connectionIndex = getConnectionIndex(inputChannel);
        if (connectionIndex < 0) {
            LOGW("Attempted to unregister already unregistered input channel '%s'",
                    inputChannel->getName().string());
@@ -1417,20 +1413,32 @@ status_t InputDispatcher::unregisterInputChannel(const sp<InputChannel>& inputCh

        connection->status = Connection::STATUS_ZOMBIE;

        mPollLoop->removeCallback(inputChannel->getReceivePipeFd());

        nsecs_t currentTime = now();
        abortDispatchCycleLocked(currentTime, connection, true /*broken*/);

        runCommandsLockedInterruptible();
    } // release lock

    mPollLoop->removeCallback(receiveFd);

    // Wake the poll loop because removing the connection may have changed the current
    // synchronization state.
    mPollLoop->wake();
    return OK;
}

ssize_t InputDispatcher::getConnectionIndex(const sp<InputChannel>& inputChannel) {
    ssize_t connectionIndex = mConnectionsByReceiveFd.indexOfKey(inputChannel->getReceivePipeFd());
    if (connectionIndex >= 0) {
        sp<Connection> connection = mConnectionsByReceiveFd.valueAt(connectionIndex);
        if (connection->inputChannel.get() == inputChannel.get()) {
            return connectionIndex;
        }
    }

    return -1;
}

void InputDispatcher::activateConnectionLocked(Connection* connection) {
    for (size_t i = 0; i < mActiveConnections.size(); i++) {
        if (mActiveConnections.itemAt(i) == connection) {
+10 −14
Original line number Diff line number Diff line
@@ -319,9 +319,9 @@ private:
    bool isScreenOn();
    bool isScreenBright();

    // Weak references to all currently registered input channels by receive fd.
    // Weak references to all currently registered input channels by connection pointer.
    Mutex mInputChannelRegistryLock;
    KeyedVector<int, jweak> mInputChannelObjWeakByReceiveFd;
    KeyedVector<InputChannel*, jweak> mInputChannelObjWeakTable;

    jobject getInputChannelObjLocal(JNIEnv* env, const sp<InputChannel>& inputChannel);

@@ -509,8 +509,7 @@ status_t NativeInputManager::registerInputChannel(JNIEnv* env,
    {
        AutoMutex _l(mInputChannelRegistryLock);

        ssize_t index = mInputChannelObjWeakByReceiveFd.indexOfKey(
                inputChannel->getReceivePipeFd());
        ssize_t index = mInputChannelObjWeakTable.indexOfKey(inputChannel.get());
        if (index >= 0) {
            LOGE("Input channel object '%s' has already been registered",
                    inputChannel->getName().string());
@@ -518,8 +517,7 @@ status_t NativeInputManager::registerInputChannel(JNIEnv* env,
            goto DeleteWeakRef;
        }

        mInputChannelObjWeakByReceiveFd.add(inputChannel->getReceivePipeFd(),
                inputChannelObjWeak);
        mInputChannelObjWeakTable.add(inputChannel.get(), inputChannelObjWeak);
    }

    status = mInputManager->registerInputChannel(inputChannel);
@@ -534,7 +532,7 @@ status_t NativeInputManager::registerInputChannel(JNIEnv* env,
    // Failed!
    {
        AutoMutex _l(mInputChannelRegistryLock);
        mInputChannelObjWeakByReceiveFd.removeItem(inputChannel->getReceivePipeFd());
        mInputChannelObjWeakTable.removeItem(inputChannel.get());
    }

DeleteWeakRef:
@@ -548,16 +546,15 @@ status_t NativeInputManager::unregisterInputChannel(JNIEnv* env,
    {
        AutoMutex _l(mInputChannelRegistryLock);

        ssize_t index = mInputChannelObjWeakByReceiveFd.indexOfKey(
                inputChannel->getReceivePipeFd());
        ssize_t index = mInputChannelObjWeakTable.indexOfKey(inputChannel.get());
        if (index < 0) {
            LOGE("Input channel object '%s' is not currently registered",
                    inputChannel->getName().string());
            return INVALID_OPERATION;
        }

        inputChannelObjWeak = mInputChannelObjWeakByReceiveFd.valueAt(index);
        mInputChannelObjWeakByReceiveFd.removeItemsAt(index);
        inputChannelObjWeak = mInputChannelObjWeakTable.valueAt(index);
        mInputChannelObjWeakTable.removeItemsAt(index);
    }

    env->DeleteWeakGlobalRef(inputChannelObjWeak);
@@ -572,13 +569,12 @@ jobject NativeInputManager::getInputChannelObjLocal(JNIEnv* env,
    {
        AutoMutex _l(mInputChannelRegistryLock);

        ssize_t index = mInputChannelObjWeakByReceiveFd.indexOfKey(
                inputChannel->getReceivePipeFd());
        ssize_t index = mInputChannelObjWeakTable.indexOfKey(inputChannel.get());
        if (index < 0) {
            return NULL;
        }

        jweak inputChannelObjWeak = mInputChannelObjWeakByReceiveFd.valueAt(index);
        jweak inputChannelObjWeak = mInputChannelObjWeakTable.valueAt(index);
        return env->NewLocalRef(inputChannelObjWeak);
    }
}