Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 59ed9a7f authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Follow-up CL of ag/6935284, add entities to extras in generateLinks" into qt-dev

parents 143c735e b6afd3c1
Loading
Loading
Loading
Loading
+21 −0
Original line number Diff line number Diff line
@@ -36,6 +36,7 @@ import java.util.List;
// TODO: Make this a TestApi for CTS testing.
public final class ExtrasUtils {

    // Keys for response objects.
    private static final String SERIALIZED_ENTITIES_DATA = "serialized-entities-data";
    private static final String ENTITIES_EXTRAS = "entities-extras";
    private static final String ACTION_INTENT = "action-intent";
@@ -48,6 +49,10 @@ public final class ExtrasUtils {
    private static final String TEXT_LANGUAGES = "text-languages";
    private static final String ENTITIES = "entities";

    // Keys for request objects.
    private static final String IS_SERIALIZED_ENTITY_DATA_ENABLED =
            "is-serialized-entity-data-enabled";

    private ExtrasUtils() {}

    /**
@@ -308,7 +313,23 @@ public final class ExtrasUtils {
    /**
     * Returns a list of entities contained in the {@code extra}.
     */
    @Nullable
    public static List<Bundle> getEntities(Bundle container) {
        return container.getParcelableArrayList(ENTITIES);
    }

    /**
     * Whether the annotator should populate serialized entity data into the result object.
     */
    public static boolean isSerializedEntityDataEnabled(TextLinks.Request request) {
        return request.getExtras().getBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED);
    }

    /**
     * To indicate whether the annotator should populate serialized entity data in the result
     * object.
     */
    public static void putIsSerializedEntityDataEnabled(Bundle bundle, boolean isEnabled) {
        bundle.putBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED, isEnabled);
    }
}
+11 −2
Original line number Diff line number Diff line
@@ -307,6 +307,8 @@ public final class TextClassifierImpl implements TextClassifier {
            final String detectLanguageTags = detectLanguageTagsFromText(request.getText());
            final AnnotatorModel annotatorImpl =
                    getAnnotatorImpl(request.getDefaultLocales());
            final boolean isSerializedEntityDataEnabled =
                    ExtrasUtils.isSerializedEntityDataEnabled(request);
            final AnnotatorModel.AnnotatedSpan[] annotations =
                    annotatorImpl.annotate(
                            textString,
@@ -314,7 +316,10 @@ public final class TextClassifierImpl implements TextClassifier {
                                    refTime.toInstant().toEpochMilli(),
                                    refTime.getZone().getId(),
                                    localesString,
                                    detectLanguageTags));
                                    detectLanguageTags,
                                    entitiesToIdentify,
                                    AnnotatorModel.AnnotationUsecase.SMART.getValue(),
                                    isSerializedEntityDataEnabled));
            for (AnnotatorModel.AnnotatedSpan span : annotations) {
                final AnnotatorModel.ClassificationResult[] results =
                        span.getClassification();
@@ -326,7 +331,11 @@ public final class TextClassifierImpl implements TextClassifier {
                for (int i = 0; i < results.length; i++) {
                    entityScores.put(results[i].getCollection(), results[i].getScore());
                }
                builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores);
                Bundle extras = new Bundle();
                if (isSerializedEntityDataEnabled) {
                    ExtrasUtils.putEntities(extras, results);
                }
                builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras);
            }
            final TextLinks links = builder.build();
            final long endTimeMs = System.currentTimeMillis();
+32 −0
Original line number Diff line number Diff line
@@ -361,6 +361,38 @@ public class TextClassifierTest {
        mClassifier.generateLinks(request);
    }

    @Test
    public void testGenerateLinks_entityData() {
        if (isTextClassifierDisabled()) return;
        String text = "The number is +12122537077.";
        Bundle extras = new Bundle();
        ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
        TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build();

        TextLinks textLinks = mClassifier.generateLinks(request);

        Truth.assertThat(textLinks.getLinks()).hasSize(1);
        TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
        List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
        Truth.assertThat(entities).hasSize(1);
        Bundle entity = entities.get(0);
        Truth.assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE);
    }

    @Test
    public void testGenerateLinks_entityData_disabled() {
        if (isTextClassifierDisabled()) return;
        String text = "The number is +12122537077.";
        TextLinks.Request request = new TextLinks.Request.Builder(text).build();

        TextLinks textLinks = mClassifier.generateLinks(request);

        Truth.assertThat(textLinks.getLinks()).hasSize(1);
        TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
        List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
        Truth.assertThat(entities).isNull();
    }

    @Test
    public void testDetectLanguage() {
        if (isTextClassifierDisabled()) return;