Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c681ce43 authored by George Hodulik's avatar George Hodulik
Browse files

Add AppPredictionServiceResolverComparator



This will sort the share activities based on the APS sorting.
We add a constructor for ResolverListController which takes an
AbstractResolverComparator, so that ChooserActivity may pass in
the APS comparator if it is enabled and available.

Test: Manually tested on APS sorter that did no sorting.
Test: atest frameworks/base/core/tests/coretests/src/com/android/internal/app
Bug: 129014961
Change-Id: I542254ffb0debad45bcd8d5073cc3f3e1bafc616
Signed-off-by: default avatarGeorge Hodulik <georgehodulik@google.com>
parent a00cc123
Loading
Loading
Loading
Loading
+24 −1
Original line number Diff line number Diff line
/*
 * Copyright 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.internal.app;

import android.app.usage.UsageStatsManager;
@@ -20,7 +36,7 @@ abstract class AbstractResolverComparator implements Comparator<ResolvedComponen

    private static final int NUM_OF_TOP_ANNOTATIONS_TO_USE = 3;

    protected AfterCompute mAfterCompute;
    private AfterCompute mAfterCompute;
    protected final PackageManager mPm;
    protected final UsageStatsManager mUsm;
    protected String[] mAnnotations;
@@ -72,6 +88,13 @@ abstract class AbstractResolverComparator implements Comparator<ResolvedComponen
        mAfterCompute = afterCompute;
    }

    protected final void afterCompute() {
        final AfterCompute afterCompute = mAfterCompute;
        if (afterCompute != null) {
            afterCompute.afterCompute();
        }
    }

    @Override
    public final int compare(ResolvedComponentInfo lhsp, ResolvedComponentInfo rhsp) {
        final ResolveInfo lhs = lhsp.getResolveInfoAt(0);
+119 −0
Original line number Diff line number Diff line
/*
 * Copyright 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.internal.app;

import static android.app.prediction.AppTargetEvent.ACTION_LAUNCH;

import android.app.prediction.AppPredictor;
import android.app.prediction.AppTarget;
import android.app.prediction.AppTargetEvent;
import android.app.prediction.AppTargetId;
import android.content.ComponentName;
import android.content.Context;
import android.content.Intent;
import android.content.pm.ResolveInfo;
import android.os.UserHandle;
import android.view.textclassifier.Log;

import com.android.internal.app.ResolverActivity.ResolvedComponentInfo;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * Uses an {@link AppPredictor} to sort Resolver targets.
 */
class AppPredictionServiceResolverComparator extends AbstractResolverComparator {

    private static final String TAG = "APSResolverComparator";

    private final AppPredictor mAppPredictor;
    private final Context mContext;
    private final Map<ComponentName, Integer> mTargetRanks = new HashMap<>();
    private final UserHandle mUser;

    AppPredictionServiceResolverComparator(
                Context context, Intent intent, AppPredictor appPredictor, UserHandle user) {
        super(context, intent);
        mContext = context;
        mAppPredictor = appPredictor;
        mUser = user;
    }

    @Override
    int compare(ResolveInfo lhs, ResolveInfo rhs) {
        Integer lhsRank = mTargetRanks.get(new ComponentName(lhs.activityInfo.packageName,
                lhs.activityInfo.name));
        Integer rhsRank = mTargetRanks.get(new ComponentName(rhs.activityInfo.packageName,
                rhs.activityInfo.name));
        if (lhsRank == null && rhsRank == null) {
            return 0;
        } else if (lhsRank == null) {
            return -1;
        } else if (rhsRank == null) {
            return 1;
        }
        return lhsRank - rhsRank;
    }

    @Override
    void compute(List<ResolvedComponentInfo> targets) {
        List<AppTarget> appTargets = new ArrayList<>();
        for (ResolvedComponentInfo target : targets) {
            appTargets.add(new AppTarget.Builder(new AppTargetId(target.name.flattenToString()))
                    .setTarget(target.name.getPackageName(), mUser)
                    .setClassName(target.name.getClassName()).build());
        }
        mAppPredictor.sortTargets(appTargets, mContext.getMainExecutor(),
                sortedAppTargets -> {
                    for (int i = 0; i < sortedAppTargets.size(); i++) {
                        mTargetRanks.put(new ComponentName(sortedAppTargets.get(i).getPackageName(),
                                sortedAppTargets.get(i).getClassName()), i);
                    }
                    afterCompute();
                });
    }

    @Override
    float getScore(ComponentName name) {
        Integer rank = mTargetRanks.get(name);
        if (rank == null) {
            Log.w(TAG, "Score requested for unknown component.");
            return 0f;
        }
        int consecutiveSumOfRanks = (mTargetRanks.size() - 1) * (mTargetRanks.size()) / 2;
        return 1.0f - (((float) rank) / consecutiveSumOfRanks);
    }

    @Override
    void updateModel(ComponentName componentName) {
        mAppPredictor.notifyAppTargetEvent(
                new AppTargetEvent.Builder(
                    new AppTarget.Builder(
                        new AppTargetId(componentName.toString()),
                        componentName.getPackageName(), mUser)
                        .setClassName(componentName.getClassName()).build(),
                    ACTION_LAUNCH).build());
    }

    @Override
    void destroy() {
        // Do nothing. App Predictor destruction is handled by caller.
    }
}
+27 −4
Original line number Diff line number Diff line
@@ -150,6 +150,7 @@ public class ChooserActivity extends ResolverActivity {
     */
    // TODO(b/123089490): Replace with system flag
    private static final boolean USE_PREDICTION_MANAGER_FOR_DIRECT_TARGETS = false;
    private static final boolean USE_PREDICTION_MANAGER_FOR_SHARE_ACTIVITIES = false;
    // TODO(b/123088566) Share these in a better way.
    private static final String APP_PREDICTION_SHARE_UI_SURFACE = "share";
    public static final String LAUNCH_LOCATON_DIRECT_SHARE = "direct_share";
@@ -1387,6 +1388,15 @@ public class ChooserActivity extends ResolverActivity {
        return USE_PREDICTION_MANAGER_FOR_DIRECT_TARGETS ? getAppPredictor() : null;
    }

    /**
     * This will return an app predictor if it is enabled for share activity sorting
     * and if one exists. Otherwise, it returns null.
     */
    @Nullable
    private AppPredictor getAppPredictorForShareActivitesIfEnabled() {
        return USE_PREDICTION_MANAGER_FOR_SHARE_ACTIVITIES ? getAppPredictor() : null;
    }

    void onRefinementResult(TargetInfo selectedTarget, Intent matchingIntent) {
        if (mRefinementResultReceiver != null) {
            mRefinementResultReceiver.destroy();
@@ -1491,8 +1501,10 @@ public class ChooserActivity extends ResolverActivity {
                PackageManager pm,
                Intent targetIntent,
                String referrerPackageName,
                int launchedFromUid) {
            super(context, pm, targetIntent, referrerPackageName, launchedFromUid);
                int launchedFromUid,
                AbstractResolverComparator resolverComparator) {
            super(context, pm, targetIntent, referrerPackageName, launchedFromUid,
                    resolverComparator);
        }

        @Override
@@ -1520,13 +1532,24 @@ public class ChooserActivity extends ResolverActivity {

    @VisibleForTesting
    protected ResolverListController createListController() {
        AppPredictor appPredictor = getAppPredictorForShareActivitesIfEnabled();
        AbstractResolverComparator resolverComparator;
        if (appPredictor != null) {
            resolverComparator = new AppPredictionServiceResolverComparator(this, getTargetIntent(),
                    appPredictor, getUser());
        } else {
            resolverComparator =
                    new ResolverRankerServiceResolverComparator(this, getTargetIntent(),
                        getReferrerPackageName(), null);
        }

        return new ChooserListController(
                this,
                mPm,
                getTargetIntent(),
                getReferrerPackageName(),
                mLaunchedFromUid
                );
                mLaunchedFromUid,
                resolverComparator);
    }

    @VisibleForTesting
+13 −3
Original line number Diff line number Diff line
@@ -63,14 +63,24 @@ public class ResolverListController {
            Intent targetIntent,
            String referrerPackage,
            int launchedFromUid) {
        this(context, pm, targetIntent, referrerPackage, launchedFromUid,
                    new ResolverRankerServiceResolverComparator(
                        context, targetIntent, referrerPackage, null));
    }

    public ResolverListController(
            Context context,
            PackageManager pm,
            Intent targetIntent,
            String referrerPackage,
            int launchedFromUid,
            AbstractResolverComparator resolverComparator) {
        mContext = context;
        mpm = pm;
        mLaunchedFromUid = launchedFromUid;
        mTargetIntent = targetIntent;
        mReferrerPackage = referrerPackage;
        mResolverComparator =
                new ResolverRankerServiceResolverComparator(
                    mContext, mTargetIntent, mReferrerPackage, null);
        mResolverComparator = resolverComparator;
    }

    @VisibleForTesting
+5 −9
Original line number Diff line number Diff line
@@ -126,7 +126,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
                            Log.e(TAG, "Receiving null prediction results.");
                        }
                        mHandler.removeMessages(RESOLVER_RANKER_RESULT_TIMEOUT);
                        mAfterCompute.afterCompute();
                        afterCompute();
                    }
                    break;

@@ -135,7 +135,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
                        Log.d(TAG, "RESOLVER_RANKER_RESULT_TIMEOUT; unbinding services");
                    }
                    mHandler.removeMessages(RESOLVER_RANKER_SERVICE_RESULT);
                    mAfterCompute.afterCompute();
                    afterCompute();
                    break;

                default:
@@ -149,7 +149,6 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
        super(context, intent);
        mCollator = Collator.getInstance(context.getResources().getConfiguration().locale);
        mReferrerPackage = referrerPackage;
        mAfterCompute = afterCompute;
        mContext = context;

        mCurrentTime = System.currentTimeMillis();
@@ -157,6 +156,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
        mStats = mUsm.queryAndAggregateUsageStats(mSinceTime, mCurrentTime);
        mAction = intent.getAction();
        mRankerServiceName = new ComponentName(mContext, this.getClass());
        setCallBack(afterCompute);
    }

    // compute features for each target according to usage stats of targets.
@@ -328,9 +328,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
            mContext.unbindService(mConnection);
            mConnection.destroy();
        }
        if (mAfterCompute != null) {
            mAfterCompute.afterCompute();
        }
        afterCompute();
        if (DEBUG) {
            Log.d(TAG, "Unbinded Resolver Ranker.");
        }
@@ -513,9 +511,7 @@ class ResolverRankerServiceResolverComparator extends AbstractResolverComparator
                Log.e(TAG, "Error in Predict: " + e);
            }
        }
        if (mAfterCompute != null) {
            mAfterCompute.afterCompute();
        }
        afterCompute();
    }

    // adds select prob as the default values, according to a pre-trained Logistic Regression model.