Loading services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java +99 −81 Original line number Diff line number Diff line Loading @@ -34,15 +34,14 @@ import android.util.SparseArray; import com.android.internal.annotations.GuardedBy; import com.android.server.LocalServices; import com.android.server.companion.securechannel.SecureChannel; import libcore.io.IoUtils; import libcore.io.Streams; import libcore.util.EmptyArray; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; Loading @@ -54,8 +53,6 @@ public class CompanionTransportManager { private static final boolean DEBUG = true; private static final int HEADER_LENGTH = 12; // TODO: refactor message processing to use streams to remove this limit private static final int MAX_PAYLOAD_LENGTH = 1_000_000; private static final int MESSAGE_REQUEST_PING = 0x63807378; // ?PIN private static final int MESSAGE_REQUEST_PERMISSION_RESTORE = 0x63826983; // ?RES Loading Loading @@ -150,53 +147,47 @@ public class CompanionTransportManager { } } private class Transport { private class Transport implements SecureChannel.Callback { private final int mAssociationId; private final InputStream mRemoteIn; private final OutputStream mRemoteOut; private final SecureChannel mSecureChannel; private final AtomicInteger mNextSequence = new AtomicInteger(); private volatile boolean mShouldProcessRequests = false; @GuardedBy("mPendingRequests") private final SparseArray<CompletableFuture<byte[]>> mPendingRequests = new SparseArray<>(); private volatile boolean mStopped; private final BlockingQueue<Runnable> mRequestQueue = new ArrayBlockingQueue<>(100); public Transport(int associationId, ParcelFileDescriptor fd) { mAssociationId = associationId; mRemoteIn = new ParcelFileDescriptor.AutoCloseInputStream(fd); mRemoteOut = new ParcelFileDescriptor.AutoCloseOutputStream(fd); mSecureChannel = new SecureChannel( new ParcelFileDescriptor.AutoCloseInputStream(fd), new ParcelFileDescriptor.AutoCloseOutputStream(fd), this, mContext ); } public void start() { new Thread(() -> { try { while (!mStopped) { receiveMessage(); } } catch (IOException e) { if (!mStopped) { Slog.w(TAG, "Trouble during transport", e); stop(); } } }).start(); mSecureChannel.start(); } public void stop() { mStopped = true; IoUtils.closeQuietly(mRemoteIn); IoUtils.closeQuietly(mRemoteOut); mSecureChannel.stop(); mShouldProcessRequests = false; } public Future<byte[]> requestForResponse(int message, byte[] data) { if (DEBUG) Slog.d(TAG, "Requesting for response"); final int sequence = mNextSequence.incrementAndGet(); final CompletableFuture<byte[]> pending = new CompletableFuture<>(); synchronized (mPendingRequests) { mPendingRequests.put(sequence, pending); } // Queue up a task mRequestQueue.add(() -> { try { sendMessage(message, sequence, data); } catch (IOException e) { Loading @@ -205,63 +196,40 @@ public class CompanionTransportManager { } pending.completeExceptionally(e); } }); // Check if channel is secured and start securing if (!mShouldProcessRequests) { Slog.d(TAG, "Establishing secure connection."); try { mSecureChannel.establishSecureConnection(); } catch (Exception e) { synchronized (mPendingRequests) { mPendingRequests.remove(sequence); } pending.completeExceptionally(e); } } return pending; } private void sendMessage(int message, int sequence, @NonNull byte[] data) throws IOException { if (DEBUG) { Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message) + " sequence " + sequence + " length " + data.length + " to association " + mAssociationId); } synchronized (mRemoteOut) { final ByteBuffer header = ByteBuffer.allocate(HEADER_LENGTH) final ByteBuffer payload = ByteBuffer.allocate(HEADER_LENGTH + data.length) .putInt(message) .putInt(sequence) .putInt(data.length); mRemoteOut.write(header.array()); mRemoteOut.write(data); mRemoteOut.flush(); } } private void receiveMessage() throws IOException { if (DEBUG) { Slog.d(TAG, "Waiting for next message..."); } final byte[] headerBytes = new byte[HEADER_LENGTH]; Streams.readFully(mRemoteIn, headerBytes); final ByteBuffer header = ByteBuffer.wrap(headerBytes); final int message = header.getInt(); final int sequence = header.getInt(); final int length = header.getInt(); if (DEBUG) { Slog.d(TAG, "Received message 0x" + Integer.toHexString(message) + " sequence " + sequence + " length " + length + " from association " + mAssociationId); } if (length > MAX_PAYLOAD_LENGTH) { Slog.w(TAG, "Ignoring message 0x" + Integer.toHexString(message) + " sequence " + sequence + " length " + length + " from association " + mAssociationId + " beyond maximum length"); Streams.skipByReading(mRemoteIn, length); return; } .putInt(data.length) .put(data); final byte[] data = new byte[length]; Streams.readFully(mRemoteIn, data); if (isRequest(message)) { processRequest(message, sequence, data); } else if (isResponse(message)) { processResponse(message, sequence, data); } else { Slog.w(TAG, "Unknown message 0x" + Integer.toHexString(message)); } mSecureChannel.sendSecureMessage(payload.array()); } private void processRequest(int message, int sequence, byte[] data) Loading Loading @@ -319,5 +287,55 @@ public class CompanionTransportManager { } } } @Override public void onSecureConnection() { mShouldProcessRequests = true; Slog.d(TAG, "Secure connection established."); // TODO: find a better way to handle incoming requests than a dedicated thread. new Thread(() -> { while (mShouldProcessRequests) { Runnable task = mRequestQueue.poll(); if (task != null) { task.run(); } } }).start(); } @Override public void onSecureMessageReceived(byte[] data) { final ByteBuffer payload = ByteBuffer.wrap(data); final int message = payload.getInt(); final int sequence = payload.getInt(); final int length = payload.getInt(); if (DEBUG) { Slog.d(TAG, "Received message 0x" + Integer.toHexString(message) + " sequence " + sequence + " length " + length + " from association " + mAssociationId); } final byte[] content = new byte[length]; payload.get(content); if (isRequest(message)) { try { processRequest(message, sequence, content); } catch (IOException e) { Slog.w(TAG, "Failed to respond to 0x" + Integer.toHexString(message), e); } } else if (isResponse(message)) { processResponse(message, sequence, content); } else { Slog.w(TAG, "Unknown message 0x" + Integer.toHexString(message)); } } @Override public void onError(Throwable error) { mShouldProcessRequests = false; Slog.e(TAG, error.getMessage(), error); } } } Loading
services/companion/java/com/android/server/companion/transport/CompanionTransportManager.java +99 −81 Original line number Diff line number Diff line Loading @@ -34,15 +34,14 @@ import android.util.SparseArray; import com.android.internal.annotations.GuardedBy; import com.android.server.LocalServices; import com.android.server.companion.securechannel.SecureChannel; import libcore.io.IoUtils; import libcore.io.Streams; import libcore.util.EmptyArray; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; Loading @@ -54,8 +53,6 @@ public class CompanionTransportManager { private static final boolean DEBUG = true; private static final int HEADER_LENGTH = 12; // TODO: refactor message processing to use streams to remove this limit private static final int MAX_PAYLOAD_LENGTH = 1_000_000; private static final int MESSAGE_REQUEST_PING = 0x63807378; // ?PIN private static final int MESSAGE_REQUEST_PERMISSION_RESTORE = 0x63826983; // ?RES Loading Loading @@ -150,53 +147,47 @@ public class CompanionTransportManager { } } private class Transport { private class Transport implements SecureChannel.Callback { private final int mAssociationId; private final InputStream mRemoteIn; private final OutputStream mRemoteOut; private final SecureChannel mSecureChannel; private final AtomicInteger mNextSequence = new AtomicInteger(); private volatile boolean mShouldProcessRequests = false; @GuardedBy("mPendingRequests") private final SparseArray<CompletableFuture<byte[]>> mPendingRequests = new SparseArray<>(); private volatile boolean mStopped; private final BlockingQueue<Runnable> mRequestQueue = new ArrayBlockingQueue<>(100); public Transport(int associationId, ParcelFileDescriptor fd) { mAssociationId = associationId; mRemoteIn = new ParcelFileDescriptor.AutoCloseInputStream(fd); mRemoteOut = new ParcelFileDescriptor.AutoCloseOutputStream(fd); mSecureChannel = new SecureChannel( new ParcelFileDescriptor.AutoCloseInputStream(fd), new ParcelFileDescriptor.AutoCloseOutputStream(fd), this, mContext ); } public void start() { new Thread(() -> { try { while (!mStopped) { receiveMessage(); } } catch (IOException e) { if (!mStopped) { Slog.w(TAG, "Trouble during transport", e); stop(); } } }).start(); mSecureChannel.start(); } public void stop() { mStopped = true; IoUtils.closeQuietly(mRemoteIn); IoUtils.closeQuietly(mRemoteOut); mSecureChannel.stop(); mShouldProcessRequests = false; } public Future<byte[]> requestForResponse(int message, byte[] data) { if (DEBUG) Slog.d(TAG, "Requesting for response"); final int sequence = mNextSequence.incrementAndGet(); final CompletableFuture<byte[]> pending = new CompletableFuture<>(); synchronized (mPendingRequests) { mPendingRequests.put(sequence, pending); } // Queue up a task mRequestQueue.add(() -> { try { sendMessage(message, sequence, data); } catch (IOException e) { Loading @@ -205,63 +196,40 @@ public class CompanionTransportManager { } pending.completeExceptionally(e); } }); // Check if channel is secured and start securing if (!mShouldProcessRequests) { Slog.d(TAG, "Establishing secure connection."); try { mSecureChannel.establishSecureConnection(); } catch (Exception e) { synchronized (mPendingRequests) { mPendingRequests.remove(sequence); } pending.completeExceptionally(e); } } return pending; } private void sendMessage(int message, int sequence, @NonNull byte[] data) throws IOException { if (DEBUG) { Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message) + " sequence " + sequence + " length " + data.length + " to association " + mAssociationId); } synchronized (mRemoteOut) { final ByteBuffer header = ByteBuffer.allocate(HEADER_LENGTH) final ByteBuffer payload = ByteBuffer.allocate(HEADER_LENGTH + data.length) .putInt(message) .putInt(sequence) .putInt(data.length); mRemoteOut.write(header.array()); mRemoteOut.write(data); mRemoteOut.flush(); } } private void receiveMessage() throws IOException { if (DEBUG) { Slog.d(TAG, "Waiting for next message..."); } final byte[] headerBytes = new byte[HEADER_LENGTH]; Streams.readFully(mRemoteIn, headerBytes); final ByteBuffer header = ByteBuffer.wrap(headerBytes); final int message = header.getInt(); final int sequence = header.getInt(); final int length = header.getInt(); if (DEBUG) { Slog.d(TAG, "Received message 0x" + Integer.toHexString(message) + " sequence " + sequence + " length " + length + " from association " + mAssociationId); } if (length > MAX_PAYLOAD_LENGTH) { Slog.w(TAG, "Ignoring message 0x" + Integer.toHexString(message) + " sequence " + sequence + " length " + length + " from association " + mAssociationId + " beyond maximum length"); Streams.skipByReading(mRemoteIn, length); return; } .putInt(data.length) .put(data); final byte[] data = new byte[length]; Streams.readFully(mRemoteIn, data); if (isRequest(message)) { processRequest(message, sequence, data); } else if (isResponse(message)) { processResponse(message, sequence, data); } else { Slog.w(TAG, "Unknown message 0x" + Integer.toHexString(message)); } mSecureChannel.sendSecureMessage(payload.array()); } private void processRequest(int message, int sequence, byte[] data) Loading Loading @@ -319,5 +287,55 @@ public class CompanionTransportManager { } } } @Override public void onSecureConnection() { mShouldProcessRequests = true; Slog.d(TAG, "Secure connection established."); // TODO: find a better way to handle incoming requests than a dedicated thread. new Thread(() -> { while (mShouldProcessRequests) { Runnable task = mRequestQueue.poll(); if (task != null) { task.run(); } } }).start(); } @Override public void onSecureMessageReceived(byte[] data) { final ByteBuffer payload = ByteBuffer.wrap(data); final int message = payload.getInt(); final int sequence = payload.getInt(); final int length = payload.getInt(); if (DEBUG) { Slog.d(TAG, "Received message 0x" + Integer.toHexString(message) + " sequence " + sequence + " length " + length + " from association " + mAssociationId); } final byte[] content = new byte[length]; payload.get(content); if (isRequest(message)) { try { processRequest(message, sequence, content); } catch (IOException e) { Slog.w(TAG, "Failed to respond to 0x" + Integer.toHexString(message), e); } } else if (isResponse(message)) { processResponse(message, sequence, content); } else { Slog.w(TAG, "Unknown message 0x" + Integer.toHexString(message)); } } @Override public void onError(Throwable error) { mShouldProcessRequests = false; Slog.e(TAG, error.getMessage(), error); } } }