From 2c863cc4aa0da46b319ad56042e238648d5653fd Mon Sep 17 00:00:00 2001 From: tomvanwoow Date: Wed, 2 Oct 2024 22:46:04 +0100 Subject: [PATCH] Supports supplying custom sec-websocket-protocol in EdenWS subscribe --- src/treaty/index.ts | 9 +++++---- src/treaty/types.ts | 4 ++-- src/treaty2/index.ts | 2 +- src/treaty2/types.ts | 4 ++-- src/treaty2/ws.ts | 4 ++-- test/treaty.test.ts | 20 ++++++++++++++++++++ test/treaty2.test.ts | 21 +++++++++++++++++++++ 7 files changed, 53 insertions(+), 11 deletions(-) diff --git a/src/treaty/index.ts b/src/treaty/index.ts index 0285a1a..d9e8932 100644 --- a/src/treaty/index.ts +++ b/src/treaty/index.ts @@ -60,8 +60,8 @@ export class EdenWS = InputSchema> { ws: WebSocket url: string - constructor(url: string) { - this.ws = new WebSocket(url) + constructor(url: string, protocols?: string[]) { + this.ws = new WebSocket(url, protocols) this.url = url } @@ -172,7 +172,7 @@ const createProxy = ( transform?: EdenTreaty.Transform headers?: Record query?: Record - } + }, ] = [{}, {}] ) { let bodyObj: any = @@ -223,7 +223,8 @@ const createProxy = ( url.replace( /^([^]+):\/\//, url.startsWith('https://') ? 'wss://' : 'ws://' - ) + ), + options as unknown as string[], ) const execute = async ( diff --git a/src/treaty/types.ts b/src/treaty/types.ts index 7de89d6..7d83f49 100644 --- a/src/treaty/types.ts +++ b/src/treaty/types.ts @@ -40,8 +40,8 @@ export namespace EdenTreaty { undefined extends Route['query'] ? (params?: { $query?: Record - }) => EdenWS - : (params: { $query: Route['query'] }) => EdenWS + }, protocols?: string[]) => EdenWS + : (params: { $query: Route['query'] }, protocols?: string[]) => EdenWS : // ? HTTP route (( params: Prettify< diff --git a/src/treaty2/index.ts b/src/treaty2/index.ts index ddb0768..a582182 100644 --- a/src/treaty2/index.ts +++ b/src/treaty2/index.ts @@ -223,7 +223,7 @@ const createProxy = ( path + q - return new EdenWS(url) + return new EdenWS(url, options) } return (async () => { diff --git a/src/treaty2/types.ts b/src/treaty2/types.ts index 07b96f3..484c7a9 100644 --- a/src/treaty2/types.ts +++ b/src/treaty2/types.ts @@ -76,8 +76,8 @@ export namespace Treaty { query: Route['subscribe']['query'] }) extends infer Param ? {} extends Param - ? (options?: Param) => EdenWS - : (options?: Param) => EdenWS + ? (options?: Param, protocols?: string[]) => EdenWS + : (options?: Param, protocols?: string[]) => EdenWS : never : Route[K] extends { body: infer Body diff --git a/src/treaty2/ws.ts b/src/treaty2/ws.ts index 75da87a..094e623 100644 --- a/src/treaty2/ws.ts +++ b/src/treaty2/ws.ts @@ -5,8 +5,8 @@ import { parseMessageEvent } from '../utils/parsingUtils' export class EdenWS = {}> { ws: WebSocket - constructor(public url: string) { - this.ws = new WebSocket(url) + constructor(public url: string, protocols?: string[]) { + this.ws = new WebSocket(url, protocols) } send(data: Schema['body'] | Schema['body'][]) { diff --git a/test/treaty.test.ts b/test/treaty.test.ts index f30d0d0..9734210 100644 --- a/test/treaty.test.ts +++ b/test/treaty.test.ts @@ -51,6 +51,15 @@ const app = new Elysia() .post('/string', ({ body }) => body, { body: t.String() }) + .ws('/custom-protocol', { + headers: t.Object({ + 'sec-websocket-protocol': t.Literal('customprotocol'), + }), + open: async (ws) => { + ws.send('success'); + ws.close(); + }, + }) .listen(8082) const client = edenTreaty('http://localhost:8082') @@ -270,4 +279,15 @@ describe('Eden Treaty', () => { expect(data).toEqual({ body: null }) }) + + + it('sends the correct custom protocols', async (done) => { + await new Promise(res => { + const socket = client['custom-protocol'].subscribe(undefined, ['customprotocol']); + + socket.subscribe(() => res()); + }) + + done() + }); }) diff --git a/test/treaty2.test.ts b/test/treaty2.test.ts index 0b0de43..5c503e2 100644 --- a/test/treaty2.test.ts +++ b/test/treaty2.test.ts @@ -152,6 +152,15 @@ const app = new Elysia() ws.close() } }) + .ws('/custom-protocol', { + headers: t.Object({ + 'sec-websocket-protocol': t.Literal('customprotocol'), + }), + open: async (ws) => { + ws.send('success'); + ws.close(); + }, + }) .get('/stream', function* stream() { yield 'a' yield 'b' @@ -640,4 +649,16 @@ describe('Treaty2 - Using endpoint URL', () => { done() }) }) + + it('sends the correct custom protocols', async (done) => { + const client = treaty('http://localhost:8080') + + await new Promise(res => { + const socket = client['custom-protocol'].subscribe(undefined, ['customprotocol']); + + socket.subscribe(() => res()); + }) + + done() + }); })