Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ce6aa898 authored by Yusuke Nojima's avatar Yusuke Nojima Committed by Android (Google) Code Review
Browse files

Merge "Calculate edit distances incrementally."

parents 6262fa53 04d87370
Loading
Loading
Loading
Loading
+93 −46
Original line number Diff line number Diff line
@@ -27,6 +27,87 @@

namespace latinime {

/////////////////////////////
// edit distance funcitons //
/////////////////////////////

#if 0 /* no longer used */
inline static int editDistance(
        int* editDistanceTable, const unsigned short* input,
        const int inputLength, const unsigned short* output, const int outputLength) {
    // dp[li][lo] dp[a][b] = dp[ a * lo + b]
    int* dp = editDistanceTable;
    const int li = inputLength + 1;
    const int lo = outputLength + 1;
    for (int i = 0; i < li; ++i) {
        dp[lo * i] = i;
    }
    for (int i = 0; i < lo; ++i) {
        dp[i] = i;
    }

    for (int i = 0; i < li - 1; ++i) {
        for (int j = 0; j < lo - 1; ++j) {
            const uint32_t ci = Dictionary::toBaseLowerCase(input[i]);
            const uint32_t co = Dictionary::toBaseLowerCase(output[j]);
            const uint16_t cost = (ci == co) ? 0 : 1;
            dp[(i + 1) * lo + (j + 1)] = min(dp[i * lo + (j + 1)] + 1,
                    min(dp[(i + 1) * lo + j] + 1, dp[i * lo + j] + cost));
            if (i > 0 && j > 0 && ci == Dictionary::toBaseLowerCase(output[j - 1])
                    && co == Dictionary::toBaseLowerCase(input[i - 1])) {
                dp[(i + 1) * lo + (j + 1)] = min(
                        dp[(i + 1) * lo + (j + 1)], dp[(i - 1) * lo + (j - 1)] + cost);
            }
        }
    }

    if (DEBUG_EDIT_DISTANCE) {
        LOGI("IN = %d, OUT = %d", inputLength, outputLength);
        for (int i = 0; i < li; ++i) {
            for (int j = 0; j < lo; ++j) {
                LOGI("EDIT[%d][%d], %d", i, j, dp[i * lo + j]);
            }
        }
    }
    return dp[li * lo - 1];
}
#endif

inline static void initEditDistance(int *editDistanceTable) {
    for (int i = 0; i <= MAX_WORD_LENGTH_INTERNAL; ++i) {
        editDistanceTable[i] = i;
    }
}

inline static void calcEditDistanceOneStep(int *editDistanceTable, const unsigned short *input,
        const int inputLength, const unsigned short *output, const int outputLength) {
    // Let dp[i][j] be editDistanceTable[i * (inputLength + 1) + j].
    // Assuming that dp[0][0] ... dp[outputLength - 1][inputLength] are already calculated,
    // and calculate dp[ouputLength][0] ... dp[outputLength][inputLength].
    int *const current = editDistanceTable + outputLength * (inputLength + 1);
    const int *const prev = editDistanceTable + (outputLength - 1) * (inputLength + 1);
    const int *const prevprev =
            outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputLength + 1) : NULL;
    current[0] = outputLength;
    const uint32_t co = Dictionary::toBaseLowerCase(output[outputLength - 1]);
    const uint32_t prevCO =
            outputLength >= 2 ? Dictionary::toBaseLowerCase(output[outputLength - 2]) : 0;
    for (int i = 1; i <= inputLength; ++i) {
        const uint32_t ci = Dictionary::toBaseLowerCase(input[i - 1]);
        const uint16_t cost = (ci == co) ? 0 : 1;
        current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost));
        if (i >= 2 && prevprev && ci == prevCO
                && co == Dictionary::toBaseLowerCase(input[i - 2])) {
            current[i] = min(current[i], prevprev[i - 2] + 1);
        }
    }
}

inline static int getCurrentEditDistance(
        int *editDistanceTable, const int inputLength, const int outputLength) {
    return editDistanceTable[(inputLength + 1) * (outputLength + 1) - 1];
}

//////////////////////
// inline functions //
//////////////////////
@@ -43,6 +124,7 @@ inline bool Correction::isQuote(const unsigned short c) {

Correction::Correction(const int typedLetterMultiplier, const int fullWordMultiplier)
        : TYPED_LETTER_MULTIPLIER(typedLetterMultiplier), FULL_WORD_MULTIPLIER(fullWordMultiplier) {
    initEditDistance(mEditDistanceTable);
}

void Correction::initCorrection(const ProximityInfo *pi, const int inputLength,
@@ -197,13 +279,21 @@ void Correction::startToTraverseAllNodes() {
}

bool Correction::needsToPrune() const {
    // TODO: use edit distance here
    return mOutputIndex - 1 >= mMaxDepth || mProximityCount > mMaxEditDistance;
}

void Correction::addCharToCurrentWord(const int32_t c) {
    mWord[mOutputIndex] = c;
    const unsigned short *primaryInputWord = mProximityInfo->getPrimaryInputWord();
    calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputLength,
            mWord, mOutputIndex + 1);
}

// TODO: inline?
Correction::CorrectionType Correction::processSkipChar(
        const int32_t c, const bool isTerminal, const bool inputIndexIncremented) {
    mWord[mOutputIndex] = c;
    addCharToCurrentWord(c);
    if (needsToTraverseAllNodes() && isTerminal) {
        mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0);
        mTerminalOutputIndex = mOutputIndex;
@@ -412,7 +502,7 @@ Correction::CorrectionType Correction::processCharAndCalcState(
                mProximityInfo->getNormalizedSquaredDistance(mInputIndex, proximityIndex);
    }

    mWord[mOutputIndex] = c;
    addCharToCurrentWord(c);

    // 4. Last char excessive correction
    mLastCharExceeded = mExcessiveCount == 0 && mSkippedCount == 0 && mTransposedCount == 0
@@ -526,47 +616,6 @@ inline static bool isUpperCase(unsigned short c) {
     return false;
}

/* static */
inline static int editDistance(
        int* editDistanceTable, const unsigned short* input,
        const int inputLength, const unsigned short* output, const int outputLength) {
    // dp[li][lo] dp[a][b] = dp[ a * lo + b]
    int* dp = editDistanceTable;
    const int li = inputLength + 1;
    const int lo = outputLength + 1;
    for (int i = 0; i < li; ++i) {
        dp[lo * i] = i;
    }
    for (int i = 0; i < lo; ++i) {
        dp[i] = i;
    }

    for (int i = 0; i < li - 1; ++i) {
        for (int j = 0; j < lo - 1; ++j) {
            const uint32_t ci = Dictionary::toBaseLowerCase(input[i]);
            const uint32_t co = Dictionary::toBaseLowerCase(output[j]);
            const uint16_t cost = (ci == co) ? 0 : 1;
            dp[(i + 1) * lo + (j + 1)] = min(dp[i * lo + (j + 1)] + 1,
                    min(dp[(i + 1) * lo + j] + 1, dp[i * lo + j] + cost));
            if (i > 0 && j > 0 && ci == Dictionary::toBaseLowerCase(output[j - 1])
                    && co == Dictionary::toBaseLowerCase(input[i - 1])) {
                dp[(i + 1) * lo + (j + 1)] = min(
                        dp[(i + 1) * lo + (j + 1)], dp[(i - 1) * lo + (j - 1)] + cost);
            }
        }
    }

    if (DEBUG_EDIT_DISTANCE) {
        LOGI("IN = %d, OUT = %d", inputLength, outputLength);
        for (int i = 0; i < li; ++i) {
            for (int j = 0; j < lo; ++j) {
                LOGI("EDIT[%d][%d], %d", i, j, dp[i * lo + j]);
            }
        }
    }
    return dp[li * lo - 1];
}

//////////////////////
// RankingAlgorithm //
//////////////////////
@@ -612,9 +661,7 @@ int Correction::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const
    // TODO: Optimize this.
    // TODO: Ignoring edit distance for transposed char, for now
    if (transposedCount == 0 && (proximityMatchedCount > 0 || skipped || excessiveCount > 0)) {
        const unsigned short* primaryInputWord = proximityInfo->getPrimaryInputWord();
        ed = editDistance(editDistanceTable, primaryInputWord,
                inputLength, word, outputIndex + 1);
        ed = getCurrentEditDistance(editDistanceTable, inputLength, outputIndex + 1);
        const int matchWeight = powerIntCapped(typedLetterMultiplier,
                max(inputLength, outputIndex + 1) - ed);
        multiplyIntCapped(matchWeight, &finalFreq);
+1 −0
Original line number Diff line number Diff line
@@ -102,6 +102,7 @@ private:
    inline bool isQuote(const unsigned short c);
    inline CorrectionType processSkipChar(
            const int32_t c, const bool isTerminal, const bool inputIndexIncremented);
    inline void addCharToCurrentWord(const int32_t c);

    const int TYPED_LETTER_MULTIPLIER;
    const int FULL_WORD_MULTIPLIER;