Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e5d60740 authored by Song Pan's avatar Song Pan
Browse files

Fix a bug in BitOutputStream where any trailing zeros will be discarded along

with an optimization to avoid maintaining two copies of data in memory during
toByteArray.

Test: atest frameworks/base/services/tests/servicestests/src/com/android/server/integrity
Bug: 143697198
Change-Id: Ide9256d7bdf5a268920a944206af8ea2b03201a4
parent 10432c77
Loading
Loading
Loading
Loading
+50 −32
Original line number Original line Diff line number Diff line
@@ -16,17 +16,25 @@


package com.android.server.integrity.model;
package com.android.server.integrity.model;


import java.util.BitSet;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Arrays;


/** A wrapper class for writing a stream of bits. */
/** A wrapper class for writing a stream of bits. */
public class BitOutputStream {
public class BitOutputStream {


    private BitSet mBitSet;
    private static final int BUFFER_SIZE = 4 * 1024;
    private int mIndex;
    private static final int BYTE_BITS = 8;


    public BitOutputStream() {
    private int mNextBitIndex;
        mBitSet = new BitSet();

        mIndex = 0;
    private final OutputStream mOutputStream;
    private final byte[] mBuffer;

    public BitOutputStream(OutputStream outputStream) {
        mBuffer = new byte[BUFFER_SIZE];
        mNextBitIndex = 0;
        mOutputStream = outputStream;
    }
    }


    /**
    /**
@@ -35,15 +43,17 @@ public class BitOutputStream {
     * @param numOfBits The number of bits used to represent the value.
     * @param numOfBits The number of bits used to represent the value.
     * @param value The value to convert to bits.
     * @param value The value to convert to bits.
     */
     */
    public void setNext(int numOfBits, int value) {
    public void setNext(int numOfBits, int value) throws IOException {
        if (numOfBits <= 0) {
        if (numOfBits <= 0) {
            return;
            return;
        }
        }
        int offset = 1 << (numOfBits - 1);

        // optional: we can do some clever size checking to "OR" an entire segment of bits instead
        // of setting bits one by one, but it is probably not worth it.
        int nextBitMask = 1 << (numOfBits - 1);
        while (numOfBits-- > 0) {
        while (numOfBits-- > 0) {
            mBitSet.set(mIndex, (value & offset) != 0);
            setNext((value & nextBitMask) != 0);
            offset >>>= 1;
            nextBitMask >>>= 1;
            mIndex++;
        }
        }
    }
    }


@@ -52,35 +62,43 @@ public class BitOutputStream {
     *
     *
     * @param value The value to set the bit to.
     * @param value The value to set the bit to.
     */
     */
    public void setNext(boolean value) {
    public void setNext(boolean value) throws IOException {
        mBitSet.set(mIndex, value);
        int byteToWrite = mNextBitIndex / BYTE_BITS;
        mIndex++;
        if (byteToWrite == BUFFER_SIZE) {
            mOutputStream.write(mBuffer);
            reset();
            byteToWrite = 0;
        }
        if (value) {
            mBuffer[byteToWrite] |= 1 << (BYTE_BITS - 1 - (mNextBitIndex % BYTE_BITS));
        }
        mNextBitIndex++;
    }
    }


    /** Set the next bit in the stream to true. */
    /** Set the next bit in the stream to true. */
    public void setNext() {
    public void setNext() throws IOException {
        setNext(/* value= */ true);
        setNext(/* value= */ true);
    }
    }


    /** Convert BitSet in big-endian to ByteArray in big-endian. */
    /**
    public byte[] toByteArray() {
     * Flush the data written to the underlying {@link java.io.OutputStream}. Any unfinished bytes
        int bitSetSize = mBitSet.length();
     * will be padded with 0.
        int numOfBytes = bitSetSize / 8;
     */
        if (bitSetSize % 8 != 0) {
    public void flush() throws IOException {
            numOfBytes++;
        int endByte = mNextBitIndex / BYTE_BITS;
        }
        if (mNextBitIndex % BYTE_BITS != 0) {
        byte[] bytes = new byte[numOfBytes];
            // If next bit is not the first bit of a byte, then mNextBitIndex / BYTE_BITS would be
        for (int i = 0; i < mBitSet.length(); i++) {
            // the byte that includes already written bits. We need to increment it so this byte
            if (mBitSet.get(i)) {
            // gets written.
                bytes[i / 8] |= 1 << (7 - (i % 8));
            endByte++;
            }
        }
        }
        return bytes;
        mOutputStream.write(mBuffer, 0, endByte);
        reset();
    }
    }


    /** Clear the stream. */
    /** Reset this output stream to start state. */
    public void clear() {
    private void reset() {
        mBitSet.clear();
        mNextBitIndex = 0;
        mIndex = 0;
        Arrays.fill(mBuffer, (byte) 0);
    }
    }
}
}
+27 −14
Original line number Original line Diff line number Diff line
@@ -23,31 +23,44 @@ import java.io.OutputStream;
 * An output stream that tracks the total number written bytes since construction and allows
 * An output stream that tracks the total number written bytes since construction and allows
 * querying this value any time during the execution.
 * querying this value any time during the execution.
 *
 *
 * This class is used for constructing the rule indexing.
 * <p>This class is used for constructing the rule indexing.
 */
 */
public class ByteTrackedOutputStream {
public class ByteTrackedOutputStream extends OutputStream {


    private static int sWrittenBytesCount;
    private static final int INT_BYTES = 4;
    private static OutputStream sOutputStream;

    private int mWrittenBytesCount;
    private final OutputStream mOutputStream;


    public ByteTrackedOutputStream(OutputStream outputStream) {
    public ByteTrackedOutputStream(OutputStream outputStream) {
        sWrittenBytesCount = 0;
        mWrittenBytesCount = 0;
        sOutputStream = outputStream;
        mOutputStream = outputStream;
    }

    @Override
    public void write(int b) throws IOException {
        mWrittenBytesCount += INT_BYTES;
        mOutputStream.write(b);
    }
    }


    /**
    /**
     * Writes the given bytes into the output stream provided in constructor and updates the
     * Writes the given bytes into the output stream provided in constructor and updates the total
     * total number of written bytes.
     * number of written bytes.
     */
     */
    @Override
    public void write(byte[] bytes) throws IOException {
    public void write(byte[] bytes) throws IOException {
        sWrittenBytesCount += bytes.length;
        mWrittenBytesCount += bytes.length;
        sOutputStream.write(bytes);
        mOutputStream.write(bytes);
    }
    }


    /**
    @Override
     * Returns the total number of bytes written into the output stream at the requested time.
    public void write(byte[] b, int off, int len) throws IOException {
     */
        mWrittenBytesCount += len;
        mOutputStream.write(b, off, len);
    }

    /** Returns the total number of bytes written into the output stream at the requested time. */
    public int getWrittenBytesCount() {
    public int getWrittenBytesCount() {
        return sWrittenBytesCount;
        return mWrittenBytesCount;
    }
    }
}
}
+31 −34
Original line number Original line Diff line number Diff line
@@ -97,42 +97,38 @@ public class RuleBinarySerializer implements RuleSerializer {
                            ruleFileByteTrackedOutputStream);
                            ruleFileByteTrackedOutputStream);
            LinkedHashMap<String, Integer> unindexedRulesIndexes =
            LinkedHashMap<String, Integer> unindexedRulesIndexes =
                    serializeRuleList(
                    serializeRuleList(
                            indexedRules.get(NOT_INDEXED),
                            indexedRules.get(NOT_INDEXED), ruleFileByteTrackedOutputStream);
                            ruleFileByteTrackedOutputStream);


            // Serialize their indexes.
            // Serialize their indexes.
            BitOutputStream indexingBitOutputStream = new BitOutputStream();
            BitOutputStream indexingBitOutputStream = new BitOutputStream(indexingFileOutputStream);
            serializeIndexGroup(packageNameIndexes, indexingBitOutputStream, /* isIndexed= */ true);
            serializeIndexGroup(packageNameIndexes, indexingBitOutputStream, /* isIndexed= */ true);
            serializeIndexGroup(appCertificateIndexes, indexingBitOutputStream, /* isIndexed= */
            serializeIndexGroup(
                    true);
                    appCertificateIndexes, indexingBitOutputStream, /* isIndexed= */ true);
            serializeIndexGroup(unindexedRulesIndexes, indexingBitOutputStream, /* isIndexed= */
            serializeIndexGroup(
                    false);
                    unindexedRulesIndexes, indexingBitOutputStream, /* isIndexed= */ false);
            // TODO(b/147609625): This dummy bit is set for fixing the padding issue. Remove when
            indexingBitOutputStream.flush();
            // the issue is fixed and correct the tests that does this padding too.
            indexingBitOutputStream.setNext();
            indexingFileOutputStream.write(indexingBitOutputStream.toByteArray());
        } catch (Exception e) {
        } catch (Exception e) {
            throw new RuleSerializeException(e.getMessage(), e);
            throw new RuleSerializeException(e.getMessage(), e);
        }
        }
    }
    }


    private void serializeRuleFileMetadata(Optional<Integer> formatVersion,
    private void serializeRuleFileMetadata(
            ByteTrackedOutputStream outputStream)
            Optional<Integer> formatVersion, ByteTrackedOutputStream outputStream)
            throws IOException {
            throws IOException {
        int formatVersionValue = formatVersion.orElse(DEFAULT_FORMAT_VERSION);
        int formatVersionValue = formatVersion.orElse(DEFAULT_FORMAT_VERSION);


        BitOutputStream bitOutputStream = new BitOutputStream();
        BitOutputStream bitOutputStream = new BitOutputStream(outputStream);
        bitOutputStream.setNext(FORMAT_VERSION_BITS, formatVersionValue);
        bitOutputStream.setNext(FORMAT_VERSION_BITS, formatVersionValue);
        outputStream.write(bitOutputStream.toByteArray());
        bitOutputStream.flush();
    }
    }


    private LinkedHashMap<String, Integer> serializeRuleList(
    private LinkedHashMap<String, Integer> serializeRuleList(
            Map<String, List<Rule>> rulesMap, ByteTrackedOutputStream outputStream)
            Map<String, List<Rule>> rulesMap, ByteTrackedOutputStream outputStream)
            throws IOException {
            throws IOException {
        Preconditions.checkArgument(rulesMap != null,
        Preconditions.checkArgument(
                "serializeRuleList should never be called with null rule list.");
                rulesMap != null, "serializeRuleList should never be called with null rule list.");


        BitOutputStream bitOutputStream = new BitOutputStream();
        BitOutputStream bitOutputStream = new BitOutputStream(outputStream);
        LinkedHashMap<String, Integer> indexMapping = new LinkedHashMap();
        LinkedHashMap<String, Integer> indexMapping = new LinkedHashMap();
        indexMapping.put(START_INDEXING_KEY, outputStream.getWrittenBytesCount());
        indexMapping.put(START_INDEXING_KEY, outputStream.getWrittenBytesCount());


@@ -145,9 +141,8 @@ public class RuleBinarySerializer implements RuleSerializer {
            }
            }


            for (Rule rule : rulesMap.get(key)) {
            for (Rule rule : rulesMap.get(key)) {
                bitOutputStream.clear();
                serializeRule(rule, bitOutputStream);
                serializeRule(rule, bitOutputStream);
                outputStream.write(bitOutputStream.toByteArray());
                bitOutputStream.flush();
                indexTracker++;
                indexTracker++;
            }
            }
        }
        }
@@ -156,7 +151,7 @@ public class RuleBinarySerializer implements RuleSerializer {
        return indexMapping;
        return indexMapping;
    }
    }


    private void serializeRule(Rule rule, BitOutputStream bitOutputStream) {
    private void serializeRule(Rule rule, BitOutputStream bitOutputStream) throws IOException {
        if (rule == null) {
        if (rule == null) {
            throw new IllegalArgumentException("Null rule can not be serialized");
            throw new IllegalArgumentException("Null rule can not be serialized");
        }
        }
@@ -171,7 +166,8 @@ public class RuleBinarySerializer implements RuleSerializer {
        bitOutputStream.setNext();
        bitOutputStream.setNext();
    }
    }


    private void serializeFormula(Formula formula, BitOutputStream bitOutputStream) {
    private void serializeFormula(Formula formula, BitOutputStream bitOutputStream)
            throws IOException {
        if (formula instanceof AtomicFormula) {
        if (formula instanceof AtomicFormula) {
            serializeAtomicFormula((AtomicFormula) formula, bitOutputStream);
            serializeAtomicFormula((AtomicFormula) formula, bitOutputStream);
        } else if (formula instanceof CompoundFormula) {
        } else if (formula instanceof CompoundFormula) {
@@ -183,7 +179,7 @@ public class RuleBinarySerializer implements RuleSerializer {
    }
    }


    private void serializeCompoundFormula(
    private void serializeCompoundFormula(
            CompoundFormula compoundFormula, BitOutputStream bitOutputStream) {
            CompoundFormula compoundFormula, BitOutputStream bitOutputStream) throws IOException {
        if (compoundFormula == null) {
        if (compoundFormula == null) {
            throw new IllegalArgumentException("Null compound formula can not be serialized");
            throw new IllegalArgumentException("Null compound formula can not be serialized");
        }
        }
@@ -197,7 +193,7 @@ public class RuleBinarySerializer implements RuleSerializer {
    }
    }


    private void serializeAtomicFormula(
    private void serializeAtomicFormula(
            AtomicFormula atomicFormula, BitOutputStream bitOutputStream) {
            AtomicFormula atomicFormula, BitOutputStream bitOutputStream) throws IOException {
        if (atomicFormula == null) {
        if (atomicFormula == null) {
            throw new IllegalArgumentException("Null atomic formula can not be serialized");
            throw new IllegalArgumentException("Null atomic formula can not be serialized");
        }
        }
@@ -231,11 +227,10 @@ public class RuleBinarySerializer implements RuleSerializer {
    private void serializeIndexGroup(
    private void serializeIndexGroup(
            LinkedHashMap<String, Integer> indexes,
            LinkedHashMap<String, Integer> indexes,
            BitOutputStream bitOutputStream,
            BitOutputStream bitOutputStream,
            boolean isIndexed) {
            boolean isIndexed)

            throws IOException {
        // Output the starting location of this indexing group.
        // Output the starting location of this indexing group.
        serializeStringValue(
        serializeStringValue(START_INDEXING_KEY, /* isHashedValue= */ false, bitOutputStream);
                START_INDEXING_KEY, /* isHashedValue= */false, bitOutputStream);
        serializeIntValue(indexes.get(START_INDEXING_KEY), bitOutputStream);
        serializeIntValue(indexes.get(START_INDEXING_KEY), bitOutputStream);


        // If the group is indexed, output the locations of the indexes.
        // If the group is indexed, output the locations of the indexes.
@@ -243,8 +238,8 @@ public class RuleBinarySerializer implements RuleSerializer {
            for (Map.Entry<String, Integer> entry : indexes.entrySet()) {
            for (Map.Entry<String, Integer> entry : indexes.entrySet()) {
                if (!entry.getKey().equals(START_INDEXING_KEY)
                if (!entry.getKey().equals(START_INDEXING_KEY)
                        && !entry.getKey().equals(END_INDEXING_KEY)) {
                        && !entry.getKey().equals(END_INDEXING_KEY)) {
                    serializeStringValue(entry.getKey(), /* isHashedValue= */false,
                    serializeStringValue(
                            bitOutputStream);
                            entry.getKey(), /* isHashedValue= */ false, bitOutputStream);
                    serializeIntValue(entry.getValue(), bitOutputStream);
                    serializeIntValue(entry.getValue(), bitOutputStream);
                }
                }
            }
            }
@@ -256,7 +251,8 @@ public class RuleBinarySerializer implements RuleSerializer {
    }
    }


    private void serializeStringValue(
    private void serializeStringValue(
            String value, boolean isHashedValue, BitOutputStream bitOutputStream) {
            String value, boolean isHashedValue, BitOutputStream bitOutputStream)
            throws IOException {
        if (value == null) {
        if (value == null) {
            throw new IllegalArgumentException("String value can not be null.");
            throw new IllegalArgumentException("String value can not be null.");
        }
        }
@@ -269,11 +265,12 @@ public class RuleBinarySerializer implements RuleSerializer {
        }
        }
    }
    }


    private void serializeIntValue(int value, BitOutputStream bitOutputStream) {
    private void serializeIntValue(int value, BitOutputStream bitOutputStream) throws IOException {
        bitOutputStream.setNext(/* numOfBits= */ 32, value);
        bitOutputStream.setNext(/* numOfBits= */ 32, value);
    }
    }


    private void serializeBooleanValue(boolean value, BitOutputStream bitOutputStream) {
    private void serializeBooleanValue(boolean value, BitOutputStream bitOutputStream)
            throws IOException {
        bitOutputStream.setNext(value);
        bitOutputStream.setNext(value);
    }
    }


+4 −4
Original line number Original line Diff line number Diff line
@@ -188,7 +188,7 @@ public class RuleEvaluatorTest {
    }
    }


    @Test
    @Test
    public void testEvaluateRules_ruleNotInDNF_ignoreAndAllow() {
    public void testEvaluateRules_orRules() {
        CompoundFormula compoundFormula =
        CompoundFormula compoundFormula =
                new CompoundFormula(
                new CompoundFormula(
                        CompoundFormula.OR,
                        CompoundFormula.OR,
@@ -206,11 +206,11 @@ public class RuleEvaluatorTest {
        IntegrityCheckResult result =
        IntegrityCheckResult result =
                RuleEvaluator.evaluateRules(Collections.singletonList(rule), APP_INSTALL_METADATA);
                RuleEvaluator.evaluateRules(Collections.singletonList(rule), APP_INSTALL_METADATA);


        assertEquals(ALLOW, result.getEffect());
        assertEquals(DENY, result.getEffect());
    }
    }


    @Test
    @Test
    public void testEvaluateRules_compoundFormulaWithNot_allow() {
    public void testEvaluateRules_compoundFormulaWithNot_deny() {
        CompoundFormula openSubFormula =
        CompoundFormula openSubFormula =
                new CompoundFormula(
                new CompoundFormula(
                        CompoundFormula.AND,
                        CompoundFormula.AND,
@@ -230,7 +230,7 @@ public class RuleEvaluatorTest {
        IntegrityCheckResult result =
        IntegrityCheckResult result =
                RuleEvaluator.evaluateRules(Collections.singletonList(rule), APP_INSTALL_METADATA);
                RuleEvaluator.evaluateRules(Collections.singletonList(rule), APP_INSTALL_METADATA);


        assertEquals(ALLOW, result.getEffect());
        assertEquals(DENY, result.getEffect());
    }
    }


    @Test
    @Test
+4 −4
Original line number Original line Diff line number Diff line
@@ -53,17 +53,17 @@ public class ByteTrackedOutputStreamTest {
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
        ByteTrackedOutputStream byteTrackedOutputStream = new ByteTrackedOutputStream(outputStream);
        ByteTrackedOutputStream byteTrackedOutputStream = new ByteTrackedOutputStream(outputStream);


        BitOutputStream bitOutputStream = new BitOutputStream();
        BitOutputStream bitOutputStream = new BitOutputStream(byteTrackedOutputStream);
        bitOutputStream.setNext(/* numOfBits= */5, /* value= */1);
        bitOutputStream.setNext(/* numOfBits= */5, /* value= */1);
        byteTrackedOutputStream.write(bitOutputStream.toByteArray());
        bitOutputStream.flush();


        // Even though we wrote 5 bits, this will complete to 1 byte.
        // Even though we wrote 5 bits, this will complete to 1 byte.
        assertThat(byteTrackedOutputStream.getWrittenBytesCount()).isEqualTo(1);
        assertThat(byteTrackedOutputStream.getWrittenBytesCount()).isEqualTo(1);


        // Add a bit less than 2 bytes (10 bits).
        // Add a bit less than 2 bytes (10 bits).
        bitOutputStream.clear();
        bitOutputStream.setNext(/* numOfBits= */10, /* value= */1);
        bitOutputStream.setNext(/* numOfBits= */10, /* value= */1);
        byteTrackedOutputStream.write(bitOutputStream.toByteArray());
        bitOutputStream.flush();
        assertThat(byteTrackedOutputStream.getWrittenBytesCount()).isEqualTo(3);


        assertThat(outputStream.toByteArray().length).isEqualTo(3);
        assertThat(outputStream.toByteArray().length).isEqualTo(3);
    }
    }
Loading