Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit b252a3f7 authored by Tony Mak's avatar Tony Mak Committed by android-build-merger
Browse files

Merge "Stores serialized entity data to the extras" into qt-dev

am: f8aeee04

Change-Id: I422d2e65861e528081e4878a7f8a0c9b24376f36
parents 70bf1bcd f8aeee04
Loading
Loading
Loading
Loading
+58 −0
Original line number Diff line number Diff line
@@ -22,7 +22,12 @@ import android.content.Intent;
import android.icu.util.ULocale;
import android.os.Bundle;

import com.android.internal.util.ArrayUtils;

import com.google.android.textclassifier.AnnotatorModel;

import java.util.ArrayList;
import java.util.List;

/**
 * Utility class for inserting and retrieving data in TextClassifier request/response extras.
@@ -31,6 +36,7 @@ import java.util.ArrayList;
// TODO: Make this a TestApi for CTS testing.
public final class ExtrasUtils {

    private static final String SERIALIZED_ENTITIES_DATA = "serialized-entities-data";
    private static final String ENTITIES_EXTRAS = "entities-extras";
    private static final String ACTION_INTENT = "action-intent";
    private static final String ACTIONS_INTENTS = "actions-intents";
@@ -40,6 +46,7 @@ public final class ExtrasUtils {
    private static final String MODEL_VERSION = "model-version";
    private static final String MODEL_NAME = "model-name";
    private static final String TEXT_LANGUAGES = "text-languages";
    private static final String ENTITIES = "entities";

    private ExtrasUtils() {}

@@ -154,6 +161,24 @@ public final class ExtrasUtils {
        return container.getParcelable(ACTION_INTENT);
    }

    /**
     * Stores serialized entity data information in TextClassifier response object's extras
     * {@code container}.
     */
    public static void putSerializedEntityData(
            Bundle container, @Nullable byte[] serializedEntityData) {
        container.putByteArray(SERIALIZED_ENTITIES_DATA, serializedEntityData);
    }

    /**
     * Returns serialized entity data information contained in a TextClassifier response
     * object.
     */
    @Nullable
    public static byte[] getSerializedEntityData(Bundle container) {
        return container.getByteArray(SERIALIZED_ENTITIES_DATA);
    }

    /**
     * Stores {@code entities} information in TextClassifier response object's extras
     * {@code container}.
@@ -253,4 +278,37 @@ public final class ExtrasUtils {
        }
        return extra.getString(MODEL_NAME);
    }

    /**
     * Stores the entities from {@link AnnotatorModel.ClassificationResult} in {@code container}.
     */
    public static void putEntities(
            Bundle container,
            @Nullable AnnotatorModel.ClassificationResult[] classifications) {
        if (ArrayUtils.isEmpty(classifications)) {
            return;
        }
        ArrayList<Bundle> entitiesBundle = new ArrayList<>();
        for (AnnotatorModel.ClassificationResult classification : classifications) {
            if (classification == null) {
                continue;
            }
            Bundle entityBundle = new Bundle();
            entityBundle.putString(ENTITY_TYPE, classification.getCollection());
            entityBundle.putByteArray(
                    SERIALIZED_ENTITIES_DATA,
                    classification.getSerializedEntityData());
            entitiesBundle.add(entityBundle);
        }
        if (!entitiesBundle.isEmpty()) {
            container.putParcelableArrayList(ENTITIES, entitiesBundle);
        }
    }

    /**
     * Returns a list of entities contained in the {@code extra}.
     */
    public static List<Bundle> getEntities(Bundle container) {
        return container.getParcelableArrayList(ENTITIES);
    }
}
+49 −6
Original line number Diff line number Diff line
@@ -44,6 +44,8 @@ import android.view.textclassifier.TextClassifier.Utils;
import com.android.internal.annotations.VisibleForTesting;
import com.android.internal.util.Preconditions;

import com.google.android.textclassifier.AnnotatorModel;

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.time.ZonedDateTime;
@@ -137,7 +139,7 @@ public final class TextClassification implements Parcelable {
            @Nullable Intent legacyIntent,
            @Nullable OnClickListener legacyOnClickListener,
            @NonNull List<RemoteAction> actions,
            @NonNull Map<String, Float> entityConfidence,
            @NonNull EntityConfidence entityConfidence,
            @Nullable String id,
            @NonNull Bundle extras) {
        mText = text;
@@ -146,7 +148,7 @@ public final class TextClassification implements Parcelable {
        mLegacyIntent = legacyIntent;
        mLegacyOnClickListener = legacyOnClickListener;
        mActions = Collections.unmodifiableList(actions);
        mEntityConfidence = new EntityConfidence(entityConfidence);
        mEntityConfidence = Preconditions.checkNotNull(entityConfidence);
        mId = id;
        mExtras = extras;
    }
@@ -326,7 +328,10 @@ public final class TextClassification implements Parcelable {
    public static final class Builder {

        @NonNull private List<RemoteAction> mActions = new ArrayList<>();
        @NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>();
        @NonNull private final Map<String, Float> mTypeScoreMap = new ArrayMap<>();
        @NonNull
        private final Map<String, AnnotatorModel.ClassificationResult> mClassificationResults =
                new ArrayMap<>();
        @Nullable private String mText;
        @Nullable private Drawable mLegacyIcon;
        @Nullable private String mLegacyLabel;
@@ -359,7 +364,36 @@ public final class TextClassification implements Parcelable {
        public Builder setEntityType(
                @NonNull @EntityType String type,
                @FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
            mEntityConfidence.put(type, confidenceScore);
            setEntityType(type, confidenceScore, null);
            return this;
        }

        /**
         * @see #setEntityType(String, float)
         *
         * @hide
         */
        @NonNull
        public Builder setEntityType(AnnotatorModel.ClassificationResult classificationResult) {
            setEntityType(
                    classificationResult.getCollection(),
                    classificationResult.getScore(),
                    classificationResult);
            return this;
        }

        /**
         * @see #setEntityType(String, float)
         *
         * @hide
         */
        @NonNull
        private Builder setEntityType(
                @NonNull @EntityType String type,
                @FloatRange(from = 0.0, to = 1.0) float confidenceScore,
                @Nullable AnnotatorModel.ClassificationResult classificationResult) {
            mTypeScoreMap.put(type, confidenceScore);
            mClassificationResults.put(type, classificationResult);
            return this;
        }

@@ -482,11 +516,13 @@ public final class TextClassification implements Parcelable {
         */
        @NonNull
        public TextClassification build() {
            EntityConfidence entityConfidence = new EntityConfidence(mTypeScoreMap);
            return new TextClassification(mText, mLegacyIcon, mLegacyLabel, mLegacyIntent,
                    mLegacyOnClickListener, mActions, mEntityConfidence, mId, buildExtras());
                    mLegacyOnClickListener, mActions, entityConfidence, mId,
                    buildExtras(entityConfidence));
        }

        private Bundle buildExtras() {
        private Bundle buildExtras(EntityConfidence entityConfidence) {
            final Bundle extras = mExtras == null ? new Bundle() : mExtras.deepCopy();
            if (mActionIntents.stream().anyMatch(Objects::nonNull)) {
                ExtrasUtils.putActionsIntents(extras, mActionIntents);
@@ -494,6 +530,13 @@ public final class TextClassification implements Parcelable {
            if (mForeignLanguageExtra != null) {
                ExtrasUtils.putForeignLanguageExtra(extras, mForeignLanguageExtra);
            }
            List<String> sortedTypes = entityConfidence.getEntities();
            ArrayList<AnnotatorModel.ClassificationResult> sortedEntities = new ArrayList<>();
            for (String type : sortedTypes) {
                sortedEntities.add(mClassificationResults.get(type));
            }
            ExtrasUtils.putEntities(
                    extras, sortedEntities.toArray(new AnnotatorModel.ClassificationResult[0]));
            return extras.isEmpty() ? Bundle.EMPTY : extras;
        }
    }
+2 −4
Original line number Diff line number Diff line
@@ -458,6 +458,7 @@ public final class TextClassifierImpl implements TextClassifier {
                remoteAction = labeledIntentResult.remoteAction;
                ExtrasUtils.putActionIntent(extras, labeledIntentResult.resolvedIntent);
            }
            ExtrasUtils.putSerializedEntityData(extras, nativeSuggestion.getSerializedEntityData());
            ExtrasUtils.putEntitiesExtras(
                    extras,
                    TemplateIntentFactory.nameVariantsToBundle(nativeSuggestion.getEntityData()));
@@ -618,9 +619,7 @@ public final class TextClassifierImpl implements TextClassifier {
        AnnotatorModel.ClassificationResult highestScoringResult =
                typeCount > 0 ? classifications[0] : null;
        for (int i = 0; i < typeCount; i++) {
            builder.setEntityType(
                    classifications[i].getCollection(),
                    classifications[i].getScore());
            builder.setEntityType(classifications[i]);
            if (classifications[i].getScore() > highestScoringResult.getScore()) {
                highestScoringResult = classifications[i];
            }
@@ -663,7 +662,6 @@ public final class TextClassifierImpl implements TextClassifier {
            }
            builder.addAction(action, intent);
        }

        return builder.setId(createId(text, start, end)).build();
    }

+7 −0
Original line number Diff line number Diff line
@@ -216,6 +216,11 @@ public class TextClassifierTest {

        TextClassification classification = mClassifier.classifyText(request);
        assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
        Bundle extras = classification.getExtras();
        List<Bundle> entities = ExtrasUtils.getEntities(extras);
        Truth.assertThat(entities).hasSize(1);
        Bundle entity = entities.get(0);
        Truth.assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_DATE);
    }

    @Test
@@ -484,6 +489,8 @@ public class TextClassifierTest {
        Truth.assertThat(conversationAction.getAction()).isNull();
        String code = ExtrasUtils.getCopyText(conversationAction.getExtras());
        Truth.assertThat(code).isEqualTo("12345");
        Truth.assertThat(
                ExtrasUtils.getSerializedEntityData(conversationAction.getExtras())).isNotEmpty();
    }

    private boolean isTextClassifierDisabled() {