Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 9aa66991 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi
Browse files

Update probabilities in language model dict content for GC.

Bug: 14425059
Change-Id: I354408afd8e5c1955ff0acea3d0243d628fe3843
parent cdc260b7
Loading
Loading
Loading
Loading
+44 −0
Original line number Diff line number Diff line
@@ -16,6 +16,8 @@

#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"

#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"

namespace latinime {

bool LanguageModelDictContent::save(FILE *const file) const {
@@ -118,4 +120,46 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
    return bitmapEntryIndex;
}

bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex,
        const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
        if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
            AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
                    level, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
            return false;
        }
        const ProbabilityEntry probabilityEntry =
                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
        if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) {
            const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
                    probabilityEntry.getHistoricalInfo(), headerPolicy);
            if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
                // Update the entry.
                const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
                if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
                        bitmapEntryIndex)) {
                    return false;
                }
            } else {
                // Remove the entry.
                if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
                    return false;
                }
                continue;
            }
        }
        if (!probabilityEntry.representsBeginningOfSentence()) {
            outEntryCounts[level] += 1;
        }
        if (!entry.hasNextLevelMap()) {
            continue;
        }
        if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
                headerPolicy, outEntryCounts)) {
            return false;
        }
    }
    return true;
}

} // namespace latinime
+10 −0
Original line number Diff line number Diff line
@@ -29,6 +29,8 @@

namespace latinime {

class HeaderPolicy;

/**
 * Class representing language model.
 *
@@ -73,6 +75,12 @@ class LanguageModelDictContent {

    bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);

    bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
            int *const outEntryCounts) {
        return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
                headerPolicy, outEntryCounts);
    }

 private:
    DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);

@@ -84,6 +92,8 @@ class LanguageModelDictContent {
            int *const outNgramCount);
    int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
    int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
    bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
+8 −21
Original line number Diff line number Diff line
@@ -161,29 +161,15 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA
    const ProbabilityEntry originalProbabilityEntry =
            mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
                    toBeUpdatedPtNodeParams->getTerminalId());
    if (originalProbabilityEntry.hasHistoricalInfo()) {
        const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
                originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy);
        const ProbabilityEntry probabilityEntry(originalProbabilityEntry.getFlags(),
                &historicalInfo);
        if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
                toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) {
            AKLOGE("Cannot write updated probability entry. terminalId: %d",
                    toBeUpdatedPtNodeParams->getTerminalId());
            return false;
    if (originalProbabilityEntry.isValid()) {
        *outNeedsToKeepPtNode = true;
        return true;
    }
        const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy);
        if (!isValid) {
    if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
        AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
        return false;
    }
        }
        *outNeedsToKeepPtNode = isValid;
    } else {
        // No need to update probability.
        *outNeedsToKeepPtNode = true;
    }
    *outNeedsToKeepPtNode = false;
    return true;
}

@@ -380,6 +366,7 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
            isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
}

// TODO: Move probability handling code to LanguageModelDictContent.
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
        const ProbabilityEntry *const originalProbabilityEntry,
        const ProbabilityEntry *const probabilityEntry) const {
+6 −0
Original line number Diff line number Diff line
@@ -85,6 +85,12 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
            mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy,
            &shortcutPolicy);

    int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
    if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy,
            entryCountTable)) {
        AKLOGE("Failed to update probabilities in language model dict content.");
        return false;
    }
    DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
    readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
    DynamicPtGcEventListeners
+4 −0
Original line number Diff line number Diff line
@@ -84,6 +84,10 @@ class TrieMap {
                return mValue;
            }

            AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const {
                return mNextLevelBitmapEntryIndex;
            }

         private:
            const TrieMap *const mTrieMap;
            const int mKey;