Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 5004517c authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Using LangID to detect the language of the text and pass it to annotator"

parents 387c9bd8 159f028b
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -12023,6 +12023,8 @@ public final class Settings {
         * notification_conversation_action_types_default   (String[])
         * lang_id_threshold_override                       (float)
         * template_intent_factory_enabled                  (boolean)
         * translate_in_classification_enabled              (boolean)
         * detect_language_from_text_enabled                (boolean)
         * </pre>
         *
         * <p>
+4 −1
Original line number Diff line number Diff line
@@ -82,9 +82,12 @@ public final class ActionsSuggestionsHelper {
            long referenceTime = message.getReferenceTime() == null
                    ? 0
                    : message.getReferenceTime().toInstant().toEpochMilli();
            String timeZone = message.getReferenceTime() == null
                    ? null
                    : message.getReferenceTime().getZone().getId();
            nativeMessages.push(new ActionsSuggestionsModel.ConversationMessage(
                    personEncoder.encode(message.getAuthor()),
                    message.getText().toString(), referenceTime,
                    message.getText().toString(), referenceTime, timeZone,
                    languageDetector.apply(message.getText())));
        }
        return nativeMessages.toArray(
+34 −2
Original line number Diff line number Diff line
@@ -48,6 +48,8 @@ import java.util.StringJoiner;
 * notification_conversation_action_types_default   (String[])
 * lang_id_threshold_override                       (float)
 * template_intent_factory_enabled                  (boolean)
 * translate_in_classification_enabled              (boolean)
 * detect_languages_from_text_enabled               (boolean)
 * </pre>
 *
 * <p>
@@ -139,7 +141,7 @@ public final class TextClassificationConstants {
    private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT =
            "notification_conversation_action_types_default";
    /**
     * Threshold in classifyText to consider a text is in a foreign language.
     * Threshold to accept a suggested language from LangID model.
     */
    private static final String LANG_ID_THRESHOLD_OVERRIDE = "lang_id_threshold_override";
    /**
@@ -147,6 +149,18 @@ public final class TextClassificationConstants {
     */
    private static final String TEMPLATE_INTENT_FACTORY_ENABLED = "template_intent_factory_enabled";

    /**
     * Whether to enable "translate" action in classifyText.
     */
    private static final String TRANSLATE_IN_CLASSIFICATION_ENABLED =
            "translate_in_classification_enabled";
    /**
     * Whether to detect the languages of the text in request by using langId for the native
     * model.
     */
    private static final String DETECT_LANGUAGES_FROM_TEXT_ENABLED =
            "detect_languages_from_text_enabled";

    private static final boolean LOCAL_TEXT_CLASSIFIER_ENABLED_DEFAULT = true;
    private static final boolean SYSTEM_TEXT_CLASSIFIER_ENABLED_DEFAULT = true;
    private static final boolean MODEL_DARK_LAUNCH_ENABLED_DEFAULT = false;
@@ -183,11 +197,13 @@ public final class TextClassificationConstants {
    /**
     * < 0  : Not set. Use value from LangId model.
     * 0 - 1: Override value in LangId model.
     * > 1  : Effectively turns off the foreign language detection. Scores should never be > 1.
     *
     * @see EntityConfidence
     */
    private static final float LANG_ID_THRESHOLD_OVERRIDE_DEFAULT = -1f;
    private static final boolean TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT = true;
    private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
    private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;

    private final boolean mSystemTextClassifierEnabled;
    private final boolean mLocalTextClassifierEnabled;
@@ -207,6 +223,8 @@ public final class TextClassificationConstants {
    private final List<String> mNotificationConversationActionTypesDefault;
    private final float mLangIdThresholdOverride;
    private final boolean mTemplateIntentFactoryEnabled;
    private final boolean mTranslateInClassificationEnabled;
    private final boolean mDetectLanguagesFromTextEnabled;

    private TextClassificationConstants(@Nullable String settings) {
        ConfigParser configParser = new ConfigParser(settings);
@@ -280,6 +298,10 @@ public final class TextClassificationConstants {
        mTemplateIntentFactoryEnabled = configParser.getBoolean(
                TEMPLATE_INTENT_FACTORY_ENABLED,
                TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT);
        mTranslateInClassificationEnabled = configParser.getBoolean(
                TRANSLATE_IN_CLASSIFICATION_ENABLED, TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT);
        mDetectLanguagesFromTextEnabled = configParser.getBoolean(
                DETECT_LANGUAGES_FROM_TEXT_ENABLED, DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT);
    }

    /** Load from a settings string. */
@@ -359,6 +381,14 @@ public final class TextClassificationConstants {
        return mTemplateIntentFactoryEnabled;
    }

    public boolean isTranslateInClassificationEnabled() {
        return mTranslateInClassificationEnabled;
    }

    public boolean isDetectLanguagesFromTextEnabled() {
        return mDetectLanguagesFromTextEnabled;
    }

    private static List<String> parseStringList(String listStr) {
        return Collections.unmodifiableList(Arrays.asList(listStr.split(STRING_LIST_DELIMITER)));
    }
@@ -385,6 +415,8 @@ public final class TextClassificationConstants {
                mNotificationConversationActionTypesDefault);
        pw.printPair("getLangIdThresholdOverride", mLangIdThresholdOverride);
        pw.printPair("isTemplateIntentFactoryEnabled", mTemplateIntentFactoryEnabled);
        pw.printPair("isTranslateInClassificationEnabled", mTranslateInClassificationEnabled);
        pw.printPair("isDetectLanguageFromTextEnabled", mDetectLanguagesFromTextEnabled);
        pw.decreaseIndent();
        pw.println();
    }
+39 −20
Original line number Diff line number Diff line
@@ -164,6 +164,7 @@ public final class TextClassifierImpl implements TextClassifier {
            if (string.length() > 0
                    && rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) {
                final String localesString = concatenateLocales(request.getDefaultLocales());
                final String detectLanguageTags = detectLanguageTagsFromText(request.getText());
                final ZonedDateTime refTime = ZonedDateTime.now();
                final AnnotatorModel annotatorImpl =
                        getAnnotatorImpl(request.getDefaultLocales());
@@ -175,7 +176,7 @@ public final class TextClassifierImpl implements TextClassifier {
                } else {
                    final int[] startEnd = annotatorImpl.suggestSelection(
                            string, request.getStartIndex(), request.getEndIndex(),
                            new AnnotatorModel.SelectionOptions(localesString));
                            new AnnotatorModel.SelectionOptions(localesString, detectLanguageTags));
                    start = startEnd[0];
                    end = startEnd[1];
                }
@@ -189,7 +190,8 @@ public final class TextClassifierImpl implements TextClassifier {
                                    new AnnotatorModel.ClassificationOptions(
                                            refTime.toInstant().toEpochMilli(),
                                            refTime.getZone().getId(),
                                            localesString),
                                            localesString,
                                            detectLanguageTags),
                                    // Passing null here to suppress intent generation
                                    // TODO: Use an explicit flag to suppress it.
                                    /* appContext */ null,
@@ -227,6 +229,7 @@ public final class TextClassifierImpl implements TextClassifier {
            final String string = request.getText().toString();
            if (string.length() > 0 && rangeLength <= mSettings.getClassifyTextMaxRangeLength()) {
                final String localesString = concatenateLocales(request.getDefaultLocales());
                final String detectLanguageTags = detectLanguageTagsFromText(request.getText());
                final ZonedDateTime refTime = request.getReferenceTime() != null
                        ? request.getReferenceTime() : ZonedDateTime.now();
                final AnnotatorModel.ClassificationResult[] results =
@@ -236,9 +239,10 @@ public final class TextClassifierImpl implements TextClassifier {
                                        new AnnotatorModel.ClassificationOptions(
                                                refTime.toInstant().toEpochMilli(),
                                                refTime.getZone().getId(),
                                                localesString),
                                                localesString,
                                                detectLanguageTags),
                                        mContext,
                                        getResourceLocaleString()
                                        getResourceLocalesString()
                                );
                if (results.length > 0) {
                    return createClassificationResult(
@@ -276,6 +280,8 @@ public final class TextClassifierImpl implements TextClassifier {
                    ? request.getEntityConfig().resolveEntityListModifications(
                    getEntitiesForHints(request.getEntityConfig().getHints()))
                    : mSettings.getEntityListDefault();
            final String localesString = concatenateLocales(request.getDefaultLocales());
            final String detectLanguageTags = detectLanguageTagsFromText(request.getText());
            final AnnotatorModel annotatorImpl =
                    getAnnotatorImpl(request.getDefaultLocales());
            final AnnotatorModel.AnnotatedSpan[] annotations =
@@ -284,7 +290,8 @@ public final class TextClassifierImpl implements TextClassifier {
                            new AnnotatorModel.AnnotationOptions(
                                    refTime.toInstant().toEpochMilli(),
                                    refTime.getZone().getId(),
                                    concatenateLocales(request.getDefaultLocales())));
                                    localesString,
                                    detectLanguageTags));
            for (AnnotatorModel.AnnotatedSpan span : annotations) {
                final AnnotatorModel.ClassificationResult[] results =
                        span.getClassification();
@@ -386,8 +393,8 @@ public final class TextClassifierImpl implements TextClassifier {
                return mFallback.suggestConversationActions(request);
            }
            ActionsSuggestionsModel.ConversationMessage[] nativeMessages =
                    ActionsSuggestionsHelper.toNativeMessages(request.getConversation(),
                            this::detectLanguageTagsFromText);
                    ActionsSuggestionsHelper.toNativeMessages(
                            request.getConversation(), this::detectLanguageTagsFromText);
            if (nativeMessages.length == 0) {
                return mFallback.suggestConversationActions(request);
            }
@@ -399,7 +406,7 @@ public final class TextClassifierImpl implements TextClassifier {
                            nativeConversation,
                            null,
                            mContext,
                            getResourceLocaleString());
                            getResourceLocalesString());
            return createConversationActionResult(request, nativeSuggestions);
        } catch (Throwable t) {
            // Avoid throwing from this method. Log the error.
@@ -463,19 +470,28 @@ public final class TextClassifierImpl implements TextClassifier {

    @Nullable
    private String detectLanguageTagsFromText(CharSequence text) {
        if (!mSettings.isDetectLanguagesFromTextEnabled()) {
            return null;
        }
        final float threshold = getLangIdThreshold();
        if (threshold < 0 || threshold > 1) {
            Log.w(LOG_TAG,
                    "[detectLanguageTagsFromText] unexpected threshold is found: " + threshold);
            return null;
        }
        TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
        TextLanguage textLanguage = detectLanguage(request);
        int localeHypothesisCount = textLanguage.getLocaleHypothesisCount();
        List<String> languageTags = new ArrayList<>();
        for (int i = 0; i < localeHypothesisCount; i++) {
            ULocale locale = textLanguage.getLocale(i);
            if (textLanguage.getConfidenceScore(locale) < getForeignLanguageThreshold()) {
            if (textLanguage.getConfidenceScore(locale) < threshold) {
                break;
            }
            languageTags.add(locale.toLanguageTag());
        }
        if (languageTags.isEmpty()) {
            return LocaleList.getDefault().toLanguageTags();
            return null;
        }
        return String.join(",", languageTags);
    }
@@ -644,10 +660,14 @@ public final class TextClassifierImpl implements TextClassifier {
    // TODO: Consider making this public API.
    @Nullable
    private Bundle detectForeignLanguage(String text) {
        if (!mSettings.isTranslateInClassificationEnabled()) {
            return null;
        }
        try {
            final float threshold = getForeignLanguageThreshold();
            if (threshold > 1) {
                Log.v(LOG_TAG, "Foreign language detection disabled.");
            final float threshold = getLangIdThreshold();
            if (threshold < 0 || threshold > 1) {
                Log.w(LOG_TAG,
                        "[detectForeignLanguage] unexpected threshold is found: " + threshold);
                return null;
            }

@@ -686,11 +706,11 @@ public final class TextClassifierImpl implements TextClassifier {
        return null;
    }

    private float getForeignLanguageThreshold() {
    private float getLangIdThreshold() {
        try {
            return mSettings.getLangIdThresholdOverride() >= 0
                    ? mSettings.getLangIdThresholdOverride()
                    : getLangIdImpl().getTranslateThreshold();
                    : getLangIdImpl().getLangIdThreshold();
        } catch (FileNotFoundException e) {
            final float defaultThreshold = 0.5f;
            Log.v(LOG_TAG, "Using default foreign language threshold: " + defaultThreshold);
@@ -746,15 +766,14 @@ public final class TextClassifierImpl implements TextClassifier {
    }

    /**
     * Returns the locale string for the current resources configuration.
     * Returns the locales string for the current resources configuration.
     */
    private String getResourceLocaleString() {
        // TODO: Pass the locale list once it is supported in native side.
    private String getResourceLocalesString() {
        try {
            return mContext.getResources().getConfiguration().getLocales().get(0).toLanguageTag();
            return mContext.getResources().getConfiguration().getLocales().toLanguageTags();
        } catch (NullPointerException e) {
            // NPE is unexpected. Erring on the side of caution.
            return LocaleList.getDefault().get(0).toLanguageTag();
            return LocaleList.getDefault().toLanguageTags();
        }
    }
}
+1 −1
Original line number Diff line number Diff line
@@ -215,7 +215,7 @@ public class ActionsSuggestionsHelperTest {
            long referenceTimeInMsUtc) {
        assertThat(nativeMessage.getText()).isEqualTo(text.toString());
        assertThat(nativeMessage.getUserId()).isEqualTo(userId);
        assertThat(nativeMessage.getLocales()).isEqualTo(LOCALE_TAG);
        assertThat(nativeMessage.getDetectedTextLanguageTags()).isEqualTo(LOCALE_TAG);
        assertThat(nativeMessage.getReferenceTimeMsUtc()).isEqualTo(referenceTimeInMsUtc);
    }
}
Loading