Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e6ec0721 authored by Cutter Coryell's avatar Cutter Coryell Committed by Automerger Merge Worker
Browse files

Merge "Load back-gesture model on background thread." into tm-qpr-dev am: fddc399c am: 30e2cea7

parents 451d13d0 30e2cea7
Loading
Loading
Loading
Loading
+60 −15
Original line number Diff line number Diff line
@@ -60,6 +60,7 @@ import com.android.internal.policy.GestureNavigationSettingsObserver;
import com.android.internal.util.LatencyTracker;
import com.android.systemui.R;
import com.android.systemui.broadcast.BroadcastDispatcher;
import com.android.systemui.dagger.qualifiers.Background;
import com.android.systemui.dagger.qualifiers.Main;
import com.android.systemui.flags.FeatureFlags;
import com.android.systemui.flags.Flags;
@@ -82,6 +83,7 @@ import com.android.systemui.shared.tracing.ProtoTraceable;
import com.android.systemui.tracing.ProtoTracer;
import com.android.systemui.tracing.nano.EdgeBackGestureHandlerProto;
import com.android.systemui.tracing.nano.SystemUiTraceProto;
import com.android.systemui.util.Assert;
import com.android.wm.shell.back.BackAnimation;
import com.android.wm.shell.pip.Pip;

@@ -191,6 +193,7 @@ public class EdgeBackGestureHandler extends CurrentUserTracker
    private final int mDisplayId;

    private final Executor mMainExecutor;
    private final Executor mBackgroundExecutor;

    private final Rect mPipExcludedBounds = new Rect();
    private final Rect mNavBarOverlayExcludedBounds = new Rect();
@@ -251,6 +254,7 @@ public class EdgeBackGestureHandler extends CurrentUserTracker
    private BackGestureTfClassifierProvider mBackGestureTfClassifierProvider;
    private Map<String, Integer> mVocab;
    private boolean mUseMLModel;
    private boolean mMLModelIsLoading;
    // minimum width below which we do not run the model
    private int mMLEnableWidth;
    private float mMLModelThreshold;
@@ -318,6 +322,7 @@ public class EdgeBackGestureHandler extends CurrentUserTracker
            SysUiState sysUiState,
            PluginManager pluginManager,
            @Main Executor executor,
            @Background Executor backgroundExecutor,
            BroadcastDispatcher broadcastDispatcher,
            ProtoTracer protoTracer,
            NavigationModeController navigationModeController,
@@ -334,6 +339,7 @@ public class EdgeBackGestureHandler extends CurrentUserTracker
        mContext = context;
        mDisplayId = context.getDisplayId();
        mMainExecutor = executor;
        mBackgroundExecutor = backgroundExecutor;
        mOverviewProxyService = overviewProxyService;
        mSysUiState = sysUiState;
        mPluginManager = pluginManager;
@@ -631,28 +637,63 @@ public class EdgeBackGestureHandler extends CurrentUserTracker
            return;
        }

        if (newState) {
            mBackGestureTfClassifierProvider = mBackGestureTfClassifierProviderProvider.get();
            mMLModelThreshold = DeviceConfig.getFloat(DeviceConfig.NAMESPACE_SYSTEMUI,
        mUseMLModel = newState;

        if (mUseMLModel) {
            Assert.isMainThread();
            if (mMLModelIsLoading) {
                Log.d(TAG, "Model tried to load while already loading.");
                return;
            }
            mMLModelIsLoading = true;
            mBackgroundExecutor.execute(() -> loadMLModel());
        } else if (mBackGestureTfClassifierProvider != null) {
            mBackGestureTfClassifierProvider.release();
            mBackGestureTfClassifierProvider = null;
            mVocab = null;
        }
    }

    private void loadMLModel() {
        BackGestureTfClassifierProvider provider = mBackGestureTfClassifierProviderProvider.get();
        float threshold = DeviceConfig.getFloat(DeviceConfig.NAMESPACE_SYSTEMUI,
                SystemUiDeviceConfigFlags.BACK_GESTURE_ML_MODEL_THRESHOLD, 0.9f);
            if (mBackGestureTfClassifierProvider.isActive()) {
        Map<String, Integer> vocab = null;
        if (provider != null && !provider.isActive()) {
            provider.release();
            provider = null;
            Log.w(TAG, "Cannot load model because it isn't active");
        }
        if (provider != null) {
            Trace.beginSection("EdgeBackGestureHandler#loadVocab");
                mVocab = mBackGestureTfClassifierProvider.loadVocab(mContext.getAssets());
            vocab = provider.loadVocab(mContext.getAssets());
            Trace.endSection();
                mUseMLModel = true;
                return;
        }
        BackGestureTfClassifierProvider finalProvider = provider;
        Map<String, Integer> finalVocab = vocab;
        mMainExecutor.execute(() -> onMLModelLoadFinished(finalProvider, finalVocab, threshold));
    }

        mUseMLModel = false;
        if (mBackGestureTfClassifierProvider != null) {
            mBackGestureTfClassifierProvider.release();
            mBackGestureTfClassifierProvider = null;
    private void onMLModelLoadFinished(BackGestureTfClassifierProvider provider,
            Map<String, Integer> vocab, float threshold) {
        Assert.isMainThread();
        mMLModelIsLoading = false;
        if (!mUseMLModel) {
            // This can happen if the user disables Gesture Nav while the model is loading.
            if (provider != null) {
                provider.release();
            }
            Log.d(TAG, "Model finished loading but isn't needed.");
            return;
        }
        mBackGestureTfClassifierProvider = provider;
        mVocab = vocab;
        mMLModelThreshold = threshold;
    }

    private int getBackGesturePredictionsCategory(int x, int y, int app) {
        if (app == -1) {
        BackGestureTfClassifierProvider provider = mBackGestureTfClassifierProvider;
        if (provider == null || app == -1) {
            return -1;
        }
        int distanceFromEdge;
@@ -673,7 +714,7 @@ public class EdgeBackGestureHandler extends CurrentUserTracker
            new long[]{(long) y},
        };

        mMLResults = mBackGestureTfClassifierProvider.predict(featuresVector);
        mMLResults = provider.predict(featuresVector);
        if (mMLResults == -1) {
            return -1;
        }
@@ -1031,6 +1072,7 @@ public class EdgeBackGestureHandler extends CurrentUserTracker
        private final SysUiState mSysUiState;
        private final PluginManager mPluginManager;
        private final Executor mExecutor;
        private final Executor mBackgroundExecutor;
        private final BroadcastDispatcher mBroadcastDispatcher;
        private final ProtoTracer mProtoTracer;
        private final NavigationModeController mNavigationModeController;
@@ -1050,6 +1092,7 @@ public class EdgeBackGestureHandler extends CurrentUserTracker
                       SysUiState sysUiState,
                       PluginManager pluginManager,
                       @Main Executor executor,
                       @Background Executor backgroundExecutor,
                       BroadcastDispatcher broadcastDispatcher,
                       ProtoTracer protoTracer,
                       NavigationModeController navigationModeController,
@@ -1067,6 +1110,7 @@ public class EdgeBackGestureHandler extends CurrentUserTracker
            mSysUiState = sysUiState;
            mPluginManager = pluginManager;
            mExecutor = executor;
            mBackgroundExecutor = backgroundExecutor;
            mBroadcastDispatcher = broadcastDispatcher;
            mProtoTracer = protoTracer;
            mNavigationModeController = navigationModeController;
@@ -1089,6 +1133,7 @@ public class EdgeBackGestureHandler extends CurrentUserTracker
                    mSysUiState,
                    mPluginManager,
                    mExecutor,
                    mBackgroundExecutor,
                    mBroadcastDispatcher,
                    mProtoTracer,
                    mNavigationModeController,