Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 9988f368 authored by Kang Li's avatar Kang Li
Browse files

Initialize Sharing Ranker with a pre-trained model.

Test: manual shared images in Photos; webpages in Chrome; contacts in
Contacts, and ran unit tests.

Change-Id: If8c724a4085f1436b3e1e5d62754c6563f756915
parent 57c45562
Loading
Loading
Loading
Loading
+38 −29
Original line number Diff line number Diff line
@@ -343,53 +343,42 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {
    class LogisticRegressionAppRanker {
        private static final String PARAM_SHARED_PREF_NAME = "resolver_ranker_params";
        private static final String BIAS_PREF_KEY = "bias";
        private static final float LEARNING_RATE = 0.02f;
        private static final float REGULARIZER_PARAM = 0.1f;
        private static final String VERSION_PREF_KEY = "version";

        // parameters for a pre-trained model, to initialize the app ranker. When updating the
        // pre-trained model, please update these params, as well as initModel().
        private static final int CURRENT_VERSION = 1;
        private static final float LEARNING_RATE = 0.0001f;
        private static final float REGULARIZER_PARAM = 0.0001f;

        private SharedPreferences mParamSharedPref;
        private ArrayMap<String, Float> mFeatureWeights;
        private float mBias;

        public LogisticRegressionAppRanker(Context context) {
            mParamSharedPref = getParamSharedPref(context);
            initModel();
        }

        public float predict(ArrayMap<String, Float> target) {
            if (target == null || mParamSharedPref == null) {
            if (target == null) {
                return 0.0f;
            }
            final int featureSize = target.size();
            if (featureSize == 0) {
                return 0.0f;
            }
            float sum = 0.0f;
            if (mFeatureWeights == null) {
                mBias = mParamSharedPref.getFloat(BIAS_PREF_KEY, 0.0f);
                mFeatureWeights = new ArrayMap<>(featureSize);
                for (int i = 0; i < featureSize; i++) {
                    String featureName = target.keyAt(i);
                    float weight = mParamSharedPref.getFloat(featureName, 0.0f);
                    sum += weight * target.valueAt(i);
                    mFeatureWeights.put(featureName, weight);
                }
            } else {
            for (int i = 0; i < featureSize; i++) {
                String featureName = target.keyAt(i);
                float weight = mFeatureWeights.getOrDefault(featureName, 0.0f);
                sum += weight * target.valueAt(i);
            }
            }
            return (float) (1.0 / (1.0 + Math.exp(-mBias - sum)));
        }

        public void update(ArrayMap<String, Float> target, float predict, boolean isSelected) {
            if (target == null || target.size() == 0) {
            if (target == null) {
                return;
            }
            final int featureSize = target.size();
            if (mFeatureWeights == null) {
                mBias = 0.0f;
                mFeatureWeights = new ArrayMap<>(featureSize);
            }
            float error = isSelected ? 1.0f - predict : -predict;
            for (int i = 0; i < featureSize; i++) {
                String featureName = target.keyAt(i);
@@ -405,15 +394,13 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {
        }

        public void commitUpdate() {
            if (mFeatureWeights == null || mFeatureWeights.size() == 0) {
                return;
            }
            SharedPreferences.Editor editor = mParamSharedPref.edit();
            editor.putFloat(BIAS_PREF_KEY, mBias);
            final int size = mFeatureWeights.size();
            for (int i = 0; i < size; i++) {
                editor.putFloat(mFeatureWeights.keyAt(i), mFeatureWeights.valueAt(i));
            }
            editor.putInt(VERSION_PREF_KEY, CURRENT_VERSION);
            editor.apply();
        }

@@ -431,5 +418,27 @@ class ResolverComparator implements Comparator<ResolvedComponentInfo> {
                    PARAM_SHARED_PREF_NAME + ".xml");
            return context.getSharedPreferences(prefsFile, Context.MODE_PRIVATE);
        }

        private void initModel() {
            mFeatureWeights = new ArrayMap<>(4);
            if (mParamSharedPref == null ||
                    mParamSharedPref.getInt(VERSION_PREF_KEY, 0) < CURRENT_VERSION) {
                // Initializing the app ranker to a pre-trained model. When updating the pre-trained
                // model, please increment CURRENT_VERSION, and update LEARNING_RATE and
                // REGULARIZER_PARAM.
                mBias = -1.6568f;
                mFeatureWeights.put(LAUNCH_SCORE, 2.5543f);
                mFeatureWeights.put(TIME_SPENT_SCORE, 2.8412f);
                mFeatureWeights.put(RECENCY_SCORE, 0.269f);
                mFeatureWeights.put(CHOOSER_SCORE, 4.2222f);
            } else {
                mBias = mParamSharedPref.getFloat(BIAS_PREF_KEY, 0.0f);
                mFeatureWeights.put(LAUNCH_SCORE, mParamSharedPref.getFloat(LAUNCH_SCORE, 0.0f));
                mFeatureWeights.put(
                        TIME_SPENT_SCORE, mParamSharedPref.getFloat(TIME_SPENT_SCORE, 0.0f));
                mFeatureWeights.put(RECENCY_SCORE, mParamSharedPref.getFloat(RECENCY_SCORE, 0.0f));
                mFeatureWeights.put(CHOOSER_SCORE, mParamSharedPref.getFloat(CHOOSER_SCORE, 0.0f));
            }
        }
    }
}