Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 38085ee7 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android (Google) Code Review
Browse files

Merge "Implement LanguageModelDictContent.getWordProbability()."

parents 41ee1f08 395fe8e9
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -354,7 +354,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
    }
    bool addedNewBigram = false;
    const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
    if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(&prevWordPtNodePos),
    if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos),
            wordPos, bigramProperty, &addedNewBigram)) {
        if (addedNewBigram) {
            mBigramCount++;
@@ -396,7 +396,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
    }
    const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
    if (mUpdatingHelper.removeNgramEntry(
            PtNodePosArrayView::fromObject(&prevWordPtNodePos), wordPos)) {
            PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) {
        mBigramCount--;
        return true;
    } else {
+34 −0
Original line number Diff line number Diff line
@@ -38,6 +38,40 @@ bool LanguageModelDictContent::runGC(
            0 /* nextLevelBitmapEntryIndex */, outNgramCount);
}

int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
        const int wordId) const {
    int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
    bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
    int maxLevel = 0;
    for (size_t i = 0; i < prevWordIds.size(); ++i) {
        const int nextBitmapEntryIndex =
                mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
        if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
            break;
        }
        maxLevel = i + 1;
        bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
    }

    for (int i = maxLevel; i >= 0; --i) {
        const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
        if (!result.mIsValid) {
            continue;
        }
        const int probability =
                ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability();
        if (mHasHistoricalInfo) {
            return std::min(
                    probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
                    MAX_PROBABILITY);
        } else {
            return probability;
        }
    }
    // Cannot find the word.
    return NOT_A_PROBABILITY;
}

ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
        const WordIdArrayView prevWordIds, const int wordId) const {
    const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
+2 −0
Original line number Diff line number Diff line
@@ -128,6 +128,8 @@ class LanguageModelDictContent {
            const LanguageModelDictContent *const originalContent,
            int *const outNgramCount);

    int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const;

    ProbabilityEntry getProbabilityEntry(const int wordId) const {
        return getNgramProbabilityEntry(WordIdArrayView(), wordId);
    }
+6 −18
Original line number Diff line number Diff line
@@ -115,24 +115,12 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,

int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds,
        const int wordId, MultiBigramMap *const multiBigramMap) const {
    // TODO: Quit using MultiBigramMap.
    if (wordId == NOT_A_WORD_ID) {
        return NOT_A_PROBABILITY;
    }
    const int ptNodePos =
            mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
    const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
    if (multiBigramMap) {
        return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds,
                wordId, ptNodeParams.getProbability());
    }
    if (prevWordIds) {
        const int probability = getProbabilityOfWord(prevWordIds, wordId);
        if (probability != NOT_A_PROBABILITY) {
            return probability;
        }
    }
    return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
    // TODO: Support n-gram.
    return mBuffers->getLanguageModelDictContent()->getWordProbability(
            WordIdArrayView::singleElementView(prevWordIds), wordId);
}

int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
@@ -166,7 +154,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
        // TODO: Support n-gram.
        const ProbabilityEntry probabilityEntry =
                mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
                        IntArrayView::fromObject(prevWordIds), wordId);
                        IntArrayView::singleElementView(prevWordIds), wordId);
        if (!probabilityEntry.isValid()) {
            return NOT_A_PROBABILITY;
        }
@@ -194,7 +182,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
    // TODO: Support n-gram.
    const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
    for (const auto entry : languageModelDictContent->getProbabilityEntries(
            WordIdArrayView::fromObject(prevWordIds))) {
            WordIdArrayView::singleElementView(prevWordIds))) {
        const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
        const int probability = probabilityEntry.hasHistoricalInfo() ?
                ForgettingCurveUtils::decodeProbability(
@@ -511,7 +499,7 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
    // Fetch bigram information.
    // TODO: Support n-gram.
    std::vector<BigramProperty> bigrams;
    const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&wordId);
    const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
    int bigramWord1CodePoints[MAX_WORD_LENGTH];
    for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
            prevWordIds)) {
+5 −0
Original line number Diff line number Diff line
@@ -48,6 +48,11 @@ class ForgettingCurveUtils {
    static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount,
            const int bigramCount, const HeaderPolicy *const headerPolicy);

    // TODO: Improve probability computation method and remove this.
    static int getProbabilityBiasForNgram(const int n) {
        return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE;
    }

    AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) {
        return static_cast<int>(static_cast<float>(maxUnigramCount)
                * UNIGRAM_COUNT_HARD_LIMIT_WEIGHT);
Loading