Loading core/java/android/view/textclassifier/ModelFileManager.java +1 −2 Original line number Diff line number Diff line Loading @@ -74,10 +74,9 @@ public final class ModelFileManager { * @param localeList the required locales, use {@code null} if there is no preference. */ public ModelFile findBestModelFile(@Nullable LocaleList localeList) { // Specified localeList takes priority over the system default, so it is listed first. final String languages = localeList == null || localeList.isEmpty() ? LocaleList.getDefault().toLanguageTags() : localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags(); : localeList.toLanguageTags(); final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages); ModelFile bestModel = null; Loading core/java/android/view/textclassifier/TextClassificationConstants.java +35 −0 Original line number Diff line number Diff line Loading @@ -90,6 +90,10 @@ public final class TextClassificationConstants { "entity_list_not_editable"; private static final String ENTITY_LIST_EDITABLE = "entity_list_editable"; private static final String IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT = "in_app_conversation_action_types_default"; private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT = "notification_conversation_action_types_default"; private static final boolean LOCAL_TEXT_CLASSIFIER_ENABLED_DEFAULT = true; private static final boolean SYSTEM_TEXT_CLASSIFIER_ENABLED_DEFAULT = true; Loading @@ -111,6 +115,18 @@ public final class TextClassificationConstants { .add(TextClassifier.TYPE_DATE) .add(TextClassifier.TYPE_DATE_TIME) .add(TextClassifier.TYPE_FLIGHT_NUMBER).toString(); private static final String CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES = new StringJoiner(ENTITY_LIST_DELIMITER) .add(ConversationActions.TYPE_TEXT_REPLY) .add(ConversationActions.TYPE_CREATE_REMINDER) .add(ConversationActions.TYPE_CALL_PHONE) .add(ConversationActions.TYPE_OPEN_URL) .add(ConversationActions.TYPE_SEND_EMAIL) .add(ConversationActions.TYPE_SEND_SMS) .add(ConversationActions.TYPE_TRACK_FLIGHT) .add(ConversationActions.TYPE_VIEW_CALENDAR) .add(ConversationActions.TYPE_VIEW_MAP) .toString(); private final boolean mSystemTextClassifierEnabled; private final boolean mLocalTextClassifierEnabled; Loading @@ -126,6 +142,8 @@ public final class TextClassificationConstants { private final List<String> mEntityListDefault; private final List<String> mEntityListNotEditable; private final List<String> mEntityListEditable; private final List<String> mInAppConversationActionTypesDefault; private final List<String> mNotificationConversationActionTypesDefault; private TextClassificationConstants(@Nullable String settings) { final KeyValueListParser parser = new KeyValueListParser(','); Loading Loading @@ -177,6 +195,12 @@ public final class TextClassificationConstants { mEntityListEditable = parseEntityList(parser.getString( ENTITY_LIST_EDITABLE, ENTITY_LIST_DEFAULT_VALUE)); mInAppConversationActionTypesDefault = parseEntityList(parser.getString( IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT, CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES)); mNotificationConversationActionTypesDefault = parseEntityList(parser.getString( NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT, CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES)); } /** Load from a settings string. */ Loading Loading @@ -240,6 +264,14 @@ public final class TextClassificationConstants { return mEntityListEditable; } public List<String> getInAppConversationActionTypes() { return mInAppConversationActionTypesDefault; } public List<String> getNotificationConversationActionTypes() { return mNotificationConversationActionTypesDefault; } private static List<String> parseEntityList(String listStr) { return Collections.unmodifiableList(Arrays.asList(listStr.split(ENTITY_LIST_DELIMITER))); } Loading @@ -261,6 +293,9 @@ public final class TextClassificationConstants { pw.printPair("getEntityListDefault", mEntityListDefault); pw.printPair("getEntityListNotEditable", mEntityListNotEditable); pw.printPair("getEntityListEditable", mEntityListEditable); pw.printPair("getInAppConversationActionTypes", mInAppConversationActionTypesDefault); pw.printPair("getNotificationConversationActionTypes", mNotificationConversationActionTypesDefault); pw.decreaseIndent(); pw.println(); } Loading core/java/android/view/textclassifier/TextClassifierImpl.java +110 −2 Original line number Diff line number Diff line Loading @@ -40,11 +40,13 @@ import android.os.UserManager; import android.provider.Browser; import android.provider.CalendarContract; import android.provider.ContactsContract; import android.text.TextUtils; import com.android.internal.annotations.GuardedBy; import com.android.internal.util.IndentingPrintWriter; import com.android.internal.util.Preconditions; import com.google.android.textclassifier.ActionsSuggestionsModel; import com.google.android.textclassifier.AnnotatorModel; import com.google.android.textclassifier.LangIdModel; Loading Loading @@ -90,6 +92,11 @@ public final class TextClassifierImpl implements TextClassifier { private static final File UPDATED_LANG_ID_MODEL_FILE = new File("/data/misc/textclassifier/lang_id.model"); // Actions private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX = "actions_suggestions.model"; private static final File UPDATED_ACTIONS_MODEL = new File("/data/misc/textclassifier/actions_suggestions.model"); private final Context mContext; private final TextClassifier mFallback; private final GenerateLinksLogger mGenerateLinksLogger; Loading @@ -101,6 +108,8 @@ public final class TextClassifierImpl implements TextClassifier { private AnnotatorModel mAnnotatorImpl; @GuardedBy("mLock") // Do not access outside this lock. private LangIdModel mLangIdImpl; @GuardedBy("mLock") // Do not access outside this lock. private ActionsSuggestionsModel mActionsImpl; private final Object mLoggerLock = new Object(); @GuardedBy("mLoggerLock") // Do not access outside this lock. Loading @@ -110,6 +119,7 @@ public final class TextClassifierImpl implements TextClassifier { private final ModelFileManager mAnnotatorModelFileManager; private final ModelFileManager mLangIdModelFileManager; private final ModelFileManager mActionsModelFileManager; public TextClassifierImpl( Context context, TextClassificationConstants settings, TextClassifier fallback) { Loading @@ -131,6 +141,13 @@ public final class TextClassifierImpl implements TextClassifier { UPDATED_LANG_ID_MODEL_FILE, fd -> -1, // TODO: Replace this with LangIdModel.getVersion(fd) fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT)); mActionsModelFileManager = new ModelFileManager( new ModelFileManager.ModelFileSupplierImpl( FACTORY_MODEL_DIR, ACTIONS_FACTORY_MODEL_FILENAME_REGEX, UPDATED_ACTIONS_MODEL, ActionsSuggestionsModel::getVersion, ActionsSuggestionsModel::getLocales)); } public TextClassifierImpl(Context context, TextClassificationConstants settings) { Loading Loading @@ -346,10 +363,69 @@ public final class TextClassifierImpl implements TextClassifier { return mFallback.detectLanguage(request); } @Override public ConversationActions suggestConversationActions(ConversationActions.Request request) { Preconditions.checkNotNull(request); Utils.checkMainThread(); try { ActionsSuggestionsModel actionsImpl = getActionsImpl(); if (actionsImpl == null) { // Actions model is optional, fallback if it is not available. return mFallback.suggestConversationActions(request); } List<ActionsSuggestionsModel.ConversationMessage> nativeMessages = new ArrayList<>(); for (ConversationActions.Message message : request.getConversation()) { if (TextUtils.isEmpty(message.getText())) { continue; } // TODO: We need to map the Person object to user id. int userId = 1; nativeMessages.add( new ActionsSuggestionsModel.ConversationMessage( userId, message.getText().toString())); } ActionsSuggestionsModel.Conversation nativeConversation = new ActionsSuggestionsModel.Conversation(nativeMessages.toArray( new ActionsSuggestionsModel.ConversationMessage[0])); ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions = actionsImpl.suggestActions(nativeConversation, null); Collection<String> expectedTypes = resolveActionTypesFromRequest(request); List<ConversationActions.ConversationAction> conversationActions = new ArrayList<>(); int maxSuggestions = Math.min(request.getMaxSuggestions(), nativeSuggestions.length); for (int i = 0; i < maxSuggestions; i++) { ActionsSuggestionsModel.ActionSuggestion nativeSuggestion = nativeSuggestions[i]; String actionType = nativeSuggestion.getActionType(); if (!expectedTypes.contains(actionType)) { continue; } conversationActions.add( new ConversationActions.ConversationAction.Builder(actionType) .setTextReply(nativeSuggestion.getResponseText()) .setConfidenceScore(nativeSuggestion.getScore()) .build()); } return new ConversationActions(conversationActions); } catch (Throwable t) { // Avoid throwing from this method. Log the error. Log.e(LOG_TAG, "Error suggesting conversation actions.", t); } return mFallback.suggestConversationActions(request); } private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) { List<String> defaultActionTypes = request.getHints().contains(ConversationActions.HINT_FOR_NOTIFICATION) ? mSettings.getNotificationConversationActionTypes() : mSettings.getInAppConversationActionTypes(); return request.getTypeConfig().resolveTypes(defaultActionTypes); } private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws FileNotFoundException { synchronized (mLock) { localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList; localeList = localeList == null ? LocaleList.getDefault() : localeList; final ModelFileManager.ModelFile bestModel = mAnnotatorModelFileManager.findBestModelFile(localeList); if (bestModel == null) { Loading Loading @@ -386,7 +462,7 @@ public final class TextClassifierImpl implements TextClassifier { synchronized (mLock) { if (mLangIdImpl == null) { final ModelFileManager.ModelFile bestModel = mLangIdModelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList()); mLangIdModelFileManager.findBestModelFile(null); if (bestModel == null) { throw new FileNotFoundException("No LangID model is found"); } Loading @@ -404,6 +480,30 @@ public final class TextClassifierImpl implements TextClassifier { } } @Nullable private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException { synchronized (mLock) { if (mActionsImpl == null) { // TODO: Use LangID to determine the locale we should use here? final ModelFileManager.ModelFile bestModel = mActionsModelFileManager.findBestModelFile(LocaleList.getDefault()); if (bestModel == null) { return null; } final ParcelFileDescriptor pfd = ParcelFileDescriptor.open( new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY); try { if (pfd != null) { mActionsImpl = new ActionsSuggestionsModel(pfd.getFd()); } } finally { maybeCloseAndLogError(pfd); } } return mActionsImpl; } } private String createId(String text, int start, int end) { synchronized (mLock) { return SelectionSessionLogger.createId(text, start, end, mContext, Loading Loading @@ -471,11 +571,19 @@ public final class TextClassifierImpl implements TextClassifier { } printWriter.decreaseIndent(); printWriter.println("LangID model file(s):"); printWriter.increaseIndent(); for (ModelFileManager.ModelFile modelFile : mLangIdModelFileManager.listModelFiles()) { printWriter.println(modelFile.toString()); } printWriter.decreaseIndent(); printWriter.println("Actions model file(s):"); printWriter.increaseIndent(); for (ModelFileManager.ModelFile modelFile : mActionsModelFileManager.listModelFiles()) { printWriter.println(modelFile.toString()); } printWriter.decreaseIndent(); printWriter.printPair("mFallback", mFallback); printWriter.decreaseIndent(); printWriter.println(); Loading core/tests/coretests/src/android/view/textclassifier/ModelFileManagerTest.java +26 −2 Original line number Diff line number Diff line Loading @@ -43,7 +43,7 @@ import java.util.stream.Collectors; @SmallTest @RunWith(AndroidJUnit4.class) public class ModelFileManagerTest { private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US"); @Mock private Supplier<List<ModelFileManager.ModelFile>> mModelFileSupplier; private ModelFileManager.ModelFileSupplierImpl mModelFileSupplierImpl; Loading Loading @@ -71,6 +71,8 @@ public class ModelFileManagerTest { mRootTestDir.mkdirs(); mFactoryModelDir.mkdirs(); Locale.setDefault(DEFAULT_LOCALE); } @After Loading Loading @@ -134,7 +136,7 @@ public class ModelFileManagerTest { } @Test public void findBestModel_useIndependentWhenNoLanguageModelMatch() { public void findBestModel_noMatchedLanguageModel() { Locale locale = Locale.forLanguageTag("ja"); ModelFileManager.ModelFile languageIndependentModelFile = new ModelFileManager.ModelFile( Loading @@ -156,6 +158,28 @@ public class ModelFileManagerTest { assertThat(bestModelFile).isEqualTo(languageIndependentModelFile); } @Test public void findBestModel_noMatchedLanguageModel_defaultLocaleModelExists() { ModelFileManager.ModelFile languageIndependentModelFile = new ModelFileManager.ModelFile( new File("/path/a"), 1, Collections.emptyList(), true); ModelFileManager.ModelFile languageDependentModelFile = new ModelFileManager.ModelFile( new File("/path/b"), 1, Collections.singletonList(DEFAULT_LOCALE), false); when(mModelFileSupplier.get()) .thenReturn( Arrays.asList(languageIndependentModelFile, languageDependentModelFile)); ModelFileManager.ModelFile bestModelFile = mModelFileManager.findBestModelFile( LocaleList.forLanguageTags("zh-hk")); assertThat(bestModelFile).isEqualTo(languageIndependentModelFile); } @Test public void findBestModel_languageIsMoreImportantThanVersion() { ModelFileManager.ModelFile matchButOlderModel = Loading core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java +28 −0 Original line number Diff line number Diff line Loading @@ -19,6 +19,7 @@ package android.view.textclassifier; import static org.hamcrest.CoreMatchers.not; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; Loading Loading @@ -324,6 +325,33 @@ public class TextClassificationManagerTest { assertThat(textLanguage, isTextLanguage("ja")); } @Test public void testSuggestConversationActions_textReplyOnly_maxThree() { if (isTextClassifierDisabled()) return; ConversationActions.Message message = new ConversationActions.Message.Builder().setText("Hello").build(); ConversationActions.TypeConfig typeConfig = new ConversationActions.TypeConfig.Builder().includeTypesFromTextClassifier(false) .setIncludedTypes( Collections.singletonList(ConversationActions.TYPE_TEXT_REPLY)) .build(); ConversationActions.Request request = new ConversationActions.Request.Builder(Collections.singletonList(message)) .setMaxSuggestions(1) .setTypeConfig(typeConfig) .build(); ConversationActions conversationActions = mClassifier.suggestConversationActions(request); assertTrue(conversationActions.getConversationActions().size() <= 1); for (ConversationActions.ConversationAction conversationAction : conversationActions.getConversationActions()) { assertEquals(conversationAction.getType(), ConversationActions.TYPE_TEXT_REPLY); assertNotNull(conversationAction.getTextReply()); assertTrue(conversationAction.getConfidenceScore() > 0); assertTrue(conversationAction.getConfidenceScore() <= 1); } } @Test public void testSetTextClassifier() { TextClassifier classifier = mock(TextClassifier.class); Loading Loading
core/java/android/view/textclassifier/ModelFileManager.java +1 −2 Original line number Diff line number Diff line Loading @@ -74,10 +74,9 @@ public final class ModelFileManager { * @param localeList the required locales, use {@code null} if there is no preference. */ public ModelFile findBestModelFile(@Nullable LocaleList localeList) { // Specified localeList takes priority over the system default, so it is listed first. final String languages = localeList == null || localeList.isEmpty() ? LocaleList.getDefault().toLanguageTags() : localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags(); : localeList.toLanguageTags(); final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages); ModelFile bestModel = null; Loading
core/java/android/view/textclassifier/TextClassificationConstants.java +35 −0 Original line number Diff line number Diff line Loading @@ -90,6 +90,10 @@ public final class TextClassificationConstants { "entity_list_not_editable"; private static final String ENTITY_LIST_EDITABLE = "entity_list_editable"; private static final String IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT = "in_app_conversation_action_types_default"; private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT = "notification_conversation_action_types_default"; private static final boolean LOCAL_TEXT_CLASSIFIER_ENABLED_DEFAULT = true; private static final boolean SYSTEM_TEXT_CLASSIFIER_ENABLED_DEFAULT = true; Loading @@ -111,6 +115,18 @@ public final class TextClassificationConstants { .add(TextClassifier.TYPE_DATE) .add(TextClassifier.TYPE_DATE_TIME) .add(TextClassifier.TYPE_FLIGHT_NUMBER).toString(); private static final String CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES = new StringJoiner(ENTITY_LIST_DELIMITER) .add(ConversationActions.TYPE_TEXT_REPLY) .add(ConversationActions.TYPE_CREATE_REMINDER) .add(ConversationActions.TYPE_CALL_PHONE) .add(ConversationActions.TYPE_OPEN_URL) .add(ConversationActions.TYPE_SEND_EMAIL) .add(ConversationActions.TYPE_SEND_SMS) .add(ConversationActions.TYPE_TRACK_FLIGHT) .add(ConversationActions.TYPE_VIEW_CALENDAR) .add(ConversationActions.TYPE_VIEW_MAP) .toString(); private final boolean mSystemTextClassifierEnabled; private final boolean mLocalTextClassifierEnabled; Loading @@ -126,6 +142,8 @@ public final class TextClassificationConstants { private final List<String> mEntityListDefault; private final List<String> mEntityListNotEditable; private final List<String> mEntityListEditable; private final List<String> mInAppConversationActionTypesDefault; private final List<String> mNotificationConversationActionTypesDefault; private TextClassificationConstants(@Nullable String settings) { final KeyValueListParser parser = new KeyValueListParser(','); Loading Loading @@ -177,6 +195,12 @@ public final class TextClassificationConstants { mEntityListEditable = parseEntityList(parser.getString( ENTITY_LIST_EDITABLE, ENTITY_LIST_DEFAULT_VALUE)); mInAppConversationActionTypesDefault = parseEntityList(parser.getString( IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT, CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES)); mNotificationConversationActionTypesDefault = parseEntityList(parser.getString( NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT, CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES)); } /** Load from a settings string. */ Loading Loading @@ -240,6 +264,14 @@ public final class TextClassificationConstants { return mEntityListEditable; } public List<String> getInAppConversationActionTypes() { return mInAppConversationActionTypesDefault; } public List<String> getNotificationConversationActionTypes() { return mNotificationConversationActionTypesDefault; } private static List<String> parseEntityList(String listStr) { return Collections.unmodifiableList(Arrays.asList(listStr.split(ENTITY_LIST_DELIMITER))); } Loading @@ -261,6 +293,9 @@ public final class TextClassificationConstants { pw.printPair("getEntityListDefault", mEntityListDefault); pw.printPair("getEntityListNotEditable", mEntityListNotEditable); pw.printPair("getEntityListEditable", mEntityListEditable); pw.printPair("getInAppConversationActionTypes", mInAppConversationActionTypesDefault); pw.printPair("getNotificationConversationActionTypes", mNotificationConversationActionTypesDefault); pw.decreaseIndent(); pw.println(); } Loading
core/java/android/view/textclassifier/TextClassifierImpl.java +110 −2 Original line number Diff line number Diff line Loading @@ -40,11 +40,13 @@ import android.os.UserManager; import android.provider.Browser; import android.provider.CalendarContract; import android.provider.ContactsContract; import android.text.TextUtils; import com.android.internal.annotations.GuardedBy; import com.android.internal.util.IndentingPrintWriter; import com.android.internal.util.Preconditions; import com.google.android.textclassifier.ActionsSuggestionsModel; import com.google.android.textclassifier.AnnotatorModel; import com.google.android.textclassifier.LangIdModel; Loading Loading @@ -90,6 +92,11 @@ public final class TextClassifierImpl implements TextClassifier { private static final File UPDATED_LANG_ID_MODEL_FILE = new File("/data/misc/textclassifier/lang_id.model"); // Actions private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX = "actions_suggestions.model"; private static final File UPDATED_ACTIONS_MODEL = new File("/data/misc/textclassifier/actions_suggestions.model"); private final Context mContext; private final TextClassifier mFallback; private final GenerateLinksLogger mGenerateLinksLogger; Loading @@ -101,6 +108,8 @@ public final class TextClassifierImpl implements TextClassifier { private AnnotatorModel mAnnotatorImpl; @GuardedBy("mLock") // Do not access outside this lock. private LangIdModel mLangIdImpl; @GuardedBy("mLock") // Do not access outside this lock. private ActionsSuggestionsModel mActionsImpl; private final Object mLoggerLock = new Object(); @GuardedBy("mLoggerLock") // Do not access outside this lock. Loading @@ -110,6 +119,7 @@ public final class TextClassifierImpl implements TextClassifier { private final ModelFileManager mAnnotatorModelFileManager; private final ModelFileManager mLangIdModelFileManager; private final ModelFileManager mActionsModelFileManager; public TextClassifierImpl( Context context, TextClassificationConstants settings, TextClassifier fallback) { Loading @@ -131,6 +141,13 @@ public final class TextClassifierImpl implements TextClassifier { UPDATED_LANG_ID_MODEL_FILE, fd -> -1, // TODO: Replace this with LangIdModel.getVersion(fd) fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT)); mActionsModelFileManager = new ModelFileManager( new ModelFileManager.ModelFileSupplierImpl( FACTORY_MODEL_DIR, ACTIONS_FACTORY_MODEL_FILENAME_REGEX, UPDATED_ACTIONS_MODEL, ActionsSuggestionsModel::getVersion, ActionsSuggestionsModel::getLocales)); } public TextClassifierImpl(Context context, TextClassificationConstants settings) { Loading Loading @@ -346,10 +363,69 @@ public final class TextClassifierImpl implements TextClassifier { return mFallback.detectLanguage(request); } @Override public ConversationActions suggestConversationActions(ConversationActions.Request request) { Preconditions.checkNotNull(request); Utils.checkMainThread(); try { ActionsSuggestionsModel actionsImpl = getActionsImpl(); if (actionsImpl == null) { // Actions model is optional, fallback if it is not available. return mFallback.suggestConversationActions(request); } List<ActionsSuggestionsModel.ConversationMessage> nativeMessages = new ArrayList<>(); for (ConversationActions.Message message : request.getConversation()) { if (TextUtils.isEmpty(message.getText())) { continue; } // TODO: We need to map the Person object to user id. int userId = 1; nativeMessages.add( new ActionsSuggestionsModel.ConversationMessage( userId, message.getText().toString())); } ActionsSuggestionsModel.Conversation nativeConversation = new ActionsSuggestionsModel.Conversation(nativeMessages.toArray( new ActionsSuggestionsModel.ConversationMessage[0])); ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions = actionsImpl.suggestActions(nativeConversation, null); Collection<String> expectedTypes = resolveActionTypesFromRequest(request); List<ConversationActions.ConversationAction> conversationActions = new ArrayList<>(); int maxSuggestions = Math.min(request.getMaxSuggestions(), nativeSuggestions.length); for (int i = 0; i < maxSuggestions; i++) { ActionsSuggestionsModel.ActionSuggestion nativeSuggestion = nativeSuggestions[i]; String actionType = nativeSuggestion.getActionType(); if (!expectedTypes.contains(actionType)) { continue; } conversationActions.add( new ConversationActions.ConversationAction.Builder(actionType) .setTextReply(nativeSuggestion.getResponseText()) .setConfidenceScore(nativeSuggestion.getScore()) .build()); } return new ConversationActions(conversationActions); } catch (Throwable t) { // Avoid throwing from this method. Log the error. Log.e(LOG_TAG, "Error suggesting conversation actions.", t); } return mFallback.suggestConversationActions(request); } private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) { List<String> defaultActionTypes = request.getHints().contains(ConversationActions.HINT_FOR_NOTIFICATION) ? mSettings.getNotificationConversationActionTypes() : mSettings.getInAppConversationActionTypes(); return request.getTypeConfig().resolveTypes(defaultActionTypes); } private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws FileNotFoundException { synchronized (mLock) { localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList; localeList = localeList == null ? LocaleList.getDefault() : localeList; final ModelFileManager.ModelFile bestModel = mAnnotatorModelFileManager.findBestModelFile(localeList); if (bestModel == null) { Loading Loading @@ -386,7 +462,7 @@ public final class TextClassifierImpl implements TextClassifier { synchronized (mLock) { if (mLangIdImpl == null) { final ModelFileManager.ModelFile bestModel = mLangIdModelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList()); mLangIdModelFileManager.findBestModelFile(null); if (bestModel == null) { throw new FileNotFoundException("No LangID model is found"); } Loading @@ -404,6 +480,30 @@ public final class TextClassifierImpl implements TextClassifier { } } @Nullable private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException { synchronized (mLock) { if (mActionsImpl == null) { // TODO: Use LangID to determine the locale we should use here? final ModelFileManager.ModelFile bestModel = mActionsModelFileManager.findBestModelFile(LocaleList.getDefault()); if (bestModel == null) { return null; } final ParcelFileDescriptor pfd = ParcelFileDescriptor.open( new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY); try { if (pfd != null) { mActionsImpl = new ActionsSuggestionsModel(pfd.getFd()); } } finally { maybeCloseAndLogError(pfd); } } return mActionsImpl; } } private String createId(String text, int start, int end) { synchronized (mLock) { return SelectionSessionLogger.createId(text, start, end, mContext, Loading Loading @@ -471,11 +571,19 @@ public final class TextClassifierImpl implements TextClassifier { } printWriter.decreaseIndent(); printWriter.println("LangID model file(s):"); printWriter.increaseIndent(); for (ModelFileManager.ModelFile modelFile : mLangIdModelFileManager.listModelFiles()) { printWriter.println(modelFile.toString()); } printWriter.decreaseIndent(); printWriter.println("Actions model file(s):"); printWriter.increaseIndent(); for (ModelFileManager.ModelFile modelFile : mActionsModelFileManager.listModelFiles()) { printWriter.println(modelFile.toString()); } printWriter.decreaseIndent(); printWriter.printPair("mFallback", mFallback); printWriter.decreaseIndent(); printWriter.println(); Loading
core/tests/coretests/src/android/view/textclassifier/ModelFileManagerTest.java +26 −2 Original line number Diff line number Diff line Loading @@ -43,7 +43,7 @@ import java.util.stream.Collectors; @SmallTest @RunWith(AndroidJUnit4.class) public class ModelFileManagerTest { private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US"); @Mock private Supplier<List<ModelFileManager.ModelFile>> mModelFileSupplier; private ModelFileManager.ModelFileSupplierImpl mModelFileSupplierImpl; Loading Loading @@ -71,6 +71,8 @@ public class ModelFileManagerTest { mRootTestDir.mkdirs(); mFactoryModelDir.mkdirs(); Locale.setDefault(DEFAULT_LOCALE); } @After Loading Loading @@ -134,7 +136,7 @@ public class ModelFileManagerTest { } @Test public void findBestModel_useIndependentWhenNoLanguageModelMatch() { public void findBestModel_noMatchedLanguageModel() { Locale locale = Locale.forLanguageTag("ja"); ModelFileManager.ModelFile languageIndependentModelFile = new ModelFileManager.ModelFile( Loading @@ -156,6 +158,28 @@ public class ModelFileManagerTest { assertThat(bestModelFile).isEqualTo(languageIndependentModelFile); } @Test public void findBestModel_noMatchedLanguageModel_defaultLocaleModelExists() { ModelFileManager.ModelFile languageIndependentModelFile = new ModelFileManager.ModelFile( new File("/path/a"), 1, Collections.emptyList(), true); ModelFileManager.ModelFile languageDependentModelFile = new ModelFileManager.ModelFile( new File("/path/b"), 1, Collections.singletonList(DEFAULT_LOCALE), false); when(mModelFileSupplier.get()) .thenReturn( Arrays.asList(languageIndependentModelFile, languageDependentModelFile)); ModelFileManager.ModelFile bestModelFile = mModelFileManager.findBestModelFile( LocaleList.forLanguageTags("zh-hk")); assertThat(bestModelFile).isEqualTo(languageIndependentModelFile); } @Test public void findBestModel_languageIsMoreImportantThanVersion() { ModelFileManager.ModelFile matchButOlderModel = Loading
core/tests/coretests/src/android/view/textclassifier/TextClassificationManagerTest.java +28 −0 Original line number Diff line number Diff line Loading @@ -19,6 +19,7 @@ package android.view.textclassifier; import static org.hamcrest.CoreMatchers.not; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; Loading Loading @@ -324,6 +325,33 @@ public class TextClassificationManagerTest { assertThat(textLanguage, isTextLanguage("ja")); } @Test public void testSuggestConversationActions_textReplyOnly_maxThree() { if (isTextClassifierDisabled()) return; ConversationActions.Message message = new ConversationActions.Message.Builder().setText("Hello").build(); ConversationActions.TypeConfig typeConfig = new ConversationActions.TypeConfig.Builder().includeTypesFromTextClassifier(false) .setIncludedTypes( Collections.singletonList(ConversationActions.TYPE_TEXT_REPLY)) .build(); ConversationActions.Request request = new ConversationActions.Request.Builder(Collections.singletonList(message)) .setMaxSuggestions(1) .setTypeConfig(typeConfig) .build(); ConversationActions conversationActions = mClassifier.suggestConversationActions(request); assertTrue(conversationActions.getConversationActions().size() <= 1); for (ConversationActions.ConversationAction conversationAction : conversationActions.getConversationActions()) { assertEquals(conversationAction.getType(), ConversationActions.TYPE_TEXT_REPLY); assertNotNull(conversationAction.getTextReply()); assertTrue(conversationAction.getConfidenceScore() > 0); assertTrue(conversationAction.getConfidenceScore() <= 1); } } @Test public void testSetTextClassifier() { TextClassifier classifier = mock(TextClassifier.class); Loading