Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 64a907bb authored by Satoshi Kataoka's avatar Satoshi Kataoka Committed by Android (Google) Code Review
Browse files

Merge "Refactor proximity info state"

parents 9d514af4 e5aad564
Loading
Loading
Loading
Loading
+12 −11
Original line number Diff line number Diff line
@@ -71,9 +71,9 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
        mSampledTimes.clear();
        mSampledInputIndice.clear();
        mSampledLengthCache.clear();
        mDistanceCache_G.clear();
        mNearKeysVector.clear();
        mSearchKeysVector.clear();
        mSampledDistanceCache_G.clear();
        mSampledNearKeysVector.clear();
        mSampledSearchKeysVector.clear();
        mSpeedRates.clear();
        mBeelineSpeedPercentiles.clear();
        mCharProbabilities.clear();
@@ -108,16 +108,17 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
        ProximityInfoStateUtils::initGeometricDistanceInfos(
                mProximityInfo, mProximityInfo->getKeyCount(),
                mSampledInputSize, lastSavedInputSize, &mSampledInputXs, &mSampledInputYs,
                &mNearKeysVector, &mSearchKeysVector, &mDistanceCache_G);
                &mSampledNearKeysVector, &mSampledDistanceCache_G);
        if (isGeometric) {
            // updates probabilities of skipping or mapping each key for all points.
            ProximityInfoStateUtils::updateAlignPointProbabilities(
                    mMaxPointToKeyLength, mProximityInfo->getMostCommonKeyWidth(),
                    mProximityInfo->getKeyCount(), lastSavedInputSize, mSampledInputSize,
                    &mSampledInputXs, &mSampledInputYs, &mSpeedRates, &mSampledLengthCache,
                    &mDistanceCache_G, &mNearKeysVector, &mCharProbabilities);
            ProximityInfoStateUtils::updateSearchKeysVector(mProximityInfo, mSampledInputSize,
                    lastSavedInputSize, &mSampledLengthCache, &mNearKeysVector, &mSearchKeysVector);
                    &mSampledDistanceCache_G, &mSampledNearKeysVector, &mCharProbabilities);
            ProximityInfoStateUtils::updateSampledSearchKeysVector(mProximityInfo,
                    mSampledInputSize, lastSavedInputSize, &mSampledLengthCache,
                    &mSampledNearKeysVector, &mSampledSearchKeysVector);
        }
    }

@@ -189,7 +190,7 @@ float ProximityInfoState::getPointToKeyLength(
    const int keyId = mProximityInfo->getKeyIndexOf(codePoint);
    if (keyId != NOT_AN_INDEX) {
        const int index = inputIndex * mProximityInfo->getKeyCount() + keyId;
        return min(mDistanceCache_G[index] * scale, mMaxPointToKeyLength);
        return min(mSampledDistanceCache_G[index] * scale, mMaxPointToKeyLength);
    }
    if (isSkippableCodePoint(codePoint)) {
        return 0.0f;
@@ -206,7 +207,7 @@ float ProximityInfoState::getPointToKeyLength_G(const int inputIndex, const int
float ProximityInfoState::getPointToKeyByIdLength(
        const int inputIndex, const int keyId, const float scale) const {
    return ProximityInfoStateUtils::getPointToKeyByIdLength(mMaxPointToKeyLength,
            &mDistanceCache_G, mProximityInfo->getKeyCount(), inputIndex, keyId, scale);
            &mSampledDistanceCache_G, mProximityInfo->getKeyCount(), inputIndex, keyId, scale);
}

float ProximityInfoState::getPointToKeyByIdLength(const int inputIndex, const int keyId) const {
@@ -289,7 +290,7 @@ int ProximityInfoState::getAllPossibleChars(
    int newFilterSize = filterSize;
    const int keyCount = mProximityInfo->getKeyCount();
    for (int j = 0; j < keyCount; ++j) {
        if (mSearchKeysVector[index].test(j)) {
        if (mSampledSearchKeysVector[index].test(j)) {
            const int keyCodePoint = mProximityInfo->getCodePointOf(j);
            bool insert = true;
            // TODO: Avoid linear search
@@ -310,7 +311,7 @@ int ProximityInfoState::getAllPossibleChars(
bool ProximityInfoState::isKeyInSerchKeysAfterIndex(const int index, const int keyId) const {
    ASSERT(keyId >= 0);
    ASSERT(index >= 0 && index < mSampledInputSize);
    return mSearchKeysVector[index].test(keyId);
    return mSampledSearchKeysVector[index].test(keyId);
}

void ProximityInfoState::popInputData() {
+7 −7
Original line number Diff line number Diff line
@@ -52,9 +52,9 @@ class ProximityInfoState {
              mKeyCount(0), mCellHeight(0), mCellWidth(0), mGridHeight(0), mGridWidth(0),
              mIsContinuationPossible(false), mSampledInputXs(), mSampledInputYs(), mSampledTimes(),
              mSampledInputIndice(), mSampledLengthCache(), mBeelineSpeedPercentiles(),
              mDistanceCache_G(), mSpeedRates(), mDirections(), mCharProbabilities(),
              mNearKeysVector(), mSearchKeysVector(), mTouchPositionCorrectionEnabled(false),
              mSampledInputSize(0) {
              mSampledDistanceCache_G(), mSpeedRates(), mDirections(), mCharProbabilities(),
              mSampledNearKeysVector(), mSampledSearchKeysVector(),
              mTouchPositionCorrectionEnabled(false), mSampledInputSize(0) {
        memset(mInputProximities, 0, sizeof(mInputProximities));
        memset(mNormalizedSquaredDistances, 0, sizeof(mNormalizedSquaredDistances));
        memset(mPrimaryInputWord, 0, sizeof(mPrimaryInputWord));
@@ -240,20 +240,20 @@ class ProximityInfoState {
    std::vector<int> mSampledInputIndice;
    std::vector<int> mSampledLengthCache;
    std::vector<int> mBeelineSpeedPercentiles;
    std::vector<float> mDistanceCache_G;
    std::vector<float> mSampledDistanceCache_G;
    std::vector<float> mSpeedRates;
    std::vector<float> mDirections;
    // probabilities of skipping or mapping to a key for each point.
    std::vector<hash_map_compat<int, float> > mCharProbabilities;
    // The vector for the key code set which holds nearby keys for each sampled input point
    // 1. Used to calculate the probability of the key
    // 2. Used to calculate mSearchKeysVector
    std::vector<ProximityInfoStateUtils::NearKeycodesSet> mNearKeysVector;
    // 2. Used to calculate mSampledSearchKeysVector
    std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSampledNearKeysVector;
    // The vector for the key code set which holds nearby keys of some trailing sampled input points
    // for each sampled input point. These nearby keys contain the next characters which can be in
    // the dictionary. Specifically, currently we are looking for keys nearby trailing sampled
    // inputs including the current input point.
    std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSearchKeysVector;
    std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSampledSearchKeysVector;
    bool mTouchPositionCorrectionEnabled;
    int mInputProximities[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH];
    int mNormalizedSquaredDistances[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH];
+32 −34
Original line number Diff line number Diff line
@@ -215,15 +215,12 @@ namespace latinime {
        const int sampledInputSize, const int lastSavedInputSize,
        const std::vector<int> *const sampledInputXs,
        const std::vector<int> *const sampledInputYs,
        std::vector<NearKeycodesSet> *nearKeysVector,
        std::vector<NearKeycodesSet> *searchKeysVector,
        std::vector<float> *distanceCache_G) {
    nearKeysVector->resize(sampledInputSize);
    searchKeysVector->resize(sampledInputSize);
    distanceCache_G->resize(sampledInputSize * keyCount);
        std::vector<NearKeycodesSet> *SampledNearKeysVector,
        std::vector<float> *SampledDistanceCache_G) {
    SampledNearKeysVector->resize(sampledInputSize);
    SampledDistanceCache_G->resize(sampledInputSize * keyCount);
    for (int i = lastSavedInputSize; i < sampledInputSize; ++i) {
        (*nearKeysVector)[i].reset();
        (*searchKeysVector)[i].reset();
        (*SampledNearKeysVector)[i].reset();
        static const float NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD = 4.0f;
        for (int k = 0; k < keyCount; ++k) {
            const int index = i * keyCount + k;
@@ -231,9 +228,9 @@ namespace latinime {
            const int y = (*sampledInputYs)[i];
            const float normalizedSquaredDistance =
                    proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y);
            (*distanceCache_G)[index] = normalizedSquaredDistance;
            (*SampledDistanceCache_G)[index] = normalizedSquaredDistance;
            if (normalizedSquaredDistance < NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) {
                (*nearKeysVector)[i][k] = true;
                (*SampledNearKeysVector)[i][k] = true;
            }
        }
    }
@@ -638,21 +635,21 @@ namespace latinime {
// This function basically converts from a length to an edit distance. Accordingly, it's obviously
// wrong to compare with mMaxPointToKeyLength.
/* static */ float ProximityInfoStateUtils::getPointToKeyByIdLength(const float maxPointToKeyLength,
        const std::vector<float> *const distanceCache_G, const int keyCount,
        const std::vector<float> *const SampledDistanceCache_G, const int keyCount,
        const int inputIndex, const int keyId, const float scale) {
    if (keyId != NOT_AN_INDEX) {
        const int index = inputIndex * keyCount + keyId;
        return min((*distanceCache_G)[index] * scale, maxPointToKeyLength);
        return min((*SampledDistanceCache_G)[index] * scale, maxPointToKeyLength);
    }
    // If the char is not a key on the keyboard then return the max length.
    return static_cast<float>(MAX_POINT_TO_KEY_LENGTH);
}

/* static */ float ProximityInfoStateUtils::getPointToKeyByIdLength(const float maxPointToKeyLength,
        const std::vector<float> *const distanceCache_G, const int keyCount,
        const std::vector<float> *const SampledDistanceCache_G, const int keyCount,
        const int inputIndex, const int keyId) {
    return getPointToKeyByIdLength(maxPointToKeyLength, distanceCache_G, keyCount, inputIndex,
            keyId, 1.0f);
    return getPointToKeyByIdLength(
            maxPointToKeyLength, SampledDistanceCache_G, keyCount, inputIndex, keyId, 1.0f);
}

// Updates probabilities of aligning to some keys and skipping.
@@ -663,8 +660,8 @@ namespace latinime {
        const std::vector<int> *const sampledInputYs,
        const std::vector<float> *const sampledSpeedRates,
        const std::vector<int> *const sampledLengthCache,
        const std::vector<float> *const distanceCache_G,
        std::vector<NearKeycodesSet> *nearKeysVector,
        const std::vector<float> *const SampledDistanceCache_G,
        std::vector<NearKeycodesSet> *SampledNearKeysVector,
        std::vector<hash_map_compat<int, float> > *charProbabilities) {
    static const float MIN_PROBABILITY = 0.000001f;
    static const float MAX_SKIP_PROBABILITY = 0.95f;
@@ -701,9 +698,9 @@ namespace latinime {

        float nearestKeyDistance = static_cast<float>(MAX_POINT_TO_KEY_LENGTH);
        for (int j = 0; j < keyCount; ++j) {
            if ((*nearKeysVector)[i].test(j)) {
            if ((*SampledNearKeysVector)[i].test(j)) {
                const float distance = getPointToKeyByIdLength(
                        maxPointToKeyLength, distanceCache_G, keyCount, i, j);
                        maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j);
                if (distance < nearestKeyDistance) {
                    nearestKeyDistance = distance;
                }
@@ -786,14 +783,14 @@ namespace latinime {
        // Summing up probability densities of all near keys.
        float sumOfProbabilityDensities = 0.0f;
        for (int j = 0; j < keyCount; ++j) {
            if ((*nearKeysVector)[i].test(j)) {
            if ((*SampledNearKeysVector)[i].test(j)) {
                float distance = sqrtf(getPointToKeyByIdLength(
                        maxPointToKeyLength, distanceCache_G, keyCount, i, j));
                        maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j));
                if (i == 0 && i != sampledInputSize - 1) {
                    // For the first point, weighted average of distances from first point and the
                    // next point to the key is used as a point to key distance.
                    const float nextDistance = sqrtf(getPointToKeyByIdLength(
                            maxPointToKeyLength, distanceCache_G, keyCount, i + 1, j));
                            maxPointToKeyLength, SampledDistanceCache_G, keyCount, i + 1, j));
                    if (nextDistance < distance) {
                        // The distance of the first point tends to bigger than continuing
                        // points because the first touch by the user can be sloppy.
@@ -806,7 +803,7 @@ namespace latinime {
                    // For the first point, weighted average of distances from last point and
                    // the previous point to the key is used as a point to key distance.
                    const float previousDistance = sqrtf(getPointToKeyByIdLength(
                            maxPointToKeyLength, distanceCache_G, keyCount, i - 1, j));
                            maxPointToKeyLength, SampledDistanceCache_G, keyCount, i - 1, j));
                    if (previousDistance < distance) {
                        // The distance of the last point tends to bigger than continuing points
                        // because the last touch by the user can be sloppy. So we promote the
@@ -824,14 +821,14 @@ namespace latinime {

        // Split the probability of an input point to keys that are close to the input point.
        for (int j = 0; j < keyCount; ++j) {
            if ((*nearKeysVector)[i].test(j)) {
            if ((*SampledNearKeysVector)[i].test(j)) {
                float distance = sqrtf(getPointToKeyByIdLength(
                        maxPointToKeyLength, distanceCache_G, keyCount, i, j));
                        maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j));
                if (i == 0 && i != sampledInputSize - 1) {
                    // For the first point, weighted average of distances from the first point and
                    // the next point to the key is used as a point to key distance.
                    const float prevDistance = sqrtf(getPointToKeyByIdLength(
                            maxPointToKeyLength, distanceCache_G, keyCount, i + 1, j));
                            maxPointToKeyLength, SampledDistanceCache_G, keyCount, i + 1, j));
                    if (prevDistance < distance) {
                        distance = (distance + prevDistance * NEXT_DISTANCE_WEIGHT)
                                / (1.0f + NEXT_DISTANCE_WEIGHT);
@@ -840,7 +837,7 @@ namespace latinime {
                    // For the first point, weighted average of distances from last point and
                    // the previous point to the key is used as a point to key distance.
                    const float prevDistance = sqrtf(getPointToKeyByIdLength(
                            maxPointToKeyLength, distanceCache_G, keyCount, i - 1, j));
                            maxPointToKeyLength, SampledDistanceCache_G, keyCount, i - 1, j));
                    if (prevDistance < distance) {
                        distance = (distance + prevDistance * PREV_DISTANCE_WEIGHT)
                                / (1.0f + PREV_DISTANCE_WEIGHT);
@@ -906,10 +903,10 @@ namespace latinime {
        for (int j = 0; j < keyCount; ++j) {
            hash_map_compat<int, float>::iterator it = (*charProbabilities)[i].find(j);
            if (it == (*charProbabilities)[i].end()){
                (*nearKeysVector)[i].reset(j);
                (*SampledNearKeysVector)[i].reset(j);
            } else if(it->second < MIN_PROBABILITY) {
                // Erases from near keys vector because it has very low probability.
                (*nearKeysVector)[i].reset(j);
                (*SampledNearKeysVector)[i].reset(j);
                (*charProbabilities)[i].erase(j);
            } else {
                it->second = -logf(it->second);
@@ -919,25 +916,26 @@ namespace latinime {
    }
}

/* static */ void ProximityInfoStateUtils::updateSearchKeysVector(
/* static */ void ProximityInfoStateUtils::updateSampledSearchKeysVector(
        const ProximityInfo *const proximityInfo, const int sampledInputSize,
        const int lastSavedInputSize,
        const std::vector<int> *const sampledLengthCache,
        const std::vector<NearKeycodesSet> *const nearKeysVector,
        std::vector<NearKeycodesSet> *searchKeysVector) {
        const std::vector<NearKeycodesSet> *const SampledNearKeysVector,
        std::vector<NearKeycodesSet> *sampledSearchKeysVector) {
    sampledSearchKeysVector->resize(sampledInputSize);
    const int readForwordLength = static_cast<int>(
            hypotf(proximityInfo->getKeyboardWidth(), proximityInfo->getKeyboardHeight())
                    * ProximityInfoParams::SEARCH_KEY_RADIUS_RATIO);
    for (int i = 0; i < sampledInputSize; ++i) {
        if (i >= lastSavedInputSize) {
            (*searchKeysVector)[i].reset();
            (*sampledSearchKeysVector)[i].reset();
        }
        for (int j = max(i, lastSavedInputSize); j < sampledInputSize; ++j) {
            // TODO: Investigate if this is required. This may not fail.
            if ((*sampledLengthCache)[j] - (*sampledLengthCache)[i] >= readForwordLength) {
                break;
            }
            (*searchKeysVector)[i] |= (*nearKeysVector)[j];
            (*sampledSearchKeysVector)[i] |= (*SampledNearKeysVector)[j];
        }
    }
}
+9 −10
Original line number Diff line number Diff line
@@ -69,29 +69,28 @@ class ProximityInfoStateUtils {
            const std::vector<int> *const sampledInputYs,
            const std::vector<float> *const sampledSpeedRates,
            const std::vector<int> *const sampledLengthCache,
            const std::vector<float> *const distanceCache_G,
            std::vector<NearKeycodesSet> *nearKeysVector,
            const std::vector<float> *const SampledDistanceCache_G,
            std::vector<NearKeycodesSet> *SampledNearKeysVector,
            std::vector<hash_map_compat<int, float> > *charProbabilities);
    static void updateSearchKeysVector(
    static void updateSampledSearchKeysVector(
            const ProximityInfo *const proximityInfo, const int sampledInputSize,
            const int lastSavedInputSize,
            const std::vector<int> *const sampledLengthCache,
            const std::vector<NearKeycodesSet> *const nearKeysVector,
            std::vector<NearKeycodesSet> *searchKeysVector);
            const std::vector<NearKeycodesSet> *const SampledNearKeysVector,
            std::vector<NearKeycodesSet> *sampledSearchKeysVector);
    static float getPointToKeyByIdLength(const float maxPointToKeyLength,
            const std::vector<float> *const distanceCache_G, const int keyCount,
            const std::vector<float> *const SampledDistanceCache_G, const int keyCount,
            const int inputIndex, const int keyId, const float scale);
    static float getPointToKeyByIdLength(const float maxPointToKeyLength,
            const std::vector<float> *const distanceCache_G, const int keyCount,
            const std::vector<float> *const SampledDistanceCache_G, const int keyCount,
            const int inputIndex, const int keyId);
    static void initGeometricDistanceInfos(
            const ProximityInfo *const proximityInfo, const int keyCount,
            const int sampledInputSize, const int lastSavedInputSize,
            const std::vector<int> *const sampledInputXs,
            const std::vector<int> *const sampledInputYs,
            std::vector<NearKeycodesSet> *nearKeysVector,
            std::vector<NearKeycodesSet> *searchKeysVector,
            std::vector<float> *distanceCache_G);
            std::vector<NearKeycodesSet> *SampledNearKeysVector,
            std::vector<float> *SampledDistanceCache_G);
    static void initPrimaryInputWord(
            const int inputSize, const int *const inputProximities, int *primaryInputWord);
    static void initNormalizedSquaredDistances(