Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit a9767337 authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "TextClassifier: Multiple entities & confidence scores."

parents 7c4faa98 a6096f6c
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -42,10 +42,10 @@ final class EntityConfidence<T> {
        float score1 = mEntityConfidence.get(e1);
        float score2 = mEntityConfidence.get(e2);
        if (score1 > score2) {
            return 1;
            return -1;
        }
        if (score1 < score2) {
            return -1;
            return 1;
        }
        return 0;
    };
+16 −4
Original line number Diff line number Diff line
@@ -37,8 +37,8 @@ final class LangId {
    /**
     * Detects the language for given text.
     */
    public String findLanguage(String text) {
        return nativeFindLanguage(mModelPtr, text);
    public ClassificationResult[] findLanguages(String text) {
        return nativeFindLanguages(mModelPtr, text);
    }

    /**
@@ -50,8 +50,20 @@ final class LangId {

    private static native long nativeNew(int fd);

    private static native String nativeFindLanguage(long context, String text);
    private static native ClassificationResult[] nativeFindLanguages(
            long context, String text);

    private static native void nativeClose(long context);
}

    /** Classification result for findLanguage method. */
    static final class ClassificationResult {
        final String mLanguage;
        /** float range: 0 - 1 */
        final float mScore;

        ClassificationResult(String language, float score) {
            mLanguage = language;
            mScore = score;
        }
    }
}
+17 −3
Original line number Diff line number Diff line
@@ -55,9 +55,11 @@ final class SmartSelection {
     *
     * The begin and end params are character indices in the context string.
     *
     * Returns the type of the selection, e.g. "email", "address", "phone".
     * Returns an array of ClassificationResult objects with the probability
     * scores for different collections.
     */
    public String classifyText(String context, int selectionBegin, int selectionEnd) {
    public ClassificationResult[] classifyText(
            String context, int selectionBegin, int selectionEnd) {
        return nativeClassifyText(mCtx, context, selectionBegin, selectionEnd);
    }

@@ -73,9 +75,21 @@ final class SmartSelection {
    private static native int[] nativeSuggest(
            long context, String text, int selectionBegin, int selectionEnd);

    private static native String nativeClassifyText(
    private static native ClassificationResult[] nativeClassifyText(
            long context, String text, int selectionBegin, int selectionEnd);

    private static native void nativeClose(long context);

    /** Classification result for classifyText method. */
    static final class ClassificationResult {
        final String mCollection;
        /** float range: 0 - 1 */
        final float mScore;

        ClassificationResult(String collection, float score) {
            mCollection = collection;
            mScore = score;
        }
    }
}
+11 −6
Original line number Diff line number Diff line
@@ -85,12 +85,17 @@ public final class TextClassificationManager {
        Preconditions.checkArgument(text != null);
        try {
            if (text.length() > 0) {
                final String language = getLanguageDetector().findLanguage(text.toString());
                final Locale locale = new Locale.Builder().setLanguageTag(language).build();
                return Collections.unmodifiableList(Arrays.asList(
                        new TextLanguage.Builder(0, text.length())
                                .setLanguage(locale, 1.0f /* confidence */)
                                .build()));
                final LangId.ClassificationResult[] results =
                        getLanguageDetector().findLanguages(text.toString());
                final TextLanguage.Builder tlBuilder = new TextLanguage.Builder(0, text.length());
                final int size = results.length;
                for (int i = 0; i < size; i++) {
                    tlBuilder.setLanguage(
                            new Locale.Builder().setLanguageTag(results[i].mLanguage).build(),
                            results[i].mScore);
                }

                return Collections.unmodifiableList(Arrays.asList(tlBuilder.build()));
            }
        } catch (Throwable t) {
            // Avoid throwing from this method. Log the error.
+35 −21
Original line number Diff line number Diff line
@@ -86,10 +86,14 @@ final class TextClassifierImpl implements TextClassifier {
                final int start = startEnd[0];
                final int end = startEnd[1];
                if (start >= 0 && end <= string.length() && start <= end) {
                    final String type = getSmartSelection().classifyText(string, start, end);
                    return new TextSelection.Builder(start, end)
                            .setEntityType(type, 1.0f)
                            .build();
                    final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
                    final SmartSelection.ClassificationResult[] results =
                            getSmartSelection().classifyText(string, start, end);
                    final int size = results.length;
                    for (int i = 0; i < size; i++) {
                        tsBuilder.setEntityType(results[i].mCollection, results[i].mScore);
                    }
                    return tsBuilder.build();
                } else {
                    // We can not trust the result. Log the issue and ignore the result.
                    Log.d(LOG_TAG, "Got bad indices for input text. Ignoring result.");
@@ -113,13 +117,13 @@ final class TextClassifierImpl implements TextClassifier {
        try {
            if (text.length() > 0) {
                final CharSequence classified = text.subSequence(startIndex, endIndex);
                String type = getSmartSelection()
                SmartSelection.ClassificationResult[] results = getSmartSelection()
                        .classifyText(text.toString(), startIndex, endIndex);
                if (!TextUtils.isEmpty(type)) {
                    type = type.toLowerCase(Locale.ENGLISH).trim();
                if (results.length > 0) {
                    // TODO: Added this log for debug only. Remove before release.
                    Log.d(LOG_TAG, String.format("Classification type: %s", type));
                    return createClassificationResult(type, classified);
                    Log.d(LOG_TAG,
                            String.format("Classification type: %s", results[0].mCollection));
                    return createClassificationResult(results, classified);
                }
            }
        } catch (Throwable t) {
@@ -174,11 +178,17 @@ final class TextClassifierImpl implements TextClassifier {
        }
    }

    private TextClassificationResult createClassificationResult(String type, CharSequence text) {
    private TextClassificationResult createClassificationResult(
            SmartSelection.ClassificationResult[] classifications, CharSequence text) {
        final TextClassificationResult.Builder builder = new TextClassificationResult.Builder()
                .setText(text.toString())
                .setEntityType(type, 1.0f /* confidence */);
                .setText(text.toString());

        final int size = classifications.length;
        for (int i = 0; i < size; i++) {
            builder.setEntityType(classifications[i].mCollection, classifications[i].mScore);
        }

        final String type = classifications[0].mCollection;
        final Intent intent = IntentFactory.create(mContext, type, text.toString());
        final PackageManager pm;
        final ResolveInfo resolveInfo;
@@ -252,8 +262,10 @@ final class TextClassifierImpl implements TextClassifier {
                final int selectionEnd = selection[1];
                if (selectionStart >= 0 && selectionEnd <= text.length()
                        && selectionStart <= selectionEnd) {
                    final String type =
                    final SmartSelection.ClassificationResult[] results =
                            smartSelection.classifyText(text, selectionStart, selectionEnd);
                    if (results.length > 0) {
                        final String type = results[0].mCollection;
                        if (matches(type, linkMask)) {
                            final Intent intent = IntentFactory.create(
                                    context, type, text.substring(selectionStart, selectionEnd));
@@ -263,6 +275,7 @@ final class TextClassifierImpl implements TextClassifier {
                            }
                        }
                    }
                }
                start = end;
            }
            return new LinksInfoImpl(text, avoidOverlaps(spans, text));
@@ -272,6 +285,7 @@ final class TextClassifierImpl implements TextClassifier {
         * Returns true if the classification type matches the specified linkMask.
         */
        private static boolean matches(String type, int linkMask) {
            type = type.trim().toLowerCase(Locale.ENGLISH);
            if ((linkMask & Linkify.PHONE_NUMBERS) != 0
                    && TextClassifier.TYPE_PHONE.equals(type)) {
                return true;
@@ -403,6 +417,7 @@ final class TextClassifierImpl implements TextClassifier {

        @Nullable
        public static Intent create(Context context, String type, String text) {
            type = type.trim().toLowerCase(Locale.ENGLISH);
            switch (type) {
                case TextClassifier.TYPE_EMAIL:
                    return new Intent(Intent.ACTION_SENDTO)
@@ -418,12 +433,12 @@ final class TextClassifierImpl implements TextClassifier {
                            .putExtra(Browser.EXTRA_APPLICATION_ID, context.getPackageName());
                default:
                    return null;
                // TODO: Add other classification types.
            }
        }

        @Nullable
        public static String getLabel(Context context, String type) {
            type = type.trim().toLowerCase(Locale.ENGLISH);
            switch (type) {
                case TextClassifier.TYPE_EMAIL:
                    return context.getString(com.android.internal.R.string.email);
@@ -435,7 +450,6 @@ final class TextClassifierImpl implements TextClassifier {
                    return context.getString(com.android.internal.R.string.browse);
                default:
                    return null;
                // TODO: Add other classification types.
            }
        }
    }