Loading native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp +6 −15 Original line number Diff line number Diff line Loading @@ -123,9 +123,10 @@ int BigramDictionary::getPredictions(const int *prevWord, int prevWordLength, in for (BinaryDictionaryBigramsIterator bigramsIt(mBinaryDictionaryInfo, pos); bigramsIt.hasNext(); /* no-op */) { bigramsIt.next(); const int length = BinaryFormat::getWordAtAddress( mBinaryDictionaryInfo->getDictRoot(), bigramsIt.getBigramPos(), MAX_WORD_LENGTH, bigramBuffer, &unigramProbability); const int length = mBinaryDictionaryInfo->getStructurePolicy()-> getCodePointsAndProbabilityAndReturnCodePointCount( mBinaryDictionaryInfo, bigramsIt.getBigramPos(), MAX_WORD_LENGTH, bigramBuffer, &unigramProbability); // inputSize == 0 means we are trying to find bigram predictions. if (inputSize < 1 || checkFirstCharacter(bigramBuffer, inputCodePoints)) { Loading Loading @@ -153,18 +154,8 @@ int BigramDictionary::getBigramListPositionForWord(const int *prevWord, const in int pos = mBinaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord( mBinaryDictionaryInfo, prevWord, prevWordLength, forceLowerCaseSearch); if (NOT_VALID_WORD == pos) return 0; const uint8_t *const root = mBinaryDictionaryInfo->getDictRoot(); const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) return 0; if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) { BinaryFormat::getCodePointAndForwardPointer(root, &pos); } else { pos = BinaryFormat::skipOtherCharacters(root, pos); } pos = BinaryFormat::skipProbability(flags, pos); pos = BinaryFormat::skipChildrenPosition(flags, pos); pos = BinaryFormat::skipShortcuts(root, flags, pos); return pos; return BinaryFormat::getBigramListPositionForWordPosition( mBinaryDictionaryInfo->getDictRoot(), pos); } bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const { Loading native/jni/src/suggest/core/dictionary/binary_format.h +15 −13 Original line number Diff line number Diff line Loading @@ -71,8 +71,9 @@ class BinaryFormat { static bool hasChildrenInFlags(const uint8_t flags); static int getTerminalPosition(const uint8_t *const root, const int *const inWord, const int length, const bool forceLowerCaseSearch); static int getWordAtAddress(const uint8_t *const root, const int address, const int maxDepth, int *outWord, int *outUnigramProbability); static int getCodePointsAndProbabilityAndReturnCodePointCount( const uint8_t *const root, const int nodePos, const int maxCodePointCount, int *outCodePoints, int *outUnigramProbability); static int getBigramListPositionForWordPosition(const uint8_t *const root, int position); private: Loading Loading @@ -342,8 +343,9 @@ AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root, * outUnigramProbability: a pointer to an int to write the probability into. * Return value : the length of the word, of 0 if the word was not found. */ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, const int address, const int maxDepth, int *outWord, int *outUnigramProbability) { AK_FORCE_INLINE int BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount( const uint8_t *const root, const int nodePos, const int maxCodePointCount, int *outCodePoints, int *outUnigramProbability) { int pos = 0; int wordPos = 0; Loading @@ -353,7 +355,7 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co // The only reason we count nodes is because we want to reduce the probability of infinite // looping in case there is a bug. Since we know there is an upper bound to the depth we are // supposed to traverse, it does not hurt to count iterations. for (int loopCount = maxDepth; loopCount > 0; --loopCount) { for (int loopCount = maxCodePointCount; loopCount > 0; --loopCount) { int lastCandidateGroupPos = 0; // Let's loop through char groups in this node searching for either the terminal // or one of its ascendants. Loading @@ -362,17 +364,17 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co const int startPos = pos; const uint8_t flags = getFlagsAndForwardPointer(root, &pos); const int character = getCodePointAndForwardPointer(root, &pos); if (address == startPos) { if (nodePos == startPos) { // We found the address. Copy the rest of the word in the buffer and return // the length. outWord[wordPos] = character; outCodePoints[wordPos] = character; if (FLAG_HAS_MULTIPLE_CHARS & flags) { int nextChar = getCodePointAndForwardPointer(root, &pos); // We count chars in order to avoid infinite loops if the file is broken or // if there is some other bug int charCount = maxDepth; int charCount = maxCodePointCount; while (NOT_A_CODE_POINT != nextChar && --charCount > 0) { outWord[++wordPos] = nextChar; outCodePoints[++wordPos] = nextChar; nextChar = getCodePointAndForwardPointer(root, &pos); } } Loading @@ -399,7 +401,7 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co if (hasChildren) { // Here comes the tricky part. First, read the children position. const int childrenPos = readChildrenPosition(root, flags, pos); if (childrenPos > address) { if (childrenPos > nodePos) { // If the children pos is greater than address, it means the previous chargroup, // which address is stored in lastCandidateGroupPos, was the right one. found = true; Loading Loading @@ -429,12 +431,12 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co const int lastChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); // We copy all the characters in this group to the buffer outWord[wordPos] = lastChar; outCodePoints[wordPos] = lastChar; if (FLAG_HAS_MULTIPLE_CHARS & lastFlags) { int nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); int charCount = maxDepth; int charCount = maxCodePointCount; while (-1 != nextChar && --charCount > 0) { outWord[++wordPos] = nextChar; outCodePoints[++wordPos] = nextChar; nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); } } Loading native/jni/src/suggest/core/policy/dictionary_structure_policy.h +3 −2 Original line number Diff line number Diff line Loading @@ -50,8 +50,9 @@ class DictionaryStructurePolicy { const BinaryDictionaryInfo *const binaryDictionaryInfo, const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const = 0; virtual void getWordAtPosition(const BinaryDictionaryInfo *const binaryDictionaryInfo, const int terminalNodePos, const int maxDepth, int *const outWord, virtual int getCodePointsAndProbabilityAndReturnCodePointCount( const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const = 0; virtual int getTerminalNodePositionOfWord( Loading native/jni/src/suggest/core/session/dic_traverse_session.cpp +7 −10 Original line number Diff line number Diff line Loading @@ -18,10 +18,8 @@ #include "defines.h" #include "jni.h" #include "suggest/core/dicnode/dic_node_utils.h" #include "suggest/core/dictionary/binary_dictionary_header.h" #include "suggest/core/dictionary/binary_dictionary_info.h" #include "suggest/core/dictionary/binary_format.h" #include "suggest/core/dictionary/dictionary.h" namespace latinime { Loading @@ -29,23 +27,22 @@ namespace latinime { void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, int prevWordLength, const SuggestOptions *const suggestOptions) { mDictionary = dictionary; mMultiWordCostMultiplier = mDictionary->getBinaryDictionaryInfo() ->getHeader()->getMultiWordCostMultiplier(); const BinaryDictionaryInfo *const binaryDictionaryInfo = mDictionary->getBinaryDictionaryInfo(); mMultiWordCostMultiplier = binaryDictionaryInfo->getHeader()->getMultiWordCostMultiplier(); mSuggestOptions = suggestOptions; if (!prevWord) { mPrevWordPos = NOT_VALID_WORD; return; } // TODO: merge following similar calls to getTerminalPosition into one case-insensitive call. mPrevWordPos = BinaryFormat::getTerminalPosition( dictionary->getBinaryDictionaryInfo()->getDictRoot(), prevWord, prevWordLength, false /* forceLowerCaseSearch */); mPrevWordPos = binaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord( binaryDictionaryInfo, prevWord, prevWordLength, false /* forceLowerCaseSearch */); if (mPrevWordPos == NOT_VALID_WORD) { // Check bigrams for lower-cased previous word if original was not found. Useful for // auto-capitalized words like "The [current_word]". mPrevWordPos = BinaryFormat::getTerminalPosition( dictionary->getBinaryDictionaryInfo()->getDictRoot(), prevWord, prevWordLength, true /* forceLowerCaseSearch */); mPrevWordPos = binaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord( binaryDictionaryInfo, prevWord, prevWordLength, true /* forceLowerCaseSearch */); } } Loading native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp +6 −4 Original line number Diff line number Diff line Loading @@ -33,11 +33,13 @@ void PatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const dicNode, // TODO: Move children creating methods form DicNodeUtils. } void PatriciaTriePolicy::getWordAtPosition(const BinaryDictionaryInfo *const binaryDictionaryInfo, const int terminalNodePos, const int maxDepth, int *const outWord, int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const { BinaryFormat::getWordAtAddress(binaryDictionaryInfo->getDictRoot(), terminalNodePos, maxDepth, outWord, outUnigramProbability); return BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount( binaryDictionaryInfo->getDictRoot(), nodePos, maxCodePointCount, outCodePoints, outUnigramProbability); } int PatriciaTriePolicy::getTerminalNodePositionOfWord( Loading Loading
native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp +6 −15 Original line number Diff line number Diff line Loading @@ -123,9 +123,10 @@ int BigramDictionary::getPredictions(const int *prevWord, int prevWordLength, in for (BinaryDictionaryBigramsIterator bigramsIt(mBinaryDictionaryInfo, pos); bigramsIt.hasNext(); /* no-op */) { bigramsIt.next(); const int length = BinaryFormat::getWordAtAddress( mBinaryDictionaryInfo->getDictRoot(), bigramsIt.getBigramPos(), MAX_WORD_LENGTH, bigramBuffer, &unigramProbability); const int length = mBinaryDictionaryInfo->getStructurePolicy()-> getCodePointsAndProbabilityAndReturnCodePointCount( mBinaryDictionaryInfo, bigramsIt.getBigramPos(), MAX_WORD_LENGTH, bigramBuffer, &unigramProbability); // inputSize == 0 means we are trying to find bigram predictions. if (inputSize < 1 || checkFirstCharacter(bigramBuffer, inputCodePoints)) { Loading Loading @@ -153,18 +154,8 @@ int BigramDictionary::getBigramListPositionForWord(const int *prevWord, const in int pos = mBinaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord( mBinaryDictionaryInfo, prevWord, prevWordLength, forceLowerCaseSearch); if (NOT_VALID_WORD == pos) return 0; const uint8_t *const root = mBinaryDictionaryInfo->getDictRoot(); const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) return 0; if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) { BinaryFormat::getCodePointAndForwardPointer(root, &pos); } else { pos = BinaryFormat::skipOtherCharacters(root, pos); } pos = BinaryFormat::skipProbability(flags, pos); pos = BinaryFormat::skipChildrenPosition(flags, pos); pos = BinaryFormat::skipShortcuts(root, flags, pos); return pos; return BinaryFormat::getBigramListPositionForWordPosition( mBinaryDictionaryInfo->getDictRoot(), pos); } bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const { Loading
native/jni/src/suggest/core/dictionary/binary_format.h +15 −13 Original line number Diff line number Diff line Loading @@ -71,8 +71,9 @@ class BinaryFormat { static bool hasChildrenInFlags(const uint8_t flags); static int getTerminalPosition(const uint8_t *const root, const int *const inWord, const int length, const bool forceLowerCaseSearch); static int getWordAtAddress(const uint8_t *const root, const int address, const int maxDepth, int *outWord, int *outUnigramProbability); static int getCodePointsAndProbabilityAndReturnCodePointCount( const uint8_t *const root, const int nodePos, const int maxCodePointCount, int *outCodePoints, int *outUnigramProbability); static int getBigramListPositionForWordPosition(const uint8_t *const root, int position); private: Loading Loading @@ -342,8 +343,9 @@ AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root, * outUnigramProbability: a pointer to an int to write the probability into. * Return value : the length of the word, of 0 if the word was not found. */ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, const int address, const int maxDepth, int *outWord, int *outUnigramProbability) { AK_FORCE_INLINE int BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount( const uint8_t *const root, const int nodePos, const int maxCodePointCount, int *outCodePoints, int *outUnigramProbability) { int pos = 0; int wordPos = 0; Loading @@ -353,7 +355,7 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co // The only reason we count nodes is because we want to reduce the probability of infinite // looping in case there is a bug. Since we know there is an upper bound to the depth we are // supposed to traverse, it does not hurt to count iterations. for (int loopCount = maxDepth; loopCount > 0; --loopCount) { for (int loopCount = maxCodePointCount; loopCount > 0; --loopCount) { int lastCandidateGroupPos = 0; // Let's loop through char groups in this node searching for either the terminal // or one of its ascendants. Loading @@ -362,17 +364,17 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co const int startPos = pos; const uint8_t flags = getFlagsAndForwardPointer(root, &pos); const int character = getCodePointAndForwardPointer(root, &pos); if (address == startPos) { if (nodePos == startPos) { // We found the address. Copy the rest of the word in the buffer and return // the length. outWord[wordPos] = character; outCodePoints[wordPos] = character; if (FLAG_HAS_MULTIPLE_CHARS & flags) { int nextChar = getCodePointAndForwardPointer(root, &pos); // We count chars in order to avoid infinite loops if the file is broken or // if there is some other bug int charCount = maxDepth; int charCount = maxCodePointCount; while (NOT_A_CODE_POINT != nextChar && --charCount > 0) { outWord[++wordPos] = nextChar; outCodePoints[++wordPos] = nextChar; nextChar = getCodePointAndForwardPointer(root, &pos); } } Loading @@ -399,7 +401,7 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co if (hasChildren) { // Here comes the tricky part. First, read the children position. const int childrenPos = readChildrenPosition(root, flags, pos); if (childrenPos > address) { if (childrenPos > nodePos) { // If the children pos is greater than address, it means the previous chargroup, // which address is stored in lastCandidateGroupPos, was the right one. found = true; Loading Loading @@ -429,12 +431,12 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co const int lastChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); // We copy all the characters in this group to the buffer outWord[wordPos] = lastChar; outCodePoints[wordPos] = lastChar; if (FLAG_HAS_MULTIPLE_CHARS & lastFlags) { int nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); int charCount = maxDepth; int charCount = maxCodePointCount; while (-1 != nextChar && --charCount > 0) { outWord[++wordPos] = nextChar; outCodePoints[++wordPos] = nextChar; nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); } } Loading
native/jni/src/suggest/core/policy/dictionary_structure_policy.h +3 −2 Original line number Diff line number Diff line Loading @@ -50,8 +50,9 @@ class DictionaryStructurePolicy { const BinaryDictionaryInfo *const binaryDictionaryInfo, const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const = 0; virtual void getWordAtPosition(const BinaryDictionaryInfo *const binaryDictionaryInfo, const int terminalNodePos, const int maxDepth, int *const outWord, virtual int getCodePointsAndProbabilityAndReturnCodePointCount( const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const = 0; virtual int getTerminalNodePositionOfWord( Loading
native/jni/src/suggest/core/session/dic_traverse_session.cpp +7 −10 Original line number Diff line number Diff line Loading @@ -18,10 +18,8 @@ #include "defines.h" #include "jni.h" #include "suggest/core/dicnode/dic_node_utils.h" #include "suggest/core/dictionary/binary_dictionary_header.h" #include "suggest/core/dictionary/binary_dictionary_info.h" #include "suggest/core/dictionary/binary_format.h" #include "suggest/core/dictionary/dictionary.h" namespace latinime { Loading @@ -29,23 +27,22 @@ namespace latinime { void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, int prevWordLength, const SuggestOptions *const suggestOptions) { mDictionary = dictionary; mMultiWordCostMultiplier = mDictionary->getBinaryDictionaryInfo() ->getHeader()->getMultiWordCostMultiplier(); const BinaryDictionaryInfo *const binaryDictionaryInfo = mDictionary->getBinaryDictionaryInfo(); mMultiWordCostMultiplier = binaryDictionaryInfo->getHeader()->getMultiWordCostMultiplier(); mSuggestOptions = suggestOptions; if (!prevWord) { mPrevWordPos = NOT_VALID_WORD; return; } // TODO: merge following similar calls to getTerminalPosition into one case-insensitive call. mPrevWordPos = BinaryFormat::getTerminalPosition( dictionary->getBinaryDictionaryInfo()->getDictRoot(), prevWord, prevWordLength, false /* forceLowerCaseSearch */); mPrevWordPos = binaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord( binaryDictionaryInfo, prevWord, prevWordLength, false /* forceLowerCaseSearch */); if (mPrevWordPos == NOT_VALID_WORD) { // Check bigrams for lower-cased previous word if original was not found. Useful for // auto-capitalized words like "The [current_word]". mPrevWordPos = BinaryFormat::getTerminalPosition( dictionary->getBinaryDictionaryInfo()->getDictRoot(), prevWord, prevWordLength, true /* forceLowerCaseSearch */); mPrevWordPos = binaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord( binaryDictionaryInfo, prevWord, prevWordLength, true /* forceLowerCaseSearch */); } } Loading
native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp +6 −4 Original line number Diff line number Diff line Loading @@ -33,11 +33,13 @@ void PatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const dicNode, // TODO: Move children creating methods form DicNodeUtils. } void PatriciaTriePolicy::getWordAtPosition(const BinaryDictionaryInfo *const binaryDictionaryInfo, const int terminalNodePos, const int maxDepth, int *const outWord, int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const { BinaryFormat::getWordAtAddress(binaryDictionaryInfo->getDictRoot(), terminalNodePos, maxDepth, outWord, outUnigramProbability); return BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount( binaryDictionaryInfo->getDictRoot(), nodePos, maxCodePointCount, outCodePoints, outUnigramProbability); } int PatriciaTriePolicy::getTerminalNodePositionOfWord( Loading