Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 7d395e2e authored by Maryam Karimzadehgan's avatar Maryam Karimzadehgan
Browse files

Use ML model for the Back Gesture in EdgeBackGestureHandler.

Change-Id: I2fd5255e903932c03a35ae463b0eff3840dc81bd
Test: manual model test and getting the results
b: 150170384
parent bea4e46f
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -419,6 +419,17 @@ public final class SystemUiDeviceConfigFlags {
     */
    public static final String SCREENSHOT_KEYCHORD_DELAY = "screenshot_keychord_delay";

    /**
     * (boolean) Whether to use an ML model for the Back Gesture.
     */
    public static final String USE_BACK_GESTURE_ML_MODEL = "use_back_gesture_ml_model";

    /**
     * (float) Threshold for Back Gesture ML model prediction.
     */
    public static final String BACK_GESTURE_ML_MODEL_THRESHOLD = "back_gesture_ml_model_threshold";


    private SystemUiDeviceConfigFlags() {
    }
}
+20 −6
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@
package com.android.systemui;

import android.content.Context;
import android.content.res.AssetManager;
import android.content.res.Resources;
import android.os.Handler;
import android.util.Log;
@@ -26,6 +27,7 @@ import com.android.systemui.dagger.DaggerGlobalRootComponent;
import com.android.systemui.dagger.GlobalRootComponent;
import com.android.systemui.dagger.SysUIComponent;
import com.android.systemui.dagger.WMComponent;
import com.android.systemui.navigationbar.gestural.BackGestureTfClassifierProvider;
import com.android.systemui.screenshot.ScreenshotNotificationSmartActionsProvider;

import java.util.concurrent.ExecutionException;
@@ -110,12 +112,16 @@ public class SystemUIFactory {
        return mSysUIComponent;
    }

    /** Returns the list of system UI components that should be started. */
    /**
     * Returns the list of system UI components that should be started.
     */
    public String[] getSystemUIServiceComponents(Resources resources) {
        return resources.getStringArray(R.array.config_systemUIServiceComponents);
    }

    /** Returns the list of system UI components that should be started per user. */
    /**
     * Returns the list of system UI components that should be started per user.
     */
    public String[] getSystemUIServiceComponentsPerUser(Resources resources) {
        return resources.getStringArray(R.array.config_systemUIServiceComponentsPerUser);
    }
@@ -125,9 +131,17 @@ public class SystemUIFactory {
     * This method is overridden in vendor specific implementation of Sys UI.
     */
    public ScreenshotNotificationSmartActionsProvider
            createScreenshotNotificationSmartActionsProvider(Context context,
            Executor executor,
            Handler uiHandler) {
                createScreenshotNotificationSmartActionsProvider(
                        Context context, Executor executor, Handler uiHandler) {
        return new ScreenshotNotificationSmartActionsProvider();
    }

    /**
     * Creates an instance of BackGestureTfClassifierProvider.
     * This method is overridden in vendor specific implementation of Sys UI.
     */
    public BackGestureTfClassifierProvider createBackGestureTfClassifierProvider(
            AssetManager am) {
        return new BackGestureTfClassifierProvider();
    }
}
 No newline at end of file
+66 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.android.systemui.navigationbar.gestural;

import android.content.res.AssetManager;

import java.util.HashMap;
import java.util.Map;

/**
 * This class can be overridden by a vendor-specific sys UI implementation,
 * in order to provide classification models for the Back Gesture.
 */
public class BackGestureTfClassifierProvider {
    private static final String TAG = "BackGestureTfClassifierProvider";

    /**
     * Default implementation that returns an empty map.
     * This method is overridden in vendor-specific Sys UI implementation.
     *
     * @param am       An AssetManager to get the vocab file.
    */
    public Map<String, Integer> loadVocab(AssetManager am) {
        return new HashMap<String, Integer>();
    }

    /**
     * This method is overridden in vendor-specific Sys UI implementation.
     *
     * @param featuresVector   List of input features.
     *
    */
    public float predict(Object[] featuresVector) {
        return -1;
    }

    /**
     * Interpreter owns resources. This method releases the resources after
     * use to avoid memory leak.
     * This method is overridden in vendor-specific Sys UI implementation.
     *
     */
    public void release() {}

    /**
     * Returns whether to use the ML model for Back Gesture.
     * This method is overridden in vendor-specific Sys UI implementation.
     *
     */
    public boolean isActive() {
        return false;
    }
}
+120 −18
Original line number Diff line number Diff line
@@ -40,7 +40,6 @@ import android.util.Log;
import android.util.TypedValue;
import android.view.Choreographer;
import android.view.ISystemGestureExclusionListener;
import android.view.InputChannel;
import android.view.InputDevice;
import android.view.InputEvent;
import android.view.InputMonitor;
@@ -56,6 +55,7 @@ import com.android.internal.config.sysui.SystemUiDeviceConfigFlags;
import com.android.internal.policy.GestureNavigationSettingsObserver;
import com.android.systemui.Dependency;
import com.android.systemui.R;
import com.android.systemui.SystemUIFactory;
import com.android.systemui.broadcast.BroadcastDispatcher;
import com.android.systemui.bubbles.BubbleController;
import com.android.systemui.model.SysUiState;
@@ -79,6 +79,7 @@ import com.android.systemui.tracing.nano.SystemUiTraceProto;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;

/**
@@ -120,8 +121,31 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
        public void onTaskStackChanged() {
            mGestureBlockingActivityRunning = isGestureBlockingActivityRunning();
        }
        @Override
        public void onTaskCreated(int taskId, ComponentName componentName) {
            if (componentName != null) {
                mPackageName = componentName.getPackageName();
            } else {
                mPackageName = "_UNKNOWN";
            }
        }
    };

    private DeviceConfig.OnPropertiesChangedListener mOnPropertiesChangedListener =
            new DeviceConfig.OnPropertiesChangedListener() {
                @Override
                public void onPropertiesChanged(DeviceConfig.Properties properties) {
                    if (DeviceConfig.NAMESPACE_SYSTEMUI.equals(properties.getNamespace())
                            && (properties.getKeyset().contains(
                                    SystemUiDeviceConfigFlags.BACK_GESTURE_ML_MODEL_THRESHOLD)
                            || properties.getKeyset().contains(
                                    SystemUiDeviceConfigFlags.USE_BACK_GESTURE_ML_MODEL))) {
                        updateMLModelState();
                    }
                }
            };


    private final Context mContext;
    private final OverviewProxyService mOverviewProxyService;
    private final SysUiState mSysUiState;
@@ -177,6 +201,13 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
    private int mRightInset;
    private int mSysUiFlags;

    // For Tf-Lite model.
    private BackGestureTfClassifierProvider mBackGestureTfClassifierProvider;
    private Map<String, Integer> mVocab;
    private boolean mUseMLModel;
    private float mMLModelThreshold;
    private String mPackageName;

    private final GestureNavigationSettingsObserver mGestureNavigationSettingsObserver;

    private final NavigationEdgeBackPlugin.BackCallback mBackCallback =
@@ -242,7 +273,6 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
                Log.e(TAG, "Failed to add gesture blocking activities", e);
            }
        }

        mLongPressTimeout = Math.min(MAX_LONG_PRESS_TIMEOUT,
                ViewConfiguration.getLongPressTimeout());

@@ -357,6 +387,7 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
            mContext.getSystemService(DisplayManager.class).unregisterDisplayListener(this);
            mPluginManager.removePluginListener(this);
            ActivityManagerWrapper.getInstance().unregisterTaskStackListener(mTaskStackListener);
            DeviceConfig.removeOnPropertiesChangedListener(mOnPropertiesChangedListener);

            try {
                WindowManagerGlobal.getWindowManagerService()
@@ -372,6 +403,9 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
            mContext.getSystemService(DisplayManager.class).registerDisplayListener(this,
                    mContext.getMainThreadHandler());
            ActivityManagerWrapper.getInstance().registerTaskStackListener(mTaskStackListener);
            DeviceConfig.addOnPropertiesChangedListener(DeviceConfig.NAMESPACE_SYSTEMUI,
                    runnable -> (mContext.getMainThreadHandler()).post(runnable),
                    mOnPropertiesChangedListener);

            try {
                WindowManagerGlobal.getWindowManagerService()
@@ -393,6 +427,8 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
            mPluginManager.addPluginListener(
                    this, NavigationEdgeBackPlugin.class, /*allowMultiple=*/ false);
        }
        // Update the ML model resources.
        updateMLModelState();
    }

    @Override
@@ -445,12 +481,73 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
        }
    }

    private void updateMLModelState() {
        boolean newState = mIsEnabled && DeviceConfig.getBoolean(DeviceConfig.NAMESPACE_SYSTEMUI,
                SystemUiDeviceConfigFlags.USE_BACK_GESTURE_ML_MODEL, false);

        if (newState == mUseMLModel) {
            return;
        }

        if (newState) {
            mBackGestureTfClassifierProvider = SystemUIFactory.getInstance()
                    .createBackGestureTfClassifierProvider(mContext.getAssets());
            mMLModelThreshold = DeviceConfig.getFloat(DeviceConfig.NAMESPACE_SYSTEMUI,
                    SystemUiDeviceConfigFlags.BACK_GESTURE_ML_MODEL_THRESHOLD, 0.9f);
            if (mBackGestureTfClassifierProvider.isActive()) {
                mVocab = mBackGestureTfClassifierProvider.loadVocab(mContext.getAssets());
                mUseMLModel = true;
                return;
            }
        }

        mUseMLModel = false;
        if (mBackGestureTfClassifierProvider != null) {
            mBackGestureTfClassifierProvider.release();
            mBackGestureTfClassifierProvider = null;
        }
    }

    private float getBackGesturePredictionsCategory(int x, int y) {
        if (!mVocab.containsKey(mPackageName)) {
            return -1;
        }

        int distanceFromEdge;
        int location;
        if (x <= mDisplaySize.x / 2.0) {
            location = 1;  // left
            distanceFromEdge = x;
        } else {
            location = 2;  // right
            distanceFromEdge = mDisplaySize.x - x;
        }

        Object[] featuresVector = {
            new long[]{(long) mDisplaySize.x},
            new long[]{(long) distanceFromEdge},
            new long[]{(long) location},
            new long[]{(long) mVocab.get(mPackageName)},
            new long[]{(long) y},
        };

        final float results = mBackGestureTfClassifierProvider.predict(featuresVector);
        if (results == -1) return -1;

        return results >= mMLModelThreshold ? 1 : 0;
    }

    private boolean isWithinTouchRegion(int x, int y) {
        boolean withinRange = false;
        float results = -1;

        if (mUseMLModel &&  (results = getBackGesturePredictionsCategory(x, y)) != -1) {
            withinRange = results == 1 ? true : false;
        } else {
            // Disallow if we are in the bottom gesture area
            if (y >= (mDisplaySize.y - mBottomGestureHeight)) {
                return false;
            }

            // If the point is way too far (twice the margin), it is
            // not interesting to us for logging purposes, nor we
            // should process it.  Simply return false and keep
@@ -459,12 +556,12 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
                    && x < (mDisplaySize.x - 2 * (mEdgeWidthRight + mRightInset))) {
                return false;
            }

            // Denotes whether we should proceed with the gesture.
            // Even if it is false, we may want to log it assuming
            // it is not invalid due to exclusion.
        boolean withinRange = x <= mEdgeWidthLeft + mLeftInset
            withinRange = x <= mEdgeWidthLeft + mLeftInset
                    || x >= (mDisplaySize.x - mEdgeWidthRight - mRightInset);
        }

        // Always allow if the user is in a transient sticky immersive state
        if (mIsNavBarShownTransiently) {
@@ -666,6 +763,11 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
        ActivityManager.RunningTaskInfo runningTask =
                ActivityManagerWrapper.getInstance().getRunningTask();
        ComponentName topActivity = runningTask == null ? null : runningTask.topActivity;
        if (topActivity != null) {
            mPackageName = topActivity.getPackageName();
        } else {
            mPackageName = "_UNKNOWN";
        }
        return topActivity != null && mGestureBlockingActivities.contains(topActivity);
    }