Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit a308f793 authored by Junyu Lai's avatar Junyu Lai Committed by Gerrit Code Review
Browse files

Merge "Gracefully handle integer overflows."

parents 3b791360 3708b684
Loading
Loading
Loading
Loading
+25 −11
Original line number Diff line number Diff line
@@ -26,6 +26,7 @@ import static android.net.NetworkStatsHistory.DataStreamUtils.writeVarLongArray;
import static android.net.NetworkStatsHistory.Entry.UNKNOWN;
import static android.net.NetworkStatsHistory.ParcelUtils.readLongArray;
import static android.net.NetworkStatsHistory.ParcelUtils.writeLongArray;
import static android.net.NetworkUtils.multiplySafeByRational;
import static android.text.format.DateUtils.SECOND_IN_MILLIS;

import static com.android.internal.util.ArrayUtils.total;
@@ -364,11 +365,12 @@ public class NetworkStatsHistory implements Parcelable {
            if (overlap <= 0) continue;

            // integer math each time is faster than floating point
            final long fracRxBytes = rxBytes * overlap / duration;
            final long fracRxPackets = rxPackets * overlap / duration;
            final long fracTxBytes = txBytes * overlap / duration;
            final long fracTxPackets = txPackets * overlap / duration;
            final long fracOperations = operations * overlap / duration;
            final long fracRxBytes = multiplySafeByRational(rxBytes, overlap, duration);
            final long fracRxPackets = multiplySafeByRational(rxPackets, overlap, duration);
            final long fracTxBytes = multiplySafeByRational(txBytes, overlap, duration);
            final long fracTxPackets = multiplySafeByRational(txPackets, overlap, duration);
            final long fracOperations = multiplySafeByRational(operations, overlap, duration);


            addLong(activeTime, i, overlap);
            addLong(this.rxBytes, i, fracRxBytes); rxBytes -= fracRxBytes;
@@ -568,12 +570,24 @@ public class NetworkStatsHistory implements Parcelable {
            if (overlap <= 0) continue;

            // integer math each time is faster than floating point
            if (activeTime != null) entry.activeTime += activeTime[i] * overlap / bucketSpan;
            if (rxBytes != null) entry.rxBytes += rxBytes[i] * overlap / bucketSpan;
            if (rxPackets != null) entry.rxPackets += rxPackets[i] * overlap / bucketSpan;
            if (txBytes != null) entry.txBytes += txBytes[i] * overlap / bucketSpan;
            if (txPackets != null) entry.txPackets += txPackets[i] * overlap / bucketSpan;
            if (operations != null) entry.operations += operations[i] * overlap / bucketSpan;
            if (activeTime != null) {
                entry.activeTime += multiplySafeByRational(activeTime[i], overlap, bucketSpan);
            }
            if (rxBytes != null) {
                entry.rxBytes += multiplySafeByRational(rxBytes[i], overlap, bucketSpan);
            }
            if (rxPackets != null) {
                entry.rxPackets += multiplySafeByRational(rxPackets[i], overlap, bucketSpan);
            }
            if (txBytes != null) {
                entry.txBytes += multiplySafeByRational(txBytes[i], overlap, bucketSpan);
            }
            if (txPackets != null) {
                entry.txPackets += multiplySafeByRational(txPackets[i], overlap, bucketSpan);
            }
            if (operations != null) {
                entry.operations += multiplySafeByRational(operations[i], overlap, bucketSpan);
            }
        }
        return entry;
    }
+31 −0
Original line number Diff line number Diff line
@@ -476,4 +476,35 @@ public class NetworkUtils {

        return true;
    }

    /**
     * Safely multiple a value by a rational.
     * <p>
     * Internally it uses integer-based math whenever possible, but switches
     * over to double-based math if values would overflow.
     * @hide
     */
    public static long multiplySafeByRational(long value, long num, long den) {
        if (den == 0) {
            throw new ArithmeticException("Invalid Denominator");
        }
        long x = value;
        long y = num;

        // Logic shamelessly borrowed from Math.multiplyExact()
        long r = x * y;
        long ax = Math.abs(x);
        long ay = Math.abs(y);
        if (((ax | ay) >>> 31 != 0)) {
            // Some bits greater than 2^31 that might cause overflow
            // Check the result using the divide operator
            // and check for the special case of Long.MIN_VALUE * -1
            if (((y != 0) && (r / y != x)) ||
                    (x == Long.MIN_VALUE && y == -1)) {
                // Use double math to avoid overflowing
                return (long) (((double) num / den) * value);
            }
        }
        return r / den;
    }
}
+11 −35
Original line number Diff line number Diff line
@@ -28,6 +28,7 @@ import static android.net.NetworkStats.SET_DEFAULT;
import static android.net.NetworkStats.TAG_NONE;
import static android.net.NetworkStats.UID_ALL;
import static android.net.TrafficStats.UID_REMOVED;
import static android.net.NetworkUtils.multiplySafeByRational;
import static android.text.format.DateUtils.WEEK_IN_MILLIS;

import static com.android.server.net.NetworkStatsService.TAG;
@@ -185,35 +186,6 @@ public class NetworkStatsCollection implements FileRotator.Reader {
        }
    }

    /**
     * Safely multiple a value by a rational.
     * <p>
     * Internally it uses integer-based math whenever possible, but switches
     * over to double-based math if values would overflow.
     */
    @VisibleForTesting
    public static long multiplySafe(long value, long num, long den) {
        if (den == 0) den = 1;
        long x = value;
        long y = num;

        // Logic shamelessly borrowed from Math.multiplyExact()
        long r = x * y;
        long ax = Math.abs(x);
        long ay = Math.abs(y);
        if (((ax | ay) >>> 31 != 0)) {
            // Some bits greater than 2^31 that might cause overflow
            // Check the result using the divide operator
            // and check for the special case of Long.MIN_VALUE * -1
            if (((y != 0) && (r / y != x)) ||
                    (x == Long.MIN_VALUE && y == -1)) {
                // Use double math to avoid overflowing
                return (long) (((double) num / den) * value);
            }
        }
        return r / den;
    }

    public int[] getRelevantUids(@NetworkStatsAccess.Level int accessLevel) {
        return getRelevantUids(accessLevel, Binder.getCallingUid());
    }
@@ -311,11 +283,13 @@ public class NetworkStatsCollection implements FileRotator.Reader {
            }

            final long rawBytes = entry.rxBytes + entry.txBytes;
            final long rawRxBytes = entry.rxBytes;
            final long rawTxBytes = entry.txBytes;
            final long rawRxBytes = entry.rxBytes == 0 ? 1 : entry.rxBytes;
            final long rawTxBytes = entry.txBytes == 0 ? 1 : entry.txBytes;
            final long targetBytes = augmentPlan.getDataUsageBytes();
            final long targetRxBytes = multiplySafe(targetBytes, rawRxBytes, rawBytes);
            final long targetTxBytes = multiplySafe(targetBytes, rawTxBytes, rawBytes);

            final long targetRxBytes = multiplySafeByRational(targetBytes, rawRxBytes, rawBytes);
            final long targetTxBytes = multiplySafeByRational(targetBytes, rawTxBytes, rawBytes);


            // Scale all matching buckets to reach anchor target
            final long beforeTotal = combined.getTotalBytes();
@@ -323,8 +297,10 @@ public class NetworkStatsCollection implements FileRotator.Reader {
                combined.getValues(i, entry);
                if (entry.bucketStart >= augmentStart
                        && entry.bucketStart + entry.bucketDuration <= augmentEnd) {
                    entry.rxBytes = multiplySafe(targetRxBytes, entry.rxBytes, rawRxBytes);
                    entry.txBytes = multiplySafe(targetTxBytes, entry.txBytes, rawTxBytes);
                    entry.rxBytes = multiplySafeByRational(
                            targetRxBytes, entry.rxBytes, rawRxBytes);
                    entry.txBytes = multiplySafeByRational(
                            targetTxBytes, entry.txBytes, rawTxBytes);
                    // We purposefully clear out packet counters to indicate
                    // that this data has been augmented.
                    entry.rxPackets = 0;
+16 −13
Original line number Diff line number Diff line
@@ -23,11 +23,12 @@ import static android.net.NetworkStats.TAG_NONE;
import static android.net.NetworkStats.UID_ALL;
import static android.net.NetworkStatsHistory.FIELD_ALL;
import static android.net.NetworkTemplate.buildTemplateMobileAll;
import static android.net.NetworkUtils.multiplySafeByRational;
import static android.os.Process.myUid;
import static android.text.format.DateUtils.HOUR_IN_MILLIS;
import static android.text.format.DateUtils.MINUTE_IN_MILLIS;

import static com.android.server.net.NetworkStatsCollection.multiplySafe;
import static com.android.testutils.MiscAssertsKt.assertThrows;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
@@ -505,23 +506,25 @@ public class NetworkStatsCollectionTest {
    }

    @Test
    public void testMultiplySafe() {
        assertEquals(25, multiplySafe(50, 1, 2));
        assertEquals(100, multiplySafe(50, 2, 1));
    public void testMultiplySafeRational() {
        assertEquals(25, multiplySafeByRational(50, 1, 2));
        assertEquals(100, multiplySafeByRational(50, 2, 1));

        assertEquals(-10, multiplySafe(30, -1, 3));
        assertEquals(0, multiplySafe(30, 0, 3));
        assertEquals(10, multiplySafe(30, 1, 3));
        assertEquals(20, multiplySafe(30, 2, 3));
        assertEquals(30, multiplySafe(30, 3, 3));
        assertEquals(40, multiplySafe(30, 4, 3));
        assertEquals(-10, multiplySafeByRational(30, -1, 3));
        assertEquals(0, multiplySafeByRational(30, 0, 3));
        assertEquals(10, multiplySafeByRational(30, 1, 3));
        assertEquals(20, multiplySafeByRational(30, 2, 3));
        assertEquals(30, multiplySafeByRational(30, 3, 3));
        assertEquals(40, multiplySafeByRational(30, 4, 3));

        assertEquals(100_000_000_000L,
                multiplySafe(300_000_000_000L, 10_000_000_000L, 30_000_000_000L));
                multiplySafeByRational(300_000_000_000L, 10_000_000_000L, 30_000_000_000L));
        assertEquals(100_000_000_010L,
                multiplySafe(300_000_000_000L, 10_000_000_001L, 30_000_000_000L));
                multiplySafeByRational(300_000_000_000L, 10_000_000_001L, 30_000_000_000L));
        assertEquals(823_202_048L,
                multiplySafe(4_939_212_288L, 2_121_815_528L, 12_730_893_165L));
                multiplySafeByRational(4_939_212_288L, 2_121_815_528L, 12_730_893_165L));

        assertThrows(ArithmeticException.class, () -> multiplySafeByRational(30, 3, 0));
    }

    /**