Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit a2c7774d authored by Kang Li's avatar Kang Li
Browse files

Rank apps by Logistic Regression for Smart-Sharing.

Bug: 30982298
Test: manual - tested by sharing images in Photos and sharing texts in
Chrome.

Change-Id: I9808abdefbc898d3452e684f3462efafdfd53c23
parent 1fd9c8d3
Loading
Loading
Loading
Loading
+1 −0
Original line number Original line Diff line number Diff line
@@ -563,6 +563,7 @@ public class ChooserActivity extends ResolverActivity {
            if (ri != null && ri.activityInfo != null) {
            if (ri != null && ri.activityInfo != null) {
                usageStatsManager.reportChooserSelection(ri.activityInfo.packageName, getUserId(),
                usageStatsManager.reportChooserSelection(ri.activityInfo.packageName, getUserId(),
                        annotation, null, info.getResolvedIntent().getAction());
                        annotation, null, info.getResolvedIntent().getAction());
                mResolverComparator.updateModel(info.getResolvedComponentName());
                if (DEBUG) {
                if (DEBUG) {
                    Log.d(TAG, "ResolveInfo Package is" + ri.activityInfo.packageName);
                    Log.d(TAG, "ResolveInfo Package is" + ri.activityInfo.packageName);
                }
                }
+1 −1
Original line number Original line Diff line number Diff line
@@ -109,10 +109,10 @@ public class ResolverActivity extends Activity {
    private boolean mResolvingHome = false;
    private boolean mResolvingHome = false;
    private int mProfileSwitchMessageId = -1;
    private int mProfileSwitchMessageId = -1;
    private final ArrayList<Intent> mIntents = new ArrayList<>();
    private final ArrayList<Intent> mIntents = new ArrayList<>();
    private ResolverComparator mResolverComparator;
    private PickTargetOptionRequest mPickOptionRequest;
    private PickTargetOptionRequest mPickOptionRequest;
    private ComponentName[] mFilteredComponents;
    private ComponentName[] mFilteredComponents;


    protected ResolverComparator mResolverComparator;
    protected ResolverDrawerLayout mResolverDrawerLayout;
    protected ResolverDrawerLayout mResolverDrawerLayout;
    protected String mContentType;
    protected String mContentType;


+154 −22
Original line number Original line Diff line number Diff line
@@ -27,11 +27,16 @@ import android.content.pm.ApplicationInfo;
import android.content.pm.ComponentInfo;
import android.content.pm.ComponentInfo;
import android.content.pm.PackageManager;
import android.content.pm.PackageManager;
import android.content.pm.ResolveInfo;
import android.content.pm.ResolveInfo;
import android.content.SharedPreferences;
import android.os.Environment;
import android.os.storage.StorageManager;
import android.os.UserHandle;
import android.os.UserHandle;
import android.text.TextUtils;
import android.text.TextUtils;
import android.util.ArrayMap;
import android.util.Log;
import android.util.Log;
import com.android.internal.app.ResolverActivity.ResolvedComponentInfo;
import com.android.internal.app.ResolverActivity.ResolvedComponentInfo;


import java.io.File;
import java.text.Collator;
import java.text.Collator;
import java.util.ArrayList;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Comparator;
@@ -54,6 +59,12 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {


    private static final float RECENCY_MULTIPLIER = 2.f;
    private static final float RECENCY_MULTIPLIER = 2.f;


    // feature names used in ranking.
    private static final String LAUNCH_SCORE = "launch";
    private static final String TIME_SPENT_SCORE = "timeSpent";
    private static final String RECENCY_SCORE = "recency";
    private static final String CHOOSER_SCORE = "chooser";

    private final Collator mCollator;
    private final Collator mCollator;
    private final boolean mHttp;
    private final boolean mHttp;
    private final PackageManager mPm;
    private final PackageManager mPm;
@@ -65,6 +76,7 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {
    private final String mReferrerPackage;
    private final String mReferrerPackage;
    public String mContentType;
    public String mContentType;
    private String mAction;
    private String mAction;
    private LogisticRegressionAppRanker mRanker;


    public ResolverComparator(Context context, Intent intent, String referrerPackage) {
    public ResolverComparator(Context context, Intent intent, String referrerPackage) {
        mCollator = Collator.getInstance(context.getResources().getConfiguration().locale);
        mCollator = Collator.getInstance(context.getResources().getConfiguration().locale);
@@ -80,6 +92,7 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {
        mStats = mUsm.queryAndAggregateUsageStats(mSinceTime, mCurrentTime);
        mStats = mUsm.queryAndAggregateUsageStats(mSinceTime, mCurrentTime);
        mContentType = intent.getType();
        mContentType = intent.getType();
        mAction = intent.getAction();
        mAction = intent.getAction();
        mRanker = new LogisticRegressionAppRanker(context);
    }
    }


    public void compute(List<ResolvedComponentInfo> targets) {
    public void compute(List<ResolvedComponentInfo> targets) {
@@ -152,16 +165,13 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {
        for (ScoredTarget target : mScoredTargets.values()) {
        for (ScoredTarget target : mScoredTargets.values()) {
            final float recency = (float) Math.max(target.lastTimeUsed - recentSinceTime, 0)
            final float recency = (float) Math.max(target.lastTimeUsed - recentSinceTime, 0)
                    / (mostRecentlyUsedTime - recentSinceTime);
                    / (mostRecentlyUsedTime - recentSinceTime);
            final float recencyScore = recency * recency * RECENCY_MULTIPLIER;
            target.setFeatures((float) target.launchCount / mostLaunched,
            final float usageTimeScore = (float) target.timeSpent / mostTimeSpent;
                    (float) target.timeSpent / mostTimeSpent,
            final float launchCountScore = (float) target.launchCount / mostLaunched;
                    recency * recency * RECENCY_MULTIPLIER,

                    (float) target.chooserCount / mostSelected);
            target.score = recencyScore + usageTimeScore + launchCountScore;
            target.selectProb = mRanker.predict(target.getFeatures());
            if (DEBUG) {
            if (DEBUG) {
                Log.d(TAG, "Scores: recencyScore: " + recencyScore
                Log.d(TAG, "Scores: " + target);
                        + " usageTimeScore: " + usageTimeScore
                        + " launchCountScore: " + launchCountScore
                        + " - " + target);
            }
            }
        }
        }
    }
    }
@@ -215,17 +225,11 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {
                final ScoredTarget rhsTarget = mScoredTargets.get(new ComponentName(
                final ScoredTarget rhsTarget = mScoredTargets.get(new ComponentName(
                        rhs.activityInfo.packageName, rhs.activityInfo.name));
                        rhs.activityInfo.packageName, rhs.activityInfo.name));


                final int chooserCountDiff = Long.compare(
                final int selectProbDiff = Float.compare(
                        rhsTarget.chooserCount, lhsTarget.chooserCount);
                        rhsTarget.selectProb, lhsTarget.selectProb);

                if (chooserCountDiff != 0) {
                    return chooserCountDiff > 0 ? 1 : -1;
                }

                final int diff = Float.compare(rhsTarget.score, lhsTarget.score);


                if (diff != 0) {
                if (selectProbDiff != 0) {
                    return diff > 0 ? 1 : -1;
                    return selectProbDiff > 0 ? 1 : -1;
                }
                }
            }
            }
        }
        }
@@ -241,32 +245,160 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {
    public float getScore(ComponentName name) {
    public float getScore(ComponentName name) {
        final ScoredTarget target = mScoredTargets.get(name);
        final ScoredTarget target = mScoredTargets.get(name);
        if (target != null) {
        if (target != null) {
            return target.score;
            return target.selectProb;
        }
        }
        return 0;
        return 0;
    }
    }


    static class ScoredTarget {
    static class ScoredTarget {
        public final ComponentInfo componentInfo;
        public final ComponentInfo componentInfo;
        public float score;
        public long lastTimeUsed;
        public long lastTimeUsed;
        public long timeSpent;
        public long timeSpent;
        public long launchCount;
        public long launchCount;
        public long chooserCount;
        public long chooserCount;
        public ArrayMap<String, Float> features;
        public float selectProb;


        public ScoredTarget(ComponentInfo ci) {
        public ScoredTarget(ComponentInfo ci) {
            componentInfo = ci;
            componentInfo = ci;
            features = new ArrayMap<>(5);
        }
        }


        @Override
        @Override
        public String toString() {
        public String toString() {
            return "ScoredTarget{" + componentInfo
            return "ScoredTarget{" + componentInfo
                    + " score: " + score
                    + " lastTimeUsed: " + lastTimeUsed
                    + " lastTimeUsed: " + lastTimeUsed
                    + " timeSpent: " + timeSpent
                    + " timeSpent: " + timeSpent
                    + " launchCount: " + launchCount
                    + " launchCount: " + launchCount
                    + " chooserCount: " + chooserCount
                    + " chooserCount: " + chooserCount
                    + " selectProb: " + selectProb
                    + "}";
                    + "}";
        }
        }

        public void setFeatures(float launchCountScore, float usageTimeScore, float recencyScore,
                                float chooserCountScore) {
            features.put(LAUNCH_SCORE, launchCountScore);
            features.put(TIME_SPENT_SCORE, usageTimeScore);
            features.put(RECENCY_SCORE, recencyScore);
            features.put(CHOOSER_SCORE, chooserCountScore);
        }

        public ArrayMap<String, Float> getFeatures() {
            return features;
        }
    }

    public void updateModel(ComponentName componentName) {
        if (mScoredTargets == null || componentName == null ||
                !mScoredTargets.containsKey(componentName)) {
            return;
        }
        ScoredTarget selected = mScoredTargets.get(componentName);
        for (ComponentName targetComponent : mScoredTargets.keySet()) {
            if (targetComponent.equals(componentName)) {
                continue;
            }
            ScoredTarget target = mScoredTargets.get(targetComponent);
            // A potential point of optimization. Save updates or derive a closed form for the
            // positive case, to avoid calculating them repeatedly.
            if (target.selectProb >= selected.selectProb) {
                mRanker.update(target.getFeatures(), target.selectProb, false);
                mRanker.update(selected.getFeatures(), selected.selectProb, true);
            }
        }
        mRanker.commitUpdate();
    }

    class LogisticRegressionAppRanker {
        private static final String PARAM_SHARED_PREF_NAME = "resolver_ranker_params";
        private static final String BIAS_PREF_KEY = "bias";
        private static final float LEARNING_RATE = 0.02f;
        private static final float REGULARIZER_PARAM = 0.1f;
        private SharedPreferences mParamSharedPref;
        private ArrayMap<String, Float> mFeatureWeights;
        private float mBias;

        public LogisticRegressionAppRanker(Context context) {
            mParamSharedPref = getParamSharedPref(context);
        }

        public float predict(ArrayMap<String, Float> target) {
            if (target == null || mParamSharedPref == null) {
                return 0.0f;
            }
            final int featureSize = target.size();
            if (featureSize == 0) {
                return 0.0f;
            }
            float sum = 0.0f;
            if (mFeatureWeights == null) {
                mBias = mParamSharedPref.getFloat(BIAS_PREF_KEY, 0.0f);
                mFeatureWeights = new ArrayMap<>(featureSize);
                for (int i = 0; i < featureSize; i++) {
                    String featureName = target.keyAt(i);
                    float weight = mParamSharedPref.getFloat(featureName, 0.0f);
                    sum += weight * target.valueAt(i);
                    mFeatureWeights.put(featureName, weight);
                }
            } else {
                for (int i = 0; i < featureSize; i++) {
                    String featureName = target.keyAt(i);
                    float weight = mFeatureWeights.getOrDefault(featureName, 0.0f);
                    sum += weight * target.valueAt(i);
                }
            }
            return (float) (1.0 / (1.0 + Math.exp(-mBias - sum)));
        }

        public void update(ArrayMap<String, Float> target, float predict, boolean isSelected) {
            if (target == null || target.size() == 0) {
                return;
            }
            final int featureSize = target.size();
            if (mFeatureWeights == null) {
                mBias = 0.0f;
                mFeatureWeights = new ArrayMap<>(featureSize);
            }
            float error = isSelected ? 1.0f - predict : -predict;
            for (int i = 0; i < featureSize; i++) {
                String featureName = target.keyAt(i);
                float currentWeight = mFeatureWeights.getOrDefault(featureName, 0.0f);
                mBias += LEARNING_RATE * error;
                currentWeight = currentWeight - LEARNING_RATE * REGULARIZER_PARAM * currentWeight +
                        LEARNING_RATE * error * target.valueAt(i);
                mFeatureWeights.put(featureName, currentWeight);
            }
            if (DEBUG) {
                Log.d(TAG, "Weights: " + mFeatureWeights + " Bias: " + mBias);
            }
        }

        public void commitUpdate() {
            if (mFeatureWeights == null || mFeatureWeights.size() == 0) {
                return;
            }
            SharedPreferences.Editor editor = mParamSharedPref.edit();
            editor.putFloat(BIAS_PREF_KEY, mBias);
            final int size = mFeatureWeights.size();
            for (int i = 0; i < size; i++) {
                editor.putFloat(mFeatureWeights.keyAt(i), mFeatureWeights.valueAt(i));
            }
            editor.apply();
        }

        private SharedPreferences getParamSharedPref(Context context) {
            // The package info in the context isn't initialized in the way it is for normal apps,
            // so the standard, name-based context.getSharedPreferences doesn't work. Instead, we
            // build the path manually below using the same policy that appears in ContextImpl.
            if (DEBUG) {
                Log.d(TAG, "Context Package Name: " + context.getPackageName());
            }
            final File prefsFile = new File(new File(
                    Environment.getDataUserCePackageDirectory(StorageManager.UUID_PRIVATE_INTERNAL,
                            context.getUserId(), context.getPackageName()),
                    "shared_prefs"),
                    PARAM_SHARED_PREF_NAME + ".xml");
            return context.getSharedPreferences(prefsFile, Context.MODE_PRIVATE);
        }
    }
    }
}
}