Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 1d4abc78 authored by Treehugger Robot's avatar Treehugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Fix BoundServiceSession bookkeeping" into main

parents 1b014b55 8fb88fb8
Loading
Loading
Loading
Loading
+56 −20
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ import android.os.Trace;
import android.ravenwood.annotation.RavenwoodKeepWholeClass;
import android.util.ArrayMap;
import android.util.IndentingPrintWriter;
import android.util.IntArray;
import android.util.Slog;

import com.android.internal.annotations.GuardedBy;
@@ -30,11 +31,15 @@ import com.android.internal.util.IntPair;

import java.lang.ref.WeakReference;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

/**
 * An implementation of {@link IBinderSession} on top of a {@link ConnectionRecord} that
 * is used to facilitate important binder calls to a bound remote service hosted by a process that
 * is eligible to get frozen by {@link ProcessStateController}.
 *
 * <p>This class simply keeps the count of ongoing transactions over a bound service's binder and
 * notifies {@link ProcessStateController} when the count changes to or from 0.
 */
@RavenwoodKeepWholeClass
public class BoundServiceSession implements IBinderSession {
@@ -42,6 +47,13 @@ public class BoundServiceSession implements IBinderSession {
    private static final int MAGIC_ID = 0xFBD_5E55;
    private static final String TRACE_TRACK = "bound_service_calls";

    /** Tag to use for all tags after we have {@link #MAX_UNIQUE_TAGS} tags within the session. */
    @VisibleForTesting
    static final String OVERFLOW_TAG = "_overflow_tags";
    /** Any new tags after this limit will be clubbed together under {@link #OVERFLOW_TAG}. */
    @VisibleForTesting
    static final int MAX_UNIQUE_TAGS = 127;

    // We don't hold a strong reference in case this object is held on for a long time after the
    // binding has gone away. This helps us easily avoid leaks and excess OomAdjuster updates
    // while remaining agnostic to binding state changes. This is also a convenient long-term choice
@@ -50,9 +62,17 @@ public class BoundServiceSession implements IBinderSession {
    private final BiConsumer<ConnectionRecord, Boolean> mProcessStateUpdater;
    private final String mDebugName;

    /**
     * For each unique tag, we generate a stable key which is simply the index of the counter
     * maintained in {@link #mCountByKey} array. We encapsulate this information in the generated
     * token which is returned to the client to pass in {@link #binderTransactionCompleted(long)}.
     */
    @VisibleForTesting
    @GuardedBy("this")
    final ArrayMap<String, Integer> mKeyByTag = new ArrayMap<>();
    @VisibleForTesting
    @GuardedBy("this")
    ArrayMap<String, Integer> mCountsByTag = null;
    final IntArray mCountByKey = new IntArray();

    @VisibleForTesting
    @GuardedBy("this")
@@ -73,7 +93,7 @@ public class BoundServiceSession implements IBinderSession {
        return IntPair.first(token) == MAGIC_ID;
    }

    private static int getKeyIndex(long token) {
    private static int getKey(long token) {
        return IntPair.second(token);
    }

@@ -88,10 +108,10 @@ public class BoundServiceSession implements IBinderSession {
        mProcessStateUpdater.accept(strongCr, mTotal > 0);
    }

    private void logTraceInstant(String message) {
    private void logTraceInstant(Supplier<String> messageSupplier) {
        if (Trace.isTagEnabled(Trace.TRACE_TAG_ACTIVITY_MANAGER)) {
            Trace.instantForTrack(Trace.TRACE_TAG_ACTIVITY_MANAGER, TRACE_TRACK, mDebugName
                    + ": " + message);
                    + ": " + messageSupplier.get());
        }
    }

@@ -101,9 +121,11 @@ public class BoundServiceSession implements IBinderSession {
        // cannot tell for which tag. We'll just reset all counts to 0 and propagate the same to
        // the underlying ConnectionRecord. This also ensures that there are no shenanigans that
        // the remote app can perform with the given token to remain unfrozen.
        logTraceInstant(errorMessage);
        Slog.wtfStack(TAG, errorMessage);
        mCountsByTag.clear();
        logTraceInstant(() -> errorMessage);
        Slog.wtfStack(TAG,
                errorMessage + ". Current keys: " + mKeyByTag + "; Counts: " + mCountByKey);
        mKeyByTag.clear();
        mCountByKey.clear();
        if (mTotal != 0) {
            mTotal = 0;
            maybePostProcessStateUpdate();
@@ -112,17 +134,31 @@ public class BoundServiceSession implements IBinderSession {

    @Override
    public long binderTransactionStarting(String debugTag) {
        logTraceInstant("+" + debugTag);
        synchronized (this) {
            if (mCountsByTag == null) {
                mCountsByTag = new ArrayMap<>(4);
            }
            mCountsByTag.merge(debugTag, 1, (old, _unused) -> old + 1);
            final int key;
            if (mKeyByTag.size() >= MAX_UNIQUE_TAGS) {
                // The values in mKeyByTag are always in the range [0, mKeyByTag.size() - 1].
                key = mKeyByTag.getOrDefault(debugTag, MAX_UNIQUE_TAGS);
                if (key == MAX_UNIQUE_TAGS && mKeyByTag.size() == MAX_UNIQUE_TAGS) {
                    Slog.wtfStack(TAG, "Too many tags supplied on " + mDebugName
                            + ". Current tag: " + debugTag + ". Existing map: " + mKeyByTag);
                    mKeyByTag.put(OVERFLOW_TAG, key);
                    mCountByKey.add(0);
                }
            } else {
                key = mKeyByTag.computeIfAbsent(debugTag, unused -> {
                    mCountByKey.add(0);
                    return mCountByKey.size() - 1;
                });
            }
            final long token = getToken(key);
            logTraceInstant(() -> "open(" + debugTag + ", " + token + ")");
            mCountByKey.set(key, mCountByKey.get(key) + 1);
            mTotal++;
            if (mTotal == 1) {
                maybePostProcessStateUpdate();
            }
            return getToken(mCountsByTag.indexOfKey(debugTag));
            return token;
        }
    }

@@ -135,15 +171,15 @@ public class BoundServiceSession implements IBinderSession {
                        + mDebugName);
                return;
            }
            final int keyIndex = getKeyIndex(token);
            if (mCountsByTag.size() <= keyIndex || mCountsByTag.valueAt(keyIndex) <= 0) {
                handleInvalidToken("Bad keyIndex " + keyIndex
            final int key = getKey(token);
            if (mCountByKey.size() <= key || mCountByKey.get(key) <= 0) {
                handleInvalidToken("Bad key " + key
                        + " received in binderTransactionCompleted! Closing all transactions on "
                        + mDebugName);
                return;
            }
            logTraceInstant("-" + mCountsByTag.keyAt(keyIndex));
            mCountsByTag.setValueAt(keyIndex, mCountsByTag.valueAt(keyIndex) - 1);
            logTraceInstant(() -> "close(" + key + ")");
            mCountByKey.set(key, mCountByKey.get(key) - 1);
            mTotal--;
            if (mTotal == 0) {
                maybePostProcessStateUpdate();
@@ -161,10 +197,10 @@ public class BoundServiceSession implements IBinderSession {
            ipw.print("Ongoing bound service calls: ");
            ipw.println(mTotal);

            if (mCountsByTag != null) {
            if (mTotal > 0) {
                ipw.increaseIndent();
                for (int i = 0; i < mCountsByTag.size(); i++) {
                    ipw.print(mCountsByTag.keyAt(i), mCountsByTag.valueAt(i));
                for (int i = 0; i < mKeyByTag.size(); i++) {
                    ipw.print(mKeyByTag.keyAt(i), mCountByKey.get(mKeyByTag.valueAt(i)));
                }
                ipw.println();
                ipw.decreaseIndent();
+157 −44
Original line number Diff line number Diff line
@@ -16,8 +16,13 @@

package com.android.server.am;

import static com.android.server.am.BoundServiceSession.MAX_UNIQUE_TAGS;
import static com.android.server.am.BoundServiceSession.OVERFLOW_TAG;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.clearInvocations;
@@ -31,6 +36,8 @@ import android.platform.test.annotations.Presubmit;
import org.junit.Test;

import java.lang.ref.WeakReference;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiConsumer;

/**
@@ -39,7 +46,7 @@ import java.util.function.BiConsumer;
 * Build/Install/Run:
 *  atest FrameworksServicesTests:BoundServiceSessionTests
 * Or
 *  atest FrameworksServicesTestsRavenwood_ProcessStateController
 *  atest FrameworksServicesTestsRavenwood_ProcessStateController:BoundServiceSessionTests
 */
@Presubmit
public class BoundServiceSessionTests {
@@ -56,14 +63,14 @@ public class BoundServiceSessionTests {

    private static void assertSessionReset(BoundServiceSession session) {
        assertEquals(0, session.mTotal);
        assertEquals(0, session.mCountsByTag.size());
        assertEquals(0, session.mKeyByTag.size());
        assertEquals(0, session.mCountByKey.size());
    }

    @Test
    public void startingState() {
        final BoundServiceSession session = getNewBoundServiceSessionForTest();
        assertEquals(0, session.mTotal);
        assertNull(session.mCountsByTag);
        assertSessionReset(session);
    }

    @Test
@@ -78,14 +85,19 @@ public class BoundServiceSessionTests {

        final long token1 = session.binderTransactionStarting(testTags[1]);

        assertEquals(4, (int) session.mCountsByTag.get(testTags[0]));
        assertEquals(1, (int) session.mCountsByTag.get(testTags[1]));
        final int key0 = session.mKeyByTag.get(testTags[0]);
        final int key1 = session.mKeyByTag.get(testTags[1]);
        assertEquals(0, key0);
        assertEquals(1, key1);

        assertEquals(4, session.mCountByKey.get(key0));
        assertEquals(1, session.mCountByKey.get(key1));
        assertEquals(5, session.mTotal);

        session.binderTransactionCompleted(token1);

        assertEquals(4, (int) session.mCountsByTag.get(testTags[0]));
        assertEquals(0, (int) session.mCountsByTag.get(testTags[1]));
        assertEquals(4, session.mCountByKey.get(key0));
        assertEquals(0, session.mCountByKey.get(key1));
        assertEquals(4, session.mTotal);

        session.binderTransactionCompleted(token1);
@@ -101,8 +113,10 @@ public class BoundServiceSessionTests {
        session.binderTransactionStarting(testTag);
        session.binderTransactionStarting(testTag);

        int key = session.mKeyByTag.get(testTag);
        assertEquals(0, key);
        assertEquals(3, session.mTotal);
        assertEquals(3, (int) session.mCountsByTag.get(testTag));
        assertEquals(3, session.mCountByKey.get(key));

        session.binderTransactionCompleted(validToken + 1);
        assertSessionReset(session);
@@ -110,8 +124,10 @@ public class BoundServiceSessionTests {
        session.binderTransactionStarting(testTag);
        session.binderTransactionStarting(testTag);

        key = session.mKeyByTag.get(testTag);
        assertEquals(0, key);
        assertEquals(2, session.mTotal);
        assertEquals(2, (int) session.mCountsByTag.get(testTag));
        assertEquals(2, session.mCountByKey.get(key));

        session.binderTransactionCompleted(-1);
        assertSessionReset(session);
@@ -121,7 +137,7 @@ public class BoundServiceSessionTests {
    public void tokenConsistency() {
        final BoundServiceSession session = getNewBoundServiceSessionForTest();

        final String[] testTags = {"test0", "test1", "test2"};
        final String[] testTags = {"test5", "test1", "test2"};

        final long token0 = session.binderTransactionStarting(testTags[0]);
        final long token1 = session.binderTransactionStarting(testTags[1]);
@@ -132,6 +148,24 @@ public class BoundServiceSessionTests {
        assertEquals(token2, session.binderTransactionStarting(testTags[2]));
    }

    @Test
    public void tokenDistinctness() {
        final BoundServiceSession session = getNewBoundServiceSessionForTest();

        final String[] tags = {"test3", "test1", "otherTag", "test2"};
        final List<Long> tokens = new ArrayList<>();

        for (String tag: tags) {
            final long token = session.binderTransactionStarting(tag);
            final int index = tokens.indexOf(token);
            if (index >= 0) {
                fail("Duplicate token " + token + " found for tag " + tag
                        + ". Previously assigned to tag " + tags[index]);
            }
            tokens.add(token);
        }
    }

    @Test
    public void callsConsumerOnChangeFromZero() {
        final BoundServiceSession session = getNewBoundServiceSessionForTest();
@@ -201,57 +235,136 @@ public class BoundServiceSessionTests {
        verify(mMockConsumer, never()).accept(any(ConnectionRecord.class), anyBoolean());
    }

    @Test
    public void keyKeeping() {
        final BoundServiceSession session = getNewBoundServiceSessionForTest();

        final String[] tags = {"test3", "test1", "otherTag", "test2"};
        final List<Integer> keys = new ArrayList<>();

        for (String tag: tags) {
            assertFalse(session.mKeyByTag.containsKey(tag));
            session.binderTransactionStarting(tag);

            final int key = session.mKeyByTag.get(tag);
            final int index = keys.indexOf(key);
            if (index >= 0) {
                fail("Duplicate key " + key + " found for tag " + tag
                        + ". Previously assigned to tag " + tags[index]);
            }
            keys.add(key);
        }

        // Ensure that keys don't change once assigned.
        for (int i = 0; i < tags.length; i++) {
            assertEquals(keys.get(i), session.mKeyByTag.get(tags[i]));
            session.binderTransactionStarting(tags[i]);
            assertEquals(keys.get(i), session.mKeyByTag.get(tags[i]));
        }
    }

    @Test
    public void countKeeping() {
        final BoundServiceSession session = getNewBoundServiceSessionForTest();

        final String[] testTags = {"test0", "test1", "test2"};
        final String[] tags = {"test3", "otherTag", "test0"};
        final List<Long> tokens = new ArrayList<>();
        final List<Integer> keys = new ArrayList<>();

        final long token0 = session.binderTransactionStarting(testTags[0]);
        final long token1 = session.binderTransactionStarting(testTags[1]);
        final long token2 = session.binderTransactionStarting(testTags[2]);
        for (String tag: tags) {
            tokens.add(session.binderTransactionStarting(tag));
            keys.add(session.mKeyByTag.get(tag));
        }

        session.binderTransactionStarting(testTags[1]);
        session.binderTransactionStarting(testTags[2]);
        session.binderTransactionStarting(testTags[2]);
        session.binderTransactionStarting(tags[1]);
        session.binderTransactionStarting(tags[2]);
        session.binderTransactionStarting(tags[2]);

        assertEquals(1, (int) session.mCountsByTag.get(testTags[0]));
        assertEquals(2, (int) session.mCountsByTag.get(testTags[1]));
        assertEquals(3, (int) session.mCountsByTag.get(testTags[2]));
        assertEquals(1, session.mCountByKey.get(keys.get(0)));
        assertEquals(2, session.mCountByKey.get(keys.get(1)));
        assertEquals(3, session.mCountByKey.get(keys.get(2)));
        assertEquals(6, session.mTotal);

        session.binderTransactionCompleted(token0);
        session.binderTransactionCompleted(token1);
        session.binderTransactionCompleted(token2);
        session.binderTransactionCompleted(tokens.get(0));
        session.binderTransactionCompleted(tokens.get(1));
        session.binderTransactionCompleted(tokens.get(2));

        assertEquals(0, (int) session.mCountsByTag.get(testTags[0]));
        assertEquals(1, (int) session.mCountsByTag.get(testTags[1]));
        assertEquals(2, (int) session.mCountsByTag.get(testTags[2]));
        assertEquals(0, session.mCountByKey.get(keys.get(0)));
        assertEquals(1, session.mCountByKey.get(keys.get(1)));
        assertEquals(2, session.mCountByKey.get(keys.get(2)));
        assertEquals(3, session.mTotal);

        session.binderTransactionCompleted(token1);
        session.binderTransactionCompleted(token2);
        session.binderTransactionCompleted(tokens.get(1));
        session.binderTransactionCompleted(tokens.get(2));

        assertEquals(0, (int) session.mCountsByTag.get(testTags[0]));
        assertEquals(0, (int) session.mCountsByTag.get(testTags[1]));
        assertEquals(1, (int) session.mCountsByTag.get(testTags[2]));
        assertEquals(0, session.mCountByKey.get(keys.get(0)));
        assertEquals(0, session.mCountByKey.get(keys.get(1)));
        assertEquals(1, session.mCountByKey.get(keys.get(2)));
        assertEquals(1, session.mTotal);

        session.binderTransactionStarting(testTags[0]);
        session.binderTransactionStarting(testTags[1]);
        session.binderTransactionStarting(tags[0]);
        session.binderTransactionStarting(tags[1]);

        assertEquals(1, (int) session.mCountsByTag.get(testTags[0]));
        assertEquals(1, (int) session.mCountsByTag.get(testTags[1]));
        assertEquals(1, (int) session.mCountsByTag.get(testTags[2]));
        assertEquals(1, session.mCountByKey.get(keys.get(0)));
        assertEquals(1, session.mCountByKey.get(keys.get(1)));
        assertEquals(1, session.mCountByKey.get(keys.get(2)));
        assertEquals(3, session.mTotal);

        session.binderTransactionCompleted(token0);
        session.binderTransactionCompleted(token1);
        session.binderTransactionCompleted(token2);
        session.binderTransactionCompleted(tokens.get(0));
        session.binderTransactionCompleted(tokens.get(1));
        session.binderTransactionCompleted(tokens.get(2));

        assertEquals(0, (int) session.mCountsByTag.get(testTags[0]));
        assertEquals(0, (int) session.mCountsByTag.get(testTags[1]));
        assertEquals(0, (int) session.mCountsByTag.get(testTags[2]));
        assertEquals(0, session.mCountByKey.get(keys.get(0)));
        assertEquals(0, session.mCountByKey.get(keys.get(1)));
        assertEquals(0, session.mCountByKey.get(keys.get(2)));
        assertEquals(0, session.mTotal);
    }

    @Test
    public void overflowTags() {
        final BoundServiceSession session = getNewBoundServiceSessionForTest();

        final List<String> uniqueTags = new ArrayList<>(MAX_UNIQUE_TAGS);
        for (int i = 0; i < MAX_UNIQUE_TAGS; i++) {
            final String tag = "unique_tag" + i;
            uniqueTags.add(tag);
            session.binderTransactionStarting(tag);
            assertEquals(i, (int) session.mKeyByTag.getOrDefault(tag, -1));
            assertEquals(1, session.mCountByKey.get(i));
        }
        assertEquals(MAX_UNIQUE_TAGS, session.mCountByKey.size());
        assertEquals(MAX_UNIQUE_TAGS, session.mKeyByTag.size());
        assertFalse(session.mKeyByTag.containsKey(OVERFLOW_TAG));

        final String[] overflowTags = {"test3", "4overflow", "test0"};

        final long overflowToken0 = session.binderTransactionStarting(overflowTags[0]);
        session.binderTransactionStarting(overflowTags[0]);
        final long overflowToken1 = session.binderTransactionStarting(overflowTags[1]);
        final long overflowToken2 = session.binderTransactionStarting(overflowTags[2]);
        session.binderTransactionStarting(overflowTags[2]);

        assertTrue(session.mKeyByTag.containsKey(OVERFLOW_TAG));
        assertEquals(MAX_UNIQUE_TAGS + 1, session.mCountByKey.size());
        assertEquals(MAX_UNIQUE_TAGS + 1, session.mKeyByTag.size());
        assertEquals(5, session.mCountByKey.get(MAX_UNIQUE_TAGS));
        for (final String overflowTag : overflowTags) {
            assertFalse(session.mKeyByTag.containsKey(overflowTag));
        }

        for (int i = 0; i < MAX_UNIQUE_TAGS; i++) {
            final String tag = uniqueTags.get(i);
            session.binderTransactionStarting(tag);
            assertEquals(2, session.mCountByKey.get(i));
            assertEquals(5, session.mCountByKey.get(MAX_UNIQUE_TAGS));
        }

        session.binderTransactionCompleted(overflowToken0);
        session.binderTransactionCompleted(overflowToken1);
        session.binderTransactionCompleted(overflowToken2);
        assertEquals(2, session.mCountByKey.get(MAX_UNIQUE_TAGS));
        for (int i = 0; i < MAX_UNIQUE_TAGS; i++) {
            assertEquals(2, session.mCountByKey.get(i));
        }
    }
}