Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ea26b2a4 authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Allowing models to support multiple languages"

parents 348f19e6 ef0156d7
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -108,9 +108,9 @@ final class SmartSelection {
    }

    /**
     * Returns the language of the model.
     * Returns a comma separated list of locales supported by the model as BCP 47 tags.
     */
    public static String getLanguage(int fd) {
    public static String getLanguages(int fd) {
        return nativeGetLanguage(fd);
    }

+142 −114
Original line number Diff line number Diff line
@@ -58,6 +58,7 @@ import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.StringJoiner;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -101,11 +102,9 @@ public final class TextClassifierImpl implements TextClassifier {

    private final Object mLock = new Object();
    @GuardedBy("mLock") // Do not access outside this lock.
    private Map<Locale, String> mModelFilePaths;
    private List<ModelFile> mAllModelFiles;
    @GuardedBy("mLock") // Do not access outside this lock.
    private Locale mLocale;
    @GuardedBy("mLock") // Do not access outside this lock.
    private int mVersion;
    private ModelFile mModel;
    @GuardedBy("mLock") // Do not access outside this lock.
    private SmartSelection mSmartSelection;

@@ -281,18 +280,18 @@ public final class TextClassifierImpl implements TextClassifier {
    private SmartSelection getSmartSelection(LocaleList localeList) throws FileNotFoundException {
        synchronized (mLock) {
            localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
            final Locale locale = findBestSupportedLocaleLocked(localeList);
            if (locale == null) {
                throw new FileNotFoundException("No file for null locale");
            final ModelFile bestModel = findBestModelLocked(localeList);
            if (bestModel == null) {
                throw new FileNotFoundException("No model for " + localeList.toLanguageTags());
            }
            if (mSmartSelection == null || !Objects.equals(mLocale, locale)) {
            if (mSmartSelection == null || !Objects.equals(mModel, bestModel)) {
                Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel);
                destroySmartSelectionIfExistsLocked();
                final ParcelFileDescriptor fd = getFdLocked(locale);
                final int modelFd = fd.getFd();
                mVersion = SmartSelection.getVersion(modelFd);
                mSmartSelection = new SmartSelection(modelFd);
                final ParcelFileDescriptor fd = ParcelFileDescriptor.open(
                        new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
                mSmartSelection = new SmartSelection(fd.getFd());
                closeAndLogError(fd);
                mLocale = locale;
                mModel = bestModel;
            }
            return mSmartSelection;
        }
@@ -300,74 +299,8 @@ public final class TextClassifierImpl implements TextClassifier {

    private String getSignature(String text, int start, int end) {
        synchronized (mLock) {
            return DefaultLogger.createSignature(text, start, end, mContext, mVersion, mLocale);
        }
    }

    @GuardedBy("mLock") // Do not call outside this lock.
    private ParcelFileDescriptor getFdLocked(Locale locale) throws FileNotFoundException {
        ParcelFileDescriptor updateFd;
        int updateVersion = -1;
        try {
            updateFd = ParcelFileDescriptor.open(
                    new File(UPDATED_MODEL_FILE_PATH), ParcelFileDescriptor.MODE_READ_ONLY);
            if (updateFd != null) {
                updateVersion = SmartSelection.getVersion(updateFd.getFd());
            }
        } catch (FileNotFoundException e) {
            updateFd = null;
        }
        ParcelFileDescriptor factoryFd;
        int factoryVersion = -1;
        try {
            final String factoryModelFilePath = getFactoryModelFilePathsLocked().get(locale);
            if (factoryModelFilePath != null) {
                factoryFd = ParcelFileDescriptor.open(
                        new File(factoryModelFilePath), ParcelFileDescriptor.MODE_READ_ONLY);
                if (factoryFd != null) {
                    factoryVersion = SmartSelection.getVersion(factoryFd.getFd());
                }
            } else {
                factoryFd = null;
            }
        } catch (FileNotFoundException e) {
            factoryFd = null;
        }

        if (updateFd == null) {
            if (factoryFd != null) {
                return factoryFd;
            } else {
                throw new FileNotFoundException(
                        String.format(Locale.US, "No model file found for %s", locale));
            }
        }

        final int updateFdInt = updateFd.getFd();
        final boolean localeMatches = Objects.equals(
                locale.getLanguage().trim().toLowerCase(),
                SmartSelection.getLanguage(updateFdInt).trim().toLowerCase());
        if (factoryFd == null) {
            if (localeMatches) {
                return updateFd;
            } else {
                closeAndLogError(updateFd);
                throw new FileNotFoundException(
                        String.format(Locale.US, "No model file found for %s", locale));
            }
        }

        if (!localeMatches) {
            closeAndLogError(updateFd);
            return factoryFd;
        }

        if (updateVersion > factoryVersion) {
            closeAndLogError(factoryFd);
            return updateFd;
        } else {
            closeAndLogError(updateFd);
            return factoryFd;
            return DefaultLogger.createSignature(text, start, end, mContext, mModel.getVersion(),
                    mModel.getSupportedLocales());
        }
    }

@@ -379,60 +312,66 @@ public final class TextClassifierImpl implements TextClassifier {
        }
    }

    /**
     * Finds the most appropriate model to use for the given target locale list.
     *
     * The basic logic is: we ignore all models that don't support any of the target locales. For
     * the remaining candidates, we take the update model unless its version number is lower than
     * the factory version. It's assumed that factory models do not have overlapping locale ranges
     * and conflict resolution between these models hence doesn't matter.
     */
    @GuardedBy("mLock") // Do not call outside this lock.
    @Nullable
    private Locale findBestSupportedLocaleLocked(LocaleList localeList) {
    private ModelFile findBestModelLocked(LocaleList localeList) {
        // Specified localeList takes priority over the system default, so it is listed first.
        final String languages = localeList.isEmpty()
                ? LocaleList.getDefault().toLanguageTags()
                : localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags();
        final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);

        final List<Locale> supportedLocales =
                new ArrayList<>(getFactoryModelFilePathsLocked().keySet());
        final Locale updatedModelLocale = getUpdatedModelLocale();
        if (updatedModelLocale != null) {
            supportedLocales.add(updatedModelLocale);
        ModelFile bestModel = null;
        int bestModelVersion = -1;
        for (ModelFile model : listAllModelsLocked()) {
            if (model.isAnyLanguageSupported(languageRangeList)) {
                if (model.getVersion() >= bestModelVersion) {
                    bestModel = model;
                    bestModelVersion = model.getVersion();
                }
            }
        }
        return Locale.lookup(languageRangeList, supportedLocales);
        return bestModel;
    }

    /** Returns a list of all model files available, in order of precedence. */
    @GuardedBy("mLock") // Do not call outside this lock.
    private Map<Locale, String> getFactoryModelFilePathsLocked() {
        if (mModelFilePaths == null) {
            final Map<Locale, String> modelFilePaths = new HashMap<>();
    private List<ModelFile> listAllModelsLocked() {
        if (mAllModelFiles == null) {
            final List<ModelFile> allModels = new ArrayList<>();
            // The update model has the highest precedence.
            if (new File(UPDATED_MODEL_FILE_PATH).exists()) {
                final ModelFile updatedModel = ModelFile.fromPath(UPDATED_MODEL_FILE_PATH);
                if (updatedModel != null) {
                    allModels.add(updatedModel);
                }
            }
            // Factory models should never have overlapping locales, so the order doesn't matter.
            final File modelsDir = new File(MODEL_DIR);
            if (modelsDir.exists() && modelsDir.isDirectory()) {
                final File[] models = modelsDir.listFiles();
                final File[] modelFiles = modelsDir.listFiles();
                final Pattern modelFilenamePattern = Pattern.compile(MODEL_FILE_REGEX);
                final int size = models.length;
                for (int i = 0; i < size; i++) {
                    final File modelFile = models[i];
                for (File modelFile : modelFiles) {
                    final Matcher matcher = modelFilenamePattern.matcher(modelFile.getName());
                    if (matcher.matches() && modelFile.isFile()) {
                        final String language = matcher.group(1);
                        final Locale locale = Locale.forLanguageTag(language);
                        modelFilePaths.put(locale, modelFile.getAbsolutePath());
                    }
                        final ModelFile model = ModelFile.fromPath(modelFile.getAbsolutePath());
                        if (model != null) {
                            allModels.add(model);
                        }
                    }
            mModelFilePaths = modelFilePaths;
                }
        return mModelFilePaths;
            }

    @Nullable
    private Locale getUpdatedModelLocale() {
        try {
            final ParcelFileDescriptor updateFd = ParcelFileDescriptor.open(
                    new File(UPDATED_MODEL_FILE_PATH), ParcelFileDescriptor.MODE_READ_ONLY);
            final Locale locale = Locale.forLanguageTag(
                    SmartSelection.getLanguage(updateFd.getFd()));
            closeAndLogError(updateFd);
            return locale;
        } catch (FileNotFoundException e) {
            return null;
            mAllModelFiles = allModels;
        }
        return mAllModelFiles;
    }

    private TextClassification createClassificationResult(
@@ -521,6 +460,95 @@ public final class TextClassifierImpl implements TextClassifier {
        }
    }

    /**
     * Describes TextClassifier model files on disk.
     */
    private static final class ModelFile {

        private final String mPath;
        private final String mName;
        private final int mVersion;
        private final List<Locale> mSupportedLocales;

        /** Returns null if the path did not point to a compatible model. */
        static @Nullable ModelFile fromPath(String path) {
            final File file = new File(path);
            try {
                final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open(
                        file, ParcelFileDescriptor.MODE_READ_ONLY);
                final int version = SmartSelection.getVersion(modelFd.getFd());
                final String supportedLocalesStr = SmartSelection.getLanguages(modelFd.getFd());
                if (supportedLocalesStr.isEmpty()) {
                    Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
                    return null;
                }
                final List<Locale> supportedLocales = new ArrayList<>();
                for (String langTag : supportedLocalesStr.split(",")) {
                    supportedLocales.add(Locale.forLanguageTag(langTag));
                }
                closeAndLogError(modelFd);
                return new ModelFile(path, file.getName(), version, supportedLocales);
            } catch (FileNotFoundException e) {
                Log.e(DEFAULT_LOG_TAG, "Failed to peek " + file.getAbsolutePath(), e);
                return null;
            }
        }

        /** The absolute path to the model file. */
        String getPath() {
            return mPath;
        }

        /** A name to use for signature generation. Effectively the name of the model file. */
        String getName() {
            return mName;
        }

        /** Returns the version tag in the model's metadata. */
        int getVersion() {
            return mVersion;
        }

        /** Returns whether the language supports any language in the given ranges. */
        boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
            return Locale.lookup(languageRanges, mSupportedLocales) != null;
        }

        /** All locales supported by the model. */
        List<Locale> getSupportedLocales() {
            return Collections.unmodifiableList(mSupportedLocales);
        }

        @Override
        public boolean equals(Object other) {
            if (this == other) {
                return true;
            } else if (other == null || !ModelFile.class.isAssignableFrom(other.getClass())) {
                return false;
            } else {
                final ModelFile otherModel = (ModelFile) other;
                return mPath.equals(otherModel.mPath);
            }
        }

        @Override
        public String toString() {
            final StringJoiner localesJoiner = new StringJoiner(",");
            for (Locale locale : mSupportedLocales) {
                localesJoiner.add(locale.toLanguageTag());
            }
            return String.format(Locale.US, "ModelFile { path=%s name=%s version=%d locales=%s }",
                    mPath, mName, mVersion, localesJoiner.toString());
        }

        private ModelFile(String path, String name, int version, List<Locale> supportedLocales) {
            mPath = path;
            mName = name;
            mVersion = version;
            mSupportedLocales = supportedLocales;
        }
    }

    /**
     * Creates intents based on the classification type.
     */
+12 −7
Original line number Diff line number Diff line
@@ -17,7 +17,6 @@
package android.view.textclassifier.logging;

import android.annotation.NonNull;
import android.annotation.Nullable;
import android.content.Context;
import android.metrics.LogMaker;
import android.util.Log;
@@ -27,8 +26,10 @@ import com.android.internal.logging.MetricsLogger;
import com.android.internal.logging.nano.MetricsProto.MetricsEvent;
import com.android.internal.util.Preconditions;

import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.StringJoiner;

/**
 * Default Logger.
@@ -210,12 +211,16 @@ public final class DefaultLogger extends Logger {
     */
    public static String createSignature(
            String text, int start, int end, Context context, int modelVersion,
            @Nullable Locale locale) {
            List<Locale> locales) {
        Preconditions.checkNotNull(text);
        Preconditions.checkNotNull(context);
        final String modelName = (locale != null)
                ? String.format(Locale.US, "%s_v%d", locale.toLanguageTag(), modelVersion)
                : "";
        Preconditions.checkNotNull(locales);
        final StringJoiner localesJoiner = new StringJoiner(",");
        for (Locale locale : locales) {
            localesJoiner.add(locale.toLanguageTag());
        }
        final String modelName = String.format(Locale.US, "%s_v%d", localesJoiner.toString(),
                modelVersion);
        final int hash = Objects.hash(text, start, end, context.getPackageName());
        return SignatureParser.createSignature(CLASSIFIER_ID, modelName, hash);
    }
@@ -242,9 +247,9 @@ public final class DefaultLogger extends Logger {

        static String getModelName(String signature) {
            Preconditions.checkNotNull(signature);
            final int start = signature.indexOf("|");
            final int start = signature.indexOf("|") + 1;
            final int end = signature.indexOf("|", start);
            if (start >= 0 && end >= start) {
            if (start >= 1 && end >= start) {
                return signature.substring(start, end);
            }
            return "";