Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit beef470b authored by Richard Ledley's avatar Richard Ledley Committed by Android (Google) Code Review
Browse files

Merge "Add entity types to Options."

parents bdbb8811 db18a578
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
@@ -49119,9 +49119,13 @@ package android.view.textclassifier {
    method public default android.view.textclassifier.TextClassification classifyText(java.lang.CharSequence, int, int, android.os.LocaleList);
    method public default android.view.textclassifier.TextLinks generateLinks(java.lang.CharSequence, android.view.textclassifier.TextLinks.Options);
    method public default android.view.textclassifier.TextLinks generateLinks(java.lang.CharSequence);
    method public default java.util.Collection<java.lang.String> getEntitiesForPreset(int);
    method public default android.view.textclassifier.TextSelection suggestSelection(java.lang.CharSequence, int, int, android.view.textclassifier.TextSelection.Options);
    method public default android.view.textclassifier.TextSelection suggestSelection(java.lang.CharSequence, int, int);
    method public default android.view.textclassifier.TextSelection suggestSelection(java.lang.CharSequence, int, int, android.os.LocaleList);
    field public static final int ENTITY_PRESET_ALL = 0; // 0x0
    field public static final int ENTITY_PRESET_BASE = 2; // 0x2
    field public static final int ENTITY_PRESET_NONE = 1; // 0x1
    field public static final android.view.textclassifier.TextClassifier NO_OP;
    field public static final java.lang.String TYPE_ADDRESS = "address";
    field public static final java.lang.String TYPE_EMAIL = "email";
@@ -49131,6 +49135,13 @@ package android.view.textclassifier {
    field public static final java.lang.String TYPE_URL = "url";
  }
  public static final class TextClassifier.EntityConfig {
    ctor public TextClassifier.EntityConfig(int);
    method public android.view.textclassifier.TextClassifier.EntityConfig excludeEntities(java.lang.String...);
    method public java.util.List<java.lang.String> getEntities(android.view.textclassifier.TextClassifier);
    method public android.view.textclassifier.TextClassifier.EntityConfig includeEntities(java.lang.String...);
  }
  public final class TextLinks {
    method public boolean apply(android.text.SpannableString, java.util.function.Function<android.view.textclassifier.TextLinks.TextLink, android.text.style.ClickableSpan>);
    method public java.util.Collection<android.view.textclassifier.TextLinks.TextLink> getLinks();
@@ -49145,7 +49156,9 @@ package android.view.textclassifier {
  public static final class TextLinks.Options {
    ctor public TextLinks.Options();
    method public android.os.LocaleList getDefaultLocales();
    method public android.view.textclassifier.TextClassifier.EntityConfig getEntityConfig();
    method public android.view.textclassifier.TextLinks.Options setDefaultLocales(android.os.LocaleList);
    method public android.view.textclassifier.TextLinks.Options setEntityConfig(android.view.textclassifier.TextClassifier.EntityConfig);
  }
  public static final class TextLinks.TextLink {
+88 −0
Original line number Diff line number Diff line
@@ -16,17 +16,23 @@

package android.view.textclassifier;

import android.annotation.IntDef;
import android.annotation.IntRange;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.StringDef;
import android.annotation.WorkerThread;
import android.os.LocaleList;
import android.util.ArraySet;

import com.android.internal.util.Preconditions;

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

/**
 * Interface for providing text classification related features.
@@ -58,6 +64,20 @@ public interface TextClassifier {
    })
    @interface EntityType {}

    /** Designates that the TextClassifier should identify all entity types it can. **/
    int ENTITY_PRESET_ALL = 0;
    /** Designates that the TextClassifier should identify no entities. **/
    int ENTITY_PRESET_NONE = 1;
    /** Designates that the TextClassifier should identify a base set of entities determined by the
     * TextClassifier. **/
    int ENTITY_PRESET_BASE = 2;

    /** @hide */
    @Retention(RetentionPolicy.SOURCE)
    @IntDef(prefix = { "ENTITY_CONFIG_" },
            value = {ENTITY_PRESET_ALL, ENTITY_PRESET_NONE, ENTITY_PRESET_BASE})
    @interface EntityPreset {}

    /**
     * No-op TextClassifier.
     * This may be used to turn off TextClassifier features.
@@ -217,6 +237,8 @@ public interface TextClassifier {
     * Returns a {@link TextLinks} that may be applied to the text to annotate it with links
     * information.
     *
     * If no options are supplied, default values will be used, determined by the TextClassifier.
     *
     * @param text the text to generate annotations for
     * @param options configuration for link generation
     *
@@ -250,6 +272,16 @@ public interface TextClassifier {
        return generateLinks(text, null);
    }

    /**
     * Returns a {@link Collection} of the entity types in the specified preset.
     *
     * @see #ENTITIES_ALL
     * @see #ENTITIES_NONE
     */
    default Collection<String> getEntitiesForPreset(@EntityPreset int entityPreset) {
        return Collections.EMPTY_LIST;
    }

    /**
     * Logs a TextClassifier event.
     *
@@ -268,6 +300,62 @@ public interface TextClassifier {
        return TextClassifierConstants.DEFAULT;
    }

    /**
     * Configuration object for specifying what entities to identify.
     *
     * Configs are initially based on a predefined preset, and can be modified from there.
     */
    final class EntityConfig {
        private final @TextClassifier.EntityPreset int mEntityPreset;
        private final Collection<String> mExcludedEntityTypes;
        private final Collection<String> mIncludedEntityTypes;

        public EntityConfig(@TextClassifier.EntityPreset int mEntityPreset) {
            this.mEntityPreset = mEntityPreset;
            mExcludedEntityTypes = new ArraySet<>();
            mIncludedEntityTypes = new ArraySet<>();
        }

        /**
         * Specifies an entity to include in addition to any specified by the enity preset.
         *
         * Note that if an entity has been excluded, the exclusion will take precedence.
         */
        public EntityConfig includeEntities(String... entities) {
            for (String entity : entities) {
                mIncludedEntityTypes.add(entity);
            }
            return this;
        }

        /**
         * Specifies an entity to be excluded.
         */
        public EntityConfig excludeEntities(String... entities) {
            for (String entity : entities) {
                mExcludedEntityTypes.add(entity);
            }
            return this;
        }

        /**
         * Returns an unmodifiable list of the final set of entities to find.
         */
        public List<String> getEntities(TextClassifier textClassifier) {
            ArrayList<String> entities = new ArrayList<>();
            for (String entity : textClassifier.getEntitiesForPreset(mEntityPreset)) {
                if (!mExcludedEntityTypes.contains(entity)) {
                    entities.add(entity);
                }
            }
            for (String entity : mIncludedEntityTypes) {
                if (!mExcludedEntityTypes.contains(entity) && !entities.contains(entity)) {
                    entities.add(entity);
                }
            }
            return Collections.unmodifiableList(entities);
        }
    }

    /**
     * Utility functions for TextClassifier methods.
+38 −3
Original line number Diff line number Diff line
@@ -42,6 +42,9 @@ import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
@@ -66,6 +69,18 @@ final class TextClassifierImpl implements TextClassifier {
    private static final String MODEL_FILE_REGEX = "textclassifier\\.smartselection\\.(.*)\\.model";
    private static final String UPDATED_MODEL_FILE_PATH =
            "/data/misc/textclassifier/textclassifier.smartselection.model";
    private static final List<String> ENTITY_TYPES_ALL =
            Collections.unmodifiableList(Arrays.asList(
                    TextClassifier.TYPE_ADDRESS,
                    TextClassifier.TYPE_EMAIL,
                    TextClassifier.TYPE_PHONE,
                    TextClassifier.TYPE_URL));
    private static final List<String> ENTITY_TYPES_BASE =
            Collections.unmodifiableList(Arrays.asList(
                    TextClassifier.TYPE_ADDRESS,
                    TextClassifier.TYPE_EMAIL,
                    TextClassifier.TYPE_PHONE,
                    TextClassifier.TYPE_URL));

    private final Context mContext;

@@ -168,17 +183,23 @@ final class TextClassifierImpl implements TextClassifier {

    @Override
    public TextLinks generateLinks(
            @NonNull CharSequence text, @NonNull TextLinks.Options options) {
            @NonNull CharSequence text, @Nullable TextLinks.Options options) {
        Utils.validateInput(text);
        final String textString = text.toString();
        final TextLinks.Builder builder = new TextLinks.Builder(textString);
        try {
            LocaleList defaultLocales = options != null ? options.getDefaultLocales() : null;
            final LocaleList defaultLocales = options != null ? options.getDefaultLocales() : null;
            final Collection<String> entitiesToIdentify =
                    options != null && options.getEntityConfig() != null
                            ? options.getEntityConfig().getEntities(this) : ENTITY_TYPES_ALL;
            final SmartSelection smartSelection = getSmartSelection(defaultLocales);
            final SmartSelection.AnnotatedSpan[] annotations = smartSelection.annotate(textString);
            for (SmartSelection.AnnotatedSpan span : annotations) {
                final Map<String, Float> entityScores = new HashMap<>();
                final SmartSelection.ClassificationResult[] results = span.getClassification();
                if (results.length == 0 || !entitiesToIdentify.contains(results[0].mCollection)) {
                    continue;
                }
                final Map<String, Float> entityScores = new HashMap<>();
                for (int i = 0; i < results.length; i++) {
                    entityScores.put(results[i].mCollection, results[i].mScore);
                }
@@ -192,6 +213,20 @@ final class TextClassifierImpl implements TextClassifier {
        return builder.build();
    }

    @Override
    public Collection<String> getEntitiesForPreset(@TextClassifier.EntityPreset int entityPreset) {
        switch (entityPreset) {
            case TextClassifier.ENTITY_PRESET_NONE:
                return Collections.emptyList();
            case TextClassifier.ENTITY_PRESET_BASE:
                return ENTITY_TYPES_BASE;
            case TextClassifier.ENTITY_PRESET_ALL:
                // fall through
            default:
                return ENTITY_TYPES_ALL;
        }
    }

    @Override
    public void logEvent(String source, String event) {
        if (LOG_TAG.equals(source)) {
+24 −3
Original line number Diff line number Diff line
@@ -161,17 +161,29 @@ public final class TextLinks {
    public static final class Options {

        private LocaleList mDefaultLocales;
        private TextClassifier.EntityConfig mEntityConfig;

        /**
         * @param defaultLocales ordered list of locale preferences that may be used to disambiguate
         *      the provided text. If no locale preferences exist, set this to null or an empty
         *      locale list.
         * @param defaultLocales ordered list of locale preferences that may be used to
         *                       disambiguate the provided text. If no locale preferences exist,
         *                       set this to null or an empty locale list.
         */
        public Options setDefaultLocales(@Nullable LocaleList defaultLocales) {
            mDefaultLocales = defaultLocales;
            return this;
        }

        /**
         * Sets the entity configuration to use. This determines what types of entities the
         * TextClassifier will look for.
         *
         * @param entityConfig EntityConfig to use
         */
        public Options setEntityConfig(@Nullable TextClassifier.EntityConfig entityConfig) {
            mEntityConfig = entityConfig;
            return this;
        }

        /**
         * @return ordered list of locale preferences that can be used to disambiguate
         *      the provided text.
@@ -180,6 +192,15 @@ public final class TextLinks {
        public LocaleList getDefaultLocales() {
            return mDefaultLocales;
        }

        /**
         * @return The config representing the set of entities to look for.
         * @see #setEntityConfig(TextClassifier.EntityConfig)
         */
        @Nullable
        public TextClassifier.EntityConfig getEntityConfig() {
            return mEntityConfig;
        }
    }

    /**
+66 −33
Original line number Diff line number Diff line
@@ -16,10 +16,9 @@

package android.view.textclassifier;

import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;

import android.os.LocaleList;
@@ -34,8 +33,6 @@ import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;

import java.util.Collection;

@SmallTest
@RunWith(AndroidJUnit4.class)
public class TextClassificationManagerTest {
@@ -166,20 +163,50 @@ public class TextClassificationManagerTest {
    }

    @Test
    public void testGenerateLinks() {
    public void testGenerateLinks_phone() {
        if (isTextClassifierDisabled()) return;
        String text = "The number is +12122537077. See you tonight!";
        assertThat(mClassifier.generateLinks(text, null),
                isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE));
    }

        checkGenerateLinksFindsLink(
                "The number is +12122537077. See you tonight!",
                "+12122537077",
                TextClassifier.TYPE_PHONE);
    @Test
    public void testGenerateLinks_exclude() {
        if (isTextClassifierDisabled()) return;
        String text = "The number is +12122537077. See you tonight!";
        assertThat(mClassifier.generateLinks(text, mLinksOptions.setEntityConfig(
                new TextClassifier.EntityConfig(TextClassifier.ENTITY_PRESET_ALL)
                        .excludeEntities(TextClassifier.TYPE_PHONE))),
                not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)));
    }

        checkGenerateLinksFindsLink(
                "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you tonight!",
                "1600 Amphitheater Parkway, Mountain View, CA",
                TextClassifier.TYPE_ADDRESS);
    @Test
    public void testGenerateLinks_none_config() {
        if (isTextClassifierDisabled()) return;
        String text = "The number is +12122537077. See you tonight!";
        assertThat(mClassifier.generateLinks(text, mLinksOptions.setEntityConfig(
                new TextClassifier.EntityConfig(TextClassifier.ENTITY_PRESET_NONE))),
                not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)));
    }

        // TODO: Add more entity types when the model supports them.
    @Test
    public void testGenerateLinks_address() {
        if (isTextClassifierDisabled()) return;
        String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
        assertThat(mClassifier.generateLinks(text, null),
                isTextLinksContaining(text, "1600 Amphitheater Parkway, Mountain View, CA",
                        TextClassifier.TYPE_ADDRESS));
    }

    @Test
    public void testGenerateLinks_include() {
        if (isTextClassifierDisabled()) return;
        String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
        assertThat(mClassifier.generateLinks(text, mLinksOptions.setEntityConfig(
                new TextClassifier.EntityConfig(TextClassifier.ENTITY_PRESET_NONE)
                        .includeEntities(TextClassifier.TYPE_ADDRESS))),
                isTextLinksContaining(text, "1600 Amphitheater Parkway, Mountain View, CA",
                        TextClassifier.TYPE_ADDRESS));
    }

    @Test
@@ -193,25 +220,6 @@ public class TextClassificationManagerTest {
        return mClassifier == TextClassifier.NO_OP;
    }

    private void checkGenerateLinksFindsLink(String text, String classifiedText, String type) {
        assertTrue(text.contains(classifiedText));
        int startIndex = text.indexOf(classifiedText);
        int endIndex = startIndex + classifiedText.length();

        Collection<TextLinks.TextLink> links = mClassifier.generateLinks(text, mLinksOptions)
                .getLinks();
        for (TextLinks.TextLink link : links) {
            if (text.subSequence(link.getStart(), link.getEnd()).equals(classifiedText)) {
                assertEquals(type, link.getEntity(0));
                assertEquals(startIndex, link.getStart());
                assertEquals(endIndex, link.getEnd());
                assertTrue(link.getConfidenceScore(type) > 0);
                return;
            }
        }
        fail(); // Subsequence was not identified.
    }

    private static Matcher<TextSelection> isTextSelection(
            final int startIndex, final int endIndex, final String type) {
        return new BaseMatcher<TextSelection>() {
@@ -240,6 +248,31 @@ public class TextClassificationManagerTest {
        };
    }

    private static Matcher<TextLinks> isTextLinksContaining(
            final String text, final String substring, final String type) {
        return new BaseMatcher<TextLinks>() {

            @Override
            public void describeTo(Description description) {
                description.appendText("text=").appendValue(text)
                        .appendText(", substring=").appendValue(substring)
                        .appendText(", type=").appendValue(type);
            }

            @Override
            public boolean matches(Object o) {
                if (o instanceof TextLinks) {
                    for (TextLinks.TextLink link : ((TextLinks) o).getLinks()) {
                        if (text.subSequence(link.getStart(), link.getEnd()).equals(substring)) {
                            return type.equals(link.getEntity(0));
                        }
                    }
                }
                return false;
            }
        };
    }

    private static Matcher<TextClassification> isTextClassification(
            final String text, final String type, final String intentUri) {
        return new BaseMatcher<TextClassification>() {