Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 25f7fdc1 authored by Abodunrinwa Toki's avatar Abodunrinwa Toki
Browse files

Language detection fixes.

- Load foreign language detection score threshold from model
- Pass resource config language to native code instead of Locale.getDefault()
- Avoid nullpointer exception in ExtrasUtils
- Don't set action_intents extras if empty

Bug: 124791964
Bug: 124794807
Test: atest core/tests/coretests/src/android/view/textclassifier
Change-Id: I2593d7cb4d364d8bf26239ed59b7212f79ddc350
parent 753f4ce4
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -94,7 +94,8 @@ public final class ExtrasUtils {
        if (actionIntents != null) {
            final int size = actionIntents.size();
            for (int i = 0; i < size; i++) {
                if (intentAction.equals(actionIntents.get(i).getAction())) {
                final Intent intent = actionIntents.get(i);
                if (intent != null && intentAction.equals(intent.getAction())) {
                    return classification.getActions().get(i);
                }
            }
+4 −3
Original line number Diff line number Diff line
@@ -54,6 +54,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

/**
 * Information for generating a widget to handle classified text.
@@ -276,8 +277,8 @@ public final class TextClassification implements Parcelable {
    @Override
    public String toString() {
        return String.format(Locale.US,
                "TextClassification {text=%s, entities=%s, actions=%s, id=%s}",
                mText, mEntityConfidence, mActions, mId);
                "TextClassification {text=%s, entities=%s, actions=%s, id=%s, extras=%s}",
                mText, mEntityConfidence, mActions, mId, mExtras);
    }

    /**
@@ -532,7 +533,7 @@ public final class TextClassification implements Parcelable {

        private Bundle buildExtras() {
            final Bundle extras = mExtras == null ? new Bundle() : mExtras.deepCopy();
            if (!mActionIntents.isEmpty()) {
            if (mActionIntents.stream().anyMatch(Objects::nonNull)) {
                ExtrasUtils.putActionsIntents(extras, mActionIntents);
            }
            if (mForeignLanguageExtra != null) {
+42 −22
Original line number Diff line number Diff line
@@ -240,9 +240,7 @@ public final class TextClassifierImpl implements TextClassifier {
                                                refTime.getZone().getId(),
                                                localesString),
                                        mContext,
                                        // TODO: Pass the locale list once it is supported in
                                        //  native side.
                                        LocaleList.getDefault().get(0).toLanguageTag()
                                        getResourceLocaleString()
                                );
                if (results.length > 0) {
                    return createClassificationResult(
@@ -403,8 +401,7 @@ public final class TextClassifierImpl implements TextClassifier {
                            nativeConversation,
                            null,
                            mContext,
                            // TODO: Pass the locale list once it is supported in native side.
                            LocaleList.getDefault().get(0).toLanguageTag());
                            getResourceLocaleString());
            return createConversationActionResult(request, nativeSuggestions);
        } catch (Throwable t) {
            // Avoid throwing from this method. Log the error.
@@ -456,10 +453,9 @@ public final class TextClassifierImpl implements TextClassifier {
        TextLanguage textLanguage = detectLanguage(request);
        int localeHypothesisCount = textLanguage.getLocaleHypothesisCount();
        List<String> languageTags = new ArrayList<>();
        // TODO: Reconsider this and probably make the score threshold configurable.
        for (int i = 0; i < localeHypothesisCount; i++) {
            ULocale locale = textLanguage.getLocale(i);
            if (textLanguage.getConfidenceScore(locale) < 0.5) {
            if (textLanguage.getConfidenceScore(locale) < getForeignLanguageThreshold()) {
                break;
            }
            languageTags.add(locale.toLanguageTag());
@@ -587,15 +583,10 @@ public final class TextClassifierImpl implements TextClassifier {
            }
        }

        final float foreignTextThreshold = mSettings.getLangIdThresholdOverride() >= 0
                ? mSettings.getLangIdThresholdOverride()
                : 0.5f /* TODO: Load this from the langId model. */;
        final Bundle foreignLanguageBundle =
                detectForeignLanguage(classifiedText, foreignTextThreshold);
        final Bundle foreignLanguageBundle = detectForeignLanguage(classifiedText);
        builder.setForeignLanguageExtra(foreignLanguageBundle);

        boolean isPrimaryAction = true;
        final ArrayList<Intent> sourceIntents = new ArrayList<>();
        List<LabeledIntent> labeledIntents = mIntentFactory.create(
                mContext,
                classifiedText,
@@ -626,16 +617,20 @@ public final class TextClassifierImpl implements TextClassifier {

    /**
     * Returns a bundle with the language and confidence score if it finds the text to be
     * in a foreign language. Otherwise returns null.
     * in a foreign language. Otherwise returns null. This algorithm defines what the system thinks
     * is a foreign language.
     */
    // TODO: Revisit this algorithm.
    // TODO: Consider making this public API.
    @Nullable
    private Bundle detectForeignLanguage(String text, float threshold) {
    private Bundle detectForeignLanguage(String text) {
        try {
            final float threshold = getForeignLanguageThreshold();
            if (threshold > 1) {
                Log.v(LOG_TAG, "Foreign language detection disabled.");
                return null;
            }

        // TODO: Revisit this algorithm.
        try {
            final LangIdModel langId = getLangIdImpl();
            final LangIdModel.LanguageResult[] langResults = langId.detectLanguages(text);
            if (langResults.length <= 0) {
@@ -651,8 +646,8 @@ public final class TextClassifierImpl implements TextClassifier {
            if (highestScoringResult.getScore() < threshold) {
                return null;
            }
            // TODO: Remove
            Log.d(LOG_TAG, String.format("Language detected: <%s:%s>",

            Log.v(LOG_TAG, String.format("Language detected: <%s:%s>",
                    highestScoringResult.getLanguage(), highestScoringResult.getScore()));

            final Locale detected = new Locale(highestScoringResult.getLanguage());
@@ -671,6 +666,18 @@ public final class TextClassifierImpl implements TextClassifier {
        return null;
    }

    private float getForeignLanguageThreshold() {
        try {
            return mSettings.getLangIdThresholdOverride() >= 0
                    ? mSettings.getLangIdThresholdOverride()
                    : getLangIdImpl().getTranslateThreshold();
        } catch (FileNotFoundException e) {
            final float defaultThreshold = 0.5f;
            Log.v(LOG_TAG, "Using default foreign language threshold: " + defaultThreshold);
            return defaultThreshold;
        }
    }

    @Override
    public void dump(@NonNull IndentingPrintWriter printWriter) {
        synchronized (mLock) {
@@ -718,6 +725,19 @@ public final class TextClassifierImpl implements TextClassifier {
        }
    }

    /**
     * Returns the locale string for the current resources configuration.
     */
    private String getResourceLocaleString() {
        // TODO: Pass the locale list once it is supported in native side.
        try {
            return mContext.getResources().getConfiguration().getLocales().get(0).toLanguageTag();
        } catch (NullPointerException e) {
            // NPE is unexpected. Erring on the side of caution.
            return LocaleList.getDefault().get(0).toLanguageTag();
        }
    }

    /**
     * Helper class to store the information from which RemoteActions are built.
     */
+3 −2
Original line number Diff line number Diff line
@@ -93,8 +93,8 @@ public class TextClassificationTest {
        final String id = "id";
        final TextClassification reference = new TextClassification.Builder()
                .setText(text)
                .addAction(remoteAction0)
                .addAction(remoteAction1)
                .addAction(remoteAction0)  // Action intent not included.
                .addAction(remoteAction1)  // Action intent not included.
                .setEntityType(TextClassifier.TYPE_ADDRESS, 0.3f)
                .setEntityType(TextClassifier.TYPE_PHONE, 0.7f)
                .setId(id)
@@ -132,6 +132,7 @@ public class TextClassificationTest {

        // Extras
        assertEquals(BUNDLE_VALUE, result.getExtras().getString(BUNDLE_KEY));
        assertNull(ExtrasUtils.getActionsIntents(result));
    }

    @Test