Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 06eb53ce authored by Nikita Iashchenko's avatar Nikita Iashchenko Committed by Gerrit Code Review
Browse files

Merge "Switch TextClassifier implementation from native to java"

parents 1eb5db6c 7ea2f83f
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -704,7 +704,6 @@ java_defaults {
    required: [
        // TODO: remove gps_debug when the build system propagates "required" properly.
        "gps_debug.conf",
        "libtextclassifier",
        // Loaded with System.loadLibrary by android.view.textclassifier
        "libmedia2_jni",
    ],
@@ -855,6 +854,10 @@ java_library {
        "nist-sip",
        "tagsoup",
        "rappor",
        "libtextclassifier-java",
    ],
    required: [
        "libtextclassifier",
    ],
    dxflags: ["--core-library"],
}
+21 −19
Original line number Diff line number Diff line
@@ -43,6 +43,8 @@ import android.provider.ContactsContract;
import com.android.internal.annotations.GuardedBy;
import com.android.internal.util.Preconditions;

import com.google.android.textclassifier.AnnotatorModel;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
@@ -91,7 +93,7 @@ public final class TextClassifierImpl implements TextClassifier {
    @GuardedBy("mLock") // Do not access outside this lock.
    private ModelFile mModel;
    @GuardedBy("mLock") // Do not access outside this lock.
    private TextClassifierImplNative mNative;
    private AnnotatorModel mNative;

    private final Object mLoggerLock = new Object();
    @GuardedBy("mLoggerLock") // Do not access outside this lock.
@@ -124,7 +126,7 @@ public final class TextClassifierImpl implements TextClassifier {
                    && rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) {
                final String localesString = concatenateLocales(request.getDefaultLocales());
                final ZonedDateTime refTime = ZonedDateTime.now();
                final TextClassifierImplNative nativeImpl = getNative(request.getDefaultLocales());
                final AnnotatorModel nativeImpl = getNative(request.getDefaultLocales());
                final int start;
                final int end;
                if (mSettings.isModelDarkLaunchEnabled() && !request.isDarkLaunchAllowed()) {
@@ -133,7 +135,7 @@ public final class TextClassifierImpl implements TextClassifier {
                } else {
                    final int[] startEnd = nativeImpl.suggestSelection(
                            string, request.getStartIndex(), request.getEndIndex(),
                            new TextClassifierImplNative.SelectionOptions(localesString));
                            new AnnotatorModel.SelectionOptions(localesString));
                    start = startEnd[0];
                    end = startEnd[1];
                }
@@ -141,10 +143,10 @@ public final class TextClassifierImpl implements TextClassifier {
                        && start >= 0 && end <= string.length()
                        && start <= request.getStartIndex() && end >= request.getEndIndex()) {
                    final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
                    final TextClassifierImplNative.ClassificationResult[] results =
                    final AnnotatorModel.ClassificationResult[] results =
                            nativeImpl.classifyText(
                                    string, start, end,
                                    new TextClassifierImplNative.ClassificationOptions(
                                    new AnnotatorModel.ClassificationOptions(
                                            refTime.toInstant().toEpochMilli(),
                                            refTime.getZone().getId(),
                                            localesString));
@@ -183,11 +185,11 @@ public final class TextClassifierImpl implements TextClassifier {
                final String localesString = concatenateLocales(request.getDefaultLocales());
                final ZonedDateTime refTime = request.getReferenceTime() != null
                        ? request.getReferenceTime() : ZonedDateTime.now();
                final TextClassifierImplNative.ClassificationResult[] results =
                final AnnotatorModel.ClassificationResult[] results =
                        getNative(request.getDefaultLocales())
                                .classifyText(
                                        string, request.getStartIndex(), request.getEndIndex(),
                                        new TextClassifierImplNative.ClassificationOptions(
                                        new AnnotatorModel.ClassificationOptions(
                                                refTime.toInstant().toEpochMilli(),
                                                refTime.getZone().getId(),
                                                localesString));
@@ -227,17 +229,17 @@ public final class TextClassifierImpl implements TextClassifier {
                    ? request.getEntityConfig().resolveEntityListModifications(
                            getEntitiesForHints(request.getEntityConfig().getHints()))
                    : mSettings.getEntityListDefault();
            final TextClassifierImplNative nativeImpl =
            final AnnotatorModel nativeImpl =
                    getNative(request.getDefaultLocales());
            final TextClassifierImplNative.AnnotatedSpan[] annotations =
            final AnnotatorModel.AnnotatedSpan[] annotations =
                    nativeImpl.annotate(
                        textString,
                        new TextClassifierImplNative.AnnotationOptions(
                        new AnnotatorModel.AnnotationOptions(
                                refTime.toInstant().toEpochMilli(),
                                        refTime.getZone().getId(),
                                concatenateLocales(request.getDefaultLocales())));
            for (TextClassifierImplNative.AnnotatedSpan span : annotations) {
                final TextClassifierImplNative.ClassificationResult[] results =
            for (AnnotatorModel.AnnotatedSpan span : annotations) {
                final AnnotatorModel.ClassificationResult[] results =
                        span.getClassification();
                if (results.length == 0
                        || !entitiesToIdentify.contains(results[0].getCollection())) {
@@ -296,7 +298,7 @@ public final class TextClassifierImpl implements TextClassifier {
        }
    }

    private TextClassifierImplNative getNative(LocaleList localeList)
    private AnnotatorModel getNative(LocaleList localeList)
            throws FileNotFoundException {
        synchronized (mLock) {
            localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
@@ -309,7 +311,7 @@ public final class TextClassifierImpl implements TextClassifier {
                destroyNativeIfExistsLocked();
                final ParcelFileDescriptor fd = ParcelFileDescriptor.open(
                        new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
                mNative = new TextClassifierImplNative(fd.getFd());
                mNative = new AnnotatorModel(fd.getFd());
                closeAndLogError(fd);
                mModel = bestModel;
            }
@@ -397,14 +399,14 @@ public final class TextClassifierImpl implements TextClassifier {
    }

    private TextClassification createClassificationResult(
            TextClassifierImplNative.ClassificationResult[] classifications,
            AnnotatorModel.ClassificationResult[] classifications,
            String text, int start, int end, @Nullable Instant referenceTime) {
        final String classifiedText = text.substring(start, end);
        final TextClassification.Builder builder = new TextClassification.Builder()
                .setText(classifiedText);

        final int size = classifications.length;
        TextClassifierImplNative.ClassificationResult highestScoringResult = null;
        AnnotatorModel.ClassificationResult highestScoringResult = null;
        float highestScore = Float.MIN_VALUE;
        for (int i = 0; i < size; i++) {
            builder.setEntityType(classifications[i].getCollection(),
@@ -467,9 +469,9 @@ public final class TextClassifierImpl implements TextClassifier {
            try {
                final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open(
                        file, ParcelFileDescriptor.MODE_READ_ONLY);
                final int version = TextClassifierImplNative.getVersion(modelFd.getFd());
                final int version = AnnotatorModel.getVersion(modelFd.getFd());
                final String supportedLocalesStr =
                        TextClassifierImplNative.getLocales(modelFd.getFd());
                        AnnotatorModel.getLocales(modelFd.getFd());
                if (supportedLocalesStr.isEmpty()) {
                    Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
                    return null;
@@ -657,7 +659,7 @@ public final class TextClassifierImpl implements TextClassifier {
        public static List<LabeledIntent> create(
                Context context,
                @Nullable Instant referenceTime,
                TextClassifierImplNative.ClassificationResult classification,
                AnnotatorModel.ClassificationResult classification,
                String text) {
            final String type = classification.getCollection().trim().toLowerCase(Locale.ENGLISH);
            text = text.trim();
+0 −301
Original line number Diff line number Diff line
/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package android.view.textclassifier;

import android.content.res.AssetFileDescriptor;

/**
 * Java wrapper for TextClassifier native library interface. This library is used for detecting
 * entities in text.
 */
final class TextClassifierImplNative {

    static {
        System.loadLibrary("textclassifier");
    }

    private final long mModelPtr;

    /**
     * Creates a new instance of TextClassifierImplNative, using the provided model image, given as
     * a file descriptor.
     */
    TextClassifierImplNative(int fd) {
        mModelPtr = nativeNew(fd);
        if (mModelPtr == 0L) {
            throw new IllegalArgumentException("Couldn't initialize TC from file descriptor.");
        }
    }

    /**
     * Creates a new instance of TextClassifierImplNative, using the provided model image, given as
     * a file path.
     */
    TextClassifierImplNative(String path) {
        mModelPtr = nativeNewFromPath(path);
        if (mModelPtr == 0L) {
            throw new IllegalArgumentException("Couldn't initialize TC from given file.");
        }
    }

    /**
     * Creates a new instance of TextClassifierImplNative, using the provided model image, given as
     * an AssetFileDescriptor.
     */
    TextClassifierImplNative(AssetFileDescriptor afd) {
        mModelPtr = nativeNewFromAssetFileDescriptor(afd, afd.getStartOffset(), afd.getLength());
        if (mModelPtr == 0L) {
            throw new IllegalArgumentException(
                    "Couldn't initialize TC from given AssetFileDescriptor");
        }
    }

    /**
     * Given a string context and current selection, computes the SmartSelection suggestion.
     *
     * <p>The begin and end are character indices into the context UTF8 string. selectionBegin is
     * the character index where the selection begins, and selectionEnd is the index of one
     * character past the selection span.
     *
     * <p>The return value is an array of two ints: suggested selection beginning and end, with the
     * same semantics as the input selectionBeginning and selectionEnd.
     */
    public int[] suggestSelection(
            String context, int selectionBegin, int selectionEnd, SelectionOptions options) {
        return nativeSuggestSelection(mModelPtr, context, selectionBegin, selectionEnd, options);
    }

    /**
     * Given a string context and current selection, classifies the type of the selected text.
     *
     * <p>The begin and end params are character indices in the context string.
     *
     * <p>Returns an array of ClassificationResult objects with the probability scores for different
     * collections.
     */
    public ClassificationResult[] classifyText(
            String context, int selectionBegin, int selectionEnd, ClassificationOptions options) {
        return nativeClassifyText(mModelPtr, context, selectionBegin, selectionEnd, options);
    }

    /**
     * Annotates given input text. The annotations should cover the whole input context except for
     * whitespaces, and are sorted by their position in the context string.
     */
    public AnnotatedSpan[] annotate(String text, AnnotationOptions options) {
        return nativeAnnotate(mModelPtr, text, options);
    }

    /** Frees up the allocated memory. */
    public void close() {
        nativeClose(mModelPtr);
    }

    /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
    public static String getLocales(int fd) {
        return nativeGetLocales(fd);
    }

    /** Returns the version of the model. */
    public static int getVersion(int fd) {
        return nativeGetVersion(fd);
    }

    /** Represents a datetime parsing result from classifyText calls. */
    public static final class DatetimeResult {
        static final int GRANULARITY_YEAR = 0;
        static final int GRANULARITY_MONTH = 1;
        static final int GRANULARITY_WEEK = 2;
        static final int GRANULARITY_DAY = 3;
        static final int GRANULARITY_HOUR = 4;
        static final int GRANULARITY_MINUTE = 5;
        static final int GRANULARITY_SECOND = 6;

        private final long mTimeMsUtc;
        private final int mGranularity;

        DatetimeResult(long timeMsUtc, int granularity) {
            mGranularity = granularity;
            mTimeMsUtc = timeMsUtc;
        }

        public long getTimeMsUtc() {
            return mTimeMsUtc;
        }

        public int getGranularity() {
            return mGranularity;
        }
    }

    /** Represents a result of classifyText method call. */
    public static final class ClassificationResult {
        private final String mCollection;
        private final float mScore;
        private final DatetimeResult mDatetimeResult;

        ClassificationResult(
                String collection, float score, DatetimeResult datetimeResult) {
            mCollection = collection;
            mScore = score;
            mDatetimeResult = datetimeResult;
        }

        public String getCollection() {
            if (mCollection.equals(TextClassifier.TYPE_DATE) && mDatetimeResult != null) {
                switch (mDatetimeResult.getGranularity()) {
                    case DatetimeResult.GRANULARITY_HOUR:
                        // fall through
                    case DatetimeResult.GRANULARITY_MINUTE:
                        // fall through
                    case DatetimeResult.GRANULARITY_SECOND:
                        return TextClassifier.TYPE_DATE_TIME;
                    default:
                        return TextClassifier.TYPE_DATE;
                }
            }
            return mCollection;
        }

        public float getScore() {
            return mScore;
        }

        public DatetimeResult getDatetimeResult() {
            return mDatetimeResult;
        }
    }

    /** Represents a result of Annotate call. */
    public static final class AnnotatedSpan {
        private final int mStartIndex;
        private final int mEndIndex;
        private final ClassificationResult[] mClassification;

        AnnotatedSpan(
                int startIndex, int endIndex, ClassificationResult[] classification) {
            mStartIndex = startIndex;
            mEndIndex = endIndex;
            mClassification = classification;
        }

        public int getStartIndex() {
            return mStartIndex;
        }

        public int getEndIndex() {
            return mEndIndex;
        }

        public ClassificationResult[] getClassification() {
            return mClassification;
        }
    }

    /** Represents options for the suggestSelection call. */
    public static final class SelectionOptions {
        private final String mLocales;

        SelectionOptions(String locales) {
            mLocales = locales;
        }

        public String getLocales() {
            return mLocales;
        }
    }

    /** Represents options for the classifyText call. */
    public static final class ClassificationOptions {
        private final long mReferenceTimeMsUtc;
        private final String mReferenceTimezone;
        private final String mLocales;

        ClassificationOptions(long referenceTimeMsUtc, String referenceTimezone, String locale) {
            mReferenceTimeMsUtc = referenceTimeMsUtc;
            mReferenceTimezone = referenceTimezone;
            mLocales = locale;
        }

        public long getReferenceTimeMsUtc() {
            return mReferenceTimeMsUtc;
        }

        public String getReferenceTimezone() {
            return mReferenceTimezone;
        }

        public String getLocale() {
            return mLocales;
        }
    }

    /** Represents options for the Annotate call. */
    public static final class AnnotationOptions {
        private final long mReferenceTimeMsUtc;
        private final String mReferenceTimezone;
        private final String mLocales;

        AnnotationOptions(long referenceTimeMsUtc, String referenceTimezone, String locale) {
            mReferenceTimeMsUtc = referenceTimeMsUtc;
            mReferenceTimezone = referenceTimezone;
            mLocales = locale;
        }

        public long getReferenceTimeMsUtc() {
            return mReferenceTimeMsUtc;
        }

        public String getReferenceTimezone() {
            return mReferenceTimezone;
        }

        public String getLocale() {
            return mLocales;
        }
    }

    private static native long nativeNew(int fd);

    private static native long nativeNewFromPath(String path);

    private static native long nativeNewFromAssetFileDescriptor(
            AssetFileDescriptor afd, long offset, long size);

    private static native int[] nativeSuggestSelection(
            long context,
            String text,
            int selectionBegin,
            int selectionEnd,
            SelectionOptions options);

    private static native ClassificationResult[] nativeClassifyText(
            long context,
            String text,
            int selectionBegin,
            int selectionEnd,
            ClassificationOptions options);

    private static native AnnotatedSpan[] nativeAnnotate(
            long context, String text, AnnotationOptions options);

    private static native void nativeClose(long context);

    private static native String nativeGetLocales(int fd);

    private static native int nativeGetVersion(int fd);
}