Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 20b6775a authored by Satoshi Kataoka's avatar Satoshi Kataoka
Browse files

Refactor most probable string

Change-Id: I96597decf5e36d9ce088c34427915f2379255054
parent 52a0d491
Loading
Loading
Loading
Loading
+16 −43
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@
 * limitations under the License.
 */

#include <cstring> // for memset()
#include <cstring> // for memset() and memcpy()
#include <sstream> // for debug prints

#define LOG_TAG "LatinIME: proximity_info_state.cpp"
@@ -59,12 +59,15 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
    int pushTouchPointStartIndex = 0;
    int lastSavedInputSize = 0;
    mMaxPointToKeyLength = maxPointToKeyLength;
    mSampledInputSize = 0;
    mMostProbableStringProbability = 0.0f;

    if (mIsContinuationPossible && mSampledInputIndice.size() > 1) {
        // Just update difference.
        // Two points prior is never skipped. Thus, we pop 2 input point data here.
        pushTouchPointStartIndex = mSampledInputIndice[mSampledInputIndice.size() - 2];
        popInputData();
        popInputData();
        // Previous two points are never skipped. Thus, we pop 2 input point data here.
        pushTouchPointStartIndex = ProximityInfoStateUtils::trimLastTwoTouchPoints(
                &mSampledInputXs, &mSampledInputYs, &mSampledTimes, &mSampledLengthCache,
                &mSampledInputIndice);
        lastSavedInputSize = mSampledInputXs.size();
    } else {
        // Clear all data.
@@ -81,11 +84,11 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
        mCharProbabilities.clear();
        mDirections.clear();
    }

    if (DEBUG_GEO_FULL) {
        AKLOGI("Init ProximityInfoState: reused points =  %d, last input size = %d",
                pushTouchPointStartIndex, lastSavedInputSize);
    }
    mSampledInputSize = 0;

    if (xCoordinates && yCoordinates) {
        mSampledInputSize = ProximityInfoStateUtils::updateTouchPoints(
@@ -121,6 +124,9 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
            ProximityInfoStateUtils::updateSampledSearchKeysVector(mProximityInfo,
                    mSampledInputSize, lastSavedInputSize, &mSampledLengthCache,
                    &mSampledNearKeysVector, &mSampledSearchKeysVector);
            mMostProbableStringProbability = ProximityInfoStateUtils::getMostProbableString(
                    mProximityInfo, mSampledInputSize, &mCharProbabilities, mMostProbableString);

        }
    }

@@ -132,8 +138,6 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
    // end
    ///////////////////////

    memset(mNormalizedSquaredDistances, NOT_A_DISTANCE, sizeof(mNormalizedSquaredDistances));
    memset(mPrimaryInputWord, 0, sizeof(mPrimaryInputWord));
    mTouchPositionCorrectionEnabled = mSampledInputSize > 0 && mHasTouchPositionCorrectionData
            && xCoordinates && yCoordinates;
    if (!isGeometric && pointerId == 0) {
@@ -142,8 +146,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
        if (mTouchPositionCorrectionEnabled) {
            ProximityInfoStateUtils::initNormalizedSquaredDistances(
                    mProximityInfo, inputSize, xCoordinates, yCoordinates, mInputProximities,
                    hasInputCoordinates(), &mSampledInputXs, &mSampledInputYs,
                    mNormalizedSquaredDistances);
                    &mSampledInputXs, &mSampledInputYs, mNormalizedSquaredDistances);
        }
    }
    if (DEBUG_GEO_FULL) {
@@ -278,16 +281,10 @@ int ProximityInfoState::getAllPossibleChars(
}

bool ProximityInfoState::isKeyInSerchKeysAfterIndex(const int index, const int keyId) const {
    ASSERT(keyId >= 0);
    ASSERT(index >= 0 && index < mSampledInputSize);
    ASSERT(keyId >= 0 && index >= 0 && index < mSampledInputSize);
    return mSampledSearchKeysVector[index].test(keyId);
}

void ProximityInfoState::popInputData() {
    ProximityInfoStateUtils::popInputData(&mSampledInputXs, &mSampledInputYs, &mSampledTimes,
            &mSampledLengthCache, &mSampledInputIndice);
}

float ProximityInfoState::getDirection(const int index0, const int index1) const {
    return ProximityInfoStateUtils::getDirection(
            &mSampledInputXs, &mSampledInputYs, index0, index1);
@@ -313,33 +310,9 @@ float ProximityInfoState::getLineToKeyDistance(
            keyX, keyY, x0, y0, x1, y1, extend);
}

// Get a word that is detected by tracing the most probable string into codePointBuf and
// returns probability of generating the word.
float ProximityInfoState::getMostProbableString(int *const codePointBuf) const {
    static const float DEMOTION_LOG_PROBABILITY = 0.3f;
    int index = 0;
    float sumLogProbability = 0.0f;
    // TODO: Current implementation is greedy algorithm. DP would be efficient for many cases.
    for (int i = 0; i < mSampledInputSize && index < MAX_WORD_LENGTH - 1; ++i) {
        float minLogProbability = static_cast<float>(MAX_POINT_TO_KEY_LENGTH);
        int character = NOT_AN_INDEX;
        for (hash_map_compat<int, float>::const_iterator it = mCharProbabilities[i].begin();
                it != mCharProbabilities[i].end(); ++it) {
            const float logProbability = (it->first != NOT_AN_INDEX)
                    ? it->second + DEMOTION_LOG_PROBABILITY : it->second;
            if (logProbability < minLogProbability) {
                minLogProbability = logProbability;
                character = it->first;
            }
        }
        if (character != NOT_AN_INDEX) {
            codePointBuf[index] = mProximityInfo->getCodePointOf(character);
            index++;
        }
        sumLogProbability += minLogProbability;
    }
    codePointBuf[index] = '\0';
    return sumLogProbability;
    memcpy(codePointBuf, mMostProbableString, sizeof(mMostProbableString));
    return mMostProbableStringProbability;
}

bool ProximityInfoState::hasSpaceProximity(const int index) const {
+20 −21
Original line number Diff line number Diff line
@@ -54,10 +54,12 @@ class ProximityInfoState {
              mSampledInputIndice(), mSampledLengthCache(), mBeelineSpeedPercentiles(),
              mSampledDistanceCache_G(), mSpeedRates(), mDirections(), mCharProbabilities(),
              mSampledNearKeysVector(), mSampledSearchKeysVector(),
              mTouchPositionCorrectionEnabled(false), mSampledInputSize(0) {
              mTouchPositionCorrectionEnabled(false), mSampledInputSize(0),
              mMostProbableStringProbability(0.0f) {
        memset(mInputProximities, 0, sizeof(mInputProximities));
        memset(mNormalizedSquaredDistances, 0, sizeof(mNormalizedSquaredDistances));
        memset(mPrimaryInputWord, 0, sizeof(mPrimaryInputWord));
        memset(mMostProbableString, 0, sizeof(mMostProbableString));
    }

    // Non virtual inline destructor -- never inherit this class
@@ -67,6 +69,21 @@ class ProximityInfoState {
        return getProximityCodePointsAt(index)[0];
    }

    inline bool sameAsTyped(const int *word, int length) const {
        if (length != mSampledInputSize) {
            return false;
        }
        const int *inputProximities = mInputProximities;
        while (length--) {
            if (*inputProximities != *word) {
                return false;
            }
            inputProximities += MAX_PROXIMITY_CHARS_SIZE;
            word++;
        }
        return true;
    }

    AK_FORCE_INLINE bool existsCodePointInProximityAt(const int index, const int c) const {
        const int *codePoints = getProximityCodePointsAt(index);
        int i = 0;
@@ -107,21 +124,6 @@ class ProximityInfoState {
        return mTouchPositionCorrectionEnabled;
    }

    inline bool sameAsTyped(const int *word, int length) const {
        if (length != mSampledInputSize) {
            return false;
        }
        const int *inputProximities = mInputProximities;
        while (length--) {
            if (*inputProximities != *word) {
                return false;
            }
            inputProximities += MAX_PROXIMITY_CHARS_SIZE;
            word++;
        }
        return true;
    }

    bool isUsed() const {
        return mSampledInputSize > 0;
    }
@@ -208,14 +210,9 @@ class ProximityInfoState {
    // Defined here                        //
    /////////////////////////////////////////

    bool hasInputCoordinates() const {
        return mSampledInputXs.size() > 0 && mSampledInputYs.size() > 0;
    }

    inline const int *getProximityCodePointsAt(const int index) const {
        return ProximityInfoStateUtils::getProximityCodePointsAt(mInputProximities, index);
    }
    void popInputData();

    // const
    const ProximityInfo *mProximityInfo;
@@ -255,6 +252,8 @@ class ProximityInfoState {
    int mNormalizedSquaredDistances[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH];
    int mSampledInputSize;
    int mPrimaryInputWord[MAX_WORD_LENGTH];
    float mMostProbableStringProbability;
    int mMostProbableString[MAX_WORD_LENGTH];
};
} // namespace latinime
#endif // LATINIME_PROXIMITY_INFO_STATE_H
+51 −1
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@
 */

#include <cmath>
#include <cstring> // for memset()
#include <sstream> // for debug prints
#include <vector>

@@ -26,6 +27,17 @@

namespace latinime {

/* static */ int ProximityInfoStateUtils::trimLastTwoTouchPoints(std::vector<int> *sampledInputXs,
        std::vector<int> *sampledInputYs, std::vector<int> *sampledInputTimes,
        std::vector<int> *sampledLengthCache, std::vector<int> *sampledInputIndice) {
    const int nextStartIndex = (*sampledInputIndice)[sampledInputIndice->size() - 2];
    popInputData(sampledInputXs, sampledInputYs, sampledInputTimes, sampledLengthCache,
            sampledInputIndice);
    popInputData(sampledInputXs, sampledInputYs, sampledInputTimes, sampledLengthCache,
            sampledInputIndice);
    return nextStartIndex;
}

/* static */ int ProximityInfoStateUtils::updateTouchPoints(const int mostCommonKeyWidth,
        const ProximityInfo *const proximityInfo, const int maxPointToKeyLength,
        const int *const inputProximities, const int *const inputXCoordinates,
@@ -133,6 +145,7 @@ namespace latinime {

/* static */ void ProximityInfoStateUtils::initPrimaryInputWord(
        const int inputSize, const int *const inputProximities, int *primaryInputWord) {
    memset(primaryInputWord, 0, sizeof(primaryInputWord[0]) * MAX_WORD_LENGTH);
    for (int i = 0; i < inputSize; ++i) {
        primaryInputWord[i] = getPrimaryCodePointAt(inputProximities, i);
    }
@@ -171,10 +184,13 @@ namespace latinime {
/* static */ void ProximityInfoStateUtils::initNormalizedSquaredDistances(
        const ProximityInfo *const proximityInfo, const int inputSize,
        const int *inputXCoordinates, const int *inputYCoordinates,
        const int *const inputProximities, const bool hasInputCoordinates,
        const int *const inputProximities,
        const std::vector<int> *const sampledInputXs,
        const std::vector<int> *const sampledInputYs,
        int *normalizedSquaredDistances) {
    memset(normalizedSquaredDistances, NOT_A_DISTANCE,
            sizeof(normalizedSquaredDistances[0]) * MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH);
    const bool hasInputCoordinates = sampledInputXs->size() > 0 && sampledInputYs->size() > 0;
    for (int i = 0; i < inputSize; ++i) {
        const int *proximityCodePoints = getProximityCodePointsAt(inputProximities, i);
        const int primaryKey = proximityCodePoints[0];
@@ -1011,6 +1027,40 @@ namespace latinime {
    return true;
}

// Get a word that is detected by tracing the most probable string into codePointBuf and
// returns probability of generating the word.
/* static */ float ProximityInfoStateUtils::getMostProbableString(
        const ProximityInfo *const proximityInfo, const int sampledInputSize,
        const std::vector<hash_map_compat<int, float> > *const charProbabilities,
        int *const codePointBuf) {
    ASSERT(charProbabilities->size() >= 0 && sampledInputSize >= 0);
    memset(codePointBuf, 0, sizeof(codePointBuf[0]) * MAX_WORD_LENGTH);
    static const float DEMOTION_LOG_PROBABILITY = 0.3f;
    int index = 0;
    float sumLogProbability = 0.0f;
    // TODO: Current implementation is greedy algorithm. DP would be efficient for many cases.
    for (int i = 0; i < sampledInputSize && index < MAX_WORD_LENGTH - 1; ++i) {
        float minLogProbability = static_cast<float>(MAX_POINT_TO_KEY_LENGTH);
        int character = NOT_AN_INDEX;
        for (hash_map_compat<int, float>::const_iterator it = (*charProbabilities)[i].begin();
                it != (*charProbabilities)[i].end(); ++it) {
            const float logProbability = (it->first != NOT_AN_INDEX)
                    ? it->second + DEMOTION_LOG_PROBABILITY : it->second;
            if (logProbability < minLogProbability) {
                minLogProbability = logProbability;
                character = it->first;
            }
        }
        if (character != NOT_AN_INDEX) {
            codePointBuf[index] = proximityInfo->getCodePointOf(character);
            index++;
        }
        sumLogProbability += minLogProbability;
    }
    codePointBuf[index] = '\0';
    return sumLogProbability;
}

/* static */ void ProximityInfoStateUtils::dump(const bool isGeometric, const int inputSize,
        const int *const inputXCoordinates, const int *const inputYCoordinates,
        const int sampledInputSize, const std::vector<int> *const sampledInputXs,
+10 −1
Original line number Diff line number Diff line
@@ -32,6 +32,9 @@ class ProximityInfoStateUtils {
    typedef hash_map_compat<int, float> NearKeysDistanceMap;
    typedef std::bitset<MAX_KEY_COUNT_IN_A_KEYBOARD> NearKeycodesSet;

    static int trimLastTwoTouchPoints(std::vector<int> *sampledInputXs,
            std::vector<int> *sampledInputYs, std::vector<int> *sampledInputTimes,
            std::vector<int> *sampledLengthCache, std::vector<int> *sampledInputIndice);
    static int updateTouchPoints(const int mostCommonKeyWidth,
            const ProximityInfo *const proximityInfo, const int maxPointToKeyLength,
            const int *const inputProximities,
@@ -96,7 +99,7 @@ class ProximityInfoStateUtils {
    static void initNormalizedSquaredDistances(
            const ProximityInfo *const proximityInfo, const int inputSize,
            const int *inputXCoordinates, const int *inputYCoordinates,
            const int *const inputProximities, const bool hasInputCoordinates,
            const int *const inputProximities,
            const std::vector<int> *const sampledInputXs,
            const std::vector<int> *const sampledInputYs,
            int *normalizedSquaredDistances);
@@ -113,6 +116,12 @@ class ProximityInfoStateUtils {
            const std::vector<int> *const sampledInputYs,
            const std::vector<int> *const sampledTimes,
            const std::vector<int> *const sampledInputIndices);
    // TODO: Move to most_probable_string_utils.h
    static float getMostProbableString(
            const ProximityInfo *const proximityInfo, const int sampledInputSize,
            const std::vector<hash_map_compat<int, float> > *const charProbabilities,
            int *const codePointBuf);

 private:
    DISALLOW_IMPLICIT_CONSTRUCTORS(ProximityInfoStateUtils);