Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit baf915e4 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android Git Automerger
Browse files

am 6d181179: Merge "Add methods for unigrams to LanguageModelDictContent." into lmp-dev

* commit '6d181179':
  Add methods for unigrams to LanguageModelDictContent.
parents b8c1075a 6d181179
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -126,6 +126,8 @@ LATIN_IME_CORE_TEST_FILES := \
    defines_test.cpp \
    suggest/core/layout/normal_distribution_2d_test.cpp \
    suggest/core/dictionary/bloom_filter_test.cpp \
    suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp \
    suggest/policyimpl/dictionary/structure/v4/content/probability_entry_test.cpp \
    suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer_test.cpp \
    suggest/policyimpl/dictionary/utils/trie_map_test.cpp \
    utils/autocorrection_threshold_utils_test.cpp \
+59 −0
Original line number Diff line number Diff line
@@ -22,4 +22,63 @@ bool LanguageModelDictContent::save(FILE *const file) const {
    return mTrieMap.save(file);
}

bool LanguageModelDictContent::runGC(
        const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
        const LanguageModelDictContent *const originalContent,
        int *const outNgramCount) {
    return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(),
            0 /* nextLevelBitmapEntryIndex */, outNgramCount);
}

ProbabilityEntry LanguageModelDictContent::getProbabilityEntry(
        const WordIdArrayView prevWordIds, const int wordId) const {
    if (!prevWordIds.empty()) {
        // TODO: Read n-gram entry.
        return ProbabilityEntry();
    }
    const TrieMap::Result result = mTrieMap.getRoot(wordId);
    if (!result.mIsValid) {
        // Not found.
        return ProbabilityEntry();
    }
    return ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
}

bool LanguageModelDictContent::setProbabilityEntry(const WordIdArrayView prevWordIds,
        const int terminalId, const ProbabilityEntry *const probabilityEntry) {
    if (!prevWordIds.empty()) {
        // TODO: Add n-gram entry.
        return false;
    }
    return mTrieMap.putRoot(terminalId, probabilityEntry->encode(mHasHistoricalInfo));
}


bool LanguageModelDictContent::runGCInner(
        const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
        const TrieMap::TrieMapRange trieMapRange,
        const int nextLevelBitmapEntryIndex, int *const outNgramCount) {
    for (auto &entry : trieMapRange) {
        const auto it = terminalIdMap->find(entry.key());
        if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) {
            // The word has been removed.
            continue;
        }
        if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) {
            return false;
        }
        if (outNgramCount) {
            *outNgramCount += 1;
        }
        if (entry.hasNextLevelMap()) {
            if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(),
                    mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex),
                    outNgramCount)) {
                return false;
            }
        }
    }
    return true;
}

} // namespace latinime
+30 −2
Original line number Diff line number Diff line
@@ -20,25 +20,53 @@
#include <cstdio>

#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
#include "suggest/policyimpl/dictionary/utils/trie_map.h"
#include "utils/byte_array_view.h"
#include "utils/int_array_view.h"

namespace latinime {

/**
 * Class representing language model.
 *
 * This class provides methods to get and store unigram/n-gram probability information and flags.
 */
class LanguageModelDictContent {
 public:
    LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
            const bool hasHistoricalInfo)
            : mTrieMap(trieMapBuffer) {}
            : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}

    explicit LanguageModelDictContent(const bool hasHistoricalInfo) : mTrieMap() {}
    explicit LanguageModelDictContent(const bool hasHistoricalInfo)
            : mTrieMap(), mHasHistoricalInfo(hasHistoricalInfo) {}

    bool isNearSizeLimit() const {
        return mTrieMap.isNearSizeLimit();
    }

    bool save(FILE *const file) const;

    bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
            const LanguageModelDictContent *const originalContent,
            int *const outNgramCount);

    ProbabilityEntry getProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId) const;

    bool setProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId,
            const ProbabilityEntry *const probabilityEntry);

 private:
    DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);

    TrieMap mTrieMap;
    const bool mHasHistoricalInfo;

    bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
            const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
            int *const outNgramCount);
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
+52 −0
Original line number Diff line number Diff line
@@ -17,6 +17,9 @@
#ifndef LATINIME_PROBABILITY_ENTRY_H
#define LATINIME_PROBABILITY_ENTRY_H

#include <climits>
#include <cstdint>

#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
#include "suggest/policyimpl/dictionary/utils/historical_info.h"
@@ -67,6 +70,50 @@ class ProbabilityEntry {
        return &mHistoricalInfo;
    }

    uint64_t encode(const bool hasHistoricalInfo) const {
        uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
        if (hasHistoricalInfo) {
            encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mHistoricalInfo.getTimeStamp());
            encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mHistoricalInfo.getLevel());
            encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mHistoricalInfo.getCount());
        } else {
            encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mProbability);
        }
        return encodedEntry;
    }

    static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) {
        if (hasHistoricalInfo) {
            const int flags = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
                    Ver4DictConstants::TIME_STAMP_FIELD_SIZE
                            + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE
                            + Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
            const int timestamp = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::TIME_STAMP_FIELD_SIZE,
                    Ver4DictConstants::WORD_LEVEL_FIELD_SIZE
                            + Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
            const int level = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::WORD_LEVEL_FIELD_SIZE,
                    Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
            const int count = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::WORD_COUNT_FIELD_SIZE, 0 /* pos */);
            const HistoricalInfo historicalInfo(timestamp, level, count);
            return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo);
        } else {
            const int flags = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
                    Ver4DictConstants::PROBABILITY_SIZE);
            const int probability = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */);
            return ProbabilityEntry(flags, probability);
        }
    }

 private:
    // Copy constructor is public to use this class as a type of return value.
    DISALLOW_ASSIGNMENT_OPERATOR(ProbabilityEntry);
@@ -74,6 +121,11 @@ class ProbabilityEntry {
    const int mFlags;
    const int mProbability;
    const HistoricalInfo mHistoricalInfo;

    static int readFromEncodedEntry(const uint64_t encodedEntry, const int size, const int pos) {
        return static_cast<int>(
                (encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1));
    }
};
} // namespace latinime
#endif /* LATINIME_PROBABILITY_ENTRY_H */
+8 −0
Original line number Diff line number Diff line
@@ -90,6 +90,14 @@ class Ver4DictBuffers {
        return &mProbabilityDictContent;
    }

    AK_FORCE_INLINE LanguageModelDictContent *getMutableLanguageModelDictContent() {
        return &mLanguageModelDictContent;
    }

    AK_FORCE_INLINE const LanguageModelDictContent *getLanguageModelDictContent() const {
        return &mLanguageModelDictContent;
    }

    AK_FORCE_INLINE BigramDictContent *getMutableBigramDictContent() {
        return &mBigramDictContent;
    }
Loading