Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 52b549b8 authored by Raphael Kim's avatar Raphael Kim Committed by Automerger Merge Worker
Browse files

Merge "CDM Transport clean-up" into udc-dev am: 9fe0570e

parents f95a69de 9fe0570e
Loading
Loading
Loading
Loading
+5 −317
Original line number Diff line number Diff line
@@ -18,51 +18,31 @@ package com.android.server.companion.transport;

import static android.Manifest.permission.DELIVER_COMPANION_MESSAGES;

import static com.android.server.companion.transport.Transport.MESSAGE_REQUEST_PERMISSION_RESTORE;

import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.SuppressLint;
import android.app.ActivityManagerInternal;
import android.content.Context;
import android.content.pm.ApplicationInfo;
import android.content.pm.PackageManager;
import android.content.pm.PackageManager.NameNotFoundException;
import android.os.Binder;
import android.os.Build;
import android.os.ParcelFileDescriptor;
import android.util.Slog;
import android.util.SparseArray;

import com.android.internal.annotations.GuardedBy;
import com.android.server.LocalServices;
import com.android.server.companion.securechannel.SecureChannel;

import libcore.io.IoUtils;
import libcore.io.Streams;
import libcore.util.EmptyArray;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;

@SuppressLint("LongLogTag")
public class CompanionTransportManager {
    private static final String TAG = "CDM_CompanionTransportManager";
    // TODO: flip to false
    private static final boolean DEBUG = true;

    private static final int HEADER_LENGTH = 12;

    private static final int MESSAGE_REQUEST_PING = 0x63807378; // ?PIN
    private static final int MESSAGE_REQUEST_PERMISSION_RESTORE = 0x63826983; // ?RES

    private static final int MESSAGE_RESPONSE_SUCCESS = 0x33838567; // !SUC
    private static final int MESSAGE_RESPONSE_FAILURE = 0x33706573; // !FAI
    private static final boolean DEBUG = false;

    private boolean mSecureTransportEnabled = true;

@@ -127,9 +107,9 @@ public class CompanionTransportManager {

            final Transport transport;
            if (isSecureTransportEnabled(associationId)) {
                transport = new SecureTransport(associationId, fd);
                transport = new SecureTransport(associationId, fd, mContext, mListener);
            } else {
                transport = new RawTransport(associationId, fd);
                transport = new RawTransport(associationId, fd, mContext, mListener);
            }

            transport.start();
@@ -172,296 +152,4 @@ public class CompanionTransportManager {
        // TODO: version comparison logic
        return enabled;
    }

    // TODO: Make Transport inner classes into standalone classes.
    private abstract class Transport {
        protected final int mAssociationId;
        protected final InputStream mRemoteIn;
        protected final OutputStream mRemoteOut;

        @GuardedBy("mPendingRequests")
        protected final SparseArray<CompletableFuture<byte[]>> mPendingRequests =
                new SparseArray<>();
        protected final AtomicInteger mNextSequence = new AtomicInteger();

        Transport(int associationId, ParcelFileDescriptor fd) {
            this(associationId,
                    new ParcelFileDescriptor.AutoCloseInputStream(fd),
                    new ParcelFileDescriptor.AutoCloseOutputStream(fd));
        }

        Transport(int associationId, InputStream in, OutputStream out) {
            this.mAssociationId = associationId;
            this.mRemoteIn = in;
            this.mRemoteOut = out;
        }

        public abstract void start();
        public abstract void stop();

        protected abstract void sendMessage(int message, int sequence, @NonNull byte[] data)
                throws IOException;

        public Future<byte[]> requestForResponse(int message, byte[] data) {
            if (DEBUG) Slog.d(TAG, "Requesting for response");
            final int sequence = mNextSequence.incrementAndGet();
            final CompletableFuture<byte[]> pending = new CompletableFuture<>();
            synchronized (mPendingRequests) {
                mPendingRequests.put(sequence, pending);
            }

            try {
                sendMessage(message, sequence, data);
            } catch (IOException e) {
                synchronized (mPendingRequests) {
                    mPendingRequests.remove(sequence);
                }
                pending.completeExceptionally(e);
            }

            return pending;
        }

        protected final void handleMessage(int message, int sequence, @NonNull byte[] data)
                throws IOException {
            if (DEBUG) {
                Slog.d(TAG, "Received message 0x" + Integer.toHexString(message)
                        + " sequence " + sequence + " length " + data.length
                        + " from association " + mAssociationId);
            }

            if (isRequest(message)) {
                try {
                    processRequest(message, sequence, data);
                } catch (IOException e) {
                    Slog.w(TAG, "Failed to respond to 0x" + Integer.toHexString(message), e);
                }
            } else if (isResponse(message)) {
                processResponse(message, sequence, data);
            } else {
                Slog.w(TAG, "Unknown message 0x" + Integer.toHexString(message));
            }
        }

        private void processRequest(int message, int sequence, byte[] data)
                throws IOException {
            switch (message) {
                case MESSAGE_REQUEST_PING: {
                    sendMessage(MESSAGE_RESPONSE_SUCCESS, sequence, data);
                    break;
                }
                case MESSAGE_REQUEST_PERMISSION_RESTORE: {
                    if (!mContext.getPackageManager().hasSystemFeature(PackageManager.FEATURE_WATCH)
                            && !Build.isDebuggable()) {
                        Slog.w(TAG, "Restoring permissions only supported on watches");
                        sendMessage(MESSAGE_RESPONSE_FAILURE, sequence, EmptyArray.BYTE);
                        break;
                    }
                    try {
                        mListener.onRequestPermissionRestore(data);
                        sendMessage(MESSAGE_RESPONSE_SUCCESS, sequence, EmptyArray.BYTE);
                    } catch (Exception e) {
                        Slog.w(TAG, "Failed to restore permissions");
                        sendMessage(MESSAGE_RESPONSE_FAILURE, sequence, EmptyArray.BYTE);
                    }
                    break;
                }
                default: {
                    Slog.w(TAG, "Unknown request 0x" + Integer.toHexString(message));
                    sendMessage(MESSAGE_RESPONSE_FAILURE, sequence, EmptyArray.BYTE);
                    break;
                }
            }
        }

        private void processResponse(int message, int sequence, byte[] data) {
            final CompletableFuture<byte[]> future;
            synchronized (mPendingRequests) {
                future = mPendingRequests.removeReturnOld(sequence);
            }
            if (future == null) {
                Slog.w(TAG, "Ignoring unknown sequence " + sequence);
                return;
            }

            switch (message) {
                case MESSAGE_RESPONSE_SUCCESS: {
                    future.complete(data);
                    break;
                }
                case MESSAGE_RESPONSE_FAILURE: {
                    future.completeExceptionally(new RuntimeException("Remote failure"));
                    break;
                }
                default: {
                    Slog.w(TAG, "Ignoring unknown response 0x" + Integer.toHexString(message));
                }
            }
        }
    }

    private class RawTransport extends Transport {
        private volatile boolean mStopped;

        RawTransport(int associationId, ParcelFileDescriptor fd) {
            super(associationId, fd);
        }

        @Override
        public void start() {
            new Thread(() -> {
                try {
                    while (!mStopped) {
                        receiveMessage();
                    }
                } catch (IOException e) {
                    if (!mStopped) {
                        Slog.w(TAG, "Trouble during transport", e);
                        stop();
                    }
                }
            }).start();
        }

        @Override
        public void stop() {
            mStopped = true;

            IoUtils.closeQuietly(mRemoteIn);
            IoUtils.closeQuietly(mRemoteOut);
        }

        @Override
        protected void sendMessage(int message, int sequence, @NonNull byte[] data)
                throws IOException {
            if (DEBUG) {
                Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message)
                        + " sequence " + sequence + " length " + data.length
                        + " to association " + mAssociationId);
            }

            synchronized (mRemoteOut) {
                final ByteBuffer header = ByteBuffer.allocate(HEADER_LENGTH)
                        .putInt(message)
                        .putInt(sequence)
                        .putInt(data.length);
                mRemoteOut.write(header.array());
                mRemoteOut.write(data);
                mRemoteOut.flush();
            }
        }

        private void receiveMessage() throws IOException {
            final byte[] headerBytes = new byte[HEADER_LENGTH];
            Streams.readFully(mRemoteIn, headerBytes);
            final ByteBuffer header = ByteBuffer.wrap(headerBytes);
            final int message = header.getInt();
            final int sequence = header.getInt();
            final int length = header.getInt();
            final byte[] data = new byte[length];
            Streams.readFully(mRemoteIn, data);

            handleMessage(message, sequence, data);
        }
    }

    private class SecureTransport extends Transport implements SecureChannel.Callback {
        private final SecureChannel mSecureChannel;

        private volatile boolean mShouldProcessRequests = false;

        private final BlockingQueue<byte[]> mRequestQueue = new ArrayBlockingQueue<>(100);

        SecureTransport(int associationId, ParcelFileDescriptor fd) {
            super(associationId, fd);
            mSecureChannel = new SecureChannel(mRemoteIn, mRemoteOut, this, mContext);
        }

        @Override
        public void start() {
            mSecureChannel.start();
        }

        @Override
        public void stop() {
            mSecureChannel.stop();
            mShouldProcessRequests = false;
        }

        @Override
        public Future<byte[]> requestForResponse(int message, byte[] data) {
            // Check if channel is secured and start securing
            if (!mShouldProcessRequests) {
                Slog.d(TAG, "Establishing secure connection.");
                try {
                    mSecureChannel.establishSecureConnection();
                } catch (Exception e) {
                    Slog.w(TAG, "Failed to initiate secure channel handshake.", e);
                    onError(e);
                }
            }

            return super.requestForResponse(message, data);
        }

        @Override
        protected void sendMessage(int message, int sequence, @NonNull byte[] data)
                throws IOException {
            if (DEBUG) {
                Slog.d(TAG, "Queueing message 0x" + Integer.toHexString(message)
                        + " sequence " + sequence + " length " + data.length
                        + " to association " + mAssociationId);
            }

            // Queue up a message to send
            mRequestQueue.add(ByteBuffer.allocate(HEADER_LENGTH + data.length)
                    .putInt(message)
                    .putInt(sequence)
                    .putInt(data.length)
                    .put(data)
                    .array());
        }

        @Override
        public void onSecureConnection() {
            mShouldProcessRequests = true;
            Slog.d(TAG, "Secure connection established.");

            // TODO: find a better way to handle incoming requests than a dedicated thread.
            new Thread(() -> {
                try {
                    while (mShouldProcessRequests) {
                        byte[] request = mRequestQueue.poll();
                        if (request != null) {
                            mSecureChannel.sendSecureMessage(request);
                        }
                    }
                } catch (IOException e) {
                    onError(e);
                }
            }).start();
        }

        @Override
        public void onSecureMessageReceived(byte[] data) {
            final ByteBuffer payload = ByteBuffer.wrap(data);
            final int message = payload.getInt();
            final int sequence = payload.getInt();
            final int length = payload.getInt();
            final byte[] content = new byte[length];
            payload.get(content);

            try {
                handleMessage(message, sequence, content);
            } catch (IOException error) {
                onError(error);
            }
        }

        @Override
        public void onError(Throwable error) {
            mShouldProcessRequests = false;
            Slog.e(TAG, error.getMessage(), error);
        }
    }
}
+95 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2023 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.server.companion.transport;

import android.annotation.NonNull;
import android.content.Context;
import android.os.ParcelFileDescriptor;
import android.util.Slog;

import com.android.server.companion.transport.CompanionTransportManager.Listener;

import libcore.io.IoUtils;
import libcore.io.Streams;

import java.io.IOException;
import java.nio.ByteBuffer;

class RawTransport extends Transport {
    private volatile boolean mStopped;

    RawTransport(int associationId, ParcelFileDescriptor fd, Context context, Listener listener) {
        super(associationId, fd, context, listener);
    }

    @Override
    public void start() {
        new Thread(() -> {
            try {
                while (!mStopped) {
                    receiveMessage();
                }
            } catch (IOException e) {
                if (!mStopped) {
                    Slog.w(TAG, "Trouble during transport", e);
                    stop();
                }
            }
        }).start();
    }

    @Override
    public void stop() {
        mStopped = true;

        IoUtils.closeQuietly(mRemoteIn);
        IoUtils.closeQuietly(mRemoteOut);
    }

    @Override
    protected void sendMessage(int message, int sequence, @NonNull byte[] data)
            throws IOException {
        if (DEBUG) {
            Slog.d(TAG, "Sending message 0x" + Integer.toHexString(message)
                    + " sequence " + sequence + " length " + data.length
                    + " to association " + mAssociationId);
        }

        synchronized (mRemoteOut) {
            final ByteBuffer header = ByteBuffer.allocate(HEADER_LENGTH)
                    .putInt(message)
                    .putInt(sequence)
                    .putInt(data.length);
            mRemoteOut.write(header.array());
            mRemoteOut.write(data);
            mRemoteOut.flush();
        }
    }

    private void receiveMessage() throws IOException {
        final byte[] headerBytes = new byte[HEADER_LENGTH];
        Streams.readFully(mRemoteIn, headerBytes);
        final ByteBuffer header = ByteBuffer.wrap(headerBytes);
        final int message = header.getInt();
        final int sequence = header.getInt();
        final int length = header.getInt();
        final byte[] data = new byte[length];
        Streams.readFully(mRemoteIn, data);

        handleMessage(message, sequence, data);
    }
}
+134 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2023 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.server.companion.transport;

import android.annotation.NonNull;
import android.content.Context;
import android.os.ParcelFileDescriptor;
import android.util.Slog;

import com.android.server.companion.securechannel.SecureChannel;
import com.android.server.companion.transport.CompanionTransportManager.Listener;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Future;

class SecureTransport extends Transport implements SecureChannel.Callback {
    private final SecureChannel mSecureChannel;

    private volatile boolean mShouldProcessRequests = false;

    private final BlockingQueue<byte[]> mRequestQueue = new ArrayBlockingQueue<>(100);

    SecureTransport(int associationId,
            ParcelFileDescriptor fd,
            Context context,
            Listener listener) {
        super(associationId, fd, context, listener);
        mSecureChannel = new SecureChannel(mRemoteIn, mRemoteOut, this, context);
    }

    @Override
    public void start() {
        mSecureChannel.start();
    }

    @Override
    public void stop() {
        mSecureChannel.stop();
        mShouldProcessRequests = false;
    }

    @Override
    public Future<byte[]> requestForResponse(int message, byte[] data) {
        // Check if channel is secured and start securing
        if (!mShouldProcessRequests) {
            Slog.d(TAG, "Establishing secure connection.");
            try {
                mSecureChannel.establishSecureConnection();
            } catch (Exception e) {
                Slog.w(TAG, "Failed to initiate secure channel handshake.", e);
                onError(e);
            }
        }

        return super.requestForResponse(message, data);
    }

    @Override
    protected void sendMessage(int message, int sequence, @NonNull byte[] data)
            throws IOException {
        if (DEBUG) {
            Slog.d(TAG, "Queueing message 0x" + Integer.toHexString(message)
                    + " sequence " + sequence + " length " + data.length
                    + " to association " + mAssociationId);
        }

        // Queue up a message to send
        mRequestQueue.add(ByteBuffer.allocate(HEADER_LENGTH + data.length)
                .putInt(message)
                .putInt(sequence)
                .putInt(data.length)
                .put(data)
                .array());
    }

    @Override
    public void onSecureConnection() {
        mShouldProcessRequests = true;
        Slog.d(TAG, "Secure connection established.");

        // TODO: find a better way to handle incoming requests than a dedicated thread.
        new Thread(() -> {
            try {
                while (mShouldProcessRequests) {
                    byte[] request = mRequestQueue.poll();
                    if (request != null) {
                        mSecureChannel.sendSecureMessage(request);
                    }
                }
            } catch (IOException e) {
                onError(e);
            }
        }).start();
    }

    @Override
    public void onSecureMessageReceived(byte[] data) {
        final ByteBuffer payload = ByteBuffer.wrap(data);
        final int message = payload.getInt();
        final int sequence = payload.getInt();
        final int length = payload.getInt();
        final byte[] content = new byte[length];
        payload.get(content);

        try {
            handleMessage(message, sequence, content);
        } catch (IOException error) {
            onError(error);
        }
    }

    @Override
    public void onError(Throwable error) {
        mShouldProcessRequests = false;
        Slog.e(TAG, error.getMessage(), error);
    }
}
+181 −0

File added.

Preview size limit exceeded, changes collapsed.