Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit bd70ed4c authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Implement TextClassifierImpl.detectLanguage()"

parents e367c38d ee3a48ee
Loading
Loading
Loading
Loading
+131 −41
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ import android.content.Intent;
import android.content.pm.PackageManager;
import android.content.pm.ResolveInfo;
import android.graphics.drawable.Icon;
import android.icu.util.ULocale;
import android.net.Uri;
import android.os.Bundle;
import android.os.LocaleList;
@@ -45,6 +46,7 @@ import com.android.internal.util.IndentingPrintWriter;
import com.android.internal.util.Preconditions;

import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.LangIdModel;

import java.io.File;
import java.io.FileNotFoundException;
@@ -83,6 +85,9 @@ public final class TextClassifierImpl implements TextClassifier {
    private static final String MODEL_FILE_REGEX = "textclassifier\\.(.*)\\.model";
    private static final String UPDATED_MODEL_FILE_PATH =
            "/data/misc/textclassifier/textclassifier.model";
    private static final String LANG_ID_MODEL_FILE_PATH = "/etc/textclassifier/lang_id.model";
    private static final String UPDATED_LANG_ID_MODEL_FILE_PATH =
            "/data/misc/textclassifier/lang_id.model";

    private final Context mContext;
    private final TextClassifier mFallback;
@@ -94,7 +99,9 @@ public final class TextClassifierImpl implements TextClassifier {
    @GuardedBy("mLock") // Do not access outside this lock.
    private ModelFile mModel;
    @GuardedBy("mLock") // Do not access outside this lock.
    private AnnotatorModel mNative;
    private AnnotatorModel mAnnotatorImpl;
    @GuardedBy("mLock") // Do not access outside this lock.
    private LangIdModel mLangIdImpl;

    private final Object mLoggerLock = new Object();
    @GuardedBy("mLoggerLock") // Do not access outside this lock.
@@ -127,14 +134,15 @@ public final class TextClassifierImpl implements TextClassifier {
                    && rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) {
                final String localesString = concatenateLocales(request.getDefaultLocales());
                final ZonedDateTime refTime = ZonedDateTime.now();
                final AnnotatorModel nativeImpl = getNative(request.getDefaultLocales());
                final AnnotatorModel annotatorImpl =
                        getAnnotatorImpl(request.getDefaultLocales());
                final int start;
                final int end;
                if (mSettings.isModelDarkLaunchEnabled() && !request.isDarkLaunchAllowed()) {
                    start = request.getStartIndex();
                    end = request.getEndIndex();
                } else {
                    final int[] startEnd = nativeImpl.suggestSelection(
                    final int[] startEnd = annotatorImpl.suggestSelection(
                            string, request.getStartIndex(), request.getEndIndex(),
                            new AnnotatorModel.SelectionOptions(localesString));
                    start = startEnd[0];
@@ -145,7 +153,7 @@ public final class TextClassifierImpl implements TextClassifier {
                        && start <= request.getStartIndex() && end >= request.getEndIndex()) {
                    final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
                    final AnnotatorModel.ClassificationResult[] results =
                            nativeImpl.classifyText(
                            annotatorImpl.classifyText(
                                    string, start, end,
                                    new AnnotatorModel.ClassificationOptions(
                                            refTime.toInstant().toEpochMilli(),
@@ -187,7 +195,7 @@ public final class TextClassifierImpl implements TextClassifier {
                final ZonedDateTime refTime = request.getReferenceTime() != null
                        ? request.getReferenceTime() : ZonedDateTime.now();
                final AnnotatorModel.ClassificationResult[] results =
                        getNative(request.getDefaultLocales())
                        getAnnotatorImpl(request.getDefaultLocales())
                                .classifyText(
                                        string, request.getStartIndex(), request.getEndIndex(),
                                        new AnnotatorModel.ClassificationOptions(
@@ -230,10 +238,10 @@ public final class TextClassifierImpl implements TextClassifier {
                    ? request.getEntityConfig().resolveEntityListModifications(
                            getEntitiesForHints(request.getEntityConfig().getHints()))
                    : mSettings.getEntityListDefault();
            final AnnotatorModel nativeImpl =
                    getNative(request.getDefaultLocales());
            final AnnotatorModel annotatorImpl =
                    getAnnotatorImpl(request.getDefaultLocales());
            final AnnotatorModel.AnnotatedSpan[] annotations =
                    nativeImpl.annotate(
                    annotatorImpl.annotate(
                        textString,
                        new AnnotatorModel.AnnotationOptions(
                                refTime.toInstant().toEpochMilli(),
@@ -288,6 +296,7 @@ public final class TextClassifierImpl implements TextClassifier {
        }
    }

    /** @inheritDoc */
    @Override
    public void onSelectionEvent(SelectionEvent event) {
        Preconditions.checkNotNull(event);
@@ -299,7 +308,29 @@ public final class TextClassifierImpl implements TextClassifier {
        }
    }

    private AnnotatorModel getNative(LocaleList localeList)
    /** @inheritDoc */
    @Override
    public TextLanguage detectLanguage(@NonNull TextLanguage.Request request) {
        Preconditions.checkNotNull(request);
        Utils.checkMainThread();
        try {
            final TextLanguage.Builder builder = new TextLanguage.Builder();
            final LangIdModel.LanguageResult[] langResults =
                    getLangIdImpl().detectLanguages(request.getText().toString());
            for (int i = 0; i < langResults.length; i++) {
                builder.putLocale(
                        ULocale.forLanguageTag(langResults[i].getLanguage()),
                        langResults[i].getScore());
            }
            return builder.build();
        } catch (Throwable t) {
            // Avoid throwing from this method. Log the error.
            Log.e(LOG_TAG, "Error detecting text language.", t);
        }
        return mFallback.detectLanguage(request);
    }

    private AnnotatorModel getAnnotatorImpl(LocaleList localeList)
            throws FileNotFoundException {
        synchronized (mLock) {
            localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
@@ -307,31 +338,79 @@ public final class TextClassifierImpl implements TextClassifier {
            if (bestModel == null) {
                throw new FileNotFoundException("No model for " + localeList.toLanguageTags());
            }
            if (mNative == null || !Objects.equals(mModel, bestModel)) {
            if (mAnnotatorImpl == null || !Objects.equals(mModel, bestModel)) {
                Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel);
                destroyNativeIfExistsLocked();
                destroyAnnotatorImplIfExistsLocked();
                final ParcelFileDescriptor fd = ParcelFileDescriptor.open(
                        new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
                mNative = new AnnotatorModel(fd.getFd());
                closeAndLogError(fd);
                try {
                    if (fd != null) {
                        mAnnotatorImpl = new AnnotatorModel(fd.getFd());
                        mModel = bestModel;
                    }
            return mNative;
                } finally {
                    maybeCloseAndLogError(fd);
                }
            }
            return mAnnotatorImpl;
        }
    }

    private String createId(String text, int start, int end) {
    @GuardedBy("mLock") // Do not call outside this lock.
    private void destroyAnnotatorImplIfExistsLocked() {
        if (mAnnotatorImpl != null) {
            mAnnotatorImpl.close();
            mAnnotatorImpl = null;
        }
    }

    private LangIdModel getLangIdImpl() throws FileNotFoundException {
        synchronized (mLock) {
            return SelectionSessionLogger.createId(text, start, end, mContext, mModel.getVersion(),
                    mModel.getSupportedLocales());
            if (mLangIdImpl == null) {
                ParcelFileDescriptor factoryFd = null;
                ParcelFileDescriptor updateFd = null;
                try {
                    int factoryVersion = -1;
                    int updateVersion = factoryVersion;
                    final File factoryFile = new File(LANG_ID_MODEL_FILE_PATH);
                    if (factoryFile.exists()) {
                        factoryFd = ParcelFileDescriptor.open(
                                factoryFile, ParcelFileDescriptor.MODE_READ_ONLY);
                        // TODO: Uncomment when method is implemented:
                        // if (factoryFd != null) {
                        //     factoryVersion = LangIdModel.getVersion(factoryFd.getFd());
                        // }
                    }
                    final File updateFile = new File(UPDATED_LANG_ID_MODEL_FILE_PATH);
                    if (updateFile.exists()) {
                        updateFd = ParcelFileDescriptor.open(
                                updateFile, ParcelFileDescriptor.MODE_READ_ONLY);
                        // TODO: Uncomment when method is implemented:
                        // if (updateFd != null) {
                        //     updateVersion = LangIdModel.getVersion(updateFd.getFd());
                        // }
                    }

                    if (updateVersion > factoryVersion) {
                        mLangIdImpl = new LangIdModel(updateFd.getFd());
                    } else if (factoryFd != null) {
                        mLangIdImpl = new LangIdModel(factoryFd.getFd());
                    } else {
                        throw new FileNotFoundException("Language detection model not found");
                    }
                } finally {
                    maybeCloseAndLogError(factoryFd);
                    maybeCloseAndLogError(updateFd);
                }
            }
            return mLangIdImpl;
        }
    }

    @GuardedBy("mLock") // Do not call outside this lock.
    private void destroyNativeIfExistsLocked() {
        if (mNative != null) {
            mNative.close();
            mNative = null;
    private String createId(String text, int start, int end) {
        synchronized (mLock) {
            return SelectionSessionLogger.createId(text, start, end, mContext, mModel.getVersion(),
                    mModel.getSupportedLocales());
        }
    }

@@ -407,20 +486,19 @@ public final class TextClassifierImpl implements TextClassifier {
                .setText(classifiedText);

        final int size = classifications.length;
        AnnotatorModel.ClassificationResult highestScoringResult = null;
        float highestScore = Float.MIN_VALUE;
        AnnotatorModel.ClassificationResult highestScoringResult =
                size > 0 ? classifications[0] : null;
        for (int i = 0; i < size; i++) {
            builder.setEntityType(classifications[i].getCollection(),
                                  classifications[i].getScore());
            if (classifications[i].getScore() > highestScore) {
            if (classifications[i].getScore() > highestScoringResult.getScore()) {
                highestScoringResult = classifications[i];
                highestScore = classifications[i].getScore();
            }
        }

        boolean isPrimaryAction = true;
        for (LabeledIntent labeledIntent : IntentFactory.create(
                mContext, referenceTime, highestScoringResult, classifiedText)) {
                mContext, classifiedText, referenceTime, highestScoringResult)) {
            final RemoteAction action = labeledIntent.asRemoteAction(mContext);
            if (action == null) {
                continue;
@@ -461,9 +539,13 @@ public final class TextClassifierImpl implements TextClassifier {
    }

    /**
     * Closes the ParcelFileDescriptor and logs any errors that occur.
     * Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur.
     */
    private static void closeAndLogError(ParcelFileDescriptor fd) {
    private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
        if (fd == null) {
            return;
        }

        try {
            fd.close();
        } catch (IOException e) {
@@ -485,12 +567,17 @@ public final class TextClassifierImpl implements TextClassifier {
        /** Returns null if the path did not point to a compatible model. */
        static @Nullable ModelFile fromPath(String path) {
            final File file = new File(path);
            if (!file.exists()) {
                return null;
            }
            ParcelFileDescriptor modelFd = null;
            try {
                final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open(
                        file, ParcelFileDescriptor.MODE_READ_ONLY);
                modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
                if (modelFd == null) {
                    return null;
                }
                final int version = AnnotatorModel.getVersion(modelFd.getFd());
                final String supportedLocalesStr =
                        AnnotatorModel.getLocales(modelFd.getFd());
                final String supportedLocalesStr = AnnotatorModel.getLocales(modelFd.getFd());
                if (supportedLocalesStr.isEmpty()) {
                    Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
                    return null;
@@ -500,12 +587,13 @@ public final class TextClassifierImpl implements TextClassifier {
                for (String langTag : supportedLocalesStr.split(",")) {
                    supportedLocales.add(Locale.forLanguageTag(langTag));
                }
                closeAndLogError(modelFd);
                return new ModelFile(path, file.getName(), version, supportedLocales,
                                     languageIndependent);
            } catch (FileNotFoundException e) {
                Log.e(DEFAULT_LOG_TAG, "Failed to peek " + file.getAbsolutePath(), e);
                return null;
            } finally {
                maybeCloseAndLogError(modelFd);
            }
        }

@@ -557,12 +645,12 @@ public final class TextClassifierImpl implements TextClassifier {
        public boolean equals(Object other) {
            if (this == other) {
                return true;
            } else if (other == null || !ModelFile.class.isAssignableFrom(other.getClass())) {
                return false;
            } else {
            }
            if (other instanceof ModelFile) {
                final ModelFile otherModel = (ModelFile) other;
                return mPath.equals(otherModel.mPath);
            }
            return false;
        }

        @Override
@@ -677,10 +765,12 @@ public final class TextClassifierImpl implements TextClassifier {
        @NonNull
        public static List<LabeledIntent> create(
                Context context,
                String text,
                @Nullable Instant referenceTime,
                AnnotatorModel.ClassificationResult classification,
                String text) {
            final String type = classification.getCollection().trim().toLowerCase(Locale.ENGLISH);
                @Nullable AnnotatorModel.ClassificationResult classification) {
            final String type = classification != null
                    ? classification.getCollection().trim().toLowerCase(Locale.ENGLISH)
                    : null;
            text = text.trim();
            switch (type) {
                case TextClassifier.TYPE_EMAIL:
+37 −0
Original line number Diff line number Diff line
@@ -306,6 +306,24 @@ public class TextClassificationManagerTest {
        mClassifier.generateLinks(request);
    }

    @Test
    public void testDetectLanguage() {
        if (isTextClassifierDisabled()) return;
        String text = "This is English text";
        TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
        TextLanguage textLanguage = mClassifier.detectLanguage(request);
        assertThat(textLanguage, isTextLanguage("en"));
    }

    @Test
    public void testDetectLanguage_japanese() {
        if (isTextClassifierDisabled()) return;
        String text = "これは日本語のテキストです";
        TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
        TextLanguage textLanguage = mClassifier.detectLanguage(request);
        assertThat(textLanguage, isTextLanguage("ja"));
    }

    @Test
    public void testSetTextClassifier() {
        TextClassifier classifier = mock(TextClassifier.class);
@@ -444,4 +462,23 @@ public class TextClassificationManagerTest {
            }
        };
    }

    private static Matcher<TextLanguage> isTextLanguage(final String languageTag) {
        return new BaseMatcher<TextLanguage>() {
            @Override
            public boolean matches(Object o) {
                if (o instanceof TextLanguage) {
                    TextLanguage result = (TextLanguage) o;
                    return result.getLocaleHypothesisCount() > 0
                            && languageTag.equals(result.getLocale(0).toLanguageTag());
                }
                return false;
            }

            @Override
            public void describeTo(Description description) {
                description.appendText("locale=").appendValue(languageTag);
            }
        };
    }
}