Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 42b985dd authored by sandeepbandaru's avatar sandeepbandaru
Browse files

Handle Model Loading/Unloading Broadcasts for Samsung in Isolated AICore

This change registers a callback with the remote service via
updateProcessingState to receive model loading/unloading updates and
then send these as broadcasts to system process.

Bug: 339688752
Test: cts in topic
Change-Id: I84ae64c5f5dfcbc7005a912b31861b3b64bf61f1
parent ee2aaacd
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -105,6 +105,21 @@ public abstract class OnDeviceSandboxedInferenceService extends Service {
    public static final String SERVICE_INTERFACE =
            "android.service.ondeviceintelligence.OnDeviceSandboxedInferenceService";

    // TODO(339594686): make API
    /**
     * @hide
     */
    public static final String REGISTER_MODEL_UPDATE_CALLBACK_BUNDLE_KEY =
            "register_model_update_callback";
    /**
     * @hide
     */
    public static final String MODEL_LOADED_BUNDLE_KEY = "model_loaded";
    /**
     * @hide
     */
    public static final String MODEL_UNLOADED_BUNDLE_KEY = "model_unloaded";

    private IRemoteStorageService mRemoteStorageService;

    /**
+7 −0
Original line number Diff line number Diff line
@@ -4704,6 +4704,13 @@
    <!-- The component name for the default system on-device sandboxed inference service. -->
    <string name="config_defaultOnDeviceSandboxedInferenceService" translatable="false"></string>

    <!-- The broadcast intent name for notifying when the on-device model is loading  -->
    <string name="config_onDeviceIntelligenceModelLoadedBroadcastKey" translatable="false"></string>

    <!-- The broadcast intent name for notifying when the on-device model has been unloaded  -->
    <string name="config_onDeviceIntelligenceModelUnloadedBroadcastKey" translatable="false"></string>


    <!-- Component name that accepts ACTION_SEND intents for requesting ambient context consent for
         wearable sensing. -->
    <string translatable="false" name="config_defaultWearableSensingConsentComponent"></string>
+2 −0
Original line number Diff line number Diff line
@@ -3941,6 +3941,8 @@
  <java-symbol type="string" name="config_defaultWearableSensingService" />
  <java-symbol type="string" name="config_defaultOnDeviceIntelligenceService" />
  <java-symbol type="string" name="config_defaultOnDeviceSandboxedInferenceService" />
  <java-symbol type="string" name="config_onDeviceIntelligenceModelLoadedBroadcastKey" />
  <java-symbol type="string" name="config_onDeviceIntelligenceModelUnloadedBroadcastKey" />
  <java-symbol type="string" name="config_retailDemoPackage" />
  <java-symbol type="string" name="config_retailDemoPackageSignature" />

+125 −17
Original line number Diff line number Diff line
@@ -16,6 +16,10 @@

package com.android.server.ondeviceintelligence;

import static android.service.ondeviceintelligence.OnDeviceSandboxedInferenceService.MODEL_LOADED_BUNDLE_KEY;
import static android.service.ondeviceintelligence.OnDeviceSandboxedInferenceService.MODEL_UNLOADED_BUNDLE_KEY;
import static android.service.ondeviceintelligence.OnDeviceSandboxedInferenceService.REGISTER_MODEL_UPDATE_CALLBACK_BUNDLE_KEY;

import static com.android.server.ondeviceintelligence.BundleUtil.sanitizeInferenceParams;
import static com.android.server.ondeviceintelligence.BundleUtil.validatePfdReadOnly;
import static com.android.server.ondeviceintelligence.BundleUtil.sanitizeStateParams;
@@ -41,6 +45,7 @@ import android.app.ondeviceintelligence.ITokenInfoCallback;
import android.app.ondeviceintelligence.OnDeviceIntelligenceException;
import android.content.ComponentName;
import android.content.Context;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.content.pm.ServiceInfo;
import android.content.res.Resources;
@@ -105,12 +110,20 @@ public class OnDeviceIntelligenceManagerService extends SystemService {
    /** Handler message to {@link #resetTemporaryServices()} */
    private static final int MSG_RESET_TEMPORARY_SERVICE = 0;

    /** Handler message to clean up temporary broadcast keys. */
    private static final int MSG_RESET_BROADCAST_KEYS = 1;

    /** Default value in absence of {@link DeviceConfig} override. */
    private static final boolean DEFAULT_SERVICE_ENABLED = true;
    private static final String NAMESPACE_ON_DEVICE_INTELLIGENCE = "ondeviceintelligence";

    private static final String SYSTEM_PACKAGE = "android";


    private final Executor resourceClosingExecutor = Executors.newCachedThreadPool();
    private final Executor callbackExecutor = Executors.newCachedThreadPool();
    private final Executor broadcastExecutor = Executors.newCachedThreadPool();


    private final Context mContext;
    protected final Object mLock = new Object();
@@ -123,10 +136,14 @@ public class OnDeviceIntelligenceManagerService extends SystemService {
    @GuardedBy("mLock")
    private String[] mTemporaryServiceNames;

    @GuardedBy("mLock")
    private String[] mTemporaryBroadcastKeys;
    @GuardedBy("mLock")
    private String mBroadcastPackageName;

    /**
     * Handler used to reset the temporary service names.
     */
    @GuardedBy("mLock")
    private Handler mTemporaryHandler;

    public OnDeviceIntelligenceManagerService(Context context) {
@@ -482,6 +499,8 @@ public class OnDeviceIntelligenceManagerService extends SystemService {
                                    ensureRemoteIntelligenceServiceInitialized();
                                    mRemoteOnDeviceIntelligenceService.run(
                                            IOnDeviceIntelligenceService::notifyInferenceServiceConnected);
                                    broadcastExecutor.execute(
                                            () -> registerModelLoadingBroadcasts(service));
                                    service.registerRemoteStorageService(
                                            getIRemoteStorageService());
                                } catch (RemoteException ex) {
@@ -493,6 +512,56 @@ public class OnDeviceIntelligenceManagerService extends SystemService {
        }
    }

    private void registerModelLoadingBroadcasts(IOnDeviceSandboxedInferenceService service) {
        String[] modelBroadcastKeys;
        try {
            modelBroadcastKeys = getBroadcastKeys();
        } catch (Resources.NotFoundException e) {
            Slog.d(TAG, "Skipping model broadcasts as broadcast intents configured.");
            return;
        }

        Bundle bundle = new Bundle();
        bundle.putBoolean(REGISTER_MODEL_UPDATE_CALLBACK_BUNDLE_KEY, true);
        try {
            service.updateProcessingState(bundle, new IProcessingUpdateStatusCallback.Stub() {
                @Override
                public void onSuccess(PersistableBundle statusParams) {
                    Binder.clearCallingIdentity();
                    synchronized (mLock) {
                        if (statusParams.containsKey(MODEL_LOADED_BUNDLE_KEY)) {
                            String modelLoadedBroadcastKey = modelBroadcastKeys[0];
                            if (modelLoadedBroadcastKey != null
                                    && !modelLoadedBroadcastKey.isEmpty()) {
                                final Intent intent = new Intent(modelLoadedBroadcastKey);
                                intent.setPackage(mBroadcastPackageName);
                                mContext.sendBroadcast(intent,
                                        Manifest.permission.USE_ON_DEVICE_INTELLIGENCE);
                            }
                        } else if (statusParams.containsKey(MODEL_UNLOADED_BUNDLE_KEY)) {
                            String modelUnloadedBroadcastKey = modelBroadcastKeys[1];
                            if (modelUnloadedBroadcastKey != null
                                    && !modelUnloadedBroadcastKey.isEmpty()) {
                                final Intent intent = new Intent(modelUnloadedBroadcastKey);
                                intent.setPackage(mBroadcastPackageName);
                                mContext.sendBroadcast(intent,
                                        Manifest.permission.USE_ON_DEVICE_INTELLIGENCE);
                            }
                        }
                    }
                }

                @Override
                public void onFailure(int errorCode, String errorMessage) {
                    Slog.e(TAG, "Failed to register model loading callback with status code",
                            new OnDeviceIntelligenceException(errorCode, errorMessage));
                }
            });
        } catch (RemoteException e) {
            Slog.e(TAG, "Failed to register model loading callback with status code", e);
        }
    }

    @NonNull
    private IRemoteStorageService.Stub getIRemoteStorageService() {
        return new IRemoteStorageService.Stub() {
@@ -629,6 +698,20 @@ public class OnDeviceIntelligenceManagerService extends SystemService {
                        R.string.config_defaultOnDeviceSandboxedInferenceService)};
    }

    protected String[] getBroadcastKeys() throws Resources.NotFoundException {
        // TODO 329240495 : Consider a small class with explicit field names for the two services
        synchronized (mLock) {
            if (mTemporaryBroadcastKeys != null && mTemporaryBroadcastKeys.length == 2) {
                return mTemporaryBroadcastKeys;
            }
        }

        return new String[]{mContext.getResources().getString(
                R.string.config_onDeviceIntelligenceModelLoadedBroadcastKey),
                mContext.getResources().getString(
                        R.string.config_onDeviceIntelligenceModelUnloadedBroadcastKey)};
    }

    @RequiresPermission(Manifest.permission.USE_ON_DEVICE_INTELLIGENCE)
    public void setTemporaryServices(@NonNull String[] componentNames, int durationMs) {
        Objects.requireNonNull(componentNames);
@@ -645,25 +728,26 @@ public class OnDeviceIntelligenceManagerService extends SystemService {
                mRemoteOnDeviceIntelligenceService.unbind();
                mRemoteOnDeviceIntelligenceService = null;
            }
            if (mTemporaryHandler == null) {
                mTemporaryHandler = new Handler(Looper.getMainLooper(), null, true) {
                    @Override
                    public void handleMessage(Message msg) {
                        if (msg.what == MSG_RESET_TEMPORARY_SERVICE) {
                            synchronized (mLock) {
                                resetTemporaryServices();
                            }
                        } else {
                            Slog.wtf(TAG, "invalid handler msg: " + msg);

            if (durationMs != -1) {
                getTemporaryHandler().sendEmptyMessageDelayed(MSG_RESET_TEMPORARY_SERVICE,
                        durationMs);
            }
        }
                };
            } else {
                mTemporaryHandler.removeMessages(MSG_RESET_TEMPORARY_SERVICE);
    }

    @RequiresPermission(Manifest.permission.USE_ON_DEVICE_INTELLIGENCE)
    public void setModelBroadcastKeys(@NonNull String[] broadcastKeys, String receiverPackageName,
            int durationMs) {
        Objects.requireNonNull(broadcastKeys);
        enforceShellOnly(Binder.getCallingUid(), "setModelBroadcastKeys");
        mContext.enforceCallingPermission(
                Manifest.permission.USE_ON_DEVICE_INTELLIGENCE, TAG);
        synchronized (mLock) {
            mTemporaryBroadcastKeys = broadcastKeys;
            mBroadcastPackageName = receiverPackageName;
            if (durationMs != -1) {
                mTemporaryHandler.sendEmptyMessageDelayed(MSG_RESET_TEMPORARY_SERVICE, durationMs);
                getTemporaryHandler().sendEmptyMessageDelayed(MSG_RESET_BROADCAST_KEYS, durationMs);
            }
        }
    }
@@ -751,4 +835,28 @@ public class OnDeviceIntelligenceManagerService extends SystemService {
            }
        }
    }

    private synchronized Handler getTemporaryHandler() {
        if (mTemporaryHandler == null) {
            mTemporaryHandler = new Handler(Looper.getMainLooper(), null, true) {
                @Override
                public void handleMessage(Message msg) {
                    if (msg.what == MSG_RESET_TEMPORARY_SERVICE) {
                        synchronized (mLock) {
                            resetTemporaryServices();
                        }
                    } else if (msg.what == MSG_RESET_BROADCAST_KEYS) {
                        synchronized (mLock) {
                            mTemporaryBroadcastKeys = null;
                            mBroadcastPackageName = SYSTEM_PACKAGE;
                        }
                    } else {
                        Slog.wtf(TAG, "invalid handler msg: " + msg);
                    }
                }
            };
        }

        return mTemporaryHandler;
    }
}
+28 −1
Original line number Diff line number Diff line
@@ -43,6 +43,8 @@ final class OnDeviceIntelligenceShellCommand extends ShellCommand {
                return setTemporaryServices();
            case "get-services":
                return getConfiguredServices();
            case "set-model-broadcasts":
                return setBroadcastKeys();
            default:
                return handleDefaultCommands(cmd);
        }
@@ -62,12 +64,18 @@ final class OnDeviceIntelligenceShellCommand extends ShellCommand {
        pw.println("    To reset, call without any arguments.");

        pw.println("  get-services To get the names of services that are currently being used.");
        pw.println(
                "  set-model-broadcasts [ModelLoadedBroadcastKey] [ModelUnloadedBroadcastKey] "
                        + "[ReceiverPackageName] "
                        + "[DURATION] To set the names of broadcast intent keys that are to be "
                        + "emitted for cts tests.");
    }

    private int setTemporaryServices() {
        final PrintWriter out = getOutPrintWriter();
        final String intelligenceServiceName = getNextArg();
        final String inferenceServiceName = getNextArg();

        if (getRemainingArgsCount() == 0 && intelligenceServiceName == null
                && inferenceServiceName == null) {
            mService.resetTemporaryServices();
@@ -79,7 +87,8 @@ final class OnDeviceIntelligenceShellCommand extends ShellCommand {
        Objects.requireNonNull(inferenceServiceName);
        final int duration = Integer.parseInt(getNextArgRequired());
        mService.setTemporaryServices(
                new String[]{intelligenceServiceName, inferenceServiceName}, duration);
                new String[]{intelligenceServiceName, inferenceServiceName},
                duration);
        out.println("OnDeviceIntelligenceService temporarily set to " + intelligenceServiceName
                + " \n and \n OnDeviceTrustedInferenceService set to " + inferenceServiceName
                + " for " + duration + "ms");
@@ -93,4 +102,22 @@ final class OnDeviceIntelligenceShellCommand extends ShellCommand {
                + " \n and \n OnDeviceTrustedInferenceService set to : " + services[1]);
        return 0;
    }

    private int setBroadcastKeys() {
        final PrintWriter out = getOutPrintWriter();
        final String modelLoadedKey = getNextArgRequired();
        final String modelUnloadedKey = getNextArgRequired();
        final String receiverPackageName = getNextArg();

        final int duration = Integer.parseInt(getNextArgRequired());
        mService.setModelBroadcastKeys(
                new String[]{modelLoadedKey, modelUnloadedKey}, receiverPackageName, duration);
        out.println("OnDeviceIntelligence Model Loading broadcast keys temporarily set to "
                + modelLoadedKey
                + " \n and \n OnDeviceTrustedInferenceService set to " + modelUnloadedKey
                + "\n and Package name set to : " + receiverPackageName
                + " for " + duration + "ms");
        return 0;
    }

}
 No newline at end of file