Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 54007019 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi
Browse files

Move entry updating method to language model dict content.

Bug: 14425059
Change-Id: I710055490d141539458cbf968adf5a7ccffd9552
parent 95f100ba
Loading
Loading
Loading
Loading
+53 −2
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ namespace latinime {

const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0;
const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1;
const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;

bool LanguageModelDictContent::save(FILE *const file) const {
    return mTrieMap.save(file);
@@ -143,6 +144,56 @@ bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
    return true;
}

bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds,
        const int wordId, const bool isValid, const HistoricalInfo historicalInfo,
        const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount) {
    if (outAddedNewNgramEntryCount) {
        *outAddedNewNgramEntryCount = 0;
    }
    if (!mHasHistoricalInfo) {
        AKLOGE("updateAllEntriesOnInputWord is called for dictionary without historical info.");
        return false;
    }
    const ProbabilityEntry originalUnigramProbabilityEntry = getProbabilityEntry(wordId);
    const ProbabilityEntry updatedUnigramProbabilityEntry = createUpdatedEntryFrom(
            originalUnigramProbabilityEntry, isValid, historicalInfo, headerPolicy);
    if (!setProbabilityEntry(wordId, &updatedUnigramProbabilityEntry)) {
        return false;
    }
    for (size_t i = 0; i < prevWordIds.size(); ++i) {
        if (prevWordIds[i] == NOT_A_WORD_ID) {
            break;
        }
        // TODO: Optimize this code.
        const WordIdArrayView limitedPrevWordIds = prevWordIds.limit(i + 1);
        const ProbabilityEntry originalNgramProbabilityEntry = getNgramProbabilityEntry(
                limitedPrevWordIds, wordId);
        const ProbabilityEntry updatedNgramProbabilityEntry = createUpdatedEntryFrom(
                originalNgramProbabilityEntry, isValid, historicalInfo, headerPolicy);
        if (!setNgramProbabilityEntry(limitedPrevWordIds, wordId, &updatedNgramProbabilityEntry)) {
            return false;
        }
        if (!originalNgramProbabilityEntry.isValid() && outAddedNewNgramEntryCount) {
            *outAddedNewNgramEntryCount += 1;
        }
    }
    return true;
}

const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(
        const ProbabilityEntry &originalProbabilityEntry, const bool isValid,
        const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const {
    const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo(
            originalProbabilityEntry.getHistoricalInfo(), isValid ?
                    DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY,
            &historicalInfo, headerPolicy);
    if (originalProbabilityEntry.isValid()) {
        return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo);
    } else {
        return ProbabilityEntry(0 /* flags */, &updatedHistoricalInfo);
    }
}

bool LanguageModelDictContent::runGCInner(
        const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
        const TrieMap::TrieMapRange trieMapRange,
@@ -203,7 +254,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
    return bitmapEntryIndex;
}

bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex,
bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
        const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
        if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
@@ -237,7 +288,7 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmap
        if (!entry.hasNextLevelMap()) {
            continue;
        }
        if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
        if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
                headerPolicy, outEntryCounts)) {
            return false;
        }
+14 −4
Original line number Diff line number Diff line
@@ -154,19 +154,23 @@ class LanguageModelDictContent {

    EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;

    bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
    bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
            int *const outEntryCounts) {
        for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
            outEntryCounts[i] = 0;
        }
        return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
                headerPolicy, outEntryCounts);
        return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
                0 /* level */, headerPolicy, outEntryCounts);
    }

    // entryCounts should be created by updateAllProbabilityEntries.
    bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);

    bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId,
            const bool isValid, const HistoricalInfo historicalInfo,
            const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount);

 private:
    DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);

@@ -193,6 +197,9 @@ class LanguageModelDictContent {
        DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
    };

    // TODO: Remove
    static const int DUMMY_PROBABILITY_FOR_VALID_WORDS;

    TrieMap mTrieMap;
    const bool mHasHistoricalInfo;

@@ -201,13 +208,16 @@ class LanguageModelDictContent {
            int *const outNgramCount);
    int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
    int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
    bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
    bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int level,
            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
    bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
            const int maxEntryCount, const int targetLevel, int *const outEntryCount);
    bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
            const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
            std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
    const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
            const bool isValid, const HistoricalInfo historicalInfo,
            const HeaderPolicy *const headerPolicy) const;
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
+3 −28
Original line number Diff line number Diff line
@@ -142,14 +142,9 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty(
    if (!toBeUpdatedPtNodeParams->isTerminal()) {
        return false;
    }
    const ProbabilityEntry originalProbabilityEntry =
            mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
                    toBeUpdatedPtNodeParams->getTerminalId());
    const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
    const ProbabilityEntry updatedProbabilityEntry =
            createUpdatedEntryFrom(&originalProbabilityEntry, &probabilityEntryOfUnigramProperty);
    return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
            toBeUpdatedPtNodeParams->getTerminalId(), &updatedProbabilityEntry);
            toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntryOfUnigramProperty);
}

bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC(
@@ -203,10 +198,8 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
    // Write probability.
    ProbabilityEntry newProbabilityEntry;
    const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
    const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom(
            &newProbabilityEntry, &probabilityEntryOfUnigramProperty);
    return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
            terminalId, &probabilityEntryToWrite);
            terminalId, &probabilityEntryOfUnigramProperty);
}

// TODO: Support counting ngram entries.
@@ -217,10 +210,8 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds
    const ProbabilityEntry probabilityEntry =
            languageModelDictContent->getNgramProbabilityEntry(prevWordIds, wordId);
    const ProbabilityEntry probabilityEntryOfNgramProperty(ngramProperty);
    const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom(
            &probabilityEntry, &probabilityEntryOfNgramProperty);
    if (!languageModelDictContent->setNgramProbabilityEntry(
            prevWordIds, wordId, &updatedProbabilityEntry)) {
            prevWordIds, wordId, &probabilityEntryOfNgramProperty)) {
        AKLOGE("Cannot add new ngram entry. prevWordId[0]: %d, prevWordId.size(): %zd, wordId: %d",
                prevWordIds[0], prevWordIds.size(), wordId);
        return false;
@@ -346,22 +337,6 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
            ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
}

// TODO: Move probability handling code to LanguageModelDictContent.
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
        const ProbabilityEntry *const originalProbabilityEntry,
        const ProbabilityEntry *const probabilityEntry) const {
    if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
        const HistoricalInfo updatedHistoricalInfo =
                ForgettingCurveUtils::createUpdatedHistoricalInfo(
                        originalProbabilityEntry->getHistoricalInfo(),
                        probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(),
                        mHeaderPolicy);
        return ProbabilityEntry(probabilityEntry->getFlags(), &updatedHistoricalInfo);
    } else {
        return *probabilityEntry;
    }
}

bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, const bool isTerminal,
        const bool hasMultipleChars) {
    // Create node flags and write them.
+2 −11
Original line number Diff line number Diff line
@@ -38,11 +38,10 @@ class Ver4ShortcutListPolicy;
class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
 public:
    Ver4PatriciaTrieNodeWriter(BufferWithExtendableBuffer *const trieBuffer,
            Ver4DictBuffers *const buffers, const HeaderPolicy *const headerPolicy,
            const PtNodeReader *const ptNodeReader,
            Ver4DictBuffers *const buffers, const PtNodeReader *const ptNodeReader,
            const PtNodeArrayReader *const ptNodeArrayReader,
            Ver4ShortcutListPolicy *const shortcutPolicy)
            : mTrieBuffer(trieBuffer), mBuffers(buffers), mHeaderPolicy(headerPolicy),
            : mTrieBuffer(trieBuffer), mBuffers(buffers),
              mReadingHelper(ptNodeReader, ptNodeArrayReader), mShortcutPolicy(shortcutPolicy) {}

    virtual ~Ver4PatriciaTrieNodeWriter() {}
@@ -96,20 +95,12 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
            const PtNodeParams *const ptNodeParams, int *const outTerminalId,
            int *const ptNodeWritingPos);

    // Create updated probability entry using given probability property. In addition to the
    // probability, this method updates historical information if needed.
    // TODO: Update flags.
    const ProbabilityEntry createUpdatedEntryFrom(
            const ProbabilityEntry *const originalProbabilityEntry,
            const ProbabilityEntry *const probabilityEntry) const;

    bool updatePtNodeFlags(const int ptNodePos, const bool isTerminal, const bool hasMultipleChars);

    static const int CHILDREN_POSITION_FIELD_SIZE;

    BufferWithExtendableBuffer *const mTrieBuffer;
    Ver4DictBuffers *const mBuffers;
    const HeaderPolicy *const mHeaderPolicy;
    DynamicPtReadingHelper mReadingHelper;
    Ver4ShortcutListPolicy *const mShortcutPolicy;
};
+35 −18
Original line number Diff line number Diff line
@@ -43,7 +43,6 @@ const char *const Ver4PatriciaTriePolicy::MAX_BIGRAM_COUNT_QUERY = "MAX_BIGRAM_C
const int Ver4PatriciaTriePolicy::MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS = 1024;
const int Ver4PatriciaTriePolicy::MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS =
        Ver4DictConstants::MAX_DICTIONARY_SIZE - MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS;
const int Ver4PatriciaTriePolicy::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;

void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNode,
        DicNodeVector *const childDicNodes) const {
@@ -151,8 +150,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI
            }
            const int probability = probabilityEntry.hasHistoricalInfo() ?
                    ForgettingCurveUtils::decodeProbability(
                            probabilityEntry.getHistoricalInfo(), mHeaderPolicy)
                            + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */) :
                            probabilityEntry.getHistoricalInfo(), mHeaderPolicy) :
                    probabilityEntry.getProbability();
            listener->onVisitEntry(probability, entry.getWordId());
        }
@@ -371,25 +369,44 @@ bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext(
                "dictionary.");
        return false;
    }
    // TODO: Have count up method in language model dict content.
    const int probability = isValidWord ? DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY;
    const bool updateAsAValidWord = ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */) ?
            false : isValidWord;
    int wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */);
    if (wordId == NOT_A_WORD_ID) {
        // The word is not in the dictionary.
        const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */,
            false /* isNotAWord */, false /* isBlacklisted */, probability, historicalInfo);
                false /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY,
                HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
        if (!addUnigramEntry(wordCodePoints, &unigramProperty)) {
        AKLOGE("Cannot update unigarm entry in updateEntriesForWordWithNgramContext().");
            AKLOGE("Cannot add unigarm entry in updateEntriesForWordWithNgramContext().");
            return false;
        }
    const int probabilityForNgram = ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)
            ? NOT_A_PROBABILITY : probability;
    const NgramProperty ngramProperty(wordCodePoints.toVector(), probabilityForNgram,
            historicalInfo);
    for (size_t i = 1; i <= ngramContext->getPrevWordCount(); ++i) {
        const NgramContext trimmedNgramContext(ngramContext->getTrimmedNgramContext(i));
        if (!addNgramEntry(&trimmedNgramContext, &ngramProperty)) {
            AKLOGE("Cannot update ngram entry in updateEntriesForWordWithNgramContext().");
        wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */);
    }

    WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
    const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray,
            false /* tryLowerCaseSearch */);
    if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID
            && ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) {
        const UnigramProperty beginningOfSentenceUnigramProperty(
                true /* representsBeginningOfSentence */,
                true /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY,
                HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
        if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */),
                &beginningOfSentenceUnigramProperty)) {
            AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext().");
            return false;
        }
        // Refresh word ids.
        ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
    }
    int addedNewNgramEntryCount = 0;
    if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(prevWordIds,
            wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &addedNewNgramEntryCount)) {
        return false;
    }
    mBigramCount += addedNewNgramEntryCount;
    return true;
}

Loading