Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 624c594f authored by Tony Mak's avatar Tony Mak Committed by android-build-merger
Browse files

Merge "Flag to configure model parameters" into qt-dev

am: 3b74731f

Change-Id: I942d859be8ee77027dfa94da8c004a143193c79e
parents 0b0b2787 3b74731f
Loading
Loading
Loading
Loading
+208 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package android.view.textclassifier;

import android.annotation.Nullable;
import android.content.ContentResolver;
import android.content.Context;
import android.database.ContentObserver;
import android.provider.Settings;
import android.text.TextUtils;
import android.util.Base64;
import android.util.KeyValueListParser;

import com.android.internal.annotations.GuardedBy;
import com.android.internal.annotations.VisibleForTesting;
import com.android.internal.util.Preconditions;

import java.lang.ref.WeakReference;
import java.util.Objects;
import java.util.function.Supplier;

/**
 * Parses the {@link Settings.Global#TEXT_CLASSIFIER_ACTION_MODEL_PARAMS} flag.
 *
 * @hide
 */
public final class ActionsModelParamsSupplier implements
        Supplier<ActionsModelParamsSupplier.ActionsModelParams> {
    private static final String TAG = TextClassifier.DEFAULT_LOG_TAG;

    @VisibleForTesting
    static final String KEY_REQUIRED_MODEL_VERSION = "required_model_version";
    @VisibleForTesting
    static final String KEY_REQUIRED_LOCALES = "required_locales";
    @VisibleForTesting
    static final String KEY_SERIALIZED_PRECONDITIONS = "serialized_preconditions";

    private final Context mAppContext;
    private final SettingsObserver mSettingsObserver;

    private final Object mLock = new Object();
    private final Runnable mOnChangedListener;
    @Nullable
    @GuardedBy("mLock")
    private ActionsModelParams mActionsModelParams;
    @GuardedBy("mLock")
    private boolean mParsed = true;

    public ActionsModelParamsSupplier(Context context, @Nullable Runnable onChangedListener) {
        mAppContext = Preconditions.checkNotNull(context).getApplicationContext();
        mOnChangedListener = onChangedListener == null ? () -> {} : onChangedListener;
        mSettingsObserver = new SettingsObserver(mAppContext, () -> {
            synchronized (mLock) {
                Log.v(TAG, "Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS is updated");
                mParsed = true;
                mOnChangedListener.run();
            }
        });
    }

    /**
     * Returns the parsed actions params or {@link ActionsModelParams#INVALID} if the value is
     * invalid.
     */
    @Override
    public ActionsModelParams get() {
        synchronized (mLock) {
            if (mParsed) {
                mActionsModelParams = parse(mAppContext.getContentResolver());
                mParsed = false;
            }
        }
        return mActionsModelParams;
    }

    private ActionsModelParams parse(ContentResolver contentResolver) {
        String settingStr = Settings.Global.getString(contentResolver,
                Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS);
        if (TextUtils.isEmpty(settingStr)) {
            return ActionsModelParams.INVALID;
        }
        try {
            KeyValueListParser keyValueListParser = new KeyValueListParser(',');
            keyValueListParser.setString(settingStr);
            int version = keyValueListParser.getInt(KEY_REQUIRED_MODEL_VERSION, -1);
            if (version == -1) {
                Log.w(TAG, "ActionsModelParams.Parse, invalid model version");
                return ActionsModelParams.INVALID;
            }
            String locales = keyValueListParser.getString(KEY_REQUIRED_LOCALES, null);
            if (locales == null) {
                Log.w(TAG, "ActionsModelParams.Parse, invalid locales");
                return ActionsModelParams.INVALID;
            }
            String serializedPreconditionsStr =
                    keyValueListParser.getString(KEY_SERIALIZED_PRECONDITIONS, null);
            if (serializedPreconditionsStr == null) {
                Log.w(TAG, "ActionsModelParams.Parse, invalid preconditions");
                return ActionsModelParams.INVALID;
            }
            byte[] serializedPreconditions =
                    Base64.decode(serializedPreconditionsStr, Base64.NO_WRAP);
            return new ActionsModelParams(version, locales, serializedPreconditions);
        } catch (Throwable t) {
            Log.e(TAG, "Invalid TEXT_CLASSIFIER_ACTION_MODEL_PARAMS, ignore", t);
        }
        return ActionsModelParams.INVALID;
    }

    @Override
    protected void finalize() throws Throwable {
        try {
            mAppContext.getContentResolver().unregisterContentObserver(mSettingsObserver);
        } finally {
            super.finalize();
        }
    }

    /**
     * Represents the parsed result.
     */
    public static final class ActionsModelParams {

        public static final ActionsModelParams INVALID =
                new ActionsModelParams(-1, "", new byte[0]);

        /**
         * The required model version to apply {@code mSerializedPreconditions}.
         */
        private final int mRequiredModelVersion;

        /**
         * The required model locales to apply {@code mSerializedPreconditions}.
         */
        private final String mRequiredModelLocales;

        /**
         * The serialized params that will be applied to the model file, if all requirements are
         * met. Do not modify.
         */
        private final byte[] mSerializedPreconditions;

        public ActionsModelParams(int requiredModelVersion, String requiredModelLocales,
                byte[] serializedPreconditions) {
            mRequiredModelVersion = requiredModelVersion;
            mRequiredModelLocales = Preconditions.checkNotNull(requiredModelLocales);
            mSerializedPreconditions = Preconditions.checkNotNull(serializedPreconditions);
        }

        /**
         * Returns the serialized preconditions. Returns {@code null} if the the model in use does
         * not meet all the requirements listed in the {@code ActionsModelParams} or the params
         * are invalid.
         */
        @Nullable
        public byte[] getSerializedPreconditions(ModelFileManager.ModelFile modelInUse) {
            if (this == INVALID) {
                return null;
            }
            if (modelInUse.getVersion() != mRequiredModelVersion) {
                Log.w(TAG, String.format(
                        "Not applying mSerializedPreconditions, required version=%d, actual=%d",
                        mRequiredModelVersion, modelInUse.getVersion()));
                return null;
            }
            if (!Objects.equals(modelInUse.getSupportedLocalesStr(), mRequiredModelLocales)) {
                Log.w(TAG, String.format(
                        "Not applying mSerializedPreconditions, required locales=%s, actual=%s",
                        mRequiredModelLocales, modelInUse.getSupportedLocalesStr()));
                return null;
            }
            return mSerializedPreconditions;
        }
    }

    private static final class SettingsObserver extends ContentObserver {

        private final WeakReference<Runnable> mOnChangedListener;

        SettingsObserver(Context appContext, Runnable listener) {
            super(null);
            mOnChangedListener = new WeakReference<>(listener);
            appContext.getContentResolver().registerContentObserver(
                    Settings.Global.getUriFor(Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS),
                    false /* notifyForDescendants */,
                    this);
        }

        public void onChange(boolean selfChange) {
            if (mOnChangedListener.get() != null) {
                mOnChangedListener.get().run();
            }
        }
    }
}
+9 −0
Original line number Diff line number Diff line
@@ -167,6 +167,7 @@ public final class ModelFileManager {
                        file,
                        version,
                        supportedLocales,
                        supportedLocalesStr,
                        ModelFile.LANGUAGE_INDEPENDENT.equals(supportedLocalesStr));
            } catch (FileNotFoundException e) {
                Log.e(DEFAULT_LOG_TAG, "Failed to find " + file.getAbsolutePath(), e);
@@ -201,13 +202,16 @@ public final class ModelFileManager {
        private final File mFile;
        private final int mVersion;
        private final List<Locale> mSupportedLocales;
        private final String mSupportedLocalesStr;
        private final boolean mLanguageIndependent;

        public ModelFile(File file, int version, List<Locale> supportedLocales,
                String supportedLocalesStr,
                boolean languageIndependent) {
            mFile = Preconditions.checkNotNull(file);
            mVersion = version;
            mSupportedLocales = Preconditions.checkNotNull(supportedLocales);
            mSupportedLocalesStr = Preconditions.checkNotNull(supportedLocalesStr);
            mLanguageIndependent = languageIndependent;
        }

@@ -237,6 +241,11 @@ public final class ModelFileManager {
            return Collections.unmodifiableList(mSupportedLocales);
        }

        /** Returns the original supported locals string read from the model file. */
        public String getSupportedLocalesStr() {
            return mSupportedLocalesStr;
        }

        /**
         * Returns if this model file is preferred to the given one.
         */
+19 −3
Original line number Diff line number Diff line
@@ -29,6 +29,7 @@ import android.os.ParcelFileDescriptor;
import android.util.ArrayMap;
import android.util.ArraySet;
import android.util.Pair;
import android.view.textclassifier.ActionsModelParamsSupplier.ActionsModelParams;
import android.view.textclassifier.intent.ClassificationIntentFactory;
import android.view.textclassifier.intent.LabeledIntent;
import android.view.textclassifier.intent.LegacyClassificationIntentFactory;
@@ -57,6 +58,7 @@ import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;

/**
 * Default implementation of the {@link TextClassifier} interface.
@@ -124,6 +126,7 @@ public final class TextClassifierImpl implements TextClassifier {

    private final ClassificationIntentFactory mClassificationIntentFactory;
    private final TemplateIntentFactory mTemplateIntentFactory;
    private final Supplier<ActionsModelParams> mActionsModelParamsSupplier;

    public TextClassifierImpl(
            Context context, TextClassificationConstants settings, TextClassifier fallback) {
@@ -158,6 +161,15 @@ public final class TextClassifierImpl implements TextClassifier {
                ? new TemplateClassificationIntentFactory(
                mTemplateIntentFactory, new LegacyClassificationIntentFactory())
                : new LegacyClassificationIntentFactory();
        mActionsModelParamsSupplier = new ActionsModelParamsSupplier(mContext,
                () -> {
                    synchronized (mLock) {
                        // Clear mActionsImpl here, so that we will create a new
                        // ActionsSuggestionsModel object with the new flag in the next request.
                        mActionsImpl = null;
                        mActionModelInUse = null;
                    }
                });
    }

    public TextClassifierImpl(Context context, TextClassificationConstants settings) {
@@ -584,10 +596,14 @@ public final class TextClassifierImpl implements TextClassifier {
                final ParcelFileDescriptor pfd = ParcelFileDescriptor.open(
                        new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
                try {
                    if (pfd != null) {
                        mActionsImpl = new ActionsSuggestionsModel(pfd.getFd());
                        mActionModelInUse = bestModel;
                    if (pfd == null) {
                        Log.d(LOG_TAG, "Failed to read the model file: " + bestModel.getPath());
                        return null;
                    }
                    ActionsModelParams params = mActionsModelParamsSupplier.get();
                    mActionsImpl = new ActionsSuggestionsModel(
                            pfd.getFd(), params.getSerializedPreconditions(bestModel));
                    mActionModelInUse = bestModel;
                } finally {
                    maybeCloseAndLogError(pfd);
                }
+95 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2019 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package android.view.textclassifier;

import static com.google.common.truth.Truth.assertThat;

import androidx.test.filters.SmallTest;
import androidx.test.runner.AndroidJUnit4;

import org.junit.Test;
import org.junit.runner.RunWith;

import java.io.File;
import java.util.Collections;
import java.util.Locale;

@SmallTest
@RunWith(AndroidJUnit4.class)
public class ActionsModelParamsSupplierTest {

    @Test
    public void getSerializedPreconditions_validActionsModelParams() {
        ModelFileManager.ModelFile modelFile = new ModelFileManager.ModelFile(
                new File("/model/file"),
                200 /* version */,
                Collections.singletonList(Locale.forLanguageTag("en")),
                "en",
                false);
        byte[] serializedPreconditions = new byte[]{0x12, 0x24, 0x36};
        ActionsModelParamsSupplier.ActionsModelParams params =
                new ActionsModelParamsSupplier.ActionsModelParams(
                        200 /* version */,
                        "en",
                        serializedPreconditions);

        byte[] actual = params.getSerializedPreconditions(modelFile);

        assertThat(actual).isEqualTo(serializedPreconditions);
    }

    @Test
    public void getSerializedPreconditions_invalidVersion() {
        ModelFileManager.ModelFile modelFile = new ModelFileManager.ModelFile(
                new File("/model/file"),
                201 /* version */,
                Collections.singletonList(Locale.forLanguageTag("en")),
                "en",
                false);
        byte[] serializedPreconditions = new byte[]{0x12, 0x24, 0x36};
        ActionsModelParamsSupplier.ActionsModelParams params =
                new ActionsModelParamsSupplier.ActionsModelParams(
                        200 /* version */,
                        "en",
                        serializedPreconditions);

        byte[] actual = params.getSerializedPreconditions(modelFile);

        assertThat(actual).isNull();
    }

    @Test
    public void getSerializedPreconditions_invalidLocales() {
        final String LANGUAGE_TAG = "zh";
        ModelFileManager.ModelFile modelFile = new ModelFileManager.ModelFile(
                new File("/model/file"),
                200 /* version */,
                Collections.singletonList(Locale.forLanguageTag(LANGUAGE_TAG)),
                LANGUAGE_TAG,
                false);
        byte[] serializedPreconditions = new byte[]{0x12, 0x24, 0x36};
        ActionsModelParamsSupplier.ActionsModelParams params =
                new ActionsModelParamsSupplier.ActionsModelParams(
                        200 /* version */,
                        "en",
                        serializedPreconditions);

        byte[] actual = params.getSerializedPreconditions(modelFile);

        assertThat(actual).isNull();
    }

}
+24 −23
Original line number Diff line number Diff line
@@ -86,7 +86,7 @@ public class ModelFileManagerTest {
    public void get() {
        ModelFileManager.ModelFile modelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1, Collections.emptyList(), true);
                        new File("/path/a"), 1, Collections.emptyList(), "", true);
        when(mModelFileSupplier.get()).thenReturn(Collections.singletonList(modelFile));

        List<ModelFileManager.ModelFile> modelFiles = mModelFileManager.listModelFiles();
@@ -100,12 +100,12 @@ public class ModelFileManagerTest {
        ModelFileManager.ModelFile olderModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.emptyList(), true);
                        Collections.emptyList(), "", true);

        ModelFileManager.ModelFile newerModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 2,
                        Collections.emptyList(), true);
                        Collections.emptyList(), "", true);
        when(mModelFileSupplier.get())
                .thenReturn(Arrays.asList(olderModelFile, newerModelFile));

@@ -121,12 +121,12 @@ public class ModelFileManagerTest {
        ModelFileManager.ModelFile languageIndependentModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.emptyList(), true);
                        Collections.emptyList(), "", true);

        ModelFileManager.ModelFile languageDependentModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 1,
                        Collections.singletonList(locale), false);
                        Collections.singletonList(locale), locale.toLanguageTag(), false);
        when(mModelFileSupplier.get())
                .thenReturn(
                        Arrays.asList(languageIndependentModelFile, languageDependentModelFile));
@@ -143,12 +143,12 @@ public class ModelFileManagerTest {
        ModelFileManager.ModelFile languageIndependentModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.emptyList(), true);
                        Collections.emptyList(), "", true);

        ModelFileManager.ModelFile languageDependentModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 1,
                        Collections.singletonList(locale), false);
                        Collections.singletonList(locale), locale.toLanguageTag(), false);

        when(mModelFileSupplier.get())
                .thenReturn(
@@ -165,12 +165,13 @@ public class ModelFileManagerTest {
        ModelFileManager.ModelFile languageIndependentModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.emptyList(), true);
                        Collections.emptyList(), "", true);

        ModelFileManager.ModelFile languageDependentModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 1,
                        Collections.singletonList(DEFAULT_LOCALE), false);
                        Collections.singletonList(
                                DEFAULT_LOCALE), DEFAULT_LOCALE.toLanguageTag(), false);

        when(mModelFileSupplier.get())
                .thenReturn(
@@ -187,12 +188,12 @@ public class ModelFileManagerTest {
        ModelFileManager.ModelFile matchButOlderModel =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.singletonList(Locale.forLanguageTag("fr")), false);
                        Collections.singletonList(Locale.forLanguageTag("fr")), "fr", false);

        ModelFileManager.ModelFile mismatchButNewerModel =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 2,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);
                        Collections.singletonList(Locale.forLanguageTag("ja")), "ja", false);

        when(mModelFileSupplier.get())
                .thenReturn(
@@ -209,12 +210,12 @@ public class ModelFileManagerTest {
        ModelFileManager.ModelFile matchLocaleModel =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);
                        Collections.singletonList(Locale.forLanguageTag("ja")), "ja", false);

        ModelFileManager.ModelFile languageIndependentModel =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 2,
                        Collections.emptyList(), true);
                        Collections.emptyList(), "", true);
        when(mModelFileSupplier.get())
                .thenReturn(
                        Arrays.asList(matchLocaleModel, languageIndependentModel));
@@ -231,12 +232,12 @@ public class ModelFileManagerTest {
        ModelFileManager.ModelFile modelA =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);
                        Collections.singletonList(Locale.forLanguageTag("ja")), "ja", false);

        ModelFileManager.ModelFile modelB =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);
                        Collections.singletonList(Locale.forLanguageTag("ja")), "ja", false);

        assertThat(modelA).isEqualTo(modelB);
    }
@@ -246,12 +247,12 @@ public class ModelFileManagerTest {
        ModelFileManager.ModelFile modelA =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);
                        Collections.singletonList(Locale.forLanguageTag("ja")), "ja", false);

        ModelFileManager.ModelFile modelB =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);
                        Collections.singletonList(Locale.forLanguageTag("ja")), "ja", false);

        assertThat(modelA).isNotEqualTo(modelB);
    }
@@ -262,7 +263,7 @@ public class ModelFileManagerTest {
        ModelFileManager.ModelFile modelA =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);
                        Collections.singletonList(Locale.forLanguageTag("ja")), "ja", false);

        assertThat(modelA.getPath()).isEqualTo("/path/a");
    }
@@ -272,7 +273,7 @@ public class ModelFileManagerTest {
        ModelFileManager.ModelFile modelA =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);
                        Collections.singletonList(Locale.forLanguageTag("ja")), "ja", false);

        assertThat(modelA.getName()).isEqualTo("a");
    }
@@ -282,12 +283,12 @@ public class ModelFileManagerTest {
        ModelFileManager.ModelFile modelA =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);
                        Collections.singletonList(Locale.forLanguageTag("ja")), "ja", false);

        ModelFileManager.ModelFile modelB =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 2,
                        Collections.emptyList(), true);
                        Collections.emptyList(), "", true);

        assertThat(modelA.isPreferredTo(modelB)).isTrue();
    }
@@ -297,12 +298,12 @@ public class ModelFileManagerTest {
        ModelFileManager.ModelFile modelA =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 2,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);
                        Collections.singletonList(Locale.forLanguageTag("ja")), "ja", false);

        ModelFileManager.ModelFile modelB =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 1,
                        Collections.emptyList(), false);
                        Collections.emptyList(), "", false);

        assertThat(modelA.isPreferredTo(modelB)).isTrue();
    }