Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 3b75e31c authored by Abodunrinwa Toki's avatar Abodunrinwa Toki Committed by android-build-merger
Browse files

Merge "SmartSelection: Use downloaded model file." into oc-dev

am: a97608ca

Change-Id: If5dbd82ad5150a1d4f3cfe35dd7325468f36a94c
parents d0c1cf63 a97608ca
Loading
Loading
Loading
Loading
+18 −0
Original line number Diff line number Diff line
@@ -75,6 +75,20 @@ final class SmartSelection {
        nativeClose(mCtx);
    }

    /**
     * Returns the language of the model.
     */
    public static String getLanguage(int fd) {
        return nativeGetLanguage(fd);
    }

    /**
     * Returns the version of the model.
     */
    public static int getVersion(int fd) {
        return nativeGetVersion(fd);
    }

    private static native long nativeNew(int fd);

    private static native int[] nativeSuggest(
@@ -85,6 +99,10 @@ final class SmartSelection {

    private static native void nativeClose(long context);

    private static native String nativeGetLanguage(int fd);

    private static native int nativeGetVersion(int fd);

    /** Classification result for classifyText method. */
    static final class ClassificationResult {
        final String mCollection;
+103 −9
Original line number Diff line number Diff line
@@ -44,6 +44,7 @@ import com.android.internal.util.Preconditions;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
@@ -71,6 +72,8 @@ final class TextClassifierImpl implements TextClassifier {
    private static final String LOG_TAG = "TextClassifierImpl";
    private static final String MODEL_DIR = "/etc/textclassifier/";
    private static final String MODEL_FILE_REGEX = "textclassifier\\.smartselection\\.(.*)\\.model";
    private static final String UPDATED_MODEL_FILE_PATH =
            "/data/misc/textclassifier/textclassifier.smartselection.model";

    private final Context mContext;

@@ -175,21 +178,80 @@ final class TextClassifierImpl implements TextClassifier {
        synchronized (mSmartSelectionLock) {
            localeList = localeList == null ? LocaleList.getEmptyLocaleList() : localeList;
            final Locale locale = findBestSupportedLocaleLocked(localeList);
            if (locale == null) {
                throw new FileNotFoundException("No file for null locale");
            }
            if (mSmartSelection == null || !Objects.equals(mLocale, locale)) {
                destroySmartSelectionIfExistsLocked();
                mSmartSelection = new SmartSelection(
                        ParcelFileDescriptor.open(
                                // findBestSupportedLocaleLocked should have initialized
                                // mModelFilePaths
                                new File(mModelFilePaths.get(locale)),
                                ParcelFileDescriptor.MODE_READ_ONLY)
                                .getFd());
                mSmartSelection = new SmartSelection(getFdLocked(locale));
                mLocale = locale;
            }
            return mSmartSelection;
        }
    }

    @GuardedBy("mSmartSelectionLock") // Do not call outside this lock.
    private int getFdLocked(Locale locale) throws FileNotFoundException {
        ParcelFileDescriptor updateFd;
        try {
            updateFd = ParcelFileDescriptor.open(
                    new File(UPDATED_MODEL_FILE_PATH), ParcelFileDescriptor.MODE_READ_ONLY);
        } catch (FileNotFoundException e) {
            updateFd = null;
        }
        ParcelFileDescriptor factoryFd;
        try {
            final String factoryModelFilePath = getFactoryModelFilePathsLocked().get(locale);
            if (factoryModelFilePath != null) {
                factoryFd = ParcelFileDescriptor.open(
                        new File(factoryModelFilePath), ParcelFileDescriptor.MODE_READ_ONLY);
            } else {
                factoryFd = null;
            }
        } catch (FileNotFoundException e) {
            factoryFd = null;
        }

        if (updateFd == null) {
            if (factoryFd != null) {
                return factoryFd.getFd();
            } else {
                throw new FileNotFoundException(
                        String.format("No model file found for %s", locale));
            }
        }

        final int updateFdInt = updateFd.getFd();
        final boolean localeMatches = Objects.equals(
                locale.getLanguage().trim().toLowerCase(),
                SmartSelection.getLanguage(updateFdInt).trim().toLowerCase());
        if (factoryFd == null) {
            if (localeMatches) {
                return updateFdInt;
            } else {
                closeAndLogError(updateFd);
                throw new FileNotFoundException(
                        String.format("No model file found for %s", locale));
            }
        }

        if (!localeMatches) {
            closeAndLogError(updateFd);
            return factoryFd.getFd();
        }

        final int updateVersion = SmartSelection.getVersion(updateFdInt);
        final int factoryFdInt = factoryFd.getFd();
        final int factoryVersion = SmartSelection.getVersion(factoryFdInt);
        if (updateVersion > factoryVersion) {
            closeAndLogError(factoryFd);
            return updateFdInt;
        } else {
            closeAndLogError(updateFd);
            return factoryFdInt;
        }
    }

    @GuardedBy("mSmartSelectionLock") // Do not call outside this lock.
    private void destroySmartSelectionIfExistsLocked() {
        if (mSmartSelection != null) {
@@ -206,11 +268,18 @@ final class TextClassifierImpl implements TextClassifier {
                ? LocaleList.getDefault().toLanguageTags()
                : localeList.toLanguageTags() + "," + LocaleList.getDefault().toLanguageTags();
        final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
        return Locale.lookup(languageRangeList, loadModelFilePathsLocked().keySet());

        final List<Locale> supportedLocales =
                new ArrayList<>(getFactoryModelFilePathsLocked().keySet());
        final Locale updatedModelLocale = getUpdatedModelLocale();
        if (updatedModelLocale != null) {
            supportedLocales.add(updatedModelLocale);
        }
        return Locale.lookup(languageRangeList, supportedLocales);
    }

    @GuardedBy("mSmartSelectionLock") // Do not call outside this lock.
    private Map<Locale, String> loadModelFilePathsLocked() {
    private Map<Locale, String> getFactoryModelFilePathsLocked() {
        if (mModelFilePaths == null) {
            final Map<Locale, String> modelFilePaths = new HashMap<>();
            final File modelsDir = new File(MODEL_DIR);
@@ -233,6 +302,20 @@ final class TextClassifierImpl implements TextClassifier {
        return mModelFilePaths;
    }

    @Nullable
    private Locale getUpdatedModelLocale() {
        try {
            final ParcelFileDescriptor updateFd = ParcelFileDescriptor.open(
                    new File(UPDATED_MODEL_FILE_PATH), ParcelFileDescriptor.MODE_READ_ONLY);
            final Locale locale = Locale.forLanguageTag(
                    SmartSelection.getLanguage(updateFd.getFd()));
            closeAndLogError(updateFd);
            return locale;
        } catch (FileNotFoundException e) {
            return null;
        }
    }

    private TextClassificationResult createClassificationResult(
            SmartSelection.ClassificationResult[] classifications, CharSequence text) {
        final TextClassificationResult.Builder builder = new TextClassificationResult.Builder()
@@ -318,6 +401,17 @@ final class TextClassifierImpl implements TextClassifier {
        return type;
    }

    /**
     * Closes the ParcelFileDescriptor and logs any errors that occur.
     */
    private static void closeAndLogError(ParcelFileDescriptor fd) {
        try {
            fd.close();
        } catch (IOException e) {
            Log.e(LOG_TAG, "Error closing file.", e);
        }
    }

    /**
     * @throws IllegalArgumentException if text is null; startIndex is negative;
     *      endIndex is greater than text.length() or is not greater than startIndex