Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 5c1decfb authored by Keisuke Kuroyanagi's avatar Keisuke Kuroyanagi
Browse files

Add entry iteration method to TrieMap.

Bug: 14425059
Change-Id: I79420b755f29f651d8eed61e7e48b6eb001d8dd2
parent c4f6fc1e
Loading
Loading
Loading
Loading
+37 −0
Original line number Diff line number Diff line
@@ -98,6 +98,43 @@ bool TrieMap::put(const int key, const uint64_t value, const int bitmapEntryInde
    return putInternal(unsignedKey, value, getBitShuffledKey(unsignedKey), bitmapEntryIndex,
            readEntry(bitmapEntryIndex), 0 /* level */);
}
/**
 * Iterate next entry in a certain level.
 *
 * @param iterationState the iteration state that will be read and updated in this method.
 * @param outKey the output key
 * @return Result instance. mIsValid is false when all entries are iterated.
 */
const TrieMap::Result TrieMap::iterateNext(std::vector<TableIterationState> *const iterationState,
        int *const outKey) const {
    while (!iterationState->empty()) {
        TableIterationState &state = iterationState->back();
        if (state.mTableSize <= state.mCurrentIndex) {
            // Move to parent.
            iterationState->pop_back();
        } else {
            const int entryIndex = state.mTableIndex + state.mCurrentIndex;
            state.mCurrentIndex += 1;
            const Entry entry = readEntry(entryIndex);
            if (entry.isBitmapEntry()) {
                // Move to child.
                iterationState->emplace_back(popCount(entry.getBitmap()), entry.getTableIndex());
            } else {
                if (outKey) {
                    *outKey = entry.getKey();
                }
                if (!entry.hasTerminalLink()) {
                    return Result(entry.getValue(), true, INVALID_INDEX);
                }
                const int valueEntryIndex = entry.getValueEntryIndex();
                const Entry valueEntry = readEntry(valueEntryIndex);
                return Result(valueEntry.getValueOfValueEntry(), true, valueEntryIndex + 1);
            }
        }
    }
    // Visited all entries.
    return Result(0, false, INVALID_INDEX);
}

/**
 * Shuffle bits of the key in the fixed order.
+121 −0
Original line number Diff line number Diff line
@@ -44,6 +44,117 @@ class TrieMap {
                  mNextLevelBitmapEntryIndex(nextLevelBitmapEntryIndex) {}
    };

    /**
     * Struct to record iteration state in a table.
     */
    struct TableIterationState {
        int mTableSize;
        int mTableIndex;
        int mCurrentIndex;

        TableIterationState(const int tableSize, const int tableIndex)
                : mTableSize(tableSize), mTableIndex(tableIndex), mCurrentIndex(0) {}
    };

    class TrieMapRange;
    class TrieMapIterator {
     public:
        class IterationResult {
         public:
            IterationResult(const TrieMap *const trieMap, const int key, const uint64_t value,
                    const int nextLeveBitmapEntryIndex)
                    : mTrieMap(trieMap), mKey(key), mValue(value),
                      mNextLevelBitmapEntryIndex(nextLeveBitmapEntryIndex) {}

            const TrieMapRange getEntriesInNextLevel() const {
                return TrieMapRange(mTrieMap, mNextLevelBitmapEntryIndex);
            }

            bool hasNextLevelMap() const {
                return mNextLevelBitmapEntryIndex != INVALID_INDEX;
            }

            AK_FORCE_INLINE int key() const {
                return mKey;
            }

            AK_FORCE_INLINE uint64_t value() const {
                return mValue;
            }

         private:
            const TrieMap *const mTrieMap;
            const int mKey;
            const uint64_t mValue;
            const int mNextLevelBitmapEntryIndex;
        };

        TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex)
                : mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex),
                  mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) {
            if (!trieMap) {
                return;
            }
            const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex);
            mStateStack.emplace_back(
                    mTrieMap->popCount(bitmapEntry.getBitmap()), bitmapEntry.getTableIndex());
            this->operator++();
        }

        const IterationResult operator*() const {
            return IterationResult(mTrieMap, mKey, mValue, mNextLevelBitmapEntryIndex);
        }

        bool operator!=(const TrieMapIterator &other) const {
            // Caveat: This works only for for loops.
            return mIsValid || other.mIsValid;
        }

        const TrieMapIterator &operator++() {
            const Result result = mTrieMap->iterateNext(&mStateStack, &mKey);
            mValue = result.mValue;
            mIsValid = result.mIsValid;
            mNextLevelBitmapEntryIndex = result.mNextLevelBitmapEntryIndex;
            return *this;
        }

     private:
        DISALLOW_DEFAULT_CONSTRUCTOR(TrieMapIterator);
        DISALLOW_ASSIGNMENT_OPERATOR(TrieMapIterator);

        const TrieMap *const mTrieMap;
        std::vector<TrieMap::TableIterationState> mStateStack;
        const int mBaseBitmapEntryIndex;
        int mKey;
        uint64_t mValue;
        bool mIsValid;
        int mNextLevelBitmapEntryIndex;
    };

    /**
     * Class to support iterating entries in TrieMap by range base for loops.
     */
    class TrieMapRange {
     public:
        TrieMapRange(const TrieMap *const trieMap, const int bitmapEntryIndex)
                : mTrieMap(trieMap), mBaseBitmapEntryIndex(bitmapEntryIndex) {};

        TrieMapIterator begin() const {
            return TrieMapIterator(mTrieMap, mBaseBitmapEntryIndex);
        }

        const TrieMapIterator end() const {
            return TrieMapIterator(nullptr, INVALID_INDEX);
        }

     private:
        DISALLOW_DEFAULT_CONSTRUCTOR(TrieMapRange);
        DISALLOW_ASSIGNMENT_OPERATOR(TrieMapRange);

        const TrieMap *const mTrieMap;
        const int mBaseBitmapEntryIndex;
    };

    static const int INVALID_INDEX;
    static const uint64_t MAX_VALUE;

@@ -73,6 +184,14 @@ class TrieMap {

    bool put(const int key, const uint64_t value, const int bitmapEntryIndex);

    const TrieMapRange getEntriesInRootLevel() const {
        return getEntriesInSpecifiedLevel(ROOT_BITMAP_ENTRY_INDEX);
    }

    const TrieMapRange getEntriesInSpecifiedLevel(const int bitmapEntryIndex) const {
        return TrieMapRange(this, bitmapEntryIndex);
    }

 private:
    DISALLOW_COPY_AND_ASSIGN(TrieMap);

@@ -171,6 +290,8 @@ class TrieMap {
    bool addNewEntryByExpandingTable(const uint32_t key, const uint64_t value,
            const int tableIndex, const uint32_t bitmap, const int bitmapEntryIndex,
            const int label);
    const Result iterateNext(std::vector<TableIterationState> *const iterationState,
            int *const outKey) const;

    AK_FORCE_INLINE const Entry readEntry(const int entryIndex) const {
        return Entry(readField0(entryIndex), readField1(entryIndex));
+57 −2
Original line number Diff line number Diff line
@@ -54,7 +54,7 @@ TEST(TrieMapTest, TestSetAndGetLarge) {
        EXPECT_TRUE(trieMap.putRoot(i, i));
    }
    for (int i = 0; i < ELEMENT_COUNT; ++i) {
        EXPECT_EQ(trieMap.getRoot(i).mValue, static_cast<uint64_t>(i));
        EXPECT_EQ(static_cast<uint64_t>(i), trieMap.getRoot(i).mValue);
    }
}

@@ -78,7 +78,7 @@ TEST(TrieMapTest, TestRandSetAndGetLarge) {
        testKeyValuePairs[key] = value;
    }
    for (const auto &v : testKeyValuePairs) {
        EXPECT_EQ(trieMap.getRoot(v.first).mValue, v.second);
        EXPECT_EQ(v.second, trieMap.getRoot(v.first).mValue);
    }
}

@@ -163,6 +163,61 @@ TEST(TrieMapTest, TestMultiLevel) {
            }
        }
    }

    // Iteration
    for (const auto &firstLevelEntry : trieMap.getEntriesInRootLevel()) {
        EXPECT_EQ(trieMap.getRoot(firstLevelEntry.key()).mValue, firstLevelEntry.value());
        EXPECT_EQ(firstLevelEntries[firstLevelEntry.key()], firstLevelEntry.value());
        firstLevelEntries.erase(firstLevelEntry.key());
        for (const auto &secondLevelEntry : firstLevelEntry.getEntriesInNextLevel()) {
            EXPECT_EQ(twoLevelMap[firstLevelEntry.key()][secondLevelEntry.key()],
                    secondLevelEntry.value());
            twoLevelMap[firstLevelEntry.key()].erase(secondLevelEntry.key());
            for (const auto &thirdLevelEntry : secondLevelEntry.getEntriesInNextLevel()) {
                EXPECT_EQ(threeLevelMap[firstLevelEntry.key()][secondLevelEntry.key()]
                        [thirdLevelEntry.key()], thirdLevelEntry.value());
                threeLevelMap[firstLevelEntry.key()][secondLevelEntry.key()].erase(
                        thirdLevelEntry.key());
            }
        }
    }

    // Ensure all entries have been traversed.
    EXPECT_TRUE(firstLevelEntries.empty());
    for (const auto &secondLevelEntry : twoLevelMap) {
        EXPECT_TRUE(secondLevelEntry.second.empty());
    }
    for (const auto &secondLevelEntry : threeLevelMap) {
        for (const auto &thirdLevelEntry : secondLevelEntry.second) {
            EXPECT_TRUE(thirdLevelEntry.second.empty());
        }
    }
}

TEST(TrieMapTest, TestIteration) {
    static const int ELEMENT_COUNT = 200000;
    TrieMap trieMap;
    std::unordered_map<int, uint64_t> testKeyValuePairs;

    // Use the uniform integer distribution [S_INT_MIN, S_INT_MAX].
    std::uniform_int_distribution<int> keyDistribution(S_INT_MIN, S_INT_MAX);
    auto keyRandomNumberGenerator = std::bind(keyDistribution, std::mt19937());

    // Use the uniform distribution [0, TrieMap::MAX_VALUE].
    std::uniform_int_distribution<uint64_t> valueDistribution(0, TrieMap::MAX_VALUE);
    auto valueRandomNumberGenerator = std::bind(valueDistribution, std::mt19937());
    for (int i = 0; i < ELEMENT_COUNT; ++i) {
        const int key = keyRandomNumberGenerator();
        const uint64_t value = valueRandomNumberGenerator();
        EXPECT_TRUE(trieMap.putRoot(key, value));
        testKeyValuePairs[key] = value;
    }
    for (const auto &entry : trieMap.getEntriesInRootLevel()) {
        EXPECT_EQ(trieMap.getRoot(entry.key()).mValue, entry.value());
        EXPECT_EQ(testKeyValuePairs[entry.key()], entry.value());
        testKeyValuePairs.erase(entry.key());
    }
    EXPECT_TRUE(testKeyValuePairs.empty());
}

}  // namespace