Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 39467767 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android (Google) Code Review
Browse files

Merge "Truncate entries in language model dict content."

parents 094a8a68 063f86d4
Loading
Loading
Loading
Loading
+99 −0
Original line number Diff line number Diff line
@@ -16,6 +16,9 @@

#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"

#include <algorithm>
#include <cstring>

#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"

namespace latinime {
@@ -68,6 +71,19 @@ bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView
    return mTrieMap.remove(wordId, bitmapEntryIndex);
}

bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
        const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) {
    for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
        if (entryCounts[i] <= maxEntryCounts[i]) {
            continue;
        }
        if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i)) {
            return false;
        }
    }
    return true;
}

bool LanguageModelDictContent::runGCInner(
        const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
        const TrieMap::TrieMapRange trieMapRange,
@@ -162,4 +178,87 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmap
    return true;
}

bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
        const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel) {
    std::vector<int> prevWordIds;
    std::vector<EntryInfoToTurncate> entryInfoVector;
    if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(),
            &prevWordIds, &entryInfoVector)) {
        return false;
    }
    if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) {
        return true;
    }
    const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount;
    std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove,
            entryInfoVector.end(),
            EntryInfoToTurncate::Comparator());
    for (int i = 0; i < entryCountToRemove; ++i) {
        const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
        if (!removeNgramProbabilityEntry(
                WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) {
            return false;
        }
    }
    return true;
}

bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy,
        const int targetLevel, const int bitmapEntryIndex,  std::vector<int> *const prevWordIds,
        std::vector<EntryInfoToTurncate> *const outEntryInfo) const {
    const int currentLevel = prevWordIds->size();
    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
        if (currentLevel < targetLevel) {
            if (!entry.hasNextLevelMap()) {
                continue;
            }
            prevWordIds->push_back(entry.key());
            if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(),
                    prevWordIds, outEntryInfo)) {
                return false;
            }
            prevWordIds->pop_back();
            continue;
        }
        const ProbabilityEntry probabilityEntry =
                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
        const int probability = (mHasHistoricalInfo) ?
                ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
                        headerPolicy) : probabilityEntry.getProbability();
        outEntryInfo->emplace_back(probability,
                probabilityEntry.getHistoricalInfo()->getTimeStamp(),
                entry.key(), targetLevel, prevWordIds->data());
    }
    return true;
}

bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
        const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const {
    if (left.mProbability != right.mProbability) {
        return left.mProbability < right.mProbability;
    }
    if (left.mTimestamp != right.mTimestamp) {
        return left.mTimestamp > right.mTimestamp;
    }
    if (left.mKey != right.mKey) {
        return left.mKey < right.mKey;
    }
    if (left.mEntryLevel != right.mEntryLevel) {
        return left.mEntryLevel > right.mEntryLevel;
    }
    for (int i = 0; i < left.mEntryLevel; ++i) {
        if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) {
            return left.mPrevWordIds[i] < right.mPrevWordIds[i];
        }
    }
    // left and rigth represent the same entry.
    return false;
}

LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability,
        const int timestamp, const int key, const int entryLevel, const int *const prevWordIds)
        : mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) {
    memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0]));
}

} // namespace latinime
+36 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@
#define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H

#include <cstdio>
#include <vector>

#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
@@ -77,13 +78,43 @@ class LanguageModelDictContent {

    bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
            int *const outEntryCounts) {
        for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
            outEntryCounts[i] = 0;
        }
        return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
                headerPolicy, outEntryCounts);
    }

    // entryCounts should be created by updateAllProbabilityEntries.
    bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
            const HeaderPolicy *const headerPolicy);

 private:
    DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);

    class EntryInfoToTurncate {
     public:
        class Comparator {
         public:
            bool operator()(const EntryInfoToTurncate &left,
                    const EntryInfoToTurncate &right) const;
         private:
            DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
        };

        EntryInfoToTurncate(const int probability, const int timestamp, const int key,
                const int entryLevel, const int *const prevWordIds);

        int mProbability;
        int mTimestamp;
        int mKey;
        int mEntryLevel;
        int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];

     private:
        DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
    };

    TrieMap mTrieMap;
    const bool mHasHistoricalInfo;

@@ -94,6 +125,11 @@ class LanguageModelDictContent {
    int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
    bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
    bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
            const int maxEntryCount, const int targetLevel);
    bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
            const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
            std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
+17 −0
Original line number Diff line number Diff line
@@ -91,6 +91,21 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
        AKLOGE("Failed to update probabilities in language model dict content.");
        return false;
    }
    if (headerPolicy->isDecayingDict()) {
        int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
        maxEntryCountTable[0] = headerPolicy->getMaxUnigramCount();
        maxEntryCountTable[1] = headerPolicy->getMaxBigramCount();
        for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
            // TODO: Have max n-gram count.
            maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
        }
        if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
                maxEntryCountTable,  headerPolicy)) {
            AKLOGE("Failed to truncate entries in language model dict content.");
            return false;
        }
    }

    DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
    readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
    DynamicPtGcEventListeners
@@ -193,6 +208,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
    return true;
}

// TODO: Remove.
bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
        const Ver4PatriciaTrieNodeReader *const ptNodeReader,
        Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) {
@@ -233,6 +249,7 @@ bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
    return true;
}

// TODO: Remove.
bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) {
    const TerminalPositionLookupTable *const terminalPosLookupTable =
            mBuffers->getTerminalPositionLookupTable();