Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 32820e53 authored by Wei Sheng Shih's avatar Wei Sheng Shih Committed by Android (Google) Code Review
Browse files

Merge "Tracking task snapshot usage in client process.(3/N)" into main

parents 5267a70e d9e82474
Loading
Loading
Loading
Loading
+6 −3
Original line number Diff line number Diff line
@@ -24,13 +24,16 @@ interface ITaskSnapshotManager {
     * Fetches the snapshot for the task with the given id.
     *
     * @param taskId the id of the task to retrieve for
     * @param latestCaptureTime the elapsed time of the latest taskSnapshot captured
     * @param retrieveResolution the resolution we want to load.
     *
     * @throws RemoteException
     * @return a graphic buffer representing a screenshot of a task, or {@code null} if no
     *         screenshot can be found.
     * @return a graphic buffer representing a screenshot of a task, it returns {@code null} if no
     *         screenshot can be found, but if latestCaptureTime is equals or greater than 0, then
     *         the client should reuse the existing snapshot.
     */
    android.window.TaskSnapshot getTaskSnapshot(int taskId, int retrieveResolution);
    android.window.TaskSnapshot getTaskSnapshot(int taskId, long latestCaptureTime,
            int retrieveResolution);

    /**
     * Requests for a new snapshot to be taken for the task with the given id, storing it in the
+17 −4
Original line number Diff line number Diff line
@@ -32,7 +32,6 @@ import android.hardware.HardwareBuffer;
import android.os.Build;
import android.os.Parcel;
import android.os.Parcelable;
import android.os.SystemClock;
import android.util.DisplayMetrics;
import android.view.Surface;
import android.view.SurfaceControl;
@@ -43,6 +42,7 @@ import com.android.internal.policy.TransitionAnimation;

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.ref.WeakReference;
import java.util.function.Consumer;

/**
@@ -52,8 +52,8 @@ import java.util.function.Consumer;
public class TaskSnapshot implements Parcelable {
    // Identifier of this snapshot
    private final long mId;
    // The elapsed real time (in nanoseconds) when this snapshot was captured, not intended for use outside the
    // process in which the snapshot was taken (ie. this is not parceled)
    // The elapsed real time (in nanoseconds) when this snapshot was captured or loaded from disk
    // since boot.
    private final long mCaptureTime;
    // Top activity in task when snapshot was taken
    private final ComponentName mTopActivityComponent;
@@ -86,6 +86,7 @@ public class TaskSnapshot implements Parcelable {
    private int mInternalReferences;
    private int mWriteToParcelCount;
    private Consumer<HardwareBuffer> mSafeSnapshotReleaser;
    private WeakReference<TaskSnapshotManager.SnapshotTracker> mSnapshotTracker;

    /** Keep in cache, doesn't need reference. */
    public static final int REFERENCE_NONE = 0;
@@ -141,7 +142,6 @@ public class TaskSnapshot implements Parcelable {

    private TaskSnapshot(Parcel source) {
        mId = source.readLong();
        mCaptureTime = SystemClock.elapsedRealtimeNanos();
        mTopActivityComponent = ComponentName.readFromParcel(source);
        mSnapshot = source.readTypedObject(HardwareBuffer.CREATOR);
        int colorSpaceId = source.readInt();
@@ -162,6 +162,7 @@ public class TaskSnapshot implements Parcelable {
        mUiMode = source.readInt();
        int densityDpi = source.readInt();
        mDensityDpi = densityDpi > 0 ? densityDpi : DisplayMetrics.DENSITY_DEVICE_STABLE;
        mCaptureTime = source.readLong();
    }

    /**
@@ -268,6 +269,13 @@ public class TaskSnapshot implements Parcelable {
     */
    public void closeBuffer() {
        if (isBufferValid()) {
            if (mSnapshotTracker != null) {
                final TaskSnapshotManager.SnapshotTracker tracker = mSnapshotTracker.get();
                if (tracker != null) {
                    TaskSnapshotManager.getInstance().removeTracker(tracker);
                    mSnapshotTracker.clear();
                }
            }
            mSnapshot.close();
        }
    }
@@ -440,6 +448,7 @@ public class TaskSnapshot implements Parcelable {
        dest.writeBoolean(mHasImeSurface);
        dest.writeInt(mUiMode);
        dest.writeInt(mDensityDpi);
        dest.writeLong(mCaptureTime);
        synchronized (this) {
            if ((mInternalReferences & REFERENCE_WRITE_TO_PARCEL) != 0) {
                mWriteToParcelCount--;
@@ -484,6 +493,10 @@ public class TaskSnapshot implements Parcelable {
                + " mDensityDpi=" + mDensityDpi;
    }

    void setSnapshotTracker(TaskSnapshotManager.SnapshotTracker tracker) {
        mSnapshotTracker = new WeakReference<>(tracker);
    }

    /**
     * Adds a reference when the object is held somewhere.
     * Only used in core.
+231 −2
Original line number Diff line number Diff line
@@ -21,11 +21,21 @@ import android.annotation.NonNull;
import android.annotation.RequiresPermission;
import android.app.ActivityTaskManager;
import android.os.RemoteException;
import android.system.SystemCleaner;
import android.util.AndroidRuntimeException;
import android.util.Singleton;
import android.util.Slog;
import android.util.SparseArray;

import com.android.internal.annotations.GuardedBy;

import java.io.PrintWriter;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.ref.Cleaner;
import java.lang.ref.WeakReference;
import java.util.Comparator;
import java.util.TreeSet;

/**
 * Retrieve or request app snapshots in system.
@@ -64,6 +74,10 @@ public class TaskSnapshotManager {
    @Retention(RetentionPolicy.SOURCE)
    public @interface Resolution {}

    private final Object mLock = new Object();
    @GuardedBy("mLock")
    private final GlobalSnapshotTracker mGlobalSnapshotTracker = new GlobalSnapshotTracker();
    static final Cleaner sCleaner = SystemCleaner.cleaner();
    private static final TaskSnapshotManager sInstance = new TaskSnapshotManager();
    private TaskSnapshotManager() { }

@@ -83,12 +97,31 @@ public class TaskSnapshotManager {
    public TaskSnapshot getTaskSnapshot(int taskId, @Resolution int retrieveResolution)
            throws RemoteException {
        final TaskSnapshot t;
        final long captureTime;
        final TaskSnapshot previousSnapshot;
        validateResolution(retrieveResolution);
        synchronized (mLock) {
            // Gets the latest snapshot from the local cache. This can be used to prevent the system
            // server from returning another snapshot that is the same as the local one.
            final SnapshotTracker st = mGlobalSnapshotTracker.peekLatestSnapshot(
                    taskId, retrieveResolution);
            // Create a temporary reference so the snapshot won't be cleared during IPC call.
            previousSnapshot = st != null ? st.mSnapshot.get() : null;
            captureTime = previousSnapshot != null ? st.mCaptureTime : -1;
        }
        try {
            t = ISnapshotManagerSingleton.get().getTaskSnapshot(taskId, retrieveResolution);
            t = ISnapshotManagerSingleton.get().getTaskSnapshot(taskId,
                    captureTime, retrieveResolution);
        } catch (RemoteException r) {
            Slog.e(TAG, "getTaskSnapshot fail: " + r);
            throw r;
        }
        if (t == null) {
            return previousSnapshot;
        }
        synchronized (mLock) {
            mGlobalSnapshotTracker.createTracker(taskId, t);
        }
        return t;
    }

@@ -109,9 +142,14 @@ public class TaskSnapshotManager {
        try {
            t = ISnapshotManagerSingleton.get().takeTaskSnapshot(taskId, updateCache);
        } catch (RemoteException r) {
            Slog.e(TAG, "getTaskSnapshot fail: " + r);
            Slog.e(TAG, "takeTaskSnapshot fail: " + r);
            throw r;
        }
        if (t != null) {
            synchronized (mLock) {
                mGlobalSnapshotTracker.createTracker(taskId, t);
            }
        }
        return t;
    }

@@ -151,4 +189,195 @@ public class TaskSnapshotManager {
                    }
                }
            };

    void removeTracker(SnapshotTracker tracker) {
        synchronized (mLock) {
            mGlobalSnapshotTracker.removeTracker(tracker);
        }
    }

    /**
     * Dump snapshot usage in the process.
     */
    public void dump(PrintWriter pw) {
        synchronized (mLock) {
            mGlobalSnapshotTracker.dump(pw);
        }
    }

    /**
     * Util method, validate requested resolution.
     */
    public static void validateResolution(int resolution) {
        switch (resolution) {
            case RESOLUTION_ANY:
            case RESOLUTION_HIGH:
            case RESOLUTION_LOW:
                return;
            default:
                throw new IllegalArgumentException("Invalidate resolution=" + resolution);
        }
    }

    private static class GlobalSnapshotTracker {
        final SparseArray<SingleTaskTracker> mSnapshotTrackers = new SparseArray<>();

        void createTracker(int taskId, TaskSnapshot snapshot) {
            SingleTaskTracker taskTracker = mSnapshotTrackers.get(taskId);
            if (taskTracker == null) {
                taskTracker = new SingleTaskTracker();
                mSnapshotTrackers.put(taskId, taskTracker);
            }
            final SnapshotTracker tracker = new SnapshotTracker(taskId, snapshot);
            taskTracker.addTracker(tracker);
            sCleaner.register(snapshot, () -> removeTracker(tracker));
        }

        void removeTracker(SnapshotTracker tracker) {
            if (tracker == null) {
                return;
            }
            final int taskId = tracker.mTaskId;
            final SingleTaskTracker taskTracker = mSnapshotTrackers.get(taskId);
            if (taskTracker == null) {
                return;
            }
            taskTracker.stopTrack(tracker);
            if (taskTracker.isEmpty()) {
                mSnapshotTrackers.remove(taskId);
            }
        }

        SnapshotTracker peekLatestSnapshot(int taskId, @Resolution int resolution) {
            final SingleTaskTracker stt = mSnapshotTrackers.get(taskId);
            if (stt == null) {
                // shouldn't happen
                return null;
            }
            return stt.peekLatestSnapshot(resolution);
        }

        /**
         * Dump snapshot usage in the process.
         */
        void dump(PrintWriter pw) {
            if (mSnapshotTrackers.size() == 0) {
                return;
            }
            pw.println("");
            pw.println("Task Snapshot Usage:");
            for (int i = mSnapshotTrackers.size() - 1; i >= 0; --i) {
                mSnapshotTrackers.valueAt(i).dump(pw);
            }
        }

        static class SingleTaskTracker {
            final TreeSet<SnapshotTracker> mHighResSortedTrackers = new TreeSet<>(TRACKER_ORDER);
            final TreeSet<SnapshotTracker> mLowResSortedTrackers  = new TreeSet<>(TRACKER_ORDER);

            static final Comparator<SnapshotTracker> TRACKER_ORDER = new Comparator<>() {
                @Override
                public int compare(SnapshotTracker s1, SnapshotTracker s2) {
                    if (s1.mCaptureTime < s2.mCaptureTime) {
                        return 1;
                    } else if (s1.mCaptureTime > s2.mCaptureTime) {
                        return -1;
                    }
                    return 0;
                }
            };

            void addTracker(@NonNull SnapshotTracker tracker) {
                final TreeSet<SnapshotTracker> targetingSet = tracker.mIsLowResolution
                        ? mLowResSortedTrackers : mHighResSortedTrackers;
                targetingSet.add(tracker);
            }

            SnapshotTracker peekLatestSnapshot(@Resolution int resolution) {
                if (resolution == RESOLUTION_ANY) {
                    final SnapshotTracker hFirst = peekFirst(mHighResSortedTrackers);
                    final SnapshotTracker lFirst = peekFirst(mLowResSortedTrackers);
                    if (hFirst != null && lFirst != null) {
                        return hFirst.mCaptureTime > lFirst.mCaptureTime ? hFirst : lFirst;
                    } else {
                        return hFirst != null ? hFirst : lFirst;
                    }
                }
                final TreeSet<SnapshotTracker> targetingSet = resolution == RESOLUTION_LOW
                        ? mLowResSortedTrackers : mHighResSortedTrackers;
                return peekFirst(targetingSet);
            }

            private static SnapshotTracker peekFirst(TreeSet<SnapshotTracker> targetingSet) {
                return targetingSet.isEmpty() ? null : targetingSet.getFirst();
            }

            void stopTrack(SnapshotTracker tracker) {
                final TreeSet<SnapshotTracker> targetingSet = tracker.mIsLowResolution
                        ? mLowResSortedTrackers : mHighResSortedTrackers;
                targetingSet.remove(tracker);
            }

            boolean isEmpty() {
                return mHighResSortedTrackers.isEmpty() && mLowResSortedTrackers.isEmpty();
            }

            void dump(PrintWriter pw) {
                for (SnapshotTracker highResSortedTracker : mHighResSortedTrackers) {
                    highResSortedTracker.dump(pw);
                }
                for (SnapshotTracker lowResSortedTracker : mLowResSortedTrackers) {
                    lowResSortedTracker.dump(pw);
                }
            }
        }
    }

    // Tracking the snapshot usage, call getStackTrace() to know where it was created.
    static class SnapshotTracker extends AndroidRuntimeException {
        private static final int ROOT_STACK_TRACE_COUNT = 4;
        final int mTaskId;
        final long mSnapshotId;
        final long mCaptureTime;
        final boolean mIsLowResolution;
        final WeakReference<TaskSnapshot> mSnapshot;

        SnapshotTracker(int taskId, TaskSnapshot snapshot) {
            super();
            mTaskId = taskId;
            mSnapshotId = snapshot.getId();
            mCaptureTime = snapshot.getCaptureTime();
            mIsLowResolution = snapshot.isLowResolution();
            snapshot.setSnapshotTracker(this);
            mSnapshot = new WeakReference<>(snapshot);
        }

        @Override
        public String getMessage() {
            return "SnapshotTracker: @" + hashCode()
                    + " {TaskId=" + mTaskId + ", SnapshotId=" + mSnapshotId + " mCaptureTime="
                    + mCaptureTime + ", isLowResolution=" + mIsLowResolution + "}";
        }

        void dump(PrintWriter pw) {
            final StringBuilder builder = buildDumpString(this);
            pw.println("  taskId=" + mTaskId + ", SnapshotID=" + mSnapshotId
                    + ", isLowResolution=" + mIsLowResolution);
            pw.println("   Get from=" + builder);
            pw.println("");
        }

        static StringBuilder buildDumpString(AndroidRuntimeException dump) {
            final StackTraceElement[] stackTrace = dump.getStackTrace();
            final int count = Math.min(stackTrace.length, ROOT_STACK_TRACE_COUNT);
            final StringBuilder builder = new StringBuilder();
            for (int i = 0; i < count; ++i) {
                builder.append(stackTrace[i]);
                if (i + 1 < count) {
                    builder.append(" ");
                }
            }
            return builder;
        }
    }
}
+34 −23
Original line number Diff line number Diff line
@@ -38,6 +38,7 @@ import android.window.ITaskSnapshotManager;
import android.window.TaskSnapshot;
import android.window.TaskSnapshotManager;

import com.android.internal.annotations.VisibleForTesting;
import com.android.window.flags.Flags;

import java.io.PrintWriter;
@@ -266,28 +267,46 @@ class SnapshotController {
        mSnapshotPersistQueue.dump(pw, prefix);
    }

    /**
     * Util method, validate requested resolution.
     */
    private static void validateResolution(int resolution) {
        switch (resolution) {
            case TaskSnapshotManager.RESOLUTION_ANY:
            case TaskSnapshotManager.RESOLUTION_HIGH:
            case TaskSnapshotManager.RESOLUTION_LOW:
                return;
            default:
                throw new IllegalArgumentException("Invalidate resolution=" + resolution);
    @VisibleForTesting
    TaskSnapshot getTaskSnapshotInner(int taskId, Task task, long latestCaptureTime,
            @TaskSnapshotManager.Resolution int retrieveResolution) {
        synchronized (mService.mGlobalLock) {
            final TaskSnapshot snapshot = mTaskSnapshotController.getSnapshot(
                    taskId, retrieveResolution);
            if (snapshot != null) {
                if (snapshot.getCaptureTime() > latestCaptureTime) {
                    snapshot.addReference(TaskSnapshot.REFERENCE_WRITE_TO_PARCEL);
                    return snapshot;
                } else {
                    return null;
                }
            }
            if (latestCaptureTime > 0) {
                // Return null if the client already has the latest snapshot.
                final TaskSnapshot inCacheSnapshot = mTaskSnapshotController.getSnapshot(
                        taskId, TaskSnapshotManager.RESOLUTION_ANY);
                if (inCacheSnapshot != null) {
                    if (inCacheSnapshot.getCaptureTime() <= latestCaptureTime) {
                        return null;
                    }
                }
            }
        }
        final boolean isLowResolution =
                retrieveResolution == TaskSnapshotManager.RESOLUTION_LOW;
        // Don't call this while holding the lock as this operation might hit the disk.
        return mTaskSnapshotController.getSnapshotFromDisk(taskId,
                task.mUserId, isLowResolution, TaskSnapshot.REFERENCE_WRITE_TO_PARCEL);
    }

    class SnapshotManagerService extends ITaskSnapshotManager.Stub {

        @Override
        public TaskSnapshot getTaskSnapshot(int taskId,
        public TaskSnapshot getTaskSnapshot(int taskId, long latestCaptureTime,
                @TaskSnapshotManager.Resolution int retrieveResolution) {
            final long ident = Binder.clearCallingIdentity();
            try {
                validateResolution(retrieveResolution);
                TaskSnapshotManager.validateResolution(retrieveResolution);
                final Task task;
                synchronized (mService.mGlobalLock) {
                    task = mService.mRoot.anyTaskForId(taskId,
@@ -296,17 +315,9 @@ class SnapshotController {
                        Slog.w(TAG, "getTaskSnapshot: taskId=" + taskId + " not found");
                        return null;
                    }
                    final TaskSnapshot snapshot = mTaskSnapshotController.getSnapshot(
                                taskId, retrieveResolution, TaskSnapshot.REFERENCE_WRITE_TO_PARCEL);
                    if (snapshot != null) {
                        return snapshot;
                    }
                }
                final boolean isLowResolution =
                        retrieveResolution == TaskSnapshotManager.RESOLUTION_LOW;
                // Don't call this while holding the lock as this operation might hit the disk.
                return mTaskSnapshotController.getSnapshotFromDisk(taskId,
                        task.mUserId, isLowResolution, TaskSnapshot.REFERENCE_WRITE_TO_PARCEL);
                return SnapshotController.this.getTaskSnapshotInner(taskId, task, latestCaptureTime,
                        retrieveResolution);
            } finally {
                Binder.restoreCallingIdentity(ident);
            }
+54 −0
Original line number Diff line number Diff line
@@ -28,9 +28,12 @@ import static com.android.server.wm.TaskSnapshotController.SNAPSHOT_MODE_REAL;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.never;
@@ -47,6 +50,7 @@ import android.graphics.Rect;
import android.hardware.HardwareBuffer;
import android.platform.test.annotations.Presubmit;
import android.window.TaskSnapshot;
import android.window.TaskSnapshotManager;

import androidx.test.filters.SmallTest;

@@ -247,4 +251,54 @@ public class TaskSnapshotControllerTest extends WindowTestsBase {
        waitHandlerIdle(mWm.mH);
        verify(mWm.mTaskSnapshotController.mCache).putSnapshot(eq(task), any());
    }

    @Test
    public void testGetTaskSnapshotFromClient() {
        spyOn(mWm.mTaskSnapshotController.mCache);
        spyOn(mWm.mTaskSnapshotController);
        final long captureTime = 100;
        final WindowState normalWindow = newWindowBuilder("normalWindow",
                FIRST_APPLICATION_WINDOW).setDisplay(mDisplayContent).build();
        final Task task = normalWindow.mActivityRecord.getTask();

        final TaskSnapshot diskSnapshot = new TaskSnapshotPersisterTestBase.TaskSnapshotBuilder()
                .setTopActivityComponent(normalWindow.mActivityRecord.mActivityComponent)
                .build();
        doReturn(diskSnapshot).when(mWm.mTaskSnapshotController)
                .getSnapshotFromDisk(anyInt(), anyInt(), anyBoolean(), anyInt());
        doReturn(null).when(mWm.mTaskSnapshotController.mCache)
                .getSnapshot(anyInt(), anyInt(), anyInt());

        // Client process doesn't has snapshot.
        TaskSnapshot result = mWm.mSnapshotController.getTaskSnapshotInner(task.mTaskId, task,
                -1 /* latestCaptureTime */, TaskSnapshotManager.RESOLUTION_ANY);
        assertEquals(result, diskSnapshot);

        // Put snapshot in cache
        final TaskSnapshot snapshot = new TaskSnapshotPersisterTestBase.TaskSnapshotBuilder()
                .setTopActivityComponent(normalWindow.mActivityRecord.mActivityComponent)
                .setCaptureTime(captureTime).build();
        doReturn(snapshot).when(mWm.mTaskSnapshotController.mCache)
                .getSnapshot(anyInt(), anyInt(), anyInt());

        // Client process doesn't has snapshot.
        result = mWm.mSnapshotController.getTaskSnapshotInner(task.mTaskId, task,
                -1 /* latestCaptureTime */, TaskSnapshotManager.RESOLUTION_ANY);
        assertEquals(result, snapshot);

        // Snapshot in client process is older than in system server.
        result = mWm.mSnapshotController.getTaskSnapshotInner(task.mTaskId, task,
                captureTime - 10 /* latestCaptureTime */, TaskSnapshotManager.RESOLUTION_ANY);
        assertEquals(result, snapshot);

        // Snapshot in client process is the same as in system server.
        result = mWm.mSnapshotController.getTaskSnapshotInner(task.mTaskId, task,
                captureTime, TaskSnapshotManager.RESOLUTION_ANY);
        assertNull(result);

        // Snapshot in client process is newer than in system server?
        result = mWm.mSnapshotController.getTaskSnapshotInner(task.mTaskId, task,
                captureTime + 10 /* latestCaptureTime */, TaskSnapshotManager.RESOLUTION_ANY);
        assertNull(result);
    }
}
Loading