Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ef93f386 authored by Shai Barack's avatar Shai Barack Committed by Android (Google) Code Review
Browse files

Merge "Remove synchronization (and potential priority inversions) from RateLimitingCache" into main

parents e1d2a135 c59db84c
Loading
Loading
Loading
Loading
+48 −28
Original line number Original line Diff line number Diff line
@@ -17,6 +17,8 @@
package com.android.internal.util;
package com.android.internal.util;


import android.os.SystemClock;
import android.os.SystemClock;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;


/**
/**
 * A speed/rate limiting cache that's used to cache a value to be returned as long as period hasn't
 * A speed/rate limiting cache that's used to cache a value to be returned as long as period hasn't
@@ -30,6 +32,12 @@ import android.os.SystemClock;
 * and then the cached value is returned for the remainder of the period. It uses a simple fixed
 * and then the cached value is returned for the remainder of the period. It uses a simple fixed
 * window method to track rate. Use a window and count appropriate for bursts of calls and for
 * window method to track rate. Use a window and count appropriate for bursts of calls and for
 * high latency/cost of the AIDL call.
 * high latency/cost of the AIDL call.
 * <p>
 * This class is thread-safe. When multiple threads call get(), they will all fetch a new value
 * if the cached value is stale. This is to prevent a slow getting thread from blocking other
 * threads from getting a fresh value. In such circumsntaces it's possible to exceed
 * <code>count</code> calls in a given period by up to the number of threads that are concurrently
 * attempting to get a fresh value minus one.
 *
 *
 * @param <Value> The type of the return value
 * @param <Value> The type of the return value
 * @hide
 * @hide
@@ -37,12 +45,11 @@ import android.os.SystemClock;
@android.ravenwood.annotation.RavenwoodKeepWholeClass
@android.ravenwood.annotation.RavenwoodKeepWholeClass
public class RateLimitingCache<Value> {
public class RateLimitingCache<Value> {


    private volatile Value mCurrentValue;
    private volatile long mLastTimestamp; // Can be last fetch time or window start of fetch time
    private final long mPeriodMillis; // window size
    private final long mPeriodMillis; // window size
    private final int mLimit; // max per window
    private final int mLimit; // max per window
    private int mCount = 0; // current count within window
    // random offset to avoid batching of AIDL calls at window boundary
    private long mRandomOffset; // random offset to avoid batching of AIDL calls at window boundary
    private final long mRandomOffset;
    private final AtomicReference<CachedValue> mCachedValue = new AtomicReference();


    /**
    /**
     * The interface to fetch the actual value, if the cache is null or expired.
     * The interface to fetch the actual value, if the cache is null or expired.
@@ -56,6 +63,12 @@ public class RateLimitingCache<Value> {
        V fetchValue();
        V fetchValue();
    }
    }


    class CachedValue {
        Value value;
        long timestamp;
        AtomicInteger count; // current count within window
    }

    /**
    /**
     * Create a speed limiting cache that returns the same value until periodMillis has passed
     * Create a speed limiting cache that returns the same value until periodMillis has passed
     * and then fetches a new value via the {@link ValueFetcher}.
     * and then fetches a new value via the {@link ValueFetcher}.
@@ -83,6 +96,8 @@ public class RateLimitingCache<Value> {
        mLimit = count;
        mLimit = count;
        if (mLimit > 1 && periodMillis > 1) {
        if (mLimit > 1 && periodMillis > 1) {
            mRandomOffset = (long) (Math.random() * (periodMillis / 2));
            mRandomOffset = (long) (Math.random() * (periodMillis / 2));
        } else {
            mRandomOffset = 0;
        }
        }
    }
    }


@@ -102,34 +117,39 @@ public class RateLimitingCache<Value> {
     * @return the cached or latest value
     * @return the cached or latest value
     */
     */
    public Value get(ValueFetcher<Value> query) {
    public Value get(ValueFetcher<Value> query) {
        // If the value never changes
        CachedValue cached = mCachedValue.get();
        if (mPeriodMillis < 0 && mLastTimestamp != 0) {

            return mCurrentValue;
        // If the value never changes and there is a previous cached value, return it
        if (mPeriodMillis < 0 && cached != null && cached.timestamp != 0) {
            return cached.value;
        }
        }


        synchronized (this) {
        // Get the current time and add a random offset to avoid colliding with other
        // Get the current time and add a random offset to avoid colliding with other
        // caches with similar harmonic window boundaries
        // caches with similar harmonic window boundaries
        final long now = getTime() + mRandomOffset;
        final long now = getTime() + mRandomOffset;
            final boolean newWindow = now - mLastTimestamp >= mPeriodMillis;
        final boolean newWindow = cached == null || now - cached.timestamp >= mPeriodMillis;
            if (newWindow || mCount < mLimit) {
        if (newWindow || cached.count.getAndIncrement() < mLimit) {
            // Fetch a new value
            // Fetch a new value
                mCurrentValue = query.fetchValue();
            Value freshValue = query.fetchValue();

            long freshTimestamp = now;
            // If rate limiting, set timestamp to start of this window
            // If rate limiting, set timestamp to start of this window
            if (mLimit > 1) {
            if (mLimit > 1) {
                    mLastTimestamp = now - (now % mPeriodMillis);
                freshTimestamp = now - (now % mPeriodMillis);
                } else {
                    mLastTimestamp = now;
            }
            }


            CachedValue freshCached = new CachedValue();
            freshCached.value = freshValue;
            freshCached.timestamp = freshTimestamp;
            if (newWindow) {
            if (newWindow) {
                    mCount = 1;
                freshCached.count = new AtomicInteger(1);
            } else {
            } else {
                    mCount++;
                freshCached.count = cached.count;
            }
            }

            // If we fail to CAS then it means that another thread beat us to it.
            // In this case we don't override their work.
            mCachedValue.compareAndSet(cached, freshCached);
        }
        }
            return mCurrentValue;
        return mCachedValue.get().value;
        }
    }
    }
}
}
+144 −1
Original line number Original line Diff line number Diff line
@@ -18,9 +18,15 @@ package com.android.internal.util;


import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import android.os.SystemClock;


import androidx.test.runner.AndroidJUnit4;
import androidx.test.runner.AndroidJUnit4;


import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.CountDownLatch;
import org.junit.Before;
import org.junit.Before;
import org.junit.Test;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runner.RunWith;
@@ -38,7 +44,7 @@ public class RateLimitingCacheTest {
        mCounter = -1;
        mCounter = -1;
    }
    }


    RateLimitingCache.ValueFetcher<Integer> mFetcher = () -> {
    private final RateLimitingCache.ValueFetcher<Integer> mFetcher = () -> {
        return ++mCounter;
        return ++mCounter;
    };
    };


@@ -119,6 +125,143 @@ public class RateLimitingCacheTest {
        assertCount(s, 2000, 20, 33);
        assertCount(s, 2000, 20, 33);
    }
    }


    /**
     * Exercises concurrent access to the cache.
     */
    @Test
    public void testMultipleThreads() throws InterruptedException {
        final long periodMillis = 1000;
        final int maxCountPerPeriod = 10;
        final RateLimitingCache<Integer> s =
                new RateLimitingCache<>(periodMillis, maxCountPerPeriod);

        Thread t1 = new Thread(() -> {
            for (int i = 0; i < 100; i++) {
                s.get(mFetcher);
            }
        });
        Thread t2 = new Thread(() -> {
            for (int i = 0; i < 100; i++) {
                s.get(mFetcher);
            }
        });

        final long startTimeMillis = SystemClock.elapsedRealtime();
        t1.start();
        t2.start();
        t1.join();
        t2.join();
        final long endTimeMillis = SystemClock.elapsedRealtime();

        final long periodsElapsed = 1 + ((endTimeMillis - startTimeMillis) / periodMillis);
        final long expected = Math.min(100 + 100, periodsElapsed * maxCountPerPeriod) - 1;
        assertEquals(mCounter, expected);
    }

    /**
     * Multiple threads calling get() on the cache while the cached value is stale are allowed
     * to fetch, regardless of the rate limiting.
     * This is to prevent a slow getting thread from blocking other threads from getting a fresh
     * value.
     */
    @Test
    public void testMultipleThreads_oneThreadIsSlow() throws InterruptedException {
        final long periodMillis = 1000;
        final int maxCountPerPeriod = 1;
        final RateLimitingCache<Integer> s =
                new RateLimitingCache<>(periodMillis, maxCountPerPeriod);

        final CountDownLatch latch1 = new CountDownLatch(2);
        final CountDownLatch latch2 = new CountDownLatch(1);
        final AtomicInteger counter = new AtomicInteger();
        final RateLimitingCache.ValueFetcher<Integer> fetcher = () -> {
            latch1.countDown();
            try {
                latch2.await();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
            return counter.incrementAndGet();
        };

        Thread t1 = new Thread(() -> {
            for (int i = 0; i < 100; i++) {
                s.get(fetcher);
            }
        });
        Thread t2 = new Thread(() -> {
            for (int i = 0; i < 100; i++) {
                s.get(fetcher);
            }
        });

        t1.start();
        t2.start();
        // Both threads should be admitted to fetch because there is no fresh cached value,
        // even though this exceeds the rate limit of at most 1 call per period.
        // Wait for both threads to be fetching.
        latch1.await();
        // Allow the fetcher to return.
        latch2.countDown();
        // Wait for both threads to finish their fetches.
        t1.join();
        t2.join();

        assertEquals(counter.get(), 2);
    }

    /**
     * Even if multiple threads race to refresh the cache, only one thread gets to set a new value.
     * This ensures, among other things, that the cache never returns values that were fetched out
     * of order.
     */
    @Test
    public void testMultipleThreads_cachedValueNeverGoesBackInTime() throws InterruptedException {
        final long periodMillis = 10;
        final int maxCountPerPeriod = 3;
        final RateLimitingCache<Integer> s =
                new RateLimitingCache<>(periodMillis, maxCountPerPeriod);
        final AtomicInteger counter = new AtomicInteger();
        final RateLimitingCache.ValueFetcher<Integer> fetcher = () -> {
            // Note that this fetcher has a side effect, which is strictly not allowed for
            // RateLimitingCache users, but we make an exception for the purpose of this test.
            return counter.incrementAndGet();
        };

        // Make three threads that spin on getting from the cache
        final AtomicBoolean shouldRun = new AtomicBoolean(true);
        Runnable worker = new Runnable() {
            @Override
            public void run() {
                while (shouldRun.get()) {
                    s.get(fetcher);
                }
            }
        };
        Thread t1 = new Thread(worker);
        Thread t2 = new Thread(worker);
        Thread t3 = new Thread(worker);
        t1.start();
        t2.start();
        t3.start();

        // Get values until a sufficiently convincing high value while ensuring that values are
        // monotonically non-decreasing.
        int lastSeen = 0;
        while (lastSeen < 10000) {
            int value = s.get(fetcher);
            if (value < lastSeen) {
                fail("Unexpectedly saw decreasing value " + value + " after " + lastSeen);
            }
            lastSeen = value;
        }

        shouldRun.set(false);
        t1.join();
        t2.join();
        t3.join();
    }

    /**
    /**
     * Helper to make repeated calls every 5 millis to verify the number of expected fetches for
     * Helper to make repeated calls every 5 millis to verify the number of expected fetches for
     * the given parameters.
     * the given parameters.