Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 7a5aa1f0 authored by Maryam Karimzadehgan's avatar Maryam Karimzadehgan Committed by Android (Google) Code Review
Browse files

Merge "Use ML model for the Back Gesture in EdgeBackGestureHandler."

parents 143e6754 7d395e2e
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -419,6 +419,17 @@ public final class SystemUiDeviceConfigFlags {
     */
    public static final String SCREENSHOT_KEYCHORD_DELAY = "screenshot_keychord_delay";

    /**
     * (boolean) Whether to use an ML model for the Back Gesture.
     */
    public static final String USE_BACK_GESTURE_ML_MODEL = "use_back_gesture_ml_model";

    /**
     * (float) Threshold for Back Gesture ML model prediction.
     */
    public static final String BACK_GESTURE_ML_MODEL_THRESHOLD = "back_gesture_ml_model_threshold";


    private SystemUiDeviceConfigFlags() {
    }
}
+20 −6
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@
package com.android.systemui;

import android.content.Context;
import android.content.res.AssetManager;
import android.content.res.Resources;
import android.os.Handler;
import android.util.Log;
@@ -26,6 +27,7 @@ import com.android.systemui.dagger.DaggerGlobalRootComponent;
import com.android.systemui.dagger.GlobalRootComponent;
import com.android.systemui.dagger.SysUIComponent;
import com.android.systemui.dagger.WMComponent;
import com.android.systemui.navigationbar.gestural.BackGestureTfClassifierProvider;
import com.android.systemui.screenshot.ScreenshotNotificationSmartActionsProvider;

import java.util.concurrent.ExecutionException;
@@ -110,12 +112,16 @@ public class SystemUIFactory {
        return mSysUIComponent;
    }

    /** Returns the list of system UI components that should be started. */
    /**
     * Returns the list of system UI components that should be started.
     */
    public String[] getSystemUIServiceComponents(Resources resources) {
        return resources.getStringArray(R.array.config_systemUIServiceComponents);
    }

    /** Returns the list of system UI components that should be started per user. */
    /**
     * Returns the list of system UI components that should be started per user.
     */
    public String[] getSystemUIServiceComponentsPerUser(Resources resources) {
        return resources.getStringArray(R.array.config_systemUIServiceComponentsPerUser);
    }
@@ -125,9 +131,17 @@ public class SystemUIFactory {
     * This method is overridden in vendor specific implementation of Sys UI.
     */
    public ScreenshotNotificationSmartActionsProvider
            createScreenshotNotificationSmartActionsProvider(Context context,
            Executor executor,
            Handler uiHandler) {
                createScreenshotNotificationSmartActionsProvider(
                        Context context, Executor executor, Handler uiHandler) {
        return new ScreenshotNotificationSmartActionsProvider();
    }

    /**
     * Creates an instance of BackGestureTfClassifierProvider.
     * This method is overridden in vendor specific implementation of Sys UI.
     */
    public BackGestureTfClassifierProvider createBackGestureTfClassifierProvider(
            AssetManager am) {
        return new BackGestureTfClassifierProvider();
    }
}
 No newline at end of file
+66 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.android.systemui.navigationbar.gestural;

import android.content.res.AssetManager;

import java.util.HashMap;
import java.util.Map;

/**
 * This class can be overridden by a vendor-specific sys UI implementation,
 * in order to provide classification models for the Back Gesture.
 */
public class BackGestureTfClassifierProvider {
    private static final String TAG = "BackGestureTfClassifierProvider";

    /**
     * Default implementation that returns an empty map.
     * This method is overridden in vendor-specific Sys UI implementation.
     *
     * @param am       An AssetManager to get the vocab file.
    */
    public Map<String, Integer> loadVocab(AssetManager am) {
        return new HashMap<String, Integer>();
    }

    /**
     * This method is overridden in vendor-specific Sys UI implementation.
     *
     * @param featuresVector   List of input features.
     *
    */
    public float predict(Object[] featuresVector) {
        return -1;
    }

    /**
     * Interpreter owns resources. This method releases the resources after
     * use to avoid memory leak.
     * This method is overridden in vendor-specific Sys UI implementation.
     *
     */
    public void release() {}

    /**
     * Returns whether to use the ML model for Back Gesture.
     * This method is overridden in vendor-specific Sys UI implementation.
     *
     */
    public boolean isActive() {
        return false;
    }
}
+120 −18
Original line number Diff line number Diff line
@@ -40,7 +40,6 @@ import android.util.Log;
import android.util.TypedValue;
import android.view.Choreographer;
import android.view.ISystemGestureExclusionListener;
import android.view.InputChannel;
import android.view.InputDevice;
import android.view.InputEvent;
import android.view.InputMonitor;
@@ -56,6 +55,7 @@ import com.android.internal.config.sysui.SystemUiDeviceConfigFlags;
import com.android.internal.policy.GestureNavigationSettingsObserver;
import com.android.systemui.Dependency;
import com.android.systemui.R;
import com.android.systemui.SystemUIFactory;
import com.android.systemui.broadcast.BroadcastDispatcher;
import com.android.systemui.bubbles.BubbleController;
import com.android.systemui.model.SysUiState;
@@ -79,6 +79,7 @@ import com.android.systemui.tracing.nano.SystemUiTraceProto;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;

/**
@@ -120,8 +121,31 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
        public void onTaskStackChanged() {
            mGestureBlockingActivityRunning = isGestureBlockingActivityRunning();
        }
        @Override
        public void onTaskCreated(int taskId, ComponentName componentName) {
            if (componentName != null) {
                mPackageName = componentName.getPackageName();
            } else {
                mPackageName = "_UNKNOWN";
            }
        }
    };

    private DeviceConfig.OnPropertiesChangedListener mOnPropertiesChangedListener =
            new DeviceConfig.OnPropertiesChangedListener() {
                @Override
                public void onPropertiesChanged(DeviceConfig.Properties properties) {
                    if (DeviceConfig.NAMESPACE_SYSTEMUI.equals(properties.getNamespace())
                            && (properties.getKeyset().contains(
                                    SystemUiDeviceConfigFlags.BACK_GESTURE_ML_MODEL_THRESHOLD)
                            || properties.getKeyset().contains(
                                    SystemUiDeviceConfigFlags.USE_BACK_GESTURE_ML_MODEL))) {
                        updateMLModelState();
                    }
                }
            };


    private final Context mContext;
    private final OverviewProxyService mOverviewProxyService;
    private final SysUiState mSysUiState;
@@ -177,6 +201,13 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
    private int mRightInset;
    private int mSysUiFlags;

    // For Tf-Lite model.
    private BackGestureTfClassifierProvider mBackGestureTfClassifierProvider;
    private Map<String, Integer> mVocab;
    private boolean mUseMLModel;
    private float mMLModelThreshold;
    private String mPackageName;

    private final GestureNavigationSettingsObserver mGestureNavigationSettingsObserver;

    private final NavigationEdgeBackPlugin.BackCallback mBackCallback =
@@ -242,7 +273,6 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
                Log.e(TAG, "Failed to add gesture blocking activities", e);
            }
        }

        mLongPressTimeout = Math.min(MAX_LONG_PRESS_TIMEOUT,
                ViewConfiguration.getLongPressTimeout());

@@ -357,6 +387,7 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
            mContext.getSystemService(DisplayManager.class).unregisterDisplayListener(this);
            mPluginManager.removePluginListener(this);
            ActivityManagerWrapper.getInstance().unregisterTaskStackListener(mTaskStackListener);
            DeviceConfig.removeOnPropertiesChangedListener(mOnPropertiesChangedListener);

            try {
                WindowManagerGlobal.getWindowManagerService()
@@ -372,6 +403,9 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
            mContext.getSystemService(DisplayManager.class).registerDisplayListener(this,
                    mContext.getMainThreadHandler());
            ActivityManagerWrapper.getInstance().registerTaskStackListener(mTaskStackListener);
            DeviceConfig.addOnPropertiesChangedListener(DeviceConfig.NAMESPACE_SYSTEMUI,
                    runnable -> (mContext.getMainThreadHandler()).post(runnable),
                    mOnPropertiesChangedListener);

            try {
                WindowManagerGlobal.getWindowManagerService()
@@ -393,6 +427,8 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
            mPluginManager.addPluginListener(
                    this, NavigationEdgeBackPlugin.class, /*allowMultiple=*/ false);
        }
        // Update the ML model resources.
        updateMLModelState();
    }

    @Override
@@ -445,12 +481,73 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
        }
    }

    private void updateMLModelState() {
        boolean newState = mIsEnabled && DeviceConfig.getBoolean(DeviceConfig.NAMESPACE_SYSTEMUI,
                SystemUiDeviceConfigFlags.USE_BACK_GESTURE_ML_MODEL, false);

        if (newState == mUseMLModel) {
            return;
        }

        if (newState) {
            mBackGestureTfClassifierProvider = SystemUIFactory.getInstance()
                    .createBackGestureTfClassifierProvider(mContext.getAssets());
            mMLModelThreshold = DeviceConfig.getFloat(DeviceConfig.NAMESPACE_SYSTEMUI,
                    SystemUiDeviceConfigFlags.BACK_GESTURE_ML_MODEL_THRESHOLD, 0.9f);
            if (mBackGestureTfClassifierProvider.isActive()) {
                mVocab = mBackGestureTfClassifierProvider.loadVocab(mContext.getAssets());
                mUseMLModel = true;
                return;
            }
        }

        mUseMLModel = false;
        if (mBackGestureTfClassifierProvider != null) {
            mBackGestureTfClassifierProvider.release();
            mBackGestureTfClassifierProvider = null;
        }
    }

    private float getBackGesturePredictionsCategory(int x, int y) {
        if (!mVocab.containsKey(mPackageName)) {
            return -1;
        }

        int distanceFromEdge;
        int location;
        if (x <= mDisplaySize.x / 2.0) {
            location = 1;  // left
            distanceFromEdge = x;
        } else {
            location = 2;  // right
            distanceFromEdge = mDisplaySize.x - x;
        }

        Object[] featuresVector = {
            new long[]{(long) mDisplaySize.x},
            new long[]{(long) distanceFromEdge},
            new long[]{(long) location},
            new long[]{(long) mVocab.get(mPackageName)},
            new long[]{(long) y},
        };

        final float results = mBackGestureTfClassifierProvider.predict(featuresVector);
        if (results == -1) return -1;

        return results >= mMLModelThreshold ? 1 : 0;
    }

    private boolean isWithinTouchRegion(int x, int y) {
        boolean withinRange = false;
        float results = -1;

        if (mUseMLModel &&  (results = getBackGesturePredictionsCategory(x, y)) != -1) {
            withinRange = results == 1 ? true : false;
        } else {
            // Disallow if we are in the bottom gesture area
            if (y >= (mDisplaySize.y - mBottomGestureHeight)) {
                return false;
            }

            // If the point is way too far (twice the margin), it is
            // not interesting to us for logging purposes, nor we
            // should process it.  Simply return false and keep
@@ -459,12 +556,12 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
                    && x < (mDisplaySize.x - 2 * (mEdgeWidthRight + mRightInset))) {
                return false;
            }

            // Denotes whether we should proceed with the gesture.
            // Even if it is false, we may want to log it assuming
            // it is not invalid due to exclusion.
        boolean withinRange = x <= mEdgeWidthLeft + mLeftInset
            withinRange = x <= mEdgeWidthLeft + mLeftInset
                    || x >= (mDisplaySize.x - mEdgeWidthRight - mRightInset);
        }

        // Always allow if the user is in a transient sticky immersive state
        if (mIsNavBarShownTransiently) {
@@ -666,6 +763,11 @@ public class EdgeBackGestureHandler extends CurrentUserTracker implements Displa
        ActivityManager.RunningTaskInfo runningTask =
                ActivityManagerWrapper.getInstance().getRunningTask();
        ComponentName topActivity = runningTask == null ? null : runningTask.topActivity;
        if (topActivity != null) {
            mPackageName = topActivity.getPackageName();
        } else {
            mPackageName = "_UNKNOWN";
        }
        return topActivity != null && mGestureBlockingActivities.contains(topActivity);
    }