Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit bcb52d73 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi
Browse files

Enable count based dynamic ngram language model for v403.

Bug: 14425059

Change-Id: Icc15e14cfd77d37cd75f75318fd0fa36f9ca7a5b
parent 660b0047
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -81,6 +81,9 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
    }
    const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext(
            mPrevWordIds, targetWordId, nullptr /* multiBigramMap */);
    if (wordAttributes.getProbability() == NOT_A_PROBABILITY) {
        return;
    }
    mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount,
            wordAttributes.getProbability());
}
+2 −0
Original line number Diff line number Diff line
@@ -26,6 +26,8 @@ namespace latinime {
 */
class NgramListener {
 public:
    // ngramProbability is always 0 for v403 decaying dictionary.
    // TODO: Remove ngramProbability.
    virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0;
    virtual ~NgramListener() {};

+73 −54
Original line number Diff line number Diff line
@@ -19,11 +19,11 @@
#include <algorithm>
#include <cstring>

#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h"
#include "suggest/policyimpl/dictionary/utils/probability_utils.h"

namespace latinime {

const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;
const int LanguageModelDictContent::TRIE_MAP_BUFFER_INDEX = 0;
const int LanguageModelDictContent::GLOBAL_COUNTERS_BUFFER_INDEX = 1;

@@ -39,7 +39,8 @@ bool LanguageModelDictContent::runGC(
}

const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds,
        const int wordId, const HeaderPolicy *const headerPolicy) const {
        const int wordId, const bool mustMatchAllPrevWords,
        const HeaderPolicy *const headerPolicy) const {
    int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
    bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
    int maxPrevWordCount = 0;
@@ -53,7 +54,15 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
        bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
    }

    const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId);
    if (mHasHistoricalInfo && unigramProbabilityEntry.getHistoricalInfo()->getCount() == 0) {
        // The word should be treated as a invalid word.
        return WordAttributes();
    }
    for (int i = maxPrevWordCount; i >= 0; --i) {
        if (mustMatchAllPrevWords && prevWordIds.size() > static_cast<size_t>(i)) {
            break;
        }
        const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
        if (!result.mIsValid) {
            continue;
@@ -62,36 +71,39 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
                ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
        int probability = NOT_A_PROBABILITY;
        if (mHasHistoricalInfo) {
            const int rawProbability = ForgettingCurveUtils::decodeProbability(
                    probabilityEntry.getHistoricalInfo(), headerPolicy);
            if (rawProbability == NOT_A_PROBABILITY) {
                // The entry should not be treated as a valid entry.
                continue;
            }
            const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
            int contextCount = 0;
            if (i == 0) {
                // unigram
                probability = rawProbability;
                contextCount = mGlobalCounters.getTotalCount();
            } else {
                const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry(
                        prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]);
                if (!prevWordProbabilityEntry.isValid()) {
                    continue;
                }
                if (prevWordProbabilityEntry.representsBeginningOfSentence()) {
                    probability = rawProbability;
                } else {
                    const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability(
                            prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy);
                    probability = std::min(MAX_PROBABILITY - prevWordRawProbability
                            + rawProbability, MAX_PROBABILITY);
                }
                if (prevWordProbabilityEntry.representsBeginningOfSentence()
                        && historicalInfo->getCount() == 1) {
                    // BoS ngram requires multiple contextCount.
                    continue;
                }
                contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount();
            }
            const float rawProbability =
                    DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts(
                            historicalInfo->getCount(), contextCount, i + 1);
            const int encodedRawProbability =
                    ProbabilityUtils::encodeRawProbability(rawProbability);
            const int decayedProbability =
                    DynamicLanguageModelProbabilityUtils::getDecayedProbability(
                            encodedRawProbability, *historicalInfo);
            probability = DynamicLanguageModelProbabilityUtils::backoff(
                    decayedProbability, i + 1 /* n */);
        } else {
            probability = probabilityEntry.getProbability();
        }
        // TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in
        // probabilityEntry.
        const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId);
        return WordAttributes(probability, unigramProbabilityEntry.isBlacklisted(),
                unigramProbabilityEntry.isNotAWord(),
                unigramProbabilityEntry.isPossiblyOffensive());
@@ -167,7 +179,8 @@ void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner(
                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
        if (probabilityEntry.isValid()) {
            const WordAttributes wordAttributes = getWordAttributes(
                    WordIdArrayView(*prevWordIds), wordId, headerPolicy);
                    WordIdArrayView(*prevWordIds), wordId, true /* mustMatchAllPrevWords */,
                    headerPolicy);
            outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId,
                    wordAttributes, probabilityEntry);
        }
@@ -231,7 +244,7 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
            return false;
        }
        mGlobalCounters.updateMaxValueOfCounters(
                updatedUnigramProbabilityEntry.getHistoricalInfo()->getCount());
                updatedNgramProbabilityEntry.getHistoricalInfo()->getCount());
        if (!originalNgramProbabilityEntry.isValid()) {
            entryCountersToUpdate->incrementNgramCount(i + 2);
        }
@@ -242,10 +255,9 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(
        const ProbabilityEntry &originalProbabilityEntry, const bool isValid,
        const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const {
    const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo(
            originalProbabilityEntry.getHistoricalInfo(), isValid ?
                    DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY,
            &historicalInfo, headerPolicy);
    const HistoricalInfo updatedHistoricalInfo = HistoricalInfo(historicalInfo.getTimestamp(),
            0 /* level */, originalProbabilityEntry.getHistoricalInfo()->getCount()
                    + historicalInfo.getCount());
    if (originalProbabilityEntry.isValid()) {
        return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo);
    } else {
@@ -311,7 +323,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord

bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
        const int prevWordCount, const HeaderPolicy *const headerPolicy,
        MutableEntryCounters *const outEntryCounters) {
        const bool needsToHalveCounters, MutableEntryCounters *const outEntryCounters) {
    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
        if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
            AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
@@ -328,33 +340,41 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
            }
            continue;
        }
        if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()
                && probabilityEntry.isValid()) {
            const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
                    probabilityEntry.getHistoricalInfo(), headerPolicy);
            if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
                // Update the entry.
                const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
                if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
                        bitmapEntryIndex)) {
        if (mHasHistoricalInfo && probabilityEntry.isValid()) {
            const HistoricalInfo *originalHistoricalInfo = probabilityEntry.getHistoricalInfo();
            if (DynamicLanguageModelProbabilityUtils::shouldRemoveEntryDuringGC(
                    *originalHistoricalInfo)) {
                // Remove the entry.
                if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
                    return false;
                }
            } else {
                continue;
            }
            if (needsToHalveCounters) {
                const int updatedCount = originalHistoricalInfo->getCount() / 2;
                if (updatedCount == 0) {
                    // Remove the entry.
                    if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
                        return false;
                    }
                    continue;
                }
                const HistoricalInfo historicalInfoToSave(originalHistoricalInfo->getTimestamp(),
                        originalHistoricalInfo->getLevel(), updatedCount);
                const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(),
                        &historicalInfoToSave);
                if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
                        bitmapEntryIndex)) {
                    return false;
                }
            }
        if (!probabilityEntry.representsBeginningOfSentence()) {
            outEntryCounters->incrementNgramCount(prevWordCount + 1);
        }
        outEntryCounters->incrementNgramCount(prevWordCount + 1);
        if (!entry.hasNextLevelMap()) {
            continue;
        }
        if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(),
                prevWordCount + 1, headerPolicy, outEntryCounters)) {
                prevWordCount + 1, headerPolicy, needsToHalveCounters, outEntryCounters)) {
            return false;
        }
    }
@@ -408,11 +428,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli
        }
        const ProbabilityEntry probabilityEntry =
                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
        const int probability = (mHasHistoricalInfo) ?
                ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
                        headerPolicy) : probabilityEntry.getProbability();
        outEntryInfo->emplace_back(probability,
                probabilityEntry.getHistoricalInfo()->getTimestamp(),
        const int priority = mHasHistoricalInfo
                ? DynamicLanguageModelProbabilityUtils::getPriorityToPreventFromEviction(
                        *probabilityEntry.getHistoricalInfo())
                : probabilityEntry.getProbability();
        outEntryInfo->emplace_back(priority, probabilityEntry.getHistoricalInfo()->getCount(),
                entry.key(), targetLevel, prevWordIds->data());
    }
    return true;
@@ -420,11 +440,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli

bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
        const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const {
    if (left.mProbability != right.mProbability) {
        return left.mProbability < right.mProbability;
    if (left.mPriority != right.mPriority) {
        return left.mPriority < right.mPriority;
    }
    if (left.mTimestamp != right.mTimestamp) {
        return left.mTimestamp > right.mTimestamp;
    if (left.mCount != right.mCount) {
        return left.mCount < right.mCount;
    }
    if (left.mKey != right.mKey) {
        return left.mKey < right.mKey;
@@ -441,10 +461,9 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
    return false;
}

LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability,
        const int timestamp, const int key, const int prevWordCount, const int *const prevWordIds)
        : mProbability(probability), mTimestamp(timestamp), mKey(key),
          mPrevWordCount(prevWordCount) {
LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int priority,
        const int count, const int key, const int prevWordCount, const int *const prevWordIds)
        : mPriority(priority), mCount(count), mKey(key), mPrevWordCount(prevWordCount) {
    memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0]));
}

+17 −9
Original line number Diff line number Diff line
@@ -151,13 +151,14 @@ class LanguageModelDictContent {
            const LanguageModelDictContent *const originalContent);

    const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
            const HeaderPolicy *const headerPolicy) const;
            const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const;

    ProbabilityEntry getProbabilityEntry(const int wordId) const {
        return getNgramProbabilityEntry(WordIdArrayView(), wordId);
    }

    bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) {
        mGlobalCounters.addToTotalCount(probabilityEntry->getHistoricalInfo()->getCount());
        return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry);
    }

@@ -180,8 +181,15 @@ class LanguageModelDictContent {

    bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
            MutableEntryCounters *const outEntryCounters) {
        return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
                0 /* prevWordCount */, headerPolicy, outEntryCounters);
        if (!updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
                0 /* prevWordCount */, headerPolicy, mGlobalCounters.needsToHalveCounters(),
                outEntryCounters)) {
            return false;
        }
        if (mGlobalCounters.needsToHalveCounters()) {
            mGlobalCounters.halveCounters();
        }
        return true;
    }

    // entryCounts should be created by updateAllProbabilityEntries.
@@ -206,11 +214,12 @@ class LanguageModelDictContent {
            DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
        };

        EntryInfoToTurncate(const int probability, const int timestamp, const int key,
        EntryInfoToTurncate(const int priority, const int count, const int key,
                const int prevWordCount, const int *const prevWordIds);

        int mProbability;
        int mTimestamp;
        int mPriority;
        // TODO: Remove.
        int mCount;
        int mKey;
        int mPrevWordCount;
        int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
@@ -219,8 +228,6 @@ class LanguageModelDictContent {
        DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
    };

    // TODO: Remove
    static const int DUMMY_PROBABILITY_FOR_VALID_WORDS;
    static const int TRIE_MAP_BUFFER_INDEX;
    static const int GLOBAL_COUNTERS_BUFFER_INDEX;

@@ -233,7 +240,8 @@ class LanguageModelDictContent {
    int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
    int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
    bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
            const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
            const HeaderPolicy *const headerPolicy, const bool needsToHalveCounters,
            MutableEntryCounters *const outEntryCounters);
    bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
            const int maxEntryCount, const int targetLevel, int *const outEntryCount);
    bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
+4 −0
Original line number Diff line number Diff line
@@ -63,6 +63,10 @@ class LanguageModelDictContentGlobalCounters {
        mTotalCount += 1;
    }

    void addToTotalCount(const int count) {
        mTotalCount += count;
    }

    void updateMaxValueOfCounters(const int count) {
        mMaxValueOfCounters = std::max(count, mMaxValueOfCounters);
    }
Loading