Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 1b2a6773 authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Refactor model listing / selection code to support other types of model"

parents 3cb90a58 ba228421
Loading
Loading
Loading
Loading
+291 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package android.view.textclassifier;

import static android.view.textclassifier.TextClassifier.DEFAULT_LOG_TAG;

import android.annotation.Nullable;
import android.os.LocaleList;
import android.os.ParcelFileDescriptor;
import android.text.TextUtils;

import com.android.internal.annotations.VisibleForTesting;
import com.android.internal.util.Preconditions;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.StringJoiner;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * Manages model files that are listed by the model files supplier.
 * @hide
 */
@VisibleForTesting(visibility = VisibleForTesting.Visibility.PACKAGE)
public final class ModelFileManager {
    private final Object mLock = new Object();
    private final Supplier<List<ModelFile>> mModelFileSupplier;

    private List<ModelFile> mModelFiles;

    public ModelFileManager(Supplier<List<ModelFile>> modelFileSupplier) {
        mModelFileSupplier = Preconditions.checkNotNull(modelFileSupplier);
    }

    /**
     * Returns an unmodifiable list of model files listed by the given model files supplier.
     * <p>
     * The result is cached.
     */
    public List<ModelFile> listModelFiles() {
        synchronized (mLock) {
            if (mModelFiles == null) {
                mModelFiles = Collections.unmodifiableList(mModelFileSupplier.get());
            }
            return mModelFiles;
        }
    }

    /**
     * Returns the best model file for the given localelist, {@code null} if nothing is found.
     *
     * @param localeList the required locales, use {@code null} if there is no preference.
     */
    public ModelFile findBestModelFile(@Nullable LocaleList localeList) {
        // Specified localeList takes priority over the system default, so it is listed first.
        final String languages = localeList == null || localeList.isEmpty()
                ? LocaleList.getDefault().toLanguageTags()
                : localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags();
        final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);

        ModelFile bestModel = null;
        for (ModelFile model : listModelFiles()) {
            if (model.isAnyLanguageSupported(languageRangeList)) {
                if (model.isPreferredTo(bestModel)) {
                    bestModel = model;
                }
            }
        }
        return bestModel;
    }

    /**
     * Default implementation of the model file supplier.
     */
    public static final class ModelFileSupplierImpl implements Supplier<List<ModelFile>> {
        private final File mUpdatedModelFile;
        private final File mFactoryModelDir;
        private final Pattern mModelFilenamePattern;
        private final Function<Integer, Integer> mVersionSupplier;
        private final Function<Integer, String> mSupportedLocalesSupplier;

        public ModelFileSupplierImpl(
                File factoryModelDir,
                String factoryModelFileNameRegex,
                File updatedModelFile,
                Function<Integer, Integer> versionSupplier,
                Function<Integer, String> supportedLocalesSupplier) {
            mUpdatedModelFile = Preconditions.checkNotNull(updatedModelFile);
            mFactoryModelDir = Preconditions.checkNotNull(factoryModelDir);
            mModelFilenamePattern = Pattern.compile(
                    Preconditions.checkNotNull(factoryModelFileNameRegex));
            mVersionSupplier = Preconditions.checkNotNull(versionSupplier);
            mSupportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
        }

        @Override
        public List<ModelFile> get() {
            final List<ModelFile> modelFiles = new ArrayList<>();
            // The update model has the highest precedence.
            if (mUpdatedModelFile.exists()) {
                final ModelFile updatedModel = createModelFile(mUpdatedModelFile);
                if (updatedModel != null) {
                    modelFiles.add(updatedModel);
                }
            }
            // Factory models should never have overlapping locales, so the order doesn't matter.
            if (mFactoryModelDir.exists() && mFactoryModelDir.isDirectory()) {
                final File[] files = mFactoryModelDir.listFiles();
                for (File file : files) {
                    final Matcher matcher = mModelFilenamePattern.matcher(file.getName());
                    if (matcher.matches() && file.isFile()) {
                        final ModelFile model = createModelFile(file);
                        if (model != null) {
                            modelFiles.add(model);
                        }
                    }
                }
            }
            return modelFiles;
        }

        /** Returns null if the path did not point to a compatible model. */
        @Nullable
        private ModelFile createModelFile(File file) {
            if (!file.exists()) {
                return null;
            }
            ParcelFileDescriptor modelFd = null;
            try {
                modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
                if (modelFd == null) {
                    return null;
                }
                final int modelFdInt = modelFd.getFd();
                final int version = mVersionSupplier.apply(modelFdInt);
                final String supportedLocalesStr = mSupportedLocalesSupplier.apply(modelFdInt);
                if (supportedLocalesStr.isEmpty()) {
                    Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
                    return null;
                }
                final List<Locale> supportedLocales = new ArrayList<>();
                for (String langTag : supportedLocalesStr.split(",")) {
                    supportedLocales.add(Locale.forLanguageTag(langTag));
                }
                return new ModelFile(
                        file,
                        version,
                        supportedLocales,
                        ModelFile.LANGUAGE_INDEPENDENT.equals(supportedLocalesStr));
            } catch (FileNotFoundException e) {
                Log.e(DEFAULT_LOG_TAG, "Failed to find " + file.getAbsolutePath(), e);
                return null;
            } finally {
                maybeCloseAndLogError(modelFd);
            }
        }

        /**
         * Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur.
         */
        private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
            if (fd == null) {
                return;
            }
            try {
                fd.close();
            } catch (IOException e) {
                Log.e(DEFAULT_LOG_TAG, "Error closing file.", e);
            }
        }

    }

    /**
     * Describes TextClassifier model files on disk.
     */
    public static final class ModelFile {
        public static final String LANGUAGE_INDEPENDENT = "*";

        private final File mFile;
        private final int mVersion;
        private final List<Locale> mSupportedLocales;
        private final boolean mLanguageIndependent;

        public ModelFile(File file, int version, List<Locale> supportedLocales,
                boolean languageIndependent) {
            mFile = Preconditions.checkNotNull(file);
            mVersion = version;
            mSupportedLocales = Preconditions.checkNotNull(supportedLocales);
            mLanguageIndependent = languageIndependent;
        }

        /** Returns the absolute path to the model file. */
        public String getPath() {
            return mFile.getAbsolutePath();
        }

        /** Returns a name to use for id generation, effectively the name of the model file. */
        public String getName() {
            return mFile.getName();
        }

        /** Returns the version tag in the model's metadata. */
        public int getVersion() {
            return mVersion;
        }

        /** Returns whether the language supports any language in the given ranges. */
        public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
            Preconditions.checkNotNull(languageRanges);
            return mLanguageIndependent || Locale.lookup(languageRanges, mSupportedLocales) != null;
        }

        /** Returns an immutable lists of supported locales. */
        public List<Locale> getSupportedLocales() {
            return Collections.unmodifiableList(mSupportedLocales);
        }

        /**
         * Returns if this model file is preferred to the given one.
         */
        public boolean isPreferredTo(@Nullable ModelFile model) {
            // A model is preferred to no model.
            if (model == null) {
                return true;
            }

            // A language-specific model is preferred to a language independent
            // model.
            if (!mLanguageIndependent && model.mLanguageIndependent) {
                return true;
            }

            // A higher-version model is preferred.
            if (mVersion > model.getVersion()) {
                return true;
            }
            return false;
        }

        @Override
        public int hashCode() {
            return Objects.hash(getPath());
        }

        @Override
        public boolean equals(Object other) {
            if (this == other) {
                return true;
            }
            if (other instanceof ModelFile) {
                final ModelFile otherModel = (ModelFile) other;
                return TextUtils.equals(getPath(), otherModel.getPath());
            }
            return false;
        }

        @Override
        public String toString() {
            final StringJoiner localesJoiner = new StringJoiner(",");
            for (Locale locale : mSupportedLocales) {
                localesJoiner.add(locale.toLanguageTag());
            }
            return String.format(Locale.US,
                    "ModelFile { path=%s name=%s version=%d locales=%s }",
                    getPath(), getName(), mVersion, localesJoiner.toString());
        }
    }
}
+62 −238

File changed.

Preview size limit exceeded, changes collapsed.

+301 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package android.view.textclassifier;

import static com.google.common.truth.Truth.assertThat;

import static org.mockito.Mockito.when;

import android.os.LocaleList;
import android.support.test.InstrumentationRegistry;
import android.support.test.filters.SmallTest;
import android.support.test.runner.AndroidJUnit4;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.function.Supplier;
import java.util.stream.Collectors;

@SmallTest
@RunWith(AndroidJUnit4.class)
public class ModelFileManagerTest {

    @Mock
    private Supplier<List<ModelFileManager.ModelFile>> mModelFileSupplier;
    private ModelFileManager.ModelFileSupplierImpl mModelFileSupplierImpl;
    private ModelFileManager mModelFileManager;
    private File mRootTestDir;
    private File mFactoryModelDir;
    private File mUpdatedModelFile;

    @Before
    public void setup() {
        MockitoAnnotations.initMocks(this);
        mModelFileManager = new ModelFileManager(mModelFileSupplier);
        mRootTestDir = InstrumentationRegistry.getContext().getCacheDir();
        mFactoryModelDir = new File(mRootTestDir, "factory");
        mUpdatedModelFile = new File(mRootTestDir, "updated.model");

        mModelFileSupplierImpl =
                new ModelFileManager.ModelFileSupplierImpl(
                        mFactoryModelDir,
                        "test\\d.model",
                        mUpdatedModelFile,
                        fd -> 1,
                        fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT
                );

        mRootTestDir.mkdirs();
        mFactoryModelDir.mkdirs();
    }

    @After
    public void removeTestDir() {
        recursiveDelete(mRootTestDir);
    }

    @Test
    public void get() {
        ModelFileManager.ModelFile modelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1, Collections.emptyList(), true);
        when(mModelFileSupplier.get()).thenReturn(Collections.singletonList(modelFile));

        List<ModelFileManager.ModelFile> modelFiles = mModelFileManager.listModelFiles();

        assertThat(modelFiles).hasSize(1);
        assertThat(modelFiles.get(0)).isEqualTo(modelFile);
    }

    @Test
    public void findBestModel_versionCode() {
        ModelFileManager.ModelFile olderModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.emptyList(), true);

        ModelFileManager.ModelFile newerModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 2,
                        Collections.emptyList(), true);
        when(mModelFileSupplier.get())
                .thenReturn(Arrays.asList(olderModelFile, newerModelFile));

        ModelFileManager.ModelFile bestModelFile =
                mModelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList());

        assertThat(bestModelFile).isEqualTo(newerModelFile);
    }

    @Test
    public void findBestModel_languageDependentModelIsPreferred() {
        Locale locale = Locale.forLanguageTag("ja");
        ModelFileManager.ModelFile languageIndependentModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.emptyList(), true);

        ModelFileManager.ModelFile languageDependentModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 1,
                        Collections.singletonList(locale), false);
        when(mModelFileSupplier.get())
                .thenReturn(
                        Arrays.asList(languageIndependentModelFile, languageDependentModelFile));

        ModelFileManager.ModelFile bestModelFile =
                mModelFileManager.findBestModelFile(
                        LocaleList.forLanguageTags(locale.toLanguageTag()));
        assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
    }

    @Test
    public void findBestModel_useIndependentWhenNoLanguageModelMatch() {
        Locale locale = Locale.forLanguageTag("ja");
        ModelFileManager.ModelFile languageIndependentModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.emptyList(), true);

        ModelFileManager.ModelFile languageDependentModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 1,
                        Collections.singletonList(locale), false);

        when(mModelFileSupplier.get())
                .thenReturn(
                        Arrays.asList(languageIndependentModelFile, languageDependentModelFile));

        ModelFileManager.ModelFile bestModelFile =
                mModelFileManager.findBestModelFile(
                        LocaleList.forLanguageTags("zh-hk"));
        assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
    }

    @Test
    public void findBestModel_languageIsMoreImportantThanVersion() {
        ModelFileManager.ModelFile matchButOlderModel =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.singletonList(Locale.forLanguageTag("fr")), false);

        ModelFileManager.ModelFile mismatchButNewerModel =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 2,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);

        when(mModelFileSupplier.get())
                .thenReturn(
                        Arrays.asList(matchButOlderModel, mismatchButNewerModel));

        ModelFileManager.ModelFile bestModelFile =
                mModelFileManager.findBestModelFile(
                        LocaleList.forLanguageTags("fr"));
        assertThat(bestModelFile).isEqualTo(matchButOlderModel);
    }

    @Test
    public void modelFileEquals() {
        ModelFileManager.ModelFile modelA =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);

        ModelFileManager.ModelFile modelB =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);

        assertThat(modelA).isEqualTo(modelB);
    }

    @Test
    public void modelFile_different() {
        ModelFileManager.ModelFile modelA =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);

        ModelFileManager.ModelFile modelB =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);

        assertThat(modelA).isNotEqualTo(modelB);
    }


    @Test
    public void modelFile_getPath() {
        ModelFileManager.ModelFile modelA =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);

        assertThat(modelA.getPath()).isEqualTo("/path/a");
    }

    @Test
    public void modelFile_getName() {
        ModelFileManager.ModelFile modelA =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);

        assertThat(modelA.getName()).isEqualTo("a");
    }

    @Test
    public void modelFile_isPreferredTo_languageDependentIsBetter() {
        ModelFileManager.ModelFile modelA =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);

        ModelFileManager.ModelFile modelB =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 2,
                        Collections.emptyList(), true);

        assertThat(modelA.isPreferredTo(modelB)).isTrue();
    }

    @Test
    public void modelFile_isPreferredTo_version() {
        ModelFileManager.ModelFile modelA =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 2,
                        Collections.singletonList(Locale.forLanguageTag("ja")), false);

        ModelFileManager.ModelFile modelB =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 1,
                        Collections.emptyList(), false);

        assertThat(modelA.isPreferredTo(modelB)).isTrue();
    }

    @Test
    public void testFileSupplierImpl_updatedFileOnly() throws IOException {
        mUpdatedModelFile.createNewFile();
        File model1 = new File(mFactoryModelDir, "test1.model");
        model1.createNewFile();
        File model2 = new File(mFactoryModelDir, "test2.model");
        model2.createNewFile();
        new File(mFactoryModelDir, "not_match_regex.model").createNewFile();

        List<ModelFileManager.ModelFile> modelFiles = mModelFileSupplierImpl.get();
        List<String> modelFilePaths =
                modelFiles
                        .stream()
                        .map(modelFile -> modelFile.getPath())
                        .collect(Collectors.toList());

        assertThat(modelFiles).hasSize(3);
        assertThat(modelFilePaths).containsExactly(
                mUpdatedModelFile.getAbsolutePath(),
                model1.getAbsolutePath(),
                model2.getAbsolutePath());
    }

    @Test
    public void testFileSupplierImpl_empty() {
        mFactoryModelDir.delete();
        List<ModelFileManager.ModelFile> modelFiles = mModelFileSupplierImpl.get();

        assertThat(modelFiles).hasSize(0);
    }

    private static void recursiveDelete(File f) {
        if (f.isDirectory()) {
            for (File innerFile : f.listFiles()) {
                recursiveDelete(innerFile);
            }
        }
        f.delete();
    }
}