Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 4a8d8b6b authored by Treehugger Robot's avatar Treehugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Added new error handling logic and executors for remote calls." into main

parents 43e49b13 e6cab1b0
Loading
Loading
Loading
Loading
+64 −60
Original line number Diff line number Diff line
@@ -61,8 +61,7 @@ public class BundleUtil {
     *
     * @throws BadParcelableException when the bundle does not meet the read-only requirements.
     */
    public static void sanitizeInferenceParams(
            @InferenceParams Bundle bundle) {
    public static void sanitizeInferenceParams(@InferenceParams Bundle bundle) {
        ensureValidBundle(bundle);

        if (!bundle.hasFileDescriptors()) {
@@ -107,8 +106,7 @@ public class BundleUtil {
     *
     * @throws BadParcelableException when the bundle does not meet the read-only requirements.
     */
    public static void sanitizeResponseParams(
            @ResponseParams Bundle bundle) {
    public static void sanitizeResponseParams(@ResponseParams Bundle bundle) {
        ensureValidBundle(bundle);

        if (!bundle.hasFileDescriptors()) {
@@ -147,13 +145,11 @@ public class BundleUtil {
    }

    /**
     * Validation of the inference request payload as described in {@link StateParams}
     * description.
     * Validation of the inference request payload as described in {@link StateParams} description.
     *
     * @throws BadParcelableException when the bundle does not meet the read-only requirements.
     */
    public static void sanitizeStateParams(
            @StateParams Bundle bundle) {
    public static void sanitizeStateParams(@StateParams Bundle bundle) {
        ensureValidBundle(bundle);

        if (!bundle.hasFileDescriptors()) {
@@ -185,7 +181,6 @@ public class BundleUtil {
        }
    }


    public static IStreamingResponseCallback wrapWithValidation(
            IStreamingResponseCallback streamingResponseCallback,
            Executor resourceClosingExecutor,
@@ -204,8 +199,7 @@ public class BundleUtil {
            }

            @Override
            public void onSuccess(Bundle resultBundle)
                    throws RemoteException {
            public void onSuccess(Bundle resultBundle) throws RemoteException {
                try {
                    sanitizeResponseParams(resultBundle);
                    streamingResponseCallback.onSuccess(resultBundle);
@@ -217,20 +211,23 @@ public class BundleUtil {
            }

            @Override
            public void onFailure(int errorCode, String errorMessage,
                    PersistableBundle errorParams) throws RemoteException {
            public void onFailure(int errorCode, String errorMessage, PersistableBundle errorParams)
                    throws RemoteException {
                try {
                    streamingResponseCallback.onFailure(errorCode, errorMessage, errorParams);
                    inferenceInfoStore.addInferenceInfoFromBundle(errorParams);
                future.completeExceptionally(new TimeoutException());
                } finally {
                    future.complete(null);
                }
            }

            @Override
            public void onDataAugmentRequest(Bundle processedContent,
                    RemoteCallback remoteCallback)
            public void onDataAugmentRequest(Bundle processedContent, RemoteCallback remoteCallback)
                    throws RemoteException {
                try {
                    sanitizeResponseParams(processedContent);
                    streamingResponseCallback.onDataAugmentRequest(processedContent,
                    streamingResponseCallback.onDataAugmentRequest(
                            processedContent,
                            new RemoteCallback(
                                    augmentedData -> {
                                        try {
@@ -256,15 +253,15 @@ public class BundleUtil {
        };
    }

    public static IResponseCallback wrapWithValidation(IResponseCallback responseCallback,
    public static IResponseCallback wrapWithValidation(
            IResponseCallback responseCallback,
            Executor resourceClosingExecutor,
            AndroidFuture future,
            InferenceInfoStore inferenceInfoStore,
            boolean shouldForwardInferenceInfo) {
        return new IResponseCallback.Stub() {
            @Override
            public void onSuccess(Bundle resultBundle)
                    throws RemoteException {
            public void onSuccess(Bundle resultBundle) throws RemoteException {
                try {
                    sanitizeResponseParams(resultBundle);
                    responseCallback.onSuccess(resultBundle);
@@ -276,20 +273,24 @@ public class BundleUtil {
            }

            @Override
            public void onFailure(int errorCode, String errorMessage,
                    PersistableBundle errorParams) throws RemoteException {
            public void onFailure(int errorCode, String errorMessage, PersistableBundle errorParams)
                    throws RemoteException {
                try {
                    responseCallback.onFailure(errorCode, errorMessage, errorParams);
                    inferenceInfoStore.addInferenceInfoFromBundle(errorParams);
                future.completeExceptionally(new TimeoutException());
                } finally {
                    future.complete(null);
                }
            }

            @Override
            public void onDataAugmentRequest(Bundle processedContent,
                    RemoteCallback remoteCallback)
            public void onDataAugmentRequest(Bundle processedContent, RemoteCallback remoteCallback)
                    throws RemoteException {
                try {
                    sanitizeResponseParams(processedContent);
                    responseCallback.onDataAugmentRequest(processedContent, new RemoteCallback(
                    responseCallback.onDataAugmentRequest(
                            processedContent,
                            new RemoteCallback(
                                    augmentedData -> {
                                        try {
                                            sanitizeInferenceParams(augmentedData);
@@ -314,30 +315,37 @@ public class BundleUtil {
        };
    }


    public static ITokenInfoCallback wrapWithValidation(ITokenInfoCallback responseCallback,
    public static ITokenInfoCallback wrapWithValidation(
            ITokenInfoCallback responseCallback,
            AndroidFuture future,
            InferenceInfoStore inferenceInfoStore) {
        return new ITokenInfoCallback.Stub() {
            @Override
            public void onSuccess(TokenInfo tokenInfo) throws RemoteException {
                try {
                    responseCallback.onSuccess(tokenInfo);
                    inferenceInfoStore.addInferenceInfoFromBundle(tokenInfo.getInfoParams());
                } finally {
                    future.complete(null);
                }
            }

            @Override
            public void onFailure(int errorCode, String errorMessage, PersistableBundle errorParams)
                    throws RemoteException {
                try {
                    responseCallback.onFailure(errorCode, errorMessage, errorParams);
                    inferenceInfoStore.addInferenceInfoFromBundle(errorParams);
                future.completeExceptionally(new TimeoutException());
                } finally {
                    future.complete(null);
                }
            }
        };
    }

    private static boolean canMarshall(Object obj) {
        return obj instanceof byte[] || obj instanceof PersistableBundle
        return obj instanceof byte[]
                || obj instanceof PersistableBundle
                || PersistableBundle.isValidType(obj);
    }

@@ -352,16 +360,13 @@ public class BundleUtil {
    }

    private static void validateParcelableArray(Parcelable[] parcelables) {
        if (parcelables.length > 0
                && parcelables[0] instanceof ParcelFileDescriptor) {
        if (parcelables.length > 0 && parcelables[0] instanceof ParcelFileDescriptor) {
            // Safe to cast
            validatePfdsReadOnly(parcelables);
        } else if (parcelables.length > 0
                && parcelables[0] instanceof Bitmap) {
        } else if (parcelables.length > 0 && parcelables[0] instanceof Bitmap) {
            validateBitmapsImmutable(parcelables);
        } else {
            throw new BadParcelableException(
                    "Could not cast to any known parcelable array");
            throw new BadParcelableException("Could not cast to any known parcelable array");
        }
    }

@@ -382,8 +387,7 @@ public class BundleUtil {
                        "Bundle contains a parcel file descriptor which is not read-only.");
            }
        } catch (ErrnoException e) {
            throw new BadParcelableException(
                    "Invalid File descriptor passed in the Bundle.", e);
            throw new BadParcelableException("Invalid File descriptor passed in the Bundle.", e);
        }
    }

+422 −283

File changed.

Preview size limit exceeded, changes collapsed.

+32 −46
Original line number Diff line number Diff line
@@ -92,7 +92,9 @@ public class RemoteStorageService extends IRemoteStorageService.Stub {
                        service.getReadOnlyFeatureFileDescriptorMap(
                                feature,
                                new RemoteCallback(
                                        result -> handleFileDescriptorMapResult(result, remoteCallback))));
                                        result ->
                                            handleFileDescriptorMapResult(
                                                        result, remoteCallback))));
    }

    @Override
@@ -106,21 +108,14 @@ public class RemoteStorageService extends IRemoteStorageService.Stub {
                                        result -> handleMetadataResult(result, remoteCallback))));
    }

    private void handleFileDescriptorMapResult(
            Bundle result,
            RemoteCallback remoteCallback) {
        mCallbackExecutor.execute(() -> {
    private void handleFileDescriptorMapResult(Bundle result, RemoteCallback remoteCallback) {
        try {
            if (result == null) {
                remoteCallback.sendResult(null);
                return;
            }
            for (String key : result.keySet()) {
                    ParcelFileDescriptor pfd =
                            result.getParcelable(
                                    key,
                                    ParcelFileDescriptor
                                            .class);
                ParcelFileDescriptor pfd = result.getParcelable(key, ParcelFileDescriptor.class);
                validatePfdReadOnly(pfd);
            }
            remoteCallback.sendResult(result);
@@ -128,15 +123,11 @@ public class RemoteStorageService extends IRemoteStorageService.Stub {
            Slog.e(TAG, "Failed to send result", e);
            remoteCallback.sendResult(null);
        } finally {
                mResourceClosingExecutor.execute(
                        () -> tryCloseResource(result));
            mResourceClosingExecutor.execute(() -> tryCloseResource(result));
        }
        });
    }

    private void handleMetadataResult(Bundle result, RemoteCallback remoteCallback) {
        mCallbackExecutor.execute(
                () -> {
        try {
            if (result == null) {
                remoteCallback.sendResult(null);
@@ -145,16 +136,11 @@ public class RemoteStorageService extends IRemoteStorageService.Stub {
            sanitizeStateParams(result);
            remoteCallback.sendResult(result);
        } catch (BadParcelableException e) {
                        Slog.e(
                                TAG,
                                "Failed to send result",
                                e);
            Slog.e(TAG, "Failed to send result", e);
            remoteCallback.sendResult(null);
        } finally {
                        mResourceClosingExecutor.execute(
                                () -> tryCloseResource(result));
            mResourceClosingExecutor.execute(() -> tryCloseResource(result));
        }
                });
    }

    private static void tryClosePfd(ParcelFileDescriptor pfd) {
+33 −27
Original line number Diff line number Diff line
@@ -26,17 +26,17 @@ import com.android.internal.infra.AndroidFuture;
import java.util.concurrent.TimeoutException;

/**
 * This class extends the {@link IDownloadCallback} and adds a timeout Runnable to the callback
 * such that, in the case where the callback methods are not invoked, we do not have to wait for
 * timeout based on {@link #onDownloadCompleted} which might take minutes or hours to complete in
 * some cases. Instead, in such cases we rely on the remote service sending progress updates and if
 * there are *no* progress callbacks in the duration of {@link #idleTimeoutMs}, we can assume the
 * download will not complete and enabling faster cleanup.
 * This class extends the {@link IDownloadCallback} and adds a timeout Runnable to the callback such
 * that, in the case where the callback methods are not invoked, we do not have to wait for timeout
 * based on {@link #onDownloadCompleted} which might take minutes or hours to complete in some
 * cases. Instead, in such cases we rely on the remote service sending progress updates and if there
 * are *no* progress callbacks in the duration of {@link #idleTimeoutMs}, we can assume the download
 * will not complete and enabling faster cleanup.
 */
public class ListenableDownloadCallback extends IDownloadCallback.Stub implements Runnable {
    private final IDownloadCallback callback;
    private final Handler handler;
    private final AndroidFuture future;
    private final AndroidFuture<?> future;
    private final long idleTimeoutMs;

    /**
@@ -47,45 +47,41 @@ public class ListenableDownloadCallback extends IDownloadCallback.Stub implement
     * @param future future to complete to signal the callback has reached a terminal state.
     * @param idleTimeoutMs timeout within which download updates should be received.
     */
    public ListenableDownloadCallback(IDownloadCallback callback, Handler handler,
            AndroidFuture future,
    public ListenableDownloadCallback(
            IDownloadCallback callback, Handler handler, AndroidFuture<?> future,
            long idleTimeoutMs) {
        this.callback = callback;
        this.handler = handler;
        this.future = future;
        this.idleTimeoutMs = idleTimeoutMs;
        handler.postDelayed(this,
                idleTimeoutMs); // init the timeout runnable in case no callback is ever invoked
        resetTimeout(); // init the timeout runnable in case no callback is ever invoked
    }

    @Override
    public void onDownloadStarted(long bytesToDownload) throws RemoteException {
        callback.onDownloadStarted(bytesToDownload);
        handler.removeCallbacks(this);
        handler.postDelayed(this, idleTimeoutMs);
        resetTimeout();
    }

    @Override
    public void onDownloadProgress(long bytesDownloaded) throws RemoteException {
        callback.onDownloadProgress(bytesDownloaded);
        handler.removeCallbacks(this); // remove previously queued timeout tasks.
        handler.postDelayed(this, idleTimeoutMs); // queue fresh timeout task for next update.
        resetTimeout();
    }

    @Override
    public void onDownloadFailed(int failureStatus,
            String errorMessage, PersistableBundle errorParams) throws RemoteException {
    public void onDownloadFailed(
            int failureStatus, String errorMessage, PersistableBundle errorParams)
            throws RemoteException {
        callback.onDownloadFailed(failureStatus, errorMessage, errorParams);
        handler.removeCallbacks(this);
        future.completeExceptionally(new TimeoutException());
        complete();
    }

    @Override
    public void onDownloadCompleted(
            android.os.PersistableBundle downloadParams) throws RemoteException {
    public void onDownloadCompleted(android.os.PersistableBundle downloadParams)
            throws RemoteException {
        callback.onDownloadCompleted(downloadParams);
        handler.removeCallbacks(this);
        future.complete(null);
        complete();
    }

    @Override
@@ -94,4 +90,14 @@ public class ListenableDownloadCallback extends IDownloadCallback.Stub implement
                new TimeoutException()); // complete the future as we haven't received updates
        // for download progress.
    }

    private void resetTimeout() {
        handler.removeCallbacks(this);
        handler.postDelayed(this, idleTimeoutMs);
    }

    private void complete() {
        handler.removeCallbacks(this);
        future.complete(null);
    }
}
+120 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2025 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.server.ondeviceintelligence.executors;

import android.Manifest;
import android.os.RemoteException;
import android.service.ondeviceintelligence.IOnDeviceSandboxedInferenceService;
import android.util.Slog;

import com.android.internal.infra.AndroidFuture;
import com.android.server.ondeviceintelligence.OnDeviceIntelligenceManagerService;

import java.util.concurrent.TimeoutException;

/**
 * An executor for making remote calls to the {@link IOnDeviceSandboxedInferenceService}.
 *
 * <p>This class centralizes the logic for permission checks, ensuring the remote service is
 * available, and handling various failure scenarios like timeouts or remote exceptions before
 * dispatching the call.
 */
public final class InferenceServiceExecutor
        extends RemoteCallExecutor<IOnDeviceSandboxedInferenceService> {
    private InferenceServiceExecutor(Builder builder) {
        super(builder);
    }

    @Override
    public AndroidFuture<?> execute(
            RemoteCallRunner<IOnDeviceSandboxedInferenceService> remoteCall) {
        OnDeviceIntelligenceManagerService manager =
                OnDeviceIntelligenceManagerService.getInstance();
        manager.getContext()
                .enforceCallingPermission(
                        Manifest.permission.USE_ON_DEVICE_INTELLIGENCE,
                        OnDeviceIntelligenceManagerService.TAG);
        if (!manager.isServiceEnabled()) {
            Slog.w(OnDeviceIntelligenceManagerService.TAG, "Service not available");
            executeOnRemoteExecutor(() -> {
                try {
                    mFailureConsumer.accept(FailureType.SERVICE_UNAVAILABLE);
                } catch (RemoteException e) {
                    Slog.e(OnDeviceIntelligenceManagerService.TAG,
                            "Failed to call service unavailable callback", e);
                }
            });
            return null;
        }
        if (!manager.ensureRemoteInferenceServiceInitialized(/* shouldThrow= */ false)) {
            Slog.w(OnDeviceIntelligenceManagerService.TAG, "Service not available");
            executeOnRemoteExecutor(() -> {
                try {
                    mFailureConsumer.accept(FailureType.SERVICE_UNAVAILABLE);
                } catch (RemoteException e) {
                    Slog.e(OnDeviceIntelligenceManagerService.TAG,
                            "Failed to call service unavailable callback", e);
                }
            });
            return null;
        }
        AndroidFuture<?> future = manager.getRemoteInferenceService().postAsync(remoteCall::run);
        future.whenComplete(
                (res, ex) -> {
                    if (ex != null) {
                        Slog.e(
                                OnDeviceIntelligenceManagerService.TAG,
                                "Remote inference service call failed",
                                ex);
                        if (ex instanceof TimeoutException) {
                            executeOnRemoteExecutor(() -> {
                                try {
                                    mFailureConsumer.accept(FailureType.TIMEOUT);
                                } catch (RemoteException e) {
                                    Slog.e(OnDeviceIntelligenceManagerService.TAG,
                                            "Failed to call timeout callback", e);
                                }
                            });
                        } else {
                            executeOnRemoteExecutor(() -> {
                                try {
                                    mFailureConsumer.accept(FailureType.REMOTE_FAILURE);
                                } catch (RemoteException e) {
                                    Slog.e(OnDeviceIntelligenceManagerService.TAG,
                                            "Failed to call remote failure callback", e);
                                }
                            });
                        }
                    }
                });
        return future;
    }

    /** Builder for {@link InferenceServiceExecutor}. */
    public static class Builder
            extends RemoteCallExecutor.Builder<IOnDeviceSandboxedInferenceService, Builder> {

        public Builder() {
            // empty constructor.
        }

        @Override
        public InferenceServiceExecutor build() {
            return new InferenceServiceExecutor(this);
        }
    }
}
Loading