Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e5d60740 authored by Song Pan's avatar Song Pan
Browse files

Fix a bug in BitOutputStream where any trailing zeros will be discarded along

with an optimization to avoid maintaining two copies of data in memory during
toByteArray.

Test: atest frameworks/base/services/tests/servicestests/src/com/android/server/integrity
Bug: 143697198
Change-Id: Ide9256d7bdf5a268920a944206af8ea2b03201a4
parent 10432c77
Loading
Loading
Loading
Loading
+50 −32
Original line number Diff line number Diff line
@@ -16,17 +16,25 @@

package com.android.server.integrity.model;

import java.util.BitSet;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Arrays;

/** A wrapper class for writing a stream of bits. */
public class BitOutputStream {

    private BitSet mBitSet;
    private int mIndex;
    private static final int BUFFER_SIZE = 4 * 1024;
    private static final int BYTE_BITS = 8;

    public BitOutputStream() {
        mBitSet = new BitSet();
        mIndex = 0;
    private int mNextBitIndex;

    private final OutputStream mOutputStream;
    private final byte[] mBuffer;

    public BitOutputStream(OutputStream outputStream) {
        mBuffer = new byte[BUFFER_SIZE];
        mNextBitIndex = 0;
        mOutputStream = outputStream;
    }

    /**
@@ -35,15 +43,17 @@ public class BitOutputStream {
     * @param numOfBits The number of bits used to represent the value.
     * @param value The value to convert to bits.
     */
    public void setNext(int numOfBits, int value) {
    public void setNext(int numOfBits, int value) throws IOException {
        if (numOfBits <= 0) {
            return;
        }
        int offset = 1 << (numOfBits - 1);

        // optional: we can do some clever size checking to "OR" an entire segment of bits instead
        // of setting bits one by one, but it is probably not worth it.
        int nextBitMask = 1 << (numOfBits - 1);
        while (numOfBits-- > 0) {
            mBitSet.set(mIndex, (value & offset) != 0);
            offset >>>= 1;
            mIndex++;
            setNext((value & nextBitMask) != 0);
            nextBitMask >>>= 1;
        }
    }

@@ -52,35 +62,43 @@ public class BitOutputStream {
     *
     * @param value The value to set the bit to.
     */
    public void setNext(boolean value) {
        mBitSet.set(mIndex, value);
        mIndex++;
    public void setNext(boolean value) throws IOException {
        int byteToWrite = mNextBitIndex / BYTE_BITS;
        if (byteToWrite == BUFFER_SIZE) {
            mOutputStream.write(mBuffer);
            reset();
            byteToWrite = 0;
        }
        if (value) {
            mBuffer[byteToWrite] |= 1 << (BYTE_BITS - 1 - (mNextBitIndex % BYTE_BITS));
        }
        mNextBitIndex++;
    }

    /** Set the next bit in the stream to true. */
    public void setNext() {
    public void setNext() throws IOException {
        setNext(/* value= */ true);
    }

    /** Convert BitSet in big-endian to ByteArray in big-endian. */
    public byte[] toByteArray() {
        int bitSetSize = mBitSet.length();
        int numOfBytes = bitSetSize / 8;
        if (bitSetSize % 8 != 0) {
            numOfBytes++;
        }
        byte[] bytes = new byte[numOfBytes];
        for (int i = 0; i < mBitSet.length(); i++) {
            if (mBitSet.get(i)) {
                bytes[i / 8] |= 1 << (7 - (i % 8));
            }
    /**
     * Flush the data written to the underlying {@link java.io.OutputStream}. Any unfinished bytes
     * will be padded with 0.
     */
    public void flush() throws IOException {
        int endByte = mNextBitIndex / BYTE_BITS;
        if (mNextBitIndex % BYTE_BITS != 0) {
            // If next bit is not the first bit of a byte, then mNextBitIndex / BYTE_BITS would be
            // the byte that includes already written bits. We need to increment it so this byte
            // gets written.
            endByte++;
        }
        return bytes;
        mOutputStream.write(mBuffer, 0, endByte);
        reset();
    }

    /** Clear the stream. */
    public void clear() {
        mBitSet.clear();
        mIndex = 0;
    /** Reset this output stream to start state. */
    private void reset() {
        mNextBitIndex = 0;
        Arrays.fill(mBuffer, (byte) 0);
    }
}
+27 −14
Original line number Diff line number Diff line
@@ -23,31 +23,44 @@ import java.io.OutputStream;
 * An output stream that tracks the total number written bytes since construction and allows
 * querying this value any time during the execution.
 *
 * This class is used for constructing the rule indexing.
 * <p>This class is used for constructing the rule indexing.
 */
public class ByteTrackedOutputStream {
public class ByteTrackedOutputStream extends OutputStream {

    private static int sWrittenBytesCount;
    private static OutputStream sOutputStream;
    private static final int INT_BYTES = 4;

    private int mWrittenBytesCount;
    private final OutputStream mOutputStream;

    public ByteTrackedOutputStream(OutputStream outputStream) {
        sWrittenBytesCount = 0;
        sOutputStream = outputStream;
        mWrittenBytesCount = 0;
        mOutputStream = outputStream;
    }

    @Override
    public void write(int b) throws IOException {
        mWrittenBytesCount += INT_BYTES;
        mOutputStream.write(b);
    }

    /**
     * Writes the given bytes into the output stream provided in constructor and updates the
     * total number of written bytes.
     * Writes the given bytes into the output stream provided in constructor and updates the total
     * number of written bytes.
     */
    @Override
    public void write(byte[] bytes) throws IOException {
        sWrittenBytesCount += bytes.length;
        sOutputStream.write(bytes);
        mWrittenBytesCount += bytes.length;
        mOutputStream.write(bytes);
    }

    /**
     * Returns the total number of bytes written into the output stream at the requested time.
     */
    @Override
    public void write(byte[] b, int off, int len) throws IOException {
        mWrittenBytesCount += len;
        mOutputStream.write(b, off, len);
    }

    /** Returns the total number of bytes written into the output stream at the requested time. */
    public int getWrittenBytesCount() {
        return sWrittenBytesCount;
        return mWrittenBytesCount;
    }
}
+31 −34
Original line number Diff line number Diff line
@@ -97,42 +97,38 @@ public class RuleBinarySerializer implements RuleSerializer {
                            ruleFileByteTrackedOutputStream);
            LinkedHashMap<String, Integer> unindexedRulesIndexes =
                    serializeRuleList(
                            indexedRules.get(NOT_INDEXED),
                            ruleFileByteTrackedOutputStream);
                            indexedRules.get(NOT_INDEXED), ruleFileByteTrackedOutputStream);

            // Serialize their indexes.
            BitOutputStream indexingBitOutputStream = new BitOutputStream();
            BitOutputStream indexingBitOutputStream = new BitOutputStream(indexingFileOutputStream);
            serializeIndexGroup(packageNameIndexes, indexingBitOutputStream, /* isIndexed= */ true);
            serializeIndexGroup(appCertificateIndexes, indexingBitOutputStream, /* isIndexed= */
                    true);
            serializeIndexGroup(unindexedRulesIndexes, indexingBitOutputStream, /* isIndexed= */
                    false);
            // TODO(b/147609625): This dummy bit is set for fixing the padding issue. Remove when
            // the issue is fixed and correct the tests that does this padding too.
            indexingBitOutputStream.setNext();
            indexingFileOutputStream.write(indexingBitOutputStream.toByteArray());
            serializeIndexGroup(
                    appCertificateIndexes, indexingBitOutputStream, /* isIndexed= */ true);
            serializeIndexGroup(
                    unindexedRulesIndexes, indexingBitOutputStream, /* isIndexed= */ false);
            indexingBitOutputStream.flush();
        } catch (Exception e) {
            throw new RuleSerializeException(e.getMessage(), e);
        }
    }

    private void serializeRuleFileMetadata(Optional<Integer> formatVersion,
            ByteTrackedOutputStream outputStream)
    private void serializeRuleFileMetadata(
            Optional<Integer> formatVersion, ByteTrackedOutputStream outputStream)
            throws IOException {
        int formatVersionValue = formatVersion.orElse(DEFAULT_FORMAT_VERSION);

        BitOutputStream bitOutputStream = new BitOutputStream();
        BitOutputStream bitOutputStream = new BitOutputStream(outputStream);
        bitOutputStream.setNext(FORMAT_VERSION_BITS, formatVersionValue);
        outputStream.write(bitOutputStream.toByteArray());
        bitOutputStream.flush();
    }

    private LinkedHashMap<String, Integer> serializeRuleList(
            Map<String, List<Rule>> rulesMap, ByteTrackedOutputStream outputStream)
            throws IOException {
        Preconditions.checkArgument(rulesMap != null,
                "serializeRuleList should never be called with null rule list.");
        Preconditions.checkArgument(
                rulesMap != null, "serializeRuleList should never be called with null rule list.");

        BitOutputStream bitOutputStream = new BitOutputStream();
        BitOutputStream bitOutputStream = new BitOutputStream(outputStream);
        LinkedHashMap<String, Integer> indexMapping = new LinkedHashMap();
        indexMapping.put(START_INDEXING_KEY, outputStream.getWrittenBytesCount());

@@ -145,9 +141,8 @@ public class RuleBinarySerializer implements RuleSerializer {
            }

            for (Rule rule : rulesMap.get(key)) {
                bitOutputStream.clear();
                serializeRule(rule, bitOutputStream);
                outputStream.write(bitOutputStream.toByteArray());
                bitOutputStream.flush();
                indexTracker++;
            }
        }
@@ -156,7 +151,7 @@ public class RuleBinarySerializer implements RuleSerializer {
        return indexMapping;
    }

    private void serializeRule(Rule rule, BitOutputStream bitOutputStream) {
    private void serializeRule(Rule rule, BitOutputStream bitOutputStream) throws IOException {
        if (rule == null) {
            throw new IllegalArgumentException("Null rule can not be serialized");
        }
@@ -171,7 +166,8 @@ public class RuleBinarySerializer implements RuleSerializer {
        bitOutputStream.setNext();
    }

    private void serializeFormula(Formula formula, BitOutputStream bitOutputStream) {
    private void serializeFormula(Formula formula, BitOutputStream bitOutputStream)
            throws IOException {
        if (formula instanceof AtomicFormula) {
            serializeAtomicFormula((AtomicFormula) formula, bitOutputStream);
        } else if (formula instanceof CompoundFormula) {
@@ -183,7 +179,7 @@ public class RuleBinarySerializer implements RuleSerializer {
    }

    private void serializeCompoundFormula(
            CompoundFormula compoundFormula, BitOutputStream bitOutputStream) {
            CompoundFormula compoundFormula, BitOutputStream bitOutputStream) throws IOException {
        if (compoundFormula == null) {
            throw new IllegalArgumentException("Null compound formula can not be serialized");
        }
@@ -197,7 +193,7 @@ public class RuleBinarySerializer implements RuleSerializer {
    }

    private void serializeAtomicFormula(
            AtomicFormula atomicFormula, BitOutputStream bitOutputStream) {
            AtomicFormula atomicFormula, BitOutputStream bitOutputStream) throws IOException {
        if (atomicFormula == null) {
            throw new IllegalArgumentException("Null atomic formula can not be serialized");
        }
@@ -231,11 +227,10 @@ public class RuleBinarySerializer implements RuleSerializer {
    private void serializeIndexGroup(
            LinkedHashMap<String, Integer> indexes,
            BitOutputStream bitOutputStream,
            boolean isIndexed) {

            boolean isIndexed)
            throws IOException {
        // Output the starting location of this indexing group.
        serializeStringValue(
                START_INDEXING_KEY, /* isHashedValue= */false, bitOutputStream);
        serializeStringValue(START_INDEXING_KEY, /* isHashedValue= */ false, bitOutputStream);
        serializeIntValue(indexes.get(START_INDEXING_KEY), bitOutputStream);

        // If the group is indexed, output the locations of the indexes.
@@ -243,8 +238,8 @@ public class RuleBinarySerializer implements RuleSerializer {
            for (Map.Entry<String, Integer> entry : indexes.entrySet()) {
                if (!entry.getKey().equals(START_INDEXING_KEY)
                        && !entry.getKey().equals(END_INDEXING_KEY)) {
                    serializeStringValue(entry.getKey(), /* isHashedValue= */false,
                            bitOutputStream);
                    serializeStringValue(
                            entry.getKey(), /* isHashedValue= */ false, bitOutputStream);
                    serializeIntValue(entry.getValue(), bitOutputStream);
                }
            }
@@ -256,7 +251,8 @@ public class RuleBinarySerializer implements RuleSerializer {
    }

    private void serializeStringValue(
            String value, boolean isHashedValue, BitOutputStream bitOutputStream) {
            String value, boolean isHashedValue, BitOutputStream bitOutputStream)
            throws IOException {
        if (value == null) {
            throw new IllegalArgumentException("String value can not be null.");
        }
@@ -269,11 +265,12 @@ public class RuleBinarySerializer implements RuleSerializer {
        }
    }

    private void serializeIntValue(int value, BitOutputStream bitOutputStream) {
    private void serializeIntValue(int value, BitOutputStream bitOutputStream) throws IOException {
        bitOutputStream.setNext(/* numOfBits= */ 32, value);
    }

    private void serializeBooleanValue(boolean value, BitOutputStream bitOutputStream) {
    private void serializeBooleanValue(boolean value, BitOutputStream bitOutputStream)
            throws IOException {
        bitOutputStream.setNext(value);
    }

+4 −4
Original line number Diff line number Diff line
@@ -188,7 +188,7 @@ public class RuleEvaluatorTest {
    }

    @Test
    public void testEvaluateRules_ruleNotInDNF_ignoreAndAllow() {
    public void testEvaluateRules_orRules() {
        CompoundFormula compoundFormula =
                new CompoundFormula(
                        CompoundFormula.OR,
@@ -206,11 +206,11 @@ public class RuleEvaluatorTest {
        IntegrityCheckResult result =
                RuleEvaluator.evaluateRules(Collections.singletonList(rule), APP_INSTALL_METADATA);

        assertEquals(ALLOW, result.getEffect());
        assertEquals(DENY, result.getEffect());
    }

    @Test
    public void testEvaluateRules_compoundFormulaWithNot_allow() {
    public void testEvaluateRules_compoundFormulaWithNot_deny() {
        CompoundFormula openSubFormula =
                new CompoundFormula(
                        CompoundFormula.AND,
@@ -230,7 +230,7 @@ public class RuleEvaluatorTest {
        IntegrityCheckResult result =
                RuleEvaluator.evaluateRules(Collections.singletonList(rule), APP_INSTALL_METADATA);

        assertEquals(ALLOW, result.getEffect());
        assertEquals(DENY, result.getEffect());
    }

    @Test
+4 −4
Original line number Diff line number Diff line
@@ -53,17 +53,17 @@ public class ByteTrackedOutputStreamTest {
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
        ByteTrackedOutputStream byteTrackedOutputStream = new ByteTrackedOutputStream(outputStream);

        BitOutputStream bitOutputStream = new BitOutputStream();
        BitOutputStream bitOutputStream = new BitOutputStream(byteTrackedOutputStream);
        bitOutputStream.setNext(/* numOfBits= */5, /* value= */1);
        byteTrackedOutputStream.write(bitOutputStream.toByteArray());
        bitOutputStream.flush();

        // Even though we wrote 5 bits, this will complete to 1 byte.
        assertThat(byteTrackedOutputStream.getWrittenBytesCount()).isEqualTo(1);

        // Add a bit less than 2 bytes (10 bits).
        bitOutputStream.clear();
        bitOutputStream.setNext(/* numOfBits= */10, /* value= */1);
        byteTrackedOutputStream.write(bitOutputStream.toByteArray());
        bitOutputStream.flush();
        assertThat(byteTrackedOutputStream.getWrittenBytesCount()).isEqualTo(3);

        assertThat(outputStream.toByteArray().length).isEqualTo(3);
    }
Loading