Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit aeddf474 authored by Junyu Lai's avatar Junyu Lai Committed by Automerger Merge Worker
Browse files

Merge "Gracefully handle integer overflows." am: a308f793

Original change: https://android-review.googlesource.com/c/platform/frameworks/base/+/1316217

Change-Id: I779298886cfaef330b433fea6009860682df35bd
parents 03e5841e a308f793
Loading
Loading
Loading
Loading
+25 −11
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ import static android.net.NetworkStatsHistory.DataStreamUtils.writeVarLongArray;
import static android.net.NetworkStatsHistory.Entry.UNKNOWN;
import static android.net.NetworkStatsHistory.ParcelUtils.readLongArray;
import static android.net.NetworkStatsHistory.ParcelUtils.writeLongArray;
import static android.net.NetworkUtils.multiplySafeByRational;
import static android.text.format.DateUtils.SECOND_IN_MILLIS;

import static com.android.internal.util.ArrayUtils.total;
@@ -364,11 +365,12 @@ public class NetworkStatsHistory implements Parcelable {
            if (overlap <= 0) continue;

            // integer math each time is faster than floating point
            final long fracRxBytes = rxBytes * overlap / duration;
            final long fracRxPackets = rxPackets * overlap / duration;
            final long fracTxBytes = txBytes * overlap / duration;
            final long fracTxPackets = txPackets * overlap / duration;
            final long fracOperations = operations * overlap / duration;
            final long fracRxBytes = multiplySafeByRational(rxBytes, overlap, duration);
            final long fracRxPackets = multiplySafeByRational(rxPackets, overlap, duration);
            final long fracTxBytes = multiplySafeByRational(txBytes, overlap, duration);
            final long fracTxPackets = multiplySafeByRational(txPackets, overlap, duration);
            final long fracOperations = multiplySafeByRational(operations, overlap, duration);


            addLong(activeTime, i, overlap);
            addLong(this.rxBytes, i, fracRxBytes); rxBytes -= fracRxBytes;
@@ -568,12 +570,24 @@ public class NetworkStatsHistory implements Parcelable {
            if (overlap <= 0) continue;

            // integer math each time is faster than floating point
            if (activeTime != null) entry.activeTime += activeTime[i] * overlap / bucketSpan;
            if (rxBytes != null) entry.rxBytes += rxBytes[i] * overlap / bucketSpan;
            if (rxPackets != null) entry.rxPackets += rxPackets[i] * overlap / bucketSpan;
            if (txBytes != null) entry.txBytes += txBytes[i] * overlap / bucketSpan;
            if (txPackets != null) entry.txPackets += txPackets[i] * overlap / bucketSpan;
            if (operations != null) entry.operations += operations[i] * overlap / bucketSpan;
            if (activeTime != null) {
                entry.activeTime += multiplySafeByRational(activeTime[i], overlap, bucketSpan);
            }
            if (rxBytes != null) {
                entry.rxBytes += multiplySafeByRational(rxBytes[i], overlap, bucketSpan);
            }
            if (rxPackets != null) {
                entry.rxPackets += multiplySafeByRational(rxPackets[i], overlap, bucketSpan);
            }
            if (txBytes != null) {
                entry.txBytes += multiplySafeByRational(txBytes[i], overlap, bucketSpan);
            }
            if (txPackets != null) {
                entry.txPackets += multiplySafeByRational(txPackets[i], overlap, bucketSpan);
            }
            if (operations != null) {
                entry.operations += multiplySafeByRational(operations[i], overlap, bucketSpan);
            }
        }
        return entry;
    }
+31 −0
Original line number Diff line number Diff line
@@ -476,4 +476,35 @@ public class NetworkUtils {

        return true;
    }

    /**
     * Safely multiple a value by a rational.
     * <p>
     * Internally it uses integer-based math whenever possible, but switches
     * over to double-based math if values would overflow.
     * @hide
     */
    public static long multiplySafeByRational(long value, long num, long den) {
        if (den == 0) {
            throw new ArithmeticException("Invalid Denominator");
        }
        long x = value;
        long y = num;

        // Logic shamelessly borrowed from Math.multiplyExact()
        long r = x * y;
        long ax = Math.abs(x);
        long ay = Math.abs(y);
        if (((ax | ay) >>> 31 != 0)) {
            // Some bits greater than 2^31 that might cause overflow
            // Check the result using the divide operator
            // and check for the special case of Long.MIN_VALUE * -1
            if (((y != 0) && (r / y != x)) ||
                    (x == Long.MIN_VALUE && y == -1)) {
                // Use double math to avoid overflowing
                return (long) (((double) num / den) * value);
            }
        }
        return r / den;
    }
}
+11 −35
Original line number Diff line number Diff line
@@ -28,6 +28,7 @@ import static android.net.NetworkStats.SET_DEFAULT;
import static android.net.NetworkStats.TAG_NONE;
import static android.net.NetworkStats.UID_ALL;
import static android.net.TrafficStats.UID_REMOVED;
import static android.net.NetworkUtils.multiplySafeByRational;
import static android.text.format.DateUtils.WEEK_IN_MILLIS;

import static com.android.server.net.NetworkStatsService.TAG;
@@ -185,35 +186,6 @@ public class NetworkStatsCollection implements FileRotator.Reader {
        }
    }

    /**
     * Safely multiple a value by a rational.
     * <p>
     * Internally it uses integer-based math whenever possible, but switches
     * over to double-based math if values would overflow.
     */
    @VisibleForTesting
    public static long multiplySafe(long value, long num, long den) {
        if (den == 0) den = 1;
        long x = value;
        long y = num;

        // Logic shamelessly borrowed from Math.multiplyExact()
        long r = x * y;
        long ax = Math.abs(x);
        long ay = Math.abs(y);
        if (((ax | ay) >>> 31 != 0)) {
            // Some bits greater than 2^31 that might cause overflow
            // Check the result using the divide operator
            // and check for the special case of Long.MIN_VALUE * -1
            if (((y != 0) && (r / y != x)) ||
                    (x == Long.MIN_VALUE && y == -1)) {
                // Use double math to avoid overflowing
                return (long) (((double) num / den) * value);
            }
        }
        return r / den;
    }

    public int[] getRelevantUids(@NetworkStatsAccess.Level int accessLevel) {
        return getRelevantUids(accessLevel, Binder.getCallingUid());
    }
@@ -311,11 +283,13 @@ public class NetworkStatsCollection implements FileRotator.Reader {
            }

            final long rawBytes = entry.rxBytes + entry.txBytes;
            final long rawRxBytes = entry.rxBytes;
            final long rawTxBytes = entry.txBytes;
            final long rawRxBytes = entry.rxBytes == 0 ? 1 : entry.rxBytes;
            final long rawTxBytes = entry.txBytes == 0 ? 1 : entry.txBytes;
            final long targetBytes = augmentPlan.getDataUsageBytes();
            final long targetRxBytes = multiplySafe(targetBytes, rawRxBytes, rawBytes);
            final long targetTxBytes = multiplySafe(targetBytes, rawTxBytes, rawBytes);

            final long targetRxBytes = multiplySafeByRational(targetBytes, rawRxBytes, rawBytes);
            final long targetTxBytes = multiplySafeByRational(targetBytes, rawTxBytes, rawBytes);


            // Scale all matching buckets to reach anchor target
            final long beforeTotal = combined.getTotalBytes();
@@ -323,8 +297,10 @@ public class NetworkStatsCollection implements FileRotator.Reader {
                combined.getValues(i, entry);
                if (entry.bucketStart >= augmentStart
                        && entry.bucketStart + entry.bucketDuration <= augmentEnd) {
                    entry.rxBytes = multiplySafe(targetRxBytes, entry.rxBytes, rawRxBytes);
                    entry.txBytes = multiplySafe(targetTxBytes, entry.txBytes, rawTxBytes);
                    entry.rxBytes = multiplySafeByRational(
                            targetRxBytes, entry.rxBytes, rawRxBytes);
                    entry.txBytes = multiplySafeByRational(
                            targetTxBytes, entry.txBytes, rawTxBytes);
                    // We purposefully clear out packet counters to indicate
                    // that this data has been augmented.
                    entry.rxPackets = 0;
+16 −13
Original line number Diff line number Diff line
@@ -23,11 +23,12 @@ import static android.net.NetworkStats.TAG_NONE;
import static android.net.NetworkStats.UID_ALL;
import static android.net.NetworkStatsHistory.FIELD_ALL;
import static android.net.NetworkTemplate.buildTemplateMobileAll;
import static android.net.NetworkUtils.multiplySafeByRational;
import static android.os.Process.myUid;
import static android.text.format.DateUtils.HOUR_IN_MILLIS;
import static android.text.format.DateUtils.MINUTE_IN_MILLIS;

import static com.android.server.net.NetworkStatsCollection.multiplySafe;
import static com.android.testutils.MiscAssertsKt.assertThrows;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
@@ -505,23 +506,25 @@ public class NetworkStatsCollectionTest {
    }

    @Test
    public void testMultiplySafe() {
        assertEquals(25, multiplySafe(50, 1, 2));
        assertEquals(100, multiplySafe(50, 2, 1));
    public void testMultiplySafeRational() {
        assertEquals(25, multiplySafeByRational(50, 1, 2));
        assertEquals(100, multiplySafeByRational(50, 2, 1));

        assertEquals(-10, multiplySafe(30, -1, 3));
        assertEquals(0, multiplySafe(30, 0, 3));
        assertEquals(10, multiplySafe(30, 1, 3));
        assertEquals(20, multiplySafe(30, 2, 3));
        assertEquals(30, multiplySafe(30, 3, 3));
        assertEquals(40, multiplySafe(30, 4, 3));
        assertEquals(-10, multiplySafeByRational(30, -1, 3));
        assertEquals(0, multiplySafeByRational(30, 0, 3));
        assertEquals(10, multiplySafeByRational(30, 1, 3));
        assertEquals(20, multiplySafeByRational(30, 2, 3));
        assertEquals(30, multiplySafeByRational(30, 3, 3));
        assertEquals(40, multiplySafeByRational(30, 4, 3));

        assertEquals(100_000_000_000L,
                multiplySafe(300_000_000_000L, 10_000_000_000L, 30_000_000_000L));
                multiplySafeByRational(300_000_000_000L, 10_000_000_000L, 30_000_000_000L));
        assertEquals(100_000_000_010L,
                multiplySafe(300_000_000_000L, 10_000_000_001L, 30_000_000_000L));
                multiplySafeByRational(300_000_000_000L, 10_000_000_001L, 30_000_000_000L));
        assertEquals(823_202_048L,
                multiplySafe(4_939_212_288L, 2_121_815_528L, 12_730_893_165L));
                multiplySafeByRational(4_939_212_288L, 2_121_815_528L, 12_730_893_165L));

        assertThrows(ArithmeticException.class, () -> multiplySafeByRational(30, 3, 0));
    }

    /**