Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 9d7e8c71 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi
Browse files

Support unigram historical information migration.

Bug: 13406708
Change-Id: Ibed15b3bc5d5ae68faefa379028dbe10d32b0c0f
parent 6b74f516
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -335,8 +335,9 @@ static void latinime_BinaryDictionary_addUnigramWord(JNIEnv *env, jclass clazz,
    if (!shortcutTargetCodePoints.empty()) {
        shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability);
    }
    // Use 1 for count to indicate the word has inputed.
    const UnigramProperty unigramProperty(isNotAWord, isBlacklisted,
            probability, timestamp, 0 /* level */, 0 /* count */, &shortcuts);
            probability, timestamp, 0 /* level */, 1 /* count */, &shortcuts);
    dictionary->addUnigramWord(codePoints, codePointCount, &unigramProperty);
}

@@ -436,8 +437,9 @@ static int latinime_BinaryDictionary_addMultipleDictionaryEntries(JNIEnv *env, j
                    env->GetIntField(languageModelParam, shortcutProbabilityFieldId);
            shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability);
        }
        // Use 1 for count to indicate the word has inputed.
        const UnigramProperty unigramProperty(isNotAWord, isBlacklisted,
                unigramProbability, timestamp, 0 /* level */, 0 /* count */, &shortcuts);
                unigramProbability, timestamp, 0 /* level */, 1 /* count */, &shortcuts);
        dictionary->addUnigramWord(word1CodePoints, word1Length, &unigramProperty);
        if (word0) {
            jint bigramProbability = env->GetIntField(languageModelParam, bigramProbabilityFieldId);
+4 −2
Original line number Diff line number Diff line
@@ -257,10 +257,12 @@ const BigramEntry Ver4BigramListPolicy::createUpdatedBigramEntryFrom(
        const int timestamp) const {
    // TODO: Consolidate historical info and probability.
    if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
        // Use 1 for count to indicate the bigram has inputed.
        const HistoricalInfo historicalInfoForUpdate(timestamp, 0 /* level */, 1 /* count */);
        const HistoricalInfo updatedHistoricalInfo =
                ForgettingCurveUtils::createUpdatedHistoricalInfo(
                        originalBigramEntry->getHistoricalInfo(), newProbability, timestamp,
                        mHeaderPolicy);
                        originalBigramEntry->getHistoricalInfo(), newProbability,
                        &historicalInfoForUpdate, mHeaderPolicy);
        return originalBigramEntry->updateHistoricalInfoAndGetEntry(&updatedHistoricalInfo);
    } else {
        return originalBigramEntry->updateProbabilityAndGetEntry(newProbability);
+3 −2
Original line number Diff line number Diff line
@@ -387,11 +387,12 @@ const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
        const UnigramProperty *const unigramProperty) const {
    // TODO: Consolidate historical info and probability.
    if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
        const HistoricalInfo historicalInfoForUpdate(unigramProperty->getTimestamp(),
                unigramProperty->getLevel(), unigramProperty->getCount());
        const HistoricalInfo updatedHistoricalInfo =
                ForgettingCurveUtils::createUpdatedHistoricalInfo(
                        originalProbabilityEntry->getHistoricalInfo(),
                        unigramProperty->getProbability(), unigramProperty->getTimestamp(),
                        mHeaderPolicy);
                        unigramProperty->getProbability(), &historicalInfoForUpdate, mHeaderPolicy);
        return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo(
                &updatedHistoricalInfo);
    } else {
+10 −2
Original line number Diff line number Diff line
@@ -425,6 +425,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code
}

int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints) {
    // TODO: Return code point count like other methods.
    // Null termination.
    outCodePoints[0] = 0;
    if (token == 0) {
        mTerminalPtNodePositionsForIteratingWords.clear();
        DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy(
@@ -441,8 +444,13 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
    }
    const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
    int unigramProbability = NOT_A_PROBABILITY;
    getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, MAX_WORD_LENGTH,
            outCodePoints, &unigramProbability);
    const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
            terminalPtNodePos, MAX_WORD_LENGTH, outCodePoints, &unigramProbability);
    if (codePointCount < MAX_WORD_LENGTH) {
        // Null termination. outCodePoints have to be null terminated or contain MAX_WORD_LENGTH
        // code points.
        outCodePoints[codePointCount] = 0;
    }
    const int nextToken = token + 1;
    if (nextToken >= terminalPtNodePositionsVectorSize) {
        // All words have been iterated.
+38 −12
Original line number Diff line number Diff line
@@ -30,7 +30,7 @@ const int ForgettingCurveUtils::MULTIPLIER_TWO_IN_PROBABILITY_SCALE = 8;
const int ForgettingCurveUtils::DECAY_INTERVAL_SECONDS = 2 * 60 * 60;

const int ForgettingCurveUtils::MAX_LEVEL = 3;
const int ForgettingCurveUtils::MIN_VALID_LEVEL = 1;
const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 1;
const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 15;
const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 14;

@@ -41,25 +41,34 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT

// TODO: Revise the logic to decide the initial probability depending on the given probability.
/* static */ const HistoricalInfo ForgettingCurveUtils::createUpdatedHistoricalInfo(
        const HistoricalInfo *const originalHistoricalInfo,
        const int newProbability, const int timestamp, const HeaderPolicy *const headerPolicy) {
        const HistoricalInfo *const originalHistoricalInfo, const int newProbability,
        const HistoricalInfo *const newHistoricalInfo, const HeaderPolicy *const headerPolicy) {
    const int timestamp = newHistoricalInfo->getTimeStamp();
    if (newProbability != NOT_A_PROBABILITY && originalHistoricalInfo->getLevel() == 0) {
        return HistoricalInfo(timestamp, MIN_VALID_LEVEL /* level */, 0 /* count */);
    } else if (!originalHistoricalInfo->isValid()) {
        // Add entry as a valid word.
        const int level = clampToVisibleEntryLevelRange(newHistoricalInfo->getLevel());
        const int count = clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy);
        return HistoricalInfo(timestamp, level, count);
    } else if (!originalHistoricalInfo->isValid()
            || originalHistoricalInfo->getLevel() < newHistoricalInfo->getLevel()
            || (originalHistoricalInfo->getLevel() == newHistoricalInfo->getLevel()
                    && originalHistoricalInfo->getCount() < newHistoricalInfo->getCount())) {
        // Initial information.
        return HistoricalInfo(timestamp, 0 /* level */, 1 /* count */);
        const int level = clampToValidLevelRange(newHistoricalInfo->getLevel());
        const int count = clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy);
        return HistoricalInfo(timestamp, level, count);
    } else {
        const int updatedCount = originalHistoricalInfo->getCount() + 1;
        if (updatedCount >= headerPolicy->getForgettingCurveOccurrencesToLevelUp()) {
            // The count exceeds the max value the level can be incremented.
            if (originalHistoricalInfo->getLevel() >= MAX_LEVEL) {
                // The level is already max.
                return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel(),
                        originalHistoricalInfo->getCount());
                return HistoricalInfo(timestamp,
                        originalHistoricalInfo->getLevel(), originalHistoricalInfo->getCount());
            } else {
                // Level up.
                return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel() + 1,
                        0 /* count */);
                return HistoricalInfo(timestamp,
                        originalHistoricalInfo->getLevel() + 1, 0 /* count */);
            }
        } else {
            return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel(), updatedCount);
@@ -73,8 +82,8 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
            headerPolicy->getForgettingCurveDurationToLevelDown());
    return sProbabilityTable.getProbability(
            headerPolicy->getForgettingCurveProbabilityValuesTableId(),
            std::min(std::max(historicalInfo->getLevel(), 0), MAX_LEVEL),
            std::min(std::max(elapsedTimeStepCount, 0), MAX_ELAPSED_TIME_STEP_COUNT));
            clampToValidLevelRange(historicalInfo->getLevel()),
            clampToValidTimeStepCountRange(elapsedTimeStepCount));
}

/* static */ int ForgettingCurveUtils::getProbability(const int unigramProbability,
@@ -155,6 +164,23 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
    return elapsedTimeInSeconds / timeStepDurationInSeconds;
}

/* static */ int ForgettingCurveUtils::clampToVisibleEntryLevelRange(const int level) {
    return std::min(std::max(level, MIN_VISIBLE_LEVEL), MAX_LEVEL);
}

/* static */ int ForgettingCurveUtils::clampToValidCountRange(const int count,
        const HeaderPolicy *const headerPolicy) {
    return std::min(std::max(count, 0), headerPolicy->getForgettingCurveOccurrencesToLevelUp() - 1);
}

/* static */ int ForgettingCurveUtils::clampToValidLevelRange(const int level) {
    return std::min(std::max(level, 0), MAX_LEVEL);
}

/* static */ int ForgettingCurveUtils::clampToValidTimeStepCountRange(const int timeStepCount) {
    return std::min(std::max(timeStepCount, 0), MAX_ELAPSED_TIME_STEP_COUNT);
}

const int ForgettingCurveUtils::ProbabilityTable::PROBABILITY_TABLE_COUNT = 4;
const int ForgettingCurveUtils::ProbabilityTable::WEAK_PROBABILITY_TABLE_ID = 0;
const int ForgettingCurveUtils::ProbabilityTable::MODEST_PROBABILITY_TABLE_ID = 1;
Loading