Loading native/jni/src/suggest/core/dictionary/dictionary.cpp +3 −0 Original line number Diff line number Diff line Loading @@ -81,6 +81,9 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi } const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext( mPrevWordIds, targetWordId, nullptr /* multiBigramMap */); if (wordAttributes.getProbability() == NOT_A_PROBABILITY) { return; } mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, wordAttributes.getProbability()); } Loading native/jni/src/suggest/core/dictionary/ngram_listener.h +2 −0 Original line number Diff line number Diff line Loading @@ -26,6 +26,8 @@ namespace latinime { */ class NgramListener { public: // ngramProbability is always 0 for v403 decaying dictionary. // TODO: Remove ngramProbability. virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0; virtual ~NgramListener() {}; Loading native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +73 −54 Original line number Diff line number Diff line Loading @@ -19,11 +19,11 @@ #include <algorithm> #include <cstring> #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" #include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h" #include "suggest/policyimpl/dictionary/utils/probability_utils.h" namespace latinime { const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1; const int LanguageModelDictContent::TRIE_MAP_BUFFER_INDEX = 0; const int LanguageModelDictContent::GLOBAL_COUNTERS_BUFFER_INDEX = 1; Loading @@ -39,7 +39,8 @@ bool LanguageModelDictContent::runGC( } const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds, const int wordId, const HeaderPolicy *const headerPolicy) const { const int wordId, const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const { int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); int maxPrevWordCount = 0; Loading @@ -53,7 +54,15 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr bitmapEntryIndices[i + 1] = nextBitmapEntryIndex; } const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId); if (mHasHistoricalInfo && unigramProbabilityEntry.getHistoricalInfo()->getCount() == 0) { // The word should be treated as a invalid word. return WordAttributes(); } for (int i = maxPrevWordCount; i >= 0; --i) { if (mustMatchAllPrevWords && prevWordIds.size() > static_cast<size_t>(i)) { break; } const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]); if (!result.mIsValid) { continue; Loading @@ -62,36 +71,39 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); int probability = NOT_A_PROBABILITY; if (mHasHistoricalInfo) { const int rawProbability = ForgettingCurveUtils::decodeProbability( probabilityEntry.getHistoricalInfo(), headerPolicy); if (rawProbability == NOT_A_PROBABILITY) { // The entry should not be treated as a valid entry. continue; } const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); int contextCount = 0; if (i == 0) { // unigram probability = rawProbability; contextCount = mGlobalCounters.getTotalCount(); } else { const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry( prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]); if (!prevWordProbabilityEntry.isValid()) { continue; } if (prevWordProbabilityEntry.representsBeginningOfSentence()) { probability = rawProbability; } else { const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability( prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy); probability = std::min(MAX_PROBABILITY - prevWordRawProbability + rawProbability, MAX_PROBABILITY); } if (prevWordProbabilityEntry.representsBeginningOfSentence() && historicalInfo->getCount() == 1) { // BoS ngram requires multiple contextCount. continue; } contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount(); } const float rawProbability = DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts( historicalInfo->getCount(), contextCount, i + 1); const int encodedRawProbability = ProbabilityUtils::encodeRawProbability(rawProbability); const int decayedProbability = DynamicLanguageModelProbabilityUtils::getDecayedProbability( encodedRawProbability, *historicalInfo); probability = DynamicLanguageModelProbabilityUtils::backoff( decayedProbability, i + 1 /* n */); } else { probability = probabilityEntry.getProbability(); } // TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in // probabilityEntry. const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId); return WordAttributes(probability, unigramProbabilityEntry.isBlacklisted(), unigramProbabilityEntry.isNotAWord(), unigramProbabilityEntry.isPossiblyOffensive()); Loading Loading @@ -167,7 +179,8 @@ void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner( ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); if (probabilityEntry.isValid()) { const WordAttributes wordAttributes = getWordAttributes( WordIdArrayView(*prevWordIds), wordId, headerPolicy); WordIdArrayView(*prevWordIds), wordId, true /* mustMatchAllPrevWords */, headerPolicy); outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId, wordAttributes, probabilityEntry); } Loading Loading @@ -231,7 +244,7 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView return false; } mGlobalCounters.updateMaxValueOfCounters( updatedUnigramProbabilityEntry.getHistoricalInfo()->getCount()); updatedNgramProbabilityEntry.getHistoricalInfo()->getCount()); if (!originalNgramProbabilityEntry.isValid()) { entryCountersToUpdate->incrementNgramCount(i + 2); } Loading @@ -242,10 +255,9 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom( const ProbabilityEntry &originalProbabilityEntry, const bool isValid, const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const { const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo( originalProbabilityEntry.getHistoricalInfo(), isValid ? DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY, &historicalInfo, headerPolicy); const HistoricalInfo updatedHistoricalInfo = HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, originalProbabilityEntry.getHistoricalInfo()->getCount() + historicalInfo.getCount()); if (originalProbabilityEntry.isValid()) { return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo); } else { Loading Loading @@ -311,7 +323,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount, const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters) { const bool needsToHalveCounters, MutableEntryCounters *const outEntryCounters) { for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", Loading @@ -328,33 +340,41 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b } continue; } if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence() && probabilityEntry.isValid()) { const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( probabilityEntry.getHistoricalInfo(), headerPolicy); if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) { // Update the entry. const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo); if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), bitmapEntryIndex)) { if (mHasHistoricalInfo && probabilityEntry.isValid()) { const HistoricalInfo *originalHistoricalInfo = probabilityEntry.getHistoricalInfo(); if (DynamicLanguageModelProbabilityUtils::shouldRemoveEntryDuringGC( *originalHistoricalInfo)) { // Remove the entry. if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { return false; } } else { continue; } if (needsToHalveCounters) { const int updatedCount = originalHistoricalInfo->getCount() / 2; if (updatedCount == 0) { // Remove the entry. if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { return false; } continue; } const HistoricalInfo historicalInfoToSave(originalHistoricalInfo->getTimestamp(), originalHistoricalInfo->getLevel(), updatedCount); const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfoToSave); if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), bitmapEntryIndex)) { return false; } } if (!probabilityEntry.representsBeginningOfSentence()) { outEntryCounters->incrementNgramCount(prevWordCount + 1); } outEntryCounters->incrementNgramCount(prevWordCount + 1); if (!entry.hasNextLevelMap()) { continue; } if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), prevWordCount + 1, headerPolicy, outEntryCounters)) { prevWordCount + 1, headerPolicy, needsToHalveCounters, outEntryCounters)) { return false; } } Loading Loading @@ -408,11 +428,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli } const ProbabilityEntry probabilityEntry = ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); const int probability = (mHasHistoricalInfo) ? ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), headerPolicy) : probabilityEntry.getProbability(); outEntryInfo->emplace_back(probability, probabilityEntry.getHistoricalInfo()->getTimestamp(), const int priority = mHasHistoricalInfo ? DynamicLanguageModelProbabilityUtils::getPriorityToPreventFromEviction( *probabilityEntry.getHistoricalInfo()) : probabilityEntry.getProbability(); outEntryInfo->emplace_back(priority, probabilityEntry.getHistoricalInfo()->getCount(), entry.key(), targetLevel, prevWordIds->data()); } return true; Loading @@ -420,11 +440,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const { if (left.mProbability != right.mProbability) { return left.mProbability < right.mProbability; if (left.mPriority != right.mPriority) { return left.mPriority < right.mPriority; } if (left.mTimestamp != right.mTimestamp) { return left.mTimestamp > right.mTimestamp; if (left.mCount != right.mCount) { return left.mCount < right.mCount; } if (left.mKey != right.mKey) { return left.mKey < right.mKey; Loading @@ -441,10 +461,9 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( return false; } LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability, const int timestamp, const int key, const int prevWordCount, const int *const prevWordIds) : mProbability(probability), mTimestamp(timestamp), mKey(key), mPrevWordCount(prevWordCount) { LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int priority, const int count, const int key, const int prevWordCount, const int *const prevWordIds) : mPriority(priority), mCount(count), mKey(key), mPrevWordCount(prevWordCount) { memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0])); } Loading native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +17 −9 Original line number Diff line number Diff line Loading @@ -151,13 +151,14 @@ class LanguageModelDictContent { const LanguageModelDictContent *const originalContent); const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId, const HeaderPolicy *const headerPolicy) const; const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const; ProbabilityEntry getProbabilityEntry(const int wordId) const { return getNgramProbabilityEntry(WordIdArrayView(), wordId); } bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) { mGlobalCounters.addToTotalCount(probabilityEntry->getHistoricalInfo()->getCount()); return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry); } Loading @@ -180,8 +181,15 @@ class LanguageModelDictContent { bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters) { return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* prevWordCount */, headerPolicy, outEntryCounters); if (!updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* prevWordCount */, headerPolicy, mGlobalCounters.needsToHalveCounters(), outEntryCounters)) { return false; } if (mGlobalCounters.needsToHalveCounters()) { mGlobalCounters.halveCounters(); } return true; } // entryCounts should be created by updateAllProbabilityEntries. Loading @@ -206,11 +214,12 @@ class LanguageModelDictContent { DISALLOW_ASSIGNMENT_OPERATOR(Comparator); }; EntryInfoToTurncate(const int probability, const int timestamp, const int key, EntryInfoToTurncate(const int priority, const int count, const int key, const int prevWordCount, const int *const prevWordIds); int mProbability; int mTimestamp; int mPriority; // TODO: Remove. int mCount; int mKey; int mPrevWordCount; int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; Loading @@ -219,8 +228,6 @@ class LanguageModelDictContent { DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate); }; // TODO: Remove static const int DUMMY_PROBABILITY_FOR_VALID_WORDS; static const int TRIE_MAP_BUFFER_INDEX; static const int GLOBAL_COUNTERS_BUFFER_INDEX; Loading @@ -233,7 +240,8 @@ class LanguageModelDictContent { int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount, const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters); const HeaderPolicy *const headerPolicy, const bool needsToHalveCounters, MutableEntryCounters *const outEntryCounters); bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel, int *const outEntryCount); bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, Loading native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h +4 −0 Original line number Diff line number Diff line Loading @@ -63,6 +63,10 @@ class LanguageModelDictContentGlobalCounters { mTotalCount += 1; } void addToTotalCount(const int count) { mTotalCount += count; } void updateMaxValueOfCounters(const int count) { mMaxValueOfCounters = std::max(count, mMaxValueOfCounters); } Loading Loading
native/jni/src/suggest/core/dictionary/dictionary.cpp +3 −0 Original line number Diff line number Diff line Loading @@ -81,6 +81,9 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi } const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext( mPrevWordIds, targetWordId, nullptr /* multiBigramMap */); if (wordAttributes.getProbability() == NOT_A_PROBABILITY) { return; } mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, wordAttributes.getProbability()); } Loading
native/jni/src/suggest/core/dictionary/ngram_listener.h +2 −0 Original line number Diff line number Diff line Loading @@ -26,6 +26,8 @@ namespace latinime { */ class NgramListener { public: // ngramProbability is always 0 for v403 decaying dictionary. // TODO: Remove ngramProbability. virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0; virtual ~NgramListener() {}; Loading
native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +73 −54 Original line number Diff line number Diff line Loading @@ -19,11 +19,11 @@ #include <algorithm> #include <cstring> #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" #include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h" #include "suggest/policyimpl/dictionary/utils/probability_utils.h" namespace latinime { const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1; const int LanguageModelDictContent::TRIE_MAP_BUFFER_INDEX = 0; const int LanguageModelDictContent::GLOBAL_COUNTERS_BUFFER_INDEX = 1; Loading @@ -39,7 +39,8 @@ bool LanguageModelDictContent::runGC( } const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds, const int wordId, const HeaderPolicy *const headerPolicy) const { const int wordId, const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const { int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); int maxPrevWordCount = 0; Loading @@ -53,7 +54,15 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr bitmapEntryIndices[i + 1] = nextBitmapEntryIndex; } const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId); if (mHasHistoricalInfo && unigramProbabilityEntry.getHistoricalInfo()->getCount() == 0) { // The word should be treated as a invalid word. return WordAttributes(); } for (int i = maxPrevWordCount; i >= 0; --i) { if (mustMatchAllPrevWords && prevWordIds.size() > static_cast<size_t>(i)) { break; } const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]); if (!result.mIsValid) { continue; Loading @@ -62,36 +71,39 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); int probability = NOT_A_PROBABILITY; if (mHasHistoricalInfo) { const int rawProbability = ForgettingCurveUtils::decodeProbability( probabilityEntry.getHistoricalInfo(), headerPolicy); if (rawProbability == NOT_A_PROBABILITY) { // The entry should not be treated as a valid entry. continue; } const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); int contextCount = 0; if (i == 0) { // unigram probability = rawProbability; contextCount = mGlobalCounters.getTotalCount(); } else { const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry( prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]); if (!prevWordProbabilityEntry.isValid()) { continue; } if (prevWordProbabilityEntry.representsBeginningOfSentence()) { probability = rawProbability; } else { const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability( prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy); probability = std::min(MAX_PROBABILITY - prevWordRawProbability + rawProbability, MAX_PROBABILITY); } if (prevWordProbabilityEntry.representsBeginningOfSentence() && historicalInfo->getCount() == 1) { // BoS ngram requires multiple contextCount. continue; } contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount(); } const float rawProbability = DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts( historicalInfo->getCount(), contextCount, i + 1); const int encodedRawProbability = ProbabilityUtils::encodeRawProbability(rawProbability); const int decayedProbability = DynamicLanguageModelProbabilityUtils::getDecayedProbability( encodedRawProbability, *historicalInfo); probability = DynamicLanguageModelProbabilityUtils::backoff( decayedProbability, i + 1 /* n */); } else { probability = probabilityEntry.getProbability(); } // TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in // probabilityEntry. const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId); return WordAttributes(probability, unigramProbabilityEntry.isBlacklisted(), unigramProbabilityEntry.isNotAWord(), unigramProbabilityEntry.isPossiblyOffensive()); Loading Loading @@ -167,7 +179,8 @@ void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner( ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); if (probabilityEntry.isValid()) { const WordAttributes wordAttributes = getWordAttributes( WordIdArrayView(*prevWordIds), wordId, headerPolicy); WordIdArrayView(*prevWordIds), wordId, true /* mustMatchAllPrevWords */, headerPolicy); outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId, wordAttributes, probabilityEntry); } Loading Loading @@ -231,7 +244,7 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView return false; } mGlobalCounters.updateMaxValueOfCounters( updatedUnigramProbabilityEntry.getHistoricalInfo()->getCount()); updatedNgramProbabilityEntry.getHistoricalInfo()->getCount()); if (!originalNgramProbabilityEntry.isValid()) { entryCountersToUpdate->incrementNgramCount(i + 2); } Loading @@ -242,10 +255,9 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom( const ProbabilityEntry &originalProbabilityEntry, const bool isValid, const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const { const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo( originalProbabilityEntry.getHistoricalInfo(), isValid ? DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY, &historicalInfo, headerPolicy); const HistoricalInfo updatedHistoricalInfo = HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, originalProbabilityEntry.getHistoricalInfo()->getCount() + historicalInfo.getCount()); if (originalProbabilityEntry.isValid()) { return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo); } else { Loading Loading @@ -311,7 +323,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount, const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters) { const bool needsToHalveCounters, MutableEntryCounters *const outEntryCounters) { for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", Loading @@ -328,33 +340,41 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b } continue; } if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence() && probabilityEntry.isValid()) { const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( probabilityEntry.getHistoricalInfo(), headerPolicy); if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) { // Update the entry. const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo); if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), bitmapEntryIndex)) { if (mHasHistoricalInfo && probabilityEntry.isValid()) { const HistoricalInfo *originalHistoricalInfo = probabilityEntry.getHistoricalInfo(); if (DynamicLanguageModelProbabilityUtils::shouldRemoveEntryDuringGC( *originalHistoricalInfo)) { // Remove the entry. if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { return false; } } else { continue; } if (needsToHalveCounters) { const int updatedCount = originalHistoricalInfo->getCount() / 2; if (updatedCount == 0) { // Remove the entry. if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { return false; } continue; } const HistoricalInfo historicalInfoToSave(originalHistoricalInfo->getTimestamp(), originalHistoricalInfo->getLevel(), updatedCount); const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfoToSave); if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), bitmapEntryIndex)) { return false; } } if (!probabilityEntry.representsBeginningOfSentence()) { outEntryCounters->incrementNgramCount(prevWordCount + 1); } outEntryCounters->incrementNgramCount(prevWordCount + 1); if (!entry.hasNextLevelMap()) { continue; } if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), prevWordCount + 1, headerPolicy, outEntryCounters)) { prevWordCount + 1, headerPolicy, needsToHalveCounters, outEntryCounters)) { return false; } } Loading Loading @@ -408,11 +428,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli } const ProbabilityEntry probabilityEntry = ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); const int probability = (mHasHistoricalInfo) ? ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), headerPolicy) : probabilityEntry.getProbability(); outEntryInfo->emplace_back(probability, probabilityEntry.getHistoricalInfo()->getTimestamp(), const int priority = mHasHistoricalInfo ? DynamicLanguageModelProbabilityUtils::getPriorityToPreventFromEviction( *probabilityEntry.getHistoricalInfo()) : probabilityEntry.getProbability(); outEntryInfo->emplace_back(priority, probabilityEntry.getHistoricalInfo()->getCount(), entry.key(), targetLevel, prevWordIds->data()); } return true; Loading @@ -420,11 +440,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const { if (left.mProbability != right.mProbability) { return left.mProbability < right.mProbability; if (left.mPriority != right.mPriority) { return left.mPriority < right.mPriority; } if (left.mTimestamp != right.mTimestamp) { return left.mTimestamp > right.mTimestamp; if (left.mCount != right.mCount) { return left.mCount < right.mCount; } if (left.mKey != right.mKey) { return left.mKey < right.mKey; Loading @@ -441,10 +461,9 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( return false; } LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability, const int timestamp, const int key, const int prevWordCount, const int *const prevWordIds) : mProbability(probability), mTimestamp(timestamp), mKey(key), mPrevWordCount(prevWordCount) { LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int priority, const int count, const int key, const int prevWordCount, const int *const prevWordIds) : mPriority(priority), mCount(count), mKey(key), mPrevWordCount(prevWordCount) { memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0])); } Loading
native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +17 −9 Original line number Diff line number Diff line Loading @@ -151,13 +151,14 @@ class LanguageModelDictContent { const LanguageModelDictContent *const originalContent); const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId, const HeaderPolicy *const headerPolicy) const; const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const; ProbabilityEntry getProbabilityEntry(const int wordId) const { return getNgramProbabilityEntry(WordIdArrayView(), wordId); } bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) { mGlobalCounters.addToTotalCount(probabilityEntry->getHistoricalInfo()->getCount()); return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry); } Loading @@ -180,8 +181,15 @@ class LanguageModelDictContent { bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters) { return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* prevWordCount */, headerPolicy, outEntryCounters); if (!updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* prevWordCount */, headerPolicy, mGlobalCounters.needsToHalveCounters(), outEntryCounters)) { return false; } if (mGlobalCounters.needsToHalveCounters()) { mGlobalCounters.halveCounters(); } return true; } // entryCounts should be created by updateAllProbabilityEntries. Loading @@ -206,11 +214,12 @@ class LanguageModelDictContent { DISALLOW_ASSIGNMENT_OPERATOR(Comparator); }; EntryInfoToTurncate(const int probability, const int timestamp, const int key, EntryInfoToTurncate(const int priority, const int count, const int key, const int prevWordCount, const int *const prevWordIds); int mProbability; int mTimestamp; int mPriority; // TODO: Remove. int mCount; int mKey; int mPrevWordCount; int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; Loading @@ -219,8 +228,6 @@ class LanguageModelDictContent { DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate); }; // TODO: Remove static const int DUMMY_PROBABILITY_FOR_VALID_WORDS; static const int TRIE_MAP_BUFFER_INDEX; static const int GLOBAL_COUNTERS_BUFFER_INDEX; Loading @@ -233,7 +240,8 @@ class LanguageModelDictContent { int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount, const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters); const HeaderPolicy *const headerPolicy, const bool needsToHalveCounters, MutableEntryCounters *const outEntryCounters); bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel, int *const outEntryCount); bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, Loading
native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h +4 −0 Original line number Diff line number Diff line Loading @@ -63,6 +63,10 @@ class LanguageModelDictContentGlobalCounters { mTotalCount += 1; } void addToTotalCount(const int count) { mTotalCount += count; } void updateMaxValueOfCounters(const int count) { mMaxValueOfCounters = std::max(count, mMaxValueOfCounters); } Loading