Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit dd83c200 authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "TextClassifier: Switch model based on locale" into oc-dev

parents 9c847fdc c39006a1
Loading
Loading
Loading
Loading
+1 −11
Original line number Diff line number Diff line
@@ -44,8 +44,6 @@ public final class TextClassificationManager {
    private final Object mLangIdLock = new Object();

    private final Context mContext;
    // TODO: Implement a way to close the file descriptors.
    private ParcelFileDescriptor mSmartSelectionFd;
    private ParcelFileDescriptor mLangIdFd;
    private TextClassifier mDefault;
    private LangId mLangId;
@@ -61,15 +59,7 @@ public final class TextClassificationManager {
    public TextClassifier getDefaultTextClassifier() {
        synchronized (mTextClassifierLock) {
            if (mDefault == null) {
                try {
                    mSmartSelectionFd = ParcelFileDescriptor.open(
                            new File("/etc/textclassifier/textclassifier.smartselection.en.model"),
                            ParcelFileDescriptor.MODE_READ_ONLY);
                    mDefault = new TextClassifierImpl(mContext, mSmartSelectionFd);
                } catch (FileNotFoundException e) {
                    Log.e(LOG_TAG, "Error accessing 'text classifier selection' model file.", e);
                    mDefault = TextClassifier.NO_OP;
                }
                mDefault = new TextClassifierImpl(mContext);
            }
            return mDefault;
        }
+80 −13
Original line number Diff line number Diff line
@@ -38,17 +38,24 @@ import android.util.Log;
import android.util.Patterns;
import android.view.View;

import com.android.internal.annotations.GuardedBy;
import com.android.internal.util.Preconditions;

import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.StringJoiner;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * Default implementation of the {@link TextClassifier} interface.
@@ -62,16 +69,21 @@ import java.util.Map;
final class TextClassifierImpl implements TextClassifier {

    private static final String LOG_TAG = "TextClassifierImpl";

    private final Object mSmartSelectionLock = new Object();
    private static final String MODEL_DIR = "/etc/textclassifier/";
    private static final String MODEL_FILE_REGEX = "textclassifier\\.smartselection\\.(.*)\\.model";

    private final Context mContext;
    private final ParcelFileDescriptor mFd;

    private final Object mSmartSelectionLock = new Object();
    @GuardedBy("mSmartSelectionLock") // Do not access outside this lock.
    private Map<Locale, String> mModelFilePaths;
    @GuardedBy("mSmartSelectionLock") // Do not access outside this lock.
    private Locale mLocale;
    @GuardedBy("mSmartSelectionLock") // Do not access outside this lock.
    private SmartSelection mSmartSelection;

    TextClassifierImpl(Context context, ParcelFileDescriptor fd) {
    TextClassifierImpl(Context context) {
        mContext = Preconditions.checkNotNull(context);
        mFd = Preconditions.checkNotNull(fd);
    }

    @Override
@@ -81,15 +93,16 @@ final class TextClassifierImpl implements TextClassifier {
        validateInput(text, selectionStartIndex, selectionEndIndex);
        try {
            if (text.length() > 0) {
                final SmartSelection smartSelection = getSmartSelection(defaultLocales);
                final String string = text.toString();
                final int[] startEnd = getSmartSelection()
                        .suggest(string, selectionStartIndex, selectionEndIndex);
                final int[] startEnd = smartSelection.suggest(
                        string, selectionStartIndex, selectionEndIndex);
                final int start = startEnd[0];
                final int end = startEnd[1];
                if (start >= 0 && end <= string.length() && start <= end) {
                    final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
                    final SmartSelection.ClassificationResult[] results =
                            getSmartSelection().classifyText(
                            smartSelection.classifyText(
                                    string, start, end,
                                    getHintFlags(string, start, end));
                    final int size = results.length;
@@ -120,7 +133,7 @@ final class TextClassifierImpl implements TextClassifier {
        try {
            if (text.length() > 0) {
                final String string = text.toString();
                SmartSelection.ClassificationResult[] results = getSmartSelection()
                SmartSelection.ClassificationResult[] results = getSmartSelection(defaultLocales)
                        .classifyText(string, startIndex, endIndex,
                                getHintFlags(string, startIndex, endIndex));
                if (results.length > 0) {
@@ -147,7 +160,7 @@ final class TextClassifierImpl implements TextClassifier {
        Preconditions.checkArgument(text != null);
        try {
            return LinksInfoFactory.create(
                    mContext, getSmartSelection(), text.toString(), linkMask);
                    mContext, getSmartSelection(defaultLocales), text.toString(), linkMask);
        } catch (Throwable t) {
            // Avoid throwing from this method. Log the error.
            Log.e(LOG_TAG, "Error getting links info.", t);
@@ -156,15 +169,69 @@ final class TextClassifierImpl implements TextClassifier {
        return TextClassifier.NO_OP.getLinks(text, linkMask, defaultLocales);
    }

    private SmartSelection getSmartSelection() throws FileNotFoundException {
    private SmartSelection getSmartSelection(LocaleList localeList) throws FileNotFoundException {
        synchronized (mSmartSelectionLock) {
            if (mSmartSelection == null) {
                mSmartSelection = new SmartSelection(mFd.getFd());
            localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
            final Locale locale = findBestSupportedLocaleLocked(localeList);
            if (mSmartSelection == null || !Objects.equals(mLocale, locale)) {
                destroySmartSelectionIfExistsLocked();
                mSmartSelection = new SmartSelection(
                        ParcelFileDescriptor.open(
                                // findBestSupportedLocaleLocked should have initialized
                                // mModelFilePaths
                                new File(mModelFilePaths.get(locale)),
                                ParcelFileDescriptor.MODE_READ_ONLY)
                                .getFd());
                mLocale = locale;
            }
            return mSmartSelection;
        }
    }

    @GuardedBy("mSmartSelectionLock") // Do not call outside this lock.
    private void destroySmartSelectionIfExistsLocked() {
        if (mSmartSelection != null) {
            mSmartSelection.close();
            mSmartSelection = null;
        }
    }

    @GuardedBy("mSmartSelectionLock") // Do not call outside this lock.
    @Nullable
    private Locale findBestSupportedLocaleLocked(LocaleList localeList) {
        final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(
                new StringJoiner(",")
                        // Specified localeList takes priority over the system default
                        .add(localeList.toLanguageTags())
                        .add(LocaleList.getDefault().toLanguageTags())
                        .toString());
        return Locale.lookup(languageRangeList, loadModelFilePathsLocked().keySet());
    }

    @GuardedBy("mSmartSelectionLock") // Do not call outside this lock.
    private Map<Locale, String> loadModelFilePathsLocked() {
        if (mModelFilePaths == null) {
            final Map<Locale, String> modelFilePaths = new HashMap<>();
            final File modelsDir = new File(MODEL_DIR);
            if (modelsDir.exists() && modelsDir.isDirectory()) {
                final File[] models = modelsDir.listFiles();
                final Pattern modelFilenamePattern = Pattern.compile(MODEL_FILE_REGEX);
                final int size = models.length;
                for (int i = 0; i < size; i++) {
                    final File modelFile = models[i];
                    final Matcher matcher = modelFilenamePattern.matcher(modelFile.getName());
                    if (matcher.matches() && modelFile.isFile()) {
                        final String language = matcher.group(1);
                        final Locale locale = Locale.forLanguageTag(language);
                        modelFilePaths.put(locale, modelFile.getAbsolutePath());
                    }
                }
            }
            mModelFilePaths = modelFilePaths;
        }
        return mModelFilePaths;
    }

    private TextClassificationResult createClassificationResult(
            SmartSelection.ClassificationResult[] classifications, CharSequence text) {
        final TextClassificationResult.Builder builder = new TextClassificationResult.Builder()