Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit f8aeee04 authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Stores serialized entity data to the extras" into qt-dev

parents fb1a20e4 fdb3554e
Loading
Loading
Loading
Loading
+58 −0
Original line number Original line Diff line number Diff line
@@ -22,7 +22,12 @@ import android.content.Intent;
import android.icu.util.ULocale;
import android.icu.util.ULocale;
import android.os.Bundle;
import android.os.Bundle;


import com.android.internal.util.ArrayUtils;

import com.google.android.textclassifier.AnnotatorModel;

import java.util.ArrayList;
import java.util.ArrayList;
import java.util.List;


/**
/**
 * Utility class for inserting and retrieving data in TextClassifier request/response extras.
 * Utility class for inserting and retrieving data in TextClassifier request/response extras.
@@ -31,6 +36,7 @@ import java.util.ArrayList;
// TODO: Make this a TestApi for CTS testing.
// TODO: Make this a TestApi for CTS testing.
public final class ExtrasUtils {
public final class ExtrasUtils {


    private static final String SERIALIZED_ENTITIES_DATA = "serialized-entities-data";
    private static final String ENTITIES_EXTRAS = "entities-extras";
    private static final String ENTITIES_EXTRAS = "entities-extras";
    private static final String ACTION_INTENT = "action-intent";
    private static final String ACTION_INTENT = "action-intent";
    private static final String ACTIONS_INTENTS = "actions-intents";
    private static final String ACTIONS_INTENTS = "actions-intents";
@@ -40,6 +46,7 @@ public final class ExtrasUtils {
    private static final String MODEL_VERSION = "model-version";
    private static final String MODEL_VERSION = "model-version";
    private static final String MODEL_NAME = "model-name";
    private static final String MODEL_NAME = "model-name";
    private static final String TEXT_LANGUAGES = "text-languages";
    private static final String TEXT_LANGUAGES = "text-languages";
    private static final String ENTITIES = "entities";


    private ExtrasUtils() {}
    private ExtrasUtils() {}


@@ -154,6 +161,24 @@ public final class ExtrasUtils {
        return container.getParcelable(ACTION_INTENT);
        return container.getParcelable(ACTION_INTENT);
    }
    }


    /**
     * Stores serialized entity data information in TextClassifier response object's extras
     * {@code container}.
     */
    public static void putSerializedEntityData(
            Bundle container, @Nullable byte[] serializedEntityData) {
        container.putByteArray(SERIALIZED_ENTITIES_DATA, serializedEntityData);
    }

    /**
     * Returns serialized entity data information contained in a TextClassifier response
     * object.
     */
    @Nullable
    public static byte[] getSerializedEntityData(Bundle container) {
        return container.getByteArray(SERIALIZED_ENTITIES_DATA);
    }

    /**
    /**
     * Stores {@code entities} information in TextClassifier response object's extras
     * Stores {@code entities} information in TextClassifier response object's extras
     * {@code container}.
     * {@code container}.
@@ -253,4 +278,37 @@ public final class ExtrasUtils {
        }
        }
        return extra.getString(MODEL_NAME);
        return extra.getString(MODEL_NAME);
    }
    }

    /**
     * Stores the entities from {@link AnnotatorModel.ClassificationResult} in {@code container}.
     */
    public static void putEntities(
            Bundle container,
            @Nullable AnnotatorModel.ClassificationResult[] classifications) {
        if (ArrayUtils.isEmpty(classifications)) {
            return;
        }
        ArrayList<Bundle> entitiesBundle = new ArrayList<>();
        for (AnnotatorModel.ClassificationResult classification : classifications) {
            if (classification == null) {
                continue;
            }
            Bundle entityBundle = new Bundle();
            entityBundle.putString(ENTITY_TYPE, classification.getCollection());
            entityBundle.putByteArray(
                    SERIALIZED_ENTITIES_DATA,
                    classification.getSerializedEntityData());
            entitiesBundle.add(entityBundle);
        }
        if (!entitiesBundle.isEmpty()) {
            container.putParcelableArrayList(ENTITIES, entitiesBundle);
        }
    }

    /**
     * Returns a list of entities contained in the {@code extra}.
     */
    public static List<Bundle> getEntities(Bundle container) {
        return container.getParcelableArrayList(ENTITIES);
    }
}
}
+49 −6
Original line number Original line Diff line number Diff line
@@ -44,6 +44,8 @@ import android.view.textclassifier.TextClassifier.Utils;
import com.android.internal.annotations.VisibleForTesting;
import com.android.internal.annotations.VisibleForTesting;
import com.android.internal.util.Preconditions;
import com.android.internal.util.Preconditions;


import com.google.android.textclassifier.AnnotatorModel;

import java.lang.annotation.Retention;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.RetentionPolicy;
import java.time.ZonedDateTime;
import java.time.ZonedDateTime;
@@ -137,7 +139,7 @@ public final class TextClassification implements Parcelable {
            @Nullable Intent legacyIntent,
            @Nullable Intent legacyIntent,
            @Nullable OnClickListener legacyOnClickListener,
            @Nullable OnClickListener legacyOnClickListener,
            @NonNull List<RemoteAction> actions,
            @NonNull List<RemoteAction> actions,
            @NonNull Map<String, Float> entityConfidence,
            @NonNull EntityConfidence entityConfidence,
            @Nullable String id,
            @Nullable String id,
            @NonNull Bundle extras) {
            @NonNull Bundle extras) {
        mText = text;
        mText = text;
@@ -146,7 +148,7 @@ public final class TextClassification implements Parcelable {
        mLegacyIntent = legacyIntent;
        mLegacyIntent = legacyIntent;
        mLegacyOnClickListener = legacyOnClickListener;
        mLegacyOnClickListener = legacyOnClickListener;
        mActions = Collections.unmodifiableList(actions);
        mActions = Collections.unmodifiableList(actions);
        mEntityConfidence = new EntityConfidence(entityConfidence);
        mEntityConfidence = Preconditions.checkNotNull(entityConfidence);
        mId = id;
        mId = id;
        mExtras = extras;
        mExtras = extras;
    }
    }
@@ -326,7 +328,10 @@ public final class TextClassification implements Parcelable {
    public static final class Builder {
    public static final class Builder {


        @NonNull private List<RemoteAction> mActions = new ArrayList<>();
        @NonNull private List<RemoteAction> mActions = new ArrayList<>();
        @NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>();
        @NonNull private final Map<String, Float> mTypeScoreMap = new ArrayMap<>();
        @NonNull
        private final Map<String, AnnotatorModel.ClassificationResult> mClassificationResults =
                new ArrayMap<>();
        @Nullable private String mText;
        @Nullable private String mText;
        @Nullable private Drawable mLegacyIcon;
        @Nullable private Drawable mLegacyIcon;
        @Nullable private String mLegacyLabel;
        @Nullable private String mLegacyLabel;
@@ -359,7 +364,36 @@ public final class TextClassification implements Parcelable {
        public Builder setEntityType(
        public Builder setEntityType(
                @NonNull @EntityType String type,
                @NonNull @EntityType String type,
                @FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
                @FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
            mEntityConfidence.put(type, confidenceScore);
            setEntityType(type, confidenceScore, null);
            return this;
        }

        /**
         * @see #setEntityType(String, float)
         *
         * @hide
         */
        @NonNull
        public Builder setEntityType(AnnotatorModel.ClassificationResult classificationResult) {
            setEntityType(
                    classificationResult.getCollection(),
                    classificationResult.getScore(),
                    classificationResult);
            return this;
        }

        /**
         * @see #setEntityType(String, float)
         *
         * @hide
         */
        @NonNull
        private Builder setEntityType(
                @NonNull @EntityType String type,
                @FloatRange(from = 0.0, to = 1.0) float confidenceScore,
                @Nullable AnnotatorModel.ClassificationResult classificationResult) {
            mTypeScoreMap.put(type, confidenceScore);
            mClassificationResults.put(type, classificationResult);
            return this;
            return this;
        }
        }


@@ -482,11 +516,13 @@ public final class TextClassification implements Parcelable {
         */
         */
        @NonNull
        @NonNull
        public TextClassification build() {
        public TextClassification build() {
            EntityConfidence entityConfidence = new EntityConfidence(mTypeScoreMap);
            return new TextClassification(mText, mLegacyIcon, mLegacyLabel, mLegacyIntent,
            return new TextClassification(mText, mLegacyIcon, mLegacyLabel, mLegacyIntent,
                    mLegacyOnClickListener, mActions, mEntityConfidence, mId, buildExtras());
                    mLegacyOnClickListener, mActions, entityConfidence, mId,
                    buildExtras(entityConfidence));
        }
        }


        private Bundle buildExtras() {
        private Bundle buildExtras(EntityConfidence entityConfidence) {
            final Bundle extras = mExtras == null ? new Bundle() : mExtras.deepCopy();
            final Bundle extras = mExtras == null ? new Bundle() : mExtras.deepCopy();
            if (mActionIntents.stream().anyMatch(Objects::nonNull)) {
            if (mActionIntents.stream().anyMatch(Objects::nonNull)) {
                ExtrasUtils.putActionsIntents(extras, mActionIntents);
                ExtrasUtils.putActionsIntents(extras, mActionIntents);
@@ -494,6 +530,13 @@ public final class TextClassification implements Parcelable {
            if (mForeignLanguageExtra != null) {
            if (mForeignLanguageExtra != null) {
                ExtrasUtils.putForeignLanguageExtra(extras, mForeignLanguageExtra);
                ExtrasUtils.putForeignLanguageExtra(extras, mForeignLanguageExtra);
            }
            }
            List<String> sortedTypes = entityConfidence.getEntities();
            ArrayList<AnnotatorModel.ClassificationResult> sortedEntities = new ArrayList<>();
            for (String type : sortedTypes) {
                sortedEntities.add(mClassificationResults.get(type));
            }
            ExtrasUtils.putEntities(
                    extras, sortedEntities.toArray(new AnnotatorModel.ClassificationResult[0]));
            return extras.isEmpty() ? Bundle.EMPTY : extras;
            return extras.isEmpty() ? Bundle.EMPTY : extras;
        }
        }
    }
    }
+2 −4
Original line number Original line Diff line number Diff line
@@ -458,6 +458,7 @@ public final class TextClassifierImpl implements TextClassifier {
                remoteAction = labeledIntentResult.remoteAction;
                remoteAction = labeledIntentResult.remoteAction;
                ExtrasUtils.putActionIntent(extras, labeledIntentResult.resolvedIntent);
                ExtrasUtils.putActionIntent(extras, labeledIntentResult.resolvedIntent);
            }
            }
            ExtrasUtils.putSerializedEntityData(extras, nativeSuggestion.getSerializedEntityData());
            ExtrasUtils.putEntitiesExtras(
            ExtrasUtils.putEntitiesExtras(
                    extras,
                    extras,
                    TemplateIntentFactory.nameVariantsToBundle(nativeSuggestion.getEntityData()));
                    TemplateIntentFactory.nameVariantsToBundle(nativeSuggestion.getEntityData()));
@@ -618,9 +619,7 @@ public final class TextClassifierImpl implements TextClassifier {
        AnnotatorModel.ClassificationResult highestScoringResult =
        AnnotatorModel.ClassificationResult highestScoringResult =
                typeCount > 0 ? classifications[0] : null;
                typeCount > 0 ? classifications[0] : null;
        for (int i = 0; i < typeCount; i++) {
        for (int i = 0; i < typeCount; i++) {
            builder.setEntityType(
            builder.setEntityType(classifications[i]);
                    classifications[i].getCollection(),
                    classifications[i].getScore());
            if (classifications[i].getScore() > highestScoringResult.getScore()) {
            if (classifications[i].getScore() > highestScoringResult.getScore()) {
                highestScoringResult = classifications[i];
                highestScoringResult = classifications[i];
            }
            }
@@ -663,7 +662,6 @@ public final class TextClassifierImpl implements TextClassifier {
            }
            }
            builder.addAction(action, intent);
            builder.addAction(action, intent);
        }
        }

        return builder.setId(createId(text, start, end)).build();
        return builder.setId(createId(text, start, end)).build();
    }
    }


+7 −0
Original line number Original line Diff line number Diff line
@@ -216,6 +216,11 @@ public class TextClassifierTest {


        TextClassification classification = mClassifier.classifyText(request);
        TextClassification classification = mClassifier.classifyText(request);
        assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
        assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
        Bundle extras = classification.getExtras();
        List<Bundle> entities = ExtrasUtils.getEntities(extras);
        Truth.assertThat(entities).hasSize(1);
        Bundle entity = entities.get(0);
        Truth.assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_DATE);
    }
    }


    @Test
    @Test
@@ -484,6 +489,8 @@ public class TextClassifierTest {
        Truth.assertThat(conversationAction.getAction()).isNull();
        Truth.assertThat(conversationAction.getAction()).isNull();
        String code = ExtrasUtils.getCopyText(conversationAction.getExtras());
        String code = ExtrasUtils.getCopyText(conversationAction.getExtras());
        Truth.assertThat(code).isEqualTo("12345");
        Truth.assertThat(code).isEqualTo("12345");
        Truth.assertThat(
                ExtrasUtils.getSerializedEntityData(conversationAction.getExtras())).isNotEmpty();
    }
    }


    private boolean isTextClassifierDisabled() {
    private boolean isTextClassifierDisabled() {