Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 47ae7368 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android (Google) Code Review
Browse files

Merge "Add bigrams to language model content."

parents 0ba23b3a 9a23f0fb
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -234,8 +234,8 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId,
        const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) {
    if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewEntry)) {
        AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d",
                sourcePtNodeParams->getTerminalId(), targetPtNodeParam->getTerminalId());
        AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d",
                prevWordIds[0], wordId);
        return false;
    }
    const int ptNodePos =
+14 −1
Original line number Diff line number Diff line
@@ -46,7 +46,7 @@ ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(

bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds,
        const int terminalId, const ProbabilityEntry *const probabilityEntry) {
    const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
    const int bitmapEntryIndex = createAndGetBitmapEntryIndex(prevWordIds);
    if (bitmapEntryIndex == TrieMap::INVALID_INDEX) {
        return false;
    }
@@ -80,6 +80,19 @@ bool LanguageModelDictContent::runGCInner(
    return true;
}

int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
    if (prevWordIds.empty()) {
        return mTrieMap.getRootBitmapEntryIndex();
    }
    const int lastBitmapEntryIndex =
            getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
    if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
        return TrieMap::INVALID_INDEX;
    }
    return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds[prevWordIds.size() - 1],
            lastBitmapEntryIndex);
}

int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
    int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
    for (const int wordId : prevWordIds) {
+1 −1
Original line number Diff line number Diff line
@@ -76,7 +76,7 @@ class LanguageModelDictContent {
    bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
            const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
            int *const outNgramCount);

    int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
    int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
};
} // namespace latinime
+22 −2
Original line number Diff line number Diff line
@@ -21,6 +21,8 @@
#include <cstdint>

#include "defines.h"
#include "suggest/core/dictionary/property/bigram_property.h"
#include "suggest/core/dictionary/property/unigram_property.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
#include "suggest/policyimpl/dictionary/utils/historical_info.h"

@@ -45,6 +47,20 @@ class ProbabilityEntry {
            const HistoricalInfo *const historicalInfo)
            : mFlags(flags), mProbability(probability), mHistoricalInfo(*historicalInfo) {}

    // Create from unigram property.
    // TODO: Set flags.
    ProbabilityEntry(const UnigramProperty *const unigramProperty)
            : mFlags(0), mProbability(unigramProperty->getProbability()),
              mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(),
                      unigramProperty->getCount()) {}

    // Create from bigram property.
    // TODO: Set flags.
    ProbabilityEntry(const BigramProperty *const bigramProperty)
            : mFlags(0), mProbability(bigramProperty->getProbability()),
              mHistoricalInfo(bigramProperty->getTimestamp(), bigramProperty->getLevel(),
                      bigramProperty->getCount()) {}

    const ProbabilityEntry createEntryWithUpdatedProbability(const int probability) const {
        return ProbabilityEntry(mFlags, probability, &mHistoricalInfo);
    }
@@ -54,6 +70,10 @@ class ProbabilityEntry {
        return ProbabilityEntry(mFlags, mProbability, historicalInfo);
    }

    bool isValid() const {
        return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo();
    }

    bool hasHistoricalInfo() const {
        return mHistoricalInfo.isValid();
    }
@@ -89,7 +109,7 @@ class ProbabilityEntry {
    static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) {
        if (hasHistoricalInfo) {
            const int flags = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
                    Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
                    Ver4DictConstants::TIME_STAMP_FIELD_SIZE
                            + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE
                            + Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
@@ -106,7 +126,7 @@ class ProbabilityEntry {
            return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo);
        } else {
            const int flags = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
                    Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
                    Ver4DictConstants::PROBABILITY_SIZE);
            const int probability = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */);
+1 −1
Original line number Diff line number Diff line
@@ -46,7 +46,7 @@ const int Ver4DictConstants::SHORTCUT_BUFFERS_INDEX =

const int Ver4DictConstants::NOT_A_TERMINAL_ID = -1;
const int Ver4DictConstants::PROBABILITY_SIZE = 1;
const int Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE = 1;
const int Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE = 1;
const int Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE = 3;
const int Ver4DictConstants::NOT_A_TERMINAL_ADDRESS = 0;
const int Ver4DictConstants::TERMINAL_ID_FIELD_SIZE = 4;
Loading