Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit fdb3554e authored by Tony Mak's avatar Tony Mak
Browse files

Stores serialized entity data to the extras

libtextclassifier (native side) will serialize the extra entity information
(like parsed datetime) into a byte array and passed it to framework.
Framework puts it to the extras Bundle of result object, i.e. TextClassification
and ConversationActions.

In the future, we will provide a library (maybe AndroidX) to deserialize the
string and return structured objects.

BUG: 129119759

Test: atest frameworks/base/core/tests/coretests/src/android/view/textclassifier/

Change-Id: I8091a1038691419825f5d6da3562b8ba81787dc7
parent 0454f3b8
Loading
Loading
Loading
Loading
+58 −0
Original line number Diff line number Diff line
@@ -22,7 +22,12 @@ import android.content.Intent;
import android.icu.util.ULocale;
import android.os.Bundle;

import com.android.internal.util.ArrayUtils;

import com.google.android.textclassifier.AnnotatorModel;

import java.util.ArrayList;
import java.util.List;

/**
 * Utility class for inserting and retrieving data in TextClassifier request/response extras.
@@ -31,6 +36,7 @@ import java.util.ArrayList;
// TODO: Make this a TestApi for CTS testing.
public final class ExtrasUtils {

    private static final String SERIALIZED_ENTITIES_DATA = "serialized-entities-data";
    private static final String ENTITIES_EXTRAS = "entities-extras";
    private static final String ACTION_INTENT = "action-intent";
    private static final String ACTIONS_INTENTS = "actions-intents";
@@ -40,6 +46,7 @@ public final class ExtrasUtils {
    private static final String MODEL_VERSION = "model-version";
    private static final String MODEL_NAME = "model-name";
    private static final String TEXT_LANGUAGES = "text-languages";
    private static final String ENTITIES = "entities";

    private ExtrasUtils() {}

@@ -154,6 +161,24 @@ public final class ExtrasUtils {
        return container.getParcelable(ACTION_INTENT);
    }

    /**
     * Stores serialized entity data information in TextClassifier response object's extras
     * {@code container}.
     */
    public static void putSerializedEntityData(
            Bundle container, @Nullable byte[] serializedEntityData) {
        container.putByteArray(SERIALIZED_ENTITIES_DATA, serializedEntityData);
    }

    /**
     * Returns serialized entity data information contained in a TextClassifier response
     * object.
     */
    @Nullable
    public static byte[] getSerializedEntityData(Bundle container) {
        return container.getByteArray(SERIALIZED_ENTITIES_DATA);
    }

    /**
     * Stores {@code entities} information in TextClassifier response object's extras
     * {@code container}.
@@ -253,4 +278,37 @@ public final class ExtrasUtils {
        }
        return extra.getString(MODEL_NAME);
    }

    /**
     * Stores the entities from {@link AnnotatorModel.ClassificationResult} in {@code container}.
     */
    public static void putEntities(
            Bundle container,
            @Nullable AnnotatorModel.ClassificationResult[] classifications) {
        if (ArrayUtils.isEmpty(classifications)) {
            return;
        }
        ArrayList<Bundle> entitiesBundle = new ArrayList<>();
        for (AnnotatorModel.ClassificationResult classification : classifications) {
            if (classification == null) {
                continue;
            }
            Bundle entityBundle = new Bundle();
            entityBundle.putString(ENTITY_TYPE, classification.getCollection());
            entityBundle.putByteArray(
                    SERIALIZED_ENTITIES_DATA,
                    classification.getSerializedEntityData());
            entitiesBundle.add(entityBundle);
        }
        if (!entitiesBundle.isEmpty()) {
            container.putParcelableArrayList(ENTITIES, entitiesBundle);
        }
    }

    /**
     * Returns a list of entities contained in the {@code extra}.
     */
    public static List<Bundle> getEntities(Bundle container) {
        return container.getParcelableArrayList(ENTITIES);
    }
}
+49 −6
Original line number Diff line number Diff line
@@ -44,6 +44,8 @@ import android.view.textclassifier.TextClassifier.Utils;
import com.android.internal.annotations.VisibleForTesting;
import com.android.internal.util.Preconditions;

import com.google.android.textclassifier.AnnotatorModel;

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.time.ZonedDateTime;
@@ -137,7 +139,7 @@ public final class TextClassification implements Parcelable {
            @Nullable Intent legacyIntent,
            @Nullable OnClickListener legacyOnClickListener,
            @NonNull List<RemoteAction> actions,
            @NonNull Map<String, Float> entityConfidence,
            @NonNull EntityConfidence entityConfidence,
            @Nullable String id,
            @NonNull Bundle extras) {
        mText = text;
@@ -146,7 +148,7 @@ public final class TextClassification implements Parcelable {
        mLegacyIntent = legacyIntent;
        mLegacyOnClickListener = legacyOnClickListener;
        mActions = Collections.unmodifiableList(actions);
        mEntityConfidence = new EntityConfidence(entityConfidence);
        mEntityConfidence = Preconditions.checkNotNull(entityConfidence);
        mId = id;
        mExtras = extras;
    }
@@ -326,7 +328,10 @@ public final class TextClassification implements Parcelable {
    public static final class Builder {

        @NonNull private List<RemoteAction> mActions = new ArrayList<>();
        @NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>();
        @NonNull private final Map<String, Float> mTypeScoreMap = new ArrayMap<>();
        @NonNull
        private final Map<String, AnnotatorModel.ClassificationResult> mClassificationResults =
                new ArrayMap<>();
        @Nullable private String mText;
        @Nullable private Drawable mLegacyIcon;
        @Nullable private String mLegacyLabel;
@@ -359,7 +364,36 @@ public final class TextClassification implements Parcelable {
        public Builder setEntityType(
                @NonNull @EntityType String type,
                @FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
            mEntityConfidence.put(type, confidenceScore);
            setEntityType(type, confidenceScore, null);
            return this;
        }

        /**
         * @see #setEntityType(String, float)
         *
         * @hide
         */
        @NonNull
        public Builder setEntityType(AnnotatorModel.ClassificationResult classificationResult) {
            setEntityType(
                    classificationResult.getCollection(),
                    classificationResult.getScore(),
                    classificationResult);
            return this;
        }

        /**
         * @see #setEntityType(String, float)
         *
         * @hide
         */
        @NonNull
        private Builder setEntityType(
                @NonNull @EntityType String type,
                @FloatRange(from = 0.0, to = 1.0) float confidenceScore,
                @Nullable AnnotatorModel.ClassificationResult classificationResult) {
            mTypeScoreMap.put(type, confidenceScore);
            mClassificationResults.put(type, classificationResult);
            return this;
        }

@@ -482,11 +516,13 @@ public final class TextClassification implements Parcelable {
         */
        @NonNull
        public TextClassification build() {
            EntityConfidence entityConfidence = new EntityConfidence(mTypeScoreMap);
            return new TextClassification(mText, mLegacyIcon, mLegacyLabel, mLegacyIntent,
                    mLegacyOnClickListener, mActions, mEntityConfidence, mId, buildExtras());
                    mLegacyOnClickListener, mActions, entityConfidence, mId,
                    buildExtras(entityConfidence));
        }

        private Bundle buildExtras() {
        private Bundle buildExtras(EntityConfidence entityConfidence) {
            final Bundle extras = mExtras == null ? new Bundle() : mExtras.deepCopy();
            if (mActionIntents.stream().anyMatch(Objects::nonNull)) {
                ExtrasUtils.putActionsIntents(extras, mActionIntents);
@@ -494,6 +530,13 @@ public final class TextClassification implements Parcelable {
            if (mForeignLanguageExtra != null) {
                ExtrasUtils.putForeignLanguageExtra(extras, mForeignLanguageExtra);
            }
            List<String> sortedTypes = entityConfidence.getEntities();
            ArrayList<AnnotatorModel.ClassificationResult> sortedEntities = new ArrayList<>();
            for (String type : sortedTypes) {
                sortedEntities.add(mClassificationResults.get(type));
            }
            ExtrasUtils.putEntities(
                    extras, sortedEntities.toArray(new AnnotatorModel.ClassificationResult[0]));
            return extras.isEmpty() ? Bundle.EMPTY : extras;
        }
    }
+2 −4
Original line number Diff line number Diff line
@@ -458,6 +458,7 @@ public final class TextClassifierImpl implements TextClassifier {
                remoteAction = labeledIntentResult.remoteAction;
                ExtrasUtils.putActionIntent(extras, labeledIntentResult.resolvedIntent);
            }
            ExtrasUtils.putSerializedEntityData(extras, nativeSuggestion.getSerializedEntityData());
            ExtrasUtils.putEntitiesExtras(
                    extras,
                    TemplateIntentFactory.nameVariantsToBundle(nativeSuggestion.getEntityData()));
@@ -618,9 +619,7 @@ public final class TextClassifierImpl implements TextClassifier {
        AnnotatorModel.ClassificationResult highestScoringResult =
                typeCount > 0 ? classifications[0] : null;
        for (int i = 0; i < typeCount; i++) {
            builder.setEntityType(
                    classifications[i].getCollection(),
                    classifications[i].getScore());
            builder.setEntityType(classifications[i]);
            if (classifications[i].getScore() > highestScoringResult.getScore()) {
                highestScoringResult = classifications[i];
            }
@@ -663,7 +662,6 @@ public final class TextClassifierImpl implements TextClassifier {
            }
            builder.addAction(action, intent);
        }

        return builder.setId(createId(text, start, end)).build();
    }

+7 −0
Original line number Diff line number Diff line
@@ -216,6 +216,11 @@ public class TextClassifierTest {

        TextClassification classification = mClassifier.classifyText(request);
        assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
        Bundle extras = classification.getExtras();
        List<Bundle> entities = ExtrasUtils.getEntities(extras);
        Truth.assertThat(entities).hasSize(1);
        Bundle entity = entities.get(0);
        Truth.assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_DATE);
    }

    @Test
@@ -484,6 +489,8 @@ public class TextClassifierTest {
        Truth.assertThat(conversationAction.getAction()).isNull();
        String code = ExtrasUtils.getCopyText(conversationAction.getExtras());
        Truth.assertThat(code).isEqualTo("12345");
        Truth.assertThat(
                ExtrasUtils.getSerializedEntityData(conversationAction.getExtras())).isNotEmpty();
    }

    private boolean isTextClassifierDisabled() {