Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 97a8d703 authored by Kang Li's avatar Kang Li Committed by Android (Google) Code Review
Browse files

Merge "Rank apps by Logistic Regression for Smart-Sharing."

parents 525aa0d9 a2c7774d
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -565,6 +565,7 @@ public class ChooserActivity extends ResolverActivity {
            if (ri != null && ri.activityInfo != null) {
                usageStatsManager.reportChooserSelection(ri.activityInfo.packageName, getUserId(),
                        annotation, null, info.getResolvedIntent().getAction());
                mResolverComparator.updateModel(info.getResolvedComponentName());
                if (DEBUG) {
                    Log.d(TAG, "ResolveInfo Package is" + ri.activityInfo.packageName);
                }
+1 −0
Original line number Diff line number Diff line
@@ -107,6 +107,7 @@ public class ResolverActivity extends Activity {
    private PickTargetOptionRequest mPickOptionRequest;
    private String mReferrerPackage;

    protected ResolverComparator mResolverComparator;
    protected ResolverDrawerLayout mResolverDrawerLayout;
    protected String mContentType;
    protected PackageManager mPm;
+154 −22
Original line number Diff line number Diff line
@@ -27,11 +27,16 @@ import android.content.pm.ApplicationInfo;
import android.content.pm.ComponentInfo;
import android.content.pm.PackageManager;
import android.content.pm.ResolveInfo;
import android.content.SharedPreferences;
import android.os.Environment;
import android.os.storage.StorageManager;
import android.os.UserHandle;
import android.text.TextUtils;
import android.util.ArrayMap;
import android.util.Log;
import com.android.internal.app.ResolverActivity.ResolvedComponentInfo;

import java.io.File;
import java.text.Collator;
import java.util.ArrayList;
import java.util.Comparator;
@@ -54,6 +59,12 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {

    private static final float RECENCY_MULTIPLIER = 2.f;

    // feature names used in ranking.
    private static final String LAUNCH_SCORE = "launch";
    private static final String TIME_SPENT_SCORE = "timeSpent";
    private static final String RECENCY_SCORE = "recency";
    private static final String CHOOSER_SCORE = "chooser";

    private final Collator mCollator;
    private final boolean mHttp;
    private final PackageManager mPm;
@@ -65,6 +76,7 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {
    private final String mReferrerPackage;
    public String mContentType;
    private String mAction;
    private LogisticRegressionAppRanker mRanker;

    public ResolverComparator(Context context, Intent intent, String referrerPackage) {
        mCollator = Collator.getInstance(context.getResources().getConfiguration().locale);
@@ -80,6 +92,7 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {
        mStats = mUsm.queryAndAggregateUsageStats(mSinceTime, mCurrentTime);
        mContentType = intent.getType();
        mAction = intent.getAction();
        mRanker = new LogisticRegressionAppRanker(context);
    }

    public void compute(List<ResolvedComponentInfo> targets) {
@@ -152,16 +165,13 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {
        for (ScoredTarget target : mScoredTargets.values()) {
            final float recency = (float) Math.max(target.lastTimeUsed - recentSinceTime, 0)
                    / (mostRecentlyUsedTime - recentSinceTime);
            final float recencyScore = recency * recency * RECENCY_MULTIPLIER;
            final float usageTimeScore = (float) target.timeSpent / mostTimeSpent;
            final float launchCountScore = (float) target.launchCount / mostLaunched;

            target.score = recencyScore + usageTimeScore + launchCountScore;
            target.setFeatures((float) target.launchCount / mostLaunched,
                    (float) target.timeSpent / mostTimeSpent,
                    recency * recency * RECENCY_MULTIPLIER,
                    (float) target.chooserCount / mostSelected);
            target.selectProb = mRanker.predict(target.getFeatures());
            if (DEBUG) {
                Log.d(TAG, "Scores: recencyScore: " + recencyScore
                        + " usageTimeScore: " + usageTimeScore
                        + " launchCountScore: " + launchCountScore
                        + " - " + target);
                Log.d(TAG, "Scores: " + target);
            }
        }
    }
@@ -215,17 +225,11 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {
                final ScoredTarget rhsTarget = mScoredTargets.get(new ComponentName(
                        rhs.activityInfo.packageName, rhs.activityInfo.name));

                final int chooserCountDiff = Long.compare(
                        rhsTarget.chooserCount, lhsTarget.chooserCount);

                if (chooserCountDiff != 0) {
                    return chooserCountDiff > 0 ? 1 : -1;
                }

                final int diff = Float.compare(rhsTarget.score, lhsTarget.score);
                final int selectProbDiff = Float.compare(
                        rhsTarget.selectProb, lhsTarget.selectProb);

                if (diff != 0) {
                    return diff > 0 ? 1 : -1;
                if (selectProbDiff != 0) {
                    return selectProbDiff > 0 ? 1 : -1;
                }
            }
        }
@@ -241,32 +245,160 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {
    public float getScore(ComponentName name) {
        final ScoredTarget target = mScoredTargets.get(name);
        if (target != null) {
            return target.score;
            return target.selectProb;
        }
        return 0;
    }

    static class ScoredTarget {
        public final ComponentInfo componentInfo;
        public float score;
        public long lastTimeUsed;
        public long timeSpent;
        public long launchCount;
        public long chooserCount;
        public ArrayMap<String, Float> features;
        public float selectProb;

        public ScoredTarget(ComponentInfo ci) {
            componentInfo = ci;
            features = new ArrayMap<>(5);
        }

        @Override
        public String toString() {
            return "ScoredTarget{" + componentInfo
                    + " score: " + score
                    + " lastTimeUsed: " + lastTimeUsed
                    + " timeSpent: " + timeSpent
                    + " launchCount: " + launchCount
                    + " chooserCount: " + chooserCount
                    + " selectProb: " + selectProb
                    + "}";
        }

        public void setFeatures(float launchCountScore, float usageTimeScore, float recencyScore,
                                float chooserCountScore) {
            features.put(LAUNCH_SCORE, launchCountScore);
            features.put(TIME_SPENT_SCORE, usageTimeScore);
            features.put(RECENCY_SCORE, recencyScore);
            features.put(CHOOSER_SCORE, chooserCountScore);
        }

        public ArrayMap<String, Float> getFeatures() {
            return features;
        }
    }

    public void updateModel(ComponentName componentName) {
        if (mScoredTargets == null || componentName == null ||
                !mScoredTargets.containsKey(componentName)) {
            return;
        }
        ScoredTarget selected = mScoredTargets.get(componentName);
        for (ComponentName targetComponent : mScoredTargets.keySet()) {
            if (targetComponent.equals(componentName)) {
                continue;
            }
            ScoredTarget target = mScoredTargets.get(targetComponent);
            // A potential point of optimization. Save updates or derive a closed form for the
            // positive case, to avoid calculating them repeatedly.
            if (target.selectProb >= selected.selectProb) {
                mRanker.update(target.getFeatures(), target.selectProb, false);
                mRanker.update(selected.getFeatures(), selected.selectProb, true);
            }
        }
        mRanker.commitUpdate();
    }

    class LogisticRegressionAppRanker {
        private static final String PARAM_SHARED_PREF_NAME = "resolver_ranker_params";
        private static final String BIAS_PREF_KEY = "bias";
        private static final float LEARNING_RATE = 0.02f;
        private static final float REGULARIZER_PARAM = 0.1f;
        private SharedPreferences mParamSharedPref;
        private ArrayMap<String, Float> mFeatureWeights;
        private float mBias;

        public LogisticRegressionAppRanker(Context context) {
            mParamSharedPref = getParamSharedPref(context);
        }

        public float predict(ArrayMap<String, Float> target) {
            if (target == null || mParamSharedPref == null) {
                return 0.0f;
            }
            final int featureSize = target.size();
            if (featureSize == 0) {
                return 0.0f;
            }
            float sum = 0.0f;
            if (mFeatureWeights == null) {
                mBias = mParamSharedPref.getFloat(BIAS_PREF_KEY, 0.0f);
                mFeatureWeights = new ArrayMap<>(featureSize);
                for (int i = 0; i < featureSize; i++) {
                    String featureName = target.keyAt(i);
                    float weight = mParamSharedPref.getFloat(featureName, 0.0f);
                    sum += weight * target.valueAt(i);
                    mFeatureWeights.put(featureName, weight);
                }
            } else {
                for (int i = 0; i < featureSize; i++) {
                    String featureName = target.keyAt(i);
                    float weight = mFeatureWeights.getOrDefault(featureName, 0.0f);
                    sum += weight * target.valueAt(i);
                }
            }
            return (float) (1.0 / (1.0 + Math.exp(-mBias - sum)));
        }

        public void update(ArrayMap<String, Float> target, float predict, boolean isSelected) {
            if (target == null || target.size() == 0) {
                return;
            }
            final int featureSize = target.size();
            if (mFeatureWeights == null) {
                mBias = 0.0f;
                mFeatureWeights = new ArrayMap<>(featureSize);
            }
            float error = isSelected ? 1.0f - predict : -predict;
            for (int i = 0; i < featureSize; i++) {
                String featureName = target.keyAt(i);
                float currentWeight = mFeatureWeights.getOrDefault(featureName, 0.0f);
                mBias += LEARNING_RATE * error;
                currentWeight = currentWeight - LEARNING_RATE * REGULARIZER_PARAM * currentWeight +
                        LEARNING_RATE * error * target.valueAt(i);
                mFeatureWeights.put(featureName, currentWeight);
            }
            if (DEBUG) {
                Log.d(TAG, "Weights: " + mFeatureWeights + " Bias: " + mBias);
            }
        }

        public void commitUpdate() {
            if (mFeatureWeights == null || mFeatureWeights.size() == 0) {
                return;
            }
            SharedPreferences.Editor editor = mParamSharedPref.edit();
            editor.putFloat(BIAS_PREF_KEY, mBias);
            final int size = mFeatureWeights.size();
            for (int i = 0; i < size; i++) {
                editor.putFloat(mFeatureWeights.keyAt(i), mFeatureWeights.valueAt(i));
            }
            editor.apply();
        }

        private SharedPreferences getParamSharedPref(Context context) {
            // The package info in the context isn't initialized in the way it is for normal apps,
            // so the standard, name-based context.getSharedPreferences doesn't work. Instead, we
            // build the path manually below using the same policy that appears in ContextImpl.
            if (DEBUG) {
                Log.d(TAG, "Context Package Name: " + context.getPackageName());
            }
            final File prefsFile = new File(new File(
                    Environment.getDataUserCePackageDirectory(StorageManager.UUID_PRIVATE_INTERNAL,
                            context.getUserId(), context.getPackageName()),
                    "shared_prefs"),
                    PARAM_SHARED_PREF_NAME + ".xml");
            return context.getSharedPreferences(prefsFile, Context.MODE_PRIVATE);
        }
    }
}