Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 12493955 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android (Google) Code Review
Browse files

Merge "Support ngram entry migration."

parents baecaa54 c9865785
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -629,8 +629,7 @@ static bool latinime_BinaryDictionary_migrateNative(JNIEnv *env, jclass clazz, j
        }
    } while (token != 0);

    // Add bigrams.
    // TODO: Support ngrams.
    // Add ngrams.
    do {
        token = dictionary->getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount);
        const WordProperty wordProperty = dictionary->getWordProperty(
+6 −4
Original line number Diff line number Diff line
@@ -580,10 +580,12 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
                    getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH,
                    bigramWord1CodePoints);
            const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo();
            const int probability = bigramEntry.hasHistoricalInfo() ?
                    ForgettingCurveUtils::decodeProbability(
                            bigramEntry.getHistoricalInfo(), mHeaderPolicy) :
                    bigramEntry.getProbability();
            const int rawBigramProbability = bigramEntry.hasHistoricalInfo()
                    ? ForgettingCurveUtils::decodeProbability(
                            bigramEntry.getHistoricalInfo(), mHeaderPolicy)
                    : bigramEntry.getProbability();
            const int probability = getBigramConditionalProbability(ptNodeParams.getProbability(),
                    ptNodeParams.representsBeginningOfSentence(), rawBigramProbability);
            ngrams.emplace_back(
                    NgramContext(wordCodePoints.data(), wordCodePoints.size(),
                            ptNodeParams.representsBeginningOfSentence()),
+55 −16
Original line number Diff line number Diff line
@@ -140,6 +140,44 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt
    return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
}

std::vector<LanguageModelDictContent::DumppedFullEntryInfo>
        LanguageModelDictContent::exportAllNgramEntriesRelatedToWord(
                const HeaderPolicy *const headerPolicy, const int wordId) const {
    const TrieMap::Result result = mTrieMap.getRoot(wordId);
    if (!result.mIsValid || result.mNextLevelBitmapEntryIndex == TrieMap::INVALID_INDEX) {
        // The word doesn't have any related ngram entries.
        return std::vector<DumppedFullEntryInfo>();
    }
    std::vector<int> prevWordIds = { wordId };
    std::vector<DumppedFullEntryInfo> entries;
    exportAllNgramEntriesRelatedToWordInner(headerPolicy, result.mNextLevelBitmapEntryIndex,
            &prevWordIds, &entries);
    return entries;
}

void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner(
        const HeaderPolicy *const headerPolicy, const int bitmapEntryIndex,
        std::vector<int> *const prevWordIds,
        std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const {
    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
        const int wordId = entry.key();
        const ProbabilityEntry probabilityEntry =
                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
        if (probabilityEntry.isValid()) {
            const WordAttributes wordAttributes = getWordAttributes(
                    WordIdArrayView(*prevWordIds), wordId, headerPolicy);
            outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId,
                    wordAttributes, probabilityEntry);
        }
        if (entry.hasNextLevelMap()) {
            prevWordIds->push_back(wordId);
            exportAllNgramEntriesRelatedToWordInner(headerPolicy,
                    entry.getNextLevelBitmapEntryIndex(), prevWordIds, outBummpedFullEntryInfo);
            prevWordIds->pop_back();
        }
    }
}

bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCounts,
        const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
        MutableEntryCounters *const outEntryCounters) {
@@ -231,25 +269,26 @@ bool LanguageModelDictContent::runGCInner(
}

int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
    if (prevWordIds.empty()) {
        return mTrieMap.getRootBitmapEntryIndex();
    }
    const int lastBitmapEntryIndex =
            getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
    if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
        return TrieMap::INVALID_INDEX;
    int lastBitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
    for (const int wordId : prevWordIds) {
        const TrieMap::Result result = mTrieMap.get(wordId, lastBitmapEntryIndex);
        if (result.mIsValid && result.mNextLevelBitmapEntryIndex != TrieMap::INVALID_INDEX) {
            lastBitmapEntryIndex = result.mNextLevelBitmapEntryIndex;
            continue;
        }
    const int oldestPrevWordId = prevWordIds.lastOrDefault(NOT_A_WORD_ID);
    const TrieMap::Result result = mTrieMap.get(oldestPrevWordId, lastBitmapEntryIndex);
        if (!result.mIsValid) {
        if (!mTrieMap.put(oldestPrevWordId,
                ProbabilityEntry().encode(mHasHistoricalInfo), lastBitmapEntryIndex)) {
            if (!mTrieMap.put(wordId, ProbabilityEntry().encode(mHasHistoricalInfo),
                    lastBitmapEntryIndex)) {
                AKLOGE("Failed to update trie map. wordId: %d, lastBitmapEntryIndex %d", wordId,
                        lastBitmapEntryIndex);
                return TrieMap::INVALID_INDEX;
            }
        }
    return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds.lastOrDefault(NOT_A_WORD_ID),
        lastBitmapEntryIndex = mTrieMap.getNextLevelBitmapEntryIndex(wordId,
                lastBitmapEntryIndex);
    }
    return lastBitmapEntryIndex;
}

int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
    int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
+27 −0
Original line number Diff line number Diff line
@@ -110,6 +110,27 @@ class LanguageModelDictContent {
        const bool mHasHistoricalInfo;
    };

    class DumppedFullEntryInfo {
     public:
        DumppedFullEntryInfo(std::vector<int> &prevWordIds, const int targetWordId,
                const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry)
                : mPrevWordIds(prevWordIds), mTargetWordId(targetWordId),
                  mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {}

        const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); }
        int getTargetWordId() const { return mTargetWordId; }
        const WordAttributes &getWordAttributes() const { return mWordAttributes; }
        const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; }

     private:
        DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo);

        const std::vector<int> mPrevWordIds;
        const int mTargetWordId;
        const WordAttributes mWordAttributes;
        const ProbabilityEntry mProbabilityEntry;
    };

    LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
            const bool hasHistoricalInfo)
            : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
@@ -151,6 +172,9 @@ class LanguageModelDictContent {

    EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;

    std::vector<DumppedFullEntryInfo> exportAllNgramEntriesRelatedToWord(
            const HeaderPolicy *const headerPolicy, const int wordId) const;

    bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
            MutableEntryCounters *const outEntryCounters) {
        return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
@@ -212,6 +236,9 @@ class LanguageModelDictContent {
    const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
            const bool isValid, const HistoricalInfo historicalInfo,
            const HeaderPolicy *const headerPolicy) const;
    void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy,
            const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
            std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const;
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
+31 −21
Original line number Diff line number Diff line
@@ -491,30 +491,37 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
    const int ptNodePos =
            mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
    const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
    const ProbabilityEntry probabilityEntry =
            mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
                    ptNodeParams.getTerminalId());
    const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
    // Fetch bigram information.
    // TODO: Support n-gram.
    const LanguageModelDictContent *const languageModelDictContent =
            mBuffers->getLanguageModelDictContent();
    // Fetch ngram information.
    std::vector<NgramProperty> ngrams;
    const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
    int bigramWord1CodePoints[MAX_WORD_LENGTH];
    for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
            prevWordIds)) {
        const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(),
                MAX_WORD_LENGTH, bigramWord1CodePoints);
    int ngramTargetCodePoints[MAX_WORD_LENGTH];
    int ngramPrevWordsCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH];
    int ngramPrevWordsCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
    bool ngramPrevWordIsBeginningOfSentense[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
    for (const auto entry : languageModelDictContent->exportAllNgramEntriesRelatedToWord(
            mHeaderPolicy, wordId)) {
        const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getTargetWordId(),
                MAX_WORD_LENGTH, ngramTargetCodePoints);
        const WordIdArrayView prevWordIds = entry.getPrevWordIds();
        for (size_t i = 0; i < prevWordIds.size(); ++i) {
            ngramPrevWordsCodePointCount[i] = getCodePointsAndReturnCodePointCount(prevWordIds[i],
                       MAX_WORD_LENGTH, ngramPrevWordsCodePoints[i]);
            ngramPrevWordIsBeginningOfSentense[i] = languageModelDictContent->getProbabilityEntry(
                    prevWordIds[i]).representsBeginningOfSentence();
            if (ngramPrevWordIsBeginningOfSentense[i]) {
                ngramPrevWordsCodePointCount[i] = CharUtils::removeBeginningOfSentenceMarker(
                        ngramPrevWordsCodePoints[i], ngramPrevWordsCodePointCount[i]);
            }
        }
        const NgramContext ngramContext(ngramPrevWordsCodePoints, ngramPrevWordsCodePointCount,
                ngramPrevWordIsBeginningOfSentense, prevWordIds.size());
        const ProbabilityEntry ngramProbabilityEntry = entry.getProbabilityEntry();
        const HistoricalInfo *const historicalInfo = ngramProbabilityEntry.getHistoricalInfo();
        const int probability = ngramProbabilityEntry.hasHistoricalInfo() ?
                ForgettingCurveUtils::decodeProbability(historicalInfo, mHeaderPolicy) :
                ngramProbabilityEntry.getProbability();
        ngrams.emplace_back(
                NgramContext(
                        wordCodePoints.data(), wordCodePoints.size(),
                        probabilityEntry.representsBeginningOfSentence()),
                CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(),
                probability, *historicalInfo);
        // TODO: Output flags in WordAttributes.
        ngrams.emplace_back(ngramContext,
                CodePointArrayView(ngramTargetCodePoints, codePointCount).toVector(),
                entry.getWordAttributes().getProbability(), *historicalInfo);
    }
    // Fetch shortcut information.
    std::vector<UnigramProperty::ShortcutProperty> shortcuts;
@@ -534,6 +541,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
                    shortcutProbability);
        }
    }
    const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry(
            ptNodeParams.getTerminalId());
    const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
    const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(),
            probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(),
            probabilityEntry.isPossiblyOffensive(), probabilityEntry.getProbability(),
Loading