Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 53195e0d authored by Tony Mak's avatar Tony Mak
Browse files

Fix memory leak in TCMS when TCSession.destroy() is not called.

TCMS stores sessions in a Map and remove it when the session is
destroyed. Memory leak happens if the client forget to call
TCSession.destroy() or does not have a chance to call before the process
is dead.

Solution:
1. Use linkToDeath() to remove the cached session when client process
   is dead. TCSessionID now contains a binder to make this possible.
2. Install a Cleaner to TCSession object such that destroy() is called
   whenever the session object is going to be GCed. This is needed
   because some clients may have a long lifecycle, e.g. apps that are
   bounded by system.

BUG:149012454

Test: Write an app that creates a TC session, but not calling destory().
      Then make sure TCMS removed the session in the following situations:
      1. The app is killed
      2. TCSession object is GCed (By forcing GC)
      By checking the output dumpsys textclassificataion, we can know
      that the session object is removed
Test: mts-tradefed run mts-extservices
Test: Sanity test: smart selection + smart replies

Change-Id: Ifb7dcb23e1f50d4b3e97a6ce40e63b57193f2892
parent 2cfc5bb1
Loading
Loading
Loading
Loading
+28 −2
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@

package android.view.textclassifier;

import android.annotation.NonNull;
import android.annotation.WorkerThread;
import android.view.textclassifier.SelectionEvent.InvocationMethod;

@@ -23,6 +24,8 @@ import com.android.internal.util.Preconditions;

import java.util.Objects;

import sun.misc.Cleaner;

/**
 * Session-aware TextClassifier.
 */
@@ -35,6 +38,7 @@ final class TextClassificationSession implements TextClassifier {
    private final SelectionEventHelper mEventHelper;
    private final TextClassificationSessionId mSessionId;
    private final TextClassificationContext mClassificationContext;
    private final Cleaner mCleaner;

    private boolean mDestroyed;

@@ -44,6 +48,8 @@ final class TextClassificationSession implements TextClassifier {
        mSessionId = new TextClassificationSessionId();
        mEventHelper = new SelectionEventHelper(mSessionId, mClassificationContext);
        initializeRemoteSession();
        // This ensures destroy() is called if the client forgot to do so.
        mCleaner = Cleaner.create(this, new CleanerRunnable(mEventHelper, mDelegate));
    }

    @Override
@@ -114,8 +120,7 @@ final class TextClassificationSession implements TextClassifier {

    @Override
    public void destroy() {
        mEventHelper.endSession();
        mDelegate.destroy();
        mCleaner.clean();
        mDestroyed = true;
    }

@@ -258,4 +263,25 @@ final class TextClassificationSession implements TextClassifier {
            }
        }
    }

    // We use a static nested class here to avoid retaining the object reference of the outer
    // class. Otherwise. the Cleaner would never be triggered.
    private static class CleanerRunnable implements Runnable {
        @NonNull
        private final SelectionEventHelper mEventHelper;
        @NonNull
        private final TextClassifier mDelegate;

        CleanerRunnable(
                @NonNull SelectionEventHelper eventHelper, @NonNull TextClassifier delegate) {
            mEventHelper = Objects.requireNonNull(eventHelper);
            mDelegate = Objects.requireNonNull(delegate);
        }

        @Override
        public void run() {
            mEventHelper.endSession();
            mDelegate.destroy();
        }
    }
}
+29 −39
Original line number Diff line number Diff line
@@ -17,6 +17,8 @@
package android.view.textclassifier;

import android.annotation.NonNull;
import android.os.Binder;
import android.os.IBinder;
import android.os.Parcel;
import android.os.Parcelable;

@@ -28,7 +30,10 @@ import java.util.UUID;
 * This class represents the id of a text classification session.
 */
public final class TextClassificationSessionId implements Parcelable {
    private final @NonNull String mValue;
    @NonNull
    private final String mValue;
    @NonNull
    private final IBinder mToken;

    /**
     * Creates a new instance.
@@ -36,7 +41,7 @@ public final class TextClassificationSessionId implements Parcelable {
     * @hide
     */
    public TextClassificationSessionId() {
        this(UUID.randomUUID().toString());
        this(UUID.randomUUID().toString(), new Binder());
    }

    /**
@@ -46,34 +51,28 @@ public final class TextClassificationSessionId implements Parcelable {
     *
     * @hide
     */
    public TextClassificationSessionId(@NonNull String value) {
        mValue = value;
    public TextClassificationSessionId(@NonNull String value, @NonNull IBinder token) {
        mValue = Objects.requireNonNull(value);
        mToken = Objects.requireNonNull(token);
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result + mValue.hashCode();
        return result;
    /** @hide */
    @NonNull
    public IBinder getToken() {
        return mToken;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null) {
            return false;
        }
        if (getClass() != obj.getClass()) {
            return false;
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        TextClassificationSessionId that = (TextClassificationSessionId) o;
        return Objects.equals(mValue, that.mValue) && Objects.equals(mToken, that.mToken);
    }
        TextClassificationSessionId other = (TextClassificationSessionId) obj;
        if (!mValue.equals(other.mValue)) {
            return false;
        }
        return true;

    @Override
    public int hashCode() {
        return Objects.hash(mValue, mToken);
    }

    @Override
@@ -84,6 +83,7 @@ public final class TextClassificationSessionId implements Parcelable {
    @Override
    public void writeToParcel(Parcel parcel, int flags) {
        parcel.writeString(mValue);
        parcel.writeStrongBinder(mToken);
    }

    @Override
@@ -96,28 +96,18 @@ public final class TextClassificationSessionId implements Parcelable {
     *
     * @return The flattened id.
     */
    public @NonNull String flattenToString() {
    @NonNull
    public String flattenToString() {
        return mValue;
    }

    /**
     * Unflattens a print job id from a string.
     *
     * @param string The string.
     * @return The unflattened id, or null if the string is malformed.
     *
     * @hide
     */
    public static @NonNull TextClassificationSessionId unflattenFromString(@NonNull String string) {
        return new TextClassificationSessionId(string);
    }

    public static final @android.annotation.NonNull Parcelable.Creator<TextClassificationSessionId> CREATOR =
    @NonNull
    public static final Parcelable.Creator<TextClassificationSessionId> CREATOR =
            new Parcelable.Creator<TextClassificationSessionId>() {
                @Override
                public TextClassificationSessionId createFromParcel(Parcel parcel) {
                    return new TextClassificationSessionId(
                            Objects.requireNonNull(parcel.readString()));
                            parcel.readString(), parcel.readStrongBinder());
                }

                @Override
+86 −13
Original line number Diff line number Diff line
@@ -36,7 +36,7 @@ import android.service.textclassifier.ITextClassifierService;
import android.service.textclassifier.TextClassifierService;
import android.service.textclassifier.TextClassifierService.ConnectionState;
import android.text.TextUtils;
import android.util.LruCache;
import android.util.ArrayMap;
import android.util.Slog;
import android.util.SparseArray;
import android.view.textclassifier.ConversationActions;
@@ -65,6 +65,7 @@ import java.io.PrintWriter;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Queue;

@@ -146,11 +147,7 @@ public final class TextClassificationManagerService extends ITextClassifierServi
    private final Object mLock;
    @GuardedBy("mLock")
    final SparseArray<UserState> mUserStates = new SparseArray<>();
    // SystemTextClassifier.onDestroy() is not guaranteed to be called, use LruCache here
    // to avoid leak.
    @GuardedBy("mLock")
    private final LruCache<TextClassificationSessionId, TextClassificationContext>
            mSessionContextCache = new LruCache<>(40);
    private final SessionCache mSessionCache;
    private final TextClassificationConstants mSettings;
    @Nullable
    private final String mDefaultTextClassifierPackage;
@@ -165,6 +162,7 @@ public final class TextClassificationManagerService extends ITextClassifierServi
        PackageManager packageManager = mContext.getPackageManager();
        mDefaultTextClassifierPackage = packageManager.getDefaultTextClassifierPackageName();
        mSystemTextClassifierPackage = packageManager.getSystemTextClassifierPackageName();
        mSessionCache = new SessionCache(mLock);
    }

    private void startListenSettings() {
@@ -314,7 +312,7 @@ public final class TextClassificationManagerService extends ITextClassifierServi
                classificationContext.getUseDefaultTextClassifier(),
                service -> {
                    service.onCreateTextClassificationSession(classificationContext, sessionId);
                    mSessionContextCache.put(sessionId, classificationContext);
                    mSessionCache.put(sessionId, classificationContext);
                },
                "onCreateTextClassificationSession",
                NO_OP_CALLBACK);
@@ -326,14 +324,14 @@ public final class TextClassificationManagerService extends ITextClassifierServi
        Objects.requireNonNull(sessionId);

        synchronized (mLock) {
            TextClassificationContext textClassificationContext =
                    mSessionContextCache.get(sessionId);
            final StrippedTextClassificationContext textClassificationContext =
                    mSessionCache.get(sessionId);
            final int userId = textClassificationContext != null
                    ? textClassificationContext.getUserId()
                    ? textClassificationContext.userId
                    : UserHandle.getCallingUserId();
            final boolean useDefaultTextClassifier =
                    textClassificationContext != null
                            ? textClassificationContext.getUseDefaultTextClassifier()
                            ? textClassificationContext.useDefaultTextClassifier
                            : true;
            handleRequest(
                    userId,
@@ -342,7 +340,7 @@ public final class TextClassificationManagerService extends ITextClassifierServi
                    useDefaultTextClassifier,
                    service -> {
                        service.onDestroyTextClassificationSession(sessionId);
                        mSessionContextCache.remove(sessionId);
                        mSessionCache.remove(sessionId);
                    },
                    "onDestroyTextClassificationSession",
                    NO_OP_CALLBACK);
@@ -409,7 +407,7 @@ public final class TextClassificationManagerService extends ITextClassifierServi
                    pw.decreaseIndent();
                }
            }
            pw.println("Number of active sessions: " + mSessionContextCache.size());
            pw.println("Number of active sessions: " + mSessionCache.size());
        }
    }

@@ -568,6 +566,81 @@ public final class TextClassificationManagerService extends ITextClassifierServi
        }
    }

    /**
     * Stores the stripped down version of {@link TextClassificationContext}s, i.e. {@link
     * StrippedTextClassificationContext},  keyed by {@link TextClassificationSessionId}. Sessions
     * are cleaned up automatically when the client process is dead.
     */
    static final class SessionCache {
        @NonNull
        private final Object mLock;
        @NonNull
        @GuardedBy("mLock")
        private final Map<TextClassificationSessionId, StrippedTextClassificationContext> mCache =
                new ArrayMap<>();
        @NonNull
        @GuardedBy("mLock")
        private final Map<TextClassificationSessionId, DeathRecipient> mDeathRecipients =
                new ArrayMap<>();

        SessionCache(@NonNull Object lock) {
            mLock = Objects.requireNonNull(lock);
        }

        void put(@NonNull TextClassificationSessionId sessionId,
                @NonNull TextClassificationContext textClassificationContext) {
            synchronized (mLock) {
                mCache.put(sessionId,
                        new StrippedTextClassificationContext(textClassificationContext));
                try {
                    DeathRecipient deathRecipient = () -> remove(sessionId);
                    sessionId.getToken().linkToDeath(deathRecipient, /* flags= */ 0);
                    mDeathRecipients.put(sessionId, deathRecipient);
                } catch (RemoteException e) {
                    Slog.w(LOG_TAG, "SessionCache: Failed to link to death", e);
                }
            }
        }

        @Nullable
        StrippedTextClassificationContext get(@NonNull TextClassificationSessionId sessionId) {
            Objects.requireNonNull(sessionId);
            synchronized (mLock) {
                return mCache.get(sessionId);
            }
        }

        void remove(@NonNull TextClassificationSessionId sessionId) {
            Objects.requireNonNull(sessionId);
            synchronized (mLock) {
                DeathRecipient deathRecipient = mDeathRecipients.get(sessionId);
                if (deathRecipient != null) {
                    sessionId.getToken().unlinkToDeath(deathRecipient, /* flags= */ 0);
                }
                mDeathRecipients.remove(sessionId);
                mCache.remove(sessionId);
            }
        }

        int size() {
            synchronized (mLock) {
                return mCache.size();
            }
        }
    }

    /** A stripped down version of {@link TextClassificationContext}. */
    static class StrippedTextClassificationContext {
        @UserIdInt
        public final int userId;
        public final boolean useDefaultTextClassifier;

        StrippedTextClassificationContext(TextClassificationContext textClassificationContext) {
            userId = textClassificationContext.getUserId();
            useDefaultTextClassifier = textClassificationContext.getUseDefaultTextClassifier();
        }
    }

    private final class UserState {
        @UserIdInt
        final int mUserId;