Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c289e31a authored by Raphael Kim's avatar Raphael Kim
Browse files

Introduce hidden API to disable secure channel for back-compatibility

Bug: 270014877
Test: atest CompanionTests:SystemDataTransportTest
Change-Id: If583239888a3431805fe0faaf14345fdcd0f8414
parent db2a744d
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -1194,6 +1194,20 @@ public final class CompanionDeviceManager {
        }
    }

    /**
     * Enable or disable secure transport for testing. Defaults to enabled.
     *
     * @param enabled true to enable. false to disable.
     * @hide
     */
    public void enableSecureTransport(boolean enabled) {
        try {
            mService.enableSecureTransport(enabled);
        } catch (RemoteException e) {
            throw e.rethrowFromSystemServer();
        }
    }

    private boolean checkFeaturePresent() {
        boolean featurePresent = mService != null;
        if (!featurePresent && DEBUG) {
+2 −0
Original line number Diff line number Diff line
@@ -88,4 +88,6 @@ interface ICompanionDeviceManager {
    void enableSystemDataSync(int associationId, int flags);

    void disableSystemDataSync(int associationId, int flags);

    void enableSecureTransport(boolean enabled);
}
+2 −0
Original line number Diff line number Diff line
@@ -60,6 +60,7 @@ public class SystemDataTransportTest extends InstrumentationTestCase {
        mContext = getInstrumentation().getTargetContext();
        mCdm = mContext.getSystemService(CompanionDeviceManager.class);
        mAssociationId = createAssociation();
        mCdm.enableSecureTransport(false);
    }

    @Override
@@ -67,6 +68,7 @@ public class SystemDataTransportTest extends InstrumentationTestCase {
        super.tearDown();

        mCdm.disassociate(mAssociationId);
        mCdm.enableSecureTransport(true);
    }

    public void testPingHandRolled() {
+5 −0
Original line number Diff line number Diff line
@@ -725,6 +725,11 @@ public class CompanionDeviceManagerService extends SystemService {
            mAssociationRequestsProcessor.disableSystemDataSync(associationId, flags);
        }

        @Override
        public void enableSecureTransport(boolean enabled) {
            mTransportManager.enableSecureTransport(enabled);
        }

        @Override
        public void notifyDeviceAppeared(int associationId) {
            if (DEBUG) Log.i(TAG, "notifyDevice_Appeared() id=" + associationId);
+204 −78
Original line number Diff line number Diff line
@@ -36,9 +36,13 @@ import com.android.internal.annotations.GuardedBy;
import com.android.server.LocalServices;
import com.android.server.companion.securechannel.SecureChannel;

import libcore.io.IoUtils;
import libcore.io.Streams;
import libcore.util.EmptyArray;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
@@ -60,6 +64,8 @@ public class CompanionTransportManager {
    private static final int MESSAGE_RESPONSE_SUCCESS = 0x33838567; // !SUC
    private static final int MESSAGE_RESPONSE_FAILURE = 0x33706573; // !FAI

    private boolean mSecureTransportEnabled = true;

    private static boolean isRequest(int message) {
        return (message & 0xFF000000) == 0x63000000;
    }
@@ -119,7 +125,13 @@ public class CompanionTransportManager {
                detachSystemDataTransport(packageName, userId, associationId);
            }

            final Transport transport = new Transport(associationId, fd);
            final Transport transport;
            if (isSecureTransportEnabled(associationId)) {
                transport = new SecureTransport(associationId, fd);
            } else {
                transport = new RawTransport(associationId, fd);
            }

            transport.start();
            mTransports.put(associationId, transport);
        }
@@ -139,44 +151,56 @@ public class CompanionTransportManager {
    public Future<?> requestPermissionRestore(int associationId, byte[] data) {
        synchronized (mTransports) {
            final Transport transport = mTransports.get(associationId);
            if (transport != null) {
                return transport.requestForResponse(MESSAGE_REQUEST_PERMISSION_RESTORE, data);
            } else {
            if (transport == null) {
                return CompletableFuture.failedFuture(new IOException("Missing transport"));
            }

            return transport.requestForResponse(MESSAGE_REQUEST_PERMISSION_RESTORE, data);
        }
    }

    private class Transport implements SecureChannel.Callback {
        private final int mAssociationId;
    /**
     * @hide
     */
    public void enableSecureTransport(boolean enabled) {
        this.mSecureTransportEnabled = enabled;
    }

        private final SecureChannel mSecureChannel;
        private final AtomicInteger mNextSequence = new AtomicInteger();
    private boolean isSecureTransportEnabled(int associationId) {
        boolean enabled = !Build.IS_DEBUGGABLE || mSecureTransportEnabled;

        private volatile boolean mShouldProcessRequests = false;
        // TODO: version comparison logic
        return enabled;
    }

    // TODO: Make Transport inner classes into standalone classes.
    private abstract class Transport {
        protected final int mAssociationId;
        protected final InputStream mRemoteIn;
        protected final OutputStream mRemoteOut;

        @GuardedBy("mPendingRequests")
        private final SparseArray<CompletableFuture<byte[]>> mPendingRequests = new SparseArray<>();
        private final BlockingQueue<Runnable> mRequestQueue = new ArrayBlockingQueue<>(100);
        protected final SparseArray<CompletableFuture<byte[]>> mPendingRequests =
                new SparseArray<>();
        protected final AtomicInteger mNextSequence = new AtomicInteger();

        public Transport(int associationId, ParcelFileDescriptor fd) {
            mAssociationId = associationId;
            mSecureChannel = new SecureChannel(
        Transport(int associationId, ParcelFileDescriptor fd) {
            this(associationId,
                    new ParcelFileDescriptor.AutoCloseInputStream(fd),
                    new ParcelFileDescriptor.AutoCloseOutputStream(fd),
                    this,
                    mContext
            );
                    new ParcelFileDescriptor.AutoCloseOutputStream(fd));
        }

        public void start() {
            mSecureChannel.start();
        Transport(int associationId, InputStream in, OutputStream out) {
            this.mAssociationId = associationId;
            this.mRemoteIn = in;
            this.mRemoteOut = out;
        }

        public void stop() {
            mSecureChannel.stop();
            mShouldProcessRequests = false;
        }
        public abstract void start();
        public abstract void stop();

        protected abstract void sendMessage(int message, int sequence, @NonNull byte[] data)
                throws IOException;

        public Future<byte[]> requestForResponse(int message, byte[] data) {
            if (DEBUG) Slog.d(TAG, "Requesting for response");
@@ -186,8 +210,6 @@ public class CompanionTransportManager {
                mPendingRequests.put(sequence, pending);
            }

            // Queue up a task
            mRequestQueue.add(() -> {
            try {
                sendMessage(message, sequence, data);
            } catch (IOException e) {
@@ -196,40 +218,29 @@ public class CompanionTransportManager {
                }
                pending.completeExceptionally(e);
            }
            });

            // Check if channel is secured and start securing
            if (!mShouldProcessRequests) {
                Slog.d(TAG, "Establishing secure connection.");
                try {
                    mSecureChannel.establishSecureConnection();
                } catch (Exception e) {
                    synchronized (mPendingRequests) {
                        mPendingRequests.remove(sequence);
                    }
                    pending.completeExceptionally(e);
                }
            }

            return pending;
        }

        private void sendMessage(int message, int sequence, @NonNull byte[] data)
        protected final void handleMessage(int message, int sequence, @NonNull byte[] data)
                throws IOException {

            if (DEBUG) {
                Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message)
                Slog.d(TAG, "Received message 0x" + Integer.toHexString(message)
                        + " sequence " + sequence + " length " + data.length
                        + " to association " + mAssociationId);
                        + " from association " + mAssociationId);
            }

            final ByteBuffer payload = ByteBuffer.allocate(HEADER_LENGTH + data.length)
                    .putInt(message)
                    .putInt(sequence)
                    .putInt(data.length)
                    .put(data);

            mSecureChannel.sendSecureMessage(payload.array());
            if (isRequest(message)) {
                try {
                    processRequest(message, sequence, data);
                } catch (IOException e) {
                    Slog.w(TAG, "Failed to respond to 0x" + Integer.toHexString(message), e);
                }
            } else if (isResponse(message)) {
                processResponse(message, sequence, data);
            } else {
                Slog.w(TAG, "Unknown message 0x" + Integer.toHexString(message));
            }
        }

        private void processRequest(int message, int sequence, byte[] data)
@@ -287,6 +298,129 @@ public class CompanionTransportManager {
                }
            }
        }
    }

    private class RawTransport extends Transport {
        private volatile boolean mStopped;

        RawTransport(int associationId, ParcelFileDescriptor fd) {
            super(associationId, fd);
        }

        @Override
        public void start() {
            new Thread(() -> {
                try {
                    while (!mStopped) {
                        receiveMessage();
                    }
                } catch (IOException e) {
                    if (!mStopped) {
                        Slog.w(TAG, "Trouble during transport", e);
                        stop();
                    }
                }
            }).start();
        }

        @Override
        public void stop() {
            mStopped = true;

            IoUtils.closeQuietly(mRemoteIn);
            IoUtils.closeQuietly(mRemoteOut);
        }

        @Override
        protected void sendMessage(int message, int sequence, @NonNull byte[] data)
                throws IOException {
            if (DEBUG) {
                Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message)
                        + " sequence " + sequence + " length " + data.length
                        + " to association " + mAssociationId);
            }

            synchronized (mRemoteOut) {
                final ByteBuffer header = ByteBuffer.allocate(HEADER_LENGTH)
                        .putInt(message)
                        .putInt(sequence)
                        .putInt(data.length);
                mRemoteOut.write(header.array());
                mRemoteOut.write(data);
                mRemoteOut.flush();
            }
        }

        private void receiveMessage() throws IOException {
            final byte[] headerBytes = new byte[HEADER_LENGTH];
            Streams.readFully(mRemoteIn, headerBytes);
            final ByteBuffer header = ByteBuffer.wrap(headerBytes);
            final int message = header.getInt();
            final int sequence = header.getInt();
            final int length = header.getInt();
            final byte[] data = new byte[length];
            Streams.readFully(mRemoteIn, data);

            handleMessage(message, sequence, data);
        }
    }

    private class SecureTransport extends Transport implements SecureChannel.Callback {
        private final SecureChannel mSecureChannel;

        private volatile boolean mShouldProcessRequests = false;

        private final BlockingQueue<byte[]> mRequestQueue = new ArrayBlockingQueue<>(100);

        SecureTransport(int associationId, ParcelFileDescriptor fd) {
            super(associationId, fd);
            mSecureChannel = new SecureChannel(mRemoteIn, mRemoteOut, this, mContext);
        }

        @Override
        public void start() {
            mSecureChannel.start();
        }

        @Override
        public void stop() {
            mSecureChannel.stop();
            mShouldProcessRequests = false;
        }

        @Override
        public Future<byte[]> requestForResponse(int message, byte[] data) {
            // Check if channel is secured and start securing
            if (!mShouldProcessRequests) {
                Slog.d(TAG, "Establishing secure connection.");
                try {
                    mSecureChannel.establishSecureConnection();
                } catch (Exception e) {
                    Slog.w(TAG, "Failed to initiate secure channel handshake.", e);
                    onError(e);
                }
            }

            return super.requestForResponse(message, data);
        }

        @Override
        protected void sendMessage(int message, int sequence, @NonNull byte[] data)
                throws IOException {
            if (DEBUG) {
                Slog.d(TAG, "Queueing message 0x" + Integer.toHexString(message)
                        + " sequence " + sequence + " length " + data.length
                        + " to association " + mAssociationId);
            }

            // Queue up a message to send
            mRequestQueue.add(ByteBuffer.allocate(HEADER_LENGTH + data.length)
                    .putInt(message)
                    .putInt(sequence)
                    .putInt(data.length)
                    .put(data)
                    .array());
        }

        @Override
        public void onSecureConnection() {
@@ -295,12 +429,16 @@ public class CompanionTransportManager {

            // TODO: find a better way to handle incoming requests than a dedicated thread.
            new Thread(() -> {
                try {
                    while (mShouldProcessRequests) {
                    Runnable task = mRequestQueue.poll();
                    if (task != null) {
                        task.run();
                        byte[] request = mRequestQueue.poll();
                        if (request != null) {
                            mSecureChannel.sendSecureMessage(request);
                        }
                    }
                } catch (IOException e) {
                    onError(e);
                }
            }).start();
        }

@@ -310,25 +448,13 @@ public class CompanionTransportManager {
            final int message = payload.getInt();
            final int sequence = payload.getInt();
            final int length = payload.getInt();

            if (DEBUG) {
                Slog.d(TAG, "Received message 0x" + Integer.toHexString(message)
                        + " sequence " + sequence + " length " + length
                        + " from association " + mAssociationId);
            }

            final byte[] content = new byte[length];
            payload.get(content);
            if (isRequest(message)) {

            try {
                    processRequest(message, sequence, content);
                } catch (IOException e) {
                    Slog.w(TAG, "Failed to respond to 0x" + Integer.toHexString(message), e);
                }
            } else if (isResponse(message)) {
                processResponse(message, sequence, content);
            } else {
                Slog.w(TAG, "Unknown message 0x" + Integer.toHexString(message));
                handleMessage(message, sequence, content);
            } catch (IOException error) {
                onError(error);
            }
        }