Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e5aad564 authored by Satoshi Kataoka's avatar Satoshi Kataoka
Browse files

Refactor proximity info state

Change-Id: I00e0618d95d20e5bf5c9e6481e4d3037723785f7
parent f1074c50
Loading
Loading
Loading
Loading
+12 −11
Original line number Diff line number Diff line
@@ -71,9 +71,9 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
        mSampledTimes.clear();
        mSampledInputIndice.clear();
        mSampledLengthCache.clear();
        mDistanceCache_G.clear();
        mNearKeysVector.clear();
        mSearchKeysVector.clear();
        mSampledDistanceCache_G.clear();
        mSampledNearKeysVector.clear();
        mSampledSearchKeysVector.clear();
        mSpeedRates.clear();
        mBeelineSpeedPercentiles.clear();
        mCharProbabilities.clear();
@@ -108,16 +108,17 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
        ProximityInfoStateUtils::initGeometricDistanceInfos(
                mProximityInfo, mProximityInfo->getKeyCount(),
                mSampledInputSize, lastSavedInputSize, &mSampledInputXs, &mSampledInputYs,
                &mNearKeysVector, &mSearchKeysVector, &mDistanceCache_G);
                &mSampledNearKeysVector, &mSampledDistanceCache_G);
        if (isGeometric) {
            // updates probabilities of skipping or mapping each key for all points.
            ProximityInfoStateUtils::updateAlignPointProbabilities(
                    mMaxPointToKeyLength, mProximityInfo->getMostCommonKeyWidth(),
                    mProximityInfo->getKeyCount(), lastSavedInputSize, mSampledInputSize,
                    &mSampledInputXs, &mSampledInputYs, &mSpeedRates, &mSampledLengthCache,
                    &mDistanceCache_G, &mNearKeysVector, &mCharProbabilities);
            ProximityInfoStateUtils::updateSearchKeysVector(mProximityInfo, mSampledInputSize,
                    lastSavedInputSize, &mSampledLengthCache, &mNearKeysVector, &mSearchKeysVector);
                    &mSampledDistanceCache_G, &mSampledNearKeysVector, &mCharProbabilities);
            ProximityInfoStateUtils::updateSampledSearchKeysVector(mProximityInfo,
                    mSampledInputSize, lastSavedInputSize, &mSampledLengthCache,
                    &mSampledNearKeysVector, &mSampledSearchKeysVector);
        }
    }

@@ -189,7 +190,7 @@ float ProximityInfoState::getPointToKeyLength(
    const int keyId = mProximityInfo->getKeyIndexOf(codePoint);
    if (keyId != NOT_AN_INDEX) {
        const int index = inputIndex * mProximityInfo->getKeyCount() + keyId;
        return min(mDistanceCache_G[index] * scale, mMaxPointToKeyLength);
        return min(mSampledDistanceCache_G[index] * scale, mMaxPointToKeyLength);
    }
    if (isSkippableCodePoint(codePoint)) {
        return 0.0f;
@@ -206,7 +207,7 @@ float ProximityInfoState::getPointToKeyLength_G(const int inputIndex, const int
float ProximityInfoState::getPointToKeyByIdLength(
        const int inputIndex, const int keyId, const float scale) const {
    return ProximityInfoStateUtils::getPointToKeyByIdLength(mMaxPointToKeyLength,
            &mDistanceCache_G, mProximityInfo->getKeyCount(), inputIndex, keyId, scale);
            &mSampledDistanceCache_G, mProximityInfo->getKeyCount(), inputIndex, keyId, scale);
}

float ProximityInfoState::getPointToKeyByIdLength(const int inputIndex, const int keyId) const {
@@ -289,7 +290,7 @@ int ProximityInfoState::getAllPossibleChars(
    int newFilterSize = filterSize;
    const int keyCount = mProximityInfo->getKeyCount();
    for (int j = 0; j < keyCount; ++j) {
        if (mSearchKeysVector[index].test(j)) {
        if (mSampledSearchKeysVector[index].test(j)) {
            const int keyCodePoint = mProximityInfo->getCodePointOf(j);
            bool insert = true;
            // TODO: Avoid linear search
@@ -310,7 +311,7 @@ int ProximityInfoState::getAllPossibleChars(
bool ProximityInfoState::isKeyInSerchKeysAfterIndex(const int index, const int keyId) const {
    ASSERT(keyId >= 0);
    ASSERT(index >= 0 && index < mSampledInputSize);
    return mSearchKeysVector[index].test(keyId);
    return mSampledSearchKeysVector[index].test(keyId);
}

void ProximityInfoState::popInputData() {
+7 −7
Original line number Diff line number Diff line
@@ -52,9 +52,9 @@ class ProximityInfoState {
              mKeyCount(0), mCellHeight(0), mCellWidth(0), mGridHeight(0), mGridWidth(0),
              mIsContinuationPossible(false), mSampledInputXs(), mSampledInputYs(), mSampledTimes(),
              mSampledInputIndice(), mSampledLengthCache(), mBeelineSpeedPercentiles(),
              mDistanceCache_G(), mSpeedRates(), mDirections(), mCharProbabilities(),
              mNearKeysVector(), mSearchKeysVector(), mTouchPositionCorrectionEnabled(false),
              mSampledInputSize(0) {
              mSampledDistanceCache_G(), mSpeedRates(), mDirections(), mCharProbabilities(),
              mSampledNearKeysVector(), mSampledSearchKeysVector(),
              mTouchPositionCorrectionEnabled(false), mSampledInputSize(0) {
        memset(mInputProximities, 0, sizeof(mInputProximities));
        memset(mNormalizedSquaredDistances, 0, sizeof(mNormalizedSquaredDistances));
        memset(mPrimaryInputWord, 0, sizeof(mPrimaryInputWord));
@@ -240,20 +240,20 @@ class ProximityInfoState {
    std::vector<int> mSampledInputIndice;
    std::vector<int> mSampledLengthCache;
    std::vector<int> mBeelineSpeedPercentiles;
    std::vector<float> mDistanceCache_G;
    std::vector<float> mSampledDistanceCache_G;
    std::vector<float> mSpeedRates;
    std::vector<float> mDirections;
    // probabilities of skipping or mapping to a key for each point.
    std::vector<hash_map_compat<int, float> > mCharProbabilities;
    // The vector for the key code set which holds nearby keys for each sampled input point
    // 1. Used to calculate the probability of the key
    // 2. Used to calculate mSearchKeysVector
    std::vector<ProximityInfoStateUtils::NearKeycodesSet> mNearKeysVector;
    // 2. Used to calculate mSampledSearchKeysVector
    std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSampledNearKeysVector;
    // The vector for the key code set which holds nearby keys of some trailing sampled input points
    // for each sampled input point. These nearby keys contain the next characters which can be in
    // the dictionary. Specifically, currently we are looking for keys nearby trailing sampled
    // inputs including the current input point.
    std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSearchKeysVector;
    std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSampledSearchKeysVector;
    bool mTouchPositionCorrectionEnabled;
    int mInputProximities[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH];
    int mNormalizedSquaredDistances[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH];
+32 −34
Original line number Diff line number Diff line
@@ -215,15 +215,12 @@ namespace latinime {
        const int sampledInputSize, const int lastSavedInputSize,
        const std::vector<int> *const sampledInputXs,
        const std::vector<int> *const sampledInputYs,
        std::vector<NearKeycodesSet> *nearKeysVector,
        std::vector<NearKeycodesSet> *searchKeysVector,
        std::vector<float> *distanceCache_G) {
    nearKeysVector->resize(sampledInputSize);
    searchKeysVector->resize(sampledInputSize);
    distanceCache_G->resize(sampledInputSize * keyCount);
        std::vector<NearKeycodesSet> *SampledNearKeysVector,
        std::vector<float> *SampledDistanceCache_G) {
    SampledNearKeysVector->resize(sampledInputSize);
    SampledDistanceCache_G->resize(sampledInputSize * keyCount);
    for (int i = lastSavedInputSize; i < sampledInputSize; ++i) {
        (*nearKeysVector)[i].reset();
        (*searchKeysVector)[i].reset();
        (*SampledNearKeysVector)[i].reset();
        static const float NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD = 4.0f;
        for (int k = 0; k < keyCount; ++k) {
            const int index = i * keyCount + k;
@@ -231,9 +228,9 @@ namespace latinime {
            const int y = (*sampledInputYs)[i];
            const float normalizedSquaredDistance =
                    proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y);
            (*distanceCache_G)[index] = normalizedSquaredDistance;
            (*SampledDistanceCache_G)[index] = normalizedSquaredDistance;
            if (normalizedSquaredDistance < NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) {
                (*nearKeysVector)[i][k] = true;
                (*SampledNearKeysVector)[i][k] = true;
            }
        }
    }
@@ -638,21 +635,21 @@ namespace latinime {
// This function basically converts from a length to an edit distance. Accordingly, it's obviously
// wrong to compare with mMaxPointToKeyLength.
/* static */ float ProximityInfoStateUtils::getPointToKeyByIdLength(const float maxPointToKeyLength,
        const std::vector<float> *const distanceCache_G, const int keyCount,
        const std::vector<float> *const SampledDistanceCache_G, const int keyCount,
        const int inputIndex, const int keyId, const float scale) {
    if (keyId != NOT_AN_INDEX) {
        const int index = inputIndex * keyCount + keyId;
        return min((*distanceCache_G)[index] * scale, maxPointToKeyLength);
        return min((*SampledDistanceCache_G)[index] * scale, maxPointToKeyLength);
    }
    // If the char is not a key on the keyboard then return the max length.
    return static_cast<float>(MAX_POINT_TO_KEY_LENGTH);
}

/* static */ float ProximityInfoStateUtils::getPointToKeyByIdLength(const float maxPointToKeyLength,
        const std::vector<float> *const distanceCache_G, const int keyCount,
        const std::vector<float> *const SampledDistanceCache_G, const int keyCount,
        const int inputIndex, const int keyId) {
    return getPointToKeyByIdLength(maxPointToKeyLength, distanceCache_G, keyCount, inputIndex,
            keyId, 1.0f);
    return getPointToKeyByIdLength(
            maxPointToKeyLength, SampledDistanceCache_G, keyCount, inputIndex, keyId, 1.0f);
}

// Updates probabilities of aligning to some keys and skipping.
@@ -663,8 +660,8 @@ namespace latinime {
        const std::vector<int> *const sampledInputYs,
        const std::vector<float> *const sampledSpeedRates,
        const std::vector<int> *const sampledLengthCache,
        const std::vector<float> *const distanceCache_G,
        std::vector<NearKeycodesSet> *nearKeysVector,
        const std::vector<float> *const SampledDistanceCache_G,
        std::vector<NearKeycodesSet> *SampledNearKeysVector,
        std::vector<hash_map_compat<int, float> > *charProbabilities) {
    static const float MIN_PROBABILITY = 0.000001f;
    static const float MAX_SKIP_PROBABILITY = 0.95f;
@@ -701,9 +698,9 @@ namespace latinime {

        float nearestKeyDistance = static_cast<float>(MAX_POINT_TO_KEY_LENGTH);
        for (int j = 0; j < keyCount; ++j) {
            if ((*nearKeysVector)[i].test(j)) {
            if ((*SampledNearKeysVector)[i].test(j)) {
                const float distance = getPointToKeyByIdLength(
                        maxPointToKeyLength, distanceCache_G, keyCount, i, j);
                        maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j);
                if (distance < nearestKeyDistance) {
                    nearestKeyDistance = distance;
                }
@@ -786,14 +783,14 @@ namespace latinime {
        // Summing up probability densities of all near keys.
        float sumOfProbabilityDensities = 0.0f;
        for (int j = 0; j < keyCount; ++j) {
            if ((*nearKeysVector)[i].test(j)) {
            if ((*SampledNearKeysVector)[i].test(j)) {
                float distance = sqrtf(getPointToKeyByIdLength(
                        maxPointToKeyLength, distanceCache_G, keyCount, i, j));
                        maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j));
                if (i == 0 && i != sampledInputSize - 1) {
                    // For the first point, weighted average of distances from first point and the
                    // next point to the key is used as a point to key distance.
                    const float nextDistance = sqrtf(getPointToKeyByIdLength(
                            maxPointToKeyLength, distanceCache_G, keyCount, i + 1, j));
                            maxPointToKeyLength, SampledDistanceCache_G, keyCount, i + 1, j));
                    if (nextDistance < distance) {
                        // The distance of the first point tends to bigger than continuing
                        // points because the first touch by the user can be sloppy.
@@ -806,7 +803,7 @@ namespace latinime {
                    // For the first point, weighted average of distances from last point and
                    // the previous point to the key is used as a point to key distance.
                    const float previousDistance = sqrtf(getPointToKeyByIdLength(
                            maxPointToKeyLength, distanceCache_G, keyCount, i - 1, j));
                            maxPointToKeyLength, SampledDistanceCache_G, keyCount, i - 1, j));
                    if (previousDistance < distance) {
                        // The distance of the last point tends to bigger than continuing points
                        // because the last touch by the user can be sloppy. So we promote the
@@ -824,14 +821,14 @@ namespace latinime {

        // Split the probability of an input point to keys that are close to the input point.
        for (int j = 0; j < keyCount; ++j) {
            if ((*nearKeysVector)[i].test(j)) {
            if ((*SampledNearKeysVector)[i].test(j)) {
                float distance = sqrtf(getPointToKeyByIdLength(
                        maxPointToKeyLength, distanceCache_G, keyCount, i, j));
                        maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j));
                if (i == 0 && i != sampledInputSize - 1) {
                    // For the first point, weighted average of distances from the first point and
                    // the next point to the key is used as a point to key distance.
                    const float prevDistance = sqrtf(getPointToKeyByIdLength(
                            maxPointToKeyLength, distanceCache_G, keyCount, i + 1, j));
                            maxPointToKeyLength, SampledDistanceCache_G, keyCount, i + 1, j));
                    if (prevDistance < distance) {
                        distance = (distance + prevDistance * NEXT_DISTANCE_WEIGHT)
                                / (1.0f + NEXT_DISTANCE_WEIGHT);
@@ -840,7 +837,7 @@ namespace latinime {
                    // For the first point, weighted average of distances from last point and
                    // the previous point to the key is used as a point to key distance.
                    const float prevDistance = sqrtf(getPointToKeyByIdLength(
                            maxPointToKeyLength, distanceCache_G, keyCount, i - 1, j));
                            maxPointToKeyLength, SampledDistanceCache_G, keyCount, i - 1, j));
                    if (prevDistance < distance) {
                        distance = (distance + prevDistance * PREV_DISTANCE_WEIGHT)
                                / (1.0f + PREV_DISTANCE_WEIGHT);
@@ -906,10 +903,10 @@ namespace latinime {
        for (int j = 0; j < keyCount; ++j) {
            hash_map_compat<int, float>::iterator it = (*charProbabilities)[i].find(j);
            if (it == (*charProbabilities)[i].end()){
                (*nearKeysVector)[i].reset(j);
                (*SampledNearKeysVector)[i].reset(j);
            } else if(it->second < MIN_PROBABILITY) {
                // Erases from near keys vector because it has very low probability.
                (*nearKeysVector)[i].reset(j);
                (*SampledNearKeysVector)[i].reset(j);
                (*charProbabilities)[i].erase(j);
            } else {
                it->second = -logf(it->second);
@@ -919,25 +916,26 @@ namespace latinime {
    }
}

/* static */ void ProximityInfoStateUtils::updateSearchKeysVector(
/* static */ void ProximityInfoStateUtils::updateSampledSearchKeysVector(
        const ProximityInfo *const proximityInfo, const int sampledInputSize,
        const int lastSavedInputSize,
        const std::vector<int> *const sampledLengthCache,
        const std::vector<NearKeycodesSet> *const nearKeysVector,
        std::vector<NearKeycodesSet> *searchKeysVector) {
        const std::vector<NearKeycodesSet> *const SampledNearKeysVector,
        std::vector<NearKeycodesSet> *sampledSearchKeysVector) {
    sampledSearchKeysVector->resize(sampledInputSize);
    const int readForwordLength = static_cast<int>(
            hypotf(proximityInfo->getKeyboardWidth(), proximityInfo->getKeyboardHeight())
                    * ProximityInfoParams::SEARCH_KEY_RADIUS_RATIO);
    for (int i = 0; i < sampledInputSize; ++i) {
        if (i >= lastSavedInputSize) {
            (*searchKeysVector)[i].reset();
            (*sampledSearchKeysVector)[i].reset();
        }
        for (int j = max(i, lastSavedInputSize); j < sampledInputSize; ++j) {
            // TODO: Investigate if this is required. This may not fail.
            if ((*sampledLengthCache)[j] - (*sampledLengthCache)[i] >= readForwordLength) {
                break;
            }
            (*searchKeysVector)[i] |= (*nearKeysVector)[j];
            (*sampledSearchKeysVector)[i] |= (*SampledNearKeysVector)[j];
        }
    }
}
+9 −10
Original line number Diff line number Diff line
@@ -69,29 +69,28 @@ class ProximityInfoStateUtils {
            const std::vector<int> *const sampledInputYs,
            const std::vector<float> *const sampledSpeedRates,
            const std::vector<int> *const sampledLengthCache,
            const std::vector<float> *const distanceCache_G,
            std::vector<NearKeycodesSet> *nearKeysVector,
            const std::vector<float> *const SampledDistanceCache_G,
            std::vector<NearKeycodesSet> *SampledNearKeysVector,
            std::vector<hash_map_compat<int, float> > *charProbabilities);
    static void updateSearchKeysVector(
    static void updateSampledSearchKeysVector(
            const ProximityInfo *const proximityInfo, const int sampledInputSize,
            const int lastSavedInputSize,
            const std::vector<int> *const sampledLengthCache,
            const std::vector<NearKeycodesSet> *const nearKeysVector,
            std::vector<NearKeycodesSet> *searchKeysVector);
            const std::vector<NearKeycodesSet> *const SampledNearKeysVector,
            std::vector<NearKeycodesSet> *sampledSearchKeysVector);
    static float getPointToKeyByIdLength(const float maxPointToKeyLength,
            const std::vector<float> *const distanceCache_G, const int keyCount,
            const std::vector<float> *const SampledDistanceCache_G, const int keyCount,
            const int inputIndex, const int keyId, const float scale);
    static float getPointToKeyByIdLength(const float maxPointToKeyLength,
            const std::vector<float> *const distanceCache_G, const int keyCount,
            const std::vector<float> *const SampledDistanceCache_G, const int keyCount,
            const int inputIndex, const int keyId);
    static void initGeometricDistanceInfos(
            const ProximityInfo *const proximityInfo, const int keyCount,
            const int sampledInputSize, const int lastSavedInputSize,
            const std::vector<int> *const sampledInputXs,
            const std::vector<int> *const sampledInputYs,
            std::vector<NearKeycodesSet> *nearKeysVector,
            std::vector<NearKeycodesSet> *searchKeysVector,
            std::vector<float> *distanceCache_G);
            std::vector<NearKeycodesSet> *SampledNearKeysVector,
            std::vector<float> *SampledDistanceCache_G);
    static void initPrimaryInputWord(
            const int inputSize, const int *const inputProximities, int *primaryInputWord);
    static void initNormalizedSquaredDistances(