Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 094a8a68 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android (Google) Code Review
Browse files

Merge "Update probabilities in language model dict content for GC."

parents 19a7012d 9aa66991
Loading
Loading
Loading
Loading
+44 −0
Original line number Diff line number Diff line
@@ -16,6 +16,8 @@

#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"

#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"

namespace latinime {

bool LanguageModelDictContent::save(FILE *const file) const {
@@ -118,4 +120,46 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
    return bitmapEntryIndex;
}

bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex,
        const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
        if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
            AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
                    level, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
            return false;
        }
        const ProbabilityEntry probabilityEntry =
                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
        if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) {
            const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
                    probabilityEntry.getHistoricalInfo(), headerPolicy);
            if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
                // Update the entry.
                const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
                if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
                        bitmapEntryIndex)) {
                    return false;
                }
            } else {
                // Remove the entry.
                if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
                    return false;
                }
                continue;
            }
        }
        if (!probabilityEntry.representsBeginningOfSentence()) {
            outEntryCounts[level] += 1;
        }
        if (!entry.hasNextLevelMap()) {
            continue;
        }
        if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
                headerPolicy, outEntryCounts)) {
            return false;
        }
    }
    return true;
}

} // namespace latinime
+10 −0
Original line number Diff line number Diff line
@@ -29,6 +29,8 @@

namespace latinime {

class HeaderPolicy;

/**
 * Class representing language model.
 *
@@ -73,6 +75,12 @@ class LanguageModelDictContent {

    bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);

    bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
            int *const outEntryCounts) {
        return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
                headerPolicy, outEntryCounts);
    }

 private:
    DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);

@@ -84,6 +92,8 @@ class LanguageModelDictContent {
            int *const outNgramCount);
    int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
    int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
    bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
+8 −21
Original line number Diff line number Diff line
@@ -161,29 +161,15 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA
    const ProbabilityEntry originalProbabilityEntry =
            mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
                    toBeUpdatedPtNodeParams->getTerminalId());
    if (originalProbabilityEntry.hasHistoricalInfo()) {
        const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
                originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy);
        const ProbabilityEntry probabilityEntry(originalProbabilityEntry.getFlags(),
                &historicalInfo);
        if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
                toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) {
            AKLOGE("Cannot write updated probability entry. terminalId: %d",
                    toBeUpdatedPtNodeParams->getTerminalId());
            return false;
    if (originalProbabilityEntry.isValid()) {
        *outNeedsToKeepPtNode = true;
        return true;
    }
        const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy);
        if (!isValid) {
    if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
        AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
        return false;
    }
        }
        *outNeedsToKeepPtNode = isValid;
    } else {
        // No need to update probability.
        *outNeedsToKeepPtNode = true;
    }
    *outNeedsToKeepPtNode = false;
    return true;
}

@@ -380,6 +366,7 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
            isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
}

// TODO: Move probability handling code to LanguageModelDictContent.
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
        const ProbabilityEntry *const originalProbabilityEntry,
        const ProbabilityEntry *const probabilityEntry) const {
+6 −0
Original line number Diff line number Diff line
@@ -85,6 +85,12 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
            mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy,
            &shortcutPolicy);

    int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
    if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy,
            entryCountTable)) {
        AKLOGE("Failed to update probabilities in language model dict content.");
        return false;
    }
    DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
    readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
    DynamicPtGcEventListeners
+4 −0
Original line number Diff line number Diff line
@@ -84,6 +84,10 @@ class TrieMap {
                return mValue;
            }

            AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const {
                return mNextLevelBitmapEntryIndex;
            }

         private:
            const TrieMap *const mTrieMap;
            const int mKey;