Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 47fc656c authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi
Browse files

Use EntryCounters during GC.

Bug: 14425059
Change-Id: I61eb798686dc753fb6c0fe99a0719c1732198f30
parent e8750d97
Loading
Loading
Loading
Loading
+20 −23
Original line number Diff line number Diff line
@@ -23,8 +23,6 @@

namespace latinime {

const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0;
const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1;
const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;

bool LanguageModelDictContent::save(FILE *const file) const {
@@ -33,10 +31,9 @@ bool LanguageModelDictContent::save(FILE *const file) const {

bool LanguageModelDictContent::runGC(
        const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
        const LanguageModelDictContent *const originalContent,
        int *const outNgramCount) {
        const LanguageModelDictContent *const originalContent) {
    return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(),
            0 /* nextLevelBitmapEntryIndex */, outNgramCount);
            0 /* nextLevelBitmapEntryIndex */);
}

const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds,
@@ -143,18 +140,23 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt
    return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
}

bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
        const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy,
        int *const outEntryCounts) {
    for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
        if (entryCounts[i] <= maxEntryCounts[i]) {
            outEntryCounts[i] = entryCounts[i];
bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCounts,
        const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
        MutableEntryCounters *const outEntryCounters) {
    for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) {
        const int totalWordCount = prevWordCount + 1;
        if (currentEntryCounts.getNgramCount(totalWordCount)
                <= maxEntryCounts.getNgramCount(totalWordCount)) {
            outEntryCounters->setNgramCount(totalWordCount,
                    currentEntryCounts.getNgramCount(totalWordCount));
            continue;
        }
        if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i,
                &outEntryCounts[i])) {
        int entryCount = 0;
        if (!turncateEntriesInSpecifiedLevel(headerPolicy,
                maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) {
            return false;
        }
        outEntryCounters->setNgramCount(totalWordCount, entryCount);
    }
    return true;
}
@@ -208,8 +210,7 @@ const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(

bool LanguageModelDictContent::runGCInner(
        const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
        const TrieMap::TrieMapRange trieMapRange,
        const int nextLevelBitmapEntryIndex, int *const outNgramCount) {
        const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex) {
    for (auto &entry : trieMapRange) {
        const auto it = terminalIdMap->find(entry.key());
        if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) {
@@ -219,13 +220,9 @@ bool LanguageModelDictContent::runGCInner(
        if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) {
            return false;
        }
        if (outNgramCount) {
            *outNgramCount += 1;
        }
        if (entry.hasNextLevelMap()) {
            if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(),
                    mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex),
                    outNgramCount)) {
                    mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex))) {
                return false;
            }
        }
@@ -268,7 +265,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord

bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
        const int prevWordCount, const HeaderPolicy *const headerPolicy,
        int *const outEntryCounts) {
        MutableEntryCounters *const outEntryCounters) {
    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
        if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
            AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
@@ -305,13 +302,13 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
            }
        }
        if (!probabilityEntry.representsBeginningOfSentence()) {
            outEntryCounts[prevWordCount] += 1;
            outEntryCounters->incrementNgramCount(prevWordCount + 1);
        }
        if (!entry.hasNextLevelMap()) {
            continue;
        }
        if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(),
                prevWordCount + 1, headerPolicy, outEntryCounts)) {
                prevWordCount + 1, headerPolicy, outEntryCounters)) {
            return false;
        }
    }
+7 −15
Original line number Diff line number Diff line
@@ -41,9 +41,6 @@ class HeaderPolicy;
 */
class LanguageModelDictContent {
 public:
    static const int UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;
    static const int BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;

    // Pair of word id and probability entry used for iteration.
    class WordIdAndProbabilityEntry {
     public:
@@ -127,8 +124,7 @@ class LanguageModelDictContent {
    bool save(FILE *const file) const;

    bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
            const LanguageModelDictContent *const originalContent,
            int *const outNgramCount);
            const LanguageModelDictContent *const originalContent);

    const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
            const HeaderPolicy *const headerPolicy) const;
@@ -156,17 +152,14 @@ class LanguageModelDictContent {
    EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;

    bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
            int *const outEntryCounts) {
        for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
            outEntryCounts[i] = 0;
        }
            MutableEntryCounters *const outEntryCounters) {
        return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
                0 /* prevWordCount */, headerPolicy, outEntryCounts);
                0 /* prevWordCount */, headerPolicy, outEntryCounters);
    }

    // entryCounts should be created by updateAllProbabilityEntries.
    bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
    bool truncateEntries(const EntryCounts &currentEntryCounts, const EntryCounts &maxEntryCounts,
            const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);

    bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId,
            const bool isValid, const HistoricalInfo historicalInfo,
@@ -206,12 +199,11 @@ class LanguageModelDictContent {
    const bool mHasHistoricalInfo;

    bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
            const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
            int *const outNgramCount);
            const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex);
    int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
    int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
    bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
            const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
    bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
            const int maxEntryCount, const int targetLevel, int *const outEntryCount);
    bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
+12 −25
Original line number Diff line number Diff line
@@ -57,16 +57,14 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
    Ver4DictBuffers::Ver4DictBuffersPtr dictBuffers(
            Ver4DictBuffers::createVer4DictBuffers(headerPolicy,
                    Ver4DictConstants::MAX_DICTIONARY_SIZE));
    int unigramCount = 0;
    int bigramCount = 0;
    if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &unigramCount, &bigramCount)) {
    MutableEntryCounters entryCounters;
    if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &entryCounters)) {
        return false;
    }
    BufferWithExtendableBuffer headerBuffer(
            BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
    if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
            EntryCounts(unigramCount, bigramCount, 0 /* trigramCount */),
            0 /* extendedRegionSize */, &headerBuffer)) {
            entryCounters.getEntryCounts(), 0 /* extendedRegionSize */, &headerBuffer)) {
        return false;
    }
    return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@@ -74,7 +72,7 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr

bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
        const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite,
        int *const outUnigramCount, int *const outBigramCount) {
        MutableEntryCounters *const outEntryCounters) {
    Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer());
    Ver4PtNodeArrayReader ptNodeArrayReader(mBuffers->getTrieBuffer());
    Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(),
@@ -82,24 +80,17 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
    Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(),
            mBuffers, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy);

    int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
    if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntriesForGC(
            headerPolicy, entryCountTable)) {
            headerPolicy, outEntryCounters)) {
        AKLOGE("Failed to update probabilities in language model dict content.");
        return false;
    }
    if (headerPolicy->isDecayingDict()) {
        int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
        maxEntryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] =
                headerPolicy->getMaxUnigramCount();
        maxEntryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] =
                headerPolicy->getMaxBigramCount();
        for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
            // TODO: Have max n-gram count.
            maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
        }
        if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
                maxEntryCountTable, headerPolicy, entryCountTable)) {
        const EntryCounts maxEntryCounts(headerPolicy->getMaxUnigramCount(),
                headerPolicy->getMaxBigramCount(), headerPolicy->getMaxTrigramCount());
        if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(
                outEntryCounters->getEntryCounts(), maxEntryCounts, headerPolicy,
                outEntryCounters)) {
            AKLOGE("Failed to truncate entries in language model dict content.");
            return false;
        }
@@ -143,9 +134,9 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
            &terminalIdMap)) {
        return false;
    }
    // Run GC for probability dict content.
    // Run GC for language model dict content.
    if (!buffersToWrite->getMutableLanguageModelDictContent()->runGC(&terminalIdMap,
            mBuffers->getLanguageModelDictContent(), nullptr /* outNgramCount */)) {
            mBuffers->getLanguageModelDictContent())) {
        return false;
    }
    // Run GC for shortcut dict content.
@@ -168,10 +159,6 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
            &traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) {
        return false;
    }
    *outUnigramCount =
            entryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
    *outBigramCount =
            entryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
    return true;
}

+1 −2
Original line number Diff line number Diff line
@@ -67,8 +67,7 @@ class Ver4PatriciaTrieWritingHelper {
    };

    bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy,
            Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount,
            int *const outBigramCount);
            Ver4DictBuffers *const buffersToWrite, MutableEntryCounters *const outEntryCounters);

    Ver4DictBuffers *const mBuffers;
};
+14 −0
Original line number Diff line number Diff line
@@ -46,6 +46,13 @@ class EntryCounts final {
        return mEntryCounts[2];
    }

    int getNgramCount(const size_t n) const {
        if (n < 1 || n > mEntryCounts.size()) {
            return 0;
        }
        return mEntryCounts[n - 1];
    }

 private:
    DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts);

@@ -110,6 +117,13 @@ class MutableEntryCounters final {
        --mEntryCounters[n - 1];
    }

    void setNgramCount(const size_t n, const int count) {
        if (n < 1 || n > mEntryCounters.size()) {
            return;
        }
        mEntryCounters[n - 1] = count;
    }

 private:
    DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters);