Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c0c146b6 authored by Zhi Dou's avatar Zhi Dou
Browse files

new buffer reader for buffer read

This change will create new buffer reader when the code needs to read
the buffer read. Previously there is only new buffer reader in the
class, so if there are multiple threads call the get method in
PackageTable, then these threads will modify the buffer position can
cause a race condition.

This change will create new buffer reader when the code needs to read
the buffer in the method. so each thread will have its own instance of
the reader, and the reader will maintain the position.

Test: atest aconfig_storage_file.test.java
Bug: 397997135
Flag: EXAMPT refactor
Change-Id: Ic355a9273591bbd42931b21f18c6cddcb952b35c
parent 0a68346a
Loading
Loading
Loading
Loading
+25 −7
Original line number Diff line number Diff line
@@ -19,10 +19,12 @@ package android.aconfig.storage;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

public class ByteBufferReader {

    private ByteBuffer mByteBuffer;
    private int mPosition;

    public ByteBufferReader(ByteBuffer byteBuffer) {
        this.mByteBuffer = byteBuffer;
@@ -30,19 +32,19 @@ public class ByteBufferReader {
    }

    public int readByte() {
        return Byte.toUnsignedInt(mByteBuffer.get());
        return Byte.toUnsignedInt(mByteBuffer.get(nextGetIndex(1)));
    }

    public int readShort() {
        return Short.toUnsignedInt(mByteBuffer.getShort());
        return Short.toUnsignedInt(mByteBuffer.getShort(nextGetIndex(2)));
    }

    public int readInt() {
        return this.mByteBuffer.getInt();
        return this.mByteBuffer.getInt(nextGetIndex(4));
    }

    public long readLong() {
        return this.mByteBuffer.getLong();
        return this.mByteBuffer.getLong(nextGetIndex(8));
    }

    public String readString() {
@@ -52,7 +54,7 @@ public class ByteBufferReader {
                    "String length exceeds maximum allowed size (1024 bytes): " + length);
        }
        byte[] bytes = new byte[length];
        mByteBuffer.get(bytes, 0, length);
        getArray(nextGetIndex(length), bytes, 0, length);
        return new String(bytes, StandardCharsets.UTF_8);
    }

@@ -61,10 +63,26 @@ public class ByteBufferReader {
    }

    public void position(int newPosition) {
        mByteBuffer.position(newPosition);
        mPosition = newPosition;
    }

    public int position() {
        return mByteBuffer.position();
        return mPosition;
    }

    private int nextGetIndex(int nb) {
        int p = mPosition;
        mPosition += nb;
        return p;
    }

    private void getArray(int index, byte[] dst, int offset, int length) {
        Objects.checkFromIndexSize(index, length, mByteBuffer.limit());
        Objects.checkFromIndexSize(offset, length, dst.length);

        int end = offset + length;
        for (int i = offset, j = index; i < end; i++, j++) {
            dst[i] = mByteBuffer.get(j);
        }
    }
}
+8 −8
Original line number Diff line number Diff line
@@ -24,12 +24,12 @@ import java.util.Objects;
public class FlagTable {

    private Header mHeader;
    private ByteBufferReader mReader;
    private ByteBuffer mBuffer;

    public static FlagTable fromBytes(ByteBuffer bytes) {
        FlagTable flagTable = new FlagTable();
        flagTable.mReader = new ByteBufferReader(bytes);
        flagTable.mHeader = Header.fromBytes(flagTable.mReader);
        flagTable.mBuffer = bytes;
        flagTable.mHeader = Header.fromBytes(new ByteBufferReader(bytes));

        return flagTable;
    }
@@ -41,16 +41,16 @@ public class FlagTable {
        if (newPosition >= mHeader.mNodeOffset) {
            return null;
        }

        mReader.position(newPosition);
        int nodeIndex = mReader.readInt();
        ByteBufferReader reader = new ByteBufferReader(mBuffer) ;
        reader.position(newPosition);
        int nodeIndex = reader.readInt();
        if (nodeIndex < mHeader.mNodeOffset || nodeIndex >= mHeader.mFileSize) {
            return null;
        }

        while (nodeIndex != -1) {
            mReader.position(nodeIndex);
            Node node = Node.fromBytes(mReader);
            reader.position(nodeIndex);
            Node node = Node.fromBytes(reader);
            if (Objects.equals(flagName, node.mFlagName) && packageId == node.mPackageId) {
                return node;
            }
+12 −10
Original line number Diff line number Diff line
@@ -30,12 +30,12 @@ public class PackageTable {
    private static final int NODE_SKIP_BYTES = 12;

    private Header mHeader;
    private ByteBufferReader mReader;
    private ByteBuffer mBuffer;

    public static PackageTable fromBytes(ByteBuffer bytes) {
        PackageTable packageTable = new PackageTable();
        packageTable.mReader = new ByteBufferReader(bytes);
        packageTable.mHeader = Header.fromBytes(packageTable.mReader);
        packageTable.mBuffer = bytes;
        packageTable.mHeader = Header.fromBytes(new ByteBufferReader(bytes));

        return packageTable;
    }
@@ -47,16 +47,17 @@ public class PackageTable {
        if (newPosition >= mHeader.mNodeOffset) {
            return null;
        }
        mReader.position(newPosition);
        int nodeIndex = mReader.readInt();
        ByteBufferReader reader = new ByteBufferReader(mBuffer);
        reader.position(newPosition);
        int nodeIndex = reader.readInt();

        if (nodeIndex < mHeader.mNodeOffset || nodeIndex >= mHeader.mFileSize) {
            return null;
        }

        while (nodeIndex != -1) {
            mReader.position(nodeIndex);
            Node node = Node.fromBytes(mReader, mHeader.mVersion);
            reader.position(nodeIndex);
            Node node = Node.fromBytes(reader, mHeader.mVersion);
            if (Objects.equals(packageName, node.mPackageName)) {
                return node;
            }
@@ -68,12 +69,13 @@ public class PackageTable {

    public List<String> getPackageList() {
        List<String> list = new ArrayList<>(mHeader.mNumPackages);
        mReader.position(mHeader.mNodeOffset);
        ByteBufferReader reader = new ByteBufferReader(mBuffer);
        reader.position(mHeader.mNodeOffset);
        int fingerprintBytes = mHeader.mVersion == 1 ? 0 : FINGERPRINT_BYTES;
        int skipBytes = fingerprintBytes + NODE_SKIP_BYTES;
        for (int i = 0; i < mHeader.mNumPackages; i++) {
            list.add(mReader.readString());
            mReader.position(mReader.position() + skipBytes);
            list.add(reader.readString());
            reader.position(reader.position() + skipBytes);
        }
        return list;
    }
+44 −0
Original line number Diff line number Diff line
@@ -28,7 +28,9 @@ import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

import java.util.HashSet;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CyclicBarrier;

@RunWith(JUnit4.class)
public class PackageTableTest {
@@ -142,4 +144,46 @@ public class PackageTableTest {
        assertTrue(packages.contains("com.android.aconfig.storage.test_2"));
        assertTrue(packages.contains("com.android.aconfig.storage.test_4"));
    }

    @Test
    public void testPackageTable_multithreadsRead() throws Exception {
        PackageTable packageTable =
                PackageTable.fromBytes(TestDataUtils.getTestPackageMapByteBuffer(2));
        int numberOfThreads = 3;
        Thread[] threads = new Thread[numberOfThreads];
        final CyclicBarrier gate = new CyclicBarrier(numberOfThreads + 1);
        String[] expects = {
            "com.android.aconfig.storage.test_1",
            "com.android.aconfig.storage.test_2",
            "com.android.aconfig.storage.test_4"
        };

        for (int i = 0; i < numberOfThreads; i++) {
            final String packageName = expects[i];
            threads[i] =
                    new Thread() {
                        @Override
                        public void run() {
                            try {
                                gate.await();
                            } catch (Exception e) {
                            }
                            for (int j = 0; j < 10; j++) {
                                if (!Objects.equals(
                                        packageName,
                                        packageTable.get(packageName).getPackageName())) {
                                    throw new RuntimeException();
                                }
                            }
                        }
                    };
            threads[i].start();
        }

        gate.await();

        for (int i = 0; i < numberOfThreads; i++) {
            threads[i].join();
        }
    }
}