Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit fa1e65cb authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android (Google) Code Review
Browse files

Merge "Use EntryCounters during GC."

parents c51b9b5b 47fc656c
Loading
Loading
Loading
Loading
+20 −23
Original line number Diff line number Diff line
@@ -23,8 +23,6 @@

namespace latinime {

const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0;
const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1;
const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;

bool LanguageModelDictContent::save(FILE *const file) const {
@@ -33,10 +31,9 @@ bool LanguageModelDictContent::save(FILE *const file) const {

bool LanguageModelDictContent::runGC(
        const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
        const LanguageModelDictContent *const originalContent,
        int *const outNgramCount) {
        const LanguageModelDictContent *const originalContent) {
    return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(),
            0 /* nextLevelBitmapEntryIndex */, outNgramCount);
            0 /* nextLevelBitmapEntryIndex */);
}

const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds,
@@ -143,18 +140,23 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt
    return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
}

bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
        const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy,
        int *const outEntryCounts) {
    for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
        if (entryCounts[i] <= maxEntryCounts[i]) {
            outEntryCounts[i] = entryCounts[i];
bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCounts,
        const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
        MutableEntryCounters *const outEntryCounters) {
    for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) {
        const int totalWordCount = prevWordCount + 1;
        if (currentEntryCounts.getNgramCount(totalWordCount)
                <= maxEntryCounts.getNgramCount(totalWordCount)) {
            outEntryCounters->setNgramCount(totalWordCount,
                    currentEntryCounts.getNgramCount(totalWordCount));
            continue;
        }
        if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i,
                &outEntryCounts[i])) {
        int entryCount = 0;
        if (!turncateEntriesInSpecifiedLevel(headerPolicy,
                maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) {
            return false;
        }
        outEntryCounters->setNgramCount(totalWordCount, entryCount);
    }
    return true;
}
@@ -208,8 +210,7 @@ const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(

bool LanguageModelDictContent::runGCInner(
        const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
        const TrieMap::TrieMapRange trieMapRange,
        const int nextLevelBitmapEntryIndex, int *const outNgramCount) {
        const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex) {
    for (auto &entry : trieMapRange) {
        const auto it = terminalIdMap->find(entry.key());
        if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) {
@@ -219,13 +220,9 @@ bool LanguageModelDictContent::runGCInner(
        if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) {
            return false;
        }
        if (outNgramCount) {
            *outNgramCount += 1;
        }
        if (entry.hasNextLevelMap()) {
            if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(),
                    mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex),
                    outNgramCount)) {
                    mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex))) {
                return false;
            }
        }
@@ -268,7 +265,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord

bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
        const int prevWordCount, const HeaderPolicy *const headerPolicy,
        int *const outEntryCounts) {
        MutableEntryCounters *const outEntryCounters) {
    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
        if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
            AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
@@ -305,13 +302,13 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
            }
        }
        if (!probabilityEntry.representsBeginningOfSentence()) {
            outEntryCounts[prevWordCount] += 1;
            outEntryCounters->incrementNgramCount(prevWordCount + 1);
        }
        if (!entry.hasNextLevelMap()) {
            continue;
        }
        if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(),
                prevWordCount + 1, headerPolicy, outEntryCounts)) {
                prevWordCount + 1, headerPolicy, outEntryCounters)) {
            return false;
        }
    }
+7 −15
Original line number Diff line number Diff line
@@ -41,9 +41,6 @@ class HeaderPolicy;
 */
class LanguageModelDictContent {
 public:
    static const int UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;
    static const int BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;

    // Pair of word id and probability entry used for iteration.
    class WordIdAndProbabilityEntry {
     public:
@@ -127,8 +124,7 @@ class LanguageModelDictContent {
    bool save(FILE *const file) const;

    bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
            const LanguageModelDictContent *const originalContent,
            int *const outNgramCount);
            const LanguageModelDictContent *const originalContent);

    const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
            const HeaderPolicy *const headerPolicy) const;
@@ -156,17 +152,14 @@ class LanguageModelDictContent {
    EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;

    bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
            int *const outEntryCounts) {
        for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
            outEntryCounts[i] = 0;
        }
            MutableEntryCounters *const outEntryCounters) {
        return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
                0 /* prevWordCount */, headerPolicy, outEntryCounts);
                0 /* prevWordCount */, headerPolicy, outEntryCounters);
    }

    // entryCounts should be created by updateAllProbabilityEntries.
    bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
    bool truncateEntries(const EntryCounts &currentEntryCounts, const EntryCounts &maxEntryCounts,
            const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);

    bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId,
            const bool isValid, const HistoricalInfo historicalInfo,
@@ -206,12 +199,11 @@ class LanguageModelDictContent {
    const bool mHasHistoricalInfo;

    bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
            const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
            int *const outNgramCount);
            const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex);
    int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
    int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
    bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
            const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
    bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
            const int maxEntryCount, const int targetLevel, int *const outEntryCount);
    bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
+12 −25
Original line number Diff line number Diff line
@@ -57,16 +57,14 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
    Ver4DictBuffers::Ver4DictBuffersPtr dictBuffers(
            Ver4DictBuffers::createVer4DictBuffers(headerPolicy,
                    Ver4DictConstants::MAX_DICTIONARY_SIZE));
    int unigramCount = 0;
    int bigramCount = 0;
    if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &unigramCount, &bigramCount)) {
    MutableEntryCounters entryCounters;
    if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &entryCounters)) {
        return false;
    }
    BufferWithExtendableBuffer headerBuffer(
            BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
    if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
            EntryCounts(unigramCount, bigramCount, 0 /* trigramCount */),
            0 /* extendedRegionSize */, &headerBuffer)) {
            entryCounters.getEntryCounts(), 0 /* extendedRegionSize */, &headerBuffer)) {
        return false;
    }
    return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@@ -74,7 +72,7 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr

bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
        const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite,
        int *const outUnigramCount, int *const outBigramCount) {
        MutableEntryCounters *const outEntryCounters) {
    Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer());
    Ver4PtNodeArrayReader ptNodeArrayReader(mBuffers->getTrieBuffer());
    Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(),
@@ -82,24 +80,17 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
    Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(),
            mBuffers, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy);

    int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
    if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntriesForGC(
            headerPolicy, entryCountTable)) {
            headerPolicy, outEntryCounters)) {
        AKLOGE("Failed to update probabilities in language model dict content.");
        return false;
    }
    if (headerPolicy->isDecayingDict()) {
        int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
        maxEntryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] =
                headerPolicy->getMaxUnigramCount();
        maxEntryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] =
                headerPolicy->getMaxBigramCount();
        for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
            // TODO: Have max n-gram count.
            maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
        }
        if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
                maxEntryCountTable, headerPolicy, entryCountTable)) {
        const EntryCounts maxEntryCounts(headerPolicy->getMaxUnigramCount(),
                headerPolicy->getMaxBigramCount(), headerPolicy->getMaxTrigramCount());
        if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(
                outEntryCounters->getEntryCounts(), maxEntryCounts, headerPolicy,
                outEntryCounters)) {
            AKLOGE("Failed to truncate entries in language model dict content.");
            return false;
        }
@@ -143,9 +134,9 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
            &terminalIdMap)) {
        return false;
    }
    // Run GC for probability dict content.
    // Run GC for language model dict content.
    if (!buffersToWrite->getMutableLanguageModelDictContent()->runGC(&terminalIdMap,
            mBuffers->getLanguageModelDictContent(), nullptr /* outNgramCount */)) {
            mBuffers->getLanguageModelDictContent())) {
        return false;
    }
    // Run GC for shortcut dict content.
@@ -168,10 +159,6 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
            &traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) {
        return false;
    }
    *outUnigramCount =
            entryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
    *outBigramCount =
            entryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
    return true;
}

+1 −2
Original line number Diff line number Diff line
@@ -67,8 +67,7 @@ class Ver4PatriciaTrieWritingHelper {
    };

    bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy,
            Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount,
            int *const outBigramCount);
            Ver4DictBuffers *const buffersToWrite, MutableEntryCounters *const outEntryCounters);

    Ver4DictBuffers *const mBuffers;
};
+14 −0
Original line number Diff line number Diff line
@@ -46,6 +46,13 @@ class EntryCounts final {
        return mEntryCounts[2];
    }

    int getNgramCount(const size_t n) const {
        if (n < 1 || n > mEntryCounts.size()) {
            return 0;
        }
        return mEntryCounts[n - 1];
    }

 private:
    DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts);

@@ -110,6 +117,13 @@ class MutableEntryCounters final {
        --mEntryCounters[n - 1];
    }

    void setNgramCount(const size_t n, const int count) {
        if (n < 1 || n > mEntryCounters.size()) {
            return;
        }
        mEntryCounters[n - 1] = count;
    }

 private:
    DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters);