Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 9005dafb authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Implements TextClassifierImpl.suggestConversationActions"

parents bda37423 adbebcc6
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -74,10 +74,9 @@ public final class ModelFileManager {
     * @param localeList the required locales, use {@code null} if there is no preference.
     */
    public ModelFile findBestModelFile(@Nullable LocaleList localeList) {
        // Specified localeList takes priority over the system default, so it is listed first.
        final String languages = localeList == null || localeList.isEmpty()
                ? LocaleList.getDefault().toLanguageTags()
                : localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags();
                : localeList.toLanguageTags();
        final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);

        ModelFile bestModel = null;
+35 −0
Original line number Diff line number Diff line
@@ -90,6 +90,10 @@ public final class TextClassificationConstants {
            "entity_list_not_editable";
    private static final String ENTITY_LIST_EDITABLE =
            "entity_list_editable";
    private static final String IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT =
            "in_app_conversation_action_types_default";
    private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT =
            "notification_conversation_action_types_default";

    private static final boolean LOCAL_TEXT_CLASSIFIER_ENABLED_DEFAULT = true;
    private static final boolean SYSTEM_TEXT_CLASSIFIER_ENABLED_DEFAULT = true;
@@ -111,6 +115,18 @@ public final class TextClassificationConstants {
            .add(TextClassifier.TYPE_DATE)
            .add(TextClassifier.TYPE_DATE_TIME)
            .add(TextClassifier.TYPE_FLIGHT_NUMBER).toString();
    private static final String CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES =
            new StringJoiner(ENTITY_LIST_DELIMITER)
                    .add(ConversationActions.TYPE_TEXT_REPLY)
                    .add(ConversationActions.TYPE_CREATE_REMINDER)
                    .add(ConversationActions.TYPE_CALL_PHONE)
                    .add(ConversationActions.TYPE_OPEN_URL)
                    .add(ConversationActions.TYPE_SEND_EMAIL)
                    .add(ConversationActions.TYPE_SEND_SMS)
                    .add(ConversationActions.TYPE_TRACK_FLIGHT)
                    .add(ConversationActions.TYPE_VIEW_CALENDAR)
                    .add(ConversationActions.TYPE_VIEW_MAP)
                    .toString();

    private final boolean mSystemTextClassifierEnabled;
    private final boolean mLocalTextClassifierEnabled;
@@ -126,6 +142,8 @@ public final class TextClassificationConstants {
    private final List<String> mEntityListDefault;
    private final List<String> mEntityListNotEditable;
    private final List<String> mEntityListEditable;
    private final List<String> mInAppConversationActionTypesDefault;
    private final List<String> mNotificationConversationActionTypesDefault;

    private TextClassificationConstants(@Nullable String settings) {
        final KeyValueListParser parser = new KeyValueListParser(',');
@@ -177,6 +195,12 @@ public final class TextClassificationConstants {
        mEntityListEditable = parseEntityList(parser.getString(
                ENTITY_LIST_EDITABLE,
                ENTITY_LIST_DEFAULT_VALUE));
        mInAppConversationActionTypesDefault = parseEntityList(parser.getString(
                IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT,
                CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES));
        mNotificationConversationActionTypesDefault = parseEntityList(parser.getString(
                NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT,
                CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES));
    }

    /** Load from a settings string. */
@@ -240,6 +264,14 @@ public final class TextClassificationConstants {
        return mEntityListEditable;
    }

    public List<String> getInAppConversationActionTypes() {
        return mInAppConversationActionTypesDefault;
    }

    public List<String> getNotificationConversationActionTypes() {
        return mNotificationConversationActionTypesDefault;
    }

    private static List<String> parseEntityList(String listStr) {
        return Collections.unmodifiableList(Arrays.asList(listStr.split(ENTITY_LIST_DELIMITER)));
    }
@@ -261,6 +293,9 @@ public final class TextClassificationConstants {
        pw.printPair("getEntityListDefault", mEntityListDefault);
        pw.printPair("getEntityListNotEditable", mEntityListNotEditable);
        pw.printPair("getEntityListEditable", mEntityListEditable);
        pw.printPair("getInAppConversationActionTypes", mInAppConversationActionTypesDefault);
        pw.printPair("getNotificationConversationActionTypes",
                mNotificationConversationActionTypesDefault);
        pw.decreaseIndent();
        pw.println();
    }
+110 −2
Original line number Diff line number Diff line
@@ -40,11 +40,13 @@ import android.os.UserManager;
import android.provider.Browser;
import android.provider.CalendarContract;
import android.provider.ContactsContract;
import android.text.TextUtils;

import com.android.internal.annotations.GuardedBy;
import com.android.internal.util.IndentingPrintWriter;
import com.android.internal.util.Preconditions;

import com.google.android.textclassifier.ActionsSuggestionsModel;
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.LangIdModel;

@@ -90,6 +92,11 @@ public final class TextClassifierImpl implements TextClassifier {
    private static final File UPDATED_LANG_ID_MODEL_FILE =
            new File("/data/misc/textclassifier/lang_id.model");

    // Actions
    private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX = "actions_suggestions.model";
    private static final File UPDATED_ACTIONS_MODEL =
            new File("/data/misc/textclassifier/actions_suggestions.model");

    private final Context mContext;
    private final TextClassifier mFallback;
    private final GenerateLinksLogger mGenerateLinksLogger;
@@ -101,6 +108,8 @@ public final class TextClassifierImpl implements TextClassifier {
    private AnnotatorModel mAnnotatorImpl;
    @GuardedBy("mLock") // Do not access outside this lock.
    private LangIdModel mLangIdImpl;
    @GuardedBy("mLock") // Do not access outside this lock.
    private ActionsSuggestionsModel mActionsImpl;

    private final Object mLoggerLock = new Object();
    @GuardedBy("mLoggerLock") // Do not access outside this lock.
@@ -110,6 +119,7 @@ public final class TextClassifierImpl implements TextClassifier {

    private final ModelFileManager mAnnotatorModelFileManager;
    private final ModelFileManager mLangIdModelFileManager;
    private final ModelFileManager mActionsModelFileManager;

    public TextClassifierImpl(
            Context context, TextClassificationConstants settings, TextClassifier fallback) {
@@ -131,6 +141,13 @@ public final class TextClassifierImpl implements TextClassifier {
                        UPDATED_LANG_ID_MODEL_FILE,
                        fd -> -1, // TODO: Replace this with LangIdModel.getVersion(fd)
                        fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT));
        mActionsModelFileManager = new ModelFileManager(
                new ModelFileManager.ModelFileSupplierImpl(
                        FACTORY_MODEL_DIR,
                        ACTIONS_FACTORY_MODEL_FILENAME_REGEX,
                        UPDATED_ACTIONS_MODEL,
                        ActionsSuggestionsModel::getVersion,
                        ActionsSuggestionsModel::getLocales));
    }

    public TextClassifierImpl(Context context, TextClassificationConstants settings) {
@@ -346,10 +363,69 @@ public final class TextClassifierImpl implements TextClassifier {
        return mFallback.detectLanguage(request);
    }

    @Override
    public ConversationActions suggestConversationActions(ConversationActions.Request request) {
        Preconditions.checkNotNull(request);
        Utils.checkMainThread();
        try {
            ActionsSuggestionsModel actionsImpl = getActionsImpl();
            if (actionsImpl == null) {
                // Actions model is optional, fallback if it is not available.
                return mFallback.suggestConversationActions(request);
            }
            List<ActionsSuggestionsModel.ConversationMessage> nativeMessages = new ArrayList<>();
            for (ConversationActions.Message message : request.getConversation()) {
                if (TextUtils.isEmpty(message.getText())) {
                    continue;
                }
                // TODO: We need to map the Person object to user id.
                int userId = 1;
                nativeMessages.add(
                        new ActionsSuggestionsModel.ConversationMessage(
                                userId, message.getText().toString()));
            }
            ActionsSuggestionsModel.Conversation nativeConversation =
                    new ActionsSuggestionsModel.Conversation(nativeMessages.toArray(
                            new ActionsSuggestionsModel.ConversationMessage[0]));

            ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions =
                    actionsImpl.suggestActions(nativeConversation, null);

            Collection<String> expectedTypes = resolveActionTypesFromRequest(request);
            List<ConversationActions.ConversationAction> conversationActions = new ArrayList<>();
            int maxSuggestions = Math.min(request.getMaxSuggestions(), nativeSuggestions.length);
            for (int i = 0; i < maxSuggestions; i++) {
                ActionsSuggestionsModel.ActionSuggestion nativeSuggestion = nativeSuggestions[i];
                String actionType = nativeSuggestion.getActionType();
                if (!expectedTypes.contains(actionType)) {
                    continue;
                }
                conversationActions.add(
                        new ConversationActions.ConversationAction.Builder(actionType)
                                .setTextReply(nativeSuggestion.getResponseText())
                                .setConfidenceScore(nativeSuggestion.getScore())
                                .build());
            }
            return new ConversationActions(conversationActions);
        } catch (Throwable t) {
            // Avoid throwing from this method. Log the error.
            Log.e(LOG_TAG, "Error suggesting conversation actions.", t);
        }
        return mFallback.suggestConversationActions(request);
    }

    private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) {
        List<String> defaultActionTypes =
                request.getHints().contains(ConversationActions.HINT_FOR_NOTIFICATION)
                        ? mSettings.getNotificationConversationActionTypes()
                        : mSettings.getInAppConversationActionTypes();
        return request.getTypeConfig().resolveTypes(defaultActionTypes);
    }

    private AnnotatorModel getAnnotatorImpl(LocaleList localeList)
            throws FileNotFoundException {
        synchronized (mLock) {
            localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
            localeList = localeList == null ? LocaleList.getDefault() : localeList;
            final ModelFileManager.ModelFile bestModel =
                    mAnnotatorModelFileManager.findBestModelFile(localeList);
            if (bestModel == null) {
@@ -386,7 +462,7 @@ public final class TextClassifierImpl implements TextClassifier {
        synchronized (mLock) {
            if (mLangIdImpl == null) {
                final ModelFileManager.ModelFile bestModel =
                        mLangIdModelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList());
                        mLangIdModelFileManager.findBestModelFile(null);
                if (bestModel == null) {
                    throw new FileNotFoundException("No LangID model is found");
                }
@@ -404,6 +480,30 @@ public final class TextClassifierImpl implements TextClassifier {
        }
    }

    @Nullable
    private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException {
        synchronized (mLock) {
            if (mActionsImpl == null) {
                // TODO: Use LangID to determine the locale we should use here?
                final ModelFileManager.ModelFile bestModel =
                        mActionsModelFileManager.findBestModelFile(LocaleList.getDefault());
                if (bestModel == null) {
                    return null;
                }
                final ParcelFileDescriptor pfd = ParcelFileDescriptor.open(
                        new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
                try {
                    if (pfd != null) {
                        mActionsImpl = new ActionsSuggestionsModel(pfd.getFd());
                    }
                } finally {
                    maybeCloseAndLogError(pfd);
                }
            }
            return mActionsImpl;
        }
    }

    private String createId(String text, int start, int end) {
        synchronized (mLock) {
            return SelectionSessionLogger.createId(text, start, end, mContext,
@@ -471,11 +571,19 @@ public final class TextClassifierImpl implements TextClassifier {
            }
            printWriter.decreaseIndent();
            printWriter.println("LangID model file(s):");
            printWriter.increaseIndent();
            for (ModelFileManager.ModelFile modelFile :
                    mLangIdModelFileManager.listModelFiles()) {
                printWriter.println(modelFile.toString());
            }
            printWriter.decreaseIndent();
            printWriter.println("Actions model file(s):");
            printWriter.increaseIndent();
            for (ModelFileManager.ModelFile modelFile :
                    mActionsModelFileManager.listModelFiles()) {
                printWriter.println(modelFile.toString());
            }
            printWriter.decreaseIndent();
            printWriter.printPair("mFallback", mFallback);
            printWriter.decreaseIndent();
            printWriter.println();
+26 −2
Original line number Diff line number Diff line
@@ -43,7 +43,7 @@ import java.util.stream.Collectors;
@SmallTest
@RunWith(AndroidJUnit4.class)
public class ModelFileManagerTest {

    private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
    @Mock
    private Supplier<List<ModelFileManager.ModelFile>> mModelFileSupplier;
    private ModelFileManager.ModelFileSupplierImpl mModelFileSupplierImpl;
@@ -71,6 +71,8 @@ public class ModelFileManagerTest {

        mRootTestDir.mkdirs();
        mFactoryModelDir.mkdirs();

        Locale.setDefault(DEFAULT_LOCALE);
    }

    @After
@@ -134,7 +136,7 @@ public class ModelFileManagerTest {
    }

    @Test
    public void findBestModel_useIndependentWhenNoLanguageModelMatch() {
    public void findBestModel_noMatchedLanguageModel() {
        Locale locale = Locale.forLanguageTag("ja");
        ModelFileManager.ModelFile languageIndependentModelFile =
                new ModelFileManager.ModelFile(
@@ -156,6 +158,28 @@ public class ModelFileManagerTest {
        assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
    }

    @Test
    public void findBestModel_noMatchedLanguageModel_defaultLocaleModelExists() {
        ModelFileManager.ModelFile languageIndependentModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/a"), 1,
                        Collections.emptyList(), true);

        ModelFileManager.ModelFile languageDependentModelFile =
                new ModelFileManager.ModelFile(
                        new File("/path/b"), 1,
                        Collections.singletonList(DEFAULT_LOCALE), false);

        when(mModelFileSupplier.get())
                .thenReturn(
                        Arrays.asList(languageIndependentModelFile, languageDependentModelFile));

        ModelFileManager.ModelFile bestModelFile =
                mModelFileManager.findBestModelFile(
                        LocaleList.forLanguageTags("zh-hk"));
        assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
    }

    @Test
    public void findBestModel_languageIsMoreImportantThanVersion() {
        ModelFileManager.ModelFile matchButOlderModel =
+28 −0
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ package android.view.textclassifier;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
@@ -324,6 +325,33 @@ public class TextClassificationManagerTest {
        assertThat(textLanguage, isTextLanguage("ja"));
    }

    @Test
    public void testSuggestConversationActions_textReplyOnly_maxThree() {
        if (isTextClassifierDisabled()) return;
        ConversationActions.Message message =
                new ConversationActions.Message.Builder().setText("Hello").build();
        ConversationActions.TypeConfig typeConfig =
                new ConversationActions.TypeConfig.Builder().includeTypesFromTextClassifier(false)
                        .setIncludedTypes(
                                Collections.singletonList(ConversationActions.TYPE_TEXT_REPLY))
                        .build();
        ConversationActions.Request request =
                new ConversationActions.Request.Builder(Collections.singletonList(message))
                        .setMaxSuggestions(1)
                        .setTypeConfig(typeConfig)
                        .build();

        ConversationActions conversationActions = mClassifier.suggestConversationActions(request);
        assertTrue(conversationActions.getConversationActions().size() <= 1);
        for (ConversationActions.ConversationAction conversationAction :
                conversationActions.getConversationActions()) {
            assertEquals(conversationAction.getType(), ConversationActions.TYPE_TEXT_REPLY);
            assertNotNull(conversationAction.getTextReply());
            assertTrue(conversationAction.getConfidenceScore() > 0);
            assertTrue(conversationAction.getConfidenceScore() <= 1);
        }
    }

    @Test
    public void testSetTextClassifier() {
        TextClassifier classifier = mock(TextClassifier.class);