Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit a9763f93 authored by Satoshi Kataoka's avatar Satoshi Kataoka
Browse files

refactor distance cache

Change-Id: I21b54b356641a63d7be17fd34b9ede7a63ec738a
parent 6cee61de
Loading
Loading
Loading
Loading
+4 −21
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@
#include "correction.h"
#include "defines.h"
#include "proximity_info_state.h"
#include "suggest_utils.h"

namespace latinime {

@@ -673,27 +674,9 @@ inline static bool isUpperCase(unsigned short c) {
            if (i < adjustedProximityMatchedCount) {
                multiplyIntCapped(typedLetterMultiplier, &finalFreq);
            }
            if (squaredDistance >= 0) {
                // Promote or demote the score according to the distance from the sweet spot
                static const float A = ZERO_DISTANCE_PROMOTION_RATE / 100.0f;
                static const float B = 1.0f;
                static const float C = 0.5f;
                static const float MIN = 0.3f;
                static const float R1 = NEUTRAL_SCORE_SQUARED_RADIUS;
                static const float R2 = HALF_SCORE_SQUARED_RADIUS;
                const float x = static_cast<float>(squaredDistance)
                        / ProximityInfoState::NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR;
                const float factor = max((x < R1)
                        ? (A * (R1 - x) + B * x) / R1
                        : (B * (R2 - x) + C * (x - R1)) / (R2 - R1), MIN);
                // factor is a piecewise linear function like:
                // A -_                  .
                //     ^-_               .
                // B      \              .
                //         \_            .
                // C         ------------.
                //                       .
                // 0   R1 R2             .
            const float factor =
                    SuggestUtils::getDistanceScalingFactor(static_cast<float>(squaredDistance));
            if (factor > 0.0f) {
                multiplyRate((int)(factor * 100.0f), &finalFreq);
            } else if (squaredDistance == PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO) {
                multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq);
+6 −6
Original line number Diff line number Diff line
@@ -101,7 +101,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
        mTimes.clear();
        mInputIndice.clear();
        mLengthCache.clear();
        mDistanceCache.clear();
        mDistanceCache_G.clear();
        mNearKeysVector.clear();
        mSearchKeysVector.clear();
        mSpeedRates.clear();
@@ -210,7 +210,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
        const int keyCount = mProximityInfo->getKeyCount();
        mNearKeysVector.resize(mSampledInputSize);
        mSearchKeysVector.resize(mSampledInputSize);
        mDistanceCache.resize(mSampledInputSize * keyCount);
        mDistanceCache_G.resize(mSampledInputSize * keyCount);
        for (int i = lastSavedInputSize; i < mSampledInputSize; ++i) {
            mNearKeysVector[i].reset();
            mSearchKeysVector[i].reset();
@@ -221,7 +221,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
                const int y = mSampledInputYs[i];
                const float normalizedSquaredDistance =
                        mProximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y);
                mDistanceCache[index] = normalizedSquaredDistance;
                mDistanceCache_G[index] = normalizedSquaredDistance;
                if (normalizedSquaredDistance < NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) {
                    mNearKeysVector[i][k] = true;
                }
@@ -686,7 +686,7 @@ float ProximityInfoState::getPointToKeyLength(
    const int keyId = mProximityInfo->getKeyIndexOf(codePoint);
    if (keyId != NOT_AN_INDEX) {
        const int index = inputIndex * mProximityInfo->getKeyCount() + keyId;
        return min(mDistanceCache[index] * scale, mMaxPointToKeyLength);
        return min(mDistanceCache_G[index] * scale, mMaxPointToKeyLength);
    }
    if (isSkippableCodePoint(codePoint)) {
        return 0.0f;
@@ -695,7 +695,7 @@ float ProximityInfoState::getPointToKeyLength(
    return MAX_POINT_TO_KEY_LENGTH;
}

float ProximityInfoState::getPointToKeyLength(const int inputIndex, const int codePoint) const {
float ProximityInfoState::getPointToKeyLength_G(const int inputIndex, const int codePoint) const {
    return getPointToKeyLength(inputIndex, codePoint, 1.0f);
}

@@ -706,7 +706,7 @@ float ProximityInfoState::getPointToKeyByIdLength(
        const int inputIndex, const int keyId, const float scale) const {
    if (keyId != NOT_AN_INDEX) {
        const int index = inputIndex * mProximityInfo->getKeyCount() + keyId;
        return min(mDistanceCache[index] * scale, mMaxPointToKeyLength);
        return min(mDistanceCache_G[index] * scale, mMaxPointToKeyLength);
    }
    // If the char is not a key on the keyboard then return the max length.
    return static_cast<float>(MAX_POINT_TO_KEY_LENGTH);
+3 −3
Original line number Diff line number Diff line
@@ -58,7 +58,7 @@ class ProximityInfoState {
              mHasTouchPositionCorrectionData(false), mMostCommonKeyWidthSquare(0), mLocaleStr(),
              mKeyCount(0), mCellHeight(0), mCellWidth(0), mGridHeight(0), mGridWidth(0),
              mIsContinuationPossible(false), mSampledInputXs(), mSampledInputYs(), mTimes(),
              mInputIndice(), mLengthCache(), mBeelineSpeedPercentiles(), mDistanceCache(),
              mInputIndice(), mLengthCache(), mBeelineSpeedPercentiles(), mDistanceCache_G(),
              mSpeedRates(), mDirections(), mCharProbabilities(), mNearKeysVector(),
              mSearchKeysVector(), mTouchPositionCorrectionEnabled(false), mSampledInputSize(0) {
        memset(mInputCodes, 0, sizeof(mInputCodes));
@@ -157,7 +157,7 @@ class ProximityInfoState {
    float getPointToKeyByIdLength(const int inputIndex, const int keyId, const float scale) const;
    float getPointToKeyByIdLength(const int inputIndex, const int keyId) const;
    float getPointToKeyLength(const int inputIndex, const int codePoint, const float scale) const;
    float getPointToKeyLength(const int inputIndex, const int codePoint) const;
    float getPointToKeyLength_G(const int inputIndex, const int codePoint) const;

    ProximityType getMatchedProximityId(const int index, const int c,
            const bool checkProximityChars, int *proximityIndex = 0) const;
@@ -274,7 +274,7 @@ class ProximityInfoState {
    std::vector<int> mInputIndice;
    std::vector<int> mLengthCache;
    std::vector<int> mBeelineSpeedPercentiles;
    std::vector<float> mDistanceCache;
    std::vector<float> mDistanceCache_G;
    std::vector<float> mSpeedRates;
    std::vector<float> mDirections;
    // probabilities of skipping or mapping to a key for each point.
+53 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2013 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LATINIME_SUGGEST_UTILS_H
#define LATINIME_SUGGEST_UTILS_H

#include "defines.h"

namespace latinime {
class SuggestUtils {
 public:
    static float getDistanceScalingFactor(float normalizedSquaredDistance) {
        if (normalizedSquaredDistance < 0.0f) {
            return -1.0f;
        }
        // Promote or demote the score according to the distance from the sweet spot
        static const float A = ZERO_DISTANCE_PROMOTION_RATE / 100.0f;
        static const float B = 1.0f;
        static const float C = 0.5f;
        static const float MIN = 0.3f;
        static const float R1 = NEUTRAL_SCORE_SQUARED_RADIUS;
        static const float R2 = HALF_SCORE_SQUARED_RADIUS;
        const float x = static_cast<float>(normalizedSquaredDistance)
                / ProximityInfoState::NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR;
        const float factor = max((x < R1)
                ? (A * (R1 - x) + B * x) / R1
                : (B * (R2 - x) + C * (x - R1)) / (R2 - R1), MIN);
        // factor is a piecewise linear function like:
        // A -_                  .
        //     ^-_               .
        // B      \              .
        //         \_            .
        // C         ------------.
        //                       .
        // 0   R1 R2             .
        return factor;
    }
};
} // namespace latinime
#endif // LATINIME_SUGGEST_UTILS_H