From f078dd1508a5449696c9a2c1e8aeb96e4412c10a Mon Sep 17 00:00:00 2001 From: Marcos Candeia Date: Mon, 30 Sep 2024 09:23:06 -0300 Subject: [PATCH] Add support for strongly typed websockets Signed-off-by: Marcos Candeia --- src/actors/channels/channel.ts | 211 +++++++++++++++++++++++++++++ src/actors/channels/serializers.ts | 14 ++ src/actors/proxy.ts | 102 ++++++++++++-- src/actors/runtime.test.ts | 33 ++++- src/actors/runtime.ts | 17 ++- 5 files changed, 365 insertions(+), 12 deletions(-) create mode 100644 src/actors/channels/channel.ts create mode 100644 src/actors/channels/serializers.ts diff --git a/src/actors/channels/channel.ts b/src/actors/channels/channel.ts new file mode 100644 index 0000000..8263084 --- /dev/null +++ b/src/actors/channels/channel.ts @@ -0,0 +1,211 @@ +import { Queue } from "@core/asyncutil/queue"; +import { jsonSerializer } from "./serializers.ts"; + +export interface Channel { + closed: Promise; + signal: AbortSignal; + close(): void; + send(value: T): Promise; + recv(signal?: AbortSignal): AsyncIterableIterator; +} + +/** + * Checks if a value is a channel. + * + * @param v - The value to check. + * + * @returns True if the value is a channel, false otherwise. + */ +export const isChannel = < + T, + TChannel extends Channel = Channel, +>(v: TChannel | unknown): v is TChannel => { + return typeof (v as TChannel).recv === "function" && + typeof (v as TChannel).send === "function"; +}; + +/** + * Checks if a value is a channel upgrader. + * + * @param v - The value to check. + * + * @returns True if the value is a channel upgrader, false otherwise. + */ +export const isUpgrade = ( + v: ChannelUpgrader | unknown, +): v is ChannelUpgrader => { + return typeof (v as ChannelUpgrader) === "function"; +}; + +/** + * Links multiple abort signals together such that when any of them + * are aborted, the returned signal is also aborted. + * + * @param signals - The abort signals to link together. + * + * @returns The linked abort signal. + */ +export const link = (...signals: AbortSignal[]): AbortSignal => { + const ctrl = new AbortController(); + for (const signal of signals) { + signal.addEventListener("abort", (evt) => { + if (!ctrl.signal.aborted) { + ctrl.abort(evt); + } + }); + } + return ctrl.signal; +}; + +export class ClosedChannelError extends Error { + constructor() { + super("Channel is closed"); + } +} +export const ifClosedChannel = + (cb: () => Promise | void) => (err: unknown) => { + if (err instanceof ClosedChannelError) { + return cb(); + } + throw err; + }; + +export const ignoreIfClosed = ifClosedChannel(() => {}); +export const makeChan = (capacity = 0): Channel => { + let currentCapacity = capacity; + const queue: Queue<{ value: T; resolve: () => void }> = new Queue(); + const ctrl = new AbortController(); + const abortPromise = Promise.withResolvers(); + ctrl.signal.onabort = () => { + abortPromise.resolve(); + }; + + const send = (value: T): Promise => { + return new Promise((resolve, reject) => { + if (ctrl.signal.aborted) reject(new ClosedChannelError()); + let mResolve = resolve; + if (currentCapacity > 0) { + currentCapacity--; + mResolve = () => { + currentCapacity++; + }; + resolve(); + } + queue.push({ value, resolve: mResolve }); + }); + }; + + const close = () => { + ctrl.abort(); + }; + + const recv = async function* ( + signal?: AbortSignal, + ): AsyncIterableIterator { + const linked = signal ? link(ctrl.signal, signal) : ctrl.signal; + while (true) { + if (linked.aborted) { + return; + } + try { + const next = await queue.pop({ signal: linked }); + next.resolve(); + yield next.value; + } catch (_err) { + if (linked.aborted) { + return; + } + throw _err; + } + } + }; + + return { + send, + recv, + close, + signal: ctrl.signal, + closed: abortPromise.promise, + }; +}; + +export interface DuplexChannel { + send: Channel["send"]; + recv: Channel["recv"]; + close: () => void | Promise; +} + +export type ChannelUpgrader = ( + ch: DuplexChannel, +) => Promise; + +// deno-lint-ignore no-explicit-any +export type Message = TMessageProperties & { + chunk?: Uint8Array; +}; + +export interface MessageSerializer< + TSend, + TReceive, + TRawFormat extends string | ArrayBufferLike | ArrayBufferView | Blob, +> { + binaryType?: BinaryType; + serialize: ( + msg: Message, + ) => TRawFormat; + deserialize: (str: TRawFormat) => Message; +} + +export const makeWebSocket = < + TSend, + TReceive, + TMessageFormat extends string | ArrayBufferLike | ArrayBufferView | Blob = + | string + | ArrayBufferLike + | ArrayBufferView + | Blob, +>( + socket: WebSocket, + _serializer?: MessageSerializer, +): Promise, Message>> => { + const serializer = _serializer ?? + jsonSerializer, Message>(); + const sendChan = makeChan>(); + const recvChan = makeChan>(); + const ch = Promise.withResolvers< + DuplexChannel, Message> + >(); + socket.binaryType = serializer.binaryType ?? "blob"; + socket.onclose = () => { + sendChan.close(); + recvChan.close(); + }; + socket.onerror = (err) => { + socket.close(); + ch.reject(err); + }; + socket.onmessage = async (msg) => { + if (recvChan.signal.aborted) { + return; + } + await recvChan.send(serializer.deserialize(msg.data)); + }; + socket.onopen = async () => { + ch.resolve({ + recv: recvChan.recv.bind(recvChan), + send: sendChan.send.bind(recvChan), + close: () => socket.close(), + }); + for await (const message of sendChan.recv()) { + try { + socket.send( + serializer.serialize(message), + ); + } catch (_err) { + console.error("error sending message through socket", message); + } + } + socket.close(); + }; + return ch.promise; +}; diff --git a/src/actors/channels/serializers.ts b/src/actors/channels/serializers.ts new file mode 100644 index 0000000..1aad38f --- /dev/null +++ b/src/actors/channels/serializers.ts @@ -0,0 +1,14 @@ +import type { MessageSerializer } from "./channel.ts"; + +export const jsonSerializer = (): MessageSerializer< + TSend, + TReceive, + string +> => { + return { + deserialize: (msg) => { + return JSON.parse(msg); + }, + serialize: JSON.stringify, + }; +}; diff --git a/src/actors/proxy.ts b/src/actors/proxy.ts index bdcd151..ec6bc53 100644 --- a/src/actors/proxy.ts +++ b/src/actors/proxy.ts @@ -1,7 +1,75 @@ +import { + type ChannelUpgrader, + type DuplexChannel, + makeWebSocket, +} from "./channels/channel.ts"; import type { Actor, ActorConstructor } from "./runtime.ts"; import { EVENT_STREAM_RESPONSE_HEADER, readFromStream } from "./stream.ts"; export const ACTOR_ID_HEADER_NAME = "x-deno-isolate-instance-id"; +export const ACTOR_ID_QS_NAME = "x-deno-isolate-instance-id"; +/** + * Promise.prototype.then onfufilled callback type. + */ +export type Fulfilled = ((result: R) => T | PromiseLike) | null; + +/** + * Promise.then onrejected callback type. + */ +// deno-lint-ignore no-explicit-any +export type Rejected = ((reason: any) => E | PromiseLike) | null; + +export class ActorAwaiter< + TResponse, + TChannel extends DuplexChannel, +> implements + PromiseLike< + TResponse + >, + DuplexChannel { + ch: Promise | null = null; + constructor( + protected fetcher: () => Promise< + TResponse + >, + protected ws: () => Promise, + ) { + } + async close() { + const ch = await this.channel; + await ch.close(); + } + + async *recv(signal?: AbortSignal): AsyncIterableIterator { + const ch = await this.channel; + yield* ch.recv(signal); + } + + private get channel(): Promise { + return this.ch ??= this.ws(); + } + + async send(value: unknown): Promise { + const ch = await this.channel; + await ch.send(value); + } + + catch(onrejected: Rejected): Promise { + return this.fetcher().catch(onrejected); + } + + then( + onfufilled?: Fulfilled< + TResponse, + TResult1 + >, + onrejected?: Rejected, + ): Promise { + return this.fetcher().then(onfufilled).catch( + onrejected, + ); + } +} /** * options to create a new actor proxy. @@ -11,11 +79,15 @@ export interface ProxyOptions { server: string; } +type PromisifyKey = Actor[key] extends + (...args: infer Args) => Awaited + ? Return extends ChannelUpgrader + ? (...args: Args) => DuplexChannel + : (...args: Args) => Promise + : Actor[key]; + type Promisify = { - [key in keyof Actor]: Actor[key] extends (...args: infer Args) => infer Return - ? Return extends Promise ? Actor[key] - : (...args: Args) => Promise - : Actor[key]; + [key in keyof Actor]: PromisifyKey; }; export interface ActorsServer { @@ -66,12 +138,13 @@ export const actors = { id: (id: string): Promisify => { return new Proxy>({} as Promisify, { get: (_, prop) => { - return async (...args: unknown[]) => { + const endpoint = `${actorsServer.url}/actors/${ + typeof actor === "string" ? actor : actor.name + }/invoke/${String(prop)}`; + const fetcher = async (...args: unknown[]) => { const abortCtrl = new AbortController(); const resp = await fetch( - `${actorsServer.url}/actors/${ - typeof actor === "string" ? actor : actor.name - }/invoke/${String(prop)}`, + endpoint, { method: "POST", signal: abortCtrl.signal, @@ -104,6 +177,19 @@ export const actors = { } return resp.json(); }; + return (...args: unknown[]) => { + const awaiter = new ActorAwaiter(() => fetcher(...args), () => { + const ws = new WebSocket( + `${endpoint}?args=${ + encodeURIComponent( + btoa(JSON.stringify({ args: args ?? [] })), + ) + }&${ACTOR_ID_QS_NAME}=${id}`, + ); + return makeWebSocket(ws); + }); + return awaiter; + }; }, }); }, diff --git a/src/actors/runtime.test.ts b/src/actors/runtime.test.ts index e9b7f1d..b74f4e7 100644 --- a/src/actors/runtime.test.ts +++ b/src/actors/runtime.test.ts @@ -1,4 +1,5 @@ -import { assertEquals } from "@std/assert"; +import { assertEquals, assertFalse } from "@std/assert"; +import type { ChannelUpgrader } from "./channels/channel.ts"; import { actors } from "./proxy.ts"; import { ActorRuntime } from "./runtime.ts"; import type { ActorState } from "./state.ts"; @@ -36,6 +37,17 @@ class Counter { watch(): AsyncIterableIterator { return this.watchTarget.subscribe(); } + + chan(name: string): ChannelUpgrader { + return (async ({ send, recv }) => { + await send(`Hello ${name}`); + for await (const str of recv()) { + if (str === "PING") { + await send("PONG"); + } + } + }); + } } const runServer = (rt: ActorRuntime): AsyncDisposable => { @@ -47,13 +59,30 @@ const runServer = (rt: ActorRuntime): AsyncDisposable => { }; }; -Deno.test("counter increment and getCount", async () => { +Deno.test("counter tests", async () => { const rt = new ActorRuntime([Counter]); await using _server = runServer(rt); const actorId = "1234"; const counterProxy = actors.proxy(Counter); const actor = counterProxy.id(actorId); + const name = `Counter Actor`; + const ch = actor.chan(name); + const it = ch.recv(); + const { value, done } = await it.next(); + + assertFalse(done); + assertEquals(value, `Hello ${name}`); + + await ch.send("PING"); + + const { value: pong, done: pongDone } = await it.next(); + + assertFalse(pongDone); + assertEquals(pong, "PONG"); + + await ch.close(); + const watcher = await actor.watch(); // Test increment const number = await actor.increment(); diff --git a/src/actors/runtime.ts b/src/actors/runtime.ts index 6392c11..c9690ef 100644 --- a/src/actors/runtime.ts +++ b/src/actors/runtime.ts @@ -1,5 +1,6 @@ import { type ServerSentEventMessage, ServerSentEventStream } from "@std/http"; -import { ACTOR_ID_HEADER_NAME } from "./proxy.ts"; +import { isUpgrade, makeWebSocket } from "./channels/channel.ts"; +import { ACTOR_ID_HEADER_NAME, ACTOR_ID_QS_NAME } from "./proxy.ts"; import { ActorState } from "./state.ts"; import { DenoKvActorStorage } from "./storage/denoKv.ts"; import { EVENT_STREAM_RESPONSE_HEADER } from "./stream.ts"; @@ -117,7 +118,8 @@ export class ActorRuntime { */ async fetch(req: Request): Promise { const url = new URL(req.url); - const actorId = req.headers.get(ACTOR_ID_HEADER_NAME); + const actorId = req.headers.get(ACTOR_ID_HEADER_NAME) ?? + url.searchParams.get(ACTOR_ID_QS_NAME); if (!actorId) { return new Response(`missing ${ACTOR_ID_HEADER_NAME} header`, { status: 400, @@ -147,6 +149,12 @@ export class ActorRuntime { if (req.headers.get("content-type")?.includes("application/json")) { const { args: margs } = await req.json(); args = margs; + } else if (url.searchParams.get("args")) { + const qargs = url.searchParams.get("args"); + const parsedArgs = qargs + ? JSON.parse(atob(decodeURIComponent(qargs))) + : {}; + args = parsedArgs.args; } const methodImpl = actor[method as keyof typeof actor]; if (!isInvocable(methodImpl)) { @@ -161,6 +169,11 @@ export class ActorRuntime { const res = await (methodImpl as Function).bind(actor)( ...Array.isArray(args) ? args : [args], ); + if (req.headers.get("upgrade") === "websocket" && isUpgrade(res)) { + const { socket, response } = Deno.upgradeWebSocket(req); + makeWebSocket(socket).then((ch) => res(ch)).finally(() => socket.close()); + return response; + } if (isEventStreamResponse(res)) { req.signal.onabort = () => { res?.return?.();