Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Unverified Commit fe5cb5a1 authored by Simon Chan's avatar Simon Chan
Browse files

feat(adb): don't kill subprocess on stream close

parent e45fb2ed
Loading
Loading
Loading
Loading
+28 −18
Original line number Diff line number Diff line
@@ -5,8 +5,8 @@ import type {
    AdbIncomingSocketHandler,
    AdbServerConnection,
    AdbServerConnectionOptions,
    AdbServerConnector,
} from "@yume-chan/adb";
import type { ReadableWritablePair } from "@yume-chan/stream-extra";
import {
    PushReadableStream,
    UnwrapConsumableStream,
@@ -15,12 +15,21 @@ import {
} from "@yume-chan/stream-extra";
import type { ValueOrPromise } from "@yume-chan/struct";

function nodeSocketToStreamPair(socket: Socket) {
function nodeSocketToConnection(socket: Socket): AdbServerConnection {
    socket.setNoDelay(true);

    const closed = new Promise<void>((resolve) => {
        socket.on("close", resolve);
    });

    return {
        readable: new PushReadableStream<Uint8Array>((controller) => {
            // eslint-disable-next-line @typescript-eslint/no-misused-promises
            socket.on("data", async (data) => {
                if (controller.abortSignal.aborted) {
                    return;
                }

                socket.pause();
                await controller.enqueue(data);
                socket.resume();
@@ -32,9 +41,6 @@ function nodeSocketToStreamPair(socket: Socket) {
                    // controller already closed
                }
            });
            controller.abortSignal.addEventListener("abort", () => {
                socket.end();
            });
        }),
        writable: new WritableStream<Uint8Array>({
            write: async (chunk) => {
@@ -48,16 +54,17 @@ function nodeSocketToStreamPair(socket: Socket) {
                    });
                });
            },
        }),
        get closed() {
            return closed;
        },
        close() {
                return new Promise<void>((resolve) => {
                    socket.end(resolve);
                });
            socket.end();
        },
        }),
    };
}

export class AdbServerNodeTcpConnection implements AdbServerConnection {
export class AdbServerNodeTcpConnector implements AdbServerConnector {
    readonly spec: SocketConnectOpts;

    readonly #listeners = new Map<string, Server>();
@@ -68,7 +75,7 @@ export class AdbServerNodeTcpConnection implements AdbServerConnection {

    async connect(
        { unref }: AdbServerConnectionOptions = { unref: false },
    ): Promise<ReadableWritablePair<Uint8Array, Uint8Array>> {
    ): Promise<AdbServerConnection> {
        const socket = new Socket();
        if (unref) {
            socket.unref();
@@ -78,7 +85,7 @@ export class AdbServerNodeTcpConnection implements AdbServerConnection {
            socket.once("connect", resolve);
            socket.once("error", reject);
        });
        return nodeSocketToStreamPair(socket);
        return nodeSocketToConnection(socket);
    }

    async addReverseTunnel(
@@ -87,16 +94,19 @@ export class AdbServerNodeTcpConnection implements AdbServerConnection {
    ): Promise<string> {
        // eslint-disable-next-line @typescript-eslint/no-misused-promises
        const server = new Server(async (socket) => {
            const stream = nodeSocketToStreamPair(socket);
            const connection = nodeSocketToConnection(socket);
            try {
                await handler({
                    service: address!,
                    readable: stream.readable,
                    readable: connection.readable,
                    writable: new WrapWritableStream(
                        stream.writable,
                        connection.writable,
                    ).bePipedThroughFrom(new UnwrapConsumableStream()),
                    close() {
                        socket.end();
                    get closed() {
                        return connection.closed;
                    },
                    async close() {
                        await connection.close();
                    },
                });
            } catch {
+3 −1
Original line number Diff line number Diff line
@@ -22,7 +22,9 @@ export interface Closeable {
export interface AdbSocket
    extends ReadableWritablePair<Uint8Array, Consumable<Uint8Array>>,
        Closeable {
    readonly service: string;
    get service(): string;

    get closed(): Promise<void>;
}

export type AdbIncomingSocketHandler = (
+11 −16
Original line number Diff line number Diff line
import type { Consumable, WritableStream } from "@yume-chan/stream-extra";
import { DuplexStreamFactory, ReadableStream } from "@yume-chan/stream-extra";
import { ReadableStream } from "@yume-chan/stream-extra";

import type { Adb, AdbSocket } from "../../../adb.js";
import { unreachable } from "../../../utils/index.js";

import type { AdbSubprocessProtocol } from "./types.js";

@@ -34,19 +35,16 @@ export class AdbSubprocessNoneProtocol implements AdbSubprocessProtocol {

    readonly #socket: AdbSocket;

    readonly #duplex: DuplexStreamFactory<Uint8Array, Uint8Array>;

    // Legacy shell forwards all data to stdin.
    get stdin(): WritableStream<Consumable<Uint8Array>> {
        return this.#socket.writable;
    }

    #stdout: ReadableStream<Uint8Array>;
    /**
     * Legacy shell mixes stdout and stderr.
     */
    get stdout(): ReadableStream<Uint8Array> {
        return this.#stdout;
        return this.#socket.readable;
    }

    #stderr: ReadableStream<Uint8Array>;
@@ -65,24 +63,21 @@ export class AdbSubprocessNoneProtocol implements AdbSubprocessProtocol {
    constructor(socket: AdbSocket) {
        this.#socket = socket;

        // Link `stdout`, `stderr` and `stdin` together,
        // so closing any of them will close the others.
        this.#duplex = new DuplexStreamFactory<Uint8Array, Uint8Array>({
            close: async () => {
                await this.#socket.close();
        this.#stderr = new ReadableStream({
            start: (controller) => {
                this.#socket.closed
                    .then(() => controller.close())
                    .catch(unreachable);
            },
        });

        this.#stdout = this.#duplex.wrapReadable(this.#socket.readable);
        this.#stderr = this.#duplex.wrapReadable(new ReadableStream());
        this.#exit = this.#duplex.closed.then(() => 0);
        this.#exit = socket.closed.then(() => 0);
    }

    resize() {
        // Not supported, but don't throw.
    }

    kill() {
        return this.#duplex.close();
    async kill() {
        await this.#socket.close();
    }
}
+7 −2
Original line number Diff line number Diff line
@@ -159,6 +159,7 @@ export class AdbSubprocessShellProtocol implements AdbSubprocessProtocol {

        let stdoutController!: PushReadableStreamController<Uint8Array>;
        let stderrController!: PushReadableStreamController<Uint8Array>;

        this.#stdout = new PushReadableStream<Uint8Array>((controller) => {
            stdoutController = controller;
        });
@@ -176,10 +177,14 @@ export class AdbSubprocessShellProtocol implements AdbSubprocessProtocol {
                                this.#exit.resolve(chunk.data[0]!);
                                break;
                            case AdbShellProtocolId.Stdout:
                                if (!stdoutController.abortSignal.aborted) {
                                    await stdoutController.enqueue(chunk.data);
                                }
                                break;
                            case AdbShellProtocolId.Stderr:
                                if (!stderrController.abortSignal.aborted) {
                                    await stderrController.enqueue(chunk.data);
                                }
                                break;
                        }
                    },
+52 −36
Original line number Diff line number Diff line
@@ -3,13 +3,11 @@
import { PromiseResolver } from "@yume-chan/async";
import type {
    AbortSignal,
    Consumable,
    ReadableWritablePair,
    WritableStreamDefaultWriter,
} from "@yume-chan/stream-extra";
import {
    BufferedReadableStream,
    DuplexStreamFactory,
    UnwrapConsumableStream,
    WrapWritableStream,
} from "@yume-chan/stream-extra";
@@ -25,7 +23,7 @@ import {
    encodeUtf8,
} from "@yume-chan/struct";

import type { AdbIncomingSocketHandler, AdbSocket } from "../adb.js";
import type { AdbIncomingSocketHandler, AdbSocket, Closeable } from "../adb.js";
import { AdbBanner } from "../banner.js";
import type { AdbFeature } from "../features.js";
import { NOOP, hexToNumber, numberToHex } from "../utils/index.js";
@@ -37,10 +35,16 @@ export interface AdbServerConnectionOptions {
    signal?: AbortSignal | undefined;
}

export interface AdbServerConnection {
export interface AdbServerConnection
    extends ReadableWritablePair<Uint8Array, Uint8Array>,
        Closeable {
    get closed(): Promise<void>;
}

export interface AdbServerConnector {
    connect(
        options?: AdbServerConnectionOptions,
    ): ValueOrPromise<ReadableWritablePair<Uint8Array, Uint8Array>>;
    ): ValueOrPromise<AdbServerConnection>;

    addReverseTunnel(
        handler: AdbIncomingSocketHandler,
@@ -74,9 +78,9 @@ export interface AdbServerDevice {
export class AdbServerClient {
    static readonly VERSION = 41;

    readonly connection: AdbServerConnection;
    readonly connection: AdbServerConnector;

    constructor(connection: AdbServerConnection) {
    constructor(connection: AdbServerConnector) {
        this.connection = connection;
    }

@@ -126,30 +130,41 @@ export class AdbServerClient {
    async connect(
        request: string,
        options?: AdbServerConnectionOptions,
    ): Promise<ReadableWritablePair<Uint8Array, Uint8Array>> {
    ): Promise<AdbServerConnection> {
        const connection = await this.connection.connect(options);

        try {
            const writer = connection.writable.getWriter();
            await AdbServerClient.writeString(writer, request);
            writer.releaseLock();
        } catch (e) {
            await connection.readable.cancel();
            await connection.close();
            throw e;
        }

        const readable = new BufferedReadableStream(connection.readable);

        try {
            // `raceSignal` throws if the signal is aborted,
            // `raceSignal` throws when the signal is aborted,
            // so the `catch` block can close the connection.
            await raceSignal(
                () => AdbServerClient.readOkay(readable),
                options?.signal,
            );

            writer.releaseLock();
            return {
                readable: readable.release(),
                writable: connection.writable,
                get closed() {
                    return connection.closed;
                },
                async close() {
                    await connection.close();
                },
            };
        } catch (e) {
            writer.close().catch(NOOP);
            readable.cancel().catch(NOOP);
            await readable.cancel().catch(NOOP);
            await connection.close();
            throw e;
        }
    }
@@ -328,8 +343,18 @@ export class AdbServerClient {
        }

        const connection = await this.connect(switchService);
        const readable = new BufferedReadableStream(connection.readable);

        try {
            const writer = connection.writable.getWriter();
            await AdbServerClient.writeString(writer, service);
            writer.releaseLock();
        } catch (e) {
            await connection.readable.cancel();
            await connection.close();
            throw e;
        }

        const readable = new BufferedReadableStream(connection.readable);
        try {
            if (transportId === undefined) {
                const array = await readable.readExactly(8);
@@ -342,34 +367,25 @@ export class AdbServerClient {
                transportId = BigIntFieldType.Uint64.getter(dataView, 0, true);
            }

            await AdbServerClient.writeString(writer, service);
            await AdbServerClient.readOkay(readable);

            writer.releaseLock();

            const duplex = new DuplexStreamFactory<
                Uint8Array,
                Consumable<Uint8Array>
            >();
            const wrapReadable = duplex.wrapReadable(readable.release());
            const wrapWritable = duplex.createWritable(
                new WrapWritableStream(connection.writable).bePipedThroughFrom(
                    new UnwrapConsumableStream(),
                ),
            );

            return {
                transportId,
                service,
                readable: wrapReadable,
                writable: wrapWritable,
                close() {
                    return duplex.close();
                readable: readable.release(),
                writable: new WrapWritableStream(
                    connection.writable,
                ).bePipedThroughFrom(new UnwrapConsumableStream()),
                get closed() {
                    return connection.closed;
                },
                async close() {
                    await connection.close();
                },
            };
        } catch (e) {
            writer.close().catch(NOOP);
            readable.cancel().catch(NOOP);
            await readable.cancel().catch(NOOP);
            await connection.close();
            throw e;
        }
    }
Loading