Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit db2a744d authored by Raphael Kim's avatar Raphael Kim
Browse files

Integrate secure channel into CDM

Bug: 253307662
Test: Manually tested on CtsPermissionsSyncTestApp
Change-Id: I1b3bee3eab0ca1655ab818379f3598358bc4b677
parent a77e254e
Loading
Loading
Loading
Loading
+99 −81
Original line number Diff line number Diff line
@@ -34,15 +34,14 @@ import android.util.SparseArray;

import com.android.internal.annotations.GuardedBy;
import com.android.server.LocalServices;
import com.android.server.companion.securechannel.SecureChannel;

import libcore.io.IoUtils;
import libcore.io.Streams;
import libcore.util.EmptyArray;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
@@ -54,8 +53,6 @@ public class CompanionTransportManager {
    private static final boolean DEBUG = true;

    private static final int HEADER_LENGTH = 12;
    // TODO: refactor message processing to use streams to remove this limit
    private static final int MAX_PAYLOAD_LENGTH = 1_000_000;

    private static final int MESSAGE_REQUEST_PING = 0x63807378; // ?PIN
    private static final int MESSAGE_REQUEST_PERMISSION_RESTORE = 0x63826983; // ?RES
@@ -150,53 +147,47 @@ public class CompanionTransportManager {
        }
    }

    private class Transport {
    private class Transport implements SecureChannel.Callback {
        private final int mAssociationId;

        private final InputStream mRemoteIn;
        private final OutputStream mRemoteOut;

        private final SecureChannel mSecureChannel;
        private final AtomicInteger mNextSequence = new AtomicInteger();

        private volatile boolean mShouldProcessRequests = false;

        @GuardedBy("mPendingRequests")
        private final SparseArray<CompletableFuture<byte[]>> mPendingRequests = new SparseArray<>();

        private volatile boolean mStopped;
        private final BlockingQueue<Runnable> mRequestQueue = new ArrayBlockingQueue<>(100);

        public Transport(int associationId, ParcelFileDescriptor fd) {
            mAssociationId = associationId;
            mRemoteIn = new ParcelFileDescriptor.AutoCloseInputStream(fd);
            mRemoteOut = new ParcelFileDescriptor.AutoCloseOutputStream(fd);
            mSecureChannel = new SecureChannel(
                    new ParcelFileDescriptor.AutoCloseInputStream(fd),
                    new ParcelFileDescriptor.AutoCloseOutputStream(fd),
                    this,
                    mContext
            );
        }

        public void start() {
            new Thread(() -> {
                try {
                    while (!mStopped) {
                        receiveMessage();
                    }
                } catch (IOException e) {
                    if (!mStopped) {
                        Slog.w(TAG, "Trouble during transport", e);
                        stop();
                    }
                }
            }).start();
            mSecureChannel.start();
        }

        public void stop() {
            mStopped = true;

            IoUtils.closeQuietly(mRemoteIn);
            IoUtils.closeQuietly(mRemoteOut);
            mSecureChannel.stop();
            mShouldProcessRequests = false;
        }

        public Future<byte[]> requestForResponse(int message, byte[] data) {
            if (DEBUG) Slog.d(TAG, "Requesting for response");
            final int sequence = mNextSequence.incrementAndGet();
            final CompletableFuture<byte[]> pending = new CompletableFuture<>();
            synchronized (mPendingRequests) {
                mPendingRequests.put(sequence, pending);
            }

            // Queue up a task
            mRequestQueue.add(() -> {
                try {
                    sendMessage(message, sequence, data);
                } catch (IOException e) {
@@ -205,63 +196,40 @@ public class CompanionTransportManager {
                    }
                    pending.completeExceptionally(e);
                }
            });

            // Check if channel is secured and start securing
            if (!mShouldProcessRequests) {
                Slog.d(TAG, "Establishing secure connection.");
                try {
                    mSecureChannel.establishSecureConnection();
                } catch (Exception e) {
                    synchronized (mPendingRequests) {
                        mPendingRequests.remove(sequence);
                    }
                    pending.completeExceptionally(e);
                }
            }

            return pending;
        }

        private void sendMessage(int message, int sequence, @NonNull byte[] data)
                throws IOException {

            if (DEBUG) {
                Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message)
                        + " sequence " + sequence + " length " + data.length
                        + " to association " + mAssociationId);
            }

            synchronized (mRemoteOut) {
                final ByteBuffer header = ByteBuffer.allocate(HEADER_LENGTH)
            final ByteBuffer payload = ByteBuffer.allocate(HEADER_LENGTH + data.length)
                    .putInt(message)
                    .putInt(sequence)
                        .putInt(data.length);
                mRemoteOut.write(header.array());
                mRemoteOut.write(data);
                mRemoteOut.flush();
            }
        }

        private void receiveMessage() throws IOException {
            if (DEBUG) {
                Slog.d(TAG, "Waiting for next message...");
            }

            final byte[] headerBytes = new byte[HEADER_LENGTH];
            Streams.readFully(mRemoteIn, headerBytes);
            final ByteBuffer header = ByteBuffer.wrap(headerBytes);
            final int message = header.getInt();
            final int sequence = header.getInt();
            final int length = header.getInt();

            if (DEBUG) {
                Slog.d(TAG, "Received message 0x" + Integer.toHexString(message)
                        + " sequence " + sequence + " length " + length
                        + " from association " + mAssociationId);
            }
            if (length > MAX_PAYLOAD_LENGTH) {
                Slog.w(TAG, "Ignoring message 0x" + Integer.toHexString(message)
                        + " sequence " + sequence + " length " + length
                        + " from association " + mAssociationId + " beyond maximum length");
                Streams.skipByReading(mRemoteIn, length);
                return;
            }
                    .putInt(data.length)
                    .put(data);

            final byte[] data = new byte[length];
            Streams.readFully(mRemoteIn, data);

            if (isRequest(message)) {
                processRequest(message, sequence, data);
            } else if (isResponse(message)) {
                processResponse(message, sequence, data);
            } else {
                Slog.w(TAG, "Unknown message 0x" + Integer.toHexString(message));
            }
            mSecureChannel.sendSecureMessage(payload.array());
        }

        private void processRequest(int message, int sequence, byte[] data)
@@ -319,5 +287,55 @@ public class CompanionTransportManager {
                }
            }
        }

        @Override
        public void onSecureConnection() {
            mShouldProcessRequests = true;
            Slog.d(TAG, "Secure connection established.");

            // TODO: find a better way to handle incoming requests than a dedicated thread.
            new Thread(() -> {
                while (mShouldProcessRequests) {
                    Runnable task = mRequestQueue.poll();
                    if (task != null) {
                        task.run();
                    }
                }
            }).start();
        }

        @Override
        public void onSecureMessageReceived(byte[] data) {
            final ByteBuffer payload = ByteBuffer.wrap(data);
            final int message = payload.getInt();
            final int sequence = payload.getInt();
            final int length = payload.getInt();

            if (DEBUG) {
                Slog.d(TAG, "Received message 0x" + Integer.toHexString(message)
                        + " sequence " + sequence + " length " + length
                        + " from association " + mAssociationId);
            }

            final byte[] content = new byte[length];
            payload.get(content);
            if (isRequest(message)) {
                try {
                    processRequest(message, sequence, content);
                } catch (IOException e) {
                    Slog.w(TAG, "Failed to respond to 0x" + Integer.toHexString(message), e);
                }
            } else if (isResponse(message)) {
                processResponse(message, sequence, content);
            } else {
                Slog.w(TAG, "Unknown message 0x" + Integer.toHexString(message));
            }
        }

        @Override
        public void onError(Throwable error) {
            mShouldProcessRequests = false;
            Slog.e(TAG, error.getMessage(), error);
        }
    }
}