Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 948ef10d authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android (Google) Code Review
Browse files

Merge "Improve bigram probability computation for decaying dicts."

parents 8784b436 aae1a062
Loading
Loading
Loading
Loading
+0 −8
Original line number Diff line number Diff line
@@ -35,23 +35,15 @@ const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE
// count.
const char *const HeaderPolicy::HAS_HISTORICAL_INFO_KEY = "HAS_HISTORICAL_INFO";
const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration
const char *const HeaderPolicy::FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY =
        "FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP";
const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY =
        "FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID";
const char *const HeaderPolicy::FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY =
        "FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS";

const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_COUNT";
const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_COUNT";

const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100;
const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f;
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP = 2;
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3;
// 30 days
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS =
        30 * 24 * 60 * 60;

const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000;
const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 10000;
+1 −30
Original line number Diff line number Diff line
@@ -53,15 +53,9 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
                      EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)),
              mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
                      &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)),
              mForgettingCurveOccurrencesToLevelUp(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY,
                      DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP)),
              mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
                      DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
              mForgettingCurveDurationToLevelDown(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY,
                      DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS)),
              mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
              mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
@@ -86,15 +80,9 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
              mUnigramCount(0), mBigramCount(0), mExtendedRegionSize(0),
              mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
                      &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)),
              mForgettingCurveOccurrencesToLevelUp(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY,
                      DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP)),
              mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
                      DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
              mForgettingCurveDurationToLevelDown(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY,
                      DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS)),
              mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
                      &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
              mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
@@ -113,12 +101,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
              mUnigramCount(headerPolicy->mUnigramCount), mBigramCount(headerPolicy->mBigramCount),
              mExtendedRegionSize(headerPolicy->mExtendedRegionSize),
              mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords),
              mForgettingCurveOccurrencesToLevelUp(
                      headerPolicy->mForgettingCurveOccurrencesToLevelUp),
              mForgettingCurveProbabilityValuesTableId(
                      headerPolicy->mForgettingCurveProbabilityValuesTableId),
              mForgettingCurveDurationToLevelDown(
                      headerPolicy->mForgettingCurveDurationToLevelDown),
              mMaxUnigramCount(headerPolicy->mMaxUnigramCount),
              mMaxBigramCount(headerPolicy->mMaxBigramCount),
              mCodePointTable(headerPolicy->mCodePointTable) {}
@@ -130,8 +114,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
              mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false),
              mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0),
              mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false),
              mForgettingCurveOccurrencesToLevelUp(0), mForgettingCurveProbabilityValuesTableId(0),
              mForgettingCurveDurationToLevelDown(0), mMaxUnigramCount(0), mMaxBigramCount(0),
              mForgettingCurveProbabilityValuesTableId(0), mMaxUnigramCount(0), mMaxBigramCount(0),
              mCodePointTable(nullptr) {}

    ~HeaderPolicy() {}
@@ -217,18 +200,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
        return &mAttributeMap;
    }

    AK_FORCE_INLINE int getForgettingCurveOccurrencesToLevelUp() const {
        return mForgettingCurveOccurrencesToLevelUp;
    }

    AK_FORCE_INLINE int getForgettingCurveProbabilityValuesTableId() const {
        return mForgettingCurveProbabilityValuesTableId;
    }

    AK_FORCE_INLINE int getForgettingCurveDurationToLevelDown() const {
        return mForgettingCurveDurationToLevelDown;
    }

    AK_FORCE_INLINE int getMaxUnigramCount() const {
        return mMaxUnigramCount;
    }
@@ -280,9 +255,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
    static const char *const MAX_BIGRAM_COUNT_KEY;
    static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE;
    static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE;
    static const int DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP;
    static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID;
    static const int DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS;
    static const int DEFAULT_MAX_UNIGRAM_COUNT;
    static const int DEFAULT_MAX_BIGRAM_COUNT;

@@ -300,9 +273,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
    const int mBigramCount;
    const int mExtendedRegionSize;
    const bool mHasHistoricalInfoOfWords;
    const int mForgettingCurveOccurrencesToLevelUp;
    const int mForgettingCurveProbabilityValuesTableId;
    const int mForgettingCurveDurationToLevelDown;
    const int mMaxUnigramCount;
    const int mMaxBigramCount;
    const int *const mCodePointTable;
+53 −27
Original line number Diff line number Diff line
@@ -146,18 +146,15 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributes(const int probabi

int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
        const int bigramProbability) const {
    if (mHeaderPolicy->isDecayingDict()) {
        // Both probabilities are encoded. Decode them and get probability.
        return ForgettingCurveUtils::getProbability(unigramProbability, bigramProbability);
    } else {
    // In the v4 format, bigramProbability is a conditional probability.
    const int bigramConditionalProbability = bigramProbability;
    if (unigramProbability == NOT_A_PROBABILITY) {
        return NOT_A_PROBABILITY;
        } else if (bigramProbability == NOT_A_PROBABILITY) {
            return ProbabilityUtils::backoff(unigramProbability);
        } else {
            return bigramProbability;
    }
    if (bigramConditionalProbability == NOT_A_PROBABILITY) {
        return ProbabilityUtils::backoff(unigramProbability);
    }
    return bigramConditionalProbability;
}

int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
@@ -170,37 +167,66 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI
    if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
        return NOT_A_PROBABILITY;
    }
    if (!prevWordIds.empty()) {
        const int bigramsPosition = getBigramsPositionOfPtNode(
                getTerminalPtNodePosFromWordId(prevWordIds[0]));
    if (prevWordIds.empty()) {
        return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
    }
    if (prevWordIds[0] == NOT_A_WORD_ID) {
        return NOT_A_PROBABILITY;
    }
    const PtNodeParams prevWordPtNodeParams =
            mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordIds[0]);
    if (prevWordPtNodeParams.isDeleted()) {
        return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
    }
    const int bigramsPosition = mBuffers->getBigramDictContent()->getBigramListHeadPos(
            prevWordPtNodeParams.getTerminalId());
    BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
    while (bigramsIt.hasNext()) {
        bigramsIt.next();
        if (bigramsIt.getBigramPos() == ptNodePos
                && bigramsIt.getProbability() != NOT_A_PROBABILITY) {
                return getProbability(ptNodeParams.getProbability(), bigramsIt.getProbability());
            const int bigramConditionalProbability = getBigramConditionalProbability(
                    prevWordPtNodeParams.getProbability(), bigramsIt.getProbability());
            return getProbability(ptNodeParams.getProbability(), bigramConditionalProbability);
        }
    }
    return NOT_A_PROBABILITY;
}
    return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}

void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds,
        NgramListener *const listener) const {
    if (prevWordIds.empty()) {
    if (prevWordIds.firstOrDefault(NOT_A_DICT_POS) == NOT_A_DICT_POS) {
        return;
    }
    const PtNodeParams prevWordPtNodeParams =
            mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordIds[0]);
    if (prevWordPtNodeParams.isDeleted()) {
        return;
    }
    const int bigramsPosition = getBigramsPositionOfPtNode(
            getTerminalPtNodePosFromWordId(prevWordIds[0]));
    const int bigramsPosition = mBuffers->getBigramDictContent()->getBigramListHeadPos(
            prevWordPtNodeParams.getTerminalId());
    BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
    while (bigramsIt.hasNext()) {
        bigramsIt.next();
        listener->onVisitEntry(bigramsIt.getProbability(),
        const int bigramConditionalProbability = getBigramConditionalProbability(
                prevWordPtNodeParams.getProbability(), bigramsIt.getProbability());
        listener->onVisitEntry(bigramConditionalProbability,
                getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos()));
    }
}

int Ver4PatriciaTriePolicy::getBigramConditionalProbability(const int prevWordUnigramProbability,
        const int bigramProbability) const {
    if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
        // Calculate conditional probability.
        return std::min(MAX_PROBABILITY - prevWordUnigramProbability + bigramProbability,
                MAX_PROBABILITY);
    } else {
        // bigramProbability is a conditional probability.
        return bigramProbability;
    }
}

BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator(
        const int wordId) const {
    const int shortcutPos = getShortcutPositionOfPtNode(getTerminalPtNodePosFromWordId(wordId));
+2 −0
Original line number Diff line number Diff line
@@ -174,6 +174,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
    int getTerminalPtNodePosFromWordId(const int wordId) const;
    const WordAttributes getWordAttributes(const int probability,
            const PtNodeParams &ptNodeParams) const;
    int getBigramConditionalProbability(const int prevWordUnigramProbability,
            const int bigramProbability) const;
};
} // namespace v402
} // namespace backward
+25 −30
Original line number Diff line number Diff line
@@ -29,10 +29,14 @@ namespace latinime {
const int ForgettingCurveUtils::MULTIPLIER_TWO_IN_PROBABILITY_SCALE = 8;
const int ForgettingCurveUtils::DECAY_INTERVAL_SECONDS = 2 * 60 * 60;

const int ForgettingCurveUtils::MAX_LEVEL = 3;
const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 1;
const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 15;
const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 14;
const int ForgettingCurveUtils::MAX_LEVEL = 15;
const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 2;
const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 31;
const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 30;
const int ForgettingCurveUtils::OCCURRENCES_TO_RAISE_THE_LEVEL = 1;
// TODO: Evaluate whether this should be 7.5 days.
// 15 days
const int ForgettingCurveUtils::DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS = 15 * 24 * 60 * 60;

const float ForgettingCurveUtils::UNIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
const float ForgettingCurveUtils::BIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
@@ -54,19 +58,23 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
            || (originalHistoricalInfo->getLevel() == newHistoricalInfo->getLevel()
                    && originalHistoricalInfo->getCount() < newHistoricalInfo->getCount())) {
        // Initial information.
        int count = newHistoricalInfo->getCount();
        if (count >= OCCURRENCES_TO_RAISE_THE_LEVEL) {
            const int level = clampToValidLevelRange(newHistoricalInfo->getLevel() + 1);
            return HistoricalInfo(timestamp, level, 0 /* count */);
        }
        const int level = clampToValidLevelRange(newHistoricalInfo->getLevel());
        const int count = clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy);
        return HistoricalInfo(timestamp, level, count);
        return HistoricalInfo(timestamp, level, clampToValidCountRange(count, headerPolicy));
    } else {
        const int updatedCount = originalHistoricalInfo->getCount() + 1;
        if (updatedCount >= headerPolicy->getForgettingCurveOccurrencesToLevelUp()) {
        if (updatedCount >= OCCURRENCES_TO_RAISE_THE_LEVEL) {
            // The count exceeds the max value the level can be incremented.
            if (originalHistoricalInfo->getLevel() >= MAX_LEVEL) {
                // The level is already max.
                return HistoricalInfo(timestamp,
                        originalHistoricalInfo->getLevel(), originalHistoricalInfo->getCount());
            } else {
                // Level up.
                // Raise the level.
                return HistoricalInfo(timestamp,
                        originalHistoricalInfo->getLevel() + 1, 0 /* count */);
            }
@@ -79,31 +87,18 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
/* static */ int ForgettingCurveUtils::decodeProbability(
        const HistoricalInfo *const historicalInfo, const HeaderPolicy *const headerPolicy) {
    const int elapsedTimeStepCount = getElapsedTimeStepCount(historicalInfo->getTimestamp(),
            headerPolicy->getForgettingCurveDurationToLevelDown());
            DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS);
    return sProbabilityTable.getProbability(
            headerPolicy->getForgettingCurveProbabilityValuesTableId(),
            clampToValidLevelRange(historicalInfo->getLevel()),
            clampToValidTimeStepCountRange(elapsedTimeStepCount));
}

/* static */ int ForgettingCurveUtils::getProbability(const int unigramProbability,
        const int bigramProbability) {
    if (unigramProbability == NOT_A_PROBABILITY) {
        return NOT_A_PROBABILITY;
    } else if (bigramProbability == NOT_A_PROBABILITY) {
        return std::min(backoff(unigramProbability), MAX_PROBABILITY);
    } else {
        // TODO: Investigate better way to handle bigram probability.
        return std::min(std::max(unigramProbability,
                bigramProbability + MULTIPLIER_TWO_IN_PROBABILITY_SCALE), MAX_PROBABILITY);
    }
}

/* static */ bool ForgettingCurveUtils::needsToKeep(const HistoricalInfo *const historicalInfo,
        const HeaderPolicy *const headerPolicy) {
    return historicalInfo->getLevel() > 0
            || getElapsedTimeStepCount(historicalInfo->getTimestamp(),
                    headerPolicy->getForgettingCurveDurationToLevelDown())
                    DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS)
                            < DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD;
}

@@ -113,14 +108,14 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
    if (originalHistoricalInfo->getTimestamp() == NOT_A_TIMESTAMP) {
        return HistoricalInfo();
    }
    const int durationToLevelDownInSeconds = headerPolicy->getForgettingCurveDurationToLevelDown();
    const int durationToLevelDownInSeconds = DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS;
    const int elapsedTimeStep = getElapsedTimeStepCount(
            originalHistoricalInfo->getTimestamp(), durationToLevelDownInSeconds);
    if (elapsedTimeStep <= MAX_ELAPSED_TIME_STEP_COUNT) {
        // No need to update historical info.
        return *originalHistoricalInfo;
    }
    // Level down.
    // Lower the level.
    const int maxLevelDownAmonut = elapsedTimeStep / (MAX_ELAPSED_TIME_STEP_COUNT + 1);
    const int levelDownAmount = (maxLevelDownAmonut >= originalHistoricalInfo->getLevel()) ?
            originalHistoricalInfo->getLevel() : maxLevelDownAmonut;
@@ -170,7 +165,7 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT

/* static */ int ForgettingCurveUtils::clampToValidCountRange(const int count,
        const HeaderPolicy *const headerPolicy) {
    return std::min(std::max(count, 0), headerPolicy->getForgettingCurveOccurrencesToLevelUp() - 1);
    return std::min(std::max(count, 0), OCCURRENCES_TO_RAISE_THE_LEVEL - 1);
}

/* static */ int ForgettingCurveUtils::clampToValidLevelRange(const int level) {
@@ -187,9 +182,9 @@ const int ForgettingCurveUtils::ProbabilityTable::MODEST_PROBABILITY_TABLE_ID =
const int ForgettingCurveUtils::ProbabilityTable::STRONG_PROBABILITY_TABLE_ID = 2;
const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_PROBABILITY_TABLE_ID = 3;
const int ForgettingCurveUtils::ProbabilityTable::WEAK_MAX_PROBABILITY = 127;
const int ForgettingCurveUtils::ProbabilityTable::MODEST_BASE_PROBABILITY = 32;
const int ForgettingCurveUtils::ProbabilityTable::STRONG_BASE_PROBABILITY = 35;
const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_BASE_PROBABILITY = 40;
const int ForgettingCurveUtils::ProbabilityTable::MODEST_BASE_PROBABILITY = 8;
const int ForgettingCurveUtils::ProbabilityTable::STRONG_BASE_PROBABILITY = 9;
const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_BASE_PROBABILITY = 10;


ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTables() {
@@ -202,7 +197,7 @@ ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTables() {
            const float endProbability = getBaseProbabilityForLevel(tableId, level - 1);
            for (int timeStepCount = 0; timeStepCount <= MAX_ELAPSED_TIME_STEP_COUNT;
                    ++timeStepCount) {
                if (level == 0) {
                if (level < MIN_VISIBLE_LEVEL) {
                    mTables[tableId][level][timeStepCount] = NOT_A_PROBABILITY;
                    continue;
                }
Loading