Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c096100b authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android (Google) Code Review
Browse files

Merge "Enable count based dynamic ngram language model for v403."

parents 04a492cb bcb52d73
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -81,6 +81,9 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
    }
    const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext(
            mPrevWordIds, targetWordId, nullptr /* multiBigramMap */);
    if (wordAttributes.getProbability() == NOT_A_PROBABILITY) {
        return;
    }
    mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount,
            wordAttributes.getProbability());
}
+2 −0
Original line number Diff line number Diff line
@@ -26,6 +26,8 @@ namespace latinime {
 */
class NgramListener {
 public:
    // ngramProbability is always 0 for v403 decaying dictionary.
    // TODO: Remove ngramProbability.
    virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0;
    virtual ~NgramListener() {};

+73 −54
Original line number Diff line number Diff line
@@ -19,11 +19,11 @@
#include <algorithm>
#include <cstring>

#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h"
#include "suggest/policyimpl/dictionary/utils/probability_utils.h"

namespace latinime {

const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;
const int LanguageModelDictContent::TRIE_MAP_BUFFER_INDEX = 0;
const int LanguageModelDictContent::GLOBAL_COUNTERS_BUFFER_INDEX = 1;

@@ -39,7 +39,8 @@ bool LanguageModelDictContent::runGC(
}

const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds,
        const int wordId, const HeaderPolicy *const headerPolicy) const {
        const int wordId, const bool mustMatchAllPrevWords,
        const HeaderPolicy *const headerPolicy) const {
    int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
    bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
    int maxPrevWordCount = 0;
@@ -53,7 +54,15 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
        bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
    }

    const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId);
    if (mHasHistoricalInfo && unigramProbabilityEntry.getHistoricalInfo()->getCount() == 0) {
        // The word should be treated as a invalid word.
        return WordAttributes();
    }
    for (int i = maxPrevWordCount; i >= 0; --i) {
        if (mustMatchAllPrevWords && prevWordIds.size() > static_cast<size_t>(i)) {
            break;
        }
        const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
        if (!result.mIsValid) {
            continue;
@@ -62,36 +71,39 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
                ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
        int probability = NOT_A_PROBABILITY;
        if (mHasHistoricalInfo) {
            const int rawProbability = ForgettingCurveUtils::decodeProbability(
                    probabilityEntry.getHistoricalInfo(), headerPolicy);
            if (rawProbability == NOT_A_PROBABILITY) {
                // The entry should not be treated as a valid entry.
                continue;
            }
            const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
            int contextCount = 0;
            if (i == 0) {
                // unigram
                probability = rawProbability;
                contextCount = mGlobalCounters.getTotalCount();
            } else {
                const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry(
                        prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]);
                if (!prevWordProbabilityEntry.isValid()) {
                    continue;
                }
                if (prevWordProbabilityEntry.representsBeginningOfSentence()) {
                    probability = rawProbability;
                } else {
                    const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability(
                            prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy);
                    probability = std::min(MAX_PROBABILITY - prevWordRawProbability
                            + rawProbability, MAX_PROBABILITY);
                }
                if (prevWordProbabilityEntry.representsBeginningOfSentence()
                        && historicalInfo->getCount() == 1) {
                    // BoS ngram requires multiple contextCount.
                    continue;
                }
                contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount();
            }
            const float rawProbability =
                    DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts(
                            historicalInfo->getCount(), contextCount, i + 1);
            const int encodedRawProbability =
                    ProbabilityUtils::encodeRawProbability(rawProbability);
            const int decayedProbability =
                    DynamicLanguageModelProbabilityUtils::getDecayedProbability(
                            encodedRawProbability, *historicalInfo);
            probability = DynamicLanguageModelProbabilityUtils::backoff(
                    decayedProbability, i + 1 /* n */);
        } else {
            probability = probabilityEntry.getProbability();
        }
        // TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in
        // probabilityEntry.
        const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId);
        return WordAttributes(probability, unigramProbabilityEntry.isBlacklisted(),
                unigramProbabilityEntry.isNotAWord(),
                unigramProbabilityEntry.isPossiblyOffensive());
@@ -167,7 +179,8 @@ void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner(
                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
        if (probabilityEntry.isValid()) {
            const WordAttributes wordAttributes = getWordAttributes(
                    WordIdArrayView(*prevWordIds), wordId, headerPolicy);
                    WordIdArrayView(*prevWordIds), wordId, true /* mustMatchAllPrevWords */,
                    headerPolicy);
            outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId,
                    wordAttributes, probabilityEntry);
        }
@@ -231,7 +244,7 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
            return false;
        }
        mGlobalCounters.updateMaxValueOfCounters(
                updatedUnigramProbabilityEntry.getHistoricalInfo()->getCount());
                updatedNgramProbabilityEntry.getHistoricalInfo()->getCount());
        if (!originalNgramProbabilityEntry.isValid()) {
            entryCountersToUpdate->incrementNgramCount(i + 2);
        }
@@ -242,10 +255,9 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(
        const ProbabilityEntry &originalProbabilityEntry, const bool isValid,
        const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const {
    const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo(
            originalProbabilityEntry.getHistoricalInfo(), isValid ?
                    DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY,
            &historicalInfo, headerPolicy);
    const HistoricalInfo updatedHistoricalInfo = HistoricalInfo(historicalInfo.getTimestamp(),
            0 /* level */, originalProbabilityEntry.getHistoricalInfo()->getCount()
                    + historicalInfo.getCount());
    if (originalProbabilityEntry.isValid()) {
        return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo);
    } else {
@@ -311,7 +323,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord

bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
        const int prevWordCount, const HeaderPolicy *const headerPolicy,
        MutableEntryCounters *const outEntryCounters) {
        const bool needsToHalveCounters, MutableEntryCounters *const outEntryCounters) {
    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
        if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
            AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
@@ -328,33 +340,41 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
            }
            continue;
        }
        if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()
                && probabilityEntry.isValid()) {
            const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
                    probabilityEntry.getHistoricalInfo(), headerPolicy);
            if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
                // Update the entry.
                const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
                if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
                        bitmapEntryIndex)) {
        if (mHasHistoricalInfo && probabilityEntry.isValid()) {
            const HistoricalInfo *originalHistoricalInfo = probabilityEntry.getHistoricalInfo();
            if (DynamicLanguageModelProbabilityUtils::shouldRemoveEntryDuringGC(
                    *originalHistoricalInfo)) {
                // Remove the entry.
                if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
                    return false;
                }
            } else {
                continue;
            }
            if (needsToHalveCounters) {
                const int updatedCount = originalHistoricalInfo->getCount() / 2;
                if (updatedCount == 0) {
                    // Remove the entry.
                    if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
                        return false;
                    }
                    continue;
                }
                const HistoricalInfo historicalInfoToSave(originalHistoricalInfo->getTimestamp(),
                        originalHistoricalInfo->getLevel(), updatedCount);
                const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(),
                        &historicalInfoToSave);
                if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
                        bitmapEntryIndex)) {
                    return false;
                }
            }
        if (!probabilityEntry.representsBeginningOfSentence()) {
            outEntryCounters->incrementNgramCount(prevWordCount + 1);
        }
        outEntryCounters->incrementNgramCount(prevWordCount + 1);
        if (!entry.hasNextLevelMap()) {
            continue;
        }
        if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(),
                prevWordCount + 1, headerPolicy, outEntryCounters)) {
                prevWordCount + 1, headerPolicy, needsToHalveCounters, outEntryCounters)) {
            return false;
        }
    }
@@ -408,11 +428,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli
        }
        const ProbabilityEntry probabilityEntry =
                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
        const int probability = (mHasHistoricalInfo) ?
                ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
                        headerPolicy) : probabilityEntry.getProbability();
        outEntryInfo->emplace_back(probability,
                probabilityEntry.getHistoricalInfo()->getTimestamp(),
        const int priority = mHasHistoricalInfo
                ? DynamicLanguageModelProbabilityUtils::getPriorityToPreventFromEviction(
                        *probabilityEntry.getHistoricalInfo())
                : probabilityEntry.getProbability();
        outEntryInfo->emplace_back(priority, probabilityEntry.getHistoricalInfo()->getCount(),
                entry.key(), targetLevel, prevWordIds->data());
    }
    return true;
@@ -420,11 +440,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli

bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
        const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const {
    if (left.mProbability != right.mProbability) {
        return left.mProbability < right.mProbability;
    if (left.mPriority != right.mPriority) {
        return left.mPriority < right.mPriority;
    }
    if (left.mTimestamp != right.mTimestamp) {
        return left.mTimestamp > right.mTimestamp;
    if (left.mCount != right.mCount) {
        return left.mCount < right.mCount;
    }
    if (left.mKey != right.mKey) {
        return left.mKey < right.mKey;
@@ -441,10 +461,9 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
    return false;
}

LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability,
        const int timestamp, const int key, const int prevWordCount, const int *const prevWordIds)
        : mProbability(probability), mTimestamp(timestamp), mKey(key),
          mPrevWordCount(prevWordCount) {
LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int priority,
        const int count, const int key, const int prevWordCount, const int *const prevWordIds)
        : mPriority(priority), mCount(count), mKey(key), mPrevWordCount(prevWordCount) {
    memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0]));
}

+17 −9
Original line number Diff line number Diff line
@@ -151,13 +151,14 @@ class LanguageModelDictContent {
            const LanguageModelDictContent *const originalContent);

    const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
            const HeaderPolicy *const headerPolicy) const;
            const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const;

    ProbabilityEntry getProbabilityEntry(const int wordId) const {
        return getNgramProbabilityEntry(WordIdArrayView(), wordId);
    }

    bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) {
        mGlobalCounters.addToTotalCount(probabilityEntry->getHistoricalInfo()->getCount());
        return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry);
    }

@@ -180,8 +181,15 @@ class LanguageModelDictContent {

    bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
            MutableEntryCounters *const outEntryCounters) {
        return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
                0 /* prevWordCount */, headerPolicy, outEntryCounters);
        if (!updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
                0 /* prevWordCount */, headerPolicy, mGlobalCounters.needsToHalveCounters(),
                outEntryCounters)) {
            return false;
        }
        if (mGlobalCounters.needsToHalveCounters()) {
            mGlobalCounters.halveCounters();
        }
        return true;
    }

    // entryCounts should be created by updateAllProbabilityEntries.
@@ -206,11 +214,12 @@ class LanguageModelDictContent {
            DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
        };

        EntryInfoToTurncate(const int probability, const int timestamp, const int key,
        EntryInfoToTurncate(const int priority, const int count, const int key,
                const int prevWordCount, const int *const prevWordIds);

        int mProbability;
        int mTimestamp;
        int mPriority;
        // TODO: Remove.
        int mCount;
        int mKey;
        int mPrevWordCount;
        int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
@@ -219,8 +228,6 @@ class LanguageModelDictContent {
        DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
    };

    // TODO: Remove
    static const int DUMMY_PROBABILITY_FOR_VALID_WORDS;
    static const int TRIE_MAP_BUFFER_INDEX;
    static const int GLOBAL_COUNTERS_BUFFER_INDEX;

@@ -233,7 +240,8 @@ class LanguageModelDictContent {
    int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
    int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
    bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
            const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
            const HeaderPolicy *const headerPolicy, const bool needsToHalveCounters,
            MutableEntryCounters *const outEntryCounters);
    bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
            const int maxEntryCount, const int targetLevel, int *const outEntryCount);
    bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
+4 −0
Original line number Diff line number Diff line
@@ -63,6 +63,10 @@ class LanguageModelDictContentGlobalCounters {
        mTotalCount += 1;
    }

    void addToTotalCount(const int count) {
        mTotalCount += count;
    }

    void updateMaxValueOfCounters(const int count) {
        mMaxValueOfCounters = std::max(count, mMaxValueOfCounters);
    }
Loading