Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit f8c36bff authored by Lukas Zilka's avatar Lukas Zilka
Browse files

Updates the name of the native library wrapper class, adds options and removes hints.

Test: Built, tested on device, CTS passes.

bit FrameworksCoreTests:android.view.textclassifier.TextClassificationManagerTest

Change-Id: I7c70427d28bec8218935ed45a39819b2ece8112a
parent bf653394
Loading
Loading
Loading
Loading
+0 −194
Original line number Diff line number Diff line
/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package android.view.textclassifier;

import android.annotation.Nullable;
import android.content.res.AssetFileDescriptor;

/**
 *  Java wrapper for SmartSelection native library interface.
 *  This library is used for detecting entities in text.
 */
final class SmartSelection {

    static {
        System.loadLibrary("textclassifier");
    }

    /** Hints the classifier that this may be a url. */
    static final int HINT_FLAG_URL = 0x01;
    /** Hints the classifier that this may be an email. */
    static final int HINT_FLAG_EMAIL = 0x02;

    private final long mCtx;

    /**
     * Creates a new instance of SmartSelect predictor, using the provided model image,
     * given as a file descriptor.
     */
    SmartSelection(int fd) {
        mCtx = nativeNew(fd);
    }

    /**
     * Creates a new instance of SmartSelect predictor, using the provided model image, given as a
     * file path.
     */
    SmartSelection(String path) {
        mCtx = nativeNewFromPath(path);
    }

    /**
     * Creates a new instance of SmartSelect predictor, using the provided model image, given as an
     * AssetFileDescriptor.
     */
    SmartSelection(AssetFileDescriptor afd) {
        mCtx = nativeNewFromAssetFileDescriptor(afd, afd.getStartOffset(), afd.getLength());
        if (mCtx == 0L) {
            throw new IllegalArgumentException(
                "Couldn't initialize TC from given AssetFileDescriptor");
        }
    }

    /**
     * Given a string context and current selection, computes the SmartSelection suggestion.
     *
     * The begin and end are character indices into the context UTF8 string. selectionBegin is the
     * character index where the selection begins, and selectionEnd is the index of one character
     * past the selection span.
     *
     * The return value is an array of two ints: suggested selection beginning and end, with the
     * same semantics as the input selectionBeginning and selectionEnd.
     */
    public int[] suggest(String context, int selectionBegin, int selectionEnd) {
        return nativeSuggest(mCtx, context, selectionBegin, selectionEnd);
    }

    /**
     * Given a string context and current selection, classifies the type of the selected text.
     *
     * The begin and end params are character indices in the context string.
     *
     * Returns an array of ClassificationResult objects with the probability
     * scores for different collections.
     */
    public ClassificationResult[] classifyText(
            String context, int selectionBegin, int selectionEnd, int hintFlags) {
        return nativeClassifyText(mCtx, context, selectionBegin, selectionEnd, hintFlags);
    }

    /**
     * Annotates given input text. Every word of the input is a part of some annotation.
     * The annotations are sorted by their position in the context string.
     * The annotations do not overlap.
     */
    public AnnotatedSpan[] annotate(String text) {
        return nativeAnnotate(mCtx, text);
    }

    /**
     * Frees up the allocated memory.
     */
    public void close() {
        nativeClose(mCtx);
    }

    /**
     * Returns a comma separated list of locales supported by the model as BCP 47 tags.
     */
    public static String getLanguages(int fd) {
        return nativeGetLanguage(fd);
    }

    /**
     * Returns the version of the model.
     */
    public static int getVersion(int fd) {
        return nativeGetVersion(fd);
    }

    private static native long nativeNew(int fd);

    private static native long nativeNewFromPath(String path);

    private static native long nativeNewFromAssetFileDescriptor(AssetFileDescriptor afd,
                                                                long offset, long size);

    private static native int[] nativeSuggest(
            long context, String text, int selectionBegin, int selectionEnd);

    private static native ClassificationResult[] nativeClassifyText(
            long context, String text, int selectionBegin, int selectionEnd, int hintFlags);

    private static native AnnotatedSpan[] nativeAnnotate(long context, String text);

    private static native void nativeClose(long context);

    private static native String nativeGetLanguage(int fd);

    private static native int nativeGetVersion(int fd);

    /** Classification result for classifyText method. */
    static final class ClassificationResult {
        final String mCollection;
        /** float range: 0 - 1 */
        final float mScore;
        @Nullable final DatetimeParseResult mDatetime;

        ClassificationResult(String collection, float score) {
            mCollection = collection;
            mScore = score;
            mDatetime = null;
        }

        ClassificationResult(String collection, float score, DatetimeParseResult datetime) {
            mCollection = collection;
            mScore = score;
            mDatetime = datetime;
        }
    }

    /** Parsed date information for the classification result. */
    static final class DatetimeParseResult {
        long mMsSinceEpoch;
    }

    /** Represents a result of Annotate call. */
    public static final class AnnotatedSpan {
        final int mStartIndex;
        final int mEndIndex;
        final ClassificationResult[] mClassification;

        AnnotatedSpan(int startIndex, int endIndex, ClassificationResult[] classification) {
            mStartIndex = startIndex;
            mEndIndex = endIndex;
            mClassification = classification;
        }

        public int getStartIndex() {
            return mStartIndex;
        }

        public int getEndIndex() {
            return mEndIndex;
        }

        public ClassificationResult[] getClassification() {
            return mClassification;
        }
    }
}
+69 −53
Original line number Diff line number Diff line
@@ -35,8 +35,6 @@ import android.provider.Browser;
import android.provider.CalendarContract;
import android.provider.ContactsContract;
import android.provider.Settings;
import android.text.util.Linkify;
import android.util.Patterns;
import android.view.textclassifier.logging.DefaultLogger;
import android.view.textclassifier.logging.Logger;

@@ -102,7 +100,7 @@ public final class TextClassifierImpl implements TextClassifier {
    @GuardedBy("mLock") // Do not access outside this lock.
    private ModelFile mModel;
    @GuardedBy("mLock") // Do not access outside this lock.
    private SmartSelection mSmartSelection;
    private TextClassifierImplNative mNative;

    private final Object mLoggerLock = new Object();
    @GuardedBy("mLoggerLock") // Do not access outside this lock.
@@ -128,8 +126,10 @@ public final class TextClassifierImpl implements TextClassifier {
            if (text.length() > 0
                    && rangeLength <= getSettings().getSuggestSelectionMaxRangeLength()) {
                final LocaleList locales = (options == null) ? null : options.getDefaultLocales();
                final String localesString = concatenateLocales(locales);
                final Calendar refTime = Calendar.getInstance();
                final boolean darkLaunchAllowed = options != null && options.isDarkLaunchAllowed();
                final SmartSelection smartSelection = getSmartSelection(locales);
                final TextClassifierImplNative nativeImpl = getNative(locales);
                final String string = text.toString();
                final int start;
                final int end;
@@ -137,8 +137,9 @@ public final class TextClassifierImpl implements TextClassifier {
                    start = selectionStartIndex;
                    end = selectionEndIndex;
                } else {
                    final int[] startEnd = smartSelection.suggest(
                            string, selectionStartIndex, selectionEndIndex);
                    final int[] startEnd = nativeImpl.suggestSelection(
                            string, selectionStartIndex, selectionEndIndex,
                            new TextClassifierImplNative.SelectionOptions(localesString));
                    start = startEnd[0];
                    end = startEnd[1];
                }
@@ -146,13 +147,16 @@ public final class TextClassifierImpl implements TextClassifier {
                        && start >= 0 && end <= string.length()
                        && start <= selectionStartIndex && end >= selectionEndIndex) {
                    final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
                    final SmartSelection.ClassificationResult[] results =
                            smartSelection.classifyText(
                    final TextClassifierImplNative.ClassificationResult[] results =
                            nativeImpl.classifyText(
                                    string, start, end,
                                    getHintFlags(string, start, end));
                                    new TextClassifierImplNative.ClassificationOptions(
                                            refTime.getTimeInMillis(),
                                            refTime.getTimeZone().getID(),
                                            localesString));
                    final int size = results.length;
                    for (int i = 0; i < size; i++) {
                        tsBuilder.setEntityType(results[i].mCollection, results[i].mScore);
                        tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore());
                    }
                    return tsBuilder
                            .setSignature(
@@ -185,10 +189,17 @@ public final class TextClassifierImpl implements TextClassifier {
            if (text.length() > 0 && rangeLength <= getSettings().getClassifyTextMaxRangeLength()) {
                final String string = text.toString();
                final LocaleList locales = (options == null) ? null : options.getDefaultLocales();
                final Calendar refTime = (options == null) ? null : options.getReferenceTime();
                final SmartSelection.ClassificationResult[] results = getSmartSelection(locales)
                final String localesString = concatenateLocales(locales);
                final Calendar refTime = (options != null && options.getReferenceTime() != null)
                        ? options.getReferenceTime() : Calendar.getInstance();

                final TextClassifierImplNative.ClassificationResult[] results =
                        getNative(locales)
                                .classifyText(string, startIndex, endIndex,
                                getHintFlags(string, startIndex, endIndex));
                                        new TextClassifierImplNative.ClassificationOptions(
                                                refTime.getTimeInMillis(),
                                                refTime.getTimeZone().getID(),
                                                localesString));
                if (results.length > 0) {
                    return createClassificationResult(
                            results, string, startIndex, endIndex, refTime);
@@ -216,21 +227,31 @@ public final class TextClassifierImpl implements TextClassifier {

        try {
            final LocaleList defaultLocales = options != null ? options.getDefaultLocales() : null;
            final Calendar refTime = Calendar.getInstance();
            final Collection<String> entitiesToIdentify =
                    options != null && options.getEntityConfig() != null
                            ? options.getEntityConfig().resolveEntityListModifications(
                                    getEntitiesForHints(options.getEntityConfig().getHints()))
                            : ENTITY_TYPES_ALL;
            final SmartSelection smartSelection = getSmartSelection(defaultLocales);
            final SmartSelection.AnnotatedSpan[] annotations = smartSelection.annotate(textString);
            for (SmartSelection.AnnotatedSpan span : annotations) {
                final SmartSelection.ClassificationResult[] results = span.getClassification();
                if (results.length == 0 || !entitiesToIdentify.contains(results[0].mCollection)) {
            final TextClassifierImplNative nativeImpl =
                    getNative(defaultLocales);
            final TextClassifierImplNative.AnnotatedSpan[] annotations =
                    nativeImpl.annotate(
                        textString,
                        new TextClassifierImplNative.AnnotationOptions(
                                refTime.getTimeInMillis(),
                                refTime.getTimeZone().getID(),
                                concatenateLocales(defaultLocales)));
            for (TextClassifierImplNative.AnnotatedSpan span : annotations) {
                final TextClassifierImplNative.ClassificationResult[] results =
                        span.getClassification();
                if (results.length == 0
                        || !entitiesToIdentify.contains(results[0].getCollection())) {
                    continue;
                }
                final Map<String, Float> entityScores = new HashMap<>();
                for (int i = 0; i < results.length; i++) {
                    entityScores.put(results[i].mCollection, results[i].mScore);
                    entityScores.put(results[i].getCollection(), results[i].getScore());
                }
                builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores);
            }
@@ -274,23 +295,24 @@ public final class TextClassifierImpl implements TextClassifier {
        return mSettings;
    }

    private SmartSelection getSmartSelection(LocaleList localeList) throws FileNotFoundException {
    private TextClassifierImplNative getNative(LocaleList localeList)
            throws FileNotFoundException {
        synchronized (mLock) {
            localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
            final ModelFile bestModel = findBestModelLocked(localeList);
            if (bestModel == null) {
                throw new FileNotFoundException("No model for " + localeList.toLanguageTags());
            }
            if (mSmartSelection == null || !Objects.equals(mModel, bestModel)) {
            if (mNative == null || !Objects.equals(mModel, bestModel)) {
                Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel);
                destroySmartSelectionIfExistsLocked();
                destroyNativeIfExistsLocked();
                final ParcelFileDescriptor fd = ParcelFileDescriptor.open(
                        new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
                mSmartSelection = new SmartSelection(fd.getFd());
                mNative = new TextClassifierImplNative(fd.getFd());
                closeAndLogError(fd);
                mModel = bestModel;
            }
            return mSmartSelection;
            return mNative;
        }
    }

@@ -302,11 +324,15 @@ public final class TextClassifierImpl implements TextClassifier {
    }

    @GuardedBy("mLock") // Do not call outside this lock.
    private void destroySmartSelectionIfExistsLocked() {
        if (mSmartSelection != null) {
            mSmartSelection.close();
            mSmartSelection = null;
    private void destroyNativeIfExistsLocked() {
        if (mNative != null) {
            mNative.close();
            mNative = null;
        }
    }

    private static String concatenateLocales(@Nullable LocaleList locales) {
        return (locales == null) ? "" : locales.toLanguageTags();
    }

    /**
@@ -372,20 +398,21 @@ public final class TextClassifierImpl implements TextClassifier {
    }

    private TextClassification createClassificationResult(
            SmartSelection.ClassificationResult[] classifications,
            TextClassifierImplNative.ClassificationResult[] classifications,
            String text, int start, int end, @Nullable Calendar referenceTime) {
        final String classifiedText = text.substring(start, end);
        final TextClassification.Builder builder = new TextClassification.Builder()
                .setText(classifiedText);

        final int size = classifications.length;
        SmartSelection.ClassificationResult highestScoringResult = null;
        TextClassifierImplNative.ClassificationResult highestScoringResult = null;
        float highestScore = Float.MIN_VALUE;
        for (int i = 0; i < size; i++) {
            builder.setEntityType(classifications[i].mCollection, classifications[i].mScore);
            if (classifications[i].mScore > highestScore) {
            builder.setEntityType(classifications[i].getCollection(),
                                  classifications[i].getScore());
            if (classifications[i].getScore() > highestScore) {
                highestScoringResult = classifications[i];
                highestScore = classifications[i].mScore;
                highestScore = classifications[i].getScore();
            }
        }

@@ -433,19 +460,6 @@ public final class TextClassifierImpl implements TextClassifier {
        }
    }

    private static int getHintFlags(CharSequence text, int start, int end) {
        int flag = 0;
        final CharSequence subText = text.subSequence(start, end);
        if (Patterns.AUTOLINK_EMAIL_ADDRESS.matcher(subText).matches()) {
            flag |= SmartSelection.HINT_FLAG_EMAIL;
        }
        if (Patterns.AUTOLINK_WEB_URL.matcher(subText).matches()
                && Linkify.sUrlMatchFilter.acceptMatch(text, start, end)) {
            flag |= SmartSelection.HINT_FLAG_URL;
        }
        return flag;
    }

    /**
     * Closes the ParcelFileDescriptor and logs any errors that occur.
     */
@@ -473,8 +487,9 @@ public final class TextClassifierImpl implements TextClassifier {
            try {
                final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open(
                        file, ParcelFileDescriptor.MODE_READ_ONLY);
                final int version = SmartSelection.getVersion(modelFd.getFd());
                final String supportedLocalesStr = SmartSelection.getLanguages(modelFd.getFd());
                final int version = TextClassifierImplNative.getVersion(modelFd.getFd());
                final String supportedLocalesStr =
                        TextClassifierImplNative.getLocales(modelFd.getFd());
                if (supportedLocalesStr.isEmpty()) {
                    Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
                    return null;
@@ -560,9 +575,9 @@ public final class TextClassifierImpl implements TextClassifier {
        public static List<Intent> create(
                Context context,
                @Nullable Calendar referenceTime,
                SmartSelection.ClassificationResult classification,
                TextClassifierImplNative.ClassificationResult classification,
                String text) {
            final String type = classification.mCollection.trim().toLowerCase(Locale.ENGLISH);
            final String type = classification.getCollection().trim().toLowerCase(Locale.ENGLISH);
            text = text.trim();
            switch (type) {
                case TextClassifier.TYPE_EMAIL:
@@ -575,9 +590,10 @@ public final class TextClassifierImpl implements TextClassifier {
                    return createForUrl(context, text);
                case TextClassifier.TYPE_DATE:
                case TextClassifier.TYPE_DATE_TIME:
                    if (classification.mDatetime != null) {
                    if (classification.getDatetimeResult() != null) {
                        Calendar eventTime = Calendar.getInstance();
                        eventTime.setTimeInMillis(classification.mDatetime.mMsSinceEpoch);
                        eventTime.setTimeInMillis(
                                classification.getDatetimeResult().getTimeMsUtc());
                        return createForDatetime(type, referenceTime, eventTime);
                    } else {
                        return new ArrayList<>();
+301 −0

File added.

Preview size limit exceeded, changes collapsed.

+39 −2
Original line number Diff line number Diff line
@@ -41,7 +41,7 @@ import java.util.Collections;
@RunWith(AndroidJUnit4.class)
public class TextClassificationManagerTest {

    private static final LocaleList LOCALES = LocaleList.forLanguageTags("en");
    private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US");
    private static final String NO_TYPE = null;

    private TextClassificationManager mTcm;
@@ -180,6 +180,42 @@ public class TextClassificationManagerTest {
                        "http://ANDROID.COM"));
    }

    @Test
    public void testTextClassifyText_date() {
        if (isTextClassifierDisabled()) return;

        String text = "Let's meet on January 9, 2018.";
        String classifiedText = "January 9, 2018";
        int startIndex = text.indexOf(classifiedText);
        int endIndex = startIndex + classifiedText.length();

        TextClassification classification = mClassifier.classifyText(
                text, startIndex, endIndex, mClassificationOptions);
        assertThat(classification,
                isTextClassification(
                        classifiedText,
                        TextClassifier.TYPE_DATE,
                        null));
    }

    @Test
    public void testTextClassifyText_datetime() {
        if (isTextClassifierDisabled()) return;

        String text = "Let's meet 2018/01/01 10:30:20.";
        String classifiedText = "2018/01/01 10:30:20";
        int startIndex = text.indexOf(classifiedText);
        int endIndex = startIndex + classifiedText.length();

        TextClassification classification = mClassifier.classifyText(
                text, startIndex, endIndex, mClassificationOptions);
        assertThat(classification,
                isTextClassification(
                        classifiedText,
                        TextClassifier.TYPE_DATE_TIME,
                        null));
    }

    @Test
    public void testGenerateLinks_phone() {
        if (isTextClassifierDisabled()) return;
@@ -334,7 +370,8 @@ public class TextClassificationManagerTest {
                            && text.equals(result.getText())
                            && result.getEntityCount() > 0
                            && type.equals(result.getEntity(0))
                            && intentUri.equals(result.getIntent().getDataString());
                            && (intentUri == null
                                || intentUri.equals(result.getIntent().getDataString()));
                    // TODO: Include other properties.
                }
                return false;