Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e3084595 authored by Jean Chalard's avatar Jean Chalard
Browse files

Compute the correct frequency for bigram prediction

Change-Id: I3196f48a0ca2ed5e94f430254d58e65d341398c8
parent cb993763
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -117,14 +117,17 @@ int BigramDictionary::getBigrams(const int32_t *prevWord, int prevWordLength, in
    do {
        bigramFlags = BinaryFormat::getFlagsAndForwardPointer(root, &pos);
        uint16_t bigramBuffer[MAX_WORD_LENGTH];
        int unigramFreq;
        const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(root, bigramFlags,
                &pos);
        const int length = BinaryFormat::getWordAtAddress(root, bigramPos, MAX_WORD_LENGTH,
                bigramBuffer);
                bigramBuffer, &unigramFreq);

        // codesSize == 0 means we are trying to find bigram predictions.
        if (codesSize < 1 || checkFirstCharacter(bigramBuffer)) {
            const int frequency = UnigramDictionary::MASK_ATTRIBUTE_FREQUENCY & bigramFlags;
            const int bigramFreq = UnigramDictionary::MASK_ATTRIBUTE_FREQUENCY & bigramFlags;
            const int frequency =
                    BinaryFormat::computeFrequencyForBigram(unigramFreq, bigramFreq);
            if (addWordBigram(bigramBuffer, length, frequency)) {
                ++bigramCount;
            }
+4 −2
Original line number Diff line number Diff line
@@ -66,7 +66,7 @@ class BinaryFormat {
    static int getTerminalPosition(const uint8_t* const root, const int32_t* const inWord,
            const int length);
    static int getWordAtAddress(const uint8_t* const root, const int address, const int maxDepth,
            uint16_t* outWord);
            uint16_t* outWord, int* outUnigramFrequency);
    static int computeFrequencyForBigram(const int unigramFreq, const int bigramFreq);
    static int getProbability(const int position, const std::map<int, int> *bigramMap,
            const uint8_t *bigramFilter, const int unigramFreq);
@@ -391,10 +391,11 @@ inline int BinaryFormat::getTerminalPosition(const uint8_t* const root,
 * address: the byte position of the last chargroup of the word we are searching for (this is
 *   what is stored as the "bigram address" in each bigram)
 * outword: an array to write the found word, with MAX_WORD_LENGTH size.
 * outUnigramFrequency: a pointer to an int to write the frequency into.
 * Return value : the length of the word, of 0 if the word was not found.
 */
inline int BinaryFormat::getWordAtAddress(const uint8_t* const root, const int address,
        const int maxDepth, uint16_t* outWord) {
        const int maxDepth, uint16_t* outWord, int* outUnigramFrequency) {
    int pos = 0;
    int wordPos = 0;

@@ -427,6 +428,7 @@ inline int BinaryFormat::getWordAtAddress(const uint8_t* const root, const int a
                        nextChar = getCharCodeAndForwardPointer(root, &pos);
                    }
                }
                *outUnigramFrequency = readFrequencyWithoutMovingPointer(root, pos);
                return ++wordPos;
            }
            // We need to skip past this char group, so skip any remaining chars after the