Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e6cab1b0 authored by Sandeep Bandaru's avatar Sandeep Bandaru
Browse files

Added new error handling logic and executors for remote calls.

- Avoids crashing in multiple scenarios which is controlled by `shouldThrow`
- Fixes some potential error handling in case of callbacks are not invoked.
- Added Intelligence and Inference executor to handle remote service calls and multiple failure callbacks in a clean manner.

Flag: EXEMPT refactor
Bug: 423888142
Bug: 416887063
Test: ran cts test suite
Change-Id: I4a2a1121d74d7d5537693ed4c5aa7d2f6de6e916
parent e5f57c74
Loading
Loading
Loading
Loading
+64 −60
Original line number Diff line number Diff line
@@ -61,8 +61,7 @@ public class BundleUtil {
     *
     * @throws BadParcelableException when the bundle does not meet the read-only requirements.
     */
    public static void sanitizeInferenceParams(
            @InferenceParams Bundle bundle) {
    public static void sanitizeInferenceParams(@InferenceParams Bundle bundle) {
        ensureValidBundle(bundle);

        if (!bundle.hasFileDescriptors()) {
@@ -107,8 +106,7 @@ public class BundleUtil {
     *
     * @throws BadParcelableException when the bundle does not meet the read-only requirements.
     */
    public static void sanitizeResponseParams(
            @ResponseParams Bundle bundle) {
    public static void sanitizeResponseParams(@ResponseParams Bundle bundle) {
        ensureValidBundle(bundle);

        if (!bundle.hasFileDescriptors()) {
@@ -147,13 +145,11 @@ public class BundleUtil {
    }

    /**
     * Validation of the inference request payload as described in {@link StateParams}
     * description.
     * Validation of the inference request payload as described in {@link StateParams} description.
     *
     * @throws BadParcelableException when the bundle does not meet the read-only requirements.
     */
    public static void sanitizeStateParams(
            @StateParams Bundle bundle) {
    public static void sanitizeStateParams(@StateParams Bundle bundle) {
        ensureValidBundle(bundle);

        if (!bundle.hasFileDescriptors()) {
@@ -185,7 +181,6 @@ public class BundleUtil {
        }
    }


    public static IStreamingResponseCallback wrapWithValidation(
            IStreamingResponseCallback streamingResponseCallback,
            Executor resourceClosingExecutor,
@@ -204,8 +199,7 @@ public class BundleUtil {
            }

            @Override
            public void onSuccess(Bundle resultBundle)
                    throws RemoteException {
            public void onSuccess(Bundle resultBundle) throws RemoteException {
                try {
                    sanitizeResponseParams(resultBundle);
                    streamingResponseCallback.onSuccess(resultBundle);
@@ -217,20 +211,23 @@ public class BundleUtil {
            }

            @Override
            public void onFailure(int errorCode, String errorMessage,
                    PersistableBundle errorParams) throws RemoteException {
            public void onFailure(int errorCode, String errorMessage, PersistableBundle errorParams)
                    throws RemoteException {
                try {
                    streamingResponseCallback.onFailure(errorCode, errorMessage, errorParams);
                    inferenceInfoStore.addInferenceInfoFromBundle(errorParams);
                future.completeExceptionally(new TimeoutException());
                } finally {
                    future.complete(null);
                }
            }

            @Override
            public void onDataAugmentRequest(Bundle processedContent,
                    RemoteCallback remoteCallback)
            public void onDataAugmentRequest(Bundle processedContent, RemoteCallback remoteCallback)
                    throws RemoteException {
                try {
                    sanitizeResponseParams(processedContent);
                    streamingResponseCallback.onDataAugmentRequest(processedContent,
                    streamingResponseCallback.onDataAugmentRequest(
                            processedContent,
                            new RemoteCallback(
                                    augmentedData -> {
                                        try {
@@ -256,15 +253,15 @@ public class BundleUtil {
        };
    }

    public static IResponseCallback wrapWithValidation(IResponseCallback responseCallback,
    public static IResponseCallback wrapWithValidation(
            IResponseCallback responseCallback,
            Executor resourceClosingExecutor,
            AndroidFuture future,
            InferenceInfoStore inferenceInfoStore,
            boolean shouldForwardInferenceInfo) {
        return new IResponseCallback.Stub() {
            @Override
            public void onSuccess(Bundle resultBundle)
                    throws RemoteException {
            public void onSuccess(Bundle resultBundle) throws RemoteException {
                try {
                    sanitizeResponseParams(resultBundle);
                    responseCallback.onSuccess(resultBundle);
@@ -276,20 +273,24 @@ public class BundleUtil {
            }

            @Override
            public void onFailure(int errorCode, String errorMessage,
                    PersistableBundle errorParams) throws RemoteException {
            public void onFailure(int errorCode, String errorMessage, PersistableBundle errorParams)
                    throws RemoteException {
                try {
                    responseCallback.onFailure(errorCode, errorMessage, errorParams);
                    inferenceInfoStore.addInferenceInfoFromBundle(errorParams);
                future.completeExceptionally(new TimeoutException());
                } finally {
                    future.complete(null);
                }
            }

            @Override
            public void onDataAugmentRequest(Bundle processedContent,
                    RemoteCallback remoteCallback)
            public void onDataAugmentRequest(Bundle processedContent, RemoteCallback remoteCallback)
                    throws RemoteException {
                try {
                    sanitizeResponseParams(processedContent);
                    responseCallback.onDataAugmentRequest(processedContent, new RemoteCallback(
                    responseCallback.onDataAugmentRequest(
                            processedContent,
                            new RemoteCallback(
                                    augmentedData -> {
                                        try {
                                            sanitizeInferenceParams(augmentedData);
@@ -314,30 +315,37 @@ public class BundleUtil {
        };
    }


    public static ITokenInfoCallback wrapWithValidation(ITokenInfoCallback responseCallback,
    public static ITokenInfoCallback wrapWithValidation(
            ITokenInfoCallback responseCallback,
            AndroidFuture future,
            InferenceInfoStore inferenceInfoStore) {
        return new ITokenInfoCallback.Stub() {
            @Override
            public void onSuccess(TokenInfo tokenInfo) throws RemoteException {
                try {
                    responseCallback.onSuccess(tokenInfo);
                    inferenceInfoStore.addInferenceInfoFromBundle(tokenInfo.getInfoParams());
                } finally {
                    future.complete(null);
                }
            }

            @Override
            public void onFailure(int errorCode, String errorMessage, PersistableBundle errorParams)
                    throws RemoteException {
                try {
                    responseCallback.onFailure(errorCode, errorMessage, errorParams);
                    inferenceInfoStore.addInferenceInfoFromBundle(errorParams);
                future.completeExceptionally(new TimeoutException());
                } finally {
                    future.complete(null);
                }
            }
        };
    }

    private static boolean canMarshall(Object obj) {
        return obj instanceof byte[] || obj instanceof PersistableBundle
        return obj instanceof byte[]
                || obj instanceof PersistableBundle
                || PersistableBundle.isValidType(obj);
    }

@@ -352,16 +360,13 @@ public class BundleUtil {
    }

    private static void validateParcelableArray(Parcelable[] parcelables) {
        if (parcelables.length > 0
                && parcelables[0] instanceof ParcelFileDescriptor) {
        if (parcelables.length > 0 && parcelables[0] instanceof ParcelFileDescriptor) {
            // Safe to cast
            validatePfdsReadOnly(parcelables);
        } else if (parcelables.length > 0
                && parcelables[0] instanceof Bitmap) {
        } else if (parcelables.length > 0 && parcelables[0] instanceof Bitmap) {
            validateBitmapsImmutable(parcelables);
        } else {
            throw new BadParcelableException(
                    "Could not cast to any known parcelable array");
            throw new BadParcelableException("Could not cast to any known parcelable array");
        }
    }

@@ -382,8 +387,7 @@ public class BundleUtil {
                        "Bundle contains a parcel file descriptor which is not read-only.");
            }
        } catch (ErrnoException e) {
            throw new BadParcelableException(
                    "Invalid File descriptor passed in the Bundle.", e);
            throw new BadParcelableException("Invalid File descriptor passed in the Bundle.", e);
        }
    }

+422 −283

File changed.

Preview size limit exceeded, changes collapsed.

+32 −46
Original line number Diff line number Diff line
@@ -92,7 +92,9 @@ public class RemoteStorageService extends IRemoteStorageService.Stub {
                        service.getReadOnlyFeatureFileDescriptorMap(
                                feature,
                                new RemoteCallback(
                                        result -> handleFileDescriptorMapResult(result, remoteCallback))));
                                        result ->
                                            handleFileDescriptorMapResult(
                                                        result, remoteCallback))));
    }

    @Override
@@ -106,21 +108,14 @@ public class RemoteStorageService extends IRemoteStorageService.Stub {
                                        result -> handleMetadataResult(result, remoteCallback))));
    }

    private void handleFileDescriptorMapResult(
            Bundle result,
            RemoteCallback remoteCallback) {
        mCallbackExecutor.execute(() -> {
    private void handleFileDescriptorMapResult(Bundle result, RemoteCallback remoteCallback) {
        try {
            if (result == null) {
                remoteCallback.sendResult(null);
                return;
            }
            for (String key : result.keySet()) {
                    ParcelFileDescriptor pfd =
                            result.getParcelable(
                                    key,
                                    ParcelFileDescriptor
                                            .class);
                ParcelFileDescriptor pfd = result.getParcelable(key, ParcelFileDescriptor.class);
                validatePfdReadOnly(pfd);
            }
            remoteCallback.sendResult(result);
@@ -128,15 +123,11 @@ public class RemoteStorageService extends IRemoteStorageService.Stub {
            Slog.e(TAG, "Failed to send result", e);
            remoteCallback.sendResult(null);
        } finally {
                mResourceClosingExecutor.execute(
                        () -> tryCloseResource(result));
            mResourceClosingExecutor.execute(() -> tryCloseResource(result));
        }
        });
    }

    private void handleMetadataResult(Bundle result, RemoteCallback remoteCallback) {
        mCallbackExecutor.execute(
                () -> {
        try {
            if (result == null) {
                remoteCallback.sendResult(null);
@@ -145,16 +136,11 @@ public class RemoteStorageService extends IRemoteStorageService.Stub {
            sanitizeStateParams(result);
            remoteCallback.sendResult(result);
        } catch (BadParcelableException e) {
                        Slog.e(
                                TAG,
                                "Failed to send result",
                                e);
            Slog.e(TAG, "Failed to send result", e);
            remoteCallback.sendResult(null);
        } finally {
                        mResourceClosingExecutor.execute(
                                () -> tryCloseResource(result));
            mResourceClosingExecutor.execute(() -> tryCloseResource(result));
        }
                });
    }

    private static void tryClosePfd(ParcelFileDescriptor pfd) {
+33 −27
Original line number Diff line number Diff line
@@ -26,17 +26,17 @@ import com.android.internal.infra.AndroidFuture;
import java.util.concurrent.TimeoutException;

/**
 * This class extends the {@link IDownloadCallback} and adds a timeout Runnable to the callback
 * such that, in the case where the callback methods are not invoked, we do not have to wait for
 * timeout based on {@link #onDownloadCompleted} which might take minutes or hours to complete in
 * some cases. Instead, in such cases we rely on the remote service sending progress updates and if
 * there are *no* progress callbacks in the duration of {@link #idleTimeoutMs}, we can assume the
 * download will not complete and enabling faster cleanup.
 * This class extends the {@link IDownloadCallback} and adds a timeout Runnable to the callback such
 * that, in the case where the callback methods are not invoked, we do not have to wait for timeout
 * based on {@link #onDownloadCompleted} which might take minutes or hours to complete in some
 * cases. Instead, in such cases we rely on the remote service sending progress updates and if there
 * are *no* progress callbacks in the duration of {@link #idleTimeoutMs}, we can assume the download
 * will not complete and enabling faster cleanup.
 */
public class ListenableDownloadCallback extends IDownloadCallback.Stub implements Runnable {
    private final IDownloadCallback callback;
    private final Handler handler;
    private final AndroidFuture future;
    private final AndroidFuture<?> future;
    private final long idleTimeoutMs;

    /**
@@ -47,45 +47,41 @@ public class ListenableDownloadCallback extends IDownloadCallback.Stub implement
     * @param future future to complete to signal the callback has reached a terminal state.
     * @param idleTimeoutMs timeout within which download updates should be received.
     */
    public ListenableDownloadCallback(IDownloadCallback callback, Handler handler,
            AndroidFuture future,
    public ListenableDownloadCallback(
            IDownloadCallback callback, Handler handler, AndroidFuture<?> future,
            long idleTimeoutMs) {
        this.callback = callback;
        this.handler = handler;
        this.future = future;
        this.idleTimeoutMs = idleTimeoutMs;
        handler.postDelayed(this,
                idleTimeoutMs); // init the timeout runnable in case no callback is ever invoked
        resetTimeout(); // init the timeout runnable in case no callback is ever invoked
    }

    @Override
    public void onDownloadStarted(long bytesToDownload) throws RemoteException {
        callback.onDownloadStarted(bytesToDownload);
        handler.removeCallbacks(this);
        handler.postDelayed(this, idleTimeoutMs);
        resetTimeout();
    }

    @Override
    public void onDownloadProgress(long bytesDownloaded) throws RemoteException {
        callback.onDownloadProgress(bytesDownloaded);
        handler.removeCallbacks(this); // remove previously queued timeout tasks.
        handler.postDelayed(this, idleTimeoutMs); // queue fresh timeout task for next update.
        resetTimeout();
    }

    @Override
    public void onDownloadFailed(int failureStatus,
            String errorMessage, PersistableBundle errorParams) throws RemoteException {
    public void onDownloadFailed(
            int failureStatus, String errorMessage, PersistableBundle errorParams)
            throws RemoteException {
        callback.onDownloadFailed(failureStatus, errorMessage, errorParams);
        handler.removeCallbacks(this);
        future.completeExceptionally(new TimeoutException());
        complete();
    }

    @Override
    public void onDownloadCompleted(
            android.os.PersistableBundle downloadParams) throws RemoteException {
    public void onDownloadCompleted(android.os.PersistableBundle downloadParams)
            throws RemoteException {
        callback.onDownloadCompleted(downloadParams);
        handler.removeCallbacks(this);
        future.complete(null);
        complete();
    }

    @Override
@@ -94,4 +90,14 @@ public class ListenableDownloadCallback extends IDownloadCallback.Stub implement
                new TimeoutException()); // complete the future as we haven't received updates
        // for download progress.
    }

    private void resetTimeout() {
        handler.removeCallbacks(this);
        handler.postDelayed(this, idleTimeoutMs);
    }

    private void complete() {
        handler.removeCallbacks(this);
        future.complete(null);
    }
}
+120 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2025 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.server.ondeviceintelligence.executors;

import android.Manifest;
import android.os.RemoteException;
import android.service.ondeviceintelligence.IOnDeviceSandboxedInferenceService;
import android.util.Slog;

import com.android.internal.infra.AndroidFuture;
import com.android.server.ondeviceintelligence.OnDeviceIntelligenceManagerService;

import java.util.concurrent.TimeoutException;

/**
 * An executor for making remote calls to the {@link IOnDeviceSandboxedInferenceService}.
 *
 * <p>This class centralizes the logic for permission checks, ensuring the remote service is
 * available, and handling various failure scenarios like timeouts or remote exceptions before
 * dispatching the call.
 */
public final class InferenceServiceExecutor
        extends RemoteCallExecutor<IOnDeviceSandboxedInferenceService> {
    private InferenceServiceExecutor(Builder builder) {
        super(builder);
    }

    @Override
    public AndroidFuture<?> execute(
            RemoteCallRunner<IOnDeviceSandboxedInferenceService> remoteCall) {
        OnDeviceIntelligenceManagerService manager =
                OnDeviceIntelligenceManagerService.getInstance();
        manager.getContext()
                .enforceCallingPermission(
                        Manifest.permission.USE_ON_DEVICE_INTELLIGENCE,
                        OnDeviceIntelligenceManagerService.TAG);
        if (!manager.isServiceEnabled()) {
            Slog.w(OnDeviceIntelligenceManagerService.TAG, "Service not available");
            executeOnRemoteExecutor(() -> {
                try {
                    mFailureConsumer.accept(FailureType.SERVICE_UNAVAILABLE);
                } catch (RemoteException e) {
                    Slog.e(OnDeviceIntelligenceManagerService.TAG,
                            "Failed to call service unavailable callback", e);
                }
            });
            return null;
        }
        if (!manager.ensureRemoteInferenceServiceInitialized(/* shouldThrow= */ false)) {
            Slog.w(OnDeviceIntelligenceManagerService.TAG, "Service not available");
            executeOnRemoteExecutor(() -> {
                try {
                    mFailureConsumer.accept(FailureType.SERVICE_UNAVAILABLE);
                } catch (RemoteException e) {
                    Slog.e(OnDeviceIntelligenceManagerService.TAG,
                            "Failed to call service unavailable callback", e);
                }
            });
            return null;
        }
        AndroidFuture<?> future = manager.getRemoteInferenceService().postAsync(remoteCall::run);
        future.whenComplete(
                (res, ex) -> {
                    if (ex != null) {
                        Slog.e(
                                OnDeviceIntelligenceManagerService.TAG,
                                "Remote inference service call failed",
                                ex);
                        if (ex instanceof TimeoutException) {
                            executeOnRemoteExecutor(() -> {
                                try {
                                    mFailureConsumer.accept(FailureType.TIMEOUT);
                                } catch (RemoteException e) {
                                    Slog.e(OnDeviceIntelligenceManagerService.TAG,
                                            "Failed to call timeout callback", e);
                                }
                            });
                        } else {
                            executeOnRemoteExecutor(() -> {
                                try {
                                    mFailureConsumer.accept(FailureType.REMOTE_FAILURE);
                                } catch (RemoteException e) {
                                    Slog.e(OnDeviceIntelligenceManagerService.TAG,
                                            "Failed to call remote failure callback", e);
                                }
                            });
                        }
                    }
                });
        return future;
    }

    /** Builder for {@link InferenceServiceExecutor}. */
    public static class Builder
            extends RemoteCallExecutor.Builder<IOnDeviceSandboxedInferenceService, Builder> {

        public Builder() {
            // empty constructor.
        }

        @Override
        public InferenceServiceExecutor build() {
            return new InferenceServiceExecutor(this);
        }
    }
}
Loading