Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 5849feee authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android (Google) Code Review
Browse files

Merge "Use ReadOnlyByteArrayView in PatriciaTriePolicy."

parents 198a47a1 180e7b4c
Loading
Loading
Loading
Loading
+44 −35
Original line number Diff line number Diff line
@@ -37,19 +37,19 @@ void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNo
        return;
    }
    int nextPos = dicNode->getChildrenPtNodeArrayPos();
    if (nextPos < 0 || nextPos >= mDictBufferSize) {
        AKLOGE("Children PtNode array position is invalid. pos: %d, dict size: %d",
                nextPos, mDictBufferSize);
    if (!isValidPos(nextPos)) {
        AKLOGE("Children PtNode array position is invalid. pos: %d, dict size: %zd",
                nextPos, mBuffer.size());
        mIsCorrupted = true;
        ASSERT(false);
        return;
    }
    const int childCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition(
            mDictRoot, &nextPos);
            mBuffer.data(), &nextPos);
    for (int i = 0; i < childCount; i++) {
        if (nextPos < 0 || nextPos >= mDictBufferSize) {
            AKLOGE("Child PtNode position is invalid. pos: %d, dict size: %d, childCount: %d / %d",
                    nextPos, mDictBufferSize, i, childCount);
        if (!isValidPos(nextPos)) {
            AKLOGE("Child PtNode position is invalid. pos: %d, dict size: %zd, childCount: %d / %d",
                    nextPos, mBuffer.size(), i, childCount);
            mIsCorrupted = true;
            ASSERT(false);
            return;
@@ -91,56 +91,57 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
        int lastCandidatePtNodePos = 0;
        // Let's loop through PtNodes in this PtNode array searching for either the terminal
        // or one of its ascendants.
        if (pos < 0 || pos >= mDictBufferSize) {
            AKLOGE("PtNode array position is invalid. pos: %d, dict size: %d",
                    pos, mDictBufferSize);
        if (!isValidPos(pos)) {
            AKLOGE("PtNode array position is invalid. pos: %d, dict size: %zd",
                    pos, mBuffer.size());
            mIsCorrupted = true;
            ASSERT(false);
            *outUnigramProbability = NOT_A_PROBABILITY;
            return 0;
        }
        for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition(
                mDictRoot, &pos); ptNodeCount > 0; --ptNodeCount) {
                mBuffer.data(), &pos); ptNodeCount > 0; --ptNodeCount) {
            const int startPos = pos;
            if (pos < 0 || pos >= mDictBufferSize) {
                AKLOGE("PtNode position is invalid. pos: %d, dict size: %d", pos, mDictBufferSize);
            if (!isValidPos(pos)) {
                AKLOGE("PtNode position is invalid. pos: %d, dict size: %zd", pos, mBuffer.size());
                mIsCorrupted = true;
                ASSERT(false);
                *outUnigramProbability = NOT_A_PROBABILITY;
                return 0;
            }
            const PatriciaTrieReadingUtils::NodeFlags flags =
                    PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos);
                    PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mBuffer.data(), &pos);
            const int character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
                    mDictRoot, &pos);
                    mBuffer.data(), &pos);
            if (ptNodePos == startPos) {
                // We found the position. Copy the rest of the code points in the buffer and return
                // the length.
                outCodePoints[wordPos] = character;
                if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) {
                    int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
                            mDictRoot, &pos);
                            mBuffer.data(), &pos);
                    // We count code points in order to avoid infinite loops if the file is broken
                    // or if there is some other bug
                    int charCount = maxCodePointCount;
                    while (NOT_A_CODE_POINT != nextChar && --charCount > 0) {
                        outCodePoints[++wordPos] = nextChar;
                        nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
                                mDictRoot, &pos);
                                mBuffer.data(), &pos);
                    }
                }
                *outUnigramProbability =
                        PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot,
                        PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(),
                                &pos);
                return ++wordPos;
            }
            // We need to skip past this PtNode, so skip any remaining code points after the
            // first and possibly the probability.
            if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) {
                PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, &pos);
                PatriciaTrieReadingUtils::skipCharacters(mBuffer.data(), flags, MAX_WORD_LENGTH,
                        &pos);
            }
            if (PatriciaTrieReadingUtils::isTerminal(flags)) {
                PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos);
                PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(), &pos);
            }
            // The fact that this PtNode has children is very important. Since we already know
            // that this PtNode does not match, if it has no children we know it is irrelevant
@@ -155,7 +156,8 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
                int currentPos = pos;
                // Here comes the tricky part. First, read the children position.
                const int childrenPos = PatriciaTrieReadingUtils
                        ::readChildrenPositionAndAdvancePosition(mDictRoot, flags, &currentPos);
                        ::readChildrenPositionAndAdvancePosition(mBuffer.data(), flags,
                                &currentPos);
                if (childrenPos > ptNodePos) {
                    // If the children pos is greater than the position, it means the previous
                    // PtNode, which position is stored in lastCandidatePtNodePos, was the right
@@ -185,30 +187,30 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
                if (0 != lastCandidatePtNodePos) {
                    const PatriciaTrieReadingUtils::NodeFlags lastFlags =
                            PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(
                                    mDictRoot, &lastCandidatePtNodePos);
                                    mBuffer.data(), &lastCandidatePtNodePos);
                    const int lastChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
                            mDictRoot, &lastCandidatePtNodePos);
                            mBuffer.data(), &lastCandidatePtNodePos);
                    // We copy all the characters in this PtNode to the buffer
                    outCodePoints[wordPos] = lastChar;
                    if (PatriciaTrieReadingUtils::hasMultipleChars(lastFlags)) {
                        int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
                                mDictRoot, &lastCandidatePtNodePos);
                                mBuffer.data(), &lastCandidatePtNodePos);
                        int charCount = maxCodePointCount;
                        while (-1 != nextChar && --charCount > 0) {
                            outCodePoints[++wordPos] = nextChar;
                            nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
                                    mDictRoot, &lastCandidatePtNodePos);
                                    mBuffer.data(), &lastCandidatePtNodePos);
                        }
                    }
                    ++wordPos;
                    // Now we only need to branch to the children address. Skip the probability if
                    // it's there, read pos, and break to resume the search at pos.
                    if (PatriciaTrieReadingUtils::isTerminal(lastFlags)) {
                        PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot,
                        PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(),
                                &lastCandidatePtNodePos);
                    }
                    pos = PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(
                            mDictRoot, lastFlags, &lastCandidatePtNodePos);
                            mBuffer.data(), lastFlags, &lastCandidatePtNodePos);
                    break;
                } else {
                    // Here is a little tricky part: we come here if we found out that all children
@@ -220,14 +222,14 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
                    // ready to start the next one.
                    if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) {
                        PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(
                                mDictRoot, flags, &pos);
                                mBuffer.data(), flags, &pos);
                    }
                    if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) {
                        mShortcutListPolicy.skipAllShortcuts(&pos);
                    }
                    if (PatriciaTrieReadingUtils::hasBigrams(flags)) {
                        if (!mBigramListPolicy.skipAllBigrams(&pos)) {
                            AKLOGE("Cannot skip bigrams. BufSize: %d, pos: %d.", mDictBufferSize,
                            AKLOGE("Cannot skip bigrams. BufSize: %zd, pos: %d.", mBuffer.size(),
                                    pos);
                            mIsCorrupted = true;
                            ASSERT(false);
@@ -244,14 +246,14 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
                // our pos is after the end of this PtNode, at the start of the next one.
                if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) {
                    PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(
                            mDictRoot, flags, &pos);
                            mBuffer.data(), flags, &pos);
                }
                if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) {
                    mShortcutListPolicy.skipAllShortcuts(&pos);
                }
                if (PatriciaTrieReadingUtils::hasBigrams(flags)) {
                    if (!mBigramListPolicy.skipAllBigrams(&pos)) {
                        AKLOGE("Cannot skip bigrams. BufSize: %d, pos: %d.", mDictBufferSize, pos);
                        AKLOGE("Cannot skip bigrams. BufSize: %zd, pos: %d.", mBuffer.size(), pos);
                        mIsCorrupted = true;
                        ASSERT(false);
                        *outUnigramProbability = NOT_A_PROBABILITY;
@@ -402,7 +404,7 @@ int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNod
    int shortcutPos = NOT_A_DICT_POS;
    int bigramPos = NOT_A_DICT_POS;
    int siblingPos = NOT_A_DICT_POS;
    PatriciaTrieReadingUtils::readPtNodeInfo(mDictRoot, ptNodePos, &mShortcutListPolicy,
    PatriciaTrieReadingUtils::readPtNodeInfo(mBuffer.data(), ptNodePos, &mShortcutListPolicy,
            &mBigramListPolicy, &flags, &mergedNodeCodePointCount, mergedNodeCodePoints,
            &probability, &childrenPos, &shortcutPos, &bigramPos, &siblingPos);
    // Skip PtNodes don't start with Unicode code point because they represent non-word information.
@@ -452,14 +454,16 @@ const WordProperty PatriciaTriePolicy::getWordProperty(
    int shortcutPos = getShortcutPositionOfPtNode(ptNodePos);
    if (shortcutPos != NOT_A_DICT_POS) {
        int shortcutTargetCodePoints[MAX_WORD_LENGTH];
        ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mDictRoot, &shortcutPos);
        ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mBuffer.data(),
                &shortcutPos);
        bool hasNext = true;
        while (hasNext) {
            const ShortcutListReadingUtils::ShortcutFlags shortcutFlags =
                    ShortcutListReadingUtils::getFlagsAndForwardPointer(mDictRoot, &shortcutPos);
                    ShortcutListReadingUtils::getFlagsAndForwardPointer(mBuffer.data(),
                            &shortcutPos);
            hasNext = ShortcutListReadingUtils::hasNext(shortcutFlags);
            const int shortcutTargetLength = ShortcutListReadingUtils::readShortcutTarget(
                    mDictRoot, MAX_WORD_LENGTH, shortcutTargetCodePoints, &shortcutPos);
                    mBuffer.data(), MAX_WORD_LENGTH, shortcutTargetCodePoints, &shortcutPos);
            const std::vector<int> shortcutTarget(shortcutTargetCodePoints,
                    shortcutTargetCodePoints + shortcutTargetLength);
            const int shortcutProbability =
@@ -512,4 +516,9 @@ int PatriciaTriePolicy::getWordIdFromTerminalPtNodePos(const int ptNodePos) cons
int PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) const {
    return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId;
}

bool PatriciaTriePolicy::isValidPos(const int pos) const {
    return pos >= 0 && pos < static_cast<int>(mBuffer.size());
}

} // namespace latinime
+8 −9
Original line number Diff line number Diff line
@@ -44,13 +44,12 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
            : mMmappedBuffer(std::move(mmappedBuffer)),
              mHeaderPolicy(mMmappedBuffer->getReadOnlyByteArrayView().data(),
                      FormatUtils::VERSION_2),
              mDictRoot(mMmappedBuffer->getReadOnlyByteArrayView().data()
                      + mHeaderPolicy.getSize()),
              mDictBufferSize(mMmappedBuffer->getReadOnlyByteArrayView().size()
                      - mHeaderPolicy.getSize()),
              mBigramListPolicy(mDictRoot, mDictBufferSize), mShortcutListPolicy(mDictRoot),
              mPtNodeReader(mDictRoot, mDictBufferSize, &mBigramListPolicy, &mShortcutListPolicy),
              mPtNodeArrayReader(mDictRoot, mDictBufferSize),
              mBuffer(mMmappedBuffer->getReadOnlyByteArrayView().skip(mHeaderPolicy.getSize())),
              mBigramListPolicy(mBuffer.data(), mBuffer.size()),
              mShortcutListPolicy(mBuffer.data()),
              mPtNodeReader(mBuffer.data(), mBuffer.size(), &mBigramListPolicy,
                      &mShortcutListPolicy),
              mPtNodeArrayReader(mBuffer.data(), mBuffer.size()),
              mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {}

    AK_FORCE_INLINE int getRootPosition() const {
@@ -149,8 +148,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {

    const MmappedBuffer::MmappedBufferPtr mMmappedBuffer;
    const HeaderPolicy mHeaderPolicy;
    const uint8_t *const mDictRoot;
    const int mDictBufferSize;
    const ReadOnlyByteArrayView mBuffer;
    const BigramListPolicy mBigramListPolicy;
    const ShortcutListPolicy mShortcutListPolicy;
    const Ver2ParticiaTrieNodeReader mPtNodeReader;
@@ -166,6 +164,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
    int getTerminalPtNodePosFromWordId(const int wordId) const;
    const WordAttributes getWordAttributes(const int probability,
            const PtNodeParams &ptNodeParams) const;
    bool isValidPos(const int pos) const;
};
} // namespace latinime
#endif // LATINIME_PATRICIA_TRIE_POLICY_H
+7 −0
Original line number Diff line number Diff line
@@ -42,6 +42,13 @@ class ReadOnlyByteArrayView {
        return mPtr;
    }

    AK_FORCE_INLINE const ReadOnlyByteArrayView skip(const size_t n) const {
        if (mSize <= n) {
            return ReadOnlyByteArrayView();
        }
        return ReadOnlyByteArrayView(mPtr + n, mSize - n);
    }

 private:
    DISALLOW_ASSIGNMENT_OPERATOR(ReadOnlyByteArrayView);