Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c990ed5e authored by Jean Chalard's avatar Jean Chalard Committed by Android Git Automerger
Browse files

am 502c041e: am 5064ac88: Merge "Be careful about the dictionary size in detection methods"

* commit '502c041e':
  Be careful about the dictionary size in detection methods
parents e3ca68aa 502c041e
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -109,7 +109,8 @@ static jlong latinime_BinaryDictionary_open(JNIEnv *env, jclass clazz, jstring s
    }
    Dictionary *dictionary = 0;
    if (BinaryFormat::UNKNOWN_FORMAT
            == BinaryFormat::detectFormat(static_cast<uint8_t *>(dictBuf))) {
            == BinaryFormat::detectFormat(static_cast<uint8_t *>(dictBuf),
                    static_cast<int>(dictSize))) {
        AKLOGE("DICT: dictionary format is unknown, bad magic number");
#ifdef USE_MMAP_FOR_DICTIONARY
        releaseDictBuf(static_cast<const char *>(dictBuf) - adjust, adjDictSize, fd);
+34 −19
Original line number Diff line number Diff line
@@ -64,13 +64,14 @@ class BinaryFormat {
    static const int UNKNOWN_FORMAT = -1;
    static const int SHORTCUT_LIST_SIZE_SIZE = 2;

    static int detectFormat(const uint8_t *const dict);
    static int getHeaderSize(const uint8_t *const dict);
    static int getFlags(const uint8_t *const dict);
    static int detectFormat(const uint8_t *const dict, const int dictSize);
    static int getHeaderSize(const uint8_t *const dict, const int dictSize);
    static int getFlags(const uint8_t *const dict, const int dictSize);
    static bool hasBlacklistedOrNotAWordFlag(const int flags);
    static void readHeaderValue(const uint8_t *const dict, const char *const key, int *outValue,
            const int outValueSize);
    static int readHeaderValueInt(const uint8_t *const dict, const char *const key);
    static void readHeaderValue(const uint8_t *const dict, const int dictSize,
            const char *const key, int *outValue, const int outValueSize);
    static int readHeaderValueInt(const uint8_t *const dict, const int dictSize,
            const char *const key);
    static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos);
    static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos);
    static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos);
@@ -96,7 +97,7 @@ class BinaryFormat {
            const uint8_t *bigramFilter, const int unigramProbability);
    static int getBigramProbabilityFromHashMap(const int position,
            const hash_map_compat<int, int> *bigramMap, const int unigramProbability);
    static float getMultiWordCostMultiplier(const uint8_t *const dict);
    static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize);
    static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position,
            hash_map_compat<int, int> *bigramMap);
    static int getBigramProbability(const uint8_t *const root, int position,
@@ -122,6 +123,8 @@ class BinaryFormat {
    static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20;
    static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30;

    // Any file smaller than this is not a dictionary.
    static const int DICTIONARY_MINIMUM_SIZE = 4;
    // Originally, format version 1 had a 16-bit magic number, then the version number `01'
    // then options that must be 0. Hence the first 32-bits of the format are always as follow
    // and it's okay to consider them a magic number as a whole.
@@ -131,6 +134,8 @@ class BinaryFormat {
    // number, so we had to change it so that version 2 files would be rejected by older
    // implementations. On this occasion, we made the magic number 32 bits long.
    static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE
    // Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12
    static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12;

    static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1;
    static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20;
@@ -141,8 +146,11 @@ class BinaryFormat {
    static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos);
};

AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) {
    // The magic number is stored big-endian.
    // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't
    // understand this format.
    if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT;
    const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3];
    switch (magicNumber) {
    case FORMAT_VERSION_1_MAGIC_NUMBER:
@@ -152,6 +160,10 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
        // Options (2 bytes) must be 0x00 0x00
        return 1;
    case FORMAT_VERSION_2_MAGIC_NUMBER:
        // Version 2 dictionaries are at least 12 bytes long (see below details for the header).
        // If this dictionary has the version 2 magic number but is less than 12 bytes long, then
        // it's an unknown format and we need to avoid confidently reading the next bytes.
        if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT;
        // Format 2 header is as follows:
        // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE
        // Version number (2 bytes) 0x00 0x02
@@ -163,8 +175,8 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
    }
}

inline int BinaryFormat::getFlags(const uint8_t *const dict) {
    switch (detectFormat(dict)) {
inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) {
    switch (detectFormat(dict, dictSize)) {
    case 1:
        return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else?
    default:
@@ -176,8 +188,8 @@ inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) {
    return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0;
}

inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) {
    switch (detectFormat(dict)) {
inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) {
    switch (detectFormat(dict, dictSize)) {
    case 1:
        return FORMAT_VERSION_1_HEADER_SIZE;
    case 2:
@@ -188,12 +200,12 @@ inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) {
    }
}

inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char *const key,
        int *outValue, const int outValueSize) {
inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize,
        const char *const key, int *outValue, const int outValueSize) {
    int outValueIndex = 0;
    // Only format 2 and above have header attributes as {key,value} string pairs. For prior
    // formats, we just return an empty string, as if the key wasn't found.
    if (2 <= detectFormat(dict)) {
    if (2 <= detectFormat(dict, dictSize)) {
        const int headerOptionsOffset = 4 /* magic number */
                + 2 /* dictionary version */ + 2 /* flags */;
        const int headerSize =
@@ -236,11 +248,12 @@ inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char
    if (outValueIndex >= 0) outValue[outValueIndex] = 0;
}

inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const char *const key) {
inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize,
        const char *const key) {
    const int bufferSize = LARGEST_INT_DIGIT_COUNT;
    int intBuffer[bufferSize];
    char charBuffer[bufferSize];
    BinaryFormat::readHeaderValue(dict, key, intBuffer, bufferSize);
    BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize);
    for (int i = 0; i < bufferSize; ++i) {
        charBuffer[i] = intBuffer[i];
    }
@@ -256,8 +269,10 @@ AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t *
    return ((msb & 0x7F) << 8) | dict[(*pos)++];
}

inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict) {
    const int headerValue = readHeaderValueInt(dict, "MULTIPLE_WORDS_DEMOTION_RATE");
inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict,
        const int dictSize) {
    const int headerValue = readHeaderValueInt(dict, dictSize,
            "MULTIPLE_WORDS_DEMOTION_RATE");
    if (headerValue == S_INT_MIN) {
        return 1.0f;
    }
+4 −2
Original line number Diff line number Diff line
@@ -34,9 +34,11 @@ namespace latinime {

Dictionary::Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust)
        : mDict(static_cast<unsigned char *>(dict)),
          mOffsetDict((static_cast<unsigned char *>(dict)) + BinaryFormat::getHeaderSize(mDict)),
          mOffsetDict((static_cast<unsigned char *>(dict))
                  + BinaryFormat::getHeaderSize(mDict, dictSize)),
          mDictSize(dictSize), mMmapFd(mmapFd), mDictBufAdjust(dictBufAdjust),
          mUnigramDictionary(new UnigramDictionary(mOffsetDict, BinaryFormat::getFlags(mDict))),
          mUnigramDictionary(new UnigramDictionary(mOffsetDict,
                  BinaryFormat::getFlags(mDict, dictSize))),
          mBigramDictionary(new BigramDictionary(mOffsetDict)),
          mGestureSuggest(new Suggest(GestureSuggestPolicyFactory::getGestureSuggestPolicy())),
          mTypingSuggest(new Suggest(TypingSuggestPolicyFactory::getTypingSuggestPolicy())) {
+2 −1
Original line number Diff line number Diff line
@@ -64,7 +64,8 @@ static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer;
void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
        int prevWordLength) {
    mDictionary = dictionary;
    mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict());
    mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict(),
            mDictionary->getDictSize());
    if (!prevWord) {
        mPrevWordPos = NOT_VALID_WORD;
        return;