Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d4828d50 authored by Satoshi Kataoka's avatar Satoshi Kataoka
Browse files

Refactor proximity info state

Change-Id: I30cc0d8f2e48d70e214739a073eabf3a8ea73618
parent 6c22439b
Loading
Loading
Loading
Loading
+7 −322
Original line number Diff line number Diff line
@@ -138,7 +138,11 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
        }
        if (isGeometric) {
            // updates probabilities of skipping or mapping each key for all points.
            updateAlignPointProbabilities(lastSavedInputSize);
            ProximityInfoStateUtils::updateAlignPointProbabilities(
                    mMaxPointToKeyLength, mProximityInfo->getMostCommonKeyWidth(),
                    keyCount, lastSavedInputSize, mSampledInputSize, &mSampledInputXs,
                    &mSampledInputYs, &mSpeedRates, &mLengthCache, &mDistanceCache_G,
                    &mNearKeysVector, &mCharProbabilities);

            static const float READ_FORWORD_LENGTH_SCALE = 0.95f;
            const int readForwordLength = static_cast<int>(
@@ -307,16 +311,10 @@ float ProximityInfoState::getPointToKeyLength_G(const int inputIndex, const int
}

// TODO: Remove the "scale" parameter
// This function basically converts from a length to an edit distance. Accordingly, it's obviously
// wrong to compare with mMaxPointToKeyLength.
float ProximityInfoState::getPointToKeyByIdLength(
        const int inputIndex, const int keyId, const float scale) const {
    if (keyId != NOT_AN_INDEX) {
        const int index = inputIndex * mProximityInfo->getKeyCount() + keyId;
        return min(mDistanceCache_G[index] * scale, mMaxPointToKeyLength);
    }
    // If the char is not a key on the keyboard then return the max length.
    return static_cast<float>(MAX_POINT_TO_KEY_LENGTH);
    return ProximityInfoStateUtils::getPointToKeyByIdLength(mMaxPointToKeyLength,
            &mDistanceCache_G, mProximityInfo->getKeyCount(), inputIndex, keyId, scale);
}

float ProximityInfoState::getPointToKeyByIdLength(const int inputIndex, const int keyId) const {
@@ -442,32 +440,6 @@ float ProximityInfoState::getDirection(const int index0, const int index1) const
            &mSampledInputXs, &mSampledInputYs, index0, index1);
}

float ProximityInfoState::getPointAngle(const int index) const {
    if (index <= 0 || index >= mSampledInputSize - 1) {
        return 0.0f;
    }
    const float previousDirection = getDirection(index - 1, index);
    const float nextDirection = getDirection(index, index + 1);
    const float directionDiff = getAngleDiff(previousDirection, nextDirection);
    return directionDiff;
}

float ProximityInfoState::getPointsAngle(
        const int index0, const int index1, const int index2) const {
    if (index0 < 0 || index0 > mSampledInputSize - 1) {
        return 0.0f;
    }
    if (index1 < 0 || index1 > mSampledInputSize - 1) {
        return 0.0f;
    }
    if (index2 < 0 || index2 > mSampledInputSize - 1) {
        return 0.0f;
    }
    const float previousDirection = getDirection(index0, index1);
    const float nextDirection = getDirection(index1, index2);
    return getAngleDiff(previousDirection, nextDirection);
}

float ProximityInfoState::getLineToKeyDistance(
        const int from, const int to, const int keyId, const bool extend) const {
    if (from < 0 || from > mSampledInputSize - 1) {
@@ -488,293 +460,6 @@ float ProximityInfoState::getLineToKeyDistance(
            keyX, keyY, x0, y0, x1, y1, extend);
}

// Updates probabilities of aligning to some keys and skipping.
// Word suggestion should be based on this probabilities.
void ProximityInfoState::updateAlignPointProbabilities(const int start) {
    static const float MIN_PROBABILITY = 0.000001f;
    static const float MAX_SKIP_PROBABILITY = 0.95f;
    static const float SKIP_FIRST_POINT_PROBABILITY = 0.01f;
    static const float SKIP_LAST_POINT_PROBABILITY = 0.1f;
    static const float MIN_SPEED_RATE_FOR_SKIP_PROBABILITY = 0.15f;
    static const float SPEED_WEIGHT_FOR_SKIP_PROBABILITY = 0.9f;
    static const float SLOW_STRAIGHT_WEIGHT_FOR_SKIP_PROBABILITY = 0.6f;
    static const float NEAREST_DISTANCE_WEIGHT = 0.5f;
    static const float NEAREST_DISTANCE_BIAS = 0.5f;
    static const float NEAREST_DISTANCE_WEIGHT_FOR_LAST = 0.6f;
    static const float NEAREST_DISTANCE_BIAS_FOR_LAST = 0.4f;

    static const float ANGLE_WEIGHT = 0.90f;
    static const float DEEP_CORNER_ANGLE_THRESHOLD = M_PI_F * 60.0f / 180.0f;
    static const float SKIP_DEEP_CORNER_PROBABILITY = 0.1f;
    static const float CORNER_ANGLE_THRESHOLD = M_PI_F * 30.0f / 180.0f;
    static const float STRAIGHT_ANGLE_THRESHOLD = M_PI_F * 15.0f / 180.0f;
    static const float SKIP_CORNER_PROBABILITY = 0.4f;
    static const float SPEED_MARGIN = 0.1f;
    static const float CENTER_VALUE_OF_NORMALIZED_DISTRIBUTION = 0.0f;

    const int keyCount = mProximityInfo->getKeyCount();
    mCharProbabilities.resize(mSampledInputSize);
    // Calculates probabilities of using a point as a correlated point with the character
    // for each point.
    for (int i = start; i < mSampledInputSize; ++i) {
        mCharProbabilities[i].clear();
        // First, calculates skip probability. Starts form MIN_SKIP_PROBABILITY.
        // Note that all values that are multiplied to this probability should be in [0.0, 1.0];
        float skipProbability = MAX_SKIP_PROBABILITY;

        const float currentAngle = getPointAngle(i);
        const float speedRate = getSpeedRate(i);

        float nearestKeyDistance = static_cast<float>(MAX_POINT_TO_KEY_LENGTH);
        for (int j = 0; j < keyCount; ++j) {
            if (mNearKeysVector[i].test(j)) {
                const float distance = getPointToKeyByIdLength(i, j);
                if (distance < nearestKeyDistance) {
                    nearestKeyDistance = distance;
                }
            }
        }

        if (i == 0) {
            skipProbability *= min(1.0f, nearestKeyDistance * NEAREST_DISTANCE_WEIGHT
                    + NEAREST_DISTANCE_BIAS);
            // Promote the first point
            skipProbability *= SKIP_FIRST_POINT_PROBABILITY;
        } else if (i == mSampledInputSize - 1) {
            skipProbability *= min(1.0f, nearestKeyDistance * NEAREST_DISTANCE_WEIGHT_FOR_LAST
                    + NEAREST_DISTANCE_BIAS_FOR_LAST);
            // Promote the last point
            skipProbability *= SKIP_LAST_POINT_PROBABILITY;
        } else {
            // If the current speed is relatively slower than adjacent keys, we promote this point.
            if (getSpeedRate(i - 1) - SPEED_MARGIN > speedRate
                    && speedRate < getSpeedRate(i + 1) - SPEED_MARGIN) {
                if (currentAngle < CORNER_ANGLE_THRESHOLD) {
                    skipProbability *= min(1.0f, speedRate
                            * SLOW_STRAIGHT_WEIGHT_FOR_SKIP_PROBABILITY);
                } else {
                    // If the angle is small enough, we promote this point more. (e.g. pit vs put)
                    skipProbability *= min(1.0f, speedRate * SPEED_WEIGHT_FOR_SKIP_PROBABILITY
                            + MIN_SPEED_RATE_FOR_SKIP_PROBABILITY);
                }
            }

            skipProbability *= min(1.0f, speedRate * nearestKeyDistance *
                    NEAREST_DISTANCE_WEIGHT + NEAREST_DISTANCE_BIAS);

            // Adjusts skip probability by a rate depending on angle.
            // ANGLE_RATE of skipProbability is adjusted by current angle.
            skipProbability *= (M_PI_F - currentAngle) / M_PI_F * ANGLE_WEIGHT
                    + (1.0f - ANGLE_WEIGHT);
            if (currentAngle > DEEP_CORNER_ANGLE_THRESHOLD) {
                skipProbability *= SKIP_DEEP_CORNER_PROBABILITY;
            }
            // We assume the angle of this point is the angle for point[i], point[i - 2]
            // and point[i - 3]. The reason why we don't use the angle for point[i], point[i - 1]
            // and point[i - 2] is this angle can be more affected by the noise.
            const float prevAngle = getPointsAngle(i, i - 2, i - 3);
            if (i >= 3 && prevAngle < STRAIGHT_ANGLE_THRESHOLD
                    && currentAngle > CORNER_ANGLE_THRESHOLD) {
                skipProbability *= SKIP_CORNER_PROBABILITY;
            }
        }

        // probabilities must be in [0.0, MAX_SKIP_PROBABILITY];
        ASSERT(skipProbability >= 0.0f);
        ASSERT(skipProbability <= MAX_SKIP_PROBABILITY);
        mCharProbabilities[i][NOT_AN_INDEX] = skipProbability;

        // Second, calculates key probabilities by dividing the rest probability
        // (1.0f - skipProbability).
        const float inputCharProbability = 1.0f - skipProbability;

        // TODO: The variance is critical for accuracy; thus, adjusting these parameter by machine
        // learning or something would be efficient.
        static const float SPEEDxANGLE_WEIGHT_FOR_STANDARD_DIVIATION = 0.3f;
        static const float MAX_SPEEDxANGLE_RATE_FOR_STANDERD_DIVIATION = 0.25f;
        static const float SPEEDxNEAREST_WEIGHT_FOR_STANDARD_DIVIATION = 0.5f;
        static const float MAX_SPEEDxNEAREST_RATE_FOR_STANDERD_DIVIATION = 0.15f;
        static const float MIN_STANDERD_DIVIATION = 0.37f;

        const float speedxAngleRate = min(speedRate * currentAngle / M_PI_F
                * SPEEDxANGLE_WEIGHT_FOR_STANDARD_DIVIATION,
                        MAX_SPEEDxANGLE_RATE_FOR_STANDERD_DIVIATION);
        const float speedxNearestKeyDistanceRate = min(speedRate * nearestKeyDistance
                * SPEEDxNEAREST_WEIGHT_FOR_STANDARD_DIVIATION,
                        MAX_SPEEDxNEAREST_RATE_FOR_STANDERD_DIVIATION);
        const float sigma = speedxAngleRate + speedxNearestKeyDistanceRate + MIN_STANDERD_DIVIATION;

        ProximityInfoUtils::NormalDistribution
                distribution(CENTER_VALUE_OF_NORMALIZED_DISTRIBUTION, sigma);
        static const float PREV_DISTANCE_WEIGHT = 0.5f;
        static const float NEXT_DISTANCE_WEIGHT = 0.6f;
        // Summing up probability densities of all near keys.
        float sumOfProbabilityDensities = 0.0f;
        for (int j = 0; j < keyCount; ++j) {
            if (mNearKeysVector[i].test(j)) {
                float distance = sqrtf(getPointToKeyByIdLength(i, j));
                if (i == 0 && i != mSampledInputSize - 1) {
                    // For the first point, weighted average of distances from first point and the
                    // next point to the key is used as a point to key distance.
                    const float nextDistance = sqrtf(getPointToKeyByIdLength(i + 1, j));
                    if (nextDistance < distance) {
                        // The distance of the first point tends to bigger than continuing
                        // points because the first touch by the user can be sloppy.
                        // So we promote the first point if the distance of that point is larger
                        // than the distance of the next point.
                        distance = (distance + nextDistance * NEXT_DISTANCE_WEIGHT)
                                / (1.0f + NEXT_DISTANCE_WEIGHT);
                    }
                } else if (i != 0 && i == mSampledInputSize - 1) {
                    // For the first point, weighted average of distances from last point and
                    // the previous point to the key is used as a point to key distance.
                    const float previousDistance = sqrtf(getPointToKeyByIdLength(i - 1, j));
                    if (previousDistance < distance) {
                        // The distance of the last point tends to bigger than continuing points
                        // because the last touch by the user can be sloppy. So we promote the
                        // last point if the distance of that point is larger than the distance of
                        // the previous point.
                        distance = (distance + previousDistance * PREV_DISTANCE_WEIGHT)
                                / (1.0f + PREV_DISTANCE_WEIGHT);
                    }
                }
                // TODO: Promote the first point when the extended line from the next input is near
                // from a key. Also, promote the last point as well.
                sumOfProbabilityDensities += distribution.getProbabilityDensity(distance);
            }
        }

        // Split the probability of an input point to keys that are close to the input point.
        for (int j = 0; j < keyCount; ++j) {
            if (mNearKeysVector[i].test(j)) {
                float distance = sqrtf(getPointToKeyByIdLength(i, j));
                if (i == 0 && i != mSampledInputSize - 1) {
                    // For the first point, weighted average of distances from the first point and
                    // the next point to the key is used as a point to key distance.
                    const float prevDistance = sqrtf(getPointToKeyByIdLength(i + 1, j));
                    if (prevDistance < distance) {
                        distance = (distance + prevDistance * NEXT_DISTANCE_WEIGHT)
                                / (1.0f + NEXT_DISTANCE_WEIGHT);
                    }
                } else if (i != 0 && i == mSampledInputSize - 1) {
                    // For the first point, weighted average of distances from last point and
                    // the previous point to the key is used as a point to key distance.
                    const float prevDistance = sqrtf(getPointToKeyByIdLength(i - 1, j));
                    if (prevDistance < distance) {
                        distance = (distance + prevDistance * PREV_DISTANCE_WEIGHT)
                                / (1.0f + PREV_DISTANCE_WEIGHT);
                    }
                }
                const float probabilityDensity = distribution.getProbabilityDensity(distance);
                const float probability = inputCharProbability * probabilityDensity
                        / sumOfProbabilityDensities;
                mCharProbabilities[i][j] = probability;
            }
        }
    }


    if (DEBUG_POINTS_PROBABILITY) {
        for (int i = 0; i < mSampledInputSize; ++i) {
            std::stringstream sstream;
            sstream << i << ", ";
            sstream << "(" << mSampledInputXs[i] << ", " << mSampledInputYs[i] << "), ";
            sstream << "Speed: "<< getSpeedRate(i) << ", ";
            sstream << "Angle: "<< getPointAngle(i) << ", \n";

            for (hash_map_compat<int, float>::iterator it = mCharProbabilities[i].begin();
                    it != mCharProbabilities[i].end(); ++it) {
                if (it->first == NOT_AN_INDEX) {
                    sstream << it->first
                            << "(skip):"
                            << it->second
                            << "\n";
                } else {
                    sstream << it->first
                            << "("
                            << static_cast<char>(mProximityInfo->getCodePointOf(it->first))
                            << "):"
                            << it->second
                            << "\n";
                }
            }
            AKLOGI("%s", sstream.str().c_str());
        }
    }

    // Decrease key probabilities of points which don't have the highest probability of that key
    // among nearby points. Probabilities of the first point and the last point are not suppressed.
    for (int i = max(start, 1); i < mSampledInputSize; ++i) {
        for (int j = i + 1; j < mSampledInputSize; ++j) {
            if (!suppressCharProbabilities(i, j)) {
                break;
            }
        }
        for (int j = i - 1; j >= max(start, 0); --j) {
            if (!suppressCharProbabilities(i, j)) {
                break;
            }
        }
    }

    // Converting from raw probabilities to log probabilities to calculate spatial distance.
    for (int i = start; i < mSampledInputSize; ++i) {
        for (int j = 0; j < keyCount; ++j) {
            hash_map_compat<int, float>::iterator it = mCharProbabilities[i].find(j);
            if (it == mCharProbabilities[i].end()){
                mNearKeysVector[i].reset(j);
            } else if(it->second < MIN_PROBABILITY) {
                // Erases from near keys vector because it has very low probability.
                mNearKeysVector[i].reset(j);
                mCharProbabilities[i].erase(j);
            } else {
                it->second = -logf(it->second);
            }
        }
        mCharProbabilities[i][NOT_AN_INDEX] = -logf(mCharProbabilities[i][NOT_AN_INDEX]);
    }
}

// Decreases char probabilities of index0 by checking probabilities of a near point (index1) and
// increases char probabilities of index1 by checking probabilities of index0.
bool ProximityInfoState::suppressCharProbabilities(const int index0, const int index1) {
    ASSERT(0 <= index0 && index0 < mSampledInputSize);
    ASSERT(0 <= index1 && index1 < mSampledInputSize);

    static const float SUPPRESSION_LENGTH_WEIGHT = 1.5f;
    static const float MIN_SUPPRESSION_RATE = 0.1f;
    static const float SUPPRESSION_WEIGHT = 0.5f;
    static const float SUPPRESSION_WEIGHT_FOR_PROBABILITY_GAIN = 0.1f;
    static const float SKIP_PROBABALITY_WEIGHT_FOR_PROBABILITY_GAIN = 0.3f;

    const float keyWidthFloat = static_cast<float>(mProximityInfo->getMostCommonKeyWidth());
    const float diff = fabsf(static_cast<float>(mLengthCache[index0] - mLengthCache[index1]));
    if (diff > keyWidthFloat * SUPPRESSION_LENGTH_WEIGHT) {
        return false;
    }
    const float suppressionRate = MIN_SUPPRESSION_RATE
            + diff / keyWidthFloat / SUPPRESSION_LENGTH_WEIGHT * SUPPRESSION_WEIGHT;
    for (hash_map_compat<int, float>::iterator it = mCharProbabilities[index0].begin();
            it != mCharProbabilities[index0].end(); ++it) {
        hash_map_compat<int, float>::iterator it2 =  mCharProbabilities[index1].find(it->first);
        if (it2 != mCharProbabilities[index1].end() && it->second < it2->second) {
            const float newProbability = it->second * suppressionRate;
            const float suppression = it->second - newProbability;
            it->second = newProbability;
            // mCharProbabilities[index0][NOT_AN_INDEX] is the probability of skipping this point.
            mCharProbabilities[index0][NOT_AN_INDEX] += suppression;

            // Add the probability of the same key nearby index1
            const float probabilityGain = min(suppression * SUPPRESSION_WEIGHT_FOR_PROBABILITY_GAIN,
                    mCharProbabilities[index1][NOT_AN_INDEX]
                            * SKIP_PROBABALITY_WEIGHT_FOR_PROBABILITY_GAIN);
            it2->second += probabilityGain;
            mCharProbabilities[index1][NOT_AN_INDEX] -= probabilityGain;
        }
    }
    return true;
}

// Get a word that is detected by tracing the most probable string into codePointBuf and
// returns probability of generating the word.
float ProximityInfoState::getMostProbableString(int *const codePointBuf) const {
+2 −24
Original line number Diff line number Diff line
@@ -17,7 +17,6 @@
#ifndef LATINIME_PROXIMITY_INFO_STATE_H
#define LATINIME_PROXIMITY_INFO_STATE_H

#include <bitset>
#include <cstring> // for memset()
#include <vector>

@@ -33,7 +32,6 @@ class ProximityInfo;

class ProximityInfoState {
 public:
    typedef std::bitset<MAX_KEY_COUNT_IN_A_KEYBOARD> NearKeycodesSet;
    static const int NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR_LOG_2;
    static const int NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR;
    static const float NOT_A_DISTANCE_FLOAT;
@@ -191,10 +189,6 @@ class ProximityInfoState {
    // get xy direction
    float getDirection(const int x, const int y) const;

    float getPointAngle(const int index) const;
    // Returns angle of three points. x, y, and z are indices.
    float getPointsAngle(const int index0, const int index1, const int index2) const;

    float getMostProbableString(int *const codePointBuf) const;

    float getProbability(const int index, const int charCode) const;
@@ -205,7 +199,6 @@ class ProximityInfoState {
    bool isKeyInSerchKeysAfterIndex(const int index, const int keyId) const;
 private:
    DISALLOW_COPY_AND_ASSIGN(ProximityInfoState);
    typedef hash_map_compat<int, float> NearKeysDistanceMap;
    /////////////////////////////////////////
    // Defined in proximity_info_state.cpp //
    /////////////////////////////////////////
@@ -226,24 +219,9 @@ class ProximityInfoState {
    inline const int *getProximityCodePointsAt(const int index) const {
        return ProximityInfoStateUtils::getProximityCodePointsAt(mInputProximities, index);
    }

    float updateNearKeysDistances(const int x, const int y,
            NearKeysDistanceMap *const currentNearKeysDistances);
    bool isPrevLocalMin(const NearKeysDistanceMap *const currentNearKeysDistances,
            const NearKeysDistanceMap *const prevNearKeysDistances,
            const NearKeysDistanceMap *const prevPrevNearKeysDistances) const;
    float getPointScore(
            const int x, const int y, const int time, const bool last, const float nearest,
            const float sumAngle, const NearKeysDistanceMap *const currentNearKeysDistances,
            const NearKeysDistanceMap *const prevNearKeysDistances,
            const NearKeysDistanceMap *const prevPrevNearKeysDistances) const;
    bool checkAndReturnIsContinuationPossible(const int inputSize, const int *const xCoordinates,
            const int *const yCoordinates, const int *const times, const bool isGeometric) const;
    void popInputData();
    void updateAlignPointProbabilities(const int start);
    bool suppressCharProbabilities(const int index1, const int index2);
    float calculateBeelineSpeedRate(const int id, const int inputSize,
            const int *const xCoordinates, const int *const yCoordinates, const int * times) const;

    // const
    const ProximityInfo *mProximityInfo;
@@ -272,12 +250,12 @@ class ProximityInfoState {
    // The vector for the key code set which holds nearby keys for each sampled input point
    // 1. Used to calculate the probability of the key
    // 2. Used to calculate mSearchKeysVector
    std::vector<NearKeycodesSet> mNearKeysVector;
    std::vector<ProximityInfoStateUtils::NearKeycodesSet> mNearKeysVector;
    // The vector for the key code set which holds nearby keys of some trailing sampled input points
    // for each sampled input point. These nearby keys contain the next characters which can be in
    // the dictionary. Specifically, currently we are looking for keys nearby trailing sampled
    // inputs including the current input point.
    std::vector<NearKeycodesSet> mSearchKeysVector;
    std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSearchKeysVector;
    bool mTouchPositionCorrectionEnabled;
    int mInputProximities[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH];
    int mNormalizedSquaredDistances[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH];
+368 −0

File changed.

Preview size limit exceeded, changes collapsed.

+32 −2

File changed.

Preview size limit exceeded, changes collapsed.