Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 04d87370 authored by Yusuke Nojima's avatar Yusuke Nojima
Browse files

Calculate edit distances incrementally.

Change-Id: I3ee734b9b71351523dc8658cba33d6c8435e348e
parent 283d35cb
Loading
Loading
Loading
Loading
+93 −46
Original line number Diff line number Diff line
@@ -27,6 +27,87 @@

namespace latinime {

/////////////////////////////
// edit distance funcitons //
/////////////////////////////

#if 0 /* no longer used */
inline static int editDistance(
        int* editDistanceTable, const unsigned short* input,
        const int inputLength, const unsigned short* output, const int outputLength) {
    // dp[li][lo] dp[a][b] = dp[ a * lo + b]
    int* dp = editDistanceTable;
    const int li = inputLength + 1;
    const int lo = outputLength + 1;
    for (int i = 0; i < li; ++i) {
        dp[lo * i] = i;
    }
    for (int i = 0; i < lo; ++i) {
        dp[i] = i;
    }

    for (int i = 0; i < li - 1; ++i) {
        for (int j = 0; j < lo - 1; ++j) {
            const uint32_t ci = Dictionary::toBaseLowerCase(input[i]);
            const uint32_t co = Dictionary::toBaseLowerCase(output[j]);
            const uint16_t cost = (ci == co) ? 0 : 1;
            dp[(i + 1) * lo + (j + 1)] = min(dp[i * lo + (j + 1)] + 1,
                    min(dp[(i + 1) * lo + j] + 1, dp[i * lo + j] + cost));
            if (i > 0 && j > 0 && ci == Dictionary::toBaseLowerCase(output[j - 1])
                    && co == Dictionary::toBaseLowerCase(input[i - 1])) {
                dp[(i + 1) * lo + (j + 1)] = min(
                        dp[(i + 1) * lo + (j + 1)], dp[(i - 1) * lo + (j - 1)] + cost);
            }
        }
    }

    if (DEBUG_EDIT_DISTANCE) {
        LOGI("IN = %d, OUT = %d", inputLength, outputLength);
        for (int i = 0; i < li; ++i) {
            for (int j = 0; j < lo; ++j) {
                LOGI("EDIT[%d][%d], %d", i, j, dp[i * lo + j]);
            }
        }
    }
    return dp[li * lo - 1];
}
#endif

inline static void initEditDistance(int *editDistanceTable) {
    for (int i = 0; i <= MAX_WORD_LENGTH_INTERNAL; ++i) {
        editDistanceTable[i] = i;
    }
}

inline static void calcEditDistanceOneStep(int *editDistanceTable, const unsigned short *input,
        const int inputLength, const unsigned short *output, const int outputLength) {
    // Let dp[i][j] be editDistanceTable[i * (inputLength + 1) + j].
    // Assuming that dp[0][0] ... dp[outputLength - 1][inputLength] are already calculated,
    // and calculate dp[ouputLength][0] ... dp[outputLength][inputLength].
    int *const current = editDistanceTable + outputLength * (inputLength + 1);
    const int *const prev = editDistanceTable + (outputLength - 1) * (inputLength + 1);
    const int *const prevprev =
            outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputLength + 1) : NULL;
    current[0] = outputLength;
    const uint32_t co = Dictionary::toBaseLowerCase(output[outputLength - 1]);
    const uint32_t prevCO =
            outputLength >= 2 ? Dictionary::toBaseLowerCase(output[outputLength - 2]) : 0;
    for (int i = 1; i <= inputLength; ++i) {
        const uint32_t ci = Dictionary::toBaseLowerCase(input[i - 1]);
        const uint16_t cost = (ci == co) ? 0 : 1;
        current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost));
        if (i >= 2 && prevprev && ci == prevCO
                && co == Dictionary::toBaseLowerCase(input[i - 2])) {
            current[i] = min(current[i], prevprev[i - 2] + 1);
        }
    }
}

inline static int getCurrentEditDistance(
        int *editDistanceTable, const int inputLength, const int outputLength) {
    return editDistanceTable[(inputLength + 1) * (outputLength + 1) - 1];
}

//////////////////////
// inline functions //
//////////////////////
@@ -43,6 +124,7 @@ inline bool Correction::isQuote(const unsigned short c) {

Correction::Correction(const int typedLetterMultiplier, const int fullWordMultiplier)
        : TYPED_LETTER_MULTIPLIER(typedLetterMultiplier), FULL_WORD_MULTIPLIER(fullWordMultiplier) {
    initEditDistance(mEditDistanceTable);
}

void Correction::initCorrection(const ProximityInfo *pi, const int inputLength,
@@ -197,13 +279,21 @@ void Correction::startToTraverseAllNodes() {
}

bool Correction::needsToPrune() const {
    // TODO: use edit distance here
    return mOutputIndex - 1 >= mMaxDepth || mProximityCount > mMaxEditDistance;
}

void Correction::addCharToCurrentWord(const int32_t c) {
    mWord[mOutputIndex] = c;
    const unsigned short *primaryInputWord = mProximityInfo->getPrimaryInputWord();
    calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputLength,
            mWord, mOutputIndex + 1);
}

// TODO: inline?
Correction::CorrectionType Correction::processSkipChar(
        const int32_t c, const bool isTerminal, const bool inputIndexIncremented) {
    mWord[mOutputIndex] = c;
    addCharToCurrentWord(c);
    if (needsToTraverseAllNodes() && isTerminal) {
        mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0);
        mTerminalOutputIndex = mOutputIndex;
@@ -412,7 +502,7 @@ Correction::CorrectionType Correction::processCharAndCalcState(
                mProximityInfo->getNormalizedSquaredDistance(mInputIndex, proximityIndex);
    }

    mWord[mOutputIndex] = c;
    addCharToCurrentWord(c);

    // 4. Last char excessive correction
    mLastCharExceeded = mExcessiveCount == 0 && mSkippedCount == 0 && mTransposedCount == 0
@@ -526,47 +616,6 @@ inline static bool isUpperCase(unsigned short c) {
     return false;
}

/* static */
inline static int editDistance(
        int* editDistanceTable, const unsigned short* input,
        const int inputLength, const unsigned short* output, const int outputLength) {
    // dp[li][lo] dp[a][b] = dp[ a * lo + b]
    int* dp = editDistanceTable;
    const int li = inputLength + 1;
    const int lo = outputLength + 1;
    for (int i = 0; i < li; ++i) {
        dp[lo * i] = i;
    }
    for (int i = 0; i < lo; ++i) {
        dp[i] = i;
    }

    for (int i = 0; i < li - 1; ++i) {
        for (int j = 0; j < lo - 1; ++j) {
            const uint32_t ci = Dictionary::toBaseLowerCase(input[i]);
            const uint32_t co = Dictionary::toBaseLowerCase(output[j]);
            const uint16_t cost = (ci == co) ? 0 : 1;
            dp[(i + 1) * lo + (j + 1)] = min(dp[i * lo + (j + 1)] + 1,
                    min(dp[(i + 1) * lo + j] + 1, dp[i * lo + j] + cost));
            if (i > 0 && j > 0 && ci == Dictionary::toBaseLowerCase(output[j - 1])
                    && co == Dictionary::toBaseLowerCase(input[i - 1])) {
                dp[(i + 1) * lo + (j + 1)] = min(
                        dp[(i + 1) * lo + (j + 1)], dp[(i - 1) * lo + (j - 1)] + cost);
            }
        }
    }

    if (DEBUG_EDIT_DISTANCE) {
        LOGI("IN = %d, OUT = %d", inputLength, outputLength);
        for (int i = 0; i < li; ++i) {
            for (int j = 0; j < lo; ++j) {
                LOGI("EDIT[%d][%d], %d", i, j, dp[i * lo + j]);
            }
        }
    }
    return dp[li * lo - 1];
}

//////////////////////
// RankingAlgorithm //
//////////////////////
@@ -612,9 +661,7 @@ int Correction::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const
    // TODO: Optimize this.
    // TODO: Ignoring edit distance for transposed char, for now
    if (transposedCount == 0 && (proximityMatchedCount > 0 || skipped || excessiveCount > 0)) {
        const unsigned short* primaryInputWord = proximityInfo->getPrimaryInputWord();
        ed = editDistance(editDistanceTable, primaryInputWord,
                inputLength, word, outputIndex + 1);
        ed = getCurrentEditDistance(editDistanceTable, inputLength, outputIndex + 1);
        const int matchWeight = powerIntCapped(typedLetterMultiplier,
                max(inputLength, outputIndex + 1) - ed);
        multiplyIntCapped(matchWeight, &finalFreq);
+1 −0
Original line number Diff line number Diff line
@@ -102,6 +102,7 @@ private:
    inline bool isQuote(const unsigned short c);
    inline CorrectionType processSkipChar(
            const int32_t c, const bool isTerminal, const bool inputIndexIncremented);
    inline void addCharToCurrentWord(const int32_t c);

    const int TYPED_LETTER_MULTIPLIER;
    const int FULL_WORD_MULTIPLIER;