Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d9e82474 authored by wilsonshih's avatar wilsonshih
Browse files

Tracking task snapshot usage in client process.(3/N)

Track hardware buffer usage within a TaskSnapshot in the client-side
process.
Cache task snapshot objects in TaskSnapshotManager and reuse existing
snapshots based on capture time, to prevent the system server from
always returning a new snapshot object to client if possible.

Flag: com.android.window.flags.reduce_task_snapshot_memory_usage
Bug: 238206323
Test: atest TaskSnapshotControllerTest
Change-Id: I81928a2ff326300441b2aa5030fa3b44f975dc86
parent 904b96b1
Loading
Loading
Loading
Loading
+6 −3
Original line number Diff line number Diff line
@@ -24,13 +24,16 @@ interface ITaskSnapshotManager {
     * Fetches the snapshot for the task with the given id.
     *
     * @param taskId the id of the task to retrieve for
     * @param latestCaptureTime the elapsed time of the latest taskSnapshot captured
     * @param retrieveResolution the resolution we want to load.
     *
     * @throws RemoteException
     * @return a graphic buffer representing a screenshot of a task, or {@code null} if no
     *         screenshot can be found.
     * @return a graphic buffer representing a screenshot of a task, it returns {@code null} if no
     *         screenshot can be found, but if latestCaptureTime is equals or greater than 0, then
     *         the client should reuse the existing snapshot.
     */
    android.window.TaskSnapshot getTaskSnapshot(int taskId, int retrieveResolution);
    android.window.TaskSnapshot getTaskSnapshot(int taskId, long latestCaptureTime,
            int retrieveResolution);

    /**
     * Requests for a new snapshot to be taken for the task with the given id, storing it in the
+17 −4
Original line number Diff line number Diff line
@@ -32,7 +32,6 @@ import android.hardware.HardwareBuffer;
import android.os.Build;
import android.os.Parcel;
import android.os.Parcelable;
import android.os.SystemClock;
import android.util.DisplayMetrics;
import android.view.Surface;
import android.view.SurfaceControl;
@@ -43,6 +42,7 @@ import com.android.internal.policy.TransitionAnimation;

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.ref.WeakReference;
import java.util.function.Consumer;

/**
@@ -52,8 +52,8 @@ import java.util.function.Consumer;
public class TaskSnapshot implements Parcelable {
    // Identifier of this snapshot
    private final long mId;
    // The elapsed real time (in nanoseconds) when this snapshot was captured, not intended for use outside the
    // process in which the snapshot was taken (ie. this is not parceled)
    // The elapsed real time (in nanoseconds) when this snapshot was captured or loaded from disk
    // since boot.
    private final long mCaptureTime;
    // Top activity in task when snapshot was taken
    private final ComponentName mTopActivityComponent;
@@ -86,6 +86,7 @@ public class TaskSnapshot implements Parcelable {
    private int mInternalReferences;
    private int mWriteToParcelCount;
    private Consumer<HardwareBuffer> mSafeSnapshotReleaser;
    private WeakReference<TaskSnapshotManager.SnapshotTracker> mSnapshotTracker;

    /** Keep in cache, doesn't need reference. */
    public static final int REFERENCE_NONE = 0;
@@ -141,7 +142,6 @@ public class TaskSnapshot implements Parcelable {

    private TaskSnapshot(Parcel source) {
        mId = source.readLong();
        mCaptureTime = SystemClock.elapsedRealtimeNanos();
        mTopActivityComponent = ComponentName.readFromParcel(source);
        mSnapshot = source.readTypedObject(HardwareBuffer.CREATOR);
        int colorSpaceId = source.readInt();
@@ -162,6 +162,7 @@ public class TaskSnapshot implements Parcelable {
        mUiMode = source.readInt();
        int densityDpi = source.readInt();
        mDensityDpi = densityDpi > 0 ? densityDpi : DisplayMetrics.DENSITY_DEVICE_STABLE;
        mCaptureTime = source.readLong();
    }

    /**
@@ -268,6 +269,13 @@ public class TaskSnapshot implements Parcelable {
     */
    public void closeBuffer() {
        if (isBufferValid()) {
            if (mSnapshotTracker != null) {
                final TaskSnapshotManager.SnapshotTracker tracker = mSnapshotTracker.get();
                if (tracker != null) {
                    TaskSnapshotManager.getInstance().removeTracker(tracker);
                    mSnapshotTracker.clear();
                }
            }
            mSnapshot.close();
        }
    }
@@ -440,6 +448,7 @@ public class TaskSnapshot implements Parcelable {
        dest.writeBoolean(mHasImeSurface);
        dest.writeInt(mUiMode);
        dest.writeInt(mDensityDpi);
        dest.writeLong(mCaptureTime);
        synchronized (this) {
            if ((mInternalReferences & REFERENCE_WRITE_TO_PARCEL) != 0) {
                mWriteToParcelCount--;
@@ -484,6 +493,10 @@ public class TaskSnapshot implements Parcelable {
                + " mDensityDpi=" + mDensityDpi;
    }

    void setSnapshotTracker(TaskSnapshotManager.SnapshotTracker tracker) {
        mSnapshotTracker = new WeakReference<>(tracker);
    }

    /**
     * Adds a reference when the object is held somewhere.
     * Only used in core.
+231 −2
Original line number Diff line number Diff line
@@ -21,11 +21,21 @@ import android.annotation.NonNull;
import android.annotation.RequiresPermission;
import android.app.ActivityTaskManager;
import android.os.RemoteException;
import android.system.SystemCleaner;
import android.util.AndroidRuntimeException;
import android.util.Singleton;
import android.util.Slog;
import android.util.SparseArray;

import com.android.internal.annotations.GuardedBy;

import java.io.PrintWriter;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.ref.Cleaner;
import java.lang.ref.WeakReference;
import java.util.Comparator;
import java.util.TreeSet;

/**
 * Retrieve or request app snapshots in system.
@@ -64,6 +74,10 @@ public class TaskSnapshotManager {
    @Retention(RetentionPolicy.SOURCE)
    public @interface Resolution {}

    private final Object mLock = new Object();
    @GuardedBy("mLock")
    private final GlobalSnapshotTracker mGlobalSnapshotTracker = new GlobalSnapshotTracker();
    static final Cleaner sCleaner = SystemCleaner.cleaner();
    private static final TaskSnapshotManager sInstance = new TaskSnapshotManager();
    private TaskSnapshotManager() { }

@@ -83,12 +97,31 @@ public class TaskSnapshotManager {
    public TaskSnapshot getTaskSnapshot(int taskId, @Resolution int retrieveResolution)
            throws RemoteException {
        final TaskSnapshot t;
        final long captureTime;
        final TaskSnapshot previousSnapshot;
        validateResolution(retrieveResolution);
        synchronized (mLock) {
            // Gets the latest snapshot from the local cache. This can be used to prevent the system
            // server from returning another snapshot that is the same as the local one.
            final SnapshotTracker st = mGlobalSnapshotTracker.peekLatestSnapshot(
                    taskId, retrieveResolution);
            // Create a temporary reference so the snapshot won't be cleared during IPC call.
            previousSnapshot = st != null ? st.mSnapshot.get() : null;
            captureTime = previousSnapshot != null ? st.mCaptureTime : -1;
        }
        try {
            t = ISnapshotManagerSingleton.get().getTaskSnapshot(taskId, retrieveResolution);
            t = ISnapshotManagerSingleton.get().getTaskSnapshot(taskId,
                    captureTime, retrieveResolution);
        } catch (RemoteException r) {
            Slog.e(TAG, "getTaskSnapshot fail: " + r);
            throw r;
        }
        if (t == null) {
            return previousSnapshot;
        }
        synchronized (mLock) {
            mGlobalSnapshotTracker.createTracker(taskId, t);
        }
        return t;
    }

@@ -109,9 +142,14 @@ public class TaskSnapshotManager {
        try {
            t = ISnapshotManagerSingleton.get().takeTaskSnapshot(taskId, updateCache);
        } catch (RemoteException r) {
            Slog.e(TAG, "getTaskSnapshot fail: " + r);
            Slog.e(TAG, "takeTaskSnapshot fail: " + r);
            throw r;
        }
        if (t != null) {
            synchronized (mLock) {
                mGlobalSnapshotTracker.createTracker(taskId, t);
            }
        }
        return t;
    }

@@ -151,4 +189,195 @@ public class TaskSnapshotManager {
                    }
                }
            };

    void removeTracker(SnapshotTracker tracker) {
        synchronized (mLock) {
            mGlobalSnapshotTracker.removeTracker(tracker);
        }
    }

    /**
     * Dump snapshot usage in the process.
     */
    public void dump(PrintWriter pw) {
        synchronized (mLock) {
            mGlobalSnapshotTracker.dump(pw);
        }
    }

    /**
     * Util method, validate requested resolution.
     */
    public static void validateResolution(int resolution) {
        switch (resolution) {
            case RESOLUTION_ANY:
            case RESOLUTION_HIGH:
            case RESOLUTION_LOW:
                return;
            default:
                throw new IllegalArgumentException("Invalidate resolution=" + resolution);
        }
    }

    private static class GlobalSnapshotTracker {
        final SparseArray<SingleTaskTracker> mSnapshotTrackers = new SparseArray<>();

        void createTracker(int taskId, TaskSnapshot snapshot) {
            SingleTaskTracker taskTracker = mSnapshotTrackers.get(taskId);
            if (taskTracker == null) {
                taskTracker = new SingleTaskTracker();
                mSnapshotTrackers.put(taskId, taskTracker);
            }
            final SnapshotTracker tracker = new SnapshotTracker(taskId, snapshot);
            taskTracker.addTracker(tracker);
            sCleaner.register(snapshot, () -> removeTracker(tracker));
        }

        void removeTracker(SnapshotTracker tracker) {
            if (tracker == null) {
                return;
            }
            final int taskId = tracker.mTaskId;
            final SingleTaskTracker taskTracker = mSnapshotTrackers.get(taskId);
            if (taskTracker == null) {
                return;
            }
            taskTracker.stopTrack(tracker);
            if (taskTracker.isEmpty()) {
                mSnapshotTrackers.remove(taskId);
            }
        }

        SnapshotTracker peekLatestSnapshot(int taskId, @Resolution int resolution) {
            final SingleTaskTracker stt = mSnapshotTrackers.get(taskId);
            if (stt == null) {
                // shouldn't happen
                return null;
            }
            return stt.peekLatestSnapshot(resolution);
        }

        /**
         * Dump snapshot usage in the process.
         */
        void dump(PrintWriter pw) {
            if (mSnapshotTrackers.size() == 0) {
                return;
            }
            pw.println("");
            pw.println("Task Snapshot Usage:");
            for (int i = mSnapshotTrackers.size() - 1; i >= 0; --i) {
                mSnapshotTrackers.valueAt(i).dump(pw);
            }
        }

        static class SingleTaskTracker {
            final TreeSet<SnapshotTracker> mHighResSortedTrackers = new TreeSet<>(TRACKER_ORDER);
            final TreeSet<SnapshotTracker> mLowResSortedTrackers  = new TreeSet<>(TRACKER_ORDER);

            static final Comparator<SnapshotTracker> TRACKER_ORDER = new Comparator<>() {
                @Override
                public int compare(SnapshotTracker s1, SnapshotTracker s2) {
                    if (s1.mCaptureTime < s2.mCaptureTime) {
                        return 1;
                    } else if (s1.mCaptureTime > s2.mCaptureTime) {
                        return -1;
                    }
                    return 0;
                }
            };

            void addTracker(@NonNull SnapshotTracker tracker) {
                final TreeSet<SnapshotTracker> targetingSet = tracker.mIsLowResolution
                        ? mLowResSortedTrackers : mHighResSortedTrackers;
                targetingSet.add(tracker);
            }

            SnapshotTracker peekLatestSnapshot(@Resolution int resolution) {
                if (resolution == RESOLUTION_ANY) {
                    final SnapshotTracker hFirst = peekFirst(mHighResSortedTrackers);
                    final SnapshotTracker lFirst = peekFirst(mLowResSortedTrackers);
                    if (hFirst != null && lFirst != null) {
                        return hFirst.mCaptureTime > lFirst.mCaptureTime ? hFirst : lFirst;
                    } else {
                        return hFirst != null ? hFirst : lFirst;
                    }
                }
                final TreeSet<SnapshotTracker> targetingSet = resolution == RESOLUTION_LOW
                        ? mLowResSortedTrackers : mHighResSortedTrackers;
                return peekFirst(targetingSet);
            }

            private static SnapshotTracker peekFirst(TreeSet<SnapshotTracker> targetingSet) {
                return targetingSet.isEmpty() ? null : targetingSet.getFirst();
            }

            void stopTrack(SnapshotTracker tracker) {
                final TreeSet<SnapshotTracker> targetingSet = tracker.mIsLowResolution
                        ? mLowResSortedTrackers : mHighResSortedTrackers;
                targetingSet.remove(tracker);
            }

            boolean isEmpty() {
                return mHighResSortedTrackers.isEmpty() && mLowResSortedTrackers.isEmpty();
            }

            void dump(PrintWriter pw) {
                for (SnapshotTracker highResSortedTracker : mHighResSortedTrackers) {
                    highResSortedTracker.dump(pw);
                }
                for (SnapshotTracker lowResSortedTracker : mLowResSortedTrackers) {
                    lowResSortedTracker.dump(pw);
                }
            }
        }
    }

    // Tracking the snapshot usage, call getStackTrace() to know where it was created.
    static class SnapshotTracker extends AndroidRuntimeException {
        private static final int ROOT_STACK_TRACE_COUNT = 4;
        final int mTaskId;
        final long mSnapshotId;
        final long mCaptureTime;
        final boolean mIsLowResolution;
        final WeakReference<TaskSnapshot> mSnapshot;

        SnapshotTracker(int taskId, TaskSnapshot snapshot) {
            super();
            mTaskId = taskId;
            mSnapshotId = snapshot.getId();
            mCaptureTime = snapshot.getCaptureTime();
            mIsLowResolution = snapshot.isLowResolution();
            snapshot.setSnapshotTracker(this);
            mSnapshot = new WeakReference<>(snapshot);
        }

        @Override
        public String getMessage() {
            return "SnapshotTracker: @" + hashCode()
                    + " {TaskId=" + mTaskId + ", SnapshotId=" + mSnapshotId + " mCaptureTime="
                    + mCaptureTime + ", isLowResolution=" + mIsLowResolution + "}";
        }

        void dump(PrintWriter pw) {
            final StringBuilder builder = buildDumpString(this);
            pw.println("  taskId=" + mTaskId + ", SnapshotID=" + mSnapshotId
                    + ", isLowResolution=" + mIsLowResolution);
            pw.println("   Get from=" + builder);
            pw.println("");
        }

        static StringBuilder buildDumpString(AndroidRuntimeException dump) {
            final StackTraceElement[] stackTrace = dump.getStackTrace();
            final int count = Math.min(stackTrace.length, ROOT_STACK_TRACE_COUNT);
            final StringBuilder builder = new StringBuilder();
            for (int i = 0; i < count; ++i) {
                builder.append(stackTrace[i]);
                if (i + 1 < count) {
                    builder.append(" ");
                }
            }
            return builder;
        }
    }
}
+34 −23
Original line number Diff line number Diff line
@@ -38,6 +38,7 @@ import android.window.ITaskSnapshotManager;
import android.window.TaskSnapshot;
import android.window.TaskSnapshotManager;

import com.android.internal.annotations.VisibleForTesting;
import com.android.window.flags.Flags;

import java.io.PrintWriter;
@@ -266,28 +267,46 @@ class SnapshotController {
        mSnapshotPersistQueue.dump(pw, prefix);
    }

    /**
     * Util method, validate requested resolution.
     */
    private static void validateResolution(int resolution) {
        switch (resolution) {
            case TaskSnapshotManager.RESOLUTION_ANY:
            case TaskSnapshotManager.RESOLUTION_HIGH:
            case TaskSnapshotManager.RESOLUTION_LOW:
                return;
            default:
                throw new IllegalArgumentException("Invalidate resolution=" + resolution);
    @VisibleForTesting
    TaskSnapshot getTaskSnapshotInner(int taskId, Task task, long latestCaptureTime,
            @TaskSnapshotManager.Resolution int retrieveResolution) {
        synchronized (mService.mGlobalLock) {
            final TaskSnapshot snapshot = mTaskSnapshotController.getSnapshot(
                    taskId, retrieveResolution);
            if (snapshot != null) {
                if (snapshot.getCaptureTime() > latestCaptureTime) {
                    snapshot.addReference(TaskSnapshot.REFERENCE_WRITE_TO_PARCEL);
                    return snapshot;
                } else {
                    return null;
                }
            }
            if (latestCaptureTime > 0) {
                // Return null if the client already has the latest snapshot.
                final TaskSnapshot inCacheSnapshot = mTaskSnapshotController.getSnapshot(
                        taskId, TaskSnapshotManager.RESOLUTION_ANY);
                if (inCacheSnapshot != null) {
                    if (inCacheSnapshot.getCaptureTime() <= latestCaptureTime) {
                        return null;
                    }
                }
            }
        }
        final boolean isLowResolution =
                retrieveResolution == TaskSnapshotManager.RESOLUTION_LOW;
        // Don't call this while holding the lock as this operation might hit the disk.
        return mTaskSnapshotController.getSnapshotFromDisk(taskId,
                task.mUserId, isLowResolution, TaskSnapshot.REFERENCE_WRITE_TO_PARCEL);
    }

    class SnapshotManagerService extends ITaskSnapshotManager.Stub {

        @Override
        public TaskSnapshot getTaskSnapshot(int taskId,
        public TaskSnapshot getTaskSnapshot(int taskId, long latestCaptureTime,
                @TaskSnapshotManager.Resolution int retrieveResolution) {
            final long ident = Binder.clearCallingIdentity();
            try {
                validateResolution(retrieveResolution);
                TaskSnapshotManager.validateResolution(retrieveResolution);
                final Task task;
                synchronized (mService.mGlobalLock) {
                    task = mService.mRoot.anyTaskForId(taskId,
@@ -296,17 +315,9 @@ class SnapshotController {
                        Slog.w(TAG, "getTaskSnapshot: taskId=" + taskId + " not found");
                        return null;
                    }
                    final TaskSnapshot snapshot = mTaskSnapshotController.getSnapshot(
                                taskId, retrieveResolution, TaskSnapshot.REFERENCE_WRITE_TO_PARCEL);
                    if (snapshot != null) {
                        return snapshot;
                    }
                }
                final boolean isLowResolution =
                        retrieveResolution == TaskSnapshotManager.RESOLUTION_LOW;
                // Don't call this while holding the lock as this operation might hit the disk.
                return mTaskSnapshotController.getSnapshotFromDisk(taskId,
                        task.mUserId, isLowResolution, TaskSnapshot.REFERENCE_WRITE_TO_PARCEL);
                return SnapshotController.this.getTaskSnapshotInner(taskId, task, latestCaptureTime,
                        retrieveResolution);
            } finally {
                Binder.restoreCallingIdentity(ident);
            }
+54 −0
Original line number Diff line number Diff line
@@ -28,9 +28,12 @@ import static com.android.server.wm.TaskSnapshotController.SNAPSHOT_MODE_REAL;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.never;
@@ -47,6 +50,7 @@ import android.graphics.Rect;
import android.hardware.HardwareBuffer;
import android.platform.test.annotations.Presubmit;
import android.window.TaskSnapshot;
import android.window.TaskSnapshotManager;

import androidx.test.filters.SmallTest;

@@ -247,4 +251,54 @@ public class TaskSnapshotControllerTest extends WindowTestsBase {
        waitHandlerIdle(mWm.mH);
        verify(mWm.mTaskSnapshotController.mCache).putSnapshot(eq(task), any());
    }

    @Test
    public void testGetTaskSnapshotFromClient() {
        spyOn(mWm.mTaskSnapshotController.mCache);
        spyOn(mWm.mTaskSnapshotController);
        final long captureTime = 100;
        final WindowState normalWindow = newWindowBuilder("normalWindow",
                FIRST_APPLICATION_WINDOW).setDisplay(mDisplayContent).build();
        final Task task = normalWindow.mActivityRecord.getTask();

        final TaskSnapshot diskSnapshot = new TaskSnapshotPersisterTestBase.TaskSnapshotBuilder()
                .setTopActivityComponent(normalWindow.mActivityRecord.mActivityComponent)
                .build();
        doReturn(diskSnapshot).when(mWm.mTaskSnapshotController)
                .getSnapshotFromDisk(anyInt(), anyInt(), anyBoolean(), anyInt());
        doReturn(null).when(mWm.mTaskSnapshotController.mCache)
                .getSnapshot(anyInt(), anyInt(), anyInt());

        // Client process doesn't has snapshot.
        TaskSnapshot result = mWm.mSnapshotController.getTaskSnapshotInner(task.mTaskId, task,
                -1 /* latestCaptureTime */, TaskSnapshotManager.RESOLUTION_ANY);
        assertEquals(result, diskSnapshot);

        // Put snapshot in cache
        final TaskSnapshot snapshot = new TaskSnapshotPersisterTestBase.TaskSnapshotBuilder()
                .setTopActivityComponent(normalWindow.mActivityRecord.mActivityComponent)
                .setCaptureTime(captureTime).build();
        doReturn(snapshot).when(mWm.mTaskSnapshotController.mCache)
                .getSnapshot(anyInt(), anyInt(), anyInt());

        // Client process doesn't has snapshot.
        result = mWm.mSnapshotController.getTaskSnapshotInner(task.mTaskId, task,
                -1 /* latestCaptureTime */, TaskSnapshotManager.RESOLUTION_ANY);
        assertEquals(result, snapshot);

        // Snapshot in client process is older than in system server.
        result = mWm.mSnapshotController.getTaskSnapshotInner(task.mTaskId, task,
                captureTime - 10 /* latestCaptureTime */, TaskSnapshotManager.RESOLUTION_ANY);
        assertEquals(result, snapshot);

        // Snapshot in client process is the same as in system server.
        result = mWm.mSnapshotController.getTaskSnapshotInner(task.mTaskId, task,
                captureTime, TaskSnapshotManager.RESOLUTION_ANY);
        assertNull(result);

        // Snapshot in client process is newer than in system server?
        result = mWm.mSnapshotController.getTaskSnapshotInner(task.mTaskId, task,
                captureTime + 10 /* latestCaptureTime */, TaskSnapshotManager.RESOLUTION_ANY);
        assertNull(result);
    }
}
Loading