Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 514a7403 authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Storage refactor for EntityConfidence"

parents d6ba912b bbe43dfd
Loading
Loading
Loading
Loading
+27 −30
Original line number Diff line number Diff line
@@ -18,13 +18,12 @@ package android.view.textclassifier;

import android.annotation.FloatRange;
import android.annotation.NonNull;
import android.util.ArrayMap;

import com.android.internal.util.Preconditions;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@@ -36,42 +35,43 @@ import java.util.Map;
 */
final class EntityConfidence<T> {

    private final Map<T, Float> mEntityConfidence = new HashMap<>();

    private final Comparator<T> mEntityComparator = (e1, e2) -> {
        float score1 = mEntityConfidence.get(e1);
        float score2 = mEntityConfidence.get(e2);
        if (score1 > score2) {
            return -1;
        }
        if (score1 < score2) {
            return 1;
        }
        return 0;
    };
    private final ArrayMap<T, Float> mEntityConfidence = new ArrayMap<>();
    private final ArrayList<T> mSortedEntities = new ArrayList<>();

    EntityConfidence() {}

    EntityConfidence(@NonNull EntityConfidence<T> source) {
        Preconditions.checkNotNull(source);
        mEntityConfidence.putAll(source.mEntityConfidence);
        mSortedEntities.addAll(source.mSortedEntities);
    }

    /**
     * Sets an entity type for the classified text and assigns a confidence score.
     * Constructs an EntityConfidence from a map of entity to confidence.
     *
     * @param confidenceScore a value from 0 (low confidence) to 1 (high confidence).
     *      0 implies the entity does not exist for the classified text.
     *      Values greater than 1 are clamped to 1.
     * Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1.
     *
     * @param source a map from entity to a confidence value in the range 0 (low confidence) to
     *               1 (high confidence).
     */
    public void setEntityType(
            @NonNull T type, @FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
        Preconditions.checkNotNull(type);
        if (confidenceScore > 0) {
            mEntityConfidence.put(type, Math.min(1, confidenceScore));
        } else {
            mEntityConfidence.remove(type);
    EntityConfidence(@NonNull Map<T, Float> source) {
        Preconditions.checkNotNull(source);

        // Prune non-existent entities and clamp to 1.
        mEntityConfidence.ensureCapacity(source.size());
        for (Map.Entry<T, Float> it : source.entrySet()) {
            if (it.getValue() <= 0) continue;
            mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue()));
        }

        // Create a list of entities sorted by decreasing confidence for getEntities().
        mSortedEntities.ensureCapacity(mEntityConfidence.size());
        mSortedEntities.addAll(mEntityConfidence.keySet());
        mSortedEntities.sort((e1, e2) -> {
            float score1 = mEntityConfidence.get(e1);
            float score2 = mEntityConfidence.get(e2);
            return Float.compare(score2, score1);
        });
    }

    /**
@@ -80,10 +80,7 @@ final class EntityConfidence<T> {
     */
    @NonNull
    public List<T> getEntities() {
        List<T> entities = new ArrayList<>(mEntityConfidence.size());
        entities.addAll(mEntityConfidence.keySet());
        entities.sort(mEntityComparator);
        return Collections.unmodifiableList(entities);
        return Collections.unmodifiableList(mSortedEntities);
    }

    /**
+7 −8
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ import android.content.Context;
import android.content.Intent;
import android.graphics.drawable.Drawable;
import android.os.LocaleList;
import android.util.ArrayMap;
import android.view.View.OnClickListener;
import android.view.textclassifier.TextClassifier.EntityType;

@@ -32,6 +33,7 @@ import com.android.internal.util.Preconditions;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;

/**
 * Information for generating a widget to handle classified text.
@@ -95,7 +97,6 @@ public final class TextClassification {
    @NonNull private final List<Intent> mIntents;
    @NonNull private final List<OnClickListener> mOnClickListeners;
    @NonNull private final EntityConfidence<String> mEntityConfidence;
    @NonNull private final List<String> mEntities;
    private int mLogType;
    @NonNull private final String mVersionInfo;

@@ -105,7 +106,7 @@ public final class TextClassification {
            @NonNull List<String> labels,
            @NonNull List<Intent> intents,
            @NonNull List<OnClickListener> onClickListeners,
            @NonNull EntityConfidence<String> entityConfidence,
            @NonNull Map<String, Float> entityConfidence,
            int logType,
            @NonNull String versionInfo) {
        Preconditions.checkArgument(labels.size() == intents.size());
@@ -117,7 +118,6 @@ public final class TextClassification {
        mIntents = intents;
        mOnClickListeners = onClickListeners;
        mEntityConfidence = new EntityConfidence<>(entityConfidence);
        mEntities = mEntityConfidence.getEntities();
        mLogType = logType;
        mVersionInfo = versionInfo;
    }
@@ -135,7 +135,7 @@ public final class TextClassification {
     */
    @IntRange(from = 0)
    public int getEntityCount() {
        return mEntities.size();
        return mEntityConfidence.getEntities().size();
    }

    /**
@@ -147,7 +147,7 @@ public final class TextClassification {
     */
    @NonNull
    public @EntityType String getEntity(int index) {
        return mEntities.get(index);
        return mEntityConfidence.getEntities().get(index);
    }

    /**
@@ -311,8 +311,7 @@ public final class TextClassification {
        @NonNull private final List<String> mLabels = new ArrayList<>();
        @NonNull private final List<Intent> mIntents = new ArrayList<>();
        @NonNull private final List<OnClickListener> mOnClickListeners = new ArrayList<>();
        @NonNull private final EntityConfidence<String> mEntityConfidence =
                new EntityConfidence<>();
        @NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>();
        private int mLogType;
        @NonNull private String mVersionInfo = "";

@@ -334,7 +333,7 @@ public final class TextClassification {
        public Builder setEntityType(
                @NonNull @EntityType String type,
                @FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
            mEntityConfidence.setEntityType(type, confidenceScore);
            mEntityConfidence.put(type, confidenceScore);
            return this;
        }

+1 −5
Original line number Diff line number Diff line
@@ -103,11 +103,7 @@ public final class TextLinks {
            mOriginalText = originalText;
            mStart = start;
            mEnd = end;
            mEntityScores = new EntityConfidence<>();

            for (Map.Entry<String, Float> entry : entityScores.entrySet()) {
                mEntityScores.setEntityType(entry.getKey(), entry.getValue());
            }
            mEntityScores = new EntityConfidence<>(entityScores);
        }

        /**
+8 −9
Original line number Diff line number Diff line
@@ -21,12 +21,13 @@ import android.annotation.IntRange;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.os.LocaleList;
import android.util.ArrayMap;
import android.view.textclassifier.TextClassifier.EntityType;

import com.android.internal.util.Preconditions;

import java.util.List;
import java.util.Locale;
import java.util.Map;

/**
 * Information about where text selection should be.
@@ -36,7 +37,6 @@ public final class TextSelection {
    private final int mStartIndex;
    private final int mEndIndex;
    @NonNull private final EntityConfidence<String> mEntityConfidence;
    @NonNull private final List<String> mEntities;
    @NonNull private final String mLogSource;
    @NonNull private final String mVersionInfo;

@@ -46,7 +46,6 @@ public final class TextSelection {
        mStartIndex = startIndex;
        mEndIndex = endIndex;
        mEntityConfidence = new EntityConfidence<>(entityConfidence);
        mEntities = mEntityConfidence.getEntities();
        mLogSource = logSource;
        mVersionInfo = versionInfo;
    }
@@ -70,7 +69,7 @@ public final class TextSelection {
     */
    @IntRange(from = 0)
    public int getEntityCount() {
        return mEntities.size();
        return mEntityConfidence.getEntities().size();
    }

    /**
@@ -82,7 +81,7 @@ public final class TextSelection {
     */
    @NonNull
    public @EntityType String getEntity(int index) {
        return mEntities.get(index);
        return mEntityConfidence.getEntities().get(index);
    }

    /**
@@ -126,8 +125,7 @@ public final class TextSelection {

        private final int mStartIndex;
        private final int mEndIndex;
        @NonNull private final EntityConfidence<String> mEntityConfidence =
                new EntityConfidence<>();
        @NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>();
        @NonNull private String mLogSource = "";
        @NonNull private String mVersionInfo = "";

@@ -154,7 +152,7 @@ public final class TextSelection {
        public Builder setEntityType(
                @NonNull @EntityType String type,
                @FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
            mEntityConfidence.setEntityType(type, confidenceScore);
            mEntityConfidence.put(type, confidenceScore);
            return this;
        }

@@ -181,7 +179,8 @@ public final class TextSelection {
         */
        public TextSelection build() {
            return new TextSelection(
                    mStartIndex, mEndIndex, mEntityConfidence, mLogSource, mVersionInfo);
                    mStartIndex, mEndIndex, new EntityConfidence<>(mEntityConfidence),  mLogSource,
                    mVersionInfo);
        }
    }