Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 9a23f0fb authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi
Browse files

Add bigrams to language model content.

Bug: 14425059

Change-Id: Id81e3775ea0104750a23e3dca62c00681ed8dc2e
parent 0807c897
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -234,8 +234,8 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId,
        const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) {
    if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewEntry)) {
        AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d",
                sourcePtNodeParams->getTerminalId(), targetPtNodeParam->getTerminalId());
        AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d",
                prevWordIds[0], wordId);
        return false;
    }
    const int ptNodePos =
+14 −1
Original line number Diff line number Diff line
@@ -46,7 +46,7 @@ ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(

bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds,
        const int terminalId, const ProbabilityEntry *const probabilityEntry) {
    const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
    const int bitmapEntryIndex = createAndGetBitmapEntryIndex(prevWordIds);
    if (bitmapEntryIndex == TrieMap::INVALID_INDEX) {
        return false;
    }
@@ -80,6 +80,19 @@ bool LanguageModelDictContent::runGCInner(
    return true;
}

int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
    if (prevWordIds.empty()) {
        return mTrieMap.getRootBitmapEntryIndex();
    }
    const int lastBitmapEntryIndex =
            getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
    if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
        return TrieMap::INVALID_INDEX;
    }
    return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds[prevWordIds.size() - 1],
            lastBitmapEntryIndex);
}

int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
    int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
    for (const int wordId : prevWordIds) {
+1 −1
Original line number Diff line number Diff line
@@ -76,7 +76,7 @@ class LanguageModelDictContent {
    bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
            const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
            int *const outNgramCount);

    int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
    int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
};
} // namespace latinime
+22 −2
Original line number Diff line number Diff line
@@ -21,6 +21,8 @@
#include <cstdint>

#include "defines.h"
#include "suggest/core/dictionary/property/bigram_property.h"
#include "suggest/core/dictionary/property/unigram_property.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
#include "suggest/policyimpl/dictionary/utils/historical_info.h"

@@ -45,6 +47,20 @@ class ProbabilityEntry {
            const HistoricalInfo *const historicalInfo)
            : mFlags(flags), mProbability(probability), mHistoricalInfo(*historicalInfo) {}

    // Create from unigram property.
    // TODO: Set flags.
    ProbabilityEntry(const UnigramProperty *const unigramProperty)
            : mFlags(0), mProbability(unigramProperty->getProbability()),
              mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(),
                      unigramProperty->getCount()) {}

    // Create from bigram property.
    // TODO: Set flags.
    ProbabilityEntry(const BigramProperty *const bigramProperty)
            : mFlags(0), mProbability(bigramProperty->getProbability()),
              mHistoricalInfo(bigramProperty->getTimestamp(), bigramProperty->getLevel(),
                      bigramProperty->getCount()) {}

    const ProbabilityEntry createEntryWithUpdatedProbability(const int probability) const {
        return ProbabilityEntry(mFlags, probability, &mHistoricalInfo);
    }
@@ -54,6 +70,10 @@ class ProbabilityEntry {
        return ProbabilityEntry(mFlags, mProbability, historicalInfo);
    }

    bool isValid() const {
        return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo();
    }

    bool hasHistoricalInfo() const {
        return mHistoricalInfo.isValid();
    }
@@ -89,7 +109,7 @@ class ProbabilityEntry {
    static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) {
        if (hasHistoricalInfo) {
            const int flags = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
                    Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
                    Ver4DictConstants::TIME_STAMP_FIELD_SIZE
                            + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE
                            + Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
@@ -106,7 +126,7 @@ class ProbabilityEntry {
            return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo);
        } else {
            const int flags = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
                    Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
                    Ver4DictConstants::PROBABILITY_SIZE);
            const int probability = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */);
+1 −1
Original line number Diff line number Diff line
@@ -46,7 +46,7 @@ const int Ver4DictConstants::SHORTCUT_BUFFERS_INDEX =

const int Ver4DictConstants::NOT_A_TERMINAL_ID = -1;
const int Ver4DictConstants::PROBABILITY_SIZE = 1;
const int Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE = 1;
const int Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE = 1;
const int Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE = 3;
const int Ver4DictConstants::NOT_A_TERMINAL_ADDRESS = 0;
const int Ver4DictConstants::TERMINAL_ID_FIELD_SIZE = 4;
Loading