Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 627b0107 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android (Google) Code Review
Browse files

Merge "Support decaying dict in getWordProbability()."

parents c0617f3e 36ba139c
Loading
Loading
Loading
Loading
+8 −7
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ bool LanguageModelDictContent::runGC(
}

int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
        const int wordId) const {
        const int wordId, const HeaderPolicy *const headerPolicy) const {
    int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
    bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
    int maxLevel = 0;
@@ -58,14 +58,15 @@ int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordI
        if (!result.mIsValid) {
            continue;
        }
        const int probability =
                ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability();
        const ProbabilityEntry probabilityEntry =
                ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
        if (mHasHistoricalInfo) {
            return std::min(
                    probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
                    MAX_PROBABILITY);
            const int probability = ForgettingCurveUtils::decodeProbability(
                    probabilityEntry.getHistoricalInfo(), headerPolicy)
                            + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */);
            return std::min(probability, MAX_PROBABILITY);
        } else {
            return probability;
            return probabilityEntry.getProbability();
        }
    }
    // Cannot find the word.
+2 −1
Original line number Diff line number Diff line
@@ -128,7 +128,8 @@ class LanguageModelDictContent {
            const LanguageModelDictContent *const originalContent,
            int *const outNgramCount);

    int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const;
    int getWordProbability(const WordIdArrayView prevWordIds, const int wordId,
            const HeaderPolicy *const headerPolicy) const;

    ProbabilityEntry getProbabilityEntry(const int wordId) const {
        return getNgramProbabilityEntry(WordIdArrayView(), wordId);
+4 −3
Original line number Diff line number Diff line
@@ -121,9 +121,10 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
            mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
    const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
    // TODO: Support n-gram.
    return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability(
            prevWordIds.limit(1 /* maxSize */), wordId), ptNodeParams.isBlacklisted(),
            ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0);
    const int probability = mBuffers->getLanguageModelDictContent()->getWordProbability(
            prevWordIds.limit(1 /* maxSize */), wordId, mHeaderPolicy);
    return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
            probability == 0);
}

int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
+4 −2
Original line number Diff line number Diff line
@@ -107,13 +107,15 @@ TEST(LanguageModelDictContentTest, TestGetWordProbability) {
    languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry);
    languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId,
            &bigramProbabilityEntry);
    EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
    EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
            nullptr /* headerPolicy */));
    const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability);
    languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1),
            prevWordIds[1], &probabilityEntry);
    languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId,
            &trigramProbabilityEntry);
    EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
    EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
            nullptr /* headerPolicy */));
}

}  // namespace