Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 395fe8e9 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi
Browse files

Implement LanguageModelDictContent.getWordProbability().

Bug: 14425059
Change-Id: I290a05cee6f341caa25fb222892505529cef1eb7
parent 9f8da0f8
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -354,7 +354,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
    }
    bool addedNewBigram = false;
    const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
    if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(&prevWordPtNodePos),
    if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos),
            wordPos, bigramProperty, &addedNewBigram)) {
        if (addedNewBigram) {
            mBigramCount++;
@@ -396,7 +396,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
    }
    const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
    if (mUpdatingHelper.removeNgramEntry(
            PtNodePosArrayView::fromObject(&prevWordPtNodePos), wordPos)) {
            PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) {
        mBigramCount--;
        return true;
    } else {
+34 −0
Original line number Diff line number Diff line
@@ -38,6 +38,40 @@ bool LanguageModelDictContent::runGC(
            0 /* nextLevelBitmapEntryIndex */, outNgramCount);
}

int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
        const int wordId) const {
    int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
    bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
    int maxLevel = 0;
    for (size_t i = 0; i < prevWordIds.size(); ++i) {
        const int nextBitmapEntryIndex =
                mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
        if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
            break;
        }
        maxLevel = i + 1;
        bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
    }

    for (int i = maxLevel; i >= 0; --i) {
        const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
        if (!result.mIsValid) {
            continue;
        }
        const int probability =
                ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability();
        if (mHasHistoricalInfo) {
            return std::min(
                    probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
                    MAX_PROBABILITY);
        } else {
            return probability;
        }
    }
    // Cannot find the word.
    return NOT_A_PROBABILITY;
}

ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
        const WordIdArrayView prevWordIds, const int wordId) const {
    const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
+2 −0
Original line number Diff line number Diff line
@@ -128,6 +128,8 @@ class LanguageModelDictContent {
            const LanguageModelDictContent *const originalContent,
            int *const outNgramCount);

    int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const;

    ProbabilityEntry getProbabilityEntry(const int wordId) const {
        return getNgramProbabilityEntry(WordIdArrayView(), wordId);
    }
+6 −18
Original line number Diff line number Diff line
@@ -115,24 +115,12 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,

int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds,
        const int wordId, MultiBigramMap *const multiBigramMap) const {
    // TODO: Quit using MultiBigramMap.
    if (wordId == NOT_A_WORD_ID) {
        return NOT_A_PROBABILITY;
    }
    const int ptNodePos =
            mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
    const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
    if (multiBigramMap) {
        return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds,
                wordId, ptNodeParams.getProbability());
    }
    if (prevWordIds) {
        const int probability = getProbabilityOfWord(prevWordIds, wordId);
        if (probability != NOT_A_PROBABILITY) {
            return probability;
        }
    }
    return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
    // TODO: Support n-gram.
    return mBuffers->getLanguageModelDictContent()->getWordProbability(
            WordIdArrayView::singleElementView(prevWordIds), wordId);
}

int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
@@ -166,7 +154,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
        // TODO: Support n-gram.
        const ProbabilityEntry probabilityEntry =
                mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
                        IntArrayView::fromObject(prevWordIds), wordId);
                        IntArrayView::singleElementView(prevWordIds), wordId);
        if (!probabilityEntry.isValid()) {
            return NOT_A_PROBABILITY;
        }
@@ -194,7 +182,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
    // TODO: Support n-gram.
    const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
    for (const auto entry : languageModelDictContent->getProbabilityEntries(
            WordIdArrayView::fromObject(prevWordIds))) {
            WordIdArrayView::singleElementView(prevWordIds))) {
        const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
        const int probability = probabilityEntry.hasHistoricalInfo() ?
                ForgettingCurveUtils::decodeProbability(
@@ -511,7 +499,7 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
    // Fetch bigram information.
    // TODO: Support n-gram.
    std::vector<BigramProperty> bigrams;
    const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&wordId);
    const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
    int bigramWord1CodePoints[MAX_WORD_LENGTH];
    for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
            prevWordIds)) {
+5 −0
Original line number Diff line number Diff line
@@ -48,6 +48,11 @@ class ForgettingCurveUtils {
    static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount,
            const int bigramCount, const HeaderPolicy *const headerPolicy);

    // TODO: Improve probability computation method and remove this.
    static int getProbabilityBiasForNgram(const int n) {
        return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE;
    }

    AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) {
        return static_cast<int>(static_cast<float>(maxUnigramCount)
                * UNIGRAM_COUNT_HARD_LIMIT_WEIGHT);
Loading