Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 063f86d4 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi
Browse files

Truncate entries in language model dict content.

Bug: 14425059

Change-Id: I023c1d5109a2c43fcea3bb11a0fd7198c82891ba
parent 9aa66991
Loading
Loading
Loading
Loading
+99 −0
Original line number Diff line number Diff line
@@ -16,6 +16,9 @@

#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"

#include <algorithm>
#include <cstring>

#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"

namespace latinime {
@@ -68,6 +71,19 @@ bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView
    return mTrieMap.remove(wordId, bitmapEntryIndex);
}

bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
        const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) {
    for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
        if (entryCounts[i] <= maxEntryCounts[i]) {
            continue;
        }
        if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i)) {
            return false;
        }
    }
    return true;
}

bool LanguageModelDictContent::runGCInner(
        const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
        const TrieMap::TrieMapRange trieMapRange,
@@ -162,4 +178,87 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmap
    return true;
}

bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
        const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel) {
    std::vector<int> prevWordIds;
    std::vector<EntryInfoToTurncate> entryInfoVector;
    if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(),
            &prevWordIds, &entryInfoVector)) {
        return false;
    }
    if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) {
        return true;
    }
    const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount;
    std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove,
            entryInfoVector.end(),
            EntryInfoToTurncate::Comparator());
    for (int i = 0; i < entryCountToRemove; ++i) {
        const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
        if (!removeNgramProbabilityEntry(
                WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) {
            return false;
        }
    }
    return true;
}

bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy,
        const int targetLevel, const int bitmapEntryIndex,  std::vector<int> *const prevWordIds,
        std::vector<EntryInfoToTurncate> *const outEntryInfo) const {
    const int currentLevel = prevWordIds->size();
    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
        if (currentLevel < targetLevel) {
            if (!entry.hasNextLevelMap()) {
                continue;
            }
            prevWordIds->push_back(entry.key());
            if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(),
                    prevWordIds, outEntryInfo)) {
                return false;
            }
            prevWordIds->pop_back();
            continue;
        }
        const ProbabilityEntry probabilityEntry =
                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
        const int probability = (mHasHistoricalInfo) ?
                ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
                        headerPolicy) : probabilityEntry.getProbability();
        outEntryInfo->emplace_back(probability,
                probabilityEntry.getHistoricalInfo()->getTimeStamp(),
                entry.key(), targetLevel, prevWordIds->data());
    }
    return true;
}

bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
        const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const {
    if (left.mProbability != right.mProbability) {
        return left.mProbability < right.mProbability;
    }
    if (left.mTimestamp != right.mTimestamp) {
        return left.mTimestamp > right.mTimestamp;
    }
    if (left.mKey != right.mKey) {
        return left.mKey < right.mKey;
    }
    if (left.mEntryLevel != right.mEntryLevel) {
        return left.mEntryLevel > right.mEntryLevel;
    }
    for (int i = 0; i < left.mEntryLevel; ++i) {
        if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) {
            return left.mPrevWordIds[i] < right.mPrevWordIds[i];
        }
    }
    // left and rigth represent the same entry.
    return false;
}

LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability,
        const int timestamp, const int key, const int entryLevel, const int *const prevWordIds)
        : mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) {
    memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0]));
}

} // namespace latinime
+36 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@
#define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H

#include <cstdio>
#include <vector>

#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
@@ -77,13 +78,43 @@ class LanguageModelDictContent {

    bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
            int *const outEntryCounts) {
        for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
            outEntryCounts[i] = 0;
        }
        return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
                headerPolicy, outEntryCounts);
    }

    // entryCounts should be created by updateAllProbabilityEntries.
    bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
            const HeaderPolicy *const headerPolicy);

 private:
    DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);

    class EntryInfoToTurncate {
     public:
        class Comparator {
         public:
            bool operator()(const EntryInfoToTurncate &left,
                    const EntryInfoToTurncate &right) const;
         private:
            DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
        };

        EntryInfoToTurncate(const int probability, const int timestamp, const int key,
                const int entryLevel, const int *const prevWordIds);

        int mProbability;
        int mTimestamp;
        int mKey;
        int mEntryLevel;
        int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];

     private:
        DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
    };

    TrieMap mTrieMap;
    const bool mHasHistoricalInfo;

@@ -94,6 +125,11 @@ class LanguageModelDictContent {
    int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
    bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
    bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
            const int maxEntryCount, const int targetLevel);
    bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
            const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
            std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
+17 −0
Original line number Diff line number Diff line
@@ -91,6 +91,21 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
        AKLOGE("Failed to update probabilities in language model dict content.");
        return false;
    }
    if (headerPolicy->isDecayingDict()) {
        int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
        maxEntryCountTable[0] = headerPolicy->getMaxUnigramCount();
        maxEntryCountTable[1] = headerPolicy->getMaxBigramCount();
        for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
            // TODO: Have max n-gram count.
            maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
        }
        if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
                maxEntryCountTable,  headerPolicy)) {
            AKLOGE("Failed to truncate entries in language model dict content.");
            return false;
        }
    }

    DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
    readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
    DynamicPtGcEventListeners
@@ -193,6 +208,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
    return true;
}

// TODO: Remove.
bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
        const Ver4PatriciaTrieNodeReader *const ptNodeReader,
        Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) {
@@ -233,6 +249,7 @@ bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
    return true;
}

// TODO: Remove.
bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) {
    const TerminalPositionLookupTable *const terminalPosLookupTable =
            mBuffers->getTerminalPositionLookupTable();