Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e10d562e authored by Raphael Kim's avatar Raphael Kim Committed by Automerger Merge Worker
Browse files

Merge "Stop raw channel thread before converting to a secure channel." into udc-dev am: 8d551e48

parents 6fb1150f 8d551e48
Loading
Loading
Loading
Loading
+55 −38
Original line number Diff line number Diff line
@@ -128,6 +128,9 @@ public class SecureChannel {
     * Start listening for incoming messages.
     */
    public void start() {
        if (DEBUG) {
            Slog.d(TAG, "Starting secure channel.");
        }
        new Thread(() -> {
            try {
                // 1. Wait for the next handshake message and process it.
@@ -151,14 +154,14 @@ public class SecureChannel {
                // TODO: Handle different types errors.

                Slog.e(TAG, "Secure channel encountered an error.", e);
                stop();
                close();
                mCallback.onError(e);
            }
        }).start();
    }

    /**
     * Stop listening to incoming messages and close the channel.
     * Stop listening to incoming messages.
     */
    public void stop() {
        if (DEBUG) {
@@ -166,7 +169,17 @@ public class SecureChannel {
        }
        mStopped = true;
        mInProgress = false;
    }

    /**
     * Stop listening to incoming messages and close the channel.
     */
    public void close() {
        stop();

        if (DEBUG) {
            Slog.d(TAG, "Closing secure channel.");
        }
        IoUtils.closeQuietly(mInput);
        IoUtils.closeQuietly(mOutput);
        KeyStoreUtils.cleanUp(mAlias);
@@ -240,12 +253,13 @@ public class SecureChannel {
            if (isSecured()) {
                Slog.d(TAG, "Waiting to receive next secure message.");
            } else {
                Slog.d(TAG, "Waiting to receive next message.");
                Slog.d(TAG, "Waiting to receive next " + expected + " message.");
            }
        }

        // TODO: Handle message timeout

        synchronized (mInput) {
            // Header is _not_ encrypted, but will be covered by MAC
            final byte[] headerBytes = new byte[HEADER_LENGTH];
            Streams.readFully(mInput, headerBytes);
@@ -261,8 +275,10 @@ public class SecureChannel {

            if (type != expected.mValue) {
                Streams.skipByReading(mInput, Long.MAX_VALUE);
            throw new SecureChannelException("Unexpected message type. Expected " + expected.name()
                    + "; Found " + MessageType.from(type).name() + ". Skipping rest of data.");
                throw new SecureChannelException(
                        "Unexpected message type. Expected " + expected.name()
                                + "; Found " + MessageType.from(type).name()
                                + ". Skipping rest of data.");
            }

            // Length of attached data is prepended as plaintext
@@ -285,15 +301,16 @@ public class SecureChannel {

            return mConnectionContext.decodeMessageFromPeer(data, headerBytes);
        }
    }

    private void sendMessage(MessageType messageType, byte[] payload)
    private void sendMessage(MessageType messageType, final byte[] payload)
            throws IOException, BadHandleException {
        synchronized (mOutput) {
            byte[] header = ByteBuffer.allocate(HEADER_LENGTH)
            final byte[] header = ByteBuffer.allocate(HEADER_LENGTH)
                    .putInt(VERSION)
                    .putShort(messageType.mValue)
                    .array();
            byte[] data = MessageType.shouldEncrypt(messageType)
            final byte[] data = MessageType.shouldEncrypt(messageType)
                    ? mConnectionContext.encodeMessageToPeer(payload, header)
                    : payload;
            mOutput.write(header);
+18 −11
Original line number Diff line number Diff line
@@ -46,6 +46,7 @@ import com.android.server.companion.AssociationStore;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
@@ -296,26 +297,32 @@ public class CompanionTransportManager {
        Slog.i(TAG, "Remote device SDK: " + remoteSdk + ", release:" + new String(remoteRelease));

        Transport transport = mTempTransport;
        mTempTransport = null;
        mTempTransport.stop();

        int sdk = Build.VERSION.SDK_INT;
        String release = Build.VERSION.RELEASE;
        if (remoteSdk == NON_ANDROID) {
        if (Build.isDebuggable()) {
            // Debug builds cannot pass attestation verification. Use hardcoded key instead.
            Slog.d(TAG, "Creating an unauthenticated secure channel");
            final byte[] testKey = "CDM".getBytes(StandardCharsets.UTF_8);
            transport = new SecureTransport(transport.getAssociationId(), transport.getFd(),
                    mContext, testKey, null);
        } else if (remoteSdk == NON_ANDROID) {
            // TODO: pass in a real preSharedKey
            transport = new SecureTransport(transport.getAssociationId(), transport.getFd(),
                    mContext, null, null);
        } else if (sdk < SECURE_CHANNEL_AVAILABLE_SDK
                || remoteSdk < SECURE_CHANNEL_AVAILABLE_SDK) {
            // TODO: depending on the release version, either
            //       1) using a RawTransport for old T versions
            //       2) or an Ukey2 handshaked transport for UKey2 backported T versions
        } else {
                    mContext, new byte[0], null);
        } else if (sdk >= SECURE_CHANNEL_AVAILABLE_SDK
                && remoteSdk >= SECURE_CHANNEL_AVAILABLE_SDK) {
            Slog.i(TAG, "Creating a secure channel");
            transport = new SecureTransport(transport.getAssociationId(), transport.getFd(),
                    mContext);
        } else {
            // TODO: depending on the release version, either
            //       1) using a RawTransport for old T versions
            //       2) or an Ukey2 handshaked transport for UKey2 backported T versions
        }
        addMessageListenersToTransport(transport);
        transport.start();
        }
        mTransports.put(transport.getAssociationId(), transport);
        // Doesn't need to notifyTransportsChanged here, it'll be done in attachSystemDataTransport
    }
+26 −10
Original line number Diff line number Diff line
@@ -36,6 +36,9 @@ class RawTransport extends Transport {

    @Override
    public void start() {
        if (DEBUG) {
            Slog.d(TAG, "Starting raw transport.");
        }
        new Thread(() -> {
            try {
                while (!mStopped) {
@@ -44,7 +47,7 @@ class RawTransport extends Transport {
            } catch (IOException e) {
                if (!mStopped) {
                    Slog.w(TAG, "Trouble during transport", e);
                    stop();
                    close();
                }
            }
        }).start();
@@ -52,8 +55,19 @@ class RawTransport extends Transport {

    @Override
    public void stop() {
        if (DEBUG) {
            Slog.d(TAG, "Stopping raw transport.");
        }
        mStopped = true;
    }

    @Override
    public void close() {
        stop();

        if (DEBUG) {
            Slog.d(TAG, "Closing raw transport.");
        }
        IoUtils.closeQuietly(mRemoteIn);
        IoUtils.closeQuietly(mRemoteOut);
    }
@@ -79,6 +93,7 @@ class RawTransport extends Transport {
    }

    private void receiveMessage() throws IOException {
        synchronized (mRemoteIn) {
            final byte[] headerBytes = new byte[HEADER_LENGTH];
            Streams.readFully(mRemoteIn, headerBytes);
            final ByteBuffer header = ByteBuffer.wrap(headerBytes);
@@ -91,3 +106,4 @@ class RawTransport extends Transport {
            handleMessage(message, sequence, data);
        }
    }
}
+21 −9
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@ import android.content.Context;
import android.os.ParcelFileDescriptor;
import android.util.Slog;

import com.android.internal.annotations.GuardedBy;
import com.android.server.companion.securechannel.AttestationVerifier;
import com.android.server.companion.securechannel.SecureChannel;

@@ -35,6 +36,7 @@ class SecureTransport extends Transport implements SecureChannel.Callback {

    private volatile boolean mShouldProcessRequests = false;

    @GuardedBy("mRequestQueue")
    private final BlockingQueue<byte[]> mRequestQueue = new ArrayBlockingQueue<>(100);

    SecureTransport(int associationId, ParcelFileDescriptor fd, Context context) {
@@ -59,6 +61,12 @@ class SecureTransport extends Transport implements SecureChannel.Callback {
        mShouldProcessRequests = false;
    }

    @Override
    public void close() {
        mSecureChannel.close();
        mShouldProcessRequests = false;
    }

    @Override
    public Future<byte[]> requestForResponse(int message, byte[] data) {
        // Check if channel is secured and start securing
@@ -85,6 +93,7 @@ class SecureTransport extends Transport implements SecureChannel.Callback {
        }

        // Queue up a message to send
        synchronized (mRequestQueue) {
            mRequestQueue.add(ByteBuffer.allocate(HEADER_LENGTH + data.length)
                    .putInt(message)
                    .putInt(sequence)
@@ -92,6 +101,7 @@ class SecureTransport extends Transport implements SecureChannel.Callback {
                    .put(data)
                    .array());
        }
    }

    @Override
    public void onSecureConnection() {
@@ -102,11 +112,13 @@ class SecureTransport extends Transport implements SecureChannel.Callback {
        new Thread(() -> {
            try {
                while (mShouldProcessRequests) {
                    synchronized (mRequestQueue) {
                        byte[] request = mRequestQueue.poll();
                        if (request != null) {
                            mSecureChannel.sendSecureMessage(request);
                        }
                    }
                }
            } catch (IOException e) {
                onError(e);
            }
+19 −2
Original line number Diff line number Diff line
@@ -110,13 +110,26 @@ public abstract class Transport {
        return mFd;
    }

    /**
     * Start listening to messages.
     */
    public abstract void start();

    /**
     * Soft stop listening to the incoming data without closing the streams.
     */
    public abstract void stop();

    /**
     * Stop listening to the incoming data and close the streams.
     */
    public abstract void close();

    protected abstract void sendMessage(int message, int sequence, @NonNull byte[] data)
            throws IOException;

    /**
     * Send a message
     * Send a message.
     */
    public void sendMessage(int message, @NonNull byte[] data) throws IOException {
        sendMessage(message, mNextSequence.incrementAndGet(), data);
@@ -170,7 +183,11 @@ public abstract class Transport {
                sendMessage(MESSAGE_RESPONSE_SUCCESS, sequence, data);
                break;
            }
            case MESSAGE_REQUEST_PLATFORM_INFO:
            case MESSAGE_REQUEST_PLATFORM_INFO: {
                callback(message, data);
                // DO NOT SEND A RESPONSE!
                break;
            }
            case MESSAGE_REQUEST_CONTEXT_SYNC: {
                callback(message, data);
                sendMessage(MESSAGE_RESPONSE_SUCCESS, sequence, EmptyArray.BYTE);