Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 72d17d92 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi
Browse files

Use better conditional probability for ngram entries.

Old:
P(W | W_prev) = f(W, W_prev) + C
New:
P(W | W_prev) = f(W, W_prev) / f(W_prev)

Bug: 14425059
Bug: 16547409

Change-Id: I4d13be6de2c6bad6bad7fb22320a23ba4ecd361c
parent 54007019
Loading
Loading
Loading
Loading
+21 −6
Original line number Diff line number Diff line
@@ -43,18 +43,18 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
        const int wordId, const HeaderPolicy *const headerPolicy) const {
    int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
    bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
    int maxLevel = 0;
    int maxPrevWordCount = 0;
    for (size_t i = 0; i < prevWordIds.size(); ++i) {
        const int nextBitmapEntryIndex =
                mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
        if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
            break;
        }
        maxLevel = i + 1;
        maxPrevWordCount = i + 1;
        bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
    }

    for (int i = maxLevel; i >= 0; --i) {
    for (int i = maxPrevWordCount; i >= 0; --i) {
        const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
        if (!result.mIsValid) {
            continue;
@@ -69,9 +69,24 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
                // The entry should not be treated as a valid entry.
                continue;
            }
            probability = std::min(rawProbability
                    + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
                            MAX_PROBABILITY);
            if (i == 0) {
                // unigram
                probability = rawProbability;
            } else {
                const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry(
                        prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]);
                if (!prevWordProbabilityEntry.isValid()) {
                    continue;
                }
                if (prevWordProbabilityEntry.representsBeginningOfSentence()) {
                    probability = rawProbability;
                } else {
                    const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability(
                            prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy);
                    probability = std::min(MAX_PROBABILITY - prevWordRawProbability
                            + rawProbability, MAX_PROBABILITY);
                }
            }
        } else {
            probability = probabilityEntry.getProbability();
        }
+5 −5
Original line number Diff line number Diff line
@@ -98,17 +98,17 @@ class ProbabilityEntry {
    }

    uint64_t encode(const bool hasHistoricalInfo) const {
        uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
        uint64_t encodedEntry = static_cast<uint8_t>(mFlags);
        if (hasHistoricalInfo) {
            encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mHistoricalInfo.getTimestamp());
                    | static_cast<uint32_t>(mHistoricalInfo.getTimestamp());
            encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mHistoricalInfo.getLevel());
                    | static_cast<uint8_t>(mHistoricalInfo.getLevel());
            encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mHistoricalInfo.getCount());
                    | static_cast<uint8_t>(mHistoricalInfo.getCount());
        } else {
            encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mProbability);
                    | static_cast<uint8_t>(mProbability);
        }
        return encodedEntry;
    }