Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 21648562 authored by Keisuke Kuroynagi's avatar Keisuke Kuroynagi Committed by Android (Google) Code Review
Browse files

Merge "Move children filtering methods to DicNodeChildrenFilter."

parents b6f5d3e3 7a06a792
Loading
Loading
Loading
Loading
+58 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2013 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LATINIME_DIC_NODE_PROXIMITY_FILTER_H
#define LATINIME_DIC_NODE_PROXIMITY_FILTER_H

#include "defines.h"
#include "suggest/core/layout/proximity_info_state.h"
#include "suggest/core/layout/proximity_info_utils.h"
#include "suggest/core/policy/dictionary_structure_policy.h"

namespace latinime {

class DicNodeProximityFilter : public DictionaryStructurePolicy::NodeFilter {
 public:
    DicNodeProximityFilter(const ProximityInfoState *const pInfoState,
            const int pointIndex, const bool exactOnly)
            : mProximityInfoState(pInfoState), mPointIndex(pointIndex), mExactOnly(exactOnly) {}

    bool isFilteredOut(const int codePoint) const {
        return !isProximityCodePoint(codePoint);
    }

 private:
    DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeProximityFilter);

    const ProximityInfoState *const mProximityInfoState;
    const int mPointIndex;
    const bool mExactOnly;

    // TODO: Move to proximity info state
    bool isProximityCodePoint(const int codePoint) const {
        if (!mProximityInfoState) {
            return true;
        }
        if (mExactOnly) {
            return mProximityInfoState->getPrimaryCodePointAt(mPointIndex) == codePoint;
        }
        const ProximityType matchedId = mProximityInfoState->getProximityType(
                mPointIndex, codePoint, true /* checkProximityChars */);
        return ProximityInfoUtils::isMatchOrProximityChar(matchedId);
    }
};
} // namespace latinime
#endif // LATINIME_DIC_NODE_PROXIMITY_FILTER_H
+13 −72
Original line number Diff line number Diff line
@@ -14,18 +14,17 @@
 * limitations under the License.
 */

#include "suggest/core/dicnode/dic_node_utils.h"

#include <cstring>
#include <vector>

#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_utils.h"
#include "suggest/core/dicnode/dic_node_proximity_filter.h"
#include "suggest/core/dicnode/dic_node_vector.h"
#include "suggest/core/dictionary/binary_dictionary_info.h"
#include "suggest/core/dictionary/binary_format.h"
#include "suggest/core/dictionary/multi_bigram_map.h"
#include "suggest/core/dictionary/probability_utils.h"
#include "suggest/core/layout/proximity_info.h"
#include "suggest/core/layout/proximity_info_state.h"
#include "suggest/core/policy/dictionary_structure_policy.h"
#include "utils/char_utils.h"

@@ -57,21 +56,20 @@ namespace latinime {
///////////////////////////////////

/* static */ void DicNodeUtils::createAndGetPassingChildNode(DicNode *dicNode,
        const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly,
        const DicNodeProximityFilter *const childrenFilter,
        DicNodeVector *childDicNodes) {
    // Passing multiple chars node. No need to traverse child
    const int codePoint = dicNode->getNodeTypedCodePoint();
    const int baseLowerCaseCodePoint = CharUtils::toBaseLowerCase(codePoint);
    const bool isMatch = isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, codePoint);
    if (isMatch || CharUtils::isIntentionalOmissionCodePoint(baseLowerCaseCodePoint)) {
    if (!childrenFilter->isFilteredOut(codePoint)
            || CharUtils::isIntentionalOmissionCodePoint(baseLowerCaseCodePoint)) {
        childDicNodes->pushPassingChild(dicNode);
    }
}

/* static */ int DicNodeUtils::createAndGetLeavingChildNode(DicNode *dicNode, int pos,
        const BinaryDictionaryInfo *const binaryDictionaryInfo,
        const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly,
        const std::vector<int> *const codePointsFilter, const ProximityInfo *const pInfo,
        const DicNodeProximityFilter *const childrenFilter,
        DicNodeVector *childDicNodes) {
    int nextPos = pos;
    const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(
@@ -110,10 +108,7 @@ namespace latinime {
    const int siblingPos = BinaryFormat::skipChildrenPosAndAttributes(
            binaryDictionaryInfo->getDictRoot(), flags, pos);

    if (isDicNodeFilteredOut(mergedNodeCodePoints[0], pInfo, codePointsFilter)) {
        return siblingPos;
    }
    if (!isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, mergedNodeCodePoints[0])) {
    if (childrenFilter->isFilteredOut(mergedNodeCodePoints[0])) {
        return siblingPos;
    }
    childDicNodes->pushLeavingChild(dicNode, nextPos, flags, childrenPos, attributesPos,
@@ -121,39 +116,9 @@ namespace latinime {
    return siblingPos;
}

/* static */ bool DicNodeUtils::isDicNodeFilteredOut(const int nodeCodePoint,
        const ProximityInfo *const pInfo, const std::vector<int> *const codePointsFilter) {
    const int filterSize = codePointsFilter ? codePointsFilter->size() : 0;
    if (filterSize <= 0) {
        return false;
    }
    if (pInfo && (pInfo->getKeyIndexOf(nodeCodePoint) == NOT_AN_INDEX
            || CharUtils::isIntentionalOmissionCodePoint(nodeCodePoint))) {
        // If normalized nodeCodePoint is not on the keyboard or skippable, this child is never
        // filtered.
        return false;
    }
    const int lowerCodePoint = CharUtils::toLowerCase(nodeCodePoint);
    const int baseLowerCodePoint = CharUtils::toBaseCodePoint(lowerCodePoint);
    // TODO: Avoid linear search
    for (int i = 0; i < filterSize; ++i) {
        // Checking if a normalized code point is in filter characters when pInfo is not
        // null. When pInfo is null, nodeCodePoint is used to check filtering without
        // normalizing.
        if ((pInfo && ((*codePointsFilter)[i] == lowerCodePoint
                || (*codePointsFilter)[i] == baseLowerCodePoint))
                        || (!pInfo && (*codePointsFilter)[i] == nodeCodePoint)) {
            return false;
        }
    }
    return true;
}

/* static */ void DicNodeUtils::createAndGetAllLeavingChildNodes(DicNode *dicNode,
        const BinaryDictionaryInfo *const binaryDictionaryInfo,
        const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly,
        const std::vector<int> *const codePointsFilter, const ProximityInfo *const pInfo,
        DicNodeVector *childDicNodes) {
        const DicNodeProximityFilter *const childrenFilter, DicNodeVector *childDicNodes) {
    if (!dicNode->hasChildren()) {
        return;
    }
@@ -161,14 +126,8 @@ namespace latinime {
    const int childCount = BinaryFormat::getGroupCountAndForwardPointer(
            binaryDictionaryInfo->getDictRoot(), &nextPos);
    for (int i = 0; i < childCount; i++) {
        const int filterSize = codePointsFilter ? codePointsFilter->size() : 0;
        nextPos = createAndGetLeavingChildNode(dicNode, nextPos, binaryDictionaryInfo,
                pInfoState, pointIndex, exactOnly, codePointsFilter, pInfo,
                childDicNodes);
        if (!pInfo && filterSize > 0 && childDicNodes->exceeds(filterSize)) {
            // All code points have been found.
            break;
        }
                childrenFilter, childDicNodes);
    }
}

@@ -184,13 +143,12 @@ namespace latinime {
    if (dicNode->isTotalInputSizeExceedingLimit()) {
        return;
    }
    const DicNodeProximityFilter childrenFilter(pInfoState, pointIndex, exactOnly);
    if (!dicNode->isLeavingNode()) {
        DicNodeUtils::createAndGetPassingChildNode(dicNode, pInfoState, pointIndex, exactOnly,
                childDicNodes);
        DicNodeUtils::createAndGetPassingChildNode(dicNode, &childrenFilter, childDicNodes);
    } else {
        DicNodeUtils::createAndGetAllLeavingChildNodes(
                dicNode, binaryDictionaryInfo, pInfoState, pointIndex, exactOnly,
                0 /* codePointsFilter */, 0 /* pInfo */, childDicNodes);
                dicNode, binaryDictionaryInfo, &childrenFilter, childDicNodes);
    }
}

@@ -230,23 +188,6 @@ namespace latinime {
    return ProbabilityUtils::backoff(unigramProbability);
}

///////////////////////////////////////
// Bigram / Unigram dictionary utils //
///////////////////////////////////////

/* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState,
        const int pointIndex, const bool exactOnly, const int nodeCodePoint) {
    if (!pInfoState) {
        return true;
    }
    if (exactOnly) {
        return pInfoState->getPrimaryCodePointAt(pointIndex) == nodeCodePoint;
    }
    const ProximityType matchedId = pInfoState->getProximityType(pointIndex, nodeCodePoint,
            true /* checkProximityChars */);
    return isProximityChar(matchedId);
}

////////////////
// Char utils //
////////////////
+5 −21
Original line number Diff line number Diff line
@@ -18,7 +18,6 @@
#define LATINIME_DIC_NODE_UTILS_H

#include <stdint.h>
#include <vector>

#include "defines.h"

@@ -26,8 +25,8 @@ namespace latinime {

class BinaryDictionaryInfo;
class DicNode;
class DicNodeProximityFilter;
class DicNodeVector;
class ProximityInfo;
class ProximityInfoState;
class MultiBigramMap;

@@ -44,19 +43,12 @@ class DicNodeUtils {
            const BinaryDictionaryInfo *const binaryDictionaryInfo, DicNodeVector *childDicNodes);
    static float getBigramNodeImprobability(const BinaryDictionaryInfo *const binaryDictionaryInfo,
            const DicNode *const node, MultiBigramMap *const multiBigramMap);
    static bool isDicNodeFilteredOut(const int nodeCodePoint, const ProximityInfo *const pInfo,
            const std::vector<int> *const codePointsFilter);
    // TODO: Move to private
    static void getProximityChildDicNodes(DicNode *dicNode,
            const BinaryDictionaryInfo *const binaryDictionaryInfo,
            const ProximityInfoState *pInfoState, const int pointIndex, bool exactOnly,
            DicNodeVector *childDicNodes);

    // TODO: Move to proximity info
    static bool isProximityChar(ProximityType type) {
        return type == MATCH_CHAR || type == PROXIMITY_CHAR || type == ADDITIONAL_PROXIMITY_CHAR;
    }

 private:
    DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils);
    // Max number of bigrams to look up
@@ -64,22 +56,14 @@ class DicNodeUtils {

    static int getBigramNodeProbability(const BinaryDictionaryInfo *const binaryDictionaryInfo,
            const DicNode *const node, MultiBigramMap *multiBigramMap);
    static void createAndGetPassingChildNode(DicNode *dicNode, const ProximityInfoState *pInfoState,
            const int pointIndex, const bool exactOnly, DicNodeVector *childDicNodes);
    static void createAndGetPassingChildNode(DicNode *dicNode,
            const DicNodeProximityFilter *const childrenFilter, DicNodeVector *childDicNodes);
    static void createAndGetAllLeavingChildNodes(DicNode *dicNode,
            const BinaryDictionaryInfo *const binaryDictionaryInfo,
            const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly,
            const std::vector<int> *const codePointsFilter,
            const ProximityInfo *const pInfo, DicNodeVector *childDicNodes);
            const DicNodeProximityFilter *const childrenFilter, DicNodeVector *childDicNodes);
    static int createAndGetLeavingChildNode(DicNode *dicNode, int pos,
            const BinaryDictionaryInfo *const binaryDictionaryInfo,
            const ProximityInfoState *pInfoState, const int pointIndex,
            const bool exactOnly, const std::vector<int> *const codePointsFilter,
            const ProximityInfo *const pInfo, DicNodeVector *childDicNodes);

    // TODO: Move to proximity info
    static bool isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, const int pointIndex,
            const bool exactOnly, const int nodeCodePoint);
            const DicNodeProximityFilter *const childrenFilter, DicNodeVector *childDicNodes);
};
} // namespace latinime
#endif // LATINIME_DIC_NODE_UTILS_H
+4 −0
Original line number Diff line number Diff line
@@ -117,6 +117,10 @@ class ProximityInfoUtils {
        return getSquaredDistanceFloat(x, y, projectionX, projectionY);
    }

     static AK_FORCE_INLINE bool isMatchOrProximityChar(const ProximityType type) {
         return type == MATCH_CHAR || type == PROXIMITY_CHAR || type == ADDITIONAL_PROXIMITY_CHAR;
     }

    // Normal distribution N(u, sigma^2).
    struct NormalDistribution {
     public:
+2 −1
Original line number Diff line number Diff line
@@ -23,6 +23,7 @@
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_vector.h"
#include "suggest/core/layout/proximity_info_state.h"
#include "suggest/core/layout/proximity_info_utils.h"
#include "suggest/core/policy/traversal.h"
#include "suggest/core/session/dic_traverse_session.h"
#include "suggest/policyimpl/typing/scoring_params.h"
@@ -159,7 +160,7 @@ class TypingTraversal : public Traversal {
            const DicNode *const dicNode) const {
        const ProximityType proximityType =
                getProximityType(traverseSession, parentDicNode, dicNode);
        if (!DicNodeUtils::isProximityChar(proximityType)) {
        if (!ProximityInfoUtils::isMatchOrProximityChar(proximityType)) {
            return false;
        }
        return true;