Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 5f322de2 authored by Song Pan's avatar Song Pan Committed by Android (Google) Code Review
Browse files

Merge "Fix a bug in BitOutputStream where any trailing zeros will be discarded...

Merge "Fix a bug in BitOutputStream where any trailing zeros will be discarded along with an optimization to avoid maintaining two copies of data in memory during toByteArray."
parents 3746de82 e5d60740
Loading
Loading
Loading
Loading
+50 −32
Original line number Diff line number Diff line
@@ -16,17 +16,25 @@

package com.android.server.integrity.model;

import java.util.BitSet;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Arrays;

/** A wrapper class for writing a stream of bits. */
public class BitOutputStream {

    private BitSet mBitSet;
    private int mIndex;
    private static final int BUFFER_SIZE = 4 * 1024;
    private static final int BYTE_BITS = 8;

    public BitOutputStream() {
        mBitSet = new BitSet();
        mIndex = 0;
    private int mNextBitIndex;

    private final OutputStream mOutputStream;
    private final byte[] mBuffer;

    public BitOutputStream(OutputStream outputStream) {
        mBuffer = new byte[BUFFER_SIZE];
        mNextBitIndex = 0;
        mOutputStream = outputStream;
    }

    /**
@@ -35,15 +43,17 @@ public class BitOutputStream {
     * @param numOfBits The number of bits used to represent the value.
     * @param value The value to convert to bits.
     */
    public void setNext(int numOfBits, int value) {
    public void setNext(int numOfBits, int value) throws IOException {
        if (numOfBits <= 0) {
            return;
        }
        int offset = 1 << (numOfBits - 1);

        // optional: we can do some clever size checking to "OR" an entire segment of bits instead
        // of setting bits one by one, but it is probably not worth it.
        int nextBitMask = 1 << (numOfBits - 1);
        while (numOfBits-- > 0) {
            mBitSet.set(mIndex, (value & offset) != 0);
            offset >>>= 1;
            mIndex++;
            setNext((value & nextBitMask) != 0);
            nextBitMask >>>= 1;
        }
    }

@@ -52,35 +62,43 @@ public class BitOutputStream {
     *
     * @param value The value to set the bit to.
     */
    public void setNext(boolean value) {
        mBitSet.set(mIndex, value);
        mIndex++;
    public void setNext(boolean value) throws IOException {
        int byteToWrite = mNextBitIndex / BYTE_BITS;
        if (byteToWrite == BUFFER_SIZE) {
            mOutputStream.write(mBuffer);
            reset();
            byteToWrite = 0;
        }
        if (value) {
            mBuffer[byteToWrite] |= 1 << (BYTE_BITS - 1 - (mNextBitIndex % BYTE_BITS));
        }
        mNextBitIndex++;
    }

    /** Set the next bit in the stream to true. */
    public void setNext() {
    public void setNext() throws IOException {
        setNext(/* value= */ true);
    }

    /** Convert BitSet in big-endian to ByteArray in big-endian. */
    public byte[] toByteArray() {
        int bitSetSize = mBitSet.length();
        int numOfBytes = bitSetSize / 8;
        if (bitSetSize % 8 != 0) {
            numOfBytes++;
        }
        byte[] bytes = new byte[numOfBytes];
        for (int i = 0; i < mBitSet.length(); i++) {
            if (mBitSet.get(i)) {
                bytes[i / 8] |= 1 << (7 - (i % 8));
            }
    /**
     * Flush the data written to the underlying {@link java.io.OutputStream}. Any unfinished bytes
     * will be padded with 0.
     */
    public void flush() throws IOException {
        int endByte = mNextBitIndex / BYTE_BITS;
        if (mNextBitIndex % BYTE_BITS != 0) {
            // If next bit is not the first bit of a byte, then mNextBitIndex / BYTE_BITS would be
            // the byte that includes already written bits. We need to increment it so this byte
            // gets written.
            endByte++;
        }
        return bytes;
        mOutputStream.write(mBuffer, 0, endByte);
        reset();
    }

    /** Clear the stream. */
    public void clear() {
        mBitSet.clear();
        mIndex = 0;
    /** Reset this output stream to start state. */
    private void reset() {
        mNextBitIndex = 0;
        Arrays.fill(mBuffer, (byte) 0);
    }
}
+27 −14
Original line number Diff line number Diff line
@@ -23,31 +23,44 @@ import java.io.OutputStream;
 * An output stream that tracks the total number written bytes since construction and allows
 * querying this value any time during the execution.
 *
 * This class is used for constructing the rule indexing.
 * <p>This class is used for constructing the rule indexing.
 */
public class ByteTrackedOutputStream {
public class ByteTrackedOutputStream extends OutputStream {

    private static int sWrittenBytesCount;
    private static OutputStream sOutputStream;
    private static final int INT_BYTES = 4;

    private int mWrittenBytesCount;
    private final OutputStream mOutputStream;

    public ByteTrackedOutputStream(OutputStream outputStream) {
        sWrittenBytesCount = 0;
        sOutputStream = outputStream;
        mWrittenBytesCount = 0;
        mOutputStream = outputStream;
    }

    @Override
    public void write(int b) throws IOException {
        mWrittenBytesCount += INT_BYTES;
        mOutputStream.write(b);
    }

    /**
     * Writes the given bytes into the output stream provided in constructor and updates the
     * total number of written bytes.
     * Writes the given bytes into the output stream provided in constructor and updates the total
     * number of written bytes.
     */
    @Override
    public void write(byte[] bytes) throws IOException {
        sWrittenBytesCount += bytes.length;
        sOutputStream.write(bytes);
        mWrittenBytesCount += bytes.length;
        mOutputStream.write(bytes);
    }

    /**
     * Returns the total number of bytes written into the output stream at the requested time.
     */
    @Override
    public void write(byte[] b, int off, int len) throws IOException {
        mWrittenBytesCount += len;
        mOutputStream.write(b, off, len);
    }

    /** Returns the total number of bytes written into the output stream at the requested time. */
    public int getWrittenBytesCount() {
        return sWrittenBytesCount;
        return mWrittenBytesCount;
    }
}
+31 −34
Original line number Diff line number Diff line
@@ -97,42 +97,38 @@ public class RuleBinarySerializer implements RuleSerializer {
                            ruleFileByteTrackedOutputStream);
            LinkedHashMap<String, Integer> unindexedRulesIndexes =
                    serializeRuleList(
                            indexedRules.get(NOT_INDEXED),
                            ruleFileByteTrackedOutputStream);
                            indexedRules.get(NOT_INDEXED), ruleFileByteTrackedOutputStream);

            // Serialize their indexes.
            BitOutputStream indexingBitOutputStream = new BitOutputStream();
            BitOutputStream indexingBitOutputStream = new BitOutputStream(indexingFileOutputStream);
            serializeIndexGroup(packageNameIndexes, indexingBitOutputStream, /* isIndexed= */ true);
            serializeIndexGroup(appCertificateIndexes, indexingBitOutputStream, /* isIndexed= */
                    true);
            serializeIndexGroup(unindexedRulesIndexes, indexingBitOutputStream, /* isIndexed= */
                    false);
            // TODO(b/147609625): This dummy bit is set for fixing the padding issue. Remove when
            // the issue is fixed and correct the tests that does this padding too.
            indexingBitOutputStream.setNext();
            indexingFileOutputStream.write(indexingBitOutputStream.toByteArray());
            serializeIndexGroup(
                    appCertificateIndexes, indexingBitOutputStream, /* isIndexed= */ true);
            serializeIndexGroup(
                    unindexedRulesIndexes, indexingBitOutputStream, /* isIndexed= */ false);
            indexingBitOutputStream.flush();
        } catch (Exception e) {
            throw new RuleSerializeException(e.getMessage(), e);
        }
    }

    private void serializeRuleFileMetadata(Optional<Integer> formatVersion,
            ByteTrackedOutputStream outputStream)
    private void serializeRuleFileMetadata(
            Optional<Integer> formatVersion, ByteTrackedOutputStream outputStream)
            throws IOException {
        int formatVersionValue = formatVersion.orElse(DEFAULT_FORMAT_VERSION);

        BitOutputStream bitOutputStream = new BitOutputStream();
        BitOutputStream bitOutputStream = new BitOutputStream(outputStream);
        bitOutputStream.setNext(FORMAT_VERSION_BITS, formatVersionValue);
        outputStream.write(bitOutputStream.toByteArray());
        bitOutputStream.flush();
    }

    private LinkedHashMap<String, Integer> serializeRuleList(
            Map<String, List<Rule>> rulesMap, ByteTrackedOutputStream outputStream)
            throws IOException {
        Preconditions.checkArgument(rulesMap != null,
                "serializeRuleList should never be called with null rule list.");
        Preconditions.checkArgument(
                rulesMap != null, "serializeRuleList should never be called with null rule list.");

        BitOutputStream bitOutputStream = new BitOutputStream();
        BitOutputStream bitOutputStream = new BitOutputStream(outputStream);
        LinkedHashMap<String, Integer> indexMapping = new LinkedHashMap();
        indexMapping.put(START_INDEXING_KEY, outputStream.getWrittenBytesCount());

@@ -145,9 +141,8 @@ public class RuleBinarySerializer implements RuleSerializer {
            }

            for (Rule rule : rulesMap.get(key)) {
                bitOutputStream.clear();
                serializeRule(rule, bitOutputStream);
                outputStream.write(bitOutputStream.toByteArray());
                bitOutputStream.flush();
                indexTracker++;
            }
        }
@@ -156,7 +151,7 @@ public class RuleBinarySerializer implements RuleSerializer {
        return indexMapping;
    }

    private void serializeRule(Rule rule, BitOutputStream bitOutputStream) {
    private void serializeRule(Rule rule, BitOutputStream bitOutputStream) throws IOException {
        if (rule == null) {
            throw new IllegalArgumentException("Null rule can not be serialized");
        }
@@ -171,7 +166,8 @@ public class RuleBinarySerializer implements RuleSerializer {
        bitOutputStream.setNext();
    }

    private void serializeFormula(Formula formula, BitOutputStream bitOutputStream) {
    private void serializeFormula(Formula formula, BitOutputStream bitOutputStream)
            throws IOException {
        if (formula instanceof AtomicFormula) {
            serializeAtomicFormula((AtomicFormula) formula, bitOutputStream);
        } else if (formula instanceof CompoundFormula) {
@@ -183,7 +179,7 @@ public class RuleBinarySerializer implements RuleSerializer {
    }

    private void serializeCompoundFormula(
            CompoundFormula compoundFormula, BitOutputStream bitOutputStream) {
            CompoundFormula compoundFormula, BitOutputStream bitOutputStream) throws IOException {
        if (compoundFormula == null) {
            throw new IllegalArgumentException("Null compound formula can not be serialized");
        }
@@ -197,7 +193,7 @@ public class RuleBinarySerializer implements RuleSerializer {
    }

    private void serializeAtomicFormula(
            AtomicFormula atomicFormula, BitOutputStream bitOutputStream) {
            AtomicFormula atomicFormula, BitOutputStream bitOutputStream) throws IOException {
        if (atomicFormula == null) {
            throw new IllegalArgumentException("Null atomic formula can not be serialized");
        }
@@ -231,11 +227,10 @@ public class RuleBinarySerializer implements RuleSerializer {
    private void serializeIndexGroup(
            LinkedHashMap<String, Integer> indexes,
            BitOutputStream bitOutputStream,
            boolean isIndexed) {

            boolean isIndexed)
            throws IOException {
        // Output the starting location of this indexing group.
        serializeStringValue(
                START_INDEXING_KEY, /* isHashedValue= */false, bitOutputStream);
        serializeStringValue(START_INDEXING_KEY, /* isHashedValue= */ false, bitOutputStream);
        serializeIntValue(indexes.get(START_INDEXING_KEY), bitOutputStream);

        // If the group is indexed, output the locations of the indexes.
@@ -243,8 +238,8 @@ public class RuleBinarySerializer implements RuleSerializer {
            for (Map.Entry<String, Integer> entry : indexes.entrySet()) {
                if (!entry.getKey().equals(START_INDEXING_KEY)
                        && !entry.getKey().equals(END_INDEXING_KEY)) {
                    serializeStringValue(entry.getKey(), /* isHashedValue= */false,
                            bitOutputStream);
                    serializeStringValue(
                            entry.getKey(), /* isHashedValue= */ false, bitOutputStream);
                    serializeIntValue(entry.getValue(), bitOutputStream);
                }
            }
@@ -256,7 +251,8 @@ public class RuleBinarySerializer implements RuleSerializer {
    }

    private void serializeStringValue(
            String value, boolean isHashedValue, BitOutputStream bitOutputStream) {
            String value, boolean isHashedValue, BitOutputStream bitOutputStream)
            throws IOException {
        if (value == null) {
            throw new IllegalArgumentException("String value can not be null.");
        }
@@ -269,11 +265,12 @@ public class RuleBinarySerializer implements RuleSerializer {
        }
    }

    private void serializeIntValue(int value, BitOutputStream bitOutputStream) {
    private void serializeIntValue(int value, BitOutputStream bitOutputStream) throws IOException {
        bitOutputStream.setNext(/* numOfBits= */ 32, value);
    }

    private void serializeBooleanValue(boolean value, BitOutputStream bitOutputStream) {
    private void serializeBooleanValue(boolean value, BitOutputStream bitOutputStream)
            throws IOException {
        bitOutputStream.setNext(value);
    }

+4 −4
Original line number Diff line number Diff line
@@ -188,7 +188,7 @@ public class RuleEvaluatorTest {
    }

    @Test
    public void testEvaluateRules_ruleNotInDNF_ignoreAndAllow() {
    public void testEvaluateRules_orRules() {
        CompoundFormula compoundFormula =
                new CompoundFormula(
                        CompoundFormula.OR,
@@ -206,11 +206,11 @@ public class RuleEvaluatorTest {
        IntegrityCheckResult result =
                RuleEvaluator.evaluateRules(Collections.singletonList(rule), APP_INSTALL_METADATA);

        assertEquals(ALLOW, result.getEffect());
        assertEquals(DENY, result.getEffect());
    }

    @Test
    public void testEvaluateRules_compoundFormulaWithNot_allow() {
    public void testEvaluateRules_compoundFormulaWithNot_deny() {
        CompoundFormula openSubFormula =
                new CompoundFormula(
                        CompoundFormula.AND,
@@ -230,7 +230,7 @@ public class RuleEvaluatorTest {
        IntegrityCheckResult result =
                RuleEvaluator.evaluateRules(Collections.singletonList(rule), APP_INSTALL_METADATA);

        assertEquals(ALLOW, result.getEffect());
        assertEquals(DENY, result.getEffect());
    }

    @Test
+4 −4
Original line number Diff line number Diff line
@@ -53,17 +53,17 @@ public class ByteTrackedOutputStreamTest {
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
        ByteTrackedOutputStream byteTrackedOutputStream = new ByteTrackedOutputStream(outputStream);

        BitOutputStream bitOutputStream = new BitOutputStream();
        BitOutputStream bitOutputStream = new BitOutputStream(byteTrackedOutputStream);
        bitOutputStream.setNext(/* numOfBits= */5, /* value= */1);
        byteTrackedOutputStream.write(bitOutputStream.toByteArray());
        bitOutputStream.flush();

        // Even though we wrote 5 bits, this will complete to 1 byte.
        assertThat(byteTrackedOutputStream.getWrittenBytesCount()).isEqualTo(1);

        // Add a bit less than 2 bytes (10 bits).
        bitOutputStream.clear();
        bitOutputStream.setNext(/* numOfBits= */10, /* value= */1);
        byteTrackedOutputStream.write(bitOutputStream.toByteArray());
        bitOutputStream.flush();
        assertThat(byteTrackedOutputStream.getWrittenBytesCount()).isEqualTo(3);

        assertThat(outputStream.toByteArray().length).isEqualTo(3);
    }
Loading