Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit df266ac7 authored by Keisuke Kuroynagi's avatar Keisuke Kuroynagi Committed by Android Git Automerger
Browse files

am b6f286bf: Merge "Make bigram dictionary and traverse session use structure policy."

* commit 'b6f286bf':
  Make bigram dictionary and traverse session use structure policy.
parents 1a16cdc6 b6f286bf
Loading
Loading
Loading
Loading
+6 −15
Original line number Diff line number Diff line
@@ -123,9 +123,10 @@ int BigramDictionary::getPredictions(const int *prevWord, int prevWordLength, in
    for (BinaryDictionaryBigramsIterator bigramsIt(mBinaryDictionaryInfo, pos);
            bigramsIt.hasNext(); /* no-op */) {
        bigramsIt.next();
        const int length = BinaryFormat::getWordAtAddress(
                mBinaryDictionaryInfo->getDictRoot(), bigramsIt.getBigramPos(),
                MAX_WORD_LENGTH, bigramBuffer, &unigramProbability);
        const int length = mBinaryDictionaryInfo->getStructurePolicy()->
                getCodePointsAndProbabilityAndReturnCodePointCount(
                        mBinaryDictionaryInfo, bigramsIt.getBigramPos(), MAX_WORD_LENGTH,
                        bigramBuffer, &unigramProbability);

        // inputSize == 0 means we are trying to find bigram predictions.
        if (inputSize < 1 || checkFirstCharacter(bigramBuffer, inputCodePoints)) {
@@ -153,18 +154,8 @@ int BigramDictionary::getBigramListPositionForWord(const int *prevWord, const in
    int pos = mBinaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord(
            mBinaryDictionaryInfo, prevWord, prevWordLength, forceLowerCaseSearch);
    if (NOT_VALID_WORD == pos) return 0;
    const uint8_t *const root = mBinaryDictionaryInfo->getDictRoot();
    const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos);
    if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) return 0;
    if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) {
        BinaryFormat::getCodePointAndForwardPointer(root, &pos);
    } else {
        pos = BinaryFormat::skipOtherCharacters(root, pos);
    }
    pos = BinaryFormat::skipProbability(flags, pos);
    pos = BinaryFormat::skipChildrenPosition(flags, pos);
    pos = BinaryFormat::skipShortcuts(root, flags, pos);
    return pos;
    return BinaryFormat::getBigramListPositionForWordPosition(
            mBinaryDictionaryInfo->getDictRoot(), pos);
}

bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const {
+15 −13
Original line number Diff line number Diff line
@@ -71,8 +71,9 @@ class BinaryFormat {
    static bool hasChildrenInFlags(const uint8_t flags);
    static int getTerminalPosition(const uint8_t *const root, const int *const inWord,
            const int length, const bool forceLowerCaseSearch);
    static int getWordAtAddress(const uint8_t *const root, const int address, const int maxDepth,
            int *outWord, int *outUnigramProbability);
    static int getCodePointsAndProbabilityAndReturnCodePointCount(
            const uint8_t *const root, const int nodePos, const int maxCodePointCount,
            int *outCodePoints, int *outUnigramProbability);
    static int getBigramListPositionForWordPosition(const uint8_t *const root, int position);

 private:
@@ -342,8 +343,9 @@ AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root,
 * outUnigramProbability: a pointer to an int to write the probability into.
 * Return value : the length of the word, of 0 if the word was not found.
 */
AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, const int address,
        const int maxDepth, int *outWord, int *outUnigramProbability) {
AK_FORCE_INLINE int BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount(
        const uint8_t *const root, const int nodePos,
        const int maxCodePointCount, int *outCodePoints, int *outUnigramProbability) {
    int pos = 0;
    int wordPos = 0;

@@ -353,7 +355,7 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
    // The only reason we count nodes is because we want to reduce the probability of infinite
    // looping in case there is a bug. Since we know there is an upper bound to the depth we are
    // supposed to traverse, it does not hurt to count iterations.
    for (int loopCount = maxDepth; loopCount > 0; --loopCount) {
    for (int loopCount = maxCodePointCount; loopCount > 0; --loopCount) {
        int lastCandidateGroupPos = 0;
        // Let's loop through char groups in this node searching for either the terminal
        // or one of its ascendants.
@@ -362,17 +364,17 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
            const int startPos = pos;
            const uint8_t flags = getFlagsAndForwardPointer(root, &pos);
            const int character = getCodePointAndForwardPointer(root, &pos);
            if (address == startPos) {
            if (nodePos == startPos) {
                // We found the address. Copy the rest of the word in the buffer and return
                // the length.
                outWord[wordPos] = character;
                outCodePoints[wordPos] = character;
                if (FLAG_HAS_MULTIPLE_CHARS & flags) {
                    int nextChar = getCodePointAndForwardPointer(root, &pos);
                    // We count chars in order to avoid infinite loops if the file is broken or
                    // if there is some other bug
                    int charCount = maxDepth;
                    int charCount = maxCodePointCount;
                    while (NOT_A_CODE_POINT != nextChar && --charCount > 0) {
                        outWord[++wordPos] = nextChar;
                        outCodePoints[++wordPos] = nextChar;
                        nextChar = getCodePointAndForwardPointer(root, &pos);
                    }
                }
@@ -399,7 +401,7 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
            if (hasChildren) {
                // Here comes the tricky part. First, read the children position.
                const int childrenPos = readChildrenPosition(root, flags, pos);
                if (childrenPos > address) {
                if (childrenPos > nodePos) {
                    // If the children pos is greater than address, it means the previous chargroup,
                    // which address is stored in lastCandidateGroupPos, was the right one.
                    found = true;
@@ -429,12 +431,12 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
                    const int lastChar =
                            getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
                    // We copy all the characters in this group to the buffer
                    outWord[wordPos] = lastChar;
                    outCodePoints[wordPos] = lastChar;
                    if (FLAG_HAS_MULTIPLE_CHARS & lastFlags) {
                        int nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
                        int charCount = maxDepth;
                        int charCount = maxCodePointCount;
                        while (-1 != nextChar && --charCount > 0) {
                            outWord[++wordPos] = nextChar;
                            outCodePoints[++wordPos] = nextChar;
                            nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
                        }
                    }
+3 −2
Original line number Diff line number Diff line
@@ -50,8 +50,9 @@ class DictionaryStructurePolicy {
            const BinaryDictionaryInfo *const binaryDictionaryInfo,
            const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const = 0;

    virtual void getWordAtPosition(const BinaryDictionaryInfo *const binaryDictionaryInfo,
            const int terminalNodePos, const int maxDepth, int *const outWord,
    virtual int getCodePointsAndProbabilityAndReturnCodePointCount(
            const BinaryDictionaryInfo *const binaryDictionaryInfo,
            const int nodePos, const int maxCodePointCount, int *const outCodePoints,
            int *const outUnigramProbability) const = 0;

    virtual int getTerminalNodePositionOfWord(
+7 −10
Original line number Diff line number Diff line
@@ -18,10 +18,8 @@

#include "defines.h"
#include "jni.h"
#include "suggest/core/dicnode/dic_node_utils.h"
#include "suggest/core/dictionary/binary_dictionary_header.h"
#include "suggest/core/dictionary/binary_dictionary_info.h"
#include "suggest/core/dictionary/binary_format.h"
#include "suggest/core/dictionary/dictionary.h"

namespace latinime {
@@ -29,23 +27,22 @@ namespace latinime {
void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
        int prevWordLength, const SuggestOptions *const suggestOptions) {
    mDictionary = dictionary;
    mMultiWordCostMultiplier = mDictionary->getBinaryDictionaryInfo()
            ->getHeader()->getMultiWordCostMultiplier();
    const BinaryDictionaryInfo *const binaryDictionaryInfo =
            mDictionary->getBinaryDictionaryInfo();
    mMultiWordCostMultiplier = binaryDictionaryInfo->getHeader()->getMultiWordCostMultiplier();
    mSuggestOptions = suggestOptions;
    if (!prevWord) {
        mPrevWordPos = NOT_VALID_WORD;
        return;
    }
    // TODO: merge following similar calls to getTerminalPosition into one case-insensitive call.
    mPrevWordPos = BinaryFormat::getTerminalPosition(
            dictionary->getBinaryDictionaryInfo()->getDictRoot(), prevWord,
            prevWordLength, false /* forceLowerCaseSearch */);
    mPrevWordPos = binaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord(
            binaryDictionaryInfo, prevWord, prevWordLength, false /* forceLowerCaseSearch */);
    if (mPrevWordPos == NOT_VALID_WORD) {
        // Check bigrams for lower-cased previous word if original was not found. Useful for
        // auto-capitalized words like "The [current_word]".
        mPrevWordPos = BinaryFormat::getTerminalPosition(
                dictionary->getBinaryDictionaryInfo()->getDictRoot(), prevWord,
                prevWordLength, true /* forceLowerCaseSearch */);
        mPrevWordPos = binaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord(
                binaryDictionaryInfo, prevWord, prevWordLength, true /* forceLowerCaseSearch */);
    }
}

+6 −4
Original line number Diff line number Diff line
@@ -33,11 +33,13 @@ void PatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const dicNode,
    // TODO: Move children creating methods form DicNodeUtils.
}

void PatriciaTriePolicy::getWordAtPosition(const BinaryDictionaryInfo *const binaryDictionaryInfo,
        const int terminalNodePos, const int maxDepth, int *const outWord,
int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
        const BinaryDictionaryInfo *const binaryDictionaryInfo,
        const int nodePos, const int maxCodePointCount, int *const outCodePoints,
        int *const outUnigramProbability) const {
    BinaryFormat::getWordAtAddress(binaryDictionaryInfo->getDictRoot(), terminalNodePos,
            maxDepth, outWord, outUnigramProbability);
    return BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount(
            binaryDictionaryInfo->getDictRoot(), nodePos,
            maxCodePointCount, outCodePoints, outUnigramProbability);
}

int PatriciaTriePolicy::getTerminalNodePositionOfWord(
Loading