Loading native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +44 −0 Original line number Diff line number Diff line Loading @@ -16,6 +16,8 @@ #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" namespace latinime { bool LanguageModelDictContent::save(FILE *const file) const { Loading Loading @@ -118,4 +120,46 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord return bitmapEntryIndex; } bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) { for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", level, MAX_PREV_WORD_COUNT_FOR_N_GRAM); return false; } const ProbabilityEntry probabilityEntry = ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) { const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( probabilityEntry.getHistoricalInfo(), headerPolicy); if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) { // Update the entry. const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo); if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), bitmapEntryIndex)) { return false; } } else { // Remove the entry. if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { return false; } continue; } } if (!probabilityEntry.representsBeginningOfSentence()) { outEntryCounts[level] += 1; } if (!entry.hasNextLevelMap()) { continue; } if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1, headerPolicy, outEntryCounts)) { return false; } } return true; } } // namespace latinime native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +10 −0 Original line number Diff line number Diff line Loading @@ -29,6 +29,8 @@ namespace latinime { class HeaderPolicy; /** * Class representing language model. * Loading Loading @@ -73,6 +75,12 @@ class LanguageModelDictContent { bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId); bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy, int *const outEntryCounts) { return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */, headerPolicy, outEntryCounts); } private: DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent); Loading @@ -84,6 +92,8 @@ class LanguageModelDictContent { int *const outNgramCount); int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts); }; } // namespace latinime #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */ native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +8 −21 Original line number Diff line number Diff line Loading @@ -161,29 +161,15 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA const ProbabilityEntry originalProbabilityEntry = mBuffers->getLanguageModelDictContent()->getProbabilityEntry( toBeUpdatedPtNodeParams->getTerminalId()); if (originalProbabilityEntry.hasHistoricalInfo()) { const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy); const ProbabilityEntry probabilityEntry(originalProbabilityEntry.getFlags(), &historicalInfo); if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) { AKLOGE("Cannot write updated probability entry. terminalId: %d", toBeUpdatedPtNodeParams->getTerminalId()); return false; if (originalProbabilityEntry.isValid()) { *outNeedsToKeepPtNode = true; return true; } const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy); if (!isValid) { if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) { AKLOGE("Cannot mark PtNode as willBecomeNonTerminal."); return false; } } *outNeedsToKeepPtNode = isValid; } else { // No need to update probability. *outNeedsToKeepPtNode = true; } *outNeedsToKeepPtNode = false; return true; } Loading Loading @@ -380,6 +366,7 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition( isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); } // TODO: Move probability handling code to LanguageModelDictContent. const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( const ProbabilityEntry *const originalProbabilityEntry, const ProbabilityEntry *const probabilityEntry) const { Loading native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +6 −0 Original line number Diff line number Diff line Loading @@ -85,6 +85,12 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy, &shortcutPolicy); int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy, entryCountTable)) { AKLOGE("Failed to update probabilities in language model dict content."); return false; } DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader); readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); DynamicPtGcEventListeners Loading native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h +4 −0 Original line number Diff line number Diff line Loading @@ -84,6 +84,10 @@ class TrieMap { return mValue; } AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const { return mNextLevelBitmapEntryIndex; } private: const TrieMap *const mTrieMap; const int mKey; Loading Loading
native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +44 −0 Original line number Diff line number Diff line Loading @@ -16,6 +16,8 @@ #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" namespace latinime { bool LanguageModelDictContent::save(FILE *const file) const { Loading Loading @@ -118,4 +120,46 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord return bitmapEntryIndex; } bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) { for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", level, MAX_PREV_WORD_COUNT_FOR_N_GRAM); return false; } const ProbabilityEntry probabilityEntry = ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) { const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( probabilityEntry.getHistoricalInfo(), headerPolicy); if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) { // Update the entry. const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo); if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), bitmapEntryIndex)) { return false; } } else { // Remove the entry. if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { return false; } continue; } } if (!probabilityEntry.representsBeginningOfSentence()) { outEntryCounts[level] += 1; } if (!entry.hasNextLevelMap()) { continue; } if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1, headerPolicy, outEntryCounts)) { return false; } } return true; } } // namespace latinime
native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +10 −0 Original line number Diff line number Diff line Loading @@ -29,6 +29,8 @@ namespace latinime { class HeaderPolicy; /** * Class representing language model. * Loading Loading @@ -73,6 +75,12 @@ class LanguageModelDictContent { bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId); bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy, int *const outEntryCounts) { return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */, headerPolicy, outEntryCounts); } private: DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent); Loading @@ -84,6 +92,8 @@ class LanguageModelDictContent { int *const outNgramCount); int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts); }; } // namespace latinime #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +8 −21 Original line number Diff line number Diff line Loading @@ -161,29 +161,15 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA const ProbabilityEntry originalProbabilityEntry = mBuffers->getLanguageModelDictContent()->getProbabilityEntry( toBeUpdatedPtNodeParams->getTerminalId()); if (originalProbabilityEntry.hasHistoricalInfo()) { const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy); const ProbabilityEntry probabilityEntry(originalProbabilityEntry.getFlags(), &historicalInfo); if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) { AKLOGE("Cannot write updated probability entry. terminalId: %d", toBeUpdatedPtNodeParams->getTerminalId()); return false; if (originalProbabilityEntry.isValid()) { *outNeedsToKeepPtNode = true; return true; } const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy); if (!isValid) { if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) { AKLOGE("Cannot mark PtNode as willBecomeNonTerminal."); return false; } } *outNeedsToKeepPtNode = isValid; } else { // No need to update probability. *outNeedsToKeepPtNode = true; } *outNeedsToKeepPtNode = false; return true; } Loading Loading @@ -380,6 +366,7 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition( isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); } // TODO: Move probability handling code to LanguageModelDictContent. const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( const ProbabilityEntry *const originalProbabilityEntry, const ProbabilityEntry *const probabilityEntry) const { Loading
native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +6 −0 Original line number Diff line number Diff line Loading @@ -85,6 +85,12 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy, &shortcutPolicy); int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy, entryCountTable)) { AKLOGE("Failed to update probabilities in language model dict content."); return false; } DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader); readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); DynamicPtGcEventListeners Loading
native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h +4 −0 Original line number Diff line number Diff line Loading @@ -84,6 +84,10 @@ class TrieMap { return mValue; } AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const { return mNextLevelBitmapEntryIndex; } private: const TrieMap *const mTrieMap; const int mKey; Loading