Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 7bdf008d authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android Git Automerger
Browse files

am 948ef10d: Merge "Improve bigram probability computation for decaying dicts."

* commit '948ef10d':
  Improve bigram probability computation for decaying dicts.
parents 462476b0 948ef10d
Loading
Loading
Loading
Loading
+0 −8
Original line number Diff line number Diff line
@@ -35,23 +35,15 @@ const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE
// count.
const char *const HeaderPolicy::HAS_HISTORICAL_INFO_KEY = "HAS_HISTORICAL_INFO";
const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration
const char *const HeaderPolicy::FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY =
        "FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP";
const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY =
        "FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID";
const char *const HeaderPolicy::FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY =
        "FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS";

const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_COUNT";
const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_COUNT";

const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100;
const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f;
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP = 2;
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3;
// 30 days
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS =
        30 * 24 * 60 * 60;

const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000;
const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 10000;
+1 −30
Original line number Diff line number Diff line
@@ -53,15 +53,9 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
                      EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)),
              mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
                      &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)),
              mForgettingCurveOccurrencesToLevelUp(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY,
                      DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP)),
              mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
                      DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
              mForgettingCurveDurationToLevelDown(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY,
                      DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS)),
              mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
              mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
@@ -86,15 +80,9 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
              mUnigramCount(0), mBigramCount(0), mExtendedRegionSize(0),
              mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
                      &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)),
              mForgettingCurveOccurrencesToLevelUp(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY,
                      DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP)),
              mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
                      DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
              mForgettingCurveDurationToLevelDown(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY,
                      DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS)),
              mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
              mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
@@ -113,12 +101,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
              mUnigramCount(headerPolicy->mUnigramCount), mBigramCount(headerPolicy->mBigramCount),
              mExtendedRegionSize(headerPolicy->mExtendedRegionSize),
              mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords),
              mForgettingCurveOccurrencesToLevelUp(
                      headerPolicy->mForgettingCurveOccurrencesToLevelUp),
              mForgettingCurveProbabilityValuesTableId(
                      headerPolicy->mForgettingCurveProbabilityValuesTableId),
              mForgettingCurveDurationToLevelDown(
                      headerPolicy->mForgettingCurveDurationToLevelDown),
              mMaxUnigramCount(headerPolicy->mMaxUnigramCount),
              mMaxBigramCount(headerPolicy->mMaxBigramCount),
              mCodePointTable(headerPolicy->mCodePointTable) {}
@@ -130,8 +114,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
              mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false),
              mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0),
              mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false),
              mForgettingCurveOccurrencesToLevelUp(0), mForgettingCurveProbabilityValuesTableId(0),
              mForgettingCurveDurationToLevelDown(0), mMaxUnigramCount(0), mMaxBigramCount(0),
              mForgettingCurveProbabilityValuesTableId(0), mMaxUnigramCount(0), mMaxBigramCount(0),
              mCodePointTable(nullptr) {}

    ~HeaderPolicy() {}
@@ -217,18 +200,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
        return &mAttributeMap;
    }

    AK_FORCE_INLINE int getForgettingCurveOccurrencesToLevelUp() const {
        return mForgettingCurveOccurrencesToLevelUp;
    }

    AK_FORCE_INLINE int getForgettingCurveProbabilityValuesTableId() const {
        return mForgettingCurveProbabilityValuesTableId;
    }

    AK_FORCE_INLINE int getForgettingCurveDurationToLevelDown() const {
        return mForgettingCurveDurationToLevelDown;
    }

    AK_FORCE_INLINE int getMaxUnigramCount() const {
        return mMaxUnigramCount;
    }
@@ -280,9 +255,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
    static const char *const MAX_BIGRAM_COUNT_KEY;
    static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE;
    static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE;
    static const int DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP;
    static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID;
    static const int DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS;
    static const int DEFAULT_MAX_UNIGRAM_COUNT;
    static const int DEFAULT_MAX_BIGRAM_COUNT;

@@ -300,9 +273,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
    const int mBigramCount;
    const int mExtendedRegionSize;
    const bool mHasHistoricalInfoOfWords;
    const int mForgettingCurveOccurrencesToLevelUp;
    const int mForgettingCurveProbabilityValuesTableId;
    const int mForgettingCurveDurationToLevelDown;
    const int mMaxUnigramCount;
    const int mMaxBigramCount;
    const int *const mCodePointTable;
+53 −27
Original line number Diff line number Diff line
@@ -146,18 +146,15 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributes(const int probabi

int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
        const int bigramProbability) const {
    if (mHeaderPolicy->isDecayingDict()) {
        // Both probabilities are encoded. Decode them and get probability.
        return ForgettingCurveUtils::getProbability(unigramProbability, bigramProbability);
    } else {
    // In the v4 format, bigramProbability is a conditional probability.
    const int bigramConditionalProbability = bigramProbability;
    if (unigramProbability == NOT_A_PROBABILITY) {
        return NOT_A_PROBABILITY;
        } else if (bigramProbability == NOT_A_PROBABILITY) {
            return ProbabilityUtils::backoff(unigramProbability);
        } else {
            return bigramProbability;
    }
    if (bigramConditionalProbability == NOT_A_PROBABILITY) {
        return ProbabilityUtils::backoff(unigramProbability);
    }
    return bigramConditionalProbability;
}

int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
@@ -170,37 +167,66 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI
    if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
        return NOT_A_PROBABILITY;
    }
    if (!prevWordIds.empty()) {
        const int bigramsPosition = getBigramsPositionOfPtNode(
                getTerminalPtNodePosFromWordId(prevWordIds[0]));
    if (prevWordIds.empty()) {
        return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
    }
    if (prevWordIds[0] == NOT_A_WORD_ID) {
        return NOT_A_PROBABILITY;
    }
    const PtNodeParams prevWordPtNodeParams =
            mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordIds[0]);
    if (prevWordPtNodeParams.isDeleted()) {
        return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
    }
    const int bigramsPosition = mBuffers->getBigramDictContent()->getBigramListHeadPos(
            prevWordPtNodeParams.getTerminalId());
    BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
    while (bigramsIt.hasNext()) {
        bigramsIt.next();
        if (bigramsIt.getBigramPos() == ptNodePos
                && bigramsIt.getProbability() != NOT_A_PROBABILITY) {
                return getProbability(ptNodeParams.getProbability(), bigramsIt.getProbability());
            const int bigramConditionalProbability = getBigramConditionalProbability(
                    prevWordPtNodeParams.getProbability(), bigramsIt.getProbability());
            return getProbability(ptNodeParams.getProbability(), bigramConditionalProbability);
        }
    }
    return NOT_A_PROBABILITY;
}
    return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}

void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds,
        NgramListener *const listener) const {
    if (prevWordIds.empty()) {
    if (prevWordIds.firstOrDefault(NOT_A_DICT_POS) == NOT_A_DICT_POS) {
        return;
    }
    const PtNodeParams prevWordPtNodeParams =
            mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordIds[0]);
    if (prevWordPtNodeParams.isDeleted()) {
        return;
    }
    const int bigramsPosition = getBigramsPositionOfPtNode(
            getTerminalPtNodePosFromWordId(prevWordIds[0]));
    const int bigramsPosition = mBuffers->getBigramDictContent()->getBigramListHeadPos(
            prevWordPtNodeParams.getTerminalId());
    BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
    while (bigramsIt.hasNext()) {
        bigramsIt.next();
        listener->onVisitEntry(bigramsIt.getProbability(),
        const int bigramConditionalProbability = getBigramConditionalProbability(
                prevWordPtNodeParams.getProbability(), bigramsIt.getProbability());
        listener->onVisitEntry(bigramConditionalProbability,
                getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos()));
    }
}

int Ver4PatriciaTriePolicy::getBigramConditionalProbability(const int prevWordUnigramProbability,
        const int bigramProbability) const {
    if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
        // Calculate conditional probability.
        return std::min(MAX_PROBABILITY - prevWordUnigramProbability + bigramProbability,
                MAX_PROBABILITY);
    } else {
        // bigramProbability is a conditional probability.
        return bigramProbability;
    }
}

BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator(
        const int wordId) const {
    const int shortcutPos = getShortcutPositionOfPtNode(getTerminalPtNodePosFromWordId(wordId));
+2 −0
Original line number Diff line number Diff line
@@ -174,6 +174,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
    int getTerminalPtNodePosFromWordId(const int wordId) const;
    const WordAttributes getWordAttributes(const int probability,
            const PtNodeParams &ptNodeParams) const;
    int getBigramConditionalProbability(const int prevWordUnigramProbability,
            const int bigramProbability) const;
};
} // namespace v402
} // namespace backward
+25 −30
Original line number Diff line number Diff line
@@ -29,10 +29,14 @@ namespace latinime {
const int ForgettingCurveUtils::MULTIPLIER_TWO_IN_PROBABILITY_SCALE = 8;
const int ForgettingCurveUtils::DECAY_INTERVAL_SECONDS = 2 * 60 * 60;

const int ForgettingCurveUtils::MAX_LEVEL = 3;
const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 1;
const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 15;
const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 14;
const int ForgettingCurveUtils::MAX_LEVEL = 15;
const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 2;
const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 31;
const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 30;
const int ForgettingCurveUtils::OCCURRENCES_TO_RAISE_THE_LEVEL = 1;
// TODO: Evaluate whether this should be 7.5 days.
// 15 days
const int ForgettingCurveUtils::DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS = 15 * 24 * 60 * 60;

const float ForgettingCurveUtils::UNIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
const float ForgettingCurveUtils::BIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
@@ -54,19 +58,23 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
            || (originalHistoricalInfo->getLevel() == newHistoricalInfo->getLevel()
                    && originalHistoricalInfo->getCount() < newHistoricalInfo->getCount())) {
        // Initial information.
        int count = newHistoricalInfo->getCount();
        if (count >= OCCURRENCES_TO_RAISE_THE_LEVEL) {
            const int level = clampToValidLevelRange(newHistoricalInfo->getLevel() + 1);
            return HistoricalInfo(timestamp, level, 0 /* count */);
        }
        const int level = clampToValidLevelRange(newHistoricalInfo->getLevel());
        const int count = clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy);
        return HistoricalInfo(timestamp, level, count);
        return HistoricalInfo(timestamp, level, clampToValidCountRange(count, headerPolicy));
    } else {
        const int updatedCount = originalHistoricalInfo->getCount() + 1;
        if (updatedCount >= headerPolicy->getForgettingCurveOccurrencesToLevelUp()) {
        if (updatedCount >= OCCURRENCES_TO_RAISE_THE_LEVEL) {
            // The count exceeds the max value the level can be incremented.
            if (originalHistoricalInfo->getLevel() >= MAX_LEVEL) {
                // The level is already max.
                return HistoricalInfo(timestamp,
                        originalHistoricalInfo->getLevel(), originalHistoricalInfo->getCount());
            } else {
                // Level up.
                // Raise the level.
                return HistoricalInfo(timestamp,
                        originalHistoricalInfo->getLevel() + 1, 0 /* count */);
            }
@@ -79,31 +87,18 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
/* static */ int ForgettingCurveUtils::decodeProbability(
        const HistoricalInfo *const historicalInfo, const HeaderPolicy *const headerPolicy) {
    const int elapsedTimeStepCount = getElapsedTimeStepCount(historicalInfo->getTimestamp(),
            headerPolicy->getForgettingCurveDurationToLevelDown());
            DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS);
    return sProbabilityTable.getProbability(
            headerPolicy->getForgettingCurveProbabilityValuesTableId(),
            clampToValidLevelRange(historicalInfo->getLevel()),
            clampToValidTimeStepCountRange(elapsedTimeStepCount));
}

/* static */ int ForgettingCurveUtils::getProbability(const int unigramProbability,
        const int bigramProbability) {
    if (unigramProbability == NOT_A_PROBABILITY) {
        return NOT_A_PROBABILITY;
    } else if (bigramProbability == NOT_A_PROBABILITY) {
        return std::min(backoff(unigramProbability), MAX_PROBABILITY);
    } else {
        // TODO: Investigate better way to handle bigram probability.
        return std::min(std::max(unigramProbability,
                bigramProbability + MULTIPLIER_TWO_IN_PROBABILITY_SCALE), MAX_PROBABILITY);
    }
}

/* static */ bool ForgettingCurveUtils::needsToKeep(const HistoricalInfo *const historicalInfo,
        const HeaderPolicy *const headerPolicy) {
    return historicalInfo->getLevel() > 0
            || getElapsedTimeStepCount(historicalInfo->getTimestamp(),
                    headerPolicy->getForgettingCurveDurationToLevelDown())
                    DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS)
                            < DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD;
}

@@ -113,14 +108,14 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
    if (originalHistoricalInfo->getTimestamp() == NOT_A_TIMESTAMP) {
        return HistoricalInfo();
    }
    const int durationToLevelDownInSeconds = headerPolicy->getForgettingCurveDurationToLevelDown();
    const int durationToLevelDownInSeconds = DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS;
    const int elapsedTimeStep = getElapsedTimeStepCount(
            originalHistoricalInfo->getTimestamp(), durationToLevelDownInSeconds);
    if (elapsedTimeStep <= MAX_ELAPSED_TIME_STEP_COUNT) {
        // No need to update historical info.
        return *originalHistoricalInfo;
    }
    // Level down.
    // Lower the level.
    const int maxLevelDownAmonut = elapsedTimeStep / (MAX_ELAPSED_TIME_STEP_COUNT + 1);
    const int levelDownAmount = (maxLevelDownAmonut >= originalHistoricalInfo->getLevel()) ?
            originalHistoricalInfo->getLevel() : maxLevelDownAmonut;
@@ -170,7 +165,7 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT

/* static */ int ForgettingCurveUtils::clampToValidCountRange(const int count,
        const HeaderPolicy *const headerPolicy) {
    return std::min(std::max(count, 0), headerPolicy->getForgettingCurveOccurrencesToLevelUp() - 1);
    return std::min(std::max(count, 0), OCCURRENCES_TO_RAISE_THE_LEVEL - 1);
}

/* static */ int ForgettingCurveUtils::clampToValidLevelRange(const int level) {
@@ -187,9 +182,9 @@ const int ForgettingCurveUtils::ProbabilityTable::MODEST_PROBABILITY_TABLE_ID =
const int ForgettingCurveUtils::ProbabilityTable::STRONG_PROBABILITY_TABLE_ID = 2;
const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_PROBABILITY_TABLE_ID = 3;
const int ForgettingCurveUtils::ProbabilityTable::WEAK_MAX_PROBABILITY = 127;
const int ForgettingCurveUtils::ProbabilityTable::MODEST_BASE_PROBABILITY = 32;
const int ForgettingCurveUtils::ProbabilityTable::STRONG_BASE_PROBABILITY = 35;
const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_BASE_PROBABILITY = 40;
const int ForgettingCurveUtils::ProbabilityTable::MODEST_BASE_PROBABILITY = 8;
const int ForgettingCurveUtils::ProbabilityTable::STRONG_BASE_PROBABILITY = 9;
const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_BASE_PROBABILITY = 10;


ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTables() {
@@ -202,7 +197,7 @@ ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTables() {
            const float endProbability = getBaseProbabilityForLevel(tableId, level - 1);
            for (int timeStepCount = 0; timeStepCount <= MAX_ELAPSED_TIME_STEP_COUNT;
                    ++timeStepCount) {
                if (level == 0) {
                if (level < MIN_VISIBLE_LEVEL) {
                    mTables[tableId][level][timeStepCount] = NOT_A_PROBABILITY;
                    continue;
                }
Loading