Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e865f53b authored by Raphael Kim's avatar Raphael Kim
Browse files

[Secure Channel] Resolve Ukey2 Client Init collision edge-case

Bug: 281069627
Test: Manually on two devices
Change-Id: Ic86227af8ff55b0b8aca2a2eb627d1e8a90ef91a
parent 4debf56a
Loading
Loading
Loading
Loading
+89 −10
Original line number Diff line number Diff line
@@ -53,8 +53,6 @@ public class SecureChannel {
    private static final int VERSION = 1;
    private static final int HEADER_LENGTH = 6;

    private static final String HANDSHAKE_PROTOCOL = "AES_256_CBC-HMAC_SHA256";

    private final InputStream mInput;
    private final OutputStream mOutput;
    private final Callback mCallback;
@@ -62,14 +60,16 @@ public class SecureChannel {
    private final AttestationVerifier mVerifier;

    private volatile boolean mStopped;
    private boolean mInProgress;
    private volatile boolean mInProgress;

    private Role mRole;
    private byte[] mClientInit;
    private D2DHandshakeContext mHandshakeContext;
    private D2DConnectionContextV1 mConnectionContext;

    private String mAlias;
    private int mVerificationResult;
    private boolean mPskVerified;


    /**
@@ -202,8 +202,8 @@ public class SecureChannel {
        }

        try {
            initiateHandshake();
            mInProgress = true;
            initiateHandshake();
        } catch (BadHandleException e) {
            throw new SecureChannelException("Failed to initiate handshake protocol.", e);
        }
@@ -329,12 +329,56 @@ public class SecureChannel {

        mRole = Role.Initiator;
        mHandshakeContext = D2DHandshakeContext.forInitiator();
        mClientInit = mHandshakeContext.getNextHandshakeMessage();

        // Send Client Init
        if (DEBUG) {
            Slog.d(TAG, "Sending Ukey2 Client Init message");
        }
        sendMessage(MessageType.HANDSHAKE_INIT, mHandshakeContext.getNextHandshakeMessage());
        sendMessage(MessageType.HANDSHAKE_INIT, constructHandshakeInitMessage(mClientInit));
    }

    // In an occasion where both participants try to initiate a handshake, resolve the conflict
    // with a dice roll simulated by the message byte content comparison.
    // The higher value wins! (a.k.a. gets to be the initiator)
    private byte[] handleHandshakeCollision(byte[] handshakeInitMessage)
            throws IOException, HandshakeException, BadHandleException, CryptoException {

        // First byte indicates message type; 0 = CLIENT INIT, 1 = SERVER INIT
        ByteBuffer buffer = ByteBuffer.wrap(handshakeInitMessage);
        boolean isClientInit = buffer.get() == 0;
        byte[] handshakeMessage = new byte[buffer.remaining()];
        buffer.get(handshakeMessage);

        // If received message is Server Init or current role is Responder, then there was
        // no collision. Return extracted handshake message.
        if (mHandshakeContext == null || !isClientInit) {
            return handshakeMessage;
        }

        Slog.w(TAG, "Detected a Ukey2 handshake role collision. Negotiating a role.");

        // if received message is "larger" than the sent message, then reset the handshake context.
        if (compareByteArray(mClientInit, handshakeMessage) < 0) {
            Slog.d(TAG, "Assigned: Responder");
            mHandshakeContext = null;
            return handshakeMessage;
        } else {
            Slog.d(TAG, "Assigned: Initiator; Discarding received Client Init");

            // Wait for another init message after discarding the client init
            ByteBuffer nextInitMessage = ByteBuffer.wrap(readMessage(MessageType.HANDSHAKE_INIT));

            // Throw if this message is a Client Init again; 0 = CLIENT INIT, 1 = SERVER INIT
            if (nextInitMessage.get() == 0) {
                // This should never happen!
                throw new HandshakeException("Failed to resolve Ukey2 handshake role collision.");
            }
            byte[] nextHandshakeMessage = new byte[nextInitMessage.remaining()];
            nextInitMessage.get(nextHandshakeMessage);

            return nextHandshakeMessage;
        }
    }

    private void exchangeHandshake()
@@ -345,8 +389,15 @@ public class SecureChannel {
        }

        // Waiting for message
        byte[] handshakeMessage = readMessage(MessageType.HANDSHAKE_INIT);
        byte[] handshakeInitMessage = readMessage(MessageType.HANDSHAKE_INIT);

        // Mark "in-progress" upon receiving the first message
        mInProgress = true;

        // Handle a potential collision where both devices tried to initiate a connection
        byte[] handshakeMessage = handleHandshakeCollision(handshakeInitMessage);

        // Proceed with the rest of Ukey2 handshake
        if (mHandshakeContext == null) { // Server-side logic
            mRole = Role.Responder;
            mHandshakeContext = D2DHandshakeContext.forResponder();
@@ -361,7 +412,8 @@ public class SecureChannel {
            if (DEBUG) {
                Slog.d(TAG, "Sending Ukey2 Server Init message");
            }
            sendMessage(MessageType.HANDSHAKE_INIT, mHandshakeContext.getNextHandshakeMessage());
            sendMessage(MessageType.HANDSHAKE_INIT,
                    constructHandshakeInitMessage(mHandshakeContext.getNextHandshakeMessage()));

            // Receive Client Finish
            if (DEBUG) {
@@ -418,9 +470,9 @@ public class SecureChannel {
                ? Role.Responder
                : Role.Initiator,
                mPreSharedKey);
        boolean authenticated = Arrays.equals(receivedAuthToken, expectedAuthToken);
        mPskVerified = Arrays.equals(receivedAuthToken, expectedAuthToken);

        if (!authenticated) {
        if (!mPskVerified) {
            throw new SecureChannelException("Failed to verify the hash of pre-shared key.");
        }

@@ -477,10 +529,21 @@ public class SecureChannel {
    }

    private boolean isSecured() {
        // Is ukey-2 encrypted
        if (mConnectionContext == null) {
            return false;
        }
        return mVerifier == null || mVerificationResult == RESULT_SUCCESS;
        // Is authenticated
        return mPskVerified || mVerificationResult == RESULT_SUCCESS;
    }

    // First byte indicates message type; 0 = CLIENT INIT, 1 = SERVER INIT
    // This information is needed to help resolve potential role collision.
    private byte[] constructHandshakeInitMessage(byte[] message) {
        return ByteBuffer.allocate(1 + message.length)
                .put((byte) (Role.Initiator.equals(mRole) ? 0 : 1))
                .put(message)
                .array();
    }

    private byte[] constructToken(D2DHandshakeContext.Role role, byte[] authValue)
@@ -494,6 +557,22 @@ public class SecureChannel {
                .array());
    }

    // Arbitrary comparator
    private int compareByteArray(byte[] a, byte[] b) {
        if (a == b) {
            return 0;
        }
        if (a.length != b.length) {
            return a.length - b.length;
        }
        for (int i = 0; i < a.length; i++) {
            if (a[i] != b[i]) {
                return a[i] - b[i];
            }
        }
        return 0;
    }

    private String generateAlias() {
        String alias;
        do {
+10 −11
Original line number Diff line number Diff line
@@ -29,7 +29,6 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Future;

class SecureTransport extends Transport implements SecureChannel.Callback {
    private final SecureChannel mSecureChannel;
@@ -70,7 +69,10 @@ class SecureTransport extends Transport implements SecureChannel.Callback {
    @Override
    protected void sendMessage(int message, int sequence, @NonNull byte[] data)
            throws IOException {
        // Check if channel is secured; otherwise start securing
        if (!mShouldProcessRequests) {
            establishSecureConnection();
        }

        if (DEBUG) {
            Slog.d(TAG, "Queueing message 0x" + Integer.toHexString(message)
@@ -90,8 +92,6 @@ class SecureTransport extends Transport implements SecureChannel.Callback {
    }

    private void establishSecureConnection() {
        // Check if channel is secured and start securing
        if (!mShouldProcessRequests) {
        Slog.d(TAG, "Establishing secure connection.");
        try {
            mSecureChannel.establishSecureConnection();
@@ -100,7 +100,6 @@ class SecureTransport extends Transport implements SecureChannel.Callback {
            onError(e);
        }
    }
    }

    @Override
    public void onSecureConnection() {