Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit b6f286bf authored by Keisuke Kuroynagi's avatar Keisuke Kuroynagi Committed by Android (Google) Code Review
Browse files

Merge "Make bigram dictionary and traverse session use structure policy."

parents 4944827e 1311cdcb
Loading
Loading
Loading
Loading
+6 −15
Original line number Diff line number Diff line
@@ -123,9 +123,10 @@ int BigramDictionary::getPredictions(const int *prevWord, int prevWordLength, in
    for (BinaryDictionaryBigramsIterator bigramsIt(mBinaryDictionaryInfo, pos);
            bigramsIt.hasNext(); /* no-op */) {
        bigramsIt.next();
        const int length = BinaryFormat::getWordAtAddress(
                mBinaryDictionaryInfo->getDictRoot(), bigramsIt.getBigramPos(),
                MAX_WORD_LENGTH, bigramBuffer, &unigramProbability);
        const int length = mBinaryDictionaryInfo->getStructurePolicy()->
                getCodePointsAndProbabilityAndReturnCodePointCount(
                        mBinaryDictionaryInfo, bigramsIt.getBigramPos(), MAX_WORD_LENGTH,
                        bigramBuffer, &unigramProbability);

        // inputSize == 0 means we are trying to find bigram predictions.
        if (inputSize < 1 || checkFirstCharacter(bigramBuffer, inputCodePoints)) {
@@ -153,18 +154,8 @@ int BigramDictionary::getBigramListPositionForWord(const int *prevWord, const in
    int pos = mBinaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord(
            mBinaryDictionaryInfo, prevWord, prevWordLength, forceLowerCaseSearch);
    if (NOT_VALID_WORD == pos) return 0;
    const uint8_t *const root = mBinaryDictionaryInfo->getDictRoot();
    const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos);
    if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) return 0;
    if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) {
        BinaryFormat::getCodePointAndForwardPointer(root, &pos);
    } else {
        pos = BinaryFormat::skipOtherCharacters(root, pos);
    }
    pos = BinaryFormat::skipProbability(flags, pos);
    pos = BinaryFormat::skipChildrenPosition(flags, pos);
    pos = BinaryFormat::skipShortcuts(root, flags, pos);
    return pos;
    return BinaryFormat::getBigramListPositionForWordPosition(
            mBinaryDictionaryInfo->getDictRoot(), pos);
}

bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const {
+15 −13
Original line number Diff line number Diff line
@@ -71,8 +71,9 @@ class BinaryFormat {
    static bool hasChildrenInFlags(const uint8_t flags);
    static int getTerminalPosition(const uint8_t *const root, const int *const inWord,
            const int length, const bool forceLowerCaseSearch);
    static int getWordAtAddress(const uint8_t *const root, const int address, const int maxDepth,
            int *outWord, int *outUnigramProbability);
    static int getCodePointsAndProbabilityAndReturnCodePointCount(
            const uint8_t *const root, const int nodePos, const int maxCodePointCount,
            int *outCodePoints, int *outUnigramProbability);
    static int getBigramListPositionForWordPosition(const uint8_t *const root, int position);

 private:
@@ -342,8 +343,9 @@ AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root,
 * outUnigramProbability: a pointer to an int to write the probability into.
 * Return value : the length of the word, of 0 if the word was not found.
 */
AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, const int address,
        const int maxDepth, int *outWord, int *outUnigramProbability) {
AK_FORCE_INLINE int BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount(
        const uint8_t *const root, const int nodePos,
        const int maxCodePointCount, int *outCodePoints, int *outUnigramProbability) {
    int pos = 0;
    int wordPos = 0;

@@ -353,7 +355,7 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
    // The only reason we count nodes is because we want to reduce the probability of infinite
    // looping in case there is a bug. Since we know there is an upper bound to the depth we are
    // supposed to traverse, it does not hurt to count iterations.
    for (int loopCount = maxDepth; loopCount > 0; --loopCount) {
    for (int loopCount = maxCodePointCount; loopCount > 0; --loopCount) {
        int lastCandidateGroupPos = 0;
        // Let's loop through char groups in this node searching for either the terminal
        // or one of its ascendants.
@@ -362,17 +364,17 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
            const int startPos = pos;
            const uint8_t flags = getFlagsAndForwardPointer(root, &pos);
            const int character = getCodePointAndForwardPointer(root, &pos);
            if (address == startPos) {
            if (nodePos == startPos) {
                // We found the address. Copy the rest of the word in the buffer and return
                // the length.
                outWord[wordPos] = character;
                outCodePoints[wordPos] = character;
                if (FLAG_HAS_MULTIPLE_CHARS & flags) {
                    int nextChar = getCodePointAndForwardPointer(root, &pos);
                    // We count chars in order to avoid infinite loops if the file is broken or
                    // if there is some other bug
                    int charCount = maxDepth;
                    int charCount = maxCodePointCount;
                    while (NOT_A_CODE_POINT != nextChar && --charCount > 0) {
                        outWord[++wordPos] = nextChar;
                        outCodePoints[++wordPos] = nextChar;
                        nextChar = getCodePointAndForwardPointer(root, &pos);
                    }
                }
@@ -399,7 +401,7 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
            if (hasChildren) {
                // Here comes the tricky part. First, read the children position.
                const int childrenPos = readChildrenPosition(root, flags, pos);
                if (childrenPos > address) {
                if (childrenPos > nodePos) {
                    // If the children pos is greater than address, it means the previous chargroup,
                    // which address is stored in lastCandidateGroupPos, was the right one.
                    found = true;
@@ -429,12 +431,12 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
                    const int lastChar =
                            getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
                    // We copy all the characters in this group to the buffer
                    outWord[wordPos] = lastChar;
                    outCodePoints[wordPos] = lastChar;
                    if (FLAG_HAS_MULTIPLE_CHARS & lastFlags) {
                        int nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
                        int charCount = maxDepth;
                        int charCount = maxCodePointCount;
                        while (-1 != nextChar && --charCount > 0) {
                            outWord[++wordPos] = nextChar;
                            outCodePoints[++wordPos] = nextChar;
                            nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
                        }
                    }
+3 −2
Original line number Diff line number Diff line
@@ -50,8 +50,9 @@ class DictionaryStructurePolicy {
            const BinaryDictionaryInfo *const binaryDictionaryInfo,
            const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const = 0;

    virtual void getWordAtPosition(const BinaryDictionaryInfo *const binaryDictionaryInfo,
            const int terminalNodePos, const int maxDepth, int *const outWord,
    virtual int getCodePointsAndProbabilityAndReturnCodePointCount(
            const BinaryDictionaryInfo *const binaryDictionaryInfo,
            const int nodePos, const int maxCodePointCount, int *const outCodePoints,
            int *const outUnigramProbability) const = 0;

    virtual int getTerminalNodePositionOfWord(
+7 −10
Original line number Diff line number Diff line
@@ -18,10 +18,8 @@

#include "defines.h"
#include "jni.h"
#include "suggest/core/dicnode/dic_node_utils.h"
#include "suggest/core/dictionary/binary_dictionary_header.h"
#include "suggest/core/dictionary/binary_dictionary_info.h"
#include "suggest/core/dictionary/binary_format.h"
#include "suggest/core/dictionary/dictionary.h"

namespace latinime {
@@ -29,23 +27,22 @@ namespace latinime {
void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
        int prevWordLength, const SuggestOptions *const suggestOptions) {
    mDictionary = dictionary;
    mMultiWordCostMultiplier = mDictionary->getBinaryDictionaryInfo()
            ->getHeader()->getMultiWordCostMultiplier();
    const BinaryDictionaryInfo *const binaryDictionaryInfo =
            mDictionary->getBinaryDictionaryInfo();
    mMultiWordCostMultiplier = binaryDictionaryInfo->getHeader()->getMultiWordCostMultiplier();
    mSuggestOptions = suggestOptions;
    if (!prevWord) {
        mPrevWordPos = NOT_VALID_WORD;
        return;
    }
    // TODO: merge following similar calls to getTerminalPosition into one case-insensitive call.
    mPrevWordPos = BinaryFormat::getTerminalPosition(
            dictionary->getBinaryDictionaryInfo()->getDictRoot(), prevWord,
            prevWordLength, false /* forceLowerCaseSearch */);
    mPrevWordPos = binaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord(
            binaryDictionaryInfo, prevWord, prevWordLength, false /* forceLowerCaseSearch */);
    if (mPrevWordPos == NOT_VALID_WORD) {
        // Check bigrams for lower-cased previous word if original was not found. Useful for
        // auto-capitalized words like "The [current_word]".
        mPrevWordPos = BinaryFormat::getTerminalPosition(
                dictionary->getBinaryDictionaryInfo()->getDictRoot(), prevWord,
                prevWordLength, true /* forceLowerCaseSearch */);
        mPrevWordPos = binaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord(
                binaryDictionaryInfo, prevWord, prevWordLength, true /* forceLowerCaseSearch */);
    }
}

+6 −4
Original line number Diff line number Diff line
@@ -33,11 +33,13 @@ void PatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const dicNode,
    // TODO: Move children creating methods form DicNodeUtils.
}

void PatriciaTriePolicy::getWordAtPosition(const BinaryDictionaryInfo *const binaryDictionaryInfo,
        const int terminalNodePos, const int maxDepth, int *const outWord,
int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
        const BinaryDictionaryInfo *const binaryDictionaryInfo,
        const int nodePos, const int maxCodePointCount, int *const outCodePoints,
        int *const outUnigramProbability) const {
    BinaryFormat::getWordAtAddress(binaryDictionaryInfo->getDictRoot(), terminalNodePos,
            maxDepth, outWord, outUnigramProbability);
    return BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount(
            binaryDictionaryInfo->getDictRoot(), nodePos,
            maxCodePointCount, outCodePoints, outUnigramProbability);
}

int PatriciaTriePolicy::getTerminalNodePositionOfWord(
Loading