Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ab58405b authored by Jack He's avatar Jack He
Browse files

LruCache: Improve efficieny and ease of use of APIs

* Add a Find() function that returns the pointer to the value associated
  with a key, further changes to the value using that poitner does not
  warm up the cache
* Remove eviction callback, but instead return an optional evicted node
  when Put() evicts a cold node. This prevents potential deadlock when
  calling LruCache methods in the callback
* HasKey() is not zero-copy
* Get() calls Find()
* Add unit tests for these new features
* Modify MetricIdAllocator to use these new features

Bug: 143515989
Test: atest --host bluetooth_test_common
Change-Id: I9071c86a9041e5c95b349824889ccedf9f9c18dc
parent b196e223
Loading
Loading
Loading
Loading
+55 −49
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@
#include <iterator>
#include <list>
#include <mutex>
#include <optional>
#include <thread>
#include <unordered_map>

@@ -35,35 +36,54 @@ template <typename K, typename V>
class LruCache {
 public:
  using Node = std::pair<K, V>;
  using LruEvictionCallback = std::function<void(K, V)>;
  /**
   * Constructor of the cache
   *
   * @param capacity maximum size of the cache
   * @param log_tag, keyword to put at the head of log.
   * @param lru_eviction_callback a call back will be called when the cache is
   * full and Put() is called
   */
  LruCache(const size_t& capacity, const std::string& log_tag,
           LruEvictionCallback lru_eviction_callback)
      : capacity_(capacity), lru_eviction_callback_(lru_eviction_callback) {
  LruCache(const size_t& capacity, const std::string& log_tag)
      : capacity_(capacity) {
    if (capacity_ == 0) {
      // don't allow invalid capacity
      LOG(FATAL) << log_tag << " unable to have 0 LRU Cache capacity";
    }
  }

  // delete copy constructor
  LruCache(LruCache const&) = delete;
  LruCache& operator=(LruCache const&) = delete;

  ~LruCache() { Clear(); }

  /**
   * Clear the cache
   */
  void Clear() {
    std::lock_guard<std::mutex> lock(lru_mutex_);
    std::lock_guard<std::recursive_mutex> lock(lru_mutex_);
    lru_map_.clear();
    node_list_.clear();
  }

  /**
   * Same as Get, but return an iterator to the accessed element
   *
   * Modifying the returned iterator does not warm up the cache
   *
   * @param key
   * @return pointer to the underlying value to allow in-place modification
   * nullptr when not found, will be invalidated when the key is evicted
   */
  V* Find(const K& key) {
    std::lock_guard<std::recursive_mutex> lock(lru_mutex_);
    auto map_iterator = lru_map_.find(key);
    if (map_iterator == lru_map_.end()) {
      return nullptr;
    }
    node_list_.splice(node_list_.begin(), node_list_, map_iterator->second);
    return &(map_iterator->second->second);
  }

  /**
   * Get the value of a key, and move the key to the head of cache, if there is
   * one
@@ -73,19 +93,13 @@ class LruCache {
   * @return true if the cache has the key
   */
  bool Get(const K& key, V* value) {
    std::lock_guard<std::mutex> lock(lru_mutex_);
    auto map_iterator = lru_map_.find(key);
    if (map_iterator == lru_map_.end()) {
    CHECK(value != nullptr);
    std::lock_guard<std::recursive_mutex> lock(lru_mutex_);
    auto value_ptr = Find(key);
    if (value_ptr == nullptr) {
      return false;
    }
    auto& list_iterator = map_iterator->second;
    auto node = *list_iterator;
    node_list_.erase(list_iterator);
    node_list_.push_front(node);
    map_iterator->second = node_list_.begin();
    if (value != nullptr) {
      *value = node.second;
    }
    *value = *value_ptr;
    return true;
  }

@@ -97,8 +111,8 @@ class LruCache {
   * @return true if the cache has the key
   */
  bool HasKey(const K& key) {
    V dummy_value;
    return Get(key, &dummy_value);
    std::lock_guard<std::recursive_mutex> lock(lru_mutex_);
    return Find(key) != nullptr;
  }

  /**
@@ -106,52 +120,49 @@ class LruCache {
   *
   * @param key
   * @param value
   * @return true if tail value is popped
   * @return evicted node if tail value is popped, std::nullopt if no value
   * is popped. std::optional can be treated as a boolean as well
   */
  bool Put(const K& key, const V& value) {
    if (HasKey(key)) {
  std::optional<Node> Put(const K& key, V value) {
    std::lock_guard<std::recursive_mutex> lock(lru_mutex_);
    auto value_ptr = Find(key);
    if (value_ptr != nullptr) {
      // hasKey() calls get(), therefore already move the node to the head
      std::lock_guard<std::mutex> lock(lru_mutex_);
      lru_map_[key]->second = value;
      return false;
      *value_ptr = std::move(value);
      return std::nullopt;
    }

    bool value_popped = false;
    std::lock_guard<std::mutex> lock(lru_mutex_);
    // remove tail
    std::optional<Node> ret = std::nullopt;
    if (lru_map_.size() == capacity_) {
      lru_map_.erase(node_list_.back().first);
      K key_evicted = node_list_.back().first;
      V value_evicted = node_list_.back().second;
      ret = std::move(node_list_.back());
      node_list_.pop_back();
      lru_eviction_callback_(key_evicted, value_evicted);
      value_popped = true;
    }
    // insert to dummy next;
    Node add(key, value);
    node_list_.push_front(add);
    lru_map_[key] = node_list_.begin();
    return value_popped;
    node_list_.emplace_front(key, std::move(value));
    lru_map_.emplace(key, node_list_.begin());
    return ret;
  }

  /**
   * Delete a key from cache
   *
   * @param key
   * @return true if delete successfully
   * @return true if deleted successfully
   */
  bool Remove(const K& key) {
    std::lock_guard<std::mutex> lock(lru_mutex_);
    if (lru_map_.count(key) == 0) {
    std::lock_guard<std::recursive_mutex> lock(lru_mutex_);
    auto map_iterator = lru_map_.find(key);
    if (map_iterator == lru_map_.end()) {
      return false;
    }

    // remove from the list
    auto& iterator = lru_map_[key];
    node_list_.erase(iterator);
    node_list_.erase(map_iterator->second);

    // delete key from map
    lru_map_.erase(key);
    lru_map_.erase(map_iterator);

    return true;
  }
@@ -162,7 +173,7 @@ class LruCache {
   * @return size of the cache
   */
  int Size() const {
    std::lock_guard<std::mutex> lock(lru_mutex_);
    std::lock_guard<std::recursive_mutex> lock(lru_mutex_);
    return lru_map_.size();
  }

@@ -170,12 +181,7 @@ class LruCache {
  std::list<Node> node_list_;
  size_t capacity_;
  std::unordered_map<K, typename std::list<Node>::iterator> lru_map_;
  LruEvictionCallback lru_eviction_callback_;
  mutable std::mutex lru_mutex_;

  // delete copy constructor
  LruCache(LruCache const&) = delete;
  LruCache& operator=(LruCache const&) = delete;
  mutable std::recursive_mutex lru_mutex_;
};

}  // namespace common
+37 −34
Original line number Diff line number Diff line
@@ -30,16 +30,12 @@ using bluetooth::common::LruCache;

TEST(BluetoothLruCacheTest, LruCacheMainTest1) {
  int* value = new int(0);
  int dummy = 0;
  int* pointer = &dummy;
  auto callback = [pointer](int a, int b) { (*pointer) = a * b; };
  LruCache<int, int> cache(3, "testing", callback);  // capacity = 3;
  LruCache<int, int> cache(3, "testing");  // capacity = 3;
  cache.Put(1, 10);
  EXPECT_EQ(cache.Size(), 1);
  cache.Put(2, 20);
  cache.Put(3, 30);
  EXPECT_FALSE(cache.Put(2, 20));
  EXPECT_FALSE(cache.Put(3, 30));
  EXPECT_EQ(cache.Size(), 3);
  EXPECT_EQ(dummy, 0);

  // 1, 2, 3 should be in cache
  EXPECT_TRUE(cache.Get(1, value));
@@ -50,8 +46,7 @@ TEST(BluetoothLruCacheTest, LruCacheMainTest1) {
  EXPECT_EQ(*value, 30);
  EXPECT_EQ(cache.Size(), 3);

  cache.Put(4, 40);
  EXPECT_EQ(dummy, 10);
  EXPECT_THAT(cache.Put(4, 40), Optional(Pair(1, 10)));
  // 2, 3, 4 should be in cache, 1 is evicted
  EXPECT_FALSE(cache.Get(1, value));
  EXPECT_TRUE(cache.Get(4, value));
@@ -61,13 +56,12 @@ TEST(BluetoothLruCacheTest, LruCacheMainTest1) {
  EXPECT_TRUE(cache.Get(3, value));
  EXPECT_EQ(*value, 30);

  cache.Put(5, 50);
  EXPECT_THAT(cache.Put(5, 50), Optional(Pair(4, 40)));
  EXPECT_EQ(cache.Size(), 3);
  EXPECT_EQ(dummy, 160);
  // 2, 3, 5 should be in cache, 4 is evicted

  EXPECT_TRUE(cache.Remove(3));
  cache.Put(6, 60);
  EXPECT_FALSE(cache.Put(6, 60));
  // 2, 5, 6 should be in cache

  EXPECT_FALSE(cache.Get(3, value));
@@ -82,17 +76,11 @@ TEST(BluetoothLruCacheTest, LruCacheMainTest1) {

TEST(BluetoothLruCacheTest, LruCacheMainTest2) {
  int* value = new int(0);
  int dummy = 0;
  int* pointer = &dummy;
  auto callback = [pointer](int a, int b) { (*pointer)++; };
  LruCache<int, int> cache(2, "testing", callback);  // size = 2;
  cache.Put(1, 10);
  cache.Put(2, 20);
  EXPECT_EQ(dummy, 0);
  cache.Put(3, 30);
  EXPECT_EQ(dummy, 1);
  cache.Put(2, 200);
  EXPECT_EQ(dummy, 1);
  LruCache<int, int> cache(2, "testing");  // size = 2;
  EXPECT_FALSE(cache.Put(1, 10));
  EXPECT_FALSE(cache.Put(2, 20));
  EXPECT_THAT(cache.Put(3, 30), Optional(Pair(1, 10)));
  EXPECT_FALSE(cache.Put(2, 200));
  EXPECT_EQ(cache.Size(), 2);
  // 3, 2 should be in cache

@@ -102,8 +90,7 @@ TEST(BluetoothLruCacheTest, LruCacheMainTest2) {
  EXPECT_TRUE(cache.Get(3, value));
  EXPECT_EQ(*value, 30);

  cache.Put(4, 40);
  EXPECT_EQ(dummy, 2);
  EXPECT_THAT(cache.Put(4, 40), Optional(Pair(2, 200)));
  // 3, 4 should be in cache

  EXPECT_FALSE(cache.HasKey(2));
@@ -139,22 +126,39 @@ TEST(BluetoothLruCacheTest, LruCacheMainTest2) {
  EXPECT_EQ(*value, 50);
}

TEST(BluetoothLruCacheTest, LruCacheFindTest) {
  LruCache<int, int> cache(10, "testing");
  cache.Put(1, 10);
  cache.Put(2, 20);
  int value = 0;
  EXPECT_TRUE(cache.Get(1, &value));
  EXPECT_EQ(value, 10);
  auto value_ptr = cache.Find(1);
  EXPECT_NE(value_ptr, nullptr);
  *value_ptr = 20;
  EXPECT_TRUE(cache.Get(1, &value));
  EXPECT_EQ(value, 20);
  cache.Put(1, 40);
  EXPECT_EQ(*value_ptr, 40);
  EXPECT_EQ(cache.Find(10), nullptr);
}

TEST(BluetoothLruCacheTest, LruCacheGetTest) {
  LruCache<int, int> cache(10, "testing", [](int a, int b) {});
  LruCache<int, int> cache(10, "testing");
  cache.Put(1, 10);
  cache.Put(2, 20);
  int value = 0;
  EXPECT_TRUE(cache.Get(1, &value));
  EXPECT_EQ(value, 10);
  EXPECT_TRUE(cache.Get(1, nullptr));
  EXPECT_TRUE(cache.Get(2, nullptr));
  EXPECT_FALSE(cache.Get(3, nullptr));
  EXPECT_TRUE(cache.HasKey(1));
  EXPECT_TRUE(cache.HasKey(2));
  EXPECT_FALSE(cache.HasKey(3));
  EXPECT_FALSE(cache.Get(3, &value));
  EXPECT_EQ(value, 10);
}

TEST(BluetoothLruCacheTest, LruCacheRemoveTest) {
  LruCache<int, int> cache(10, "testing", [](int a, int b) {});
  LruCache<int, int> cache(10, "testing");
  for (int key = 0; key <= 30; key++) {
    cache.Put(key, key * 100);
  }
@@ -173,7 +177,7 @@ TEST(BluetoothLruCacheTest, LruCacheRemoveTest) {
}

TEST(BluetoothLruCacheTest, LruCacheClearTest) {
  LruCache<int, int> cache(10, "testing", [](int a, int b) {});
  LruCache<int, int> cache(10, "testing");
  for (int key = 0; key < 10; key++) {
    cache.Put(key, key * 100);
  }
@@ -196,8 +200,7 @@ TEST(BluetoothLruCacheTest, LruCacheClearTest) {
TEST(BluetoothLruCacheTest, LruCachePressureTest) {
  auto started = std::chrono::high_resolution_clock::now();
  int max_size = 0xFFFFF;  // 2^20 = 1M
  LruCache<int, int> cache(static_cast<size_t>(max_size), "testing",
                           [](int a, int b) {});
  LruCache<int, int> cache(static_cast<size_t>(max_size), "testing");

  // fill the cache
  for (int key = 0; key < max_size; key++) {
@@ -237,7 +240,7 @@ TEST(BluetoothLruCacheTest, LruCachePressureTest) {
}

TEST(BluetoothLruCacheTest, BluetoothLruMultiThreadPressureTest) {
  LruCache<int, int> cache(100, "testing", [](int a, int b) {});
  LruCache<int, int> cache(100, "testing");
  auto pointer = &cache;
  // make sure no deadlock
  std::vector<std::thread> workers;
+15 −11
Original line number Diff line number Diff line
@@ -42,13 +42,8 @@ static_assert((MetricIdAllocator::kMaxNumUnpairedDevicesInMemory +
              "kMaxNumPairedDevicesInMemory + MaxNumUnpairedDevicesInMemory");

MetricIdAllocator::MetricIdAllocator()
    : paired_device_cache_(kMaxNumPairedDevicesInMemory, LOGGING_TAG,
                           [this](RawAddress mac_address, int id) {
                             ForgetDevicePostprocess(mac_address, id);
                           }),
      temporary_device_cache_(
          kMaxNumUnpairedDevicesInMemory, LOGGING_TAG,
          [this](RawAddress dummy, int id) { this->id_set_.erase(id); }) {}
    : paired_device_cache_(kMaxNumPairedDevicesInMemory, LOGGING_TAG),
      temporary_device_cache_(kMaxNumUnpairedDevicesInMemory, LOGGING_TAG) {}

bool MetricIdAllocator::Init(
    const std::unordered_map<RawAddress, int>& paired_device_map,
@@ -68,11 +63,14 @@ bool MetricIdAllocator::Init(
  }

  next_id_ = kMinId;
  for (const std::pair<RawAddress, int>& p : paired_device_map) {
  for (const auto& p : paired_device_map) {
    if (p.second < kMinId || p.second > kMaxId) {
      LOG(FATAL) << LOGGING_TAG << "Invalid Bluetooth Metric Id in config";
    }
    paired_device_cache_.Put(p.first, p.second);
    auto evicted = paired_device_cache_.Put(p.first, p.second);
    if (evicted) {
      ForgetDevicePostprocess(evicted->first, evicted->second);
    }
    id_set_.insert(p.second);
    next_id_ = std::max(next_id_, p.second + 1);
  }
@@ -134,7 +132,10 @@ int MetricIdAllocator::AllocateId(const RawAddress& mac_address) {
  }
  id = next_id_++;
  id_set_.insert(id);
  temporary_device_cache_.Put(mac_address, id);
  auto evicted = temporary_device_cache_.Put(mac_address, id);
  if (evicted) {
    this->id_set_.erase(evicted->second);
  }

  if (next_id_ > kMaxId) {
    next_id_ = kMinId;
@@ -160,7 +161,10 @@ bool MetricIdAllocator::SaveDevice(const RawAddress& mac_address) {
               << "Failed to remove device from temporary_device_cache_";
    return false;
  }
  paired_device_cache_.Put(mac_address, id);
  auto evicted = paired_device_cache_.Put(mac_address, id);
  if (evicted) {
    ForgetDevicePostprocess(evicted->first, evicted->second);
  }
  if (!save_id_callback_(mac_address, id)) {
    LOG(ERROR) << LOGGING_TAG
               << "Callback returned false after saving the device";