Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 3195604a authored by Tom Ouyang's avatar Tom Ouyang Committed by Android Git Automerger
Browse files

am 7cc319f5: am 9559dd2e: Improve bigram frequency lookup

* commit '7cc319f5':
  Improve bigram frequency lookup
parents 17a7b697 7cc319f5
Loading
Loading
Loading
Loading
+72 −0
Original line number Diff line number Diff line
@@ -23,6 +23,7 @@

#include "bloom_filter.h"
#include "char_utils.h"
#include "hash_map_compat.h"

namespace latinime {

@@ -93,7 +94,13 @@ class BinaryFormat {
            const int unigramProbability, const int bigramProbability);
    static int getProbability(const int position, const std::map<int, int> *bigramMap,
            const uint8_t *bigramFilter, const int unigramProbability);
    static int getBigramProbabilityFromHashMap(const int position,
            const hash_map_compat<int, int> *bigramMap, const int unigramProbability);
    static float getMultiWordCostMultiplier(const uint8_t *const dict);
    static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position,
            hash_map_compat<int, int> *bigramMap);
    static int getBigramProbability(const uint8_t *const root, int position,
            const int nextPosition, const int unigramProbability);

    // Flags for special processing
    // Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or
@@ -105,6 +112,8 @@ class BinaryFormat {

 private:
    DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryFormat);
    static int getBigramListPositionForWordPosition(const uint8_t *const root, int position);

    static const int FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00;
    static const int FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40;
    static const int FLAG_GROUP_ADDRESS_TYPE_TWOBYTES = 0x80;
@@ -687,5 +696,68 @@ inline int BinaryFormat::getProbability(const int position, const std::map<int,
    }
    return backoff(unigramProbability);
}

// This returns a probability in log space.
inline int BinaryFormat::getBigramProbabilityFromHashMap(const int position,
        const hash_map_compat<int, int> *bigramMap, const int unigramProbability) {
    if (!bigramMap) return backoff(unigramProbability);
    const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position);
    if (bigramProbabilityIt != bigramMap->end()) {
        const int bigramProbability = bigramProbabilityIt->second;
        return computeProbabilityForBigram(unigramProbability, bigramProbability);
    }
    return backoff(unigramProbability);
}

AK_FORCE_INLINE void BinaryFormat::fillBigramProbabilityToHashMap(
        const uint8_t *const root, int position, hash_map_compat<int, int> *bigramMap) {
    position = getBigramListPositionForWordPosition(root, position);
    if (0 == position) return;

    uint8_t bigramFlags;
    do {
        bigramFlags = getFlagsAndForwardPointer(root, &position);
        const int probability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags;
        const int bigramPos = getAttributeAddressAndForwardPointer(root, bigramFlags,
                &position);
        (*bigramMap)[bigramPos] = probability;
    } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags);
}

AK_FORCE_INLINE int BinaryFormat::getBigramProbability(const uint8_t *const root, int position,
        const int nextPosition, const int unigramProbability) {
    position = getBigramListPositionForWordPosition(root, position);
    if (0 == position) return backoff(unigramProbability);

    uint8_t bigramFlags;
    do {
        bigramFlags = getFlagsAndForwardPointer(root, &position);
        const int bigramPos = getAttributeAddressAndForwardPointer(
                root, bigramFlags, &position);
        if (bigramPos == nextPosition) {
            const int bigramProbability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags;
            return computeProbabilityForBigram(unigramProbability, bigramProbability);
        }
    } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags);
    return backoff(unigramProbability);
}

// Returns a pointer to the start of the bigram list.
AK_FORCE_INLINE int BinaryFormat::getBigramListPositionForWordPosition(
        const uint8_t *const root, int position) {
    if (NOT_VALID_WORD == position) return 0;
    const uint8_t flags = getFlagsAndForwardPointer(root, &position);
    if (!(flags & FLAG_HAS_BIGRAMS)) return 0;
    if (flags & FLAG_HAS_MULTIPLE_CHARS) {
        position = skipOtherCharacters(root, position);
    } else {
        getCodePointAndForwardPointer(root, &position);
    }
    position = skipProbability(flags, position);
    position = skipChildrenPosition(flags, position);
    position = skipShortcuts(root, flags, position);
    return position;
}

} // namespace latinime
#endif // LATINIME_BINARY_FORMAT_H
+9 −0
Original line number Diff line number Diff line
@@ -379,6 +379,15 @@ static inline void prof_out(void) {
#error "BIGRAM_FILTER_MODULO is larger than BIGRAM_FILTER_BYTE_SIZE"
#endif

// Max number of bigram maps (previous word contexts) to be cached. Increasing this number could
// improve bigram lookup speed for multi-word suggestions, but at the cost of more memory usage.
// Also, there are diminishing returns since the most frequently used bigrams are typically near
// the beginning of the input and are thus the first ones to be cached. Note that these bigrams
// are reset for each new composing word.
#define MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP 25
// Most common previous word contexts currently have 100 bigrams
#define DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP 100

template<typename T> AK_FORCE_INLINE const T &min(const T &a, const T &b) { return a < b ? a : b; }
template<typename T> AK_FORCE_INLINE const T &max(const T &a, const T &b) { return a > b ? a : b; }

+89 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2013 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LATINIME_MULTI_BIGRAM_MAP_H
#define LATINIME_MULTI_BIGRAM_MAP_H

#include <cstring>
#include <stdint.h>

#include "defines.h"
#include "binary_format.h"
#include "hash_map_compat.h"

namespace latinime {

// Class for caching bigram maps for multiple previous word contexts. This is useful since the
// algorithm needs to look up the set of bigrams for every word pair that occurs in every
// multi-word suggestion.
class MultiBigramMap {
 public:
    MultiBigramMap() : mBigramMaps() {}
    ~MultiBigramMap() {}

    // Look up the bigram probability for the given word pair from the cached bigram maps.
    // Also caches the bigrams if there is space remaining and they have not been cached already.
    int getBigramProbability(const uint8_t *const dicRoot, const int wordPosition,
            const int nextWordPosition, const int unigramProbability) {
        hash_map_compat<int, BigramMap>::const_iterator mapPosition =
                mBigramMaps.find(wordPosition);
        if (mapPosition != mBigramMaps.end()) {
            return mapPosition->second.getBigramProbability(nextWordPosition, unigramProbability);
        }
        if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) {
            addBigramsForWordPosition(dicRoot, wordPosition);
            return mBigramMaps[wordPosition].getBigramProbability(
                    nextWordPosition, unigramProbability);
        }
        return BinaryFormat::getBigramProbability(
                dicRoot, wordPosition, nextWordPosition, unigramProbability);
    }

    void clear() {
        mBigramMaps.clear();
    }

 private:
    DISALLOW_COPY_AND_ASSIGN(MultiBigramMap);

    class BigramMap {
     public:
        BigramMap() : mBigramMap(DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP) {}
        ~BigramMap() {}

        void init(const uint8_t *const dicRoot, int position) {
            BinaryFormat::fillBigramProbabilityToHashMap(dicRoot, position, &mBigramMap);
        }

        inline int getBigramProbability(const int nextWordPosition, const int unigramProbability)
                const {
           return BinaryFormat::getBigramProbabilityFromHashMap(
                   nextWordPosition, &mBigramMap, unigramProbability);
        }

     private:
        // Note: Default copy constructor needed for use in hash_map.
        hash_map_compat<int, int> mBigramMap;
    };

    void addBigramsForWordPosition(const uint8_t *const dicRoot, const int position) {
        mBigramMaps[position].init(dicRoot, position);
    }

    hash_map_compat<int, BigramMap> mBigramMaps;
};
} // namespace latinime
#endif // LATINIME_MULTI_BIGRAM_MAP_H
+13 −70
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@
#include "dic_node.h"
#include "dic_node_utils.h"
#include "dic_node_vector.h"
#include "multi_bigram_map.h"
#include "proximity_info.h"
#include "proximity_info_state.h"

@@ -191,11 +192,11 @@ namespace latinime {
 * Computes the combined bigram / unigram cost for the given dicNode.
 */
/* static */ float DicNodeUtils::getBigramNodeImprobability(const uint8_t *const dicRoot,
        const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) {
        const DicNode *const node, MultiBigramMap *multiBigramMap) {
    if (node->isImpossibleBigramWord()) {
        return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
    }
    const int probability = getBigramNodeProbability(dicRoot, node, bigramCacheMap);
    const int probability = getBigramNodeProbability(dicRoot, node, multiBigramMap);
    // TODO: This equation to calculate the improbability looks unreasonable.  Investigate this.
    const float cost = static_cast<float>(MAX_PROBABILITY - probability)
            / static_cast<float>(MAX_PROBABILITY);
@@ -203,83 +204,25 @@ namespace latinime {
}

/* static */ int DicNodeUtils::getBigramNodeProbability(const uint8_t *const dicRoot,
        const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) {
        const DicNode *const node, MultiBigramMap *multiBigramMap) {
    const int unigramProbability = node->getProbability();
    const int encodedDiffOfBigramProbability =
            getBigramNodeEncodedDiffProbability(dicRoot, node, bigramCacheMap);
    if (NOT_A_PROBABILITY == encodedDiffOfBigramProbability) {
    const int wordPos = node->getPos();
    const int prevWordPos = node->getPrevWordPos();
    if (NOT_VALID_WORD == wordPos || NOT_VALID_WORD == prevWordPos) {
        // Note: Normally wordPos comes from the dictionary and should never equal NOT_VALID_WORD.
        return backoff(unigramProbability);
    }
    return BinaryFormat::computeProbabilityForBigram(
            unigramProbability, encodedDiffOfBigramProbability);
    if (multiBigramMap) {
        return multiBigramMap->getBigramProbability(
                dicRoot, prevWordPos, wordPos, unigramProbability);
    }
    return BinaryFormat::getBigramProbability(dicRoot, prevWordPos, wordPos, unigramProbability);
}

///////////////////////////////////////
// Bigram / Unigram dictionary utils //
///////////////////////////////////////

/* static */ int16_t DicNodeUtils::getBigramNodeEncodedDiffProbability(const uint8_t *const dicRoot,
        const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) {
    const int wordPos = node->getPos();
    const int prevWordPos = node->getPrevWordPos();
    return getBigramProbability(dicRoot, prevWordPos, wordPos, bigramCacheMap);
}

// TODO: Move this to BigramDictionary
/* static */ int16_t DicNodeUtils::getBigramProbability(const uint8_t *const dicRoot, int pos,
        const int nextPos, hash_map_compat<int, int16_t> *bigramCacheMap) {
    // TODO: this is painfully slow compared to the method used in the previous version of the
    // algorithm. Switch to that method.
    if (NOT_VALID_WORD == pos) return NOT_A_PROBABILITY;
    if (NOT_VALID_WORD == nextPos) return NOT_A_PROBABILITY;

    // Create a hash code for the given node pair (based on Josh Bloch's effective Java).
    // TODO: Use a real hash map data structure that deals with collisions.
    int hash = 17;
    hash = hash * 31 + pos;
    hash = hash * 31 + nextPos;

    hash_map_compat<int, int16_t>::const_iterator mapPos = bigramCacheMap->find(hash);
    if (mapPos != bigramCacheMap->end()) {
        return mapPos->second;
    }
    if (NOT_VALID_WORD == pos) {
        return NOT_A_PROBABILITY;
    }
    const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos);
    if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) {
        return NOT_A_PROBABILITY;
    }
    if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) {
        BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos);
    } else {
        pos = BinaryFormat::skipOtherCharacters(dicRoot, pos);
    }
    pos = BinaryFormat::skipChildrenPosition(flags, pos);
    pos = BinaryFormat::skipProbability(flags, pos);
    uint8_t bigramFlags;
    int count = 0;
    do {
        bigramFlags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos);
        const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(dicRoot,
                bigramFlags, &pos);
        if (bigramPos == nextPos) {
            const int16_t probability = BinaryFormat::MASK_ATTRIBUTE_PROBABILITY & bigramFlags;
            if (static_cast<int>(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) {
                (*bigramCacheMap)[hash] = probability;
            }
            return probability;
        }
        count++;
    } while ((BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags)
            && count < MAX_BIGRAMS_CONSIDERED_PER_CONTEXT);
    if (static_cast<int>(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) {
        // TODO: does this -1 mean NOT_VALID_WORD?
        (*bigramCacheMap)[hash] = -1;
    }
    return NOT_A_PROBABILITY;
}

/* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState,
        const int pointIndex, const bool exactOnly, const int nodeCodePoint) {
    if (!pInfoState) {
+3 −9
Original line number Diff line number Diff line
@@ -21,7 +21,6 @@
#include <vector>

#include "defines.h"
#include "hash_map_compat.h"

namespace latinime {

@@ -29,6 +28,7 @@ class DicNode;
class DicNodeVector;
class ProximityInfo;
class ProximityInfoState;
class MultiBigramMap;

class DicNodeUtils {
 public:
@@ -42,7 +42,7 @@ class DicNodeUtils {
    static void getAllChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot,
            DicNodeVector *childDicNodes);
    static float getBigramNodeImprobability(const uint8_t *const dicRoot,
            const DicNode *const node, hash_map_compat<int, int16_t> *const bigramCacheMap);
            const DicNode *const node, MultiBigramMap *const multiBigramMap);
    static bool isDicNodeFilteredOut(const int nodeCodePoint, const ProximityInfo *const pInfo,
            const std::vector<int> *const codePointsFilter);
    // TODO: Move to private
@@ -57,15 +57,11 @@ class DicNodeUtils {

 private:
    DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils);
    // Max cache size for the space omission error correction bigram lookup
    static const int MAX_BIGRAM_MAP_SIZE = 20000;
    // Max number of bigrams to look up
    static const int MAX_BIGRAMS_CONSIDERED_PER_CONTEXT = 500;

    static int getBigramNodeProbability(const uint8_t *const dicRoot, const DicNode *const node,
            hash_map_compat<int, int16_t> *bigramCacheMap);
    static int16_t getBigramNodeEncodedDiffProbability(const uint8_t *const dicRoot,
            const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap);
            MultiBigramMap *multiBigramMap);
    static void createAndGetPassingChildNode(DicNode *dicNode, const ProximityInfoState *pInfoState,
            const int pointIndex, const bool exactOnly, DicNodeVector *childDicNodes);
    static void createAndGetAllLeavingChildNodes(DicNode *dicNode, const uint8_t *const dicRoot,
@@ -76,8 +72,6 @@ class DicNodeUtils {
            const int terminalDepth, const ProximityInfoState *pInfoState, const int pointIndex,
            const bool exactOnly, const std::vector<int> *const codePointsFilter,
            const ProximityInfo *const pInfo, DicNodeVector *childDicNodes);
    static int16_t getBigramProbability(const uint8_t *const dicRoot, int pos, const int nextPos,
            hash_map_compat<int, int16_t> *bigramCacheMap);

    // TODO: Move to proximity info
    static bool isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, const int pointIndex,
Loading