Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit b4b83acf authored by Shai Barack's avatar Shai Barack
Browse files

Store Messages directly in priority queues

Remove MessageNode indirection to eliminate a potential performance penalty.

To avoid ABA, we no longer recycle Messages in concurrent MessageQueue code.
The ABA happens because skiplist removal is by Message instance identity.
Potential ABA:
1. Thread 1 enqueues Message A
2. Thread 2 calls remove that would match A, starts iterating on skiplist, finds A
3. Thread 3 wakes up, polls A from queue, recycles it
4. Thread 3 obtains A from pool and enqueues it as a semantically new Message
5. Thread 2 remove()s A from skiplist

Bug: 415954362
Flag: build.RELEASE_PACKAGE_MESSAGEQUEUE_IMPLEMENTATION
Change-Id: Ib3a7ff94c5234c1d302cb4c9b04d1f963293641f
parent ee0b88cc
Loading
Loading
Loading
Loading
+102 −137
Original line number Diff line number Diff line
@@ -285,10 +285,10 @@ public final class MessageQueue {
        }
    }

    static final class EnqueueOrder implements Comparator<MessageNode> {
    static final class EnqueueOrder implements Comparator<Message> {
        @Override
        public int compare(MessageNode n1, MessageNode n2) {
            return compareMessages(n1.mMessage, n2.mMessage);
        public int compare(Message m1, Message m2) {
            return compareMessages(m1, m2);
        }
    }

@@ -311,15 +311,15 @@ public final class MessageQueue {
        return 0;
    }

    private static boolean isBarrier(MessageNode msgNode) {
        return msgNode != null && msgNode.mMessage.target == null;
    private static boolean isBarrier(Message msg) {
        return msg != null && msg.target == null;
    }

    static final class MatchDeliverableMessages extends MessageCompare {
        @Override
        public boolean compareMessage(MessageNode n, Handler h, int what, Object object, Runnable r,
        public boolean compareMessage(Message m, Handler h, int what, Object object, Runnable r,
                long when) {
            return n.mMessage.when <= when;
            return m.when <= when;
        }
    }
    private final MatchDeliverableMessages mMatchDeliverableMessages =
@@ -332,13 +332,13 @@ public final class MessageQueue {
            return false;
        }

        final MessageNode msgNode = first(mPriorityQueue);
        if (msgNode != null && msgNode.mMessage.when <= now) {
        final Message msg = first(mPriorityQueue);
        if (msg != null && msg.when <= now) {
            return false;
        }

        final MessageNode asyncMsgNode = first(mAsyncPriorityQueue);
        if (asyncMsgNode != null && asyncMsgNode.mMessage.when <= now) {
        final Message asyncMsg = first(mAsyncPriorityQueue);
        if (asyncMsg != null && asyncMsg.when <= now) {
            return false;
        }

@@ -755,27 +755,25 @@ public final class MessageQueue {
             */

            /* Get the first node from each queue */
            MessageNode msgNode = first(mPriorityQueue);
            MessageNode asyncMsgNode = first(mAsyncPriorityQueue);
            Message msg = first(mPriorityQueue);
            Message asyncMsg = first(mAsyncPriorityQueue);
            final long now = SystemClock.uptimeMillis();

            if (DEBUG) {
                if (msgNode != null) {
                    Message msg = msgNode.mMessage;
                if (msg != null) {
                    Log.d(TAG_C, "Next found node"
                            + " what: " + msg.what
                            + " when: " + msg.when
                            + " seq: " + msgNode.mMessage.insertSeq
                            + " barrier: " + isBarrier(msgNode)
                            + " seq: " + msg.insertSeq
                            + " barrier: " + isBarrier(msg)
                            + " now: " + now);
                }
                if (asyncMsgNode != null) {
                    Message msg = asyncMsgNode.mMessage;
                if (asyncMsg != null) {
                    Log.d(TAG_C, "Next found async node"
                            + " what: " + msg.what
                            + " when: " + msg.when
                            + " seq: " + asyncMsgNode.mMessage.insertSeq
                            + " barrier: " + isBarrier(asyncMsgNode)
                            + " what: " + asyncMsg.what
                            + " when: " + asyncMsg.when
                            + " seq: " + asyncMsg.insertSeq
                            + " barrier: " + isBarrier(asyncMsg)
                            + " now: " + now);
                }
            }
@@ -783,34 +781,37 @@ public final class MessageQueue {
            /*
             * the node which we will return, null if none are ready
             */
            MessageNode found = null;
            Message found = null;
            /*
             * The node from which we will determine our next wakeup time.
             * Null indicates there is no next message ready. If we found a node,
             * we can leave this null as Looper will call us again after delivering
             * the message.
             */
            MessageNode next = null;
            Message next = null;

            /*
             * If we have a barrier we should return the async node (if it exists and is ready)
             */
            if (isBarrier(msgNode)) {
                if (asyncMsgNode != null && (returnEarliest || now >= asyncMsgNode.mMessage.when)) {
                    found = asyncMsgNode;
            if (isBarrier(msg)) {
                if (asyncMsg != null && (returnEarliest || now >= asyncMsg.when)) {
                    found = asyncMsg;
                } else {
                    next = asyncMsgNode;
                    next = asyncMsg;
                }
            } else { /* No barrier. */
                MessageNode earliest;
                /*
                 * If we have two messages, pick the earliest option from either queue.
                 * Otherwise grab whichever node is non-null. If both are null we'll fall through.
                 */
                earliest = pickEarliestNode(msgNode, asyncMsgNode);
                // Pick the earliest of the next sync and async messages, if any.
                Message earliest = msg;
                if (msg == null) {
                    earliest = asyncMsg;
                } else if (asyncMsg != null) {
                    if (compareMessages(msg, asyncMsg) > 0) {
                        earliest = asyncMsg;
                    }
                }

                if (earliest != null) {
                    if (returnEarliest || now >= earliest.mMessage.when) {
                    if (returnEarliest || now >= earliest.when) {
                        found = earliest;
                    } else {
                        next = earliest;
@@ -820,25 +821,23 @@ public final class MessageQueue {

            if (DEBUG) {
                if (found != null) {
                    Message msg = found.mMessage;
                    Log.d(TAG_C, "Will deliver node"
                            + " what: " + msg.what
                            + " when: " + msg.when
                            + " seq: " + found.mMessage.insertSeq
                            + " what: " + found.what
                            + " when: " + found.when
                            + " seq: " + found.insertSeq
                            + " barrier: " + isBarrier(found)
                            + " async: " + found.mMessage.isAsynchronous()
                            + " async: " + found.isAsynchronous()
                            + " now: " + now);
                } else {
                    Log.d(TAG_C, "No node to deliver");
                }
                if (next != null) {
                    Message msg = next.mMessage;
                    Log.d(TAG_C, "Next node"
                            + " what: " + msg.what
                            + " when: " + msg.when
                            + " seq: " + next.mMessage.insertSeq
                            + " what: " + next.what
                            + " when: " + next.when
                            + " seq: " + next.insertSeq
                            + " barrier: " + isBarrier(next)
                            + " async: " + next.mMessage.isAsynchronous()
                            + " async: " + next.isAsynchronous()
                            + " now: " + now);
                } else {
                    Log.d(TAG_C, "No next node");
@@ -865,7 +864,7 @@ public final class MessageQueue {
                    }
                } else {
                    /* Message not ready, or we found one to deliver already, set a timeout */
                    long nextMessageWhen = next.mMessage.when;
                    long nextMessageWhen = next.when;
                    if (nextMessageWhen > now) {
                        mNextPollTimeoutMillis = (int) Math.min(nextMessageWhen - now,
                                Integer.MAX_VALUE);
@@ -899,7 +898,7 @@ public final class MessageQueue {
                        continue;
                    }

                    return found.mMessage;
                    return found;
                }
                return null;
            }
@@ -1277,9 +1276,8 @@ public final class MessageQueue {
        }

        @Override
        public boolean compareMessage(MessageNode n, Handler h, int what, Object object, Runnable r,
        public boolean compareMessage(Message m, Handler h, int what, Object object, Runnable r,
                long when) {
            final Message m = n.mMessage;
            if (m.target == null && m.arg1 == mBarrierToken) {
                return true;
            }
@@ -1292,11 +1290,10 @@ public final class MessageQueue {
        final MatchBarrierToken matchBarrierToken = new MatchBarrierToken(token);

        // Retain the first element to see if we are currently stuck on a barrier.
        final MessageNode first = first(mPriorityQueue);
        final Message m = first(mPriorityQueue);

        removed = findOrRemoveMessages(null, 0, null, null, 0, matchBarrierToken, true);
        if (removed && first != null) {
            Message m = first.mMessage;
        if (removed && m != null) {
            if (m.target == null && m.arg1 == token) {
                /* Wake up next() in case it was sleeping on this barrier. */
                concurrentWake();
@@ -1554,10 +1551,7 @@ public final class MessageQueue {
        if (sUseConcurrent) {
            // Call nextMessage to get the stack drained into our priority queues
            nextMessage(true, false);

            MessageNode queueNode = first(mPriorityQueue);

            return (isBarrier(queueNode));
            return (isBarrier(first(mPriorityQueue)));
        } else {
            Message msg = mMessages;
            return msg != null && msg.target == null;
@@ -1566,9 +1560,8 @@ public final class MessageQueue {

    static final class MatchHandlerWhatAndObject extends MessageCompare {
        @Override
        public boolean compareMessage(MessageNode n, Handler h, int what, Object object, Runnable r,
        public boolean compareMessage(Message m, Handler h, int what, Object object, Runnable r,
                long when) {
            final Message m = n.mMessage;
            if (m.target == h && m.what == what && (object == null || m.obj == object)) {
                return true;
            }
@@ -1609,9 +1602,8 @@ public final class MessageQueue {

    static final class MatchHandlerWhatAndObjectEquals extends MessageCompare {
        @Override
        public boolean compareMessage(MessageNode n, Handler h, int what, Object object, Runnable r,
        public boolean compareMessage(Message m, Handler h, int what, Object object, Runnable r,
                long when) {
            final Message m = n.mMessage;
            if (m.target == h && m.what == what && (object == null || object.equals(m.obj))) {
                return true;
            }
@@ -1652,9 +1644,8 @@ public final class MessageQueue {

    static final class MatchHandlerRunnableAndObject extends MessageCompare {
        @Override
        public boolean compareMessage(MessageNode n, Handler h, int what, Object object, Runnable r,
        public boolean compareMessage(Message m, Handler h, int what, Object object, Runnable r,
                long when) {
            final Message m = n.mMessage;
            if (m.target == h && m.callback == r && (object == null || m.obj == object)) {
                return true;
            }
@@ -1696,9 +1687,9 @@ public final class MessageQueue {

    static final class MatchHandler extends MessageCompare {
        @Override
        public boolean compareMessage(MessageNode n, Handler h, int what, Object object, Runnable r,
        public boolean compareMessage(Message m, Handler h, int what, Object object, Runnable r,
                long when) {
            return n.mMessage.target == h;
            return m.target == h;
        }
    }
    private final MatchHandler mMatchHandler = new MatchHandler();
@@ -1915,9 +1906,8 @@ public final class MessageQueue {

    static final class MatchHandlerRunnableAndObjectEquals extends MessageCompare {
        @Override
        public boolean compareMessage(MessageNode n, Handler h, int what, Object object, Runnable r,
        public boolean compareMessage(Message m, Handler h, int what, Object object, Runnable r,
                long when) {
            final Message m = n.mMessage;
            if (m.target == h && m.callback == r && (object == null || object.equals(m.obj))) {
                return true;
            }
@@ -1990,9 +1980,8 @@ public final class MessageQueue {

    static final class MatchHandlerAndObject extends MessageCompare {
        @Override
        public boolean compareMessage(MessageNode n, Handler h, int what, Object object, Runnable r,
        public boolean compareMessage(Message m, Handler h, int what, Object object, Runnable r,
                long when) {
            final Message m = n.mMessage;
            if (m.target == h && (object == null || m.obj == object)) {
                return true;
            }
@@ -2063,9 +2052,8 @@ public final class MessageQueue {

    static final class MatchHandlerAndObjectEquals extends MessageCompare {
        @Override
        public boolean compareMessage(MessageNode n, Handler h, int what, Object object, Runnable r,
        public boolean compareMessage(Message m, Handler h, int what, Object object, Runnable r,
                long when) {
            final Message m = n.mMessage;
            if (m.target == h && (object == null || object.equals(m.obj))) {
                return true;
            }
@@ -2185,7 +2173,7 @@ public final class MessageQueue {

    static final class MatchAllMessages extends MessageCompare {
        @Override
        public boolean compareMessage(MessageNode n, Handler h, int what, Object object, Runnable r,
        public boolean compareMessage(Message m, Handler h, int what, Object object, Runnable r,
                long when) {
            return true;
        }
@@ -2197,11 +2185,9 @@ public final class MessageQueue {

    static final class MatchAllFutureMessages extends MessageCompare {
        @Override
        public boolean compareMessage(MessageNode n, Handler h, int what, Object object, Runnable r,
        public boolean compareMessage(Message m, Handler h, int what, Object object, Runnable r,
                long when) {
            final Message m = n.mMessage;

            return n.mMessage.when > when;
            return m.when > when;
        }
    }
    private static final MatchAllFutureMessages sMatchAllFutureMessages =
@@ -2213,26 +2199,22 @@ public final class MessageQueue {

    @NeverCompile
    private void printPriorityQueueNodes() {
        Iterator<MessageNode> iterator = mPriorityQueue.iterator();

        Log.d(TAG_C, "* Dump priority queue");
        while (iterator.hasNext()) {
            MessageNode msgNode = iterator.next();
        for (Message msg : mPriorityQueue) {
            Log.d(TAG_C,
                    "** MessageNode what: " + msgNode.mMessage.what
                    + " when " + msgNode.mMessage.when
                    + " seq: " + msgNode.mMessage.insertSeq);
                    "** Message what: " + msg.what
                    + " when " + msg.when
                    + " seq: " + msg.insertSeq);
        }
    }

    @NeverCompile
    private int dumpPriorityQueue(ConcurrentSkipListSet<MessageNode> queue, Printer pw,
    private int dumpPriorityQueue(ConcurrentSkipListSet<Message> queue, Printer pw,
            String prefix, Handler h, int n) {
        int count = 0;
        long now = SystemClock.uptimeMillis();

        for (MessageNode msgNode : queue) {
            Message msg = msgNode.mMessage;
        for (Message msg : queue) {
            if (h == null || h == msg.target) {
                pw.println(prefix + "Message " + (n + count) + ": " + msg.toString(now));
            }
@@ -2290,12 +2272,10 @@ public final class MessageQueue {
    }

    @NeverCompile
    private int dumpPriorityQueue(ConcurrentSkipListSet<MessageNode> queue,
    private int dumpPriorityQueue(ConcurrentSkipListSet<Message> queue,
            ProtoOutputStream proto) {
        int count = 0;

        for (MessageNode msgNode : queue) {
            Message msg = msgNode.mMessage;
        for (Message msg : queue) {
            msg.dumpDebug(proto, MessageQueueProto.MESSAGES);
            count++;
        }
@@ -2442,34 +2422,23 @@ public final class MessageQueue {
     * ConcurrentMessageQueue specific classes methods and variables
     */
    /* Helper to choose the correct queue to insert into. */
    private void insertIntoPriorityQueue(MessageNode msgNode) {
        if (msgNode.mMessage.isAsynchronous()) {
            mAsyncPriorityQueue.add(msgNode);
    private void insertIntoPriorityQueue(Message msg) {
        if (msg.isAsynchronous()) {
            mAsyncPriorityQueue.add(msg);
        } else {
            mPriorityQueue.add(msgNode);
            mPriorityQueue.add(msg);
        }
    }

    private boolean removeFromPriorityQueue(MessageNode msgNode) {
        if (msgNode.mMessage.isAsynchronous()) {
            return mAsyncPriorityQueue.remove(msgNode);
    private boolean removeFromPriorityQueue(Message msg) {
        if (msg.isAsynchronous()) {
            return mAsyncPriorityQueue.remove(msg);
        } else {
            return mPriorityQueue.remove(msgNode);
        }
    }

    private MessageNode pickEarliestNode(MessageNode nodeA, MessageNode nodeB) {
        if (nodeA != null && nodeB != null) {
            if (compareMessages(nodeA.mMessage, nodeB.mMessage) < 0) {
                return nodeA;
            return mPriorityQueue.remove(msg);
        }
            return nodeB;
    }

        return nodeA != null ? nodeA : nodeB;
    }

    private static MessageNode first(ConcurrentSkipListSet<MessageNode> queue) {
    private static Message first(ConcurrentSkipListSet<Message> queue) {
        try {
            return queue.first();
        } catch (NoSuchElementException e) {
@@ -2494,7 +2463,7 @@ public final class MessageQueue {
        while (oldTop.isMessageNode()) {
            MessageNode oldTopMessageNode = (MessageNode) oldTop;
            if (oldTopMessageNode.removeFromStack()) {
                insertIntoPriorityQueue(oldTopMessageNode);
                insertIntoPriorityQueue(oldTopMessageNode.mMessage);
            }
            MessageNode inserted = oldTopMessageNode;
            oldTop = oldTopMessageNode.mNext;
@@ -2636,10 +2605,9 @@ public final class MessageQueue {
            }
        }

        MessageNode(@NonNull Message message, long insertSeq) {
        MessageNode(@NonNull Message message) {
            super(STACK_NODE_MESSAGE);
            mMessage = message;
            message.insertSeq = insertSeq;
        }

        boolean removeFromStack() {
@@ -2669,10 +2637,10 @@ public final class MessageQueue {
    private static final VarHandle sState;

    private volatile StackNode mStateValue = sStackStateParked;
    private final ConcurrentSkipListSet<MessageNode> mPriorityQueue =
            new ConcurrentSkipListSet<MessageNode>(sEnqueueOrder);
    private final ConcurrentSkipListSet<MessageNode> mAsyncPriorityQueue =
            new ConcurrentSkipListSet<MessageNode>(sEnqueueOrder);
    private final ConcurrentSkipListSet<Message> mPriorityQueue =
            new ConcurrentSkipListSet<Message>(sEnqueueOrder);
    private final ConcurrentSkipListSet<Message> mAsyncPriorityQueue =
            new ConcurrentSkipListSet<Message>(sEnqueueOrder);

    /*
     * This helps us ensure that messages with the same timestamp are inserted in FIFO order.
@@ -2917,19 +2885,17 @@ public final class MessageQueue {
    private boolean enqueueMessageUnchecked(@NonNull Message msg, long when) {
        long seq = when != 0 ? ((long) sNextInsertSeq.getAndAdd(this, 1L) + 1L)
                : ((long) sNextFrontInsertSeq.getAndAdd(this, -1L) - 1L);
        /* TODO: Add a MessageNode member to Message so we can avoid this allocation */
        MessageNode node = new MessageNode(msg, seq);
        msg.when = when;
        msg.insertSeq = seq;
        msg.markInUse();
        incAndTraceMessageCount(msg, when);

        if (DEBUG) {
            Log.d(TAG_C, "Insert message"
                    + " what: " + msg.what
                    + " when: " + msg.when
                    + " seq: " + node.mMessage.insertSeq
                    + " barrier: " + isBarrier(node)
                    + " async: " + node.mMessage.isAsynchronous()
                    + " seq: " + msg.insertSeq
                    + " barrier: " + isBarrier(msg)
                    + " async: " + msg.isAsynchronous()
                    + " now: " + SystemClock.uptimeMillis());
        }

@@ -2945,8 +2911,8 @@ public final class MessageQueue {
                return false;
            }

            node.removeFromStack();
            insertIntoPriorityQueue(node);
            insertIntoPriorityQueue(msg);
            incAndTraceMessageCount(msg, when);
            /*
             * We still need to do this even though we are the current thread,
             * otherwise next() may sleep indefinitely.
@@ -2958,6 +2924,7 @@ public final class MessageQueue {
            return true;
        }

        MessageNode node = new MessageNode(msg);
        while (true) {
            StackNode old = (StackNode) sState.getVolatile(this);
            boolean wakeNeeded;
@@ -2986,7 +2953,7 @@ public final class MessageQueue {
                case STACK_NODE_TIMEDPARK:
                    node.mBottomOfStack = (StateNode) old;
                    inactive = true;
                    wakeNeeded = mStackStateTimedPark.mWhenToWake >= node.mMessage.when;
                    wakeNeeded = mStackStateTimedPark.mWhenToWake >= msg.when;
                    node.mWokeUp = wakeNeeded;
                    break;

@@ -3017,6 +2984,7 @@ public final class MessageQueue {
                        mMessageCounts.incrementQueued();
                    }
                }
                incAndTraceMessageCount(msg, when);
                return true;
            }
        }
@@ -3026,7 +2994,7 @@ public final class MessageQueue {
     * This class is used to find matches for hasMessages() and removeMessages()
     */
    abstract static class MessageCompare {
        public abstract boolean compareMessage(MessageNode n, Handler h, int what, Object object,
        public abstract boolean compareMessage(Message m, Handler h, int what, Object object,
                Runnable r, long when);
    }

@@ -3065,14 +3033,15 @@ public final class MessageQueue {
        MessageNode p = (MessageNode) top;

        while (true) {
            if (compare.compareMessage(p, h, what, object, r, when)) {
            final Message msg = p.mMessage;
            if (compare.compareMessage(msg, h, what, object, r, when)) {
                found = true;
                if (DEBUG) {
                    Log.d(TAG_C, "stackHasMessages node matches");
                }
                if (removeMatches) {
                    if (p.removeFromStack()) {
                        p.mMessage.recycleUnchecked();
                        msg.recycleUnchecked();
                        decAndTraceMessageCount();
                        if (mMessageCounts.incrementCancelled()) {
                            concurrentWake();
@@ -3103,21 +3072,17 @@ public final class MessageQueue {
        return found;
    }

    private boolean priorityQueueHasMessage(ConcurrentSkipListSet<MessageNode> queue, Handler h,
    private boolean priorityQueueHasMessage(ConcurrentSkipListSet<Message> queue, Handler h,
            int what, Object object, Runnable r, long when, MessageCompare compare,
            boolean removeMatches) {
        Iterator<MessageNode> iterator = queue.iterator();
        boolean found = false;

        while (iterator.hasNext()) {
            MessageNode msg = iterator.next();

        for (Message msg : queue) {
            if (compare.compareMessage(msg, h, what, object, r, when)) {
                if (removeMatches) {
                    found = true;
                    if (queue.remove(msg)) {
                        msg.mMessage.recycleUnchecked();
                        msg.recycleUnchecked();
                        decAndTraceMessageCount();
                        found = true;
                    }
                } else {
                    return true;