Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 758d0936 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi
Browse files

Get entry count after truncation using LanguageModelDictContent.

Bug: 14425059
Change-Id: I41b237c1c22c21740946d52e3be9d6f963c9cd54
parent c7f1de82
Loading
Loading
Loading
Loading
+12 −3
Original line number Diff line number Diff line
@@ -23,6 +23,9 @@

namespace latinime {

const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0;
const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1;

bool LanguageModelDictContent::save(FILE *const file) const {
    return mTrieMap.save(file);
}
@@ -78,12 +81,15 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt
}

bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
        const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) {
        const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy,
        int *const outEntryCounts) {
    for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
        if (entryCounts[i] <= maxEntryCounts[i]) {
            outEntryCounts[i] = entryCounts[i];
            continue;
        }
        if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i)) {
        if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i,
                &outEntryCounts[i])) {
            return false;
        }
    }
@@ -185,7 +191,8 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmap
}

bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
        const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel) {
        const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel,
        int *const outEntryCount) {
    std::vector<int> prevWordIds;
    std::vector<EntryInfoToTurncate> entryInfoVector;
    if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(),
@@ -193,8 +200,10 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
        return false;
    }
    if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) {
        *outEntryCount = static_cast<int>(entryInfoVector.size());
        return true;
    }
    *outEntryCount = maxEntryCount;
    const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount;
    std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove,
            entryInfoVector.end(),
+5 −2
Original line number Diff line number Diff line
@@ -39,6 +39,9 @@ class HeaderPolicy;
 */
class LanguageModelDictContent {
 public:
    static const int UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;
    static const int BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;

    // Pair of word id and probability entry used for iteration.
    class WordIdAndProbabilityEntry {
     public:
@@ -158,7 +161,7 @@ class LanguageModelDictContent {

    // entryCounts should be created by updateAllProbabilityEntries.
    bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
            const HeaderPolicy *const headerPolicy);
            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);

 private:
    DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
@@ -197,7 +200,7 @@ class LanguageModelDictContent {
    bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
    bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
            const int maxEntryCount, const int targetLevel);
            const int maxEntryCount, const int targetLevel, int *const outEntryCount);
    bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
            const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
            std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
+9 −4
Original line number Diff line number Diff line
@@ -93,14 +93,16 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
    }
    if (headerPolicy->isDecayingDict()) {
        int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
        maxEntryCountTable[0] = headerPolicy->getMaxUnigramCount();
        maxEntryCountTable[1] = headerPolicy->getMaxBigramCount();
        maxEntryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] =
                headerPolicy->getMaxUnigramCount();
        maxEntryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] =
                headerPolicy->getMaxBigramCount();
        for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
            // TODO: Have max n-gram count.
            maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
        }
        if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
                maxEntryCountTable,  headerPolicy)) {
                maxEntryCountTable, headerPolicy, entryCountTable)) {
            AKLOGE("Failed to truncate entries in language model dict content.");
            return false;
        }
@@ -204,7 +206,10 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
            &traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) {
        return false;
    }
    *outUnigramCount = traversePolicyToUpdateAllPositionFields.getUnigramCount();
    *outUnigramCount =
            entryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
    *outBigramCount =
            entryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
    return true;
}