Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit f87bb77a authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi
Browse files

Create .cpp file for NgramContext.

Bug: 14425059

Change-Id: Ie950878817b9c80cc9c970e1a84880c9b9ab228a
parent 47fc656c
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -40,6 +40,7 @@ LATIN_IME_CORE_SRC_FILES := \
        proximity_info_state_utils.cpp) \
    suggest/core/policy/weighting.cpp \
    suggest/core/session/dic_traverse_session.cpp \
    suggest/core/session/ngram_context.cpp \
    $(addprefix suggest/core/result/, \
        suggestion_results.cpp \
        suggestions_output_utils.cpp) \
+123 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2014 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "suggest/core/session/ngram_context.h"

#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
#include "utils/char_utils.h"

namespace latinime {

NgramContext::NgramContext() : mPrevWordCount(0) {}

NgramContext::NgramContext(const NgramContext &ngramContext)
        : mPrevWordCount(ngramContext.mPrevWordCount) {
    for (size_t i = 0; i < mPrevWordCount; ++i) {
        mPrevWordCodePointCount[i] = ngramContext.mPrevWordCodePointCount[i];
        memmove(mPrevWordCodePoints[i], ngramContext.mPrevWordCodePoints[i],
                sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]);
        mIsBeginningOfSentence[i] = ngramContext.mIsBeginningOfSentence[i];
    }
}

NgramContext::NgramContext(const int prevWordCodePoints[][MAX_WORD_LENGTH],
        const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence,
        const size_t prevWordCount)
        : mPrevWordCount(std::min(NELEMS(mPrevWordCodePoints), prevWordCount)) {
    clear();
    for (size_t i = 0; i < mPrevWordCount; ++i) {
        if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) {
            continue;
        }
        memmove(mPrevWordCodePoints[i], prevWordCodePoints[i],
                sizeof(mPrevWordCodePoints[i][0]) * prevWordCodePointCount[i]);
        mPrevWordCodePointCount[i] = prevWordCodePointCount[i];
        mIsBeginningOfSentence[i] = isBeginningOfSentence[i];
    }
}

NgramContext::NgramContext(const int *const prevWordCodePoints, const int prevWordCodePointCount,
        const bool isBeginningOfSentence) : mPrevWordCount(1) {
    clear();
    if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) {
        return;
    }
    memmove(mPrevWordCodePoints[0], prevWordCodePoints,
            sizeof(mPrevWordCodePoints[0][0]) * prevWordCodePointCount);
    mPrevWordCodePointCount[0] = prevWordCodePointCount;
    mIsBeginningOfSentence[0] = isBeginningOfSentence;
}

bool NgramContext::isValid() const {
    if (mPrevWordCodePointCount[0] > 0) {
        return true;
    }
    if (mIsBeginningOfSentence[0]) {
        return true;
    }
    return false;
}

const CodePointArrayView NgramContext::getNthPrevWordCodePoints(const size_t n) const {
    if (n <= 0 || n > mPrevWordCount) {
        return CodePointArrayView();
    }
    return CodePointArrayView(mPrevWordCodePoints[n - 1], mPrevWordCodePointCount[n - 1]);
}

bool NgramContext::isNthPrevWordBeginningOfSentence(const size_t n) const {
    if (n <= 0 || n > mPrevWordCount) {
        return false;
    }
    return mIsBeginningOfSentence[n - 1];
}

/* static */ int NgramContext::getWordId(
        const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
        const int *const wordCodePoints, const int wordCodePointCount,
        const bool isBeginningOfSentence, const bool tryLowerCaseSearch) {
    if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) {
        return NOT_A_WORD_ID;
    }
    int codePoints[MAX_WORD_LENGTH];
    int codePointCount = wordCodePointCount;
    memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount);
    if (isBeginningOfSentence) {
        codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, codePointCount,
                MAX_WORD_LENGTH);
        if (codePointCount <= 0) {
            return NOT_A_WORD_ID;
        }
    }
    const CodePointArrayView codePointArrayView(codePoints, codePointCount);
    const int wordId = dictStructurePolicy->getWordId(codePointArrayView,
            false /* forceLowerCaseSearch */);
    if (wordId != NOT_A_WORD_ID || !tryLowerCaseSearch) {
        // Return the id when when the word was found or doesn't try lower case search.
        return wordId;
    }
    // Check bigrams for lower-cased previous word if original was not found. Useful for
    // auto-capitalized words like "The [current_word]".
    return dictStructurePolicy->getWordId(codePointArrayView, true /* forceLowerCaseSearch */);
}

void NgramContext::clear() {
    for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) {
        mPrevWordCodePointCount[i] = 0;
        mIsBeginningOfSentence[i] = false;
    }
}
} // namespace latinime
+15 −106
Original line number Diff line number Diff line
@@ -20,145 +20,54 @@
#include <array>

#include "defines.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
#include "utils/char_utils.h"
#include "utils/int_array_view.h"

namespace latinime {

// Rename to NgramContext.
class DictionaryStructureWithBufferPolicy;

class NgramContext {
 public:
    // No prev word information.
    NgramContext() : mPrevWordCount(0) {
        clear();
    }

    NgramContext(const NgramContext &ngramContext)
            : mPrevWordCount(ngramContext.mPrevWordCount) {
        for (size_t i = 0; i < mPrevWordCount; ++i) {
            mPrevWordCodePointCount[i] = ngramContext.mPrevWordCodePointCount[i];
            memmove(mPrevWordCodePoints[i], ngramContext.mPrevWordCodePoints[i],
                    sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]);
            mIsBeginningOfSentence[i] = ngramContext.mIsBeginningOfSentence[i];
        }
    }

    NgramContext();
    // Copy constructor to use this class with std::vector and use this class as a return value.
    NgramContext(const NgramContext &ngramContext);
    // Construct from previous words.
    NgramContext(const int prevWordCodePoints[][MAX_WORD_LENGTH],
            const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence,
            const size_t prevWordCount)
            : mPrevWordCount(std::min(NELEMS(mPrevWordCodePoints), prevWordCount)) {
        clear();
        for (size_t i = 0; i < mPrevWordCount; ++i) {
            if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) {
                continue;
            }
            memmove(mPrevWordCodePoints[i], prevWordCodePoints[i],
                    sizeof(mPrevWordCodePoints[i][0]) * prevWordCodePointCount[i]);
            mPrevWordCodePointCount[i] = prevWordCodePointCount[i];
            mIsBeginningOfSentence[i] = isBeginningOfSentence[i];
        }
    }

            const size_t prevWordCount);
    // Construct from a previous word.
    NgramContext(const int *const prevWordCodePoints, const int prevWordCodePointCount,
            const bool isBeginningOfSentence) : mPrevWordCount(1) {
        clear();
        if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) {
            return;
        }
        memmove(mPrevWordCodePoints[0], prevWordCodePoints,
                sizeof(mPrevWordCodePoints[0][0]) * prevWordCodePointCount);
        mPrevWordCodePointCount[0] = prevWordCodePointCount;
        mIsBeginningOfSentence[0] = isBeginningOfSentence;
    }
            const bool isBeginningOfSentence);

    size_t getPrevWordCount() const {
        return mPrevWordCount;
    }

    // TODO: Remove.
    const NgramContext getTrimmedNgramContext(const size_t maxPrevWordCount) const {
        return NgramContext(mPrevWordCodePoints, mPrevWordCodePointCount, mIsBeginningOfSentence,
                std::min(mPrevWordCount, maxPrevWordCount));
    }

    bool isValid() const {
        if (mPrevWordCodePointCount[0] > 0) {
            return true;
        }
        if (mIsBeginningOfSentence[0]) {
            return true;
        }
        return false;
    }
    bool isValid() const;

    template<size_t N>
    const WordIdArrayView getPrevWordIds(
            const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
            std::array<int, N> *const prevWordIdBuffer, const bool tryLowerCaseSearch) const {
            WordIdArray<N> *const prevWordIdBuffer, const bool tryLowerCaseSearch) const {
        for (size_t i = 0; i < std::min(mPrevWordCount, N); ++i) {
            prevWordIdBuffer->at(i) = getWordId(dictStructurePolicy,
                    mPrevWordCodePoints[i], mPrevWordCodePointCount[i],
                    mIsBeginningOfSentence[i], tryLowerCaseSearch);
            prevWordIdBuffer->at(i) = getWordId(dictStructurePolicy, mPrevWordCodePoints[i],
                    mPrevWordCodePointCount[i], mIsBeginningOfSentence[i], tryLowerCaseSearch);
        }
        return WordIdArrayView::fromArray(*prevWordIdBuffer).limit(mPrevWordCount);
    }

    // n is 1-indexed.
    const CodePointArrayView getNthPrevWordCodePoints(const size_t n) const {
        if (n <= 0 || n > mPrevWordCount) {
            return CodePointArrayView();
        }
        return CodePointArrayView(mPrevWordCodePoints[n - 1], mPrevWordCodePointCount[n - 1]);
    }

    const CodePointArrayView getNthPrevWordCodePoints(const size_t n) const;
    // n is 1-indexed.
    bool isNthPrevWordBeginningOfSentence(const size_t n) const {
        if (n <= 0 || n > mPrevWordCount) {
            return false;
        }
        return mIsBeginningOfSentence[n - 1];
    }
    bool isNthPrevWordBeginningOfSentence(const size_t n) const;

 private:
    DISALLOW_ASSIGNMENT_OPERATOR(NgramContext);

    static int getWordId(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
            const int *const wordCodePoints, const int wordCodePointCount,
            const bool isBeginningOfSentence, const bool tryLowerCaseSearch) {
        if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) {
            return NOT_A_WORD_ID;
        }
        int codePoints[MAX_WORD_LENGTH];
        int codePointCount = wordCodePointCount;
        memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount);
        if (isBeginningOfSentence) {
            codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints,
                    codePointCount, MAX_WORD_LENGTH);
            if (codePointCount <= 0) {
                return NOT_A_WORD_ID;
            }
        }
        const CodePointArrayView codePointArrayView(codePoints, codePointCount);
        const int wordId = dictStructurePolicy->getWordId(
                codePointArrayView, false /* forceLowerCaseSearch */);
        if (wordId != NOT_A_WORD_ID || !tryLowerCaseSearch) {
            // Return the id when when the word was found or doesn't try lower case search.
            return wordId;
        }
        // Check bigrams for lower-cased previous word if original was not found. Useful for
        // auto-capitalized words like "The [current_word]".
        return dictStructurePolicy->getWordId(codePointArrayView, true /* forceLowerCaseSearch */);
    }

    void clear() {
        for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) {
            mPrevWordCodePointCount[i] = 0;
            mIsBeginningOfSentence[i] = false;
        }
    }
            const bool isBeginningOfSentence, const bool tryLowerCaseSearch);
    void clear();

    const size_t mPrevWordCount;
    int mPrevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH];