Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 67ca6c5d authored by Felipe Leme's avatar Felipe Leme
Browse files

YAAFFFCR - Yet Another Android Autofill Framework Field Classification Refactoring.

The field classification service is moving to another process, hence we need
to get the scores in a batch and handle the results in a callback.

Test: atest CtsAutoFillServiceTestCases:FillEventHistoryTest \
            CtsAutoFillServiceTestCases:FieldsClassificationTest
Bug: 70939974

Change-Id: I0de91f18828872c455abd1609d3a3890ddc3bd4f
parent be97bb84
Loading
Loading
Loading
Loading
+125 −11
Original line number Diff line number Diff line
@@ -43,14 +43,19 @@ import android.os.Binder;
import android.os.Bundle;
import android.os.IBinder;
import android.os.Looper;
import android.os.Parcel;
import android.os.Parcelable;
import android.os.RemoteCallbackList;
import android.os.RemoteException;
import android.os.SystemClock;
import android.os.UserHandle;
import android.os.UserManager;
import android.os.Parcelable.Creator;
import android.os.RemoteCallback;
import android.provider.Settings;
import android.service.autofill.AutofillService;
import android.service.autofill.AutofillServiceInfo;
import android.service.autofill.Dataset;
import android.service.autofill.EditDistanceScorer;
import android.service.autofill.FieldClassification;
import android.service.autofill.FieldClassification.Match;
@@ -123,6 +128,8 @@ final class AutofillManagerServiceImpl {
    // TODO(b/70939974): temporary, will be moved to ExtServices
    static final class FieldClassificationAlgorithmService {

        static final String EXTRA_SCORES = "scores";

        /**
         * Gets the name of all available algorithms.
         */
@@ -140,26 +147,110 @@ final class AutofillManagerServiceImpl {
        }

        /**
         * Gets a field classification score.
         * Gets the field classification scores.
         *
         * @param algorithmName algorithm to be used. If invalid, the default algorithm will be used
         * instead.
         * @param algorithmArgs optional arguments to be passed to the algorithm.
         * @param actualValue value entered by the user.
         * @param userDataValue value from the user data.
         *
         * @return pair containing the algorithm used and the score.
         * @param currentValues values entered by the user.
         * @param userValues values from the user data.
         * @param callback returns a nullable bundle with the parcelable results on
         * {@link #EXTRA_SCORES}.
         */
        // TODO(b/70939974): use parcelable instead of pair
        Pair<String, Float> getScore(@NonNull String algorithmName, @Nullable Bundle algorithmArgs,
                @NonNull AutofillValue actualValue, @NonNull String userDataValue) {
        @Nullable
        void getScores(@NonNull String algorithmName, @Nullable Bundle algorithmArgs,
                List<AutofillValue> currentValues, @NonNull String[] userValues,
                @NonNull RemoteCallback callback) {
            if (currentValues == null || userValues == null) {
                // TODO(b/70939974): use preconditions / add unit test
                throw new IllegalArgumentException("values cannot be null");
            }
            if (currentValues.isEmpty() || userValues.length == 0) {
                Slog.w(TAG, "getScores(): empty currentvalues (" + currentValues
                        + ") or userValues (" + Arrays.toString(userValues) + ")");
                // TODO(b/70939974): add unit test
                callback.sendResult(null);
            }
            String actualAlgorithName = algorithmName;
            if (!EditDistanceScorer.NAME.equals(algorithmName)) {
                Log.w(TAG, "Ignoring invalid algorithm (" + algorithmName + ") and using "
                Slog.w(TAG, "Ignoring invalid algorithm (" + algorithmName + ") and using "
                        + EditDistanceScorer.NAME + " instead");
                actualAlgorithName = EditDistanceScorer.NAME;
            }
            final int currentValuesSize = currentValues.size();
            if (sDebug) {
                Log.d(TAG, "getScores() will return a " + currentValuesSize + "x"
                        + userValues.length + " matrix for " + actualAlgorithName);
            }
            final FieldClassificationScores scores = new FieldClassificationScores(
                    actualAlgorithName, currentValuesSize, userValues.length);
            final EditDistanceScorer algorithm = EditDistanceScorer.getInstance();
            for (int i = 0; i < currentValuesSize; i++) {
                for (int j = 0; j < userValues.length; j++) {
                    final float score = algorithm.getScore(currentValues.get(i), userValues[j]);
                    scores.scores[i][j] = score;
                }
            }
            final Bundle result = new Bundle();
            result.putParcelable(EXTRA_SCORES, scores);
            callback.sendResult(result);
        }
    }

    // TODO(b/70939974): temporary, will be moved to ExtServices
    public static final class FieldClassificationScores implements Parcelable {
        public final String algorithmName;
        public final float[][] scores;

        public FieldClassificationScores(String algorithmName, int size1, int size2) {
            this.algorithmName = algorithmName;
            scores = new float[size1][size2];
        }

        public FieldClassificationScores(Parcel parcel) {
            algorithmName = parcel.readString();
            final int size1 = parcel.readInt();
            final int size2 = parcel.readInt();
            scores = new float[size1][size2];
            for (int i = 0; i < size1; i++) {
                for (int j = 0; j < size2; j++) {
                    scores[i][j] = parcel.readFloat();
                }
            }
        }

        @Override
        public int describeContents() {
            return 0;
        }

        @Override
        public void writeToParcel(Parcel parcel, int flags) {
            parcel.writeString(algorithmName);
            int size1 = scores.length;
            int size2 = scores[0].length;
            parcel.writeInt(size1);
            parcel.writeInt(size2);
            for (int i = 0; i < size1; i++) {
                for (int j = 0; j < size2; j++) {
                    parcel.writeFloat(scores[i][j]);
                }
            }
        }

        public static final Creator<FieldClassificationScores> CREATOR = new Creator<FieldClassificationScores>() {

            @Override
            public FieldClassificationScores createFromParcel(Parcel parcel) {
                return new FieldClassificationScores(parcel);
            }
            return new Pair<>(EditDistanceScorer.NAME,
                    EditDistanceScorer.getInstance().getScore(actualValue, userDataValue));

            @Override
            public FieldClassificationScores[] newArray(int size) {
                return new FieldClassificationScores[size];
            }

        };
    }

    private final FieldClassificationAlgorithmService mFcService =
@@ -769,6 +860,19 @@ final class AutofillManagerServiceImpl {
    /**
     * Updates the last fill response when an autofill context is committed.
     */
    void logContextCommittedLocked(int sessionId, @Nullable Bundle clientState,
            @Nullable ArrayList<String> selectedDatasets,
            @Nullable ArraySet<String> ignoredDatasets,
            @Nullable ArrayList<AutofillId> changedFieldIds,
            @Nullable ArrayList<String> changedDatasetIds,
            @Nullable ArrayList<AutofillId> manuallyFilledFieldIds,
            @Nullable ArrayList<ArrayList<String>> manuallyFilledDatasetIds,
            @NonNull String appPackageName) {
        logContextCommittedLocked(sessionId, clientState, selectedDatasets, ignoredDatasets,
                changedFieldIds, changedDatasetIds, manuallyFilledFieldIds,
                manuallyFilledDatasetIds, null, null, appPackageName);
    }

    void logContextCommittedLocked(int sessionId, @Nullable Bundle clientState,
            @Nullable ArrayList<String> selectedDatasets,
            @Nullable ArraySet<String> ignoredDatasets,
@@ -780,6 +884,16 @@ final class AutofillManagerServiceImpl {
            @Nullable ArrayList<FieldClassification> detectedFieldClassificationsList,
            @NonNull String appPackageName) {
        if (isValidEventLocked("logDatasetNotSelected()", sessionId)) {
            if (sVerbose) {
                Slog.v(TAG, "logContextCommitted() with FieldClassification: id=" + sessionId
                        + ", selectedDatasets=" + selectedDatasets
                        + ", ignoredDatasetIds=" + ignoredDatasets
                        + ", changedAutofillIds=" + changedFieldIds
                        + ", changedDatasetIds=" + changedDatasetIds
                        + ", manuallyFilledFieldIds=" + manuallyFilledFieldIds
                        + ", detectedFieldIds=" + detectedFieldIdsList
                        + ", detectedFieldClassifications=" + detectedFieldClassificationsList);
            }
            AutofillId[] detectedFieldsIds = null;
            FieldClassification[] detectedFieldClassifications = null;
            if (detectedFieldIdsList != null) {
+92 −64
Original line number Diff line number Diff line
@@ -25,11 +25,11 @@ import static android.view.autofill.AutofillManager.ACTION_VALUE_CHANGED;
import static android.view.autofill.AutofillManager.ACTION_VIEW_ENTERED;
import static android.view.autofill.AutofillManager.ACTION_VIEW_EXITED;

import static com.android.server.autofill.AutofillManagerServiceImpl.FieldClassificationAlgorithmService.EXTRA_SCORES;
import static com.android.server.autofill.Helper.sDebug;
import static com.android.server.autofill.Helper.sPartitionMaxCount;
import static com.android.server.autofill.Helper.sVerbose;
import static com.android.server.autofill.Helper.toArray;
import static com.android.server.autofill.ViewState.STATE_AUTOFILLED;
import static com.android.server.autofill.ViewState.STATE_RESTARTED_SESSION;

import android.annotation.NonNull;
@@ -51,11 +51,14 @@ import android.os.Binder;
import android.os.Bundle;
import android.os.IBinder;
import android.os.Parcelable;
import android.os.RemoteCallback;
import android.os.RemoteException;
import android.os.SystemClock;
import android.service.autofill.AutofillService;
import android.service.autofill.Dataset;
import android.service.autofill.FieldClassification;
import android.service.autofill.FieldClassification.Match;
import android.service.carrier.CarrierMessagingService.ResultCallback;
import android.service.autofill.FillContext;
import android.service.autofill.FillRequest;
import android.service.autofill.FillResponse;
@@ -65,11 +68,9 @@ import android.service.autofill.SaveInfo;
import android.service.autofill.SaveRequest;
import android.service.autofill.UserData;
import android.service.autofill.ValueFinder;
import android.service.autofill.FieldClassification;
import android.util.ArrayMap;
import android.util.ArraySet;
import android.util.LocalLog;
import android.util.Pair;
import android.util.Slog;
import android.util.SparseArray;
import android.util.TimeUtils;
@@ -86,16 +87,19 @@ import com.android.internal.logging.nano.MetricsProto.MetricsEvent;
import com.android.internal.os.HandlerCaller;
import com.android.internal.util.ArrayUtils;
import com.android.server.autofill.AutofillManagerServiceImpl.FieldClassificationAlgorithmService;
import com.android.server.autofill.AutofillManagerServiceImpl.FieldClassificationScores;
import com.android.server.autofill.ui.AutoFillUI;
import com.android.server.autofill.ui.PendingUi;

import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

/**
 * A session for a given activity.
@@ -972,18 +976,6 @@ final class Session implements RemoteFillService.FillServiceCallbacks, ViewState

        final UserData userData = mService.getUserData();

        final ArrayList<AutofillId> detectedFieldIds;
        final ArrayList<FieldClassification> detectedFieldClassifications;

        if (userData != null) {
            final int maxFieldsSize = UserData.getMaxFieldClassificationIdsSize();
            detectedFieldIds = new ArrayList<>(maxFieldsSize);
            detectedFieldClassifications = new ArrayList<>(maxFieldsSize);
        } else {
            detectedFieldIds = null;
            detectedFieldClassifications = null;
        }

        for (int i = 0; i < mViewStates.size(); i++) {
            final ViewState viewState = mViewStates.valueAt(i);
            final int state = viewState.getState();
@@ -1088,32 +1080,14 @@ final class Session implements RemoteFillService.FillServiceCallbacks, ViewState
                        } // for j
                    }

                    // Sets field classification score for field
                    if (userData!= null) {
                        setFieldClassificationScore(mService.getFieldClassificationService(),
                                detectedFieldIds, detectedFieldClassifications, userData,
                                viewState.id, currentValue);
                    }
                } // else
            } // else
        }

        if (sVerbose) {
            Slog.v(TAG, "logContextCommitted(): id=" + id
                    + ", selectedDatasetids=" + mSelectedDatasetIds
                    + ", ignoredDatasetIds=" + ignoredDatasets
                    + ", changedAutofillIds=" + changedFieldIds
                    + ", changedDatasetIds=" + changedDatasetIds
                    + ", manuallyFilledIds=" + manuallyFilledIds
                    + ", detectedFieldIds=" + detectedFieldIds
                    + ", detectedFieldClassifications=" + detectedFieldClassifications
                    );
        }

        ArrayList<AutofillId> manuallyFilledFieldIds = null;
        ArrayList<ArrayList<String>> manuallyFilledDatasetIds = null;

        // Must "flatten" the map to the parcellable collection primitives
        // Must "flatten" the map to the parcelable collection primitives
        if (manuallyFilledIds != null) {
            final int size = manuallyFilledIds.size();
            manuallyFilledFieldIds = new ArrayList<>(size);
@@ -1126,22 +1100,35 @@ final class Session implements RemoteFillService.FillServiceCallbacks, ViewState
            }
        }

        mService.logContextCommittedLocked(id, mClientState, mSelectedDatasetIds, ignoredDatasets,
                changedFieldIds, changedDatasetIds,
        // Sets field classification scores
        final FieldClassificationAlgorithmService fcService =
                mService.getFieldClassificationService();
        if (userData != null && fcService != null) {
            logFieldClassificationScoreLocked(fcService, ignoredDatasets, changedFieldIds,
                    changedDatasetIds, manuallyFilledFieldIds, manuallyFilledDatasetIds,
                    manuallyFilledIds, userData,
                    mViewStates.values());
        } else {
            mService.logContextCommittedLocked(id, mClientState, mSelectedDatasetIds,
                    ignoredDatasets, changedFieldIds, changedDatasetIds,
                    manuallyFilledFieldIds, manuallyFilledDatasetIds,
                detectedFieldIds, detectedFieldClassifications, mComponentName.getPackageName());
                    mComponentName.getPackageName());
        }
    }

    /**
     * Adds the matches to {@code detectedFieldsIds} and {@code detectedFieldClassifications} for
     * {@code fieldId} based on its {@code currentValue} and {@code userData}.
     */
    private static void setFieldClassificationScore(
            @NonNull AutofillManagerServiceImpl.FieldClassificationAlgorithmService  service,
            @NonNull ArrayList<AutofillId> detectedFieldIds,
            @NonNull ArrayList<FieldClassification> detectedFieldClassifications,
            @NonNull UserData userData, @NonNull AutofillId fieldId,
            @NonNull AutofillValue currentValue) {
    private void logFieldClassificationScoreLocked(
            @NonNull AutofillManagerServiceImpl.FieldClassificationAlgorithmService fcService,
            @NonNull ArraySet<String> ignoredDatasets,
            @NonNull ArrayList<AutofillId> changedFieldIds,
            @NonNull ArrayList<String> changedDatasetIds,
            @NonNull ArrayList<AutofillId> manuallyFilledFieldIds,
            @NonNull ArrayList<ArrayList<String>> manuallyFilledDatasetIds,
            @NonNull ArrayMap<AutofillId, ArraySet<String>> manuallyFilledIds,
            @NonNull UserData userData, @NonNull Collection<ViewState> viewStates) {

        final String[] userValues = userData.getValues();
        final String[] remoteIds = userData.getRemoteIds();
@@ -1155,26 +1142,58 @@ final class Session implements RemoteFillService.FillServiceCallbacks, ViewState
            return;
        }

        final int maxFieldsSize = UserData.getMaxFieldClassificationIdsSize();

        final ArrayList<AutofillId> detectedFieldIds = new ArrayList<>(maxFieldsSize);
        final ArrayList<FieldClassification> detectedFieldClassifications = new ArrayList<>(
                maxFieldsSize);

        final String algorithm = userData.getFieldClassificationAlgorithm();
        final Bundle algorithmArgs = userData.getAlgorithmArgs();
        final int viewsSize = viewStates.size();

        // First, we get all scores.
        final AutofillId[] fieldIds = new AutofillId[viewsSize];
        final ArrayList<AutofillValue> currentValues = new ArrayList<>(viewsSize);
        int k = 0;
        for (ViewState viewState : viewStates) {
            currentValues.add(viewState.getCurrentValue());
            fieldIds[k++] = viewState.id;
        }

        final RemoteCallback callback = new RemoteCallback((result) -> {
            if (result == null) {
                if (sDebug) Slog.d(TAG, "setFieldClassificationScore(): no results");
                mService.logContextCommittedLocked(id, mClientState, mSelectedDatasetIds,
                        ignoredDatasets, changedFieldIds, changedDatasetIds,
                        manuallyFilledFieldIds, manuallyFilledDatasetIds,
                        mComponentName.getPackageName());
                return;
            }
            final FieldClassificationScores matrix = result.getParcelable(EXTRA_SCORES);

            // Then use the results.
            for (int i = 0; i < viewsSize; i++) {
                final AutofillId fieldId = fieldIds[i];

                ArrayList<Match> matches = null;
        for (int i = 0; i < userValues.length; i++) {
            String remoteId = remoteIds[i];
            final String value = userValues[i];
            final Pair<String, Float> result = service.getScore(algorithm, algorithmArgs,
                    currentValue, value);
            final String actualAlgorithm = result.first;
            final float score = result.second;
                for (int j = 0; j < userValues.length; j++) {
                    String remoteId = remoteIds[j];
                    final String actualAlgorithm = matrix.algorithmName;
                    final float score = matrix.scores[i][j];
                    if (score > 0) {
                        if (sVerbose) {
                    Slog.v(TAG, "adding score " + score + " at index " + i + " and id " + fieldId);
                            Slog.v(TAG, "adding score " + score + " at index " + j + " and id "
                                    + fieldId);
                        }
                        if (matches == null) {
                            matches = new ArrayList<>(userValues.length);
                        }
                        matches.add(new Match(remoteId, score, actualAlgorithm));
                    }
            else if (sVerbose) Slog.v(TAG, "skipping score 0 at index " + i + " and id " + fieldId);
                    else if (sVerbose) {
                        Slog.v(TAG, "skipping score 0 at index " + j + " and id " + fieldId);
                    }
                }
                if (matches != null) {
                    detectedFieldIds.add(fieldId);
@@ -1182,6 +1201,15 @@ final class Session implements RemoteFillService.FillServiceCallbacks, ViewState
                }
            }

            mService.logContextCommittedLocked(id, mClientState, mSelectedDatasetIds,
                    ignoredDatasets, changedFieldIds, changedDatasetIds, manuallyFilledFieldIds,
                    manuallyFilledDatasetIds, detectedFieldIds, detectedFieldClassifications,
                    mComponentName.getPackageName());
        });

        fcService.getScores(algorithm, algorithmArgs, currentValues, userValues, callback);
    }

    /**
     * Shows the save UI, when session can be saved.
     *