Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit db18a578 authored by Richard Ledley's avatar Richard Ledley
Browse files

Add entity types to Options.

Test: bit FrameworksCoreTests:android.view.textclassifier.TextClassificationManagerTest
Bug: b/67629726
Change-Id: I9cad0159ab539a71d9f504019ebe91fe18206d60
parent 26b8722d
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
@@ -49130,9 +49130,13 @@ package android.view.textclassifier {
    method public default android.view.textclassifier.TextClassification classifyText(java.lang.CharSequence, int, int, android.os.LocaleList);
    method public default android.view.textclassifier.TextLinks generateLinks(java.lang.CharSequence, android.view.textclassifier.TextLinks.Options);
    method public default android.view.textclassifier.TextLinks generateLinks(java.lang.CharSequence);
    method public default java.util.Collection<java.lang.String> getEntitiesForPreset(int);
    method public default android.view.textclassifier.TextSelection suggestSelection(java.lang.CharSequence, int, int, android.view.textclassifier.TextSelection.Options);
    method public default android.view.textclassifier.TextSelection suggestSelection(java.lang.CharSequence, int, int);
    method public default android.view.textclassifier.TextSelection suggestSelection(java.lang.CharSequence, int, int, android.os.LocaleList);
    field public static final int ENTITY_PRESET_ALL = 0; // 0x0
    field public static final int ENTITY_PRESET_BASE = 2; // 0x2
    field public static final int ENTITY_PRESET_NONE = 1; // 0x1
    field public static final android.view.textclassifier.TextClassifier NO_OP;
    field public static final java.lang.String TYPE_ADDRESS = "address";
    field public static final java.lang.String TYPE_EMAIL = "email";
@@ -49142,6 +49146,13 @@ package android.view.textclassifier {
    field public static final java.lang.String TYPE_URL = "url";
  }
  public static final class TextClassifier.EntityConfig {
    ctor public TextClassifier.EntityConfig(int);
    method public android.view.textclassifier.TextClassifier.EntityConfig excludeEntities(java.lang.String...);
    method public java.util.List<java.lang.String> getEntities(android.view.textclassifier.TextClassifier);
    method public android.view.textclassifier.TextClassifier.EntityConfig includeEntities(java.lang.String...);
  }
  public final class TextLinks {
    method public boolean apply(android.text.SpannableString, java.util.function.Function<android.view.textclassifier.TextLinks.TextLink, android.text.style.ClickableSpan>);
    method public java.util.Collection<android.view.textclassifier.TextLinks.TextLink> getLinks();
@@ -49156,7 +49167,9 @@ package android.view.textclassifier {
  public static final class TextLinks.Options {
    ctor public TextLinks.Options();
    method public android.os.LocaleList getDefaultLocales();
    method public android.view.textclassifier.TextClassifier.EntityConfig getEntityConfig();
    method public android.view.textclassifier.TextLinks.Options setDefaultLocales(android.os.LocaleList);
    method public android.view.textclassifier.TextLinks.Options setEntityConfig(android.view.textclassifier.TextClassifier.EntityConfig);
  }
  public static final class TextLinks.TextLink {
+88 −0
Original line number Diff line number Diff line
@@ -16,17 +16,23 @@

package android.view.textclassifier;

import android.annotation.IntDef;
import android.annotation.IntRange;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.StringDef;
import android.annotation.WorkerThread;
import android.os.LocaleList;
import android.util.ArraySet;

import com.android.internal.util.Preconditions;

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

/**
 * Interface for providing text classification related features.
@@ -58,6 +64,20 @@ public interface TextClassifier {
    })
    @interface EntityType {}

    /** Designates that the TextClassifier should identify all entity types it can. **/
    int ENTITY_PRESET_ALL = 0;
    /** Designates that the TextClassifier should identify no entities. **/
    int ENTITY_PRESET_NONE = 1;
    /** Designates that the TextClassifier should identify a base set of entities determined by the
     * TextClassifier. **/
    int ENTITY_PRESET_BASE = 2;

    /** @hide */
    @Retention(RetentionPolicy.SOURCE)
    @IntDef(prefix = { "ENTITY_CONFIG_" },
            value = {ENTITY_PRESET_ALL, ENTITY_PRESET_NONE, ENTITY_PRESET_BASE})
    @interface EntityPreset {}

    /**
     * No-op TextClassifier.
     * This may be used to turn off TextClassifier features.
@@ -217,6 +237,8 @@ public interface TextClassifier {
     * Returns a {@link TextLinks} that may be applied to the text to annotate it with links
     * information.
     *
     * If no options are supplied, default values will be used, determined by the TextClassifier.
     *
     * @param text the text to generate annotations for
     * @param options configuration for link generation
     *
@@ -250,6 +272,16 @@ public interface TextClassifier {
        return generateLinks(text, null);
    }

    /**
     * Returns a {@link Collection} of the entity types in the specified preset.
     *
     * @see #ENTITIES_ALL
     * @see #ENTITIES_NONE
     */
    default Collection<String> getEntitiesForPreset(@EntityPreset int entityPreset) {
        return Collections.EMPTY_LIST;
    }

    /**
     * Logs a TextClassifier event.
     *
@@ -268,6 +300,62 @@ public interface TextClassifier {
        return TextClassifierConstants.DEFAULT;
    }

    /**
     * Configuration object for specifying what entities to identify.
     *
     * Configs are initially based on a predefined preset, and can be modified from there.
     */
    final class EntityConfig {
        private final @TextClassifier.EntityPreset int mEntityPreset;
        private final Collection<String> mExcludedEntityTypes;
        private final Collection<String> mIncludedEntityTypes;

        public EntityConfig(@TextClassifier.EntityPreset int mEntityPreset) {
            this.mEntityPreset = mEntityPreset;
            mExcludedEntityTypes = new ArraySet<>();
            mIncludedEntityTypes = new ArraySet<>();
        }

        /**
         * Specifies an entity to include in addition to any specified by the enity preset.
         *
         * Note that if an entity has been excluded, the exclusion will take precedence.
         */
        public EntityConfig includeEntities(String... entities) {
            for (String entity : entities) {
                mIncludedEntityTypes.add(entity);
            }
            return this;
        }

        /**
         * Specifies an entity to be excluded.
         */
        public EntityConfig excludeEntities(String... entities) {
            for (String entity : entities) {
                mExcludedEntityTypes.add(entity);
            }
            return this;
        }

        /**
         * Returns an unmodifiable list of the final set of entities to find.
         */
        public List<String> getEntities(TextClassifier textClassifier) {
            ArrayList<String> entities = new ArrayList<>();
            for (String entity : textClassifier.getEntitiesForPreset(mEntityPreset)) {
                if (!mExcludedEntityTypes.contains(entity)) {
                    entities.add(entity);
                }
            }
            for (String entity : mIncludedEntityTypes) {
                if (!mExcludedEntityTypes.contains(entity) && !entities.contains(entity)) {
                    entities.add(entity);
                }
            }
            return Collections.unmodifiableList(entities);
        }
    }

    /**
     * Utility functions for TextClassifier methods.
+38 −3
Original line number Diff line number Diff line
@@ -42,6 +42,9 @@ import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
@@ -66,6 +69,18 @@ final class TextClassifierImpl implements TextClassifier {
    private static final String MODEL_FILE_REGEX = "textclassifier\\.smartselection\\.(.*)\\.model";
    private static final String UPDATED_MODEL_FILE_PATH =
            "/data/misc/textclassifier/textclassifier.smartselection.model";
    private static final List<String> ENTITY_TYPES_ALL =
            Collections.unmodifiableList(Arrays.asList(
                    TextClassifier.TYPE_ADDRESS,
                    TextClassifier.TYPE_EMAIL,
                    TextClassifier.TYPE_PHONE,
                    TextClassifier.TYPE_URL));
    private static final List<String> ENTITY_TYPES_BASE =
            Collections.unmodifiableList(Arrays.asList(
                    TextClassifier.TYPE_ADDRESS,
                    TextClassifier.TYPE_EMAIL,
                    TextClassifier.TYPE_PHONE,
                    TextClassifier.TYPE_URL));

    private final Context mContext;

@@ -168,17 +183,23 @@ final class TextClassifierImpl implements TextClassifier {

    @Override
    public TextLinks generateLinks(
            @NonNull CharSequence text, @NonNull TextLinks.Options options) {
            @NonNull CharSequence text, @Nullable TextLinks.Options options) {
        Utils.validateInput(text);
        final String textString = text.toString();
        final TextLinks.Builder builder = new TextLinks.Builder(textString);
        try {
            LocaleList defaultLocales = options != null ? options.getDefaultLocales() : null;
            final LocaleList defaultLocales = options != null ? options.getDefaultLocales() : null;
            final Collection<String> entitiesToIdentify =
                    options != null && options.getEntityConfig() != null
                            ? options.getEntityConfig().getEntities(this) : ENTITY_TYPES_ALL;
            final SmartSelection smartSelection = getSmartSelection(defaultLocales);
            final SmartSelection.AnnotatedSpan[] annotations = smartSelection.annotate(textString);
            for (SmartSelection.AnnotatedSpan span : annotations) {
                final Map<String, Float> entityScores = new HashMap<>();
                final SmartSelection.ClassificationResult[] results = span.getClassification();
                if (results.length == 0 || !entitiesToIdentify.contains(results[0].mCollection)) {
                    continue;
                }
                final Map<String, Float> entityScores = new HashMap<>();
                for (int i = 0; i < results.length; i++) {
                    entityScores.put(results[i].mCollection, results[i].mScore);
                }
@@ -192,6 +213,20 @@ final class TextClassifierImpl implements TextClassifier {
        return builder.build();
    }

    @Override
    public Collection<String> getEntitiesForPreset(@TextClassifier.EntityPreset int entityPreset) {
        switch (entityPreset) {
            case TextClassifier.ENTITY_PRESET_NONE:
                return Collections.emptyList();
            case TextClassifier.ENTITY_PRESET_BASE:
                return ENTITY_TYPES_BASE;
            case TextClassifier.ENTITY_PRESET_ALL:
                // fall through
            default:
                return ENTITY_TYPES_ALL;
        }
    }

    @Override
    public void logEvent(String source, String event) {
        if (LOG_TAG.equals(source)) {
+24 −3
Original line number Diff line number Diff line
@@ -161,17 +161,29 @@ public final class TextLinks {
    public static final class Options {

        private LocaleList mDefaultLocales;
        private TextClassifier.EntityConfig mEntityConfig;

        /**
         * @param defaultLocales ordered list of locale preferences that may be used to disambiguate
         *      the provided text. If no locale preferences exist, set this to null or an empty
         *      locale list.
         * @param defaultLocales ordered list of locale preferences that may be used to
         *                       disambiguate the provided text. If no locale preferences exist,
         *                       set this to null or an empty locale list.
         */
        public Options setDefaultLocales(@Nullable LocaleList defaultLocales) {
            mDefaultLocales = defaultLocales;
            return this;
        }

        /**
         * Sets the entity configuration to use. This determines what types of entities the
         * TextClassifier will look for.
         *
         * @param entityConfig EntityConfig to use
         */
        public Options setEntityConfig(@Nullable TextClassifier.EntityConfig entityConfig) {
            mEntityConfig = entityConfig;
            return this;
        }

        /**
         * @return ordered list of locale preferences that can be used to disambiguate
         *      the provided text.
@@ -180,6 +192,15 @@ public final class TextLinks {
        public LocaleList getDefaultLocales() {
            return mDefaultLocales;
        }

        /**
         * @return The config representing the set of entities to look for.
         * @see #setEntityConfig(TextClassifier.EntityConfig)
         */
        @Nullable
        public TextClassifier.EntityConfig getEntityConfig() {
            return mEntityConfig;
        }
    }

    /**
+66 −33
Original line number Diff line number Diff line
@@ -16,10 +16,9 @@

package android.view.textclassifier;

import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;

import android.os.LocaleList;
@@ -34,8 +33,6 @@ import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;

import java.util.Collection;

@SmallTest
@RunWith(AndroidJUnit4.class)
public class TextClassificationManagerTest {
@@ -166,20 +163,50 @@ public class TextClassificationManagerTest {
    }

    @Test
    public void testGenerateLinks() {
    public void testGenerateLinks_phone() {
        if (isTextClassifierDisabled()) return;
        String text = "The number is +12122537077. See you tonight!";
        assertThat(mClassifier.generateLinks(text, null),
                isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE));
    }

        checkGenerateLinksFindsLink(
                "The number is +12122537077. See you tonight!",
                "+12122537077",
                TextClassifier.TYPE_PHONE);
    @Test
    public void testGenerateLinks_exclude() {
        if (isTextClassifierDisabled()) return;
        String text = "The number is +12122537077. See you tonight!";
        assertThat(mClassifier.generateLinks(text, mLinksOptions.setEntityConfig(
                new TextClassifier.EntityConfig(TextClassifier.ENTITY_PRESET_ALL)
                        .excludeEntities(TextClassifier.TYPE_PHONE))),
                not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)));
    }

        checkGenerateLinksFindsLink(
                "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you tonight!",
                "1600 Amphitheater Parkway, Mountain View, CA",
                TextClassifier.TYPE_ADDRESS);
    @Test
    public void testGenerateLinks_none_config() {
        if (isTextClassifierDisabled()) return;
        String text = "The number is +12122537077. See you tonight!";
        assertThat(mClassifier.generateLinks(text, mLinksOptions.setEntityConfig(
                new TextClassifier.EntityConfig(TextClassifier.ENTITY_PRESET_NONE))),
                not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)));
    }

        // TODO: Add more entity types when the model supports them.
    @Test
    public void testGenerateLinks_address() {
        if (isTextClassifierDisabled()) return;
        String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
        assertThat(mClassifier.generateLinks(text, null),
                isTextLinksContaining(text, "1600 Amphitheater Parkway, Mountain View, CA",
                        TextClassifier.TYPE_ADDRESS));
    }

    @Test
    public void testGenerateLinks_include() {
        if (isTextClassifierDisabled()) return;
        String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
        assertThat(mClassifier.generateLinks(text, mLinksOptions.setEntityConfig(
                new TextClassifier.EntityConfig(TextClassifier.ENTITY_PRESET_NONE)
                        .includeEntities(TextClassifier.TYPE_ADDRESS))),
                isTextLinksContaining(text, "1600 Amphitheater Parkway, Mountain View, CA",
                        TextClassifier.TYPE_ADDRESS));
    }

    @Test
@@ -193,25 +220,6 @@ public class TextClassificationManagerTest {
        return mClassifier == TextClassifier.NO_OP;
    }

    private void checkGenerateLinksFindsLink(String text, String classifiedText, String type) {
        assertTrue(text.contains(classifiedText));
        int startIndex = text.indexOf(classifiedText);
        int endIndex = startIndex + classifiedText.length();

        Collection<TextLinks.TextLink> links = mClassifier.generateLinks(text, mLinksOptions)
                .getLinks();
        for (TextLinks.TextLink link : links) {
            if (text.subSequence(link.getStart(), link.getEnd()).equals(classifiedText)) {
                assertEquals(type, link.getEntity(0));
                assertEquals(startIndex, link.getStart());
                assertEquals(endIndex, link.getEnd());
                assertTrue(link.getConfidenceScore(type) > 0);
                return;
            }
        }
        fail(); // Subsequence was not identified.
    }

    private static Matcher<TextSelection> isTextSelection(
            final int startIndex, final int endIndex, final String type) {
        return new BaseMatcher<TextSelection>() {
@@ -240,6 +248,31 @@ public class TextClassificationManagerTest {
        };
    }

    private static Matcher<TextLinks> isTextLinksContaining(
            final String text, final String substring, final String type) {
        return new BaseMatcher<TextLinks>() {

            @Override
            public void describeTo(Description description) {
                description.appendText("text=").appendValue(text)
                        .appendText(", substring=").appendValue(substring)
                        .appendText(", type=").appendValue(type);
            }

            @Override
            public boolean matches(Object o) {
                if (o instanceof TextLinks) {
                    for (TextLinks.TextLink link : ((TextLinks) o).getLinks()) {
                        if (text.subSequence(link.getStart(), link.getEnd()).equals(substring)) {
                            return type.equals(link.getEntity(0));
                        }
                    }
                }
                return false;
            }
        };
    }

    private static Matcher<TextClassification> isTextClassification(
            final String text, final String type, final String intentUri) {
        return new BaseMatcher<TextClassification>() {