Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 6d181179 authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi Committed by Android (Google) Code Review
Browse files

Merge "Add methods for unigrams to LanguageModelDictContent." into lmp-dev

parents 1b9f7c96 08894842
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -126,6 +126,8 @@ LATIN_IME_CORE_TEST_FILES := \
    defines_test.cpp \
    suggest/core/layout/normal_distribution_2d_test.cpp \
    suggest/core/dictionary/bloom_filter_test.cpp \
    suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp \
    suggest/policyimpl/dictionary/structure/v4/content/probability_entry_test.cpp \
    suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer_test.cpp \
    suggest/policyimpl/dictionary/utils/trie_map_test.cpp \
    utils/autocorrection_threshold_utils_test.cpp \
+59 −0
Original line number Diff line number Diff line
@@ -22,4 +22,63 @@ bool LanguageModelDictContent::save(FILE *const file) const {
    return mTrieMap.save(file);
}

bool LanguageModelDictContent::runGC(
        const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
        const LanguageModelDictContent *const originalContent,
        int *const outNgramCount) {
    return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(),
            0 /* nextLevelBitmapEntryIndex */, outNgramCount);
}

ProbabilityEntry LanguageModelDictContent::getProbabilityEntry(
        const WordIdArrayView prevWordIds, const int wordId) const {
    if (!prevWordIds.empty()) {
        // TODO: Read n-gram entry.
        return ProbabilityEntry();
    }
    const TrieMap::Result result = mTrieMap.getRoot(wordId);
    if (!result.mIsValid) {
        // Not found.
        return ProbabilityEntry();
    }
    return ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
}

bool LanguageModelDictContent::setProbabilityEntry(const WordIdArrayView prevWordIds,
        const int terminalId, const ProbabilityEntry *const probabilityEntry) {
    if (!prevWordIds.empty()) {
        // TODO: Add n-gram entry.
        return false;
    }
    return mTrieMap.putRoot(terminalId, probabilityEntry->encode(mHasHistoricalInfo));
}


bool LanguageModelDictContent::runGCInner(
        const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
        const TrieMap::TrieMapRange trieMapRange,
        const int nextLevelBitmapEntryIndex, int *const outNgramCount) {
    for (auto &entry : trieMapRange) {
        const auto it = terminalIdMap->find(entry.key());
        if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) {
            // The word has been removed.
            continue;
        }
        if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) {
            return false;
        }
        if (outNgramCount) {
            *outNgramCount += 1;
        }
        if (entry.hasNextLevelMap()) {
            if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(),
                    mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex),
                    outNgramCount)) {
                return false;
            }
        }
    }
    return true;
}

} // namespace latinime
+30 −2
Original line number Diff line number Diff line
@@ -20,25 +20,53 @@
#include <cstdio>

#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
#include "suggest/policyimpl/dictionary/utils/trie_map.h"
#include "utils/byte_array_view.h"
#include "utils/int_array_view.h"

namespace latinime {

/**
 * Class representing language model.
 *
 * This class provides methods to get and store unigram/n-gram probability information and flags.
 */
class LanguageModelDictContent {
 public:
    LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
            const bool hasHistoricalInfo)
            : mTrieMap(trieMapBuffer) {}
            : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}

    explicit LanguageModelDictContent(const bool hasHistoricalInfo) : mTrieMap() {}
    explicit LanguageModelDictContent(const bool hasHistoricalInfo)
            : mTrieMap(), mHasHistoricalInfo(hasHistoricalInfo) {}

    bool isNearSizeLimit() const {
        return mTrieMap.isNearSizeLimit();
    }

    bool save(FILE *const file) const;

    bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
            const LanguageModelDictContent *const originalContent,
            int *const outNgramCount);

    ProbabilityEntry getProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId) const;

    bool setProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId,
            const ProbabilityEntry *const probabilityEntry);

 private:
    DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);

    TrieMap mTrieMap;
    const bool mHasHistoricalInfo;

    bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
            const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
            int *const outNgramCount);
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
+52 −0
Original line number Diff line number Diff line
@@ -17,6 +17,9 @@
#ifndef LATINIME_PROBABILITY_ENTRY_H
#define LATINIME_PROBABILITY_ENTRY_H

#include <climits>
#include <cstdint>

#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
#include "suggest/policyimpl/dictionary/utils/historical_info.h"
@@ -67,6 +70,50 @@ class ProbabilityEntry {
        return &mHistoricalInfo;
    }

    uint64_t encode(const bool hasHistoricalInfo) const {
        uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
        if (hasHistoricalInfo) {
            encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mHistoricalInfo.getTimeStamp());
            encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mHistoricalInfo.getLevel());
            encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mHistoricalInfo.getCount());
        } else {
            encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT))
                    ^ static_cast<uint64_t>(mProbability);
        }
        return encodedEntry;
    }

    static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) {
        if (hasHistoricalInfo) {
            const int flags = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
                    Ver4DictConstants::TIME_STAMP_FIELD_SIZE
                            + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE
                            + Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
            const int timestamp = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::TIME_STAMP_FIELD_SIZE,
                    Ver4DictConstants::WORD_LEVEL_FIELD_SIZE
                            + Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
            const int level = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::WORD_LEVEL_FIELD_SIZE,
                    Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
            const int count = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::WORD_COUNT_FIELD_SIZE, 0 /* pos */);
            const HistoricalInfo historicalInfo(timestamp, level, count);
            return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo);
        } else {
            const int flags = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
                    Ver4DictConstants::PROBABILITY_SIZE);
            const int probability = readFromEncodedEntry(encodedEntry,
                    Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */);
            return ProbabilityEntry(flags, probability);
        }
    }

 private:
    // Copy constructor is public to use this class as a type of return value.
    DISALLOW_ASSIGNMENT_OPERATOR(ProbabilityEntry);
@@ -74,6 +121,11 @@ class ProbabilityEntry {
    const int mFlags;
    const int mProbability;
    const HistoricalInfo mHistoricalInfo;

    static int readFromEncodedEntry(const uint64_t encodedEntry, const int size, const int pos) {
        return static_cast<int>(
                (encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1));
    }
};
} // namespace latinime
#endif /* LATINIME_PROBABILITY_ENTRY_H */
+8 −0
Original line number Diff line number Diff line
@@ -90,6 +90,14 @@ class Ver4DictBuffers {
        return &mProbabilityDictContent;
    }

    AK_FORCE_INLINE LanguageModelDictContent *getMutableLanguageModelDictContent() {
        return &mLanguageModelDictContent;
    }

    AK_FORCE_INLINE const LanguageModelDictContent *getLanguageModelDictContent() const {
        return &mLanguageModelDictContent;
    }

    AK_FORCE_INLINE BigramDictContent *getMutableBigramDictContent() {
        return &mBigramDictContent;
    }
Loading