Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit adbebcc6 authored by Tony Mak's avatar Tony Mak
Browse files

Implements TextClassifierImpl.suggestConversationActions

TODO: Construct RemoteAction for contextual actions.
TODO: Map Person object to user id.
TODO: Consider to use LangID to infer the locale. And get a new model when locale is changed.

BUG: 111437455
BUG: 111406942

Test: atest ./core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java

Change-Id: Id35066455918b3321fcd30df0ff215e30586a4b3
parent 50619770
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -74,10 +74,9 @@ public final class ModelFileManager {
     * @param localeList the required locales, use {@code null} if there is no preference.
     */
    public ModelFile findBestModelFile(@Nullable LocaleList localeList) {
        // Specified localeList takes priority over the system default, so it is listed first.
        final String languages = localeList == null || localeList.isEmpty()
                ? LocaleList.getDefault().toLanguageTags()
                : localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags();
                : localeList.toLanguageTags();
        final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);

        ModelFile bestModel = null;
+35 −0
Original line number Diff line number Diff line
@@ -90,6 +90,10 @@ public final class TextClassificationConstants {
            "entity_list_not_editable";
    private static final String ENTITY_LIST_EDITABLE =
            "entity_list_editable";
    private static final String IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT =
            "in_app_conversation_action_types_default";
    private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT =
            "notification_conversation_action_types_default";

    private static final boolean LOCAL_TEXT_CLASSIFIER_ENABLED_DEFAULT = true;
    private static final boolean SYSTEM_TEXT_CLASSIFIER_ENABLED_DEFAULT = true;
@@ -111,6 +115,18 @@ public final class TextClassificationConstants {
            .add(TextClassifier.TYPE_DATE)
            .add(TextClassifier.TYPE_DATE_TIME)
            .add(TextClassifier.TYPE_FLIGHT_NUMBER).toString();
    private static final String CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES =
            new StringJoiner(ENTITY_LIST_DELIMITER)
                    .add(ConversationActions.TYPE_TEXT_REPLY)
                    .add(ConversationActions.TYPE_CREATE_REMINDER)
                    .add(ConversationActions.TYPE_CALL_PHONE)
                    .add(ConversationActions.TYPE_OPEN_URL)
                    .add(ConversationActions.TYPE_SEND_EMAIL)
                    .add(ConversationActions.TYPE_SEND_SMS)
                    .add(ConversationActions.TYPE_TRACK_FLIGHT)
                    .add(ConversationActions.TYPE_VIEW_CALENDAR)
                    .add(ConversationActions.TYPE_VIEW_MAP)
                    .toString();

    private final boolean mSystemTextClassifierEnabled;
    private final boolean mLocalTextClassifierEnabled;
@@ -126,6 +142,8 @@ public final class TextClassificationConstants {
    private final List<String> mEntityListDefault;
    private final List<String> mEntityListNotEditable;
    private final List<String> mEntityListEditable;
    private final List<String> mInAppConversationActionTypesDefault;
    private final List<String> mNotificationConversationActionTypesDefault;

    private TextClassificationConstants(@Nullable String settings) {
        final KeyValueListParser parser = new KeyValueListParser(',');
@@ -177,6 +195,12 @@ public final class TextClassificationConstants {
        mEntityListEditable = parseEntityList(parser.getString(
                ENTITY_LIST_EDITABLE,
                ENTITY_LIST_DEFAULT_VALUE));
        mInAppConversationActionTypesDefault = parseEntityList(parser.getString(
                IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT,
                CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES));
        mNotificationConversationActionTypesDefault = parseEntityList(parser.getString(
                NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT,
                CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES));
    }

    /** Load from a settings string. */
@@ -240,6 +264,14 @@ public final class TextClassificationConstants {
        return mEntityListEditable;
    }

    public List<String> getInAppConversationActionTypes() {
        return mInAppConversationActionTypesDefault;
    }

    public List<String> getNotificationConversationActionTypes() {
        return mNotificationConversationActionTypesDefault;
    }

    private static List<String> parseEntityList(String listStr) {
        return Collections.unmodifiableList(Arrays.asList(listStr.split(ENTITY_LIST_DELIMITER)));
    }
@@ -261,6 +293,9 @@ public final class TextClassificationConstants {
        pw.printPair("getEntityListDefault", mEntityListDefault);
        pw.printPair("getEntityListNotEditable", mEntityListNotEditable);
        pw.printPair("getEntityListEditable", mEntityListEditable);
        pw.printPair("getInAppConversationActionTypes", mInAppConversationActionTypesDefault);
        pw.printPair("getNotificationConversationActionTypes",
                mNotificationConversationActionTypesDefault);
        pw.decreaseIndent();
        pw.println();
    }
+110 −2
Original line number Diff line number Diff line
@@ -40,11 +40,13 @@ import android.os.UserManager;
import android.provider.Browser;
import android.provider.CalendarContract;
import android.provider.ContactsContract;
import android.text.TextUtils;

import com.android.internal.annotations.GuardedBy;
import com.android.internal.util.IndentingPrintWriter;
import com.android.internal.util.Preconditions;

import com.google.android.textclassifier.ActionsSuggestionsModel;
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.LangIdModel;

@@ -90,6 +92,11 @@ public final class TextClassifierImpl implements TextClassifier {
    private static final File UPDATED_LANG_ID_MODEL_FILE =
            new File("/data/misc/textclassifier/lang_id.model");

    // Actions
    private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX = "actions_suggestions.model";
    private static final File UPDATED_ACTIONS_MODEL =
            new File("/data/misc/textclassifier/actions_suggestions.model");

    private final Context mContext;
    private final TextClassifier mFallback;
    private final GenerateLinksLogger mGenerateLinksLogger;
@@ -101,6 +108,8 @@ public final class TextClassifierImpl implements TextClassifier {
    private AnnotatorModel mAnnotatorImpl;
    @GuardedBy("mLock") // Do not access outside this lock.
    private LangIdModel mLangIdImpl;
    @GuardedBy("mLock") // Do not access outside this lock.
    private ActionsSuggestionsModel mActionsImpl;

    private final Object mLoggerLock = new Object();
    @GuardedBy("mLoggerLock") // Do not access outside this lock.
@@ -110,6 +119,7 @@ public final class TextClassifierImpl implements TextClassifier {

    private final ModelFileManager mAnnotatorModelFileManager;
    private final ModelFileManager mLangIdModelFileManager;
    private final ModelFileManager mActionsModelFileManager;

    public TextClassifierImpl(
            Context context, TextClassificationConstants settings, TextClassifier fallback) {
@@ -131,6 +141,13 @@ public final class TextClassifierImpl implements TextClassifier {
                        UPDATED_LANG_ID_MODEL_FILE,
                        fd -> -1, // TODO: Replace this with LangIdModel.getVersion(fd)
                        fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT));
        mActionsModelFileManager = new ModelFileManager(
                new ModelFileManager.ModelFileSupplierImpl(
                        FACTORY_MODEL_DIR,
                        ACTIONS_FACTORY_MODEL_FILENAME_REGEX,
                        UPDATED_ACTIONS_MODEL,
                        ActionsSuggestionsModel::getVersion,
                        ActionsSuggestionsModel::getLocales));
    }

    public TextClassifierImpl(Context context, TextClassificationConstants settings) {
@@ -346,10 +363,69 @@ public final class TextClassifierImpl implements TextClassifier {
        return mFallback.detectLanguage(request);
    }

    @Override
    public ConversationActions suggestConversationActions(ConversationActions.Request request) {
        Preconditions.checkNotNull(request);
        Utils.checkMainThread();
        try {
            ActionsSuggestionsModel actionsImpl = getActionsImpl();
            if (actionsImpl == null) {
                // Actions model is optional, fallback if it is not available.
                return mFallback.suggestConversationActions(request);
            }
            List<ActionsSuggestionsModel.ConversationMessage> nativeMessages = new ArrayList<>();
            for (ConversationActions.Message message : request.getConversation()) {
                if (TextUtils.isEmpty(message.getText())) {
                    continue;
                }
                // TODO: We need to map the Person object to user id.
                int userId = 1;
                nativeMessages.add(
                        new ActionsSuggestionsModel.ConversationMessage(
                                userId, message.getText().toString()));
            }
            ActionsSuggestionsModel.Conversation nativeConversation =
                    new ActionsSuggestionsModel.Conversation(nativeMessages.toArray(
                            new ActionsSuggestionsModel.ConversationMessage[0]));

            ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions =
                    actionsImpl.suggestActions(nativeConversation, null);

            Collection<String> expectedTypes = resolveActionTypesFromRequest(request);
            List<ConversationActions.ConversationAction> conversationActions = new ArrayList<>();
            int maxSuggestions = Math.min(request.getMaxSuggestions(), nativeSuggestions.length);
            for (int i = 0; i < maxSuggestions; i++) {
                ActionsSuggestionsModel.ActionSuggestion nativeSuggestion = nativeSuggestions[i];
                String actionType = nativeSuggestion.getActionType();
                if (!expectedTypes.contains(actionType)) {
                    continue;
                }
                conversationActions.add(
                        new ConversationActions.ConversationAction.Builder(actionType)
                                .setTextReply(nativeSuggestion.getResponseText())
                                .setConfidenceScore(nativeSuggestion.getScore())
                                .build());
            }
            return new ConversationActions(conversationActions);
        } catch (Throwable t) {
            // Avoid throwing from this method. Log the error.
            Log.e(LOG_TAG, "Error suggesting conversation actions.", t);
        }
        return mFallback.suggestConversationActions(request);
    }

    private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) {
        List<String> defaultActionTypes =
                request.getHints().contains(ConversationActions.HINT_FOR_NOTIFICATION)
                        ? mSettings.getNotificationConversationActionTypes()
                        : mSettings.getInAppConversationActionTypes();
        return request.getTypeConfig().resolveTypes(defaultActionTypes);
    }

    private AnnotatorModel getAnnotatorImpl(LocaleList localeList)
            throws FileNotFoundException {
        synchronized (mLock) {
            localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
            localeList = localeList == null ? LocaleList.getDefault() : localeList;
            final ModelFileManager.ModelFile bestModel =
                    mAnnotatorModelFileManager.findBestModelFile(localeList);
            if (bestModel == null) {
@@ -386,7 +462,7 @@ public final class TextClassifierImpl implements TextClassifier {
        synchronized (mLock) {
            if (mLangIdImpl == null) {
                final ModelFileManager.ModelFile bestModel =
                        mLangIdModelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList());
                        mLangIdModelFileManager.findBestModelFile(null);
                if (bestModel == null) {
                    throw new FileNotFoundException("No LangID model is found");
                }
@@ -404,6 +480,30 @@ public final class TextClassifierImpl implements TextClassifier {
        }
    }

    @Nullable
    private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException {
        synchronized (mLock) {
            if (mActionsImpl == null) {
                // TODO: Use LangID to determine the locale we should use here?
                final ModelFileManager.ModelFile bestModel =
                        mActionsModelFileManager.findBestModelFile(LocaleList.getDefault());
                if (bestModel == null) {
                    return null;
                }
                final ParcelFileDescriptor pfd = ParcelFileDescriptor.open(
                        new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
                try {
                    if (pfd != null) {
                        mActionsImpl = new ActionsSuggestionsModel(pfd.getFd());
                    }
                } finally {
                    maybeCloseAndLogError(pfd);
                }
            }
            return mActionsImpl;
        }
    }

    private String createId(String text, int start, int end) {
        synchronized (mLock) {
            return SelectionSessionLogger.createId(text, start, end, mContext,
@@ -471,11 +571,19 @@ public final class TextClassifierImpl implements TextClassifier {
            }
            printWriter.decreaseIndent();
            printWriter.println("LangID model file(s):");
            printWriter.increaseIndent();
            for (ModelFileManager.ModelFile modelFile :
                    mLangIdModelFileManager.listModelFiles()) {
                printWriter.println(modelFile.toString());
            }
            printWriter.decreaseIndent();
            printWriter.println("Actions model file(s):");
            printWriter.increaseIndent();
            for (ModelFileManager.ModelFile modelFile :
                    mActionsModelFileManager.listModelFiles()) {
                printWriter.println(modelFile.toString());
            }
            printWriter.decreaseIndent();
            printWriter.printPair("mFallback", mFallback);
            printWriter.decreaseIndent();
            printWriter.println();
+26 −2
Original line number Diff line number Diff line
@@ -43,7 +43,7 @@ import java.util.stream.Collectors;
@SmallTest
@RunWith(AndroidJUnit4.class)
public class ModelFileManagerTest {

    private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
    @Mock
    private Supplier<List<ModelFileManager.ModelFile>> mModelFileSupplier;
    private ModelFileManager.ModelFileSupplierImpl mModelFileSupplierImpl;
@@ -71,6 +71,8 @@ public class ModelFileManagerTest {

        mRootTestDir.mkdirs();
        mFactoryModelDir.mkdirs();

        Locale.setDefault(DEFAULT_LOCALE);
    }

    @After
@@ -134,7 +136,7 @@ public class ModelFileManagerTest {
    }

    @Test
    public void findBestModel_useIndependentWhenNoLanguageModelMatch() {
    public void findBestModel_noMatchedLanguageModel() {
        Locale locale = Locale.forLanguageTag("ja");
        ModelFileManager.ModelFile languageIndependentModelFile =
                new ModelFileManager.ModelFile(
@@ -156,6 +158,28 @@ public class ModelFileManagerTest {
        assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
    }

    @Test
    public void findBestModel_noMatchedLanguageModel_defaultLocaleModelExists() {
        ModelFileManager.ModelFile languageIndependentModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.emptyList(), true);

        ModelFileManager.ModelFile languageDependentModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 1,
                        Collections.singletonList(DEFAULT_LOCALE), false);

        when(mModelFileSupplier.get())
                .thenReturn(
                        Arrays.asList(languageIndependentModelFile, languageDependentModelFile));

        ModelFileManager.ModelFile bestModelFile =
                mModelFileManager.findBestModelFile(
                        LocaleList.forLanguageTags("zh-hk"));
        assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
    }

    @Test
    public void findBestModel_languageIsMoreImportantThanVersion() {
        ModelFileManager.ModelFile matchButOlderModel =
+28 −0
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ package android.view.textclassifier;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
@@ -324,6 +325,33 @@ public class TextClassificationManagerTest {
        assertThat(textLanguage, isTextLanguage("ja"));
    }

    @Test
    public void testSuggestConversationActions_textReplyOnly_maxThree() {
        if (isTextClassifierDisabled()) return;
        ConversationActions.Message message =
                new ConversationActions.Message.Builder().setText("Hello").build();
        ConversationActions.TypeConfig typeConfig =
                new ConversationActions.TypeConfig.Builder().includeTypesFromTextClassifier(false)
                        .setIncludedTypes(
                                Collections.singletonList(ConversationActions.TYPE_TEXT_REPLY))
                        .build();
        ConversationActions.Request request =
                new ConversationActions.Request.Builder(Collections.singletonList(message))
                        .setMaxSuggestions(1)
                        .setTypeConfig(typeConfig)
                        .build();

        ConversationActions conversationActions = mClassifier.suggestConversationActions(request);
        assertTrue(conversationActions.getConversationActions().size() <= 1);
        for (ConversationActions.ConversationAction conversationAction :
                conversationActions.getConversationActions()) {
            assertEquals(conversationAction.getType(), ConversationActions.TYPE_TEXT_REPLY);
            assertNotNull(conversationAction.getTextReply());
            assertTrue(conversationAction.getConfidenceScore() > 0);
            assertTrue(conversationAction.getConfidenceScore() <= 1);
        }
    }

    @Test
    public void testSetTextClassifier() {
        TextClassifier classifier = mock(TextClassifier.class);