Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c2429c54 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android (Google) Code Review
Browse files

Merge "Move entry updating method to language model dict content."

parents 76c8b644 54007019
Loading
Loading
Loading
Loading
+53 −2
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ namespace latinime {

const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0;
const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1;
const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;

bool LanguageModelDictContent::save(FILE *const file) const {
    return mTrieMap.save(file);
@@ -143,6 +144,56 @@ bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
    return true;
}

bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds,
        const int wordId, const bool isValid, const HistoricalInfo historicalInfo,
        const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount) {
    if (outAddedNewNgramEntryCount) {
        *outAddedNewNgramEntryCount = 0;
    }
    if (!mHasHistoricalInfo) {
        AKLOGE("updateAllEntriesOnInputWord is called for dictionary without historical info.");
        return false;
    }
    const ProbabilityEntry originalUnigramProbabilityEntry = getProbabilityEntry(wordId);
    const ProbabilityEntry updatedUnigramProbabilityEntry = createUpdatedEntryFrom(
            originalUnigramProbabilityEntry, isValid, historicalInfo, headerPolicy);
    if (!setProbabilityEntry(wordId, &updatedUnigramProbabilityEntry)) {
        return false;
    }
    for (size_t i = 0; i < prevWordIds.size(); ++i) {
        if (prevWordIds[i] == NOT_A_WORD_ID) {
            break;
        }
        // TODO: Optimize this code.
        const WordIdArrayView limitedPrevWordIds = prevWordIds.limit(i + 1);
        const ProbabilityEntry originalNgramProbabilityEntry = getNgramProbabilityEntry(
                limitedPrevWordIds, wordId);
        const ProbabilityEntry updatedNgramProbabilityEntry = createUpdatedEntryFrom(
                originalNgramProbabilityEntry, isValid, historicalInfo, headerPolicy);
        if (!setNgramProbabilityEntry(limitedPrevWordIds, wordId, &updatedNgramProbabilityEntry)) {
            return false;
        }
        if (!originalNgramProbabilityEntry.isValid() && outAddedNewNgramEntryCount) {
            *outAddedNewNgramEntryCount += 1;
        }
    }
    return true;
}

const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(
        const ProbabilityEntry &originalProbabilityEntry, const bool isValid,
        const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const {
    const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo(
            originalProbabilityEntry.getHistoricalInfo(), isValid ?
                    DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY,
            &historicalInfo, headerPolicy);
    if (originalProbabilityEntry.isValid()) {
        return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo);
    } else {
        return ProbabilityEntry(0 /* flags */, &updatedHistoricalInfo);
    }
}

bool LanguageModelDictContent::runGCInner(
        const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
        const TrieMap::TrieMapRange trieMapRange,
@@ -203,7 +254,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
    return bitmapEntryIndex;
}

bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex,
bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
        const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
        if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
@@ -237,7 +288,7 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmap
        if (!entry.hasNextLevelMap()) {
            continue;
        }
        if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
        if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
                headerPolicy, outEntryCounts)) {
            return false;
        }
+14 −4
Original line number Diff line number Diff line
@@ -154,19 +154,23 @@ class LanguageModelDictContent {

    EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;

    bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
    bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
            int *const outEntryCounts) {
        for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
            outEntryCounts[i] = 0;
        }
        return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
                headerPolicy, outEntryCounts);
        return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
                0 /* level */, headerPolicy, outEntryCounts);
    }

    // entryCounts should be created by updateAllProbabilityEntries.
    bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);

    bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId,
            const bool isValid, const HistoricalInfo historicalInfo,
            const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount);

 private:
    DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);

@@ -193,6 +197,9 @@ class LanguageModelDictContent {
        DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
    };

    // TODO: Remove
    static const int DUMMY_PROBABILITY_FOR_VALID_WORDS;

    TrieMap mTrieMap;
    const bool mHasHistoricalInfo;

@@ -201,13 +208,16 @@ class LanguageModelDictContent {
            int *const outNgramCount);
    int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
    int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
    bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
    bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int level,
            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
    bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
            const int maxEntryCount, const int targetLevel, int *const outEntryCount);
    bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
            const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
            std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
    const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
            const bool isValid, const HistoricalInfo historicalInfo,
            const HeaderPolicy *const headerPolicy) const;
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
+3 −28
Original line number Diff line number Diff line
@@ -142,14 +142,9 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty(
    if (!toBeUpdatedPtNodeParams->isTerminal()) {
        return false;
    }
    const ProbabilityEntry originalProbabilityEntry =
            mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
                    toBeUpdatedPtNodeParams->getTerminalId());
    const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
    const ProbabilityEntry updatedProbabilityEntry =
            createUpdatedEntryFrom(&originalProbabilityEntry, &probabilityEntryOfUnigramProperty);
    return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
            toBeUpdatedPtNodeParams->getTerminalId(), &updatedProbabilityEntry);
            toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntryOfUnigramProperty);
}

bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC(
@@ -203,10 +198,8 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
    // Write probability.
    ProbabilityEntry newProbabilityEntry;
    const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
    const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom(
            &newProbabilityEntry, &probabilityEntryOfUnigramProperty);
    return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
            terminalId, &probabilityEntryToWrite);
            terminalId, &probabilityEntryOfUnigramProperty);
}

// TODO: Support counting ngram entries.
@@ -217,10 +210,8 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds
    const ProbabilityEntry probabilityEntry =
            languageModelDictContent->getNgramProbabilityEntry(prevWordIds, wordId);
    const ProbabilityEntry probabilityEntryOfNgramProperty(ngramProperty);
    const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom(
            &probabilityEntry, &probabilityEntryOfNgramProperty);
    if (!languageModelDictContent->setNgramProbabilityEntry(
            prevWordIds, wordId, &updatedProbabilityEntry)) {
            prevWordIds, wordId, &probabilityEntryOfNgramProperty)) {
        AKLOGE("Cannot add new ngram entry. prevWordId[0]: %d, prevWordId.size(): %zd, wordId: %d",
                prevWordIds[0], prevWordIds.size(), wordId);
        return false;
@@ -346,22 +337,6 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
            ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
}

// TODO: Move probability handling code to LanguageModelDictContent.
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
        const ProbabilityEntry *const originalProbabilityEntry,
        const ProbabilityEntry *const probabilityEntry) const {
    if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
        const HistoricalInfo updatedHistoricalInfo =
                ForgettingCurveUtils::createUpdatedHistoricalInfo(
                        originalProbabilityEntry->getHistoricalInfo(),
                        probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(),
                        mHeaderPolicy);
        return ProbabilityEntry(probabilityEntry->getFlags(), &updatedHistoricalInfo);
    } else {
        return *probabilityEntry;
    }
}

bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, const bool isTerminal,
        const bool hasMultipleChars) {
    // Create node flags and write them.
+2 −11
Original line number Diff line number Diff line
@@ -38,11 +38,10 @@ class Ver4ShortcutListPolicy;
class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
 public:
    Ver4PatriciaTrieNodeWriter(BufferWithExtendableBuffer *const trieBuffer,
            Ver4DictBuffers *const buffers, const HeaderPolicy *const headerPolicy,
            const PtNodeReader *const ptNodeReader,
            Ver4DictBuffers *const buffers, const PtNodeReader *const ptNodeReader,
            const PtNodeArrayReader *const ptNodeArrayReader,
            Ver4ShortcutListPolicy *const shortcutPolicy)
            : mTrieBuffer(trieBuffer), mBuffers(buffers), mHeaderPolicy(headerPolicy),
            : mTrieBuffer(trieBuffer), mBuffers(buffers),
              mReadingHelper(ptNodeReader, ptNodeArrayReader), mShortcutPolicy(shortcutPolicy) {}

    virtual ~Ver4PatriciaTrieNodeWriter() {}
@@ -96,20 +95,12 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
            const PtNodeParams *const ptNodeParams, int *const outTerminalId,
            int *const ptNodeWritingPos);

    // Create updated probability entry using given probability property. In addition to the
    // probability, this method updates historical information if needed.
    // TODO: Update flags.
    const ProbabilityEntry createUpdatedEntryFrom(
            const ProbabilityEntry *const originalProbabilityEntry,
            const ProbabilityEntry *const probabilityEntry) const;

    bool updatePtNodeFlags(const int ptNodePos, const bool isTerminal, const bool hasMultipleChars);

    static const int CHILDREN_POSITION_FIELD_SIZE;

    BufferWithExtendableBuffer *const mTrieBuffer;
    Ver4DictBuffers *const mBuffers;
    const HeaderPolicy *const mHeaderPolicy;
    DynamicPtReadingHelper mReadingHelper;
    Ver4ShortcutListPolicy *const mShortcutPolicy;
};
+35 −18
Original line number Diff line number Diff line
@@ -43,7 +43,6 @@ const char *const Ver4PatriciaTriePolicy::MAX_BIGRAM_COUNT_QUERY = "MAX_BIGRAM_C
const int Ver4PatriciaTriePolicy::MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS = 1024;
const int Ver4PatriciaTriePolicy::MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS =
        Ver4DictConstants::MAX_DICTIONARY_SIZE - MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS;
const int Ver4PatriciaTriePolicy::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;

void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNode,
        DicNodeVector *const childDicNodes) const {
@@ -151,8 +150,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI
            }
            const int probability = probabilityEntry.hasHistoricalInfo() ?
                    ForgettingCurveUtils::decodeProbability(
                            probabilityEntry.getHistoricalInfo(), mHeaderPolicy)
                            + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */) :
                            probabilityEntry.getHistoricalInfo(), mHeaderPolicy) :
                    probabilityEntry.getProbability();
            listener->onVisitEntry(probability, entry.getWordId());
        }
@@ -371,25 +369,44 @@ bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext(
                "dictionary.");
        return false;
    }
    // TODO: Have count up method in language model dict content.
    const int probability = isValidWord ? DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY;
    const bool updateAsAValidWord = ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */) ?
            false : isValidWord;
    int wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */);
    if (wordId == NOT_A_WORD_ID) {
        // The word is not in the dictionary.
        const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */,
            false /* isNotAWord */, false /* isBlacklisted */, probability, historicalInfo);
                false /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY,
                HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
        if (!addUnigramEntry(wordCodePoints, &unigramProperty)) {
        AKLOGE("Cannot update unigarm entry in updateEntriesForWordWithNgramContext().");
            AKLOGE("Cannot add unigarm entry in updateEntriesForWordWithNgramContext().");
            return false;
        }
    const int probabilityForNgram = ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)
            ? NOT_A_PROBABILITY : probability;
    const NgramProperty ngramProperty(wordCodePoints.toVector(), probabilityForNgram,
            historicalInfo);
    for (size_t i = 1; i <= ngramContext->getPrevWordCount(); ++i) {
        const NgramContext trimmedNgramContext(ngramContext->getTrimmedNgramContext(i));
        if (!addNgramEntry(&trimmedNgramContext, &ngramProperty)) {
            AKLOGE("Cannot update ngram entry in updateEntriesForWordWithNgramContext().");
        wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */);
    }

    WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
    const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray,
            false /* tryLowerCaseSearch */);
    if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID
            && ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) {
        const UnigramProperty beginningOfSentenceUnigramProperty(
                true /* representsBeginningOfSentence */,
                true /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY,
                HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
        if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */),
                &beginningOfSentenceUnigramProperty)) {
            AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext().");
            return false;
        }
        // Refresh word ids.
        ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
    }
    int addedNewNgramEntryCount = 0;
    if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(prevWordIds,
            wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &addedNewNgramEntryCount)) {
        return false;
    }
    mBigramCount += addedNewNgramEntryCount;
    return true;
}

Loading