Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 0fc93fe4 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi
Browse files

Implement PatriciaTriePolicy::getNextWordAndNextToken().

Bug: 12810574
Change-Id: Id1d44f90de9455d9cbe7b6e0a161cae91d6d422c
parent 85fe06e7
Loading
Loading
Loading
Loading
+51 −16
Original line number Diff line number Diff line
@@ -17,12 +17,12 @@
package com.android.inputmethod.latin.makedict;

import com.android.inputmethod.annotations.UsedForTesting;
import com.android.inputmethod.latin.BinaryDictionary;
import com.android.inputmethod.latin.makedict.BinaryDictDecoderUtils.CharEncoding;
import com.android.inputmethod.latin.makedict.BinaryDictDecoderUtils.DictBuffer;
import com.android.inputmethod.latin.makedict.FormatSpec.FormatOptions;
import com.android.inputmethod.latin.makedict.FusionDictionary.WeightedString;

import android.util.Log;
import com.android.inputmethod.latin.utils.CollectionUtils;

import java.io.File;
import java.io.FileNotFoundException;
@@ -33,6 +33,7 @@ import java.util.Arrays;
/**
 * An implementation of DictDecoder for version 2 binary dictionary.
 */
// TODO: Separate logics that are used only for testing.
@UsedForTesting
public class Ver2DictDecoder extends AbstractDictDecoder {
    private static final String TAG = Ver2DictDecoder.class.getSimpleName();
@@ -116,12 +117,19 @@ public class Ver2DictDecoder extends AbstractDictDecoder {
    }

    protected final File mDictionaryBinaryFile;
    // TODO: Remove mBufferFactory and mDictBuffer from this class members because they are now
    // used only for testing.
    private final DictionaryBufferFactory mBufferFactory;
    protected DictBuffer mDictBuffer;
    private final BinaryDictionary mBinaryDictionary;

    /* package */ Ver2DictDecoder(final File file, final int factoryFlag) {
        mDictionaryBinaryFile = file;
        mDictBuffer = null;
        // dictType is not being used in dicttool. Passing an empty string.
        mBinaryDictionary = new BinaryDictionary(file.getAbsolutePath(),
                0 /* offset */, file.length() /* length */, true /* useFullEditDistance */,
                null /* locale */, "" /* dictType */, false /* isUpdatable */);

        if ((factoryFlag & MASK_DICTBUFFER) == USE_READONLY_BYTEBUFFER) {
            mBufferFactory = new DictionaryBufferFromReadOnlyByteBufferFactory();
@@ -137,6 +145,10 @@ public class Ver2DictDecoder extends AbstractDictDecoder {
    /* package */ Ver2DictDecoder(final File file, final DictionaryBufferFactory factory) {
        mDictionaryBinaryFile = file;
        mBufferFactory = factory;
        // dictType is not being used in dicttool. Passing an empty string.
        mBinaryDictionary = new BinaryDictionary(file.getAbsolutePath(),
                0 /* offset */, file.length() /* length */, true /* useFullEditDistance */,
                null /* locale */, "" /* dictType */, false /* isUpdatable */);
    }

    @Override
@@ -238,24 +250,47 @@ public class Ver2DictDecoder extends AbstractDictDecoder {
    @Override
    public FusionDictionary readDictionaryBinary(final boolean deleteDictIfBroken)
            throws FileNotFoundException, IOException, UnsupportedFormatException {
        if (mDictBuffer == null) {
            openDictBuffer();
        final DictionaryHeader header = readHeader();
        final FusionDictionary fusionDict =
                new FusionDictionary(new FusionDictionary.PtNodeArray(), header.mDictionaryOptions);
        int token = 0;
        final ArrayList<WordProperty> wordProperties = CollectionUtils.newArrayList();
        do {
            final BinaryDictionary.GetNextWordPropertyResult result =
                    mBinaryDictionary.getNextWordProperty(token);
            final WordProperty wordProperty = result.mWordProperty;
            if (wordProperty == null) {
                if (deleteDictIfBroken) {
                    mBinaryDictionary.close();
                    mDictionaryBinaryFile.delete();
                }
                return null;
            }
            wordProperties.add(wordProperty);
            token = result.mNextToken;
        } while (token != 0);

        // Insert unigrams into the fusion dictionary.
        for (final WordProperty wordProperty : wordProperties) {
            if (wordProperty.mIsBlacklistEntry) {
                fusionDict.addBlacklistEntry(wordProperty.mWord, wordProperty.mShortcutTargets,
                        wordProperty.mIsNotAWord);
            } else {
                fusionDict.add(wordProperty.mWord, wordProperty.mProbabilityInfo,
                        wordProperty.mShortcutTargets, wordProperty.mIsNotAWord);
            }
        }
        try {
            return BinaryDictDecoderUtils.readDictionaryBinary(this);
        } catch (IOException e) {
            Log.e(TAG, "The dictionary " + mDictionaryBinaryFile.getName() + " is broken.", e);
            if (deleteDictIfBroken && !mDictionaryBinaryFile.delete()) {
                Log.e(TAG, "Failed to delete the broken dictionary.");
        // Insert bigrams into the fusion dictionary.
        for (final WordProperty wordProperty : wordProperties) {
            if (wordProperty.mBigrams == null) {
                continue;
            }
            throw e;
        } catch (UnsupportedFormatException e) {
            Log.e(TAG, "The dictionary " + mDictionaryBinaryFile.getName() + " is broken.", e);
            if (deleteDictIfBroken && !mDictionaryBinaryFile.delete()) {
                Log.e(TAG, "Failed to delete the broken dictionary.");
            final String word0 = wordProperty.mWord;
            for (final WeightedString bigram : wordProperty.mBigrams) {
                fusionDict.setBigram(word0, bigram.mWord, bigram.mProbabilityInfo);
            }
            throw e;
        }
        return fusionDict;
    }

    @Override
+3 −2
Original line number Diff line number Diff line
@@ -45,6 +45,7 @@ public class Ver4DictDecoder extends AbstractDictDecoder {
    @UsedForTesting
    /* package */ Ver4DictDecoder(final File dictDirectory, final DictionaryBufferFactory factory) {
        mDictDirectory = dictDirectory;
        // dictType is not being used in dicttool. Passing an empty string.
        mBinaryDictionary = new BinaryDictionary(dictDirectory.getAbsolutePath(),
                0 /* offset */, 0 /* length */, true /* useFullEditDistance */, null /* locale */,
                "" /* dictType */, true /* isUpdatable */);
@@ -78,7 +79,7 @@ public class Ver4DictDecoder extends AbstractDictDecoder {
            token = result.mNextToken;
        } while (token != 0);

        // Insert unigrams to the fusion dictionary.
        // Insert unigrams into the fusion dictionary.
        for (final WordProperty wordProperty : wordProperties) {
            if (wordProperty.mIsBlacklistEntry) {
                fusionDict.addBlacklistEntry(wordProperty.mWord, wordProperty.mShortcutTargets,
@@ -88,7 +89,7 @@ public class Ver4DictDecoder extends AbstractDictDecoder {
                        wordProperty.mShortcutTargets, wordProperty.mIsNotAWord);
            }
        }
        // Insert bigrams to the fusion dictionary.
        // Insert bigrams into the fusion dictionary.
        for (final WordProperty wordProperty : wordProperties) {
            if (wordProperty.mBigrams == null) {
                continue;
+18 −16
Original line number Diff line number Diff line
@@ -30,18 +30,19 @@ namespace latinime {
class PtNodeParams {
 public:
    // Invalid PtNode.
    PtNodeParams() : mHeadPos(NOT_A_DICT_POS), mFlags(0), mParentPos(NOT_A_DICT_POS),
            mCodePointCount(0), mCodePoints(), mTerminalIdFieldPos(NOT_A_DICT_POS),
            mTerminalId(Ver4DictConstants::NOT_A_TERMINAL_ID), mProbabilityFieldPos(NOT_A_DICT_POS),
            mProbability(NOT_A_PROBABILITY), mChildrenPosFieldPos(NOT_A_DICT_POS),
            mChildrenPos(NOT_A_DICT_POS), mBigramLinkedNodePos(NOT_A_DICT_POS),
            mShortcutPos(NOT_A_DICT_POS), mBigramPos(NOT_A_DICT_POS),
            mSiblingPos(NOT_A_DICT_POS) {}
    PtNodeParams() : mHeadPos(NOT_A_DICT_POS), mFlags(0), mHasMovedFlag(false),
            mParentPos(NOT_A_DICT_POS), mCodePointCount(0), mCodePoints(),
            mTerminalIdFieldPos(NOT_A_DICT_POS), mTerminalId(Ver4DictConstants::NOT_A_TERMINAL_ID),
            mProbabilityFieldPos(NOT_A_DICT_POS), mProbability(NOT_A_PROBABILITY),
            mChildrenPosFieldPos(NOT_A_DICT_POS), mChildrenPos(NOT_A_DICT_POS),
            mBigramLinkedNodePos(NOT_A_DICT_POS), mShortcutPos(NOT_A_DICT_POS),
            mBigramPos(NOT_A_DICT_POS), mSiblingPos(NOT_A_DICT_POS) {}

    PtNodeParams(const PtNodeParams& ptNodeParams)
            : mHeadPos(ptNodeParams.mHeadPos), mFlags(ptNodeParams.mFlags),
              mParentPos(ptNodeParams.mParentPos), mCodePointCount(ptNodeParams.mCodePointCount),
              mCodePoints(), mTerminalIdFieldPos(ptNodeParams.mTerminalIdFieldPos),
              mHasMovedFlag(ptNodeParams.mHasMovedFlag), mParentPos(ptNodeParams.mParentPos),
              mCodePointCount(ptNodeParams.mCodePointCount), mCodePoints(),
              mTerminalIdFieldPos(ptNodeParams.mTerminalIdFieldPos),
              mTerminalId(ptNodeParams.mTerminalId),
              mProbabilityFieldPos(ptNodeParams.mProbabilityFieldPos),
              mProbability(ptNodeParams.mProbability),
@@ -58,7 +59,7 @@ class PtNodeParams {
            const int codePointCount, const int *const codePoints, const int probability,
            const int childrenPos, const int shortcutPos, const int bigramPos,
            const int siblingPos)
            : mHeadPos(headPos), mFlags(flags), mParentPos(NOT_A_DICT_POS),
            : mHeadPos(headPos), mFlags(flags), mHasMovedFlag(false), mParentPos(NOT_A_DICT_POS),
              mCodePointCount(codePointCount), mCodePoints(), mTerminalIdFieldPos(NOT_A_DICT_POS),
              mTerminalId(Ver4DictConstants::NOT_A_TERMINAL_ID),
              mProbabilityFieldPos(NOT_A_DICT_POS), mProbability(probability),
@@ -73,7 +74,7 @@ class PtNodeParams {
            const int parentPos, const int codePointCount, const int *const codePoints,
            const int terminalIdFieldPos, const int terminalId, const int probability,
            const int childrenPosFieldPos, const int childrenPos, const int siblingPos)
            : mHeadPos(headPos), mFlags(flags), mParentPos(parentPos),
            : mHeadPos(headPos), mFlags(flags), mHasMovedFlag(true), mParentPos(parentPos),
              mCodePointCount(codePointCount), mCodePoints(),
              mTerminalIdFieldPos(terminalIdFieldPos), mTerminalId(terminalId),
              mProbabilityFieldPos(NOT_A_DICT_POS), mProbability(probability),
@@ -87,8 +88,8 @@ class PtNodeParams {
    PtNodeParams(const PtNodeParams *const ptNodeParams,
            const PatriciaTrieReadingUtils::NodeFlags flags, const int parentPos,
            const int codePointCount, const int *const codePoints, const int probability)
            : mHeadPos(ptNodeParams->getHeadPos()), mFlags(flags), mParentPos(parentPos),
              mCodePointCount(codePointCount), mCodePoints(),
            : mHeadPos(ptNodeParams->getHeadPos()), mFlags(flags), mHasMovedFlag(true),
              mParentPos(parentPos), mCodePointCount(codePointCount), mCodePoints(),
              mTerminalIdFieldPos(ptNodeParams->getTerminalIdFieldPos()),
              mTerminalId(ptNodeParams->getTerminalId()),
              mProbabilityFieldPos(ptNodeParams->getProbabilityFieldPos()),
@@ -104,7 +105,7 @@ class PtNodeParams {

    PtNodeParams(const PatriciaTrieReadingUtils::NodeFlags flags, const int parentPos,
            const int codePointCount, const int *const codePoints, const int probability)
            : mHeadPos(NOT_A_DICT_POS), mFlags(flags), mParentPos(parentPos),
            : mHeadPos(NOT_A_DICT_POS), mFlags(flags), mHasMovedFlag(true), mParentPos(parentPos),
              mCodePointCount(codePointCount), mCodePoints(),
              mTerminalIdFieldPos(NOT_A_DICT_POS),
              mTerminalId(Ver4DictConstants::NOT_A_TERMINAL_ID),
@@ -126,11 +127,11 @@ class PtNodeParams {

    // Flags
    AK_FORCE_INLINE bool isDeleted() const {
        return DynamicPtReadingUtils::isDeleted(mFlags);
        return mHasMovedFlag && DynamicPtReadingUtils::isDeleted(mFlags);
    }

    AK_FORCE_INLINE bool willBecomeNonTerminal() const {
        return DynamicPtReadingUtils::willBecomeNonTerminal(mFlags);
        return mHasMovedFlag && DynamicPtReadingUtils::willBecomeNonTerminal(mFlags);
    }

    AK_FORCE_INLINE bool hasChildren() const {
@@ -224,6 +225,7 @@ class PtNodeParams {

    const int mHeadPos;
    const PatriciaTrieReadingUtils::NodeFlags mFlags;
    const bool mHasMovedFlag;
    const int mParentPos;
    const uint8_t mCodePointCount;
    int mCodePoints[MAX_WORD_LENGTH];
+29 −0
Original line number Diff line number Diff line
@@ -363,4 +363,33 @@ const WordProperty PatriciaTriePolicy::getWordProperty(const int *const codePoin
            &bigrams, &shortcuts);
}

int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints) {
    if (token == 0) {
        // Start iterating the dictionary.
        mTerminalPtNodePositionsForIteratingWords.clear();
        DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy(
                &mTerminalPtNodePositionsForIteratingWords);
        DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader);
        readingHelper.initWithPtNodeArrayPos(getRootPosition());
        readingHelper.traverseAllPtNodesInPostorderDepthFirstManner(&traversePolicy);
    }
    const int terminalPtNodePositionsVectorSize =
            static_cast<int>(mTerminalPtNodePositionsForIteratingWords.size());
    if (token < 0 || token >= terminalPtNodePositionsVectorSize) {
        AKLOGE("Given token %d is invalid.", token);
        return 0;
    }
    const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
    int unigramProbability = NOT_A_PROBABILITY;
    getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, MAX_WORD_LENGTH,
            outCodePoints, &unigramProbability);
    const int nextToken = token + 1;
    if (nextToken >= terminalPtNodePositionsVectorSize) {
        // All words have been iterated.
        mTerminalPtNodePositionsForIteratingWords.clear();
        return 0;
    }
    return nextToken;
}

} // namespace latinime
+5 −5
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@
#define LATINIME_PATRICIA_TRIE_POLICY_H

#include <stdint.h>
#include <vector>

#include "defines.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
@@ -44,7 +45,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
                      - mHeaderPolicy.getSize()),
              mBigramListPolicy(mDictRoot), mShortcutListPolicy(mDictRoot),
              mPtNodeReader(mDictRoot, mDictBufferSize, &mBigramListPolicy, &mShortcutListPolicy),
              mPtNodeArrayReader(mDictRoot, mDictBufferSize) {}
              mPtNodeArrayReader(mDictRoot, mDictBufferSize),
              mTerminalPtNodePositionsForIteratingWords() {}

    AK_FORCE_INLINE int getRootPosition() const {
        return 0;
@@ -130,10 +132,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
    const WordProperty getWordProperty(const int *const codePoints,
            const int codePointCount) const;

    int getNextWordAndNextToken(const int token, int *const outCodePoints) {
        // getNextWordAndNextToken is not supported.
        return 0;
    }
    int getNextWordAndNextToken(const int token, int *const outCodePoints);

 private:
    DISALLOW_IMPLICIT_CONSTRUCTORS(PatriciaTriePolicy);
@@ -146,6 +145,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
    const ShortcutListPolicy mShortcutListPolicy;
    const Ver2ParticiaTrieNodeReader mPtNodeReader;
    const Ver2PtNodeArrayReader mPtNodeArrayReader;
    std::vector<int> mTerminalPtNodePositionsForIteratingWords;

    int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos,
            DicNodeVector *const childDicNodes) const;
Loading