Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 6ff8cf84 authored by Nikita Iashchenko's avatar Nikita Iashchenko Committed by android-build-merger
Browse files

Merge "Switch TextClassifier implementation from native to java"

am: 06eb53ce

Change-Id: Iaaddc488874381fbc1f6b76f5f60dc47abf65cdd
parents 3603c807 06eb53ce
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -704,7 +704,6 @@ java_defaults {
    required: [
        // TODO: remove gps_debug when the build system propagates "required" properly.
        "gps_debug.conf",
        "libtextclassifier",
        // Loaded with System.loadLibrary by android.view.textclassifier
        "libmedia2_jni",
    ],
@@ -855,6 +854,10 @@ java_library {
        "nist-sip",
        "tagsoup",
        "rappor",
        "libtextclassifier-java",
    ],
    required: [
        "libtextclassifier",
    ],
    dxflags: ["--core-library"],
}
+21 −19
Original line number Diff line number Diff line
@@ -43,6 +43,8 @@ import android.provider.ContactsContract;
import com.android.internal.annotations.GuardedBy;
import com.android.internal.util.Preconditions;

import com.google.android.textclassifier.AnnotatorModel;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
@@ -91,7 +93,7 @@ public final class TextClassifierImpl implements TextClassifier {
    @GuardedBy("mLock") // Do not access outside this lock.
    private ModelFile mModel;
    @GuardedBy("mLock") // Do not access outside this lock.
    private TextClassifierImplNative mNative;
    private AnnotatorModel mNative;

    private final Object mLoggerLock = new Object();
    @GuardedBy("mLoggerLock") // Do not access outside this lock.
@@ -124,7 +126,7 @@ public final class TextClassifierImpl implements TextClassifier {
                    && rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) {
                final String localesString = concatenateLocales(request.getDefaultLocales());
                final ZonedDateTime refTime = ZonedDateTime.now();
                final TextClassifierImplNative nativeImpl = getNative(request.getDefaultLocales());
                final AnnotatorModel nativeImpl = getNative(request.getDefaultLocales());
                final int start;
                final int end;
                if (mSettings.isModelDarkLaunchEnabled() && !request.isDarkLaunchAllowed()) {
@@ -133,7 +135,7 @@ public final class TextClassifierImpl implements TextClassifier {
                } else {
                    final int[] startEnd = nativeImpl.suggestSelection(
                            string, request.getStartIndex(), request.getEndIndex(),
                            new TextClassifierImplNative.SelectionOptions(localesString));
                            new AnnotatorModel.SelectionOptions(localesString));
                    start = startEnd[0];
                    end = startEnd[1];
                }
@@ -141,10 +143,10 @@ public final class TextClassifierImpl implements TextClassifier {
                        && start >= 0 && end <= string.length()
                        && start <= request.getStartIndex() && end >= request.getEndIndex()) {
                    final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
                    final TextClassifierImplNative.ClassificationResult[] results =
                    final AnnotatorModel.ClassificationResult[] results =
                            nativeImpl.classifyText(
                                    string, start, end,
                                    new TextClassifierImplNative.ClassificationOptions(
                                    new AnnotatorModel.ClassificationOptions(
                                            refTime.toInstant().toEpochMilli(),
                                            refTime.getZone().getId(),
                                            localesString));
@@ -183,11 +185,11 @@ public final class TextClassifierImpl implements TextClassifier {
                final String localesString = concatenateLocales(request.getDefaultLocales());
                final ZonedDateTime refTime = request.getReferenceTime() != null
                        ? request.getReferenceTime() : ZonedDateTime.now();
                final TextClassifierImplNative.ClassificationResult[] results =
                final AnnotatorModel.ClassificationResult[] results =
                        getNative(request.getDefaultLocales())
                                .classifyText(
                                        string, request.getStartIndex(), request.getEndIndex(),
                                        new TextClassifierImplNative.ClassificationOptions(
                                        new AnnotatorModel.ClassificationOptions(
                                                refTime.toInstant().toEpochMilli(),
                                                refTime.getZone().getId(),
                                                localesString));
@@ -227,17 +229,17 @@ public final class TextClassifierImpl implements TextClassifier {
                    ? request.getEntityConfig().resolveEntityListModifications(
                            getEntitiesForHints(request.getEntityConfig().getHints()))
                    : mSettings.getEntityListDefault();
            final TextClassifierImplNative nativeImpl =
            final AnnotatorModel nativeImpl =
                    getNative(request.getDefaultLocales());
            final TextClassifierImplNative.AnnotatedSpan[] annotations =
            final AnnotatorModel.AnnotatedSpan[] annotations =
                    nativeImpl.annotate(
                        textString,
                        new TextClassifierImplNative.AnnotationOptions(
                        new AnnotatorModel.AnnotationOptions(
                                refTime.toInstant().toEpochMilli(),
                                        refTime.getZone().getId(),
                                concatenateLocales(request.getDefaultLocales())));
            for (TextClassifierImplNative.AnnotatedSpan span : annotations) {
                final TextClassifierImplNative.ClassificationResult[] results =
            for (AnnotatorModel.AnnotatedSpan span : annotations) {
                final AnnotatorModel.ClassificationResult[] results =
                        span.getClassification();
                if (results.length == 0
                        || !entitiesToIdentify.contains(results[0].getCollection())) {
@@ -296,7 +298,7 @@ public final class TextClassifierImpl implements TextClassifier {
        }
    }

    private TextClassifierImplNative getNative(LocaleList localeList)
    private AnnotatorModel getNative(LocaleList localeList)
            throws FileNotFoundException {
        synchronized (mLock) {
            localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
@@ -309,7 +311,7 @@ public final class TextClassifierImpl implements TextClassifier {
                destroyNativeIfExistsLocked();
                final ParcelFileDescriptor fd = ParcelFileDescriptor.open(
                        new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
                mNative = new TextClassifierImplNative(fd.getFd());
                mNative = new AnnotatorModel(fd.getFd());
                closeAndLogError(fd);
                mModel = bestModel;
            }
@@ -397,14 +399,14 @@ public final class TextClassifierImpl implements TextClassifier {
    }

    private TextClassification createClassificationResult(
            TextClassifierImplNative.ClassificationResult[] classifications,
            AnnotatorModel.ClassificationResult[] classifications,
            String text, int start, int end, @Nullable Instant referenceTime) {
        final String classifiedText = text.substring(start, end);
        final TextClassification.Builder builder = new TextClassification.Builder()
                .setText(classifiedText);

        final int size = classifications.length;
        TextClassifierImplNative.ClassificationResult highestScoringResult = null;
        AnnotatorModel.ClassificationResult highestScoringResult = null;
        float highestScore = Float.MIN_VALUE;
        for (int i = 0; i < size; i++) {
            builder.setEntityType(classifications[i].getCollection(),
@@ -467,9 +469,9 @@ public final class TextClassifierImpl implements TextClassifier {
            try {
                final ParcelFileDescriptor modelFd = ParcelFileDescriptor.open(
                        file, ParcelFileDescriptor.MODE_READ_ONLY);
                final int version = TextClassifierImplNative.getVersion(modelFd.getFd());
                final int version = AnnotatorModel.getVersion(modelFd.getFd());
                final String supportedLocalesStr =
                        TextClassifierImplNative.getLocales(modelFd.getFd());
                        AnnotatorModel.getLocales(modelFd.getFd());
                if (supportedLocalesStr.isEmpty()) {
                    Log.d(DEFAULT_LOG_TAG, "Ignoring " + file.getAbsolutePath());
                    return null;
@@ -657,7 +659,7 @@ public final class TextClassifierImpl implements TextClassifier {
        public static List<LabeledIntent> create(
                Context context,
                @Nullable Instant referenceTime,
                TextClassifierImplNative.ClassificationResult classification,
                AnnotatorModel.ClassificationResult classification,
                String text) {
            final String type = classification.getCollection().trim().toLowerCase(Locale.ENGLISH);
            text = text.trim();
+0 −301
Original line number Diff line number Diff line
/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package android.view.textclassifier;

import android.content.res.AssetFileDescriptor;

/**
 * Java wrapper for TextClassifier native library interface. This library is used for detecting
 * entities in text.
 */
final class TextClassifierImplNative {

    static {
        System.loadLibrary("textclassifier");
    }

    private final long mModelPtr;

    /**
     * Creates a new instance of TextClassifierImplNative, using the provided model image, given as
     * a file descriptor.
     */
    TextClassifierImplNative(int fd) {
        mModelPtr = nativeNew(fd);
        if (mModelPtr == 0L) {
            throw new IllegalArgumentException("Couldn't initialize TC from file descriptor.");
        }
    }

    /**
     * Creates a new instance of TextClassifierImplNative, using the provided model image, given as
     * a file path.
     */
    TextClassifierImplNative(String path) {
        mModelPtr = nativeNewFromPath(path);
        if (mModelPtr == 0L) {
            throw new IllegalArgumentException("Couldn't initialize TC from given file.");
        }
    }

    /**
     * Creates a new instance of TextClassifierImplNative, using the provided model image, given as
     * an AssetFileDescriptor.
     */
    TextClassifierImplNative(AssetFileDescriptor afd) {
        mModelPtr = nativeNewFromAssetFileDescriptor(afd, afd.getStartOffset(), afd.getLength());
        if (mModelPtr == 0L) {
            throw new IllegalArgumentException(
                    "Couldn't initialize TC from given AssetFileDescriptor");
        }
    }

    /**
     * Given a string context and current selection, computes the SmartSelection suggestion.
     *
     * <p>The begin and end are character indices into the context UTF8 string. selectionBegin is
     * the character index where the selection begins, and selectionEnd is the index of one
     * character past the selection span.
     *
     * <p>The return value is an array of two ints: suggested selection beginning and end, with the
     * same semantics as the input selectionBeginning and selectionEnd.
     */
    public int[] suggestSelection(
            String context, int selectionBegin, int selectionEnd, SelectionOptions options) {
        return nativeSuggestSelection(mModelPtr, context, selectionBegin, selectionEnd, options);
    }

    /**
     * Given a string context and current selection, classifies the type of the selected text.
     *
     * <p>The begin and end params are character indices in the context string.
     *
     * <p>Returns an array of ClassificationResult objects with the probability scores for different
     * collections.
     */
    public ClassificationResult[] classifyText(
            String context, int selectionBegin, int selectionEnd, ClassificationOptions options) {
        return nativeClassifyText(mModelPtr, context, selectionBegin, selectionEnd, options);
    }

    /**
     * Annotates given input text. The annotations should cover the whole input context except for
     * whitespaces, and are sorted by their position in the context string.
     */
    public AnnotatedSpan[] annotate(String text, AnnotationOptions options) {
        return nativeAnnotate(mModelPtr, text, options);
    }

    /** Frees up the allocated memory. */
    public void close() {
        nativeClose(mModelPtr);
    }

    /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
    public static String getLocales(int fd) {
        return nativeGetLocales(fd);
    }

    /** Returns the version of the model. */
    public static int getVersion(int fd) {
        return nativeGetVersion(fd);
    }

    /** Represents a datetime parsing result from classifyText calls. */
    public static final class DatetimeResult {
        static final int GRANULARITY_YEAR = 0;
        static final int GRANULARITY_MONTH = 1;
        static final int GRANULARITY_WEEK = 2;
        static final int GRANULARITY_DAY = 3;
        static final int GRANULARITY_HOUR = 4;
        static final int GRANULARITY_MINUTE = 5;
        static final int GRANULARITY_SECOND = 6;

        private final long mTimeMsUtc;
        private final int mGranularity;

        DatetimeResult(long timeMsUtc, int granularity) {
            mGranularity = granularity;
            mTimeMsUtc = timeMsUtc;
        }

        public long getTimeMsUtc() {
            return mTimeMsUtc;
        }

        public int getGranularity() {
            return mGranularity;
        }
    }

    /** Represents a result of classifyText method call. */
    public static final class ClassificationResult {
        private final String mCollection;
        private final float mScore;
        private final DatetimeResult mDatetimeResult;

        ClassificationResult(
                String collection, float score, DatetimeResult datetimeResult) {
            mCollection = collection;
            mScore = score;
            mDatetimeResult = datetimeResult;
        }

        public String getCollection() {
            if (mCollection.equals(TextClassifier.TYPE_DATE) && mDatetimeResult != null) {
                switch (mDatetimeResult.getGranularity()) {
                    case DatetimeResult.GRANULARITY_HOUR:
                        // fall through
                    case DatetimeResult.GRANULARITY_MINUTE:
                        // fall through
                    case DatetimeResult.GRANULARITY_SECOND:
                        return TextClassifier.TYPE_DATE_TIME;
                    default:
                        return TextClassifier.TYPE_DATE;
                }
            }
            return mCollection;
        }

        public float getScore() {
            return mScore;
        }

        public DatetimeResult getDatetimeResult() {
            return mDatetimeResult;
        }
    }

    /** Represents a result of Annotate call. */
    public static final class AnnotatedSpan {
        private final int mStartIndex;
        private final int mEndIndex;
        private final ClassificationResult[] mClassification;

        AnnotatedSpan(
                int startIndex, int endIndex, ClassificationResult[] classification) {
            mStartIndex = startIndex;
            mEndIndex = endIndex;
            mClassification = classification;
        }

        public int getStartIndex() {
            return mStartIndex;
        }

        public int getEndIndex() {
            return mEndIndex;
        }

        public ClassificationResult[] getClassification() {
            return mClassification;
        }
    }

    /** Represents options for the suggestSelection call. */
    public static final class SelectionOptions {
        private final String mLocales;

        SelectionOptions(String locales) {
            mLocales = locales;
        }

        public String getLocales() {
            return mLocales;
        }
    }

    /** Represents options for the classifyText call. */
    public static final class ClassificationOptions {
        private final long mReferenceTimeMsUtc;
        private final String mReferenceTimezone;
        private final String mLocales;

        ClassificationOptions(long referenceTimeMsUtc, String referenceTimezone, String locale) {
            mReferenceTimeMsUtc = referenceTimeMsUtc;
            mReferenceTimezone = referenceTimezone;
            mLocales = locale;
        }

        public long getReferenceTimeMsUtc() {
            return mReferenceTimeMsUtc;
        }

        public String getReferenceTimezone() {
            return mReferenceTimezone;
        }

        public String getLocale() {
            return mLocales;
        }
    }

    /** Represents options for the Annotate call. */
    public static final class AnnotationOptions {
        private final long mReferenceTimeMsUtc;
        private final String mReferenceTimezone;
        private final String mLocales;

        AnnotationOptions(long referenceTimeMsUtc, String referenceTimezone, String locale) {
            mReferenceTimeMsUtc = referenceTimeMsUtc;
            mReferenceTimezone = referenceTimezone;
            mLocales = locale;
        }

        public long getReferenceTimeMsUtc() {
            return mReferenceTimeMsUtc;
        }

        public String getReferenceTimezone() {
            return mReferenceTimezone;
        }

        public String getLocale() {
            return mLocales;
        }
    }

    private static native long nativeNew(int fd);

    private static native long nativeNewFromPath(String path);

    private static native long nativeNewFromAssetFileDescriptor(
            AssetFileDescriptor afd, long offset, long size);

    private static native int[] nativeSuggestSelection(
            long context,
            String text,
            int selectionBegin,
            int selectionEnd,
            SelectionOptions options);

    private static native ClassificationResult[] nativeClassifyText(
            long context,
            String text,
            int selectionBegin,
            int selectionEnd,
            ClassificationOptions options);

    private static native AnnotatedSpan[] nativeAnnotate(
            long context, String text, AnnotationOptions options);

    private static native void nativeClose(long context);

    private static native String nativeGetLocales(int fd);

    private static native int nativeGetVersion(int fd);
}