Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c9865785 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi
Browse files

Support ngram entry migration.

Bug: 14425059
Change-Id: I98cb9fa303af2d93a0a3512e8732231c564e3c5d
parent 0b8bb0c2
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -629,8 +629,7 @@ static bool latinime_BinaryDictionary_migrateNative(JNIEnv *env, jclass clazz, j
        }
    } while (token != 0);

    // Add bigrams.
    // TODO: Support ngrams.
    // Add ngrams.
    do {
        token = dictionary->getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount);
        const WordProperty wordProperty = dictionary->getWordProperty(
+6 −4
Original line number Diff line number Diff line
@@ -580,10 +580,12 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
                    getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH,
                    bigramWord1CodePoints);
            const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo();
            const int probability = bigramEntry.hasHistoricalInfo() ?
                    ForgettingCurveUtils::decodeProbability(
                            bigramEntry.getHistoricalInfo(), mHeaderPolicy) :
                    bigramEntry.getProbability();
            const int rawBigramProbability = bigramEntry.hasHistoricalInfo()
                    ? ForgettingCurveUtils::decodeProbability(
                            bigramEntry.getHistoricalInfo(), mHeaderPolicy)
                    : bigramEntry.getProbability();
            const int probability = getBigramConditionalProbability(ptNodeParams.getProbability(),
                    ptNodeParams.representsBeginningOfSentence(), rawBigramProbability);
            ngrams.emplace_back(
                    NgramContext(wordCodePoints.data(), wordCodePoints.size(),
                            ptNodeParams.representsBeginningOfSentence()),
+55 −16
Original line number Diff line number Diff line
@@ -140,6 +140,44 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt
    return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
}

std::vector<LanguageModelDictContent::DumppedFullEntryInfo>
        LanguageModelDictContent::exportAllNgramEntriesRelatedToWord(
                const HeaderPolicy *const headerPolicy, const int wordId) const {
    const TrieMap::Result result = mTrieMap.getRoot(wordId);
    if (!result.mIsValid || result.mNextLevelBitmapEntryIndex == TrieMap::INVALID_INDEX) {
        // The word doesn't have any related ngram entries.
        return std::vector<DumppedFullEntryInfo>();
    }
    std::vector<int> prevWordIds = { wordId };
    std::vector<DumppedFullEntryInfo> entries;
    exportAllNgramEntriesRelatedToWordInner(headerPolicy, result.mNextLevelBitmapEntryIndex,
            &prevWordIds, &entries);
    return entries;
}

void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner(
        const HeaderPolicy *const headerPolicy, const int bitmapEntryIndex,
        std::vector<int> *const prevWordIds,
        std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const {
    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
        const int wordId = entry.key();
        const ProbabilityEntry probabilityEntry =
                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
        if (probabilityEntry.isValid()) {
            const WordAttributes wordAttributes = getWordAttributes(
                    WordIdArrayView(*prevWordIds), wordId, headerPolicy);
            outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId,
                    wordAttributes, probabilityEntry);
        }
        if (entry.hasNextLevelMap()) {
            prevWordIds->push_back(wordId);
            exportAllNgramEntriesRelatedToWordInner(headerPolicy,
                    entry.getNextLevelBitmapEntryIndex(), prevWordIds, outBummpedFullEntryInfo);
            prevWordIds->pop_back();
        }
    }
}

bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCounts,
        const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
        MutableEntryCounters *const outEntryCounters) {
@@ -231,25 +269,26 @@ bool LanguageModelDictContent::runGCInner(
}

int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
    if (prevWordIds.empty()) {
        return mTrieMap.getRootBitmapEntryIndex();
    }
    const int lastBitmapEntryIndex =
            getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
    if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
        return TrieMap::INVALID_INDEX;
    int lastBitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
    for (const int wordId : prevWordIds) {
        const TrieMap::Result result = mTrieMap.get(wordId, lastBitmapEntryIndex);
        if (result.mIsValid && result.mNextLevelBitmapEntryIndex != TrieMap::INVALID_INDEX) {
            lastBitmapEntryIndex = result.mNextLevelBitmapEntryIndex;
            continue;
        }
    const int oldestPrevWordId = prevWordIds.lastOrDefault(NOT_A_WORD_ID);
    const TrieMap::Result result = mTrieMap.get(oldestPrevWordId, lastBitmapEntryIndex);
        if (!result.mIsValid) {
        if (!mTrieMap.put(oldestPrevWordId,
                ProbabilityEntry().encode(mHasHistoricalInfo), lastBitmapEntryIndex)) {
            if (!mTrieMap.put(wordId, ProbabilityEntry().encode(mHasHistoricalInfo),
                    lastBitmapEntryIndex)) {
                AKLOGE("Failed to update trie map. wordId: %d, lastBitmapEntryIndex %d", wordId,
                        lastBitmapEntryIndex);
                return TrieMap::INVALID_INDEX;
            }
        }
    return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds.lastOrDefault(NOT_A_WORD_ID),
        lastBitmapEntryIndex = mTrieMap.getNextLevelBitmapEntryIndex(wordId,
                lastBitmapEntryIndex);
    }
    return lastBitmapEntryIndex;
}

int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
    int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
+27 −0
Original line number Diff line number Diff line
@@ -110,6 +110,27 @@ class LanguageModelDictContent {
        const bool mHasHistoricalInfo;
    };

    class DumppedFullEntryInfo {
     public:
        DumppedFullEntryInfo(std::vector<int> &prevWordIds, const int targetWordId,
                const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry)
                : mPrevWordIds(prevWordIds), mTargetWordId(targetWordId),
                  mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {}

        const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); }
        int getTargetWordId() const { return mTargetWordId; }
        const WordAttributes &getWordAttributes() const { return mWordAttributes; }
        const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; }

     private:
        DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo);

        const std::vector<int> mPrevWordIds;
        const int mTargetWordId;
        const WordAttributes mWordAttributes;
        const ProbabilityEntry mProbabilityEntry;
    };

    LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
            const bool hasHistoricalInfo)
            : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
@@ -151,6 +172,9 @@ class LanguageModelDictContent {

    EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;

    std::vector<DumppedFullEntryInfo> exportAllNgramEntriesRelatedToWord(
            const HeaderPolicy *const headerPolicy, const int wordId) const;

    bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
            MutableEntryCounters *const outEntryCounters) {
        return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
@@ -212,6 +236,9 @@ class LanguageModelDictContent {
    const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
            const bool isValid, const HistoricalInfo historicalInfo,
            const HeaderPolicy *const headerPolicy) const;
    void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy,
            const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
            std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const;
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
+31 −21
Original line number Diff line number Diff line
@@ -491,30 +491,37 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
    const int ptNodePos =
            mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
    const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
    const ProbabilityEntry probabilityEntry =
            mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
                    ptNodeParams.getTerminalId());
    const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
    // Fetch bigram information.
    // TODO: Support n-gram.
    const LanguageModelDictContent *const languageModelDictContent =
            mBuffers->getLanguageModelDictContent();
    // Fetch ngram information.
    std::vector<NgramProperty> ngrams;
    const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
    int bigramWord1CodePoints[MAX_WORD_LENGTH];
    for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
            prevWordIds)) {
        const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(),
                MAX_WORD_LENGTH, bigramWord1CodePoints);
    int ngramTargetCodePoints[MAX_WORD_LENGTH];
    int ngramPrevWordsCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH];
    int ngramPrevWordsCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
    bool ngramPrevWordIsBeginningOfSentense[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
    for (const auto entry : languageModelDictContent->exportAllNgramEntriesRelatedToWord(
            mHeaderPolicy, wordId)) {
        const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getTargetWordId(),
                MAX_WORD_LENGTH, ngramTargetCodePoints);
        const WordIdArrayView prevWordIds = entry.getPrevWordIds();
        for (size_t i = 0; i < prevWordIds.size(); ++i) {
            ngramPrevWordsCodePointCount[i] = getCodePointsAndReturnCodePointCount(prevWordIds[i],
                       MAX_WORD_LENGTH, ngramPrevWordsCodePoints[i]);
            ngramPrevWordIsBeginningOfSentense[i] = languageModelDictContent->getProbabilityEntry(
                    prevWordIds[i]).representsBeginningOfSentence();
            if (ngramPrevWordIsBeginningOfSentense[i]) {
                ngramPrevWordsCodePointCount[i] = CharUtils::removeBeginningOfSentenceMarker(
                        ngramPrevWordsCodePoints[i], ngramPrevWordsCodePointCount[i]);
            }
        }
        const NgramContext ngramContext(ngramPrevWordsCodePoints, ngramPrevWordsCodePointCount,
                ngramPrevWordIsBeginningOfSentense, prevWordIds.size());
        const ProbabilityEntry ngramProbabilityEntry = entry.getProbabilityEntry();
        const HistoricalInfo *const historicalInfo = ngramProbabilityEntry.getHistoricalInfo();
        const int probability = ngramProbabilityEntry.hasHistoricalInfo() ?
                ForgettingCurveUtils::decodeProbability(historicalInfo, mHeaderPolicy) :
                ngramProbabilityEntry.getProbability();
        ngrams.emplace_back(
                NgramContext(
                        wordCodePoints.data(), wordCodePoints.size(),
                        probabilityEntry.representsBeginningOfSentence()),
                CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(),
                probability, *historicalInfo);
        // TODO: Output flags in WordAttributes.
        ngrams.emplace_back(ngramContext,
                CodePointArrayView(ngramTargetCodePoints, codePointCount).toVector(),
                entry.getWordAttributes().getProbability(), *historicalInfo);
    }
    // Fetch shortcut information.
    std::vector<UnigramProperty::ShortcutProperty> shortcuts;
@@ -534,6 +541,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
                    shortcutProbability);
        }
    }
    const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry(
            ptNodeParams.getTerminalId());
    const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
    const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(),
            probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(),
            probabilityEntry.isPossiblyOffensive(), probabilityEntry.getProbability(),
Loading