diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index 4f436f6..6de603d 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -447,6 +447,43 @@ final class HummingbirdWebSocketTests: XCTestCase { } catch let error as WebSocketClientError where error == .webSocketUpgradeFailed {} } + /// Test context from router is passed through to web socket + func testRouterContextUpdate() async throws { + struct MyRequestContext: WebSocketRequestContext { + var coreContext: CoreRequestContext + var webSocket: WebSocketRouterContext + var name: String + + init(channel: Channel, logger: Logger) { + self.coreContext = .init(allocator: channel.allocator, logger: logger) + self.webSocket = .init() + self.name = "" + } + } + struct MyMiddleware: RouterMiddleware { + func handle(_ request: Request, context: MyRequestContext, next: (Request, MyRequestContext) async throws -> Response) async throws -> Response { + var context = context + context.name = "Roger Moore" + return try await next(request, context) + } + } + let router = Router(context: MyRequestContext.self) + router.middlewares.add(MyMiddleware()) + router.ws("/ws") { _, _ in + return .upgrade([:]) + } handle: { _, outbound, context in + try await outbound.write(.text(context.name)) + } + do { + try await self.testClientAndServerWithRouter(webSocketRouter: router, uri: "localhost:8080") { port, logger in + try WebSocketClient(url: .init("ws://localhost:\(port)/ws"), logger: logger) { inbound, _, _ in + let text = await inbound.first { _ in true } + XCTAssertEqual(text, .text("Roger Moore")) + } + } + } catch let error as WebSocketClientError where error == .webSocketUpgradeFailed {} + } + func testHTTPRequest() async throws { let router = Router(context: BasicWebSocketRequestContext.self) router.ws("/ws") { _, _ in