Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 0c5b5043 authored by Treehugger Robot's avatar Treehugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Recycle Bitmap objects when closing resources in BundleUtil." into main

parents 6d32f057 fbeb4360
Loading
Loading
Loading
Loading
+30 −8
Original line number Diff line number Diff line
@@ -324,7 +324,9 @@ public class BundleUtil {
            public void onSuccess(TokenInfo tokenInfo) throws RemoteException {
                try {
                    responseCallback.onSuccess(tokenInfo);
                    if (tokenInfo.getInfoParams() != null) {
                        inferenceInfoStore.addInferenceInfoFromBundle(tokenInfo.getInfoParams());
                    }
                } finally {
                    future.complete(null);
                }
@@ -404,8 +406,22 @@ public class BundleUtil {
        }
    }

    private static void tryCloseParcelableArray(Parcelable[] parcelables) {
        for (Parcelable p : parcelables) {
            try {
                if (p instanceof ParcelFileDescriptor pfd) {
                    pfd.close();
                } else if (p instanceof Bitmap bitmap) {
                    bitmap.recycle();
                }
            } catch (Exception e) {
                Log.e(TAG, "Error closing a resource in a Parcelable array", e);
            }
        }
    }

    public static void tryCloseResource(Bundle bundle) {
        if (bundle == null || bundle.isEmpty() || !bundle.hasFileDescriptors()) {
        if (bundle == null || bundle.isEmpty()) {
            return;
        }

@@ -414,13 +430,19 @@ public class BundleUtil {

            try {
                // TODO(b/329898589) : This can be cleaned up after the flag passing is fixed.
                if (obj instanceof ParcelFileDescriptor) {
                    ((ParcelFileDescriptor) obj).close();
                } else if (obj instanceof CursorWindow) {
                    ((CursorWindow) obj).close();
                } else if (obj instanceof SharedMemory) {
                if (obj instanceof ParcelFileDescriptor pfd) {
                    pfd.close();
                } else if (obj instanceof CursorWindow cursorWindow) {
                    cursorWindow.close();
                } else if (obj instanceof SharedMemory sharedMemory) {
                    // TODO(b/331796886) : Shared memory should honour parcelable flags.
                    ((SharedMemory) obj).close();
                    sharedMemory.close();
                } else if (obj instanceof Bitmap bitmap) {
                    bitmap.recycle();
                } else if (obj instanceof Parcelable[] parcelables) {
                    tryCloseParcelableArray(parcelables);
                } else if (obj instanceof Bundle nestedBundle) {
                    tryCloseResource(nestedBundle);
                }
            } catch (Exception e) {
                Log.e(TAG, "Error closing resource with key: " + key, e);
+497 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2024 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.server.ondeviceintelligence;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;

import android.app.ondeviceintelligence.IResponseCallback;
import android.app.ondeviceintelligence.IStreamingResponseCallback;
import android.app.ondeviceintelligence.ITokenInfoCallback;
import android.app.ondeviceintelligence.InferenceInfo;
import android.app.ondeviceintelligence.TokenInfo;
import android.database.CursorWindow;
import android.graphics.Bitmap;
import android.os.BadParcelableException;
import android.os.Binder;
import android.os.Bundle;
import android.os.Parcel;
import android.os.ParcelFileDescriptor;
import android.os.Parcelable;
import android.os.PersistableBundle;
import android.os.RemoteCallback;
import android.os.SharedMemory;
import android.system.ErrnoException;

import androidx.test.ext.junit.runners.AndroidJUnit4;

import com.android.internal.infra.AndroidFuture;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;

import java.io.IOException;
import java.util.concurrent.Executor;

@RunWith(AndroidJUnit4.class)
public class BundleUtilTest {
    private ParcelFileDescriptor mReadOnlyPfd;
    private ParcelFileDescriptor mReadWritePfd;
    private Bitmap mImmutableBitmap;
    private Bitmap mMutableBitmap;
    private SharedMemory mSharedMemory;
    private CursorWindow mCursorWindow;

    private final Executor mDirectExecutor = Runnable::run;

    @Before
    public void setUp() throws IOException, ErrnoException {
        ParcelFileDescriptor[] pipe = ParcelFileDescriptor.createPipe();
        mReadOnlyPfd = pipe[0];
        mReadWritePfd = pipe[1];

        mMutableBitmap = Bitmap.createBitmap(1, 1, Bitmap.Config.ARGB_8888);
        mImmutableBitmap = mMutableBitmap.asShared();

        mSharedMemory = SharedMemory.create("test", 1024);
        mCursorWindow = new CursorWindow("test_cursor");
    }

    @After
    public void tearDown() throws Exception {
        try {
            mReadOnlyPfd.close();
        } catch (IOException e) {
            // ignore
        }
        try {
            mReadWritePfd.close();
        } catch (IOException e) {
            // ignore
        }
        if (mImmutableBitmap != null) {
            mImmutableBitmap.recycle();
        }
        if (mMutableBitmap != null) {
            mMutableBitmap.recycle();
        }
        mSharedMemory.close();
        mCursorWindow.close();
    }

    private ParcelFileDescriptor createReadOnlyPfd() throws IOException {
        return ParcelFileDescriptor.createPipe()[0];
    }

    private ParcelFileDescriptor createReadWritePfd() throws IOException {
        return ParcelFileDescriptor.createPipe()[1];
    }

    private Bundle getParcelledBundle(Bundle bundle) {
        Parcel p = Parcel.obtain();
        bundle.writeToParcel(p, 0);
        p.setDataPosition(0);
        Bundle newBundle = new Bundle(p);
        p.recycle();
        return newBundle;
    }

    @Test
    public void sanitizeInferenceParams_nullBundle_throws() {
        assertThrows(
                IllegalArgumentException.class, () -> BundleUtil.sanitizeInferenceParams(null));
    }

    @Test
    public void sanitizeInferenceParams_bundleWithBinder_throws() {
        Bundle bundle = new Bundle();
        bundle.putBinder("binder", new Binder());
        assertThrows(
                BadParcelableException.class, () -> BundleUtil.sanitizeInferenceParams(bundle));
    }

    @Test
    public void sanitizeInferenceParams_validTypes_success() {
        Bundle bundle = new Bundle();
        bundle.putByteArray("bytes", new byte[1]);
        bundle.putParcelable("persistable", new PersistableBundle());
        BundleUtil.sanitizeInferenceParams(getParcelledBundle(bundle)); // no exception
    }

    @Test
    public void sanitizeInferenceParams_readOnlyPfd_success() {
        Bundle bundle = new Bundle();
        bundle.putParcelable("pfd", mReadOnlyPfd);
        BundleUtil.sanitizeInferenceParams(getParcelledBundle(bundle));
    }

    @Test
    public void sanitizeInferenceParams_readWritePfd_throws() {
        Bundle bundle = new Bundle();
        bundle.putParcelable("pfd", mReadWritePfd);
        assertThrows(
                BadParcelableException.class, () -> BundleUtil.sanitizeInferenceParams(bundle));
    }

    @Test
    public void sanitizeInferenceParams_sharedMemory_setsReadOnly() throws Exception {
        Bundle bundle = new Bundle();
        bundle.putParcelable("shmem", mSharedMemory);
        BundleUtil.sanitizeInferenceParams(getParcelledBundle(bundle));
    }

    @Test
    public void sanitizeInferenceParams_immutableBitmap_success() {
        Bundle bundle = new Bundle();
        bundle.putParcelable("bitmap", mImmutableBitmap);
        BundleUtil.sanitizeInferenceParams(getParcelledBundle(bundle));
    }

    @Test
    public void sanitizeInferenceParams_mutableBitmap_throws() {
        Bundle bundle = new Bundle();
        bundle.putParcelable("bitmap", mMutableBitmap);
        assertThrows(
                BadParcelableException.class, () -> BundleUtil.sanitizeInferenceParams(bundle));
    }

    @Test
    public void sanitizeInferenceParams_unsupportedParcelable_throws() {
        Bundle bundle = new Bundle();
        bundle.putParcelable("unsupported", mock(Parcelable.class));
        assertThrows(
                BadParcelableException.class, () -> BundleUtil.sanitizeInferenceParams(bundle));
    }

    @Test
    public void sanitizeInferenceParams_readOnlyPfdArray_success() throws IOException {
        Bundle bundle = new Bundle();
        bundle.putParcelableArray(
                "pfd_array", new Parcelable[] {createReadOnlyPfd(), createReadOnlyPfd()});
        BundleUtil.sanitizeInferenceParams(getParcelledBundle(bundle));
    }

    @Test
    public void sanitizeInferenceParams_readWritePfdInArray_throws() throws IOException {
        Bundle bundle = new Bundle();
        bundle.putParcelableArray(
                "pfd_array", new Parcelable[] {createReadOnlyPfd(), createReadWritePfd()});
        assertThrows(
                BadParcelableException.class, () -> BundleUtil.sanitizeInferenceParams(bundle));
    }

    @Test
    public void sanitizeInferenceParams_immutableBitmapsArray_success() {
        Bundle bundle = new Bundle();
        bundle.putParcelableArray(
                "bitmap_array", new Parcelable[] {mImmutableBitmap, mImmutableBitmap});
        BundleUtil.sanitizeInferenceParams(getParcelledBundle(bundle));
    }

    @Test
    public void sanitizeInferenceParams_mutableBitmapInArray_throws() {
        Bundle bundle = new Bundle();
        bundle.putParcelableArray(
                "bitmap_array", new Parcelable[] {mImmutableBitmap, mMutableBitmap});
        assertThrows(
                BadParcelableException.class, () -> BundleUtil.sanitizeInferenceParams(bundle));
    }

    @Test
    public void sanitizeInferenceParams_nestedBundle_sanitizesRecursively() {
        Bundle nestedBundle = new Bundle();
        nestedBundle.putParcelable("pfd", mReadWritePfd);
        Bundle bundle = new Bundle();
        bundle.putBundle("nested", nestedBundle);
        assertThrows(
                BadParcelableException.class,
                () -> BundleUtil.sanitizeInferenceParams(getParcelledBundle(bundle)));
    }

    @Test
    public void sanitizeInferenceParams_withCursorWindow_success() {
        Bundle bundle = new Bundle();
        bundle.putParcelable("cursor", mCursorWindow);
        BundleUtil.sanitizeInferenceParams(getParcelledBundle(bundle));
    }

    @Test
    public void sanitizeResponseParams_unsupportedTypes_throws() {
        Bundle bundleSharedMem = new Bundle();
        bundleSharedMem.putParcelable("shmem", mSharedMemory);
        assertThrows(
                BadParcelableException.class,
                () -> BundleUtil.sanitizeResponseParams(bundleSharedMem));

        Bundle bundleCursor = new Bundle();
        bundleCursor.putParcelable("cursor", mCursorWindow);
        assertThrows(
                BadParcelableException.class,
                () -> BundleUtil.sanitizeResponseParams(bundleCursor));
    }

    @Test
    public void sanitizeStateParams_readOnlyPfd_success() {
        Bundle bundle = new Bundle();
        bundle.putParcelable("pfd", mReadOnlyPfd);
        BundleUtil.sanitizeStateParams(getParcelledBundle(bundle));
    }

    @Test
    public void tryCloseResource_nullAndEmptyBundle_doesNotThrow() {
        BundleUtil.tryCloseResource(null);
        BundleUtil.tryCloseResource(new Bundle());
    }

    @Test
    public void tryCloseResource_closesPfd() throws Exception {
        ParcelFileDescriptor pfd = createReadWritePfd();
        Bundle bundle = new Bundle();
        bundle.putParcelable("pfd", pfd);
        assertTrue(pfd.getFileDescriptor().valid());
        BundleUtil.tryCloseResource(bundle);
        assertFalse(pfd.getFileDescriptor().valid());
    }

    @Test
    public void tryCloseResource_recyclesBitmap() {
        Bitmap bitmap = Bitmap.createBitmap(1, 1, Bitmap.Config.ARGB_8888).asShared();
        Bundle bundle = new Bundle();
        bundle.putParcelable("bitmap", bitmap);

        assertFalse(bitmap.isRecycled());
        BundleUtil.tryCloseResource(bundle);
        assertTrue(bitmap.isRecycled());
    }

    @Test
    public void tryCloseResource_closesSharedMemory() throws Exception {
        SharedMemory shmem = SharedMemory.create("closetest", 128);
        Bundle bundle = new Bundle();
        bundle.putParcelable("shmem", shmem);

        BundleUtil.tryCloseResource(bundle);
    }

    @Test
    public void tryCloseResource_closesCursorWindow() {
        CursorWindow window = new CursorWindow("closetest");
        Bundle bundle = new Bundle();
        bundle.putParcelable("cursor", window);
        BundleUtil.tryCloseResource(bundle);
    }

    @Test
    public void tryCloseResource_closesPfdInArray() throws Exception {
        ParcelFileDescriptor pfd = createReadWritePfd();
        Bundle bundle = new Bundle();
        bundle.putParcelableArray("pfd_array", new Parcelable[] {pfd});
        assertTrue(pfd.getFileDescriptor().valid());
        BundleUtil.tryCloseResource(bundle);
        assertFalse(pfd.getFileDescriptor().valid());
    }

    @Test
    public void tryCloseResource_recyclesBitmapInArray() {
        Bitmap bitmap = Bitmap.createBitmap(1, 1, Bitmap.Config.ARGB_8888).asShared();
        Bundle bundle = new Bundle();
        bundle.putParcelableArray("bitmap_array", new Parcelable[] {bitmap});

        assertFalse(bitmap.isRecycled());
        BundleUtil.tryCloseResource(bundle);
        assertTrue(bitmap.isRecycled());
    }

    @Test
    public void tryCloseResource_closesNestedBundleResource() throws Exception {
        ParcelFileDescriptor pfdMock = mock(ParcelFileDescriptor.class);
        Bundle nestedBundle = new Bundle();
        nestedBundle.putParcelable("pfd", pfdMock);
        Bundle bundle = new Bundle();
        bundle.putBundle("nested", nestedBundle);
        BundleUtil.tryCloseResource(bundle);
        verify(pfdMock).close();
    }

    @Test
    public void wrapWithValidation_IStreamingResponseCallback_onSuccess() throws Exception {
        IStreamingResponseCallback mockCallback = mock(IStreamingResponseCallback.class);
        AndroidFuture<Void> future = new AndroidFuture<>();
        InferenceInfoStore mockStore = mock(InferenceInfoStore.class);
        Bundle result = new Bundle();
        result.putParcelable("pfd", createReadOnlyPfd());
        result = getParcelledBundle(result);

        IStreamingResponseCallback wrapper =
                BundleUtil.wrapWithValidation(
                        mockCallback, mDirectExecutor, future, mockStore, false);
        wrapper.onSuccess(result);

        verify(mockCallback).onSuccess(result);
        verify(mockStore).addInferenceInfoFromBundle(any(Bundle.class));
        assertTrue(future.isDone());
        assertFalse(
                result.getParcelable("pfd", ParcelFileDescriptor.class)
                        .getFileDescriptor()
                        .valid());
    }

    @Test
    public void wrapWithValidation_IStreamingResponseCallback_onFailure() throws Exception {
        IStreamingResponseCallback mockCallback = mock(IStreamingResponseCallback.class);
        AndroidFuture<Void> future = new AndroidFuture<>();
        InferenceInfoStore mockStore = mock(InferenceInfoStore.class);
        PersistableBundle errorParams = new PersistableBundle();

        IStreamingResponseCallback wrapper =
                BundleUtil.wrapWithValidation(
                        mockCallback, mDirectExecutor, future, mockStore, false);
        wrapper.onFailure(1, "error", errorParams);

        verify(mockCallback).onFailure(1, "error", errorParams);
        verify(mockStore).addInferenceInfoFromBundle(any(PersistableBundle.class));
        assertTrue(future.isDone());
    }

    @Test
    public void wrapWithValidation_IStreamingResponseCallback_onInferenceInfo_forwardsWhenEnabled()
            throws Exception {
        IStreamingResponseCallback mockCallback = mock(IStreamingResponseCallback.class);
        AndroidFuture<Void> future = new AndroidFuture<>();
        InferenceInfoStore mockStore = mock(InferenceInfoStore.class);
        InferenceInfo info = new InferenceInfo.Builder(123).build();

        IStreamingResponseCallback wrapper =
                BundleUtil.wrapWithValidation(
                        mockCallback, mDirectExecutor, future, mockStore, true);
        wrapper.onInferenceInfo(info);

        verify(mockStore).add(info);
        verify(mockCallback).onInferenceInfo(info);
    }

    @Test
    public void
            wrapWithValidation_IStreamingResponseCallback_onInferenceInfo_doesNotForwardWhenDisabled()
                    throws Exception {
        IStreamingResponseCallback mockCallback = mock(IStreamingResponseCallback.class);
        AndroidFuture<Void> future = new AndroidFuture<>();
        InferenceInfoStore mockStore = mock(InferenceInfoStore.class);
        InferenceInfo info = new InferenceInfo.Builder(123).build();
        IStreamingResponseCallback wrapperNoForward =
                BundleUtil.wrapWithValidation(
                        mockCallback, mDirectExecutor, future, mockStore, false);
        wrapperNoForward.onInferenceInfo(info);
        verify(mockCallback, never()).onInferenceInfo(info);
    }

    @Test
    public void wrapWithValidation_IResponseCallback_onSuccess() throws Exception {
        IResponseCallback mockCallback = mock(IResponseCallback.class);
        AndroidFuture<Void> future = new AndroidFuture<>();
        InferenceInfoStore mockStore = mock(InferenceInfoStore.class);
        Bundle result = new Bundle();
        result.putParcelable("pfd", createReadOnlyPfd());
        result = getParcelledBundle(result);

        IResponseCallback wrapper =
                BundleUtil.wrapWithValidation(
                        mockCallback, mDirectExecutor, future, mockStore, false);
        wrapper.onSuccess(result);

        verify(mockCallback).onSuccess(result);
        verify(mockStore).addInferenceInfoFromBundle(any(Bundle.class));
        assertTrue(future.isDone());
        assertFalse(
                result.getParcelable("pfd", ParcelFileDescriptor.class)
                        .getFileDescriptor()
                        .valid());
    }

    @Test
    public void wrapWithValidation_IResponseCallback_onDataAugmentRequest() throws Exception {
        IResponseCallback mockCallback = mock(IResponseCallback.class);
        RemoteCallback mockRemoteCallback = mock(RemoteCallback.class);
        AndroidFuture<Void> future = new AndroidFuture<>();
        InferenceInfoStore mockStore = mock(InferenceInfoStore.class);
        Bundle processedContent = new Bundle();
        processedContent.putParcelable("pfd_processed", createReadOnlyPfd());
        processedContent = getParcelledBundle(processedContent);
        Bundle augmentedData = new Bundle();
        augmentedData.putParcelable("pfd_augmented", createReadOnlyPfd());
        augmentedData = getParcelledBundle(augmentedData);

        IResponseCallback wrapper =
                BundleUtil.wrapWithValidation(
                        mockCallback, mDirectExecutor, future, mockStore, false);
        ArgumentCaptor<RemoteCallback> callbackCaptor =
                ArgumentCaptor.forClass(RemoteCallback.class);

        wrapper.onDataAugmentRequest(processedContent, mockRemoteCallback);
        verify(mockCallback).onDataAugmentRequest(any(Bundle.class), callbackCaptor.capture());

        callbackCaptor.getValue().sendResult(augmentedData);
        verify(mockRemoteCallback).sendResult(augmentedData);

        assertFalse(
                processedContent
                        .getParcelable("pfd_processed", ParcelFileDescriptor.class)
                        .getFileDescriptor()
                        .valid());
        assertFalse(
                augmentedData
                        .getParcelable("pfd_augmented", ParcelFileDescriptor.class)
                        .getFileDescriptor()
                        .valid());
    }

    @Test
    public void wrapWithValidation_ITokenInfoCallback_onSuccess() throws Exception {
        ITokenInfoCallback mockCallback = mock(ITokenInfoCallback.class);
        AndroidFuture<Void> future = new AndroidFuture<>();
        InferenceInfoStore mockStore = mock(InferenceInfoStore.class);
        TokenInfo tokenInfo = new TokenInfo(1, PersistableBundle.EMPTY);

        ITokenInfoCallback wrapper = BundleUtil.wrapWithValidation(mockCallback, future, mockStore);
        wrapper.onSuccess(tokenInfo);

        verify(mockCallback).onSuccess(tokenInfo);
        verify(mockStore).addInferenceInfoFromBundle(tokenInfo.getInfoParams());
        assertTrue(future.isDone());
    }

    @Test
    public void wrapWithValidation_ITokenInfoCallback_onFailure() throws Exception {
        ITokenInfoCallback mockCallback = mock(ITokenInfoCallback.class);
        AndroidFuture<Void> future = new AndroidFuture<>();
        InferenceInfoStore mockStore = mock(InferenceInfoStore.class);
        PersistableBundle errorParams = new PersistableBundle();
        ITokenInfoCallback wrapper = BundleUtil.wrapWithValidation(mockCallback, future, mockStore);
        wrapper.onFailure(1, "error", errorParams);
        verify(mockCallback).onFailure(1, "error", errorParams);
        verify(mockStore).addInferenceInfoFromBundle(errorParams);
        assertTrue(future.isDone());
    }
}