Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ee3a48ee authored by Abodunrinwa Toki's avatar Abodunrinwa Toki
Browse files

Implement TextClassifierImpl.detectLanguage()

- Includes some fixes to handle null ParcelFileDescriptors.
- Closes fds immediately after the model has been loaded.

Bug: 116020587
Test: atest android.view.textclassifier.TextClassificationManagerTest
Change-Id: Ieb05d081847ac218d2a5b46db95cd512838f67ab
parent c2896a27
Loading
Loading
Loading
Loading
+131 −41
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ import android.content.Intent;
import android.content.pm.PackageManager;
import android.content.pm.ResolveInfo;
import android.graphics.drawable.Icon;
import android.icu.util.ULocale;
import android.net.Uri;
import android.os.Bundle;
import android.os.LocaleList;
@@ -45,6 +46,7 @@ import com.android.internal.util.IndentingPrintWriter;
import com.android.internal.util.Preconditions;

import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.LangIdModel;

import java.io.File;
import java.io.FileNotFoundException;
@@ -83,6 +85,9 @@ public final class TextClassifierImpl implements TextClassifier {
    private static final String MODEL_FILE_REGEX = "textclassifier\\.(.*)\\.model";
    private static final String UPDATED_MODEL_FILE_PATH =
            "/data/misc/textclassifier/textclassifier.model";
    private static final String LANG_ID_MODEL_FILE_PATH = "/etc/textclassifier/lang_id.model";
    private static final String UPDATED_LANG_ID_MODEL_FILE_PATH =
            "/data/misc/textclassifier/lang_id.model";

    private final Context mContext;
    private final TextClassifier mFallback;
@@ -94,7 +99,9 @@ public final class TextClassifierImpl implements TextClassifier {
    @GuardedBy("mLock") // Do not access outside this lock.
    private ModelFile mModel;
    @GuardedBy("mLock") // Do not access outside this lock.
    private AnnotatorModel mNative;
    private AnnotatorModel mAnnotatorImpl;
    @GuardedBy("mLock") // Do not access outside this lock.
    private LangIdModel mLangIdImpl;

    private final Object mLoggerLock = new Object();
    @GuardedBy("mLoggerLock") // Do not access outside this lock.
@@ -127,14 +134,15 @@ public final class TextClassifierImpl implements TextClassifier {
                    && rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) {
                final String localesString = concatenateLocales(request.getDefaultLocales());
                final ZonedDateTime refTime = ZonedDateTime.now();
                final AnnotatorModel nativeImpl = getNative(request.getDefaultLocales());
                final AnnotatorModel annotatorImpl =
                        getAnnotatorImpl(request.getDefaultLocales());
                final int start;
                final int end;
                if (mSettings.isModelDarkLaunchEnabled() && !request.isDarkLaunchAllowed()) {
                    start = request.getStartIndex();
                    end = request.getEndIndex();
                } else {
                    final int[] startEnd = nativeImpl.suggestSelection(
                    final int[] startEnd = annotatorImpl.suggestSelection(
                            string, request.getStartIndex(), request.getEndIndex(),
                            new AnnotatorModel.SelectionOptions(localesString));
                    start = startEnd[0];
@@ -145,7 +153,7 @@ public final class TextClassifierImpl implements TextClassifier {
                        && start <= request.getStartIndex() && end >= request.getEndIndex()) {
                    final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
                    final AnnotatorModel.ClassificationResult[] results =
                            nativeImpl.classifyText(
                            annotatorImpl.classifyText(
                                    string, start, end,
                                    new AnnotatorModel.ClassificationOptions(
                                            refTime.toInstant().toEpochMilli(),
@@ -187,7 +195,7 @@ public final class TextClassifierImpl implements TextClassifier {
                final ZonedDateTime refTime = request.getReferenceTime() != null
                        ? request.getReferenceTime() : ZonedDateTime.now();
                final AnnotatorModel.ClassificationResult[] results =
                        getNative(request.getDefaultLocales())
                        getAnnotatorImpl(request.getDefaultLocales())
                                .classifyText(
                                        string, request.getStartIndex(), request.getEndIndex(),
                                        new AnnotatorModel.ClassificationOptions(
@@ -230,10 +238,10 @@ public final class TextClassifierImpl implements TextClassifier {
                    ? request.getEntityConfig().resolveEntityListModifications(
                            getEntitiesForHints(request.getEntityConfig().getHints()))
                    : mSettings.getEntityListDefault();
            final AnnotatorModel nativeImpl =
                    getNative(request.getDefaultLocales());
            final AnnotatorModel annotatorImpl =
                    getAnnotatorImpl(request.getDefaultLocales());
            final AnnotatorModel.AnnotatedSpan[] annotations =
                    nativeImpl.annotate(
                    annotatorImpl.annotate(
                        textString,
                        new AnnotatorModel.AnnotationOptions(
                                refTime.toInstant().toEpochMilli(),
@@ -288,6 +296,7 @@ public final class TextClassifierImpl implements TextClassifier {
        }
    }

    /** @inheritDoc */
    @Override
    public void onSelectionEvent(SelectionEvent event) {
        Preconditions.checkNotNull(event);
@@ -299,7 +308,29 @@ public final class TextClassifierImpl implements TextClassifier {
        }
    }

    private AnnotatorModel getNative(LocaleList localeList)
    /** @inheritDoc */
    @Override
    public TextLanguage detectLanguage(@NonNull TextLanguage.Request request) {
        Preconditions.checkNotNull(request);
        Utils.checkMainThread();
        try {
            final TextLanguage.Builder builder = new TextLanguage.Builder();
            final LangIdModel.LanguageResult[] langResults =
                    getLangIdImpl().detectLanguages(request.getText().toString());
            for (int i = 0; i < langResults.length; i++) {
                builder.putLocale(
                        ULocale.forLanguageTag(langResults[i].getLanguage()),
                        langResults[i].getScore());
            }
            return builder.build();
        } catch (Throwable t) {
            // Avoid throwing from this method. Log the error.
            Log.e(LOG_TAG, "Error detecting text language.", t);
        }
        return mFallback.detectLanguage(request);
    }

    private AnnotatorModel getAnnotatorImpl(LocaleList localeList)
            throws FileNotFoundException {
        synchronized (mLock) {
            localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
@@ -307,31 +338,79 @@ public final class TextClassifierImpl implements TextClassifier {
            if (bestModel == null) {
                throw new FileNotFoundException("No model for " + localeList.toLanguageTags());
            }
            if (mNative == null || !Objects.equals(mModel, bestModel)) {
            if (mAnnotatorImpl == null || !Objects.equals(mModel, bestModel)) {
                Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel);
                destroyNativeIfExistsLocked();
                destroyAnnotatorImplIfExistsLocked();
                final ParcelFileDescriptor fd = ParcelFileDescriptor.open(
                        new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
                mNative = new AnnotatorModel(fd.getFd());
                closeAndLogError(fd);
                try {
                    if (fd != null) {
                        mAnnotatorImpl = new AnnotatorModel(fd.getFd());
                        mModel = bestModel;
                    }
            return mNative;
                } finally {
                    maybeCloseAndLogError(fd);
                }
            }
            return mAnnotatorImpl;
        }
    }

    private String createId(String text, int start, int end) {
    @GuardedBy("mLock") // Do not call outside this lock.
    private void destroyAnnotatorImplIfExistsLocked() {
        if (mAnnotatorImpl != null) {
            mAnnotatorImpl.close();
            mAnnotatorImpl = null;
        }
    }

    private LangIdModel getLangIdImpl() throws FileNotFoundException {
        synchronized (mLock) {
            return SelectionSessionLogger.createId(text, start, end, mContext, mModel.getVersion(),
                    mModel.getSupportedLocales());
            if (mLangIdImpl == null) {
                ParcelFileDescriptor factoryFd = null;
                ParcelFileDescriptor updateFd = null;
                try {
                    int factoryVersion = -1;
                    int updateVersion = factoryVersion;
                    final File factoryFile = new File(LANG_ID_MODEL_FILE_PATH);
                    if (factoryFile.exists()) {
                        factoryFd = ParcelFileDescriptor.open(
                                factoryFile, ParcelFileDescriptor.MODE_READ_ONLY);
                        // TODO: Uncomment when method is implemented:
                        // if (factoryFd != null) {
                        //     factoryVersion = LangIdModel.getVersion(factoryFd.getFd());
                        // }
                    }
                    final File updateFile = new File(UPDATED_LANG_ID_MODEL_FILE_PATH);
                    if (updateFile.exists()) {
                        updateFd = ParcelFileDescriptor.open(
                                updateFile, ParcelFileDescriptor.MODE_READ_ONLY);
                        // TODO: Uncomment when method is implemented:
                        // if (updateFd != null) {
                        //     updateVersion = LangIdModel.getVersion(updateFd.getFd());
                        // }
                    }

                    if (updateVersion > factoryVersion) {
                        mLangIdImpl = new LangIdModel(updateFd.getFd());
                    } else if (factoryFd != null) {
                        mLangIdImpl = new LangIdModel(factoryFd.getFd());
                    } else {
                        throw new FileNotFoundException("Language detection model not found");
                    }
                } finally {
                    maybeCloseAndLogError(factoryFd);
                    maybeCloseAndLogError(updateFd);
                }
            }
            return mLangIdImpl;
        }
    }

    @GuardedBy("mLock") // Do not call outside this lock.
    private void destroyNativeIfExistsLocked() {
        if (mNative != null) {
            mNative.close();
            mNative = null;
    private String createId(String text, int start, int end) {
        synchronized (mLock) {
            return SelectionSessionLogger.createId(text, start, end, mContext, mModel.getVersion(),
                    mModel.getSupportedLocales());
        }
    }

@@ -407,20 +486,19 @@ public final class TextClassifierImpl implements TextClassifier {
                .setText(classifiedText);

        final int size = classifications.length;
        AnnotatorModel.ClassificationResult highestScoringResult = null;
        float highestScore = Float.MIN_VALUE;
        AnnotatorModel.ClassificationResult highestScoringResult =
                size > 0 ? classifications[0] : null;
        for (int i = 0; i < size; i++) {
            builder.setEntityType(classifications[i].getCollection(),
                                  classifications[i].getScore());
            if (classifications[i].getScore() > highestScore) {
            if (classifications[i].getScore() > highestScoringResult.getScore()) {
                highestScoringResult = classifications[i];
                highestScore = classifications[i].getScore();
            }
        }

        boolean isPrimaryAction = true;
        for (LabeledIntent labeledIntent : IntentFactory.create(
                mContext, referenceTime, highestScoringResult, classifiedText)) {
                mContext, classifiedText, referenceTime, highestScoringResult)) {
            final RemoteAction action = labeledIntent.asRemoteAction(mContext);
            if (action == null) {
                continue;
@@ -461,9 +539,13 @@ public final class TextClassifierImpl implements TextClassifier {
    }

    /**
     * Closes the ParcelFileDescriptor and logs any errors that occur.
     * Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur.
     */
    private static void closeAndLogError(ParcelFileDescriptor fd) {
    private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
        if (fd == null) {
            return;
        }

        try {
            fd.close();
        } catch (IOException e) {
@@ -485,12 +567,17 @@ public final class TextClassifierImpl implements TextClassifier {
        /** Returns null if the path did not point to a compatible model. */
        static @Nullable ModelFile fromPath(String path) {
            final File file = new File(path);
            if (!file.exists()) {
                return null;
            }
            ParcelFileDescriptor modelFd = null;
            try {
                final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open(
                        file, ParcelFileDescriptor.MODE_READ_ONLY);
                modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
                if (modelFd == null) {
                    return null;
                }
                final int version = AnnotatorModel.getVersion(modelFd.getFd());
                final String supportedLocalesStr =
                        AnnotatorModel.getLocales(modelFd.getFd());
                final String supportedLocalesStr = AnnotatorModel.getLocales(modelFd.getFd());
                if (supportedLocalesStr.isEmpty()) {
                    Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
                    return null;
@@ -500,12 +587,13 @@ public final class TextClassifierImpl implements TextClassifier {
                for (String langTag : supportedLocalesStr.split(",")) {
                    supportedLocales.add(Locale.forLanguageTag(langTag));
                }
                closeAndLogError(modelFd);
                return new ModelFile(path, file.getName(), version, supportedLocales,
                                     languageIndependent);
            } catch (FileNotFoundException e) {
                Log.e(DEFAULT_LOG_TAG, "Failed to peek " + file.getAbsolutePath(), e);
                return null;
            } finally {
                maybeCloseAndLogError(modelFd);
            }
        }

@@ -557,12 +645,12 @@ public final class TextClassifierImpl implements TextClassifier {
        public boolean equals(Object other) {
            if (this == other) {
                return true;
            } else if (other == null || !ModelFile.class.isAssignableFrom(other.getClass())) {
                return false;
            } else {
            }
            if (other instanceof ModelFile) {
                final ModelFile otherModel = (ModelFile) other;
                return mPath.equals(otherModel.mPath);
            }
            return false;
        }

        @Override
@@ -677,10 +765,12 @@ public final class TextClassifierImpl implements TextClassifier {
        @NonNull
        public static List<LabeledIntent> create(
                Context context,
                String text,
                @Nullable Instant referenceTime,
                AnnotatorModel.ClassificationResult classification,
                String text) {
            final String type = classification.getCollection().trim().toLowerCase(Locale.ENGLISH);
                @Nullable AnnotatorModel.ClassificationResult classification) {
            final String type = classification != null
                    ? classification.getCollection().trim().toLowerCase(Locale.ENGLISH)
                    : null;
            text = text.trim();
            switch (type) {
                case TextClassifier.TYPE_EMAIL:
+37 −0
Original line number Diff line number Diff line
@@ -306,6 +306,24 @@ public class TextClassificationManagerTest {
        mClassifier.generateLinks(request);
    }

    @Test
    public void testDetectLanguage() {
        if (isTextClassifierDisabled()) return;
        String text = "This is English text";
        TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
        TextLanguage textLanguage = mClassifier.detectLanguage(request);
        assertThat(textLanguage, isTextLanguage("en"));
    }

    @Test
    public void testDetectLanguage_japanese() {
        if (isTextClassifierDisabled()) return;
        String text = "これは日本語のテキストです";
        TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
        TextLanguage textLanguage = mClassifier.detectLanguage(request);
        assertThat(textLanguage, isTextLanguage("ja"));
    }

    @Test
    public void testSetTextClassifier() {
        TextClassifier classifier = mock(TextClassifier.class);
@@ -444,4 +462,23 @@ public class TextClassificationManagerTest {
            }
        };
    }

    private static Matcher<TextLanguage> isTextLanguage(final String languageTag) {
        return new BaseMatcher<TextLanguage>() {
            @Override
            public boolean matches(Object o) {
                if (o instanceof TextLanguage) {
                    TextLanguage result = (TextLanguage) o;
                    return result.getLocaleHypothesisCount() > 0
                            && languageTag.equals(result.getLocale(0).toLanguageTag());
                }
                return false;
            }

            @Override
            public void describeTo(Description description) {
                description.appendText("locale=").appendValue(languageTag);
            }
        };
    }
}