Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 82f7d3a9 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android (Google) Code Review
Browse files

Merge "Add a method to iterate entries in LanguageModelDictContent."

parents 00042cb4 07b3b41c
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -71,6 +71,12 @@ bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView
    return mTrieMap.remove(wordId, bitmapEntryIndex);
}

LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEntries(
        const WordIdArrayView prevWordIds) const {
    const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
    return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
}

bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
        const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) {
    for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
+71 −0
Original line number Diff line number Diff line
@@ -39,6 +39,75 @@ class HeaderPolicy;
 */
class LanguageModelDictContent {
 public:
    // Pair of word id and probability entry used for iteration.
    class WordIdAndProbabilityEntry {
     public:
        WordIdAndProbabilityEntry(const int wordId, const ProbabilityEntry &probabilityEntry)
                : mWordId(wordId), mProbabilityEntry(probabilityEntry) {}

        int getWordId() const { return mWordId; }
        const ProbabilityEntry getProbabilityEntry() const { return mProbabilityEntry; }

     private:
        DISALLOW_DEFAULT_CONSTRUCTOR(WordIdAndProbabilityEntry);
        DISALLOW_ASSIGNMENT_OPERATOR(WordIdAndProbabilityEntry);

        const int mWordId;
        const ProbabilityEntry mProbabilityEntry;
    };

    // Iterator.
    class EntryIterator {
     public:
        EntryIterator(const TrieMap::TrieMapIterator &trieMapIterator,
                const bool hasHistoricalInfo)
                : mTrieMapIterator(trieMapIterator), mHasHistoricalInfo(hasHistoricalInfo) {}

        const WordIdAndProbabilityEntry operator*() const {
            const TrieMap::TrieMapIterator::IterationResult &result = *mTrieMapIterator;
            return WordIdAndProbabilityEntry(
                    result.key(), ProbabilityEntry::decode(result.value(), mHasHistoricalInfo));
        }

        bool operator!=(const EntryIterator &other) const {
            return mTrieMapIterator != other.mTrieMapIterator;
        }

        const EntryIterator &operator++() {
            ++mTrieMapIterator;
            return *this;
        }

     private:
        DISALLOW_DEFAULT_CONSTRUCTOR(EntryIterator);
        DISALLOW_ASSIGNMENT_OPERATOR(EntryIterator);

        TrieMap::TrieMapIterator mTrieMapIterator;
        const bool mHasHistoricalInfo;
    };

    // Class represents range to use range base for loops.
    class EntryRange {
     public:
        EntryRange(const TrieMap::TrieMapRange trieMapRange, const bool hasHistoricalInfo)
                : mTrieMapRange(trieMapRange), mHasHistoricalInfo(hasHistoricalInfo) {}

        EntryIterator begin() const {
            return EntryIterator(mTrieMapRange.begin(), mHasHistoricalInfo);
        }

        EntryIterator end() const {
            return EntryIterator(mTrieMapRange.end(), mHasHistoricalInfo);
        }

     private:
        DISALLOW_DEFAULT_CONSTRUCTOR(EntryRange);
        DISALLOW_ASSIGNMENT_OPERATOR(EntryRange);

        const TrieMap::TrieMapRange mTrieMapRange;
        const bool mHasHistoricalInfo;
    };

    LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
            const bool hasHistoricalInfo)
            : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
@@ -76,6 +145,8 @@ class LanguageModelDictContent {

    bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);

    EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;

    bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
            int *const outEntryCounts) {
        for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
+1 −1
Original line number Diff line number Diff line
@@ -98,7 +98,7 @@ class TrieMap {
        TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex)
                : mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex),
                  mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) {
            if (!trieMap) {
            if (!trieMap || mBaseBitmapEntryIndex == INVALID_INDEX) {
                return;
            }
            const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex);
+20 −0
Original line number Diff line number Diff line
@@ -18,6 +18,8 @@

#include <gtest/gtest.h>

#include <unordered_set>

#include "utils/int_array_view.h"

namespace latinime {
@@ -69,5 +71,23 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
    EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
}

TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
    LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);

    const ProbabilityEntry originalEntry(0xFC, 100);

    const int wordIds[] = { 1, 2, 3, 4, 5 };
    for (const int wordId : wordIds) {
        languageModelDictContent.setProbabilityEntry(wordId, &originalEntry);
    }
    std::unordered_set<int> wordIdSet(std::begin(wordIds), std::end(wordIds));
    for (const auto entry : languageModelDictContent.getProbabilityEntries(WordIdArrayView())) {
        EXPECT_EQ(originalEntry.getFlags(), entry.getProbabilityEntry().getFlags());
        EXPECT_EQ(originalEntry.getProbability(), entry.getProbabilityEntry().getProbability());
        wordIdSet.erase(entry.getWordId());
    }
    EXPECT_TRUE(wordIdSet.empty());
}

}  // namespace
}  // namespace latinime