Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 183e21c3 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android (Google) Code Review
Browse files

Merge "Use better conditional probability for ngram entries."

parents 17100ad8 72d17d92
Loading
Loading
Loading
Loading
+21 −6
Original line number Diff line number Diff line
@@ -43,18 +43,18 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
        const int wordId, const HeaderPolicy *const headerPolicy) const {
    int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
    bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
    int maxLevel = 0;
    int maxPrevWordCount = 0;
    for (size_t i = 0; i < prevWordIds.size(); ++i) {
        const int nextBitmapEntryIndex =
                mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
        if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
            break;
        }
        maxLevel = i + 1;
        maxPrevWordCount = i + 1;
        bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
    }

    for (int i = maxLevel; i >= 0; --i) {
    for (int i = maxPrevWordCount; i >= 0; --i) {
        const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
        if (!result.mIsValid) {
            continue;
@@ -69,9 +69,24 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
                // The entry should not be treated as a valid entry.
                continue;
            }
            probability = std::min(rawProbability
                    + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
                            MAX_PROBABILITY);
            if (i == 0) {
                // unigram
                probability = rawProbability;
            } else {
                const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry(
                        prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]);
                if (!prevWordProbabilityEntry.isValid()) {
                    continue;
                }
                if (prevWordProbabilityEntry.representsBeginningOfSentence()) {
                    probability = rawProbability;
                } else {
                    const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability(
                            prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy);
                    probability = std::min(MAX_PROBABILITY - prevWordRawProbability
                            + rawProbability, MAX_PROBABILITY);
                }
            }
        } else {
            probability = probabilityEntry.getProbability();
        }
+5 −5
Original line number Diff line number Diff line
@@ -98,17 +98,17 @@ class ProbabilityEntry {
    }

    uint64_t encode(const bool hasHistoricalInfo) const {
        uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
        uint64_t encodedEntry = static_cast<uint8_t>(mFlags);
        if (hasHistoricalInfo) {
            encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mHistoricalInfo.getTimestamp());
                    | static_cast<uint32_t>(mHistoricalInfo.getTimestamp());
            encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mHistoricalInfo.getLevel());
                    | static_cast<uint8_t>(mHistoricalInfo.getLevel());
            encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mHistoricalInfo.getCount());
                    | static_cast<uint8_t>(mHistoricalInfo.getCount());
        } else {
            encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mProbability);
                    | static_cast<uint8_t>(mProbability);
        }
        return encodedEntry;
    }