From 541324e46ff1947ccd6c928991e04116fcab5471 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Fri, 29 Mar 2024 07:04:52 +0000 Subject: [PATCH] WebSocket extensions and `permessage-deflate` extension (#45) --- Package.swift | 13 +- Snippets/WebsocketTest.swift | 3 +- .../PerMessageDeflateExtension.swift | 197 +++--- .../Client/WebSocketClientChannel.swift | 18 +- .../Client/WebSocketClientConfiguration.swift | 6 +- .../NIOWebSocketServerUpgrade+ext.swift | 5 +- .../Server/WebSocketChannel.swift | 68 +- .../Server/WebSocketHTTPChannelBuilder.swift | 8 +- .../Server/WebSocketRouter.swift | 18 +- .../Server/WebSocketServerConfiguration.swift | 6 +- .../WebSocketDataHandler.swift | 2 +- .../WebSocketExtension.swift | 163 +++++ .../WebSocketHandler.swift | 36 +- .../WebSocketOutboundWriter.swift | 14 +- .../WebSocketExtensionTests.swift | 617 +++++++++--------- 15 files changed, 738 insertions(+), 436 deletions(-) create mode 100644 Sources/HummingbirdWebSocket/WebSocketExtension.swift diff --git a/Package.swift b/Package.swift index 911fe6e..9951868 100644 --- a/Package.swift +++ b/Package.swift @@ -8,7 +8,7 @@ let package = Package( platforms: [.macOS(.v14), .iOS(.v17), .tvOS(.v17)], products: [ .library(name: "HummingbirdWebSocket", targets: ["HummingbirdWebSocket"]), - // .library(name: "HummingbirdWSCompression", targets: ["HummingbirdWSCompression"]), + .library(name: "HummingbirdWSCompression", targets: ["HummingbirdWSCompression"]), ], dependencies: [ .package(url: "https://github.com/hummingbird-project/hummingbird.git", branch: "main"), @@ -18,7 +18,6 @@ let package = Package( .package(url: "https://github.com/apple/swift-nio.git", from: "2.62.0"), .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.22.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.5.0"), - .package(url: "https://github.com/swift-extras/swift-extras-base64.git", from: "0.5.0"), .package(url: "https://github.com/adam-fowler/compress-nio.git", from: "1.0.0"), ], targets: [ @@ -32,13 +31,13 @@ let package = Package( .product(name: "NIOHTTPTypesHTTP1", package: "swift-nio-extras"), .product(name: "NIOWebSocket", package: "swift-nio"), ]), - /* .target(name: "HummingbirdWSCompression", dependencies: [ - .byName(name: "HummingbirdWSCore"), - .product(name: "CompressNIO", package: "compress-nio"), - ]),*/ + .target(name: "HummingbirdWSCompression", dependencies: [ + .byName(name: "HummingbirdWebSocket"), + .product(name: "CompressNIO", package: "compress-nio"), + ]), .testTarget(name: "HummingbirdWebSocketTests", dependencies: [ .byName(name: "HummingbirdWebSocket"), - // .byName(name: "HummingbirdWSCompression"), + .byName(name: "HummingbirdWSCompression"), .product(name: "Atomics", package: "swift-atomics"), .product(name: "Hummingbird", package: "hummingbird"), .product(name: "HummingbirdTesting", package: "hummingbird"), diff --git a/Snippets/WebsocketTest.swift b/Snippets/WebsocketTest.swift index 1c2ac3a..5097893 100644 --- a/Snippets/WebsocketTest.swift +++ b/Snippets/WebsocketTest.swift @@ -1,6 +1,7 @@ import HTTPTypes import Hummingbird import HummingbirdWebSocket +import HummingbirdWSCompression import Logging var logger = Logger(label: "Echo") @@ -22,7 +23,7 @@ router.ws("/ws") { inbound, outbound, _ in let app = Application( router: router, - server: .webSocketUpgrade(webSocketRouter: router), + server: .webSocketUpgrade(webSocketRouter: router, configuration: .init(extensions: [.perMessageDeflate()])), logger: logger ) try await app.runService() diff --git a/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift b/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift index 4e9efc3..420cb40 100644 --- a/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift +++ b/Sources/HummingbirdWSCompression/PerMessageDeflateExtension.swift @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// import CompressNIO -import HummingbirdWSCore +import HummingbirdWebSocket import NIOCore import NIOWebSocket @@ -27,6 +27,7 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { let serverNoContextTakeover: Bool let compressionLevel: Int? let memoryLevel: Int? + let maxDecompressedFrameSize: Int init( clientMaxWindow: Int? = nil, @@ -34,7 +35,8 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { serverMaxWindow: Int? = nil, serverNoContextTakeover: Bool = false, compressionLevel: Int? = nil, - memoryLevel: Int? = nil + memoryLevel: Int? = nil, + maxDecompressedFrameSize: Int = (1 << 14) ) { self.clientMaxWindow = clientMaxWindow self.clientNoContextTakeover = clientNoContextTakeover @@ -42,6 +44,7 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { self.serverNoContextTakeover = serverNoContextTakeover self.compressionLevel = compressionLevel self.memoryLevel = memoryLevel + self.maxDecompressedFrameSize = maxDecompressedFrameSize } /// Return client request header @@ -86,17 +89,15 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { /// Create server PerMessageDeflateExtension based off request headers /// - Parameters: /// - request: Client request - /// - eventLoop: EventLoop it is bound to - func serverExtension(from request: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (WebSocketExtension)? { + func serverExtension(from request: WebSocketExtensionHTTPParameters) throws -> (WebSocketExtension)? { let configuration = self.responseConfiguration(to: request) - return try PerMessageDeflateExtension(configuration: configuration, eventLoop: eventLoop) + return try PerMessageDeflateExtension(configuration: configuration) } /// Create client PerMessageDeflateExtension based off response headers /// - Parameters: /// - response: Server response - /// - eventLoop: EventLoop it is bound to - func clientExtension(from response: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> WebSocketExtension? { + func clientExtension(from response: WebSocketExtensionHTTPParameters) throws -> WebSocketExtension? { let clientMaxWindowParam = response.parameters["client_max_window_bits"]?.integer let clientNoContextTakeoverParam = response.parameters["client_no_context_takeover"] != nil let serverMaxWindowParam = response.parameters["server_max_window_bits"]?.integer @@ -107,8 +108,9 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { sendMaxWindow: clientMaxWindowParam, sendNoContextTakeover: clientNoContextTakeoverParam, compressionLevel: self.compressionLevel, - memoryLevel: self.memoryLevel - ), eventLoop: eventLoop) + memoryLevel: self.memoryLevel, + maxDecompressedFrameSize: self.maxDecompressedFrameSize + )) } private func responseConfiguration(to request: WebSocketExtensionHTTPParameters) -> PerMessageDeflateExtension.Configuration { @@ -134,7 +136,8 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { sendMaxWindow: optionalMin(requestServerMaxWindow?.integer, self.serverMaxWindow), sendNoContextTakeover: requestServerNoContextTakeover || self.serverNoContextTakeover, compressionLevel: self.compressionLevel, - memoryLevel: self.memoryLevel + memoryLevel: self.memoryLevel, + maxDecompressedFrameSize: self.maxDecompressedFrameSize ) } } @@ -144,11 +147,6 @@ struct PerMessageDeflateExtensionBuilder: WebSocketExtensionBuilder { /// Uses deflate to compress messages sent across a WebSocket /// See RFC 7692 for more details https://www.rfc-editor.org/rfc/rfc7692 struct PerMessageDeflateExtension: WebSocketExtension { - enum SendState: Sendable { - case idle - case sendingMessage - } - struct Configuration: Sendable { let receiveMaxWindow: Int? let receiveNoContextTakeover: Bool @@ -156,90 +154,126 @@ struct PerMessageDeflateExtension: WebSocketExtension { let sendNoContextTakeover: Bool let compressionLevel: Int? let memoryLevel: Int? + let maxDecompressedFrameSize: Int } - /// Internal mutable state and referenced types, that cannot be set to Sendable - class InternalState { + actor Decompressor { fileprivate let decompressor: any NIODecompressor + + init(_ decompressor: any NIODecompressor) throws { + self.decompressor = decompressor + try self.decompressor.startStream() + } + + func decompress(_ frame: WebSocketFrame, maxSize: Int, resetStream: Bool, context: some WebSocketContextProtocol) throws -> WebSocketFrame { + var frame = frame + precondition(frame.fin, "Only concatenated frames with fin set can be processed by the permessage-deflate extension") + // Reinstate last four bytes 0x00 0x00 0xff 0xff that were removed in the frame + // send (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2). + frame.data.writeBytes([0, 0, 255, 255]) + frame.data = try frame.data.decompressStream(with: self.decompressor, maxSize: maxSize, allocator: context.allocator) + if resetStream { + try self.decompressor.resetStream() + } + return frame + } + + func shutdown() throws { + try self.decompressor.finishStream() + } + } + + actor Compressor { + enum SendState: Sendable { + case idle + case sendingMessage + } + fileprivate let compressor: any NIOCompressor - fileprivate var sendState: SendState + var sendState: SendState - init(configuration: Configuration) throws { - self.decompressor = CompressionAlgorithm.deflate( - configuration: .init( - windowSize: numericCast(configuration.receiveMaxWindow ?? 15) - ) - ).decompressor - // compression level -1 will setup the default compression level, 8 is the default memory level - self.compressor = CompressionAlgorithm.deflate( - configuration: .init( - windowSize: numericCast(configuration.sendMaxWindow ?? 15), - compressionLevel: configuration.compressionLevel.map { numericCast($0) } ?? -1, - memoryLevel: configuration.memoryLevel.map { numericCast($0) } ?? 8 - ) - ).compressor + init(_ compressor: any NIOCompressor) throws { + self.compressor = compressor self.sendState = .idle - try self.decompressor.startStream() try self.compressor.startStream() } - func shutdown() { - try? self.compressor.finishStream() - try? self.decompressor.finishStream() + func compress(_ frame: WebSocketFrame, resetStream: Bool, context: some WebSocketContextProtocol) throws -> WebSocketFrame { + // if the frame is larger than 16 bytes, we haven't received a final frame or we are in the process of sending a message + // compress the data + let shouldWeCompress = frame.data.readableBytes > 16 || !frame.fin || self.sendState != .idle + if shouldWeCompress { + var newFrame = frame + if self.sendState == .idle { + newFrame.rsv1 = true + self.sendState = .sendingMessage + } + newFrame.data = try newFrame.data.compressStream(with: self.compressor, flush: .sync, allocator: context.allocator) + // if final frame then remove last four bytes 0x00 0x00 0xff 0xff + // (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1) + if newFrame.fin { + newFrame.data = newFrame.data.getSlice(at: newFrame.data.readerIndex, length: newFrame.data.readableBytes - 4) ?? newFrame.data + self.sendState = .idle + if resetStream { + try self.compressor.resetStream() + } + } + return newFrame + } + return frame + } + + func shutdown() throws { + try self.compressor.finishStream() } } + let name = "permessage-deflate" let configuration: Configuration - let internalState: NIOLoopBound + let decompressor: Decompressor + let compressor: Compressor - init(configuration: Configuration, eventLoop: EventLoop) throws { + init(configuration: Configuration) throws { self.configuration = configuration - self.internalState = try .init(.init(configuration: configuration), eventLoop: eventLoop) + self.decompressor = try .init( + CompressionAlgorithm.deflate( + configuration: .init( + windowSize: numericCast(configuration.receiveMaxWindow ?? 15) + ) + ).decompressor + ) + self.compressor = try .init( + CompressionAlgorithm.deflate( + configuration: .init( + windowSize: numericCast(configuration.sendMaxWindow ?? 15), + compressionLevel: configuration.compressionLevel.map { numericCast($0) } ?? -1, + memoryLevel: configuration.memoryLevel.map { numericCast($0) } ?? 8 + ) + ).compressor + ) } - func shutdown() { - self.internalState.value.shutdown() + func shutdown() async { + try? await self.decompressor.shutdown() + try? await self.compressor.shutdown() } - func processReceivedFrame(_ frame: WebSocketFrame, ws: WebSocket) throws -> WebSocketFrame { - var frame = frame + func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) async throws -> WebSocketFrame { if frame.rsv1 { - let state = self.internalState.value - precondition(frame.fin, "Only concatenated frames with fin set can be processed by the permessage-deflate extension") - // Reinstate last four bytes 0x00 0x00 0xff 0xff that were removed in the frame - // send (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.2). - frame.data.writeBytes([0, 0, 255, 255]) - frame.data = try frame.data.decompressStream(with: state.decompressor, maxSize: ws.maxFrameSize, allocator: ws.channel.allocator) - if self.configuration.receiveNoContextTakeover { - try state.decompressor.resetStream() - } + return try await self.decompressor.decompress( + frame, + maxSize: self.configuration.maxDecompressedFrameSize, + resetStream: self.configuration.receiveNoContextTakeover, + context: context + ) } return frame } - func processFrameToSend(_ frame: WebSocketFrame, ws: WebSocket) throws -> WebSocketFrame { - let state = self.internalState.value - // if the frame is larger than 16 bytes, we haven't received a final frame or we are in the process of sending a message - // compress the data - let shouldWeCompress = frame.data.readableBytes > 16 || !frame.fin || state.sendState != .idle + func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) async throws -> WebSocketFrame { let isCorrectType = frame.opcode == .text || frame.opcode == .binary - if shouldWeCompress, isCorrectType { - var newFrame = frame - if state.sendState == .idle { - newFrame.rsv1 = true - state.sendState = .sendingMessage - } - newFrame.data = try newFrame.data.compressStream(with: state.compressor, flush: .sync, allocator: ws.channel.allocator) - // if final frame then remove last four bytes 0x00 0x00 0xff 0xff - // (see https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.1) - if newFrame.fin { - newFrame.data = newFrame.data.getSlice(at: newFrame.data.readerIndex, length: newFrame.data.readableBytes - 4) ?? newFrame.data - state.sendState = .idle - if self.configuration.sendNoContextTakeover { - try state.compressor.resetStream() - } - } - return newFrame + if isCorrectType { + return try await self.compressor.compress(frame, resetStream: self.configuration.sendNoContextTakeover, context: context) } return frame } @@ -250,7 +284,11 @@ extension WebSocketExtensionFactory { /// - Parameters: /// - maxWindow: Max window to be used for decompression and compression /// - noContextTakeover: Should we reset window on every message - public static func perMessageDeflate(maxWindow: Int? = nil, noContextTakeover: Bool = false) -> WebSocketExtensionFactory { + public static func perMessageDeflate( + maxWindow: Int? = nil, + noContextTakeover: Bool = false, + maxDecompressedFrameSize: Int = 1 << 14 + ) -> WebSocketExtensionFactory { return .init { PerMessageDeflateExtensionBuilder( clientMaxWindow: maxWindow, @@ -258,7 +296,8 @@ extension WebSocketExtensionFactory { serverMaxWindow: maxWindow, serverNoContextTakeover: noContextTakeover, compressionLevel: nil, - memoryLevel: nil + memoryLevel: nil, + maxDecompressedFrameSize: maxDecompressedFrameSize ) } } @@ -279,7 +318,8 @@ extension WebSocketExtensionFactory { serverMaxWindow: Int? = nil, serverNoContextTakeover: Bool = false, compressionLevel: Int? = nil, - memoryLevel: Int? = nil + memoryLevel: Int? = nil, + maxDecompressedFrameSize: Int = 1 << 14 ) -> WebSocketExtensionFactory { return .init { PerMessageDeflateExtensionBuilder( @@ -288,7 +328,8 @@ extension WebSocketExtensionFactory { serverMaxWindow: serverMaxWindow, serverNoContextTakeover: serverNoContextTakeover, compressionLevel: compressionLevel, - memoryLevel: memoryLevel + memoryLevel: memoryLevel, + maxDecompressedFrameSize: maxDecompressedFrameSize ) } } diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift index 20d9330..43d9220 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClientChannel.swift @@ -22,7 +22,7 @@ import NIOWebSocket struct WebSocketClientChannel: ClientConnectionChannel { enum UpgradeResult { - case websocket(NIOAsyncChannel) + case websocket(NIOAsyncChannel, [any WebSocketExtension]) case notUpgraded } @@ -42,10 +42,16 @@ struct WebSocketClientChannel: ClientConnectionChannel { channel.eventLoop.makeCompletedFuture { let upgrader = NIOTypedWebSocketClientUpgrader( maxFrameSize: self.configuration.maxFrameSize, - upgradePipelineHandler: { channel, _ in + upgradePipelineHandler: { channel, head in channel.eventLoop.makeCompletedFuture { let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) - return UpgradeResult.websocket(asyncChannel) + // work out what extensions we should add based off the server response + let headerFields = HTTPFields(head.headers, splitCookie: false) + let serverExtensions = WebSocketExtensionHTTPParameters.parseHeaders(headerFields) + let extensions = try configuration.extensions.compactMap { + try $0.clientExtension(from: serverExtensions) + } + return UpgradeResult.websocket(asyncChannel, extensions) } } ) @@ -55,6 +61,8 @@ struct WebSocketClientChannel: ClientConnectionChannel { headers.add(name: "Content-Length", value: "0") let additionalHeaders = HTTPHeaders(self.configuration.additionalHeaders) headers.add(contentsOf: additionalHeaders) + // add websocket extensions to headers + headers.add(contentsOf: self.configuration.extensions.map { (name: "Sec-WebSocket-Extensions", value: $0.clientRequestHeader()) }) let requestHead = HTTPRequestHead( version: .http1_1, @@ -83,8 +91,8 @@ struct WebSocketClientChannel: ClientConnectionChannel { func handle(value: Value, logger: Logger) async throws { switch try await value.get() { - case .websocket(let webSocketChannel): - let webSocket = WebSocketHandler(asyncChannel: webSocketChannel, type: .client) + case .websocket(let webSocketChannel, let extensions): + let webSocket = WebSocketHandler(asyncChannel: webSocketChannel, type: .client, extensions: extensions) await webSocket.handle(handler: self.handler, context: WebSocketContext(channel: webSocketChannel.channel, logger: logger)) case .notUpgraded: // The upgrade to websocket did not succeed. diff --git a/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift b/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift index d4085fe..d8049af 100644 --- a/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift +++ b/Sources/HummingbirdWebSocket/Client/WebSocketClientConfiguration.swift @@ -19,6 +19,8 @@ public struct WebSocketClientConfiguration: Sendable { public var maxFrameSize: Int /// Additional headers to be sent with the initial HTTP request public var additionalHeaders: HTTPFields + /// WebSocket extensions + public var extensions: [any WebSocketExtensionBuilder] /// Initialize WebSocketClient configuration /// - Paramters @@ -26,9 +28,11 @@ public struct WebSocketClientConfiguration: Sendable { /// - additionalHeaders: Additional headers to be sent with the initial HTTP request public init( maxFrameSize: Int = (1 << 14), - additionalHeaders: HTTPFields = .init() + additionalHeaders: HTTPFields = .init(), + extensions: [WebSocketExtensionFactory] = [] ) { self.maxFrameSize = maxFrameSize self.additionalHeaders = additionalHeaders + self.extensions = extensions.map { $0.build() } } } diff --git a/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift b/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift index 3ca339c..0553024 100644 --- a/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift +++ b/Sources/HummingbirdWebSocket/Server/NIOWebSocketServerUpgrade+ext.swift @@ -25,12 +25,13 @@ public enum ShouldUpgradeResult: Sendable { case upgrade(HTTPFields, Value) /// Map upgrade result to difference type - func map(_ map: (Value) throws -> Result) rethrows -> ShouldUpgradeResult { + func map(_ map: (HTTPFields, Value) throws -> (HTTPFields, Result)) rethrows -> ShouldUpgradeResult { switch self { case .dontUpgrade: return .dontUpgrade case .upgrade(let headers, let value): - return try .upgrade(headers, map(value)) + let result = try map(headers, value) + return .upgrade(result.0, result.1) } } } diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift index 7759a25..79f3a3b 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketChannel.swift @@ -24,7 +24,7 @@ import NIOHTTPTypesHTTP1 import NIOWebSocket /// Child channel supporting a web socket upgrade from HTTP1 -public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { +public struct HTTP1WebSocketUpgradeChannel: ServerChildChannel, HTTPChannelHandler { public typealias WebSocketChannelHandler = @Sendable (NIOAsyncChannel, Logger) async -> Void /// Upgrade result (either a websocket AsyncChannel, or an HTTP1 AsyncChannel) public enum UpgradeResult { @@ -50,15 +50,21 @@ public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { ) { self.additionalChannelHandlers = additionalChannelHandlers self.configuration = configuration - self.shouldUpgrade = { head, channel, logger in + self.shouldUpgrade = { head, channel, logger -> EventLoopFuture> in channel.eventLoop.makeCompletedFuture { () -> ShouldUpgradeResult in try shouldUpgrade(head, channel, logger) - .map { handler in - return { asyncChannel, logger in - let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) + .map { headers, handler -> (HTTPFields, WebSocketChannelHandler) in + let (headers, extensions) = try Self.webSocketExtensionNegotiation( + extensionBuilders: configuration.extensions, + requestHeaders: head.headerFields, + responseHeaders: headers, + logger: logger + ) + return (headers, { asyncChannel, logger in + let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server, extensions: extensions) let context = WebSocketContext(channel: channel, logger: logger) await webSocket.handle(handler: handler, context: context) - } + }) } } } @@ -80,16 +86,22 @@ public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { ) { self.additionalChannelHandlers = additionalChannelHandlers self.configuration = configuration - self.shouldUpgrade = { head, channel, logger in + self.shouldUpgrade = { head, channel, logger -> EventLoopFuture> in let promise = channel.eventLoop.makePromise(of: ShouldUpgradeResult.self) promise.completeWithTask { try await shouldUpgrade(head, channel, logger) - .map { handler in - return { asyncChannel, logger in - let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) + .map { headers, handler in + let (headers, extensions) = try Self.webSocketExtensionNegotiation( + extensionBuilders: configuration.extensions, + requestHeaders: head.headerFields, + responseHeaders: headers, + logger: logger + ) + return (headers, { asyncChannel, logger in + let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server, extensions: extensions) let context = WebSocketContext(channel: channel, logger: logger) await webSocket.handle(handler: handler, context: context) - } + }) } } return promise.futureResult @@ -198,6 +210,40 @@ public struct HTTP1AndWebSocketChannel: ServerChildChannel, HTTPChannelHandler { } } + /// WebSocket extension negotiation + /// - Parameters: + /// - requestHeaders: Request headers + /// - headers: Response headers + /// - logger: Logger + /// - Returns: Response headers and extensions enabled + static func webSocketExtensionNegotiation( + extensionBuilders: [any WebSocketExtensionBuilder], + requestHeaders: HTTPFields, + responseHeaders: HTTPFields, + logger: Logger + ) throws -> (responseHeaders: HTTPFields, extensions: [any WebSocketExtension]) { + var responseHeaders = responseHeaders + let clientHeaders = WebSocketExtensionHTTPParameters.parseHeaders(requestHeaders) + if clientHeaders.count > 0 { + logger.trace( + "Extensions requested", + metadata: ["hb_extensions": .string(clientHeaders.map(\.name).joined(separator: ","))] + ) + } + let extensionResponseHeaders = extensionBuilders.compactMap { $0.serverResponseHeader(to: clientHeaders) } + responseHeaders.append(contentsOf: extensionResponseHeaders.map { .init(name: .secWebSocketExtensions, value: $0) }) + let extensions = try extensionBuilders.compactMap { + try $0.serverExtension(from: clientHeaders) + } + if extensions.count > 0 { + logger.debug( + "Enabled extensions", + metadata: ["hb_extensions": .string(extensions.map(\.name).joined(separator: ","))] + ) + } + return (responseHeaders, extensions) + } + public var responder: @Sendable (Request, Channel) async throws -> Response let shouldUpgrade: @Sendable (HTTPRequest, Channel, Logger) -> EventLoopFuture> let configuration: WebSocketServerConfiguration diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift b/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift index fe55a7b..e89cb40 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketHTTPChannelBuilder.swift @@ -24,9 +24,9 @@ extension HTTPChannelBuilder { configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult> - ) -> HTTPChannelBuilder { + ) -> HTTPChannelBuilder { return .init { responder in - return HTTP1AndWebSocketChannel( + return HTTP1WebSocketUpgradeChannel( responder: responder, configuration: configuration, additionalChannelHandlers: additionalChannelHandlers, @@ -40,9 +40,9 @@ extension HTTPChannelBuilder { configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [], shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult> - ) -> HTTPChannelBuilder { + ) -> HTTPChannelBuilder { return .init { responder in - return HTTP1AndWebSocketChannel( + return HTTP1WebSocketUpgradeChannel( responder: responder, configuration: configuration, additionalChannelHandlers: additionalChannelHandlers, diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift index 54e1570..ede01bb 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketRouter.swift @@ -117,8 +117,8 @@ public struct WebSocketUpgradeMiddleware: Rout } } -extension HTTP1AndWebSocketChannel { - /// Initialize HTTP1AndWebSocketChannel with async `shouldUpgrade` function +extension HTTP1WebSocketUpgradeChannel { + /// Initialize HTTP1WebSocketUpgradeChannel with async `shouldUpgrade` function /// - Parameters: /// - additionalChannelHandlers: Additional channel handlers to add /// - responder: HTTP responder @@ -141,8 +141,14 @@ extension HTTP1AndWebSocketChannel { do { let response = try await webSocketResponder.respond(to: request, context: context) if response.status == .ok, let webSocketHandler = context.webSocket.handler.withLockedValue({ $0 }) { - return .upgrade(response.headers) { asyncChannel, _ in - let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server) + let (headers, extensions) = try Self.webSocketExtensionNegotiation( + extensionBuilders: configuration.extensions, + requestHeaders: head.headerFields, + responseHeaders: response.headers, + logger: logger + ) + return .upgrade(headers) { asyncChannel, _ in + let webSocket = WebSocketHandler(asyncChannel: asyncChannel, type: .server, extensions: extensions) await webSocket.handle(handler: webSocketHandler.handler, context: webSocketHandler.context) } } else { @@ -174,10 +180,10 @@ extension HTTPChannelBuilder { webSocketRouter: WSResponderBuilder, configuration: WebSocketServerConfiguration = .init(), additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [] - ) -> HTTPChannelBuilder where WSResponderBuilder.Responder.Context: WebSocketRequestContext { + ) -> HTTPChannelBuilder where WSResponderBuilder.Responder.Context: WebSocketRequestContext { let webSocketReponder = webSocketRouter.buildResponder() return .init { responder in - return HTTP1AndWebSocketChannel( + return HTTP1WebSocketUpgradeChannel( responder: responder, webSocketResponder: webSocketReponder, configuration: configuration, diff --git a/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift b/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift index 8d7e471..c80d9d0 100644 --- a/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift +++ b/Sources/HummingbirdWebSocket/Server/WebSocketServerConfiguration.swift @@ -16,14 +16,18 @@ public struct WebSocketServerConfiguration: Sendable { /// Max websocket frame size that can be sent/received public var maxFrameSize: Int + /// WebSocket extensions + public var extensions: [any WebSocketExtensionBuilder] /// Initialize WebSocketClient configuration /// - Paramters /// - maxFrameSize: Max websocket frame size that can be sent/received /// - additionalHeaders: Additional headers to be sent with the initial HTTP request public init( - maxFrameSize: Int = (1 << 14) + maxFrameSize: Int = (1 << 14), + extensions: [WebSocketExtensionFactory] = [] ) { self.maxFrameSize = maxFrameSize + self.extensions = extensions.map { $0.build() } } } diff --git a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift index 6bc9a7a..91064c3 100644 --- a/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketDataHandler.swift @@ -19,4 +19,4 @@ import NIOCore import NIOWebSocket /// Function that handles websocket data and text blocks -public typealias WebSocketDataHandler = @Sendable (WebSocketInboundStream, WebSocketOutboundWriter, Context) async throws -> Void +public typealias WebSocketDataHandler = @Sendable (WebSocketInboundStream, WebSocketOutboundWriter, Context) async throws -> Void diff --git a/Sources/HummingbirdWebSocket/WebSocketExtension.swift b/Sources/HummingbirdWebSocket/WebSocketExtension.swift new file mode 100644 index 0000000..012f902 --- /dev/null +++ b/Sources/HummingbirdWebSocket/WebSocketExtension.swift @@ -0,0 +1,163 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Hummingbird server framework project +// +// Copyright (c) 2023 the Hummingbird authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import HTTPTypes +import NIOCore +import NIOWebSocket + +/// Protocol for WebSocket extension +public protocol WebSocketExtension: Sendable { + /// Extension name + var name: String { get } + /// Process frame received from websocket + func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) async throws -> WebSocketFrame + /// Process frame about to be sent to websocket + func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) async throws -> WebSocketFrame + /// shutdown extension + func shutdown() async +} + +/// Protocol for WebSocket extension builder +public protocol WebSocketExtensionBuilder: Sendable { + /// name of WebSocket extension name + static var name: String { get } + /// construct client request header + func clientRequestHeader() -> String + /// construct server response header based of client request + func serverReponseHeader(to: WebSocketExtensionHTTPParameters) -> String? + /// construct server version of extension based of client request + func serverExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? + /// construct client version of extension based of server response + func clientExtension(from: WebSocketExtensionHTTPParameters) throws -> (any WebSocketExtension)? +} + +extension WebSocketExtensionBuilder { + /// construct server response header based of all client requests + public func serverResponseHeader(to requests: [WebSocketExtensionHTTPParameters]) -> String? { + for request in requests { + guard request.name == Self.name else { continue } + if let response = serverReponseHeader(to: request) { + return response + } + } + return nil + } + + /// construct all server extensions based of all client requests + public func serverExtension(from requests: [WebSocketExtensionHTTPParameters]) throws -> (any WebSocketExtension)? { + for request in requests { + guard request.name == Self.name else { continue } + if let ext = try serverExtension(from: request) { + return ext + } + } + return nil + } + + /// construct all client extensions based of all server responses + public func clientExtension(from requests: [WebSocketExtensionHTTPParameters]) throws -> (any WebSocketExtension)? { + for request in requests { + guard request.name == Self.name else { continue } + if let ext = try clientExtension(from: request) { + return ext + } + } + return nil + } +} + +/// Build WebSocket extension builder +public struct WebSocketExtensionFactory: Sendable { + public let build: @Sendable () -> any WebSocketExtensionBuilder + + public init(_ build: @escaping @Sendable () -> any WebSocketExtensionBuilder) { + self.build = build + } +} + +/// Parsed parameters from `Sec-WebSocket-Extensions` header +public struct WebSocketExtensionHTTPParameters: Sendable, Equatable { + /// A single parameter + public enum Parameter: Sendable, Equatable { + // Parameter with a value + case value(String) + // Parameter with no value + case null + + // Convert to optional + public var optional: String? { + switch self { + case .value(let string): + return .some(string) + case .null: + return .none + } + } + + // Convert to integer + public var integer: Int? { + switch self { + case .value(let string): + return Int(string) + case .null: + return .none + } + } + } + + public let parameters: [String: Parameter] + let name: String + + /// initialise WebSocket extension parameters from string + init?(from header: some StringProtocol) { + let split = header.split(separator: ";", omittingEmptySubsequences: true).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }[...] + if let name = split.first { + self.name = name + } else { + return nil + } + var index = split.index(after: split.startIndex) + var parameters: [String: Parameter] = [:] + while index != split.endIndex { + let keyValue = split[index].split(separator: "=", maxSplits: 1).map { $0.trimmingCharacters(in: .whitespacesAndNewlines) } + if let key = keyValue.first { + if keyValue.count > 1 { + parameters[key] = .value(keyValue[1]) + } else { + parameters[key] = .null + } + } + index = split.index(after: index) + } + self.parameters = parameters + } + + /// Parse all `Sec-WebSocket-Extensions` header values + /// - Parameters: + /// - headers: headers coming from other + /// - type: client or server + /// - Returns: Array of extensions + public static func parseHeaders(_ headers: HTTPFields) -> [WebSocketExtensionHTTPParameters] { + let extHeaders = headers[values: .secWebSocketExtensions] + return extHeaders.compactMap { .init(from: $0) } + } +} + +extension WebSocketExtensionHTTPParameters { + /// Initialiser used by tests + init(_ name: String, parameters: [String: Parameter]) { + self.name = name + self.parameters = parameters + } +} diff --git a/Sources/HummingbirdWebSocket/WebSocketHandler.swift b/Sources/HummingbirdWebSocket/WebSocketHandler.swift index f45b3d8..fdcbbc6 100644 --- a/Sources/HummingbirdWebSocket/WebSocketHandler.swift +++ b/Sources/HummingbirdWebSocket/WebSocketHandler.swift @@ -29,18 +29,24 @@ public enum WebSocketType: Sendable { /// Manages ping, pong and close messages. Collates data and text messages into final frame /// and passes them onto the ``WebSocketDataHandler`` data handler setup by the user. actor WebSocketHandler: Sendable { + enum InternalError: Error { + case close(WebSocketErrorCode) + } + static let pingDataSize = 16 let asyncChannel: NIOAsyncChannel let type: WebSocketType - var closed: Bool var pingData: ByteBuffer + var closed = false + let extensions: [any WebSocketExtension] - init(asyncChannel: NIOAsyncChannel, type: WebSocketType) { + init(asyncChannel: NIOAsyncChannel, type: WebSocketType, extensions: [any WebSocketExtension]) { self.asyncChannel = asyncChannel self.type = type self.pingData = ByteBufferAllocator().buffer(capacity: Self.pingDataSize) self.closed = false + self.extensions = extensions } /// Handle WebSocket AsynChannel @@ -51,7 +57,9 @@ actor WebSocketHandler: Sendable { let webSocketOutbound = WebSocketOutboundWriter( type: self.type, allocator: asyncChannel.channel.allocator, - outbound: outbound + outbound: outbound, + extensions: self.extensions, + context: context ) try await withTaskCancellationHandler { try await withGracefulShutdownHandler { @@ -94,8 +102,14 @@ actor WebSocketHandler: Sendable { break } if let frameSeq = frameSequence, frame.fin { - await webSocketInbound.send(frameSeq.data) - frameSequence = nil + var collatedFrame = frameSeq.collapsed + for ext in self.extensions.reversed() { + collatedFrame = try await ext.processReceivedFrame(collatedFrame, context: context) + } + if let finalFrame = WebSocketDataFrame(frame: collatedFrame) { + await webSocketInbound.send(finalFrame) + frameSequence = nil + } } } catch { // catch errors while processing websocket frames so responding close message @@ -110,6 +124,8 @@ actor WebSocketHandler: Sendable { // handle websocket data and text try await handler(webSocketInbound, webSocketOutbound, context) try await self.close(code: .normalClosure, outbound: webSocketOutbound, context: context) + } catch InternalError.close(let code) { + try await self.close(code: code, outbound: webSocketOutbound, context: context) } catch { if self.type == .server { let errorCode = WebSocketErrorCode.unexpectedServerError @@ -140,7 +156,7 @@ actor WebSocketHandler: Sendable { /// Respond to ping func onPing( _ frame: WebSocketFrame, - outbound: WebSocketOutboundWriter, + outbound: WebSocketOutboundWriter, context: some WebSocketContextProtocol ) async throws { if frame.fin { @@ -153,7 +169,7 @@ actor WebSocketHandler: Sendable { /// Respond to pong func onPong( _ frame: WebSocketFrame, - outbound: WebSocketOutboundWriter, + outbound: WebSocketOutboundWriter, context: some WebSocketContextProtocol ) async throws { guard !self.closed else { return } @@ -166,7 +182,7 @@ actor WebSocketHandler: Sendable { } /// Send ping - func ping(outbound: WebSocketOutboundWriter) async throws { + func ping(outbound: WebSocketOutboundWriter) async throws { guard !self.closed else { return } if self.pingData.readableBytes == 0 { // creating random payload @@ -177,7 +193,7 @@ actor WebSocketHandler: Sendable { } /// Send pong - func pong(data: ByteBuffer?, outbound: WebSocketOutboundWriter) async throws { + func pong(data: ByteBuffer?, outbound: WebSocketOutboundWriter) async throws { guard !self.closed else { return } try await outbound.write(frame: .init(fin: true, opcode: .pong, data: data ?? .init())) } @@ -185,7 +201,7 @@ actor WebSocketHandler: Sendable { /// Send close func close( code: WebSocketErrorCode = .normalClosure, - outbound: WebSocketOutboundWriter, + outbound: WebSocketOutboundWriter, context: some WebSocketContextProtocol ) async throws { guard !self.closed else { return } diff --git a/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift b/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift index 4d334ee..dbc19d7 100644 --- a/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift +++ b/Sources/HummingbirdWebSocket/WebSocketOutboundWriter.swift @@ -16,7 +16,7 @@ import NIOCore import NIOWebSocket /// Outbound websocket writer -public struct WebSocketOutboundWriter: Sendable { +public struct WebSocketOutboundWriter: Sendable { /// WebSocket frame that can be written public enum OutboundFrame: Sendable { /// Text frame @@ -32,6 +32,8 @@ public struct WebSocketOutboundWriter: Sendable { let type: WebSocketType let allocator: ByteBufferAllocator let outbound: NIOAsyncChannelOutboundWriter + let extensions: [any WebSocketExtension] + let context: Context /// Write WebSocket frame public func write(_ frame: OutboundFrame) async throws { @@ -58,8 +60,18 @@ public struct WebSocketOutboundWriter: Sendable { frame: WebSocketFrame ) async throws { var frame = frame + do { + for ext in self.extensions { + frame = try await ext.processFrameToSend(frame, context: self.context) + } + } catch { + self.context.logger.debug("Closing as we failed to generate valid frame data") + throw WebSocketHandler.InternalError.close(.unexpectedServerError) + } frame.maskKey = self.makeMaskKey() try await self.outbound.write(frame) + + self.context.logger.trace("Sent \(frame.opcode)") } func finish() { diff --git a/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift index 057e67a..867f359 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift @@ -11,312 +11,313 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -/* - import Hummingbird - import HummingbirdWebSocket - import HummingbirdWSClient - @testable import HummingbirdWSCompression - @testable import HummingbirdWSCore - import NIOCore - import NIOPosix - import NIOWebSocket - import XCTest - final class HummingbirdWebSocketExtensionTests: XCTestCase { - static var eventLoopGroup: EventLoopGroup! - - override class func setUp() { - self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - } - - override class func tearDown() { - XCTAssertNoThrow(try self.eventLoopGroup.syncShutdownGracefully()) - } - - /// Create random buffer - /// - Parameters: - /// - size: size of buffer - /// - randomness: how random you want the buffer to be (percentage) - func createRandomBuffer(size: Int, randomness: Int = 100) -> ByteBuffer { - var buffer = ByteBufferAllocator().buffer(capacity: size) - let randomness = (randomness * randomness) / 100 - for i in 0.. Void, - onClient: @escaping (WebSocket) async throws -> Void - ) async throws -> HBApplication { - let app = HBApplication(configuration: .init(address: .hostname(port: 0))) - app.logger.logLevel = .trace - // add HTTP to WebSocket upgrade - app.ws.addUpgrade(maxFrameSize: 1 << 14, extensions: serverExtensions) - // on websocket connect. - app.ws.on("/test", onUpgrade: { _, ws in - try await onServer(ws) - return .ok - }) - try app.start() - - let eventLoop = app.eventLoopGroup.next() - let ws = try await WebSocketClient.connect( - url: HBURL("ws://localhost:\(app.server.port!)/test"), - configuration: .init(extensions: clientExtensions), - on: eventLoop - ) - try await onClient(ws) - return app - } - - func testExtensionHeaderParsing() { - let headers: HTTPHeaders = ["Sec-WebSocket-Extensions": "permessage-deflate; client_max_window_bits; server_max_window_bits=10, permessage-deflate;client_max_window_bits"] - let extensions = WebSocketExtensionHTTPParameters.parseHeaders(headers) - XCTAssertEqual( - extensions, - [ - .init("permessage-deflate", parameters: ["client_max_window_bits": .null, "server_max_window_bits": .value("10")]), - .init("permessage-deflate", parameters: ["client_max_window_bits": .null]), - ] - ) - } - - func testDeflateServerResponse() { - let requestHeaders: [WebSocketExtensionHTTPParameters] = [ - .init("permessage-deflate", parameters: ["client_max_window_bits": .value("10")]), - ] - let ext = PerMessageDeflateExtensionBuilder(clientNoContextTakeover: true, serverNoContextTakeover: true) - let serverResponse = ext.serverResponseHeader(to: requestHeaders) - XCTAssertEqual( - serverResponse, - "permessage-deflate;client_max_window_bits=10;client_no_context_takeover;server_no_context_takeover" - ) - } - - func testDeflateServerResponseClientMaxWindowBits() { - let requestHeaders: [WebSocketExtensionHTTPParameters] = [ - .init("permessage-deflate", parameters: ["client_max_window_bits": .null]), - ] - let ext1 = PerMessageDeflateExtensionBuilder(serverNoContextTakeover: true) - let serverResponse1 = ext1.serverResponseHeader(to: requestHeaders) - XCTAssertEqual( - serverResponse1, - "permessage-deflate;server_no_context_takeover" - ) - let ext2 = PerMessageDeflateExtensionBuilder(clientNoContextTakeover: true, serverMaxWindow: 12) - let serverResponse2 = ext2.serverResponseHeader(to: requestHeaders) - XCTAssertEqual( - serverResponse2, - "permessage-deflate;client_no_context_takeover;server_max_window_bits=12" - ) - } - - func testUnregonisedExtensionServerResponse() { - let requestHeaders: [WebSocketExtensionHTTPParameters] = [ - .init("permessage-foo", parameters: ["bar": .value("baz")]), - .init("permessage-deflate", parameters: ["client_max_window_bits": .value("10")]), - ] - let ext = PerMessageDeflateExtensionBuilder() - let serverResponse = ext.serverResponseHeader(to: requestHeaders) - XCTAssertEqual( - serverResponse, - "permessage-deflate;client_max_window_bits=10" - ) - } - - func testPerMessageDeflate() async throws { - let promise = TimeoutPromise(eventLoop: Self.eventLoopGroup.next(), timeout: .seconds(10)) - - let app = try await self.setupClientAndServer( - serverExtensions: [.perMessageDeflate()], - clientExtensions: [.perMessageDeflate()], - onServer: { ws in - XCTAssertNotNil(ws.extensions.first as? PerMessageDeflateExtension) - let stream = ws.readStream() - Task { - var iterator = stream.makeAsyncIterator() - let firstMessage = await iterator.next() - XCTAssertEqual(firstMessage, .text("Hello, testing this is compressed")) - let secondMessage = await iterator.next() - XCTAssertEqual(secondMessage, .text("Hello")) - for await _ in stream {} - ws.onClose { _ in - promise.succeed() - } - } - }, - onClient: { ws in - XCTAssertNotNil(ws.extensions.first as? PerMessageDeflateExtension) - try await ws.write(.text("Hello, testing this is compressed")) - try await ws.write(.text("Hello")) - try await ws.close() - } - ) - defer { app.stop() } - - try promise.wait() - } - - func testPerMessageDeflateMaxWindow() async throws { - let promise = TimeoutPromise(eventLoop: Self.eventLoopGroup.next(), timeout: .seconds(10)) - - let buffer = self.createRandomBuffer(size: 4096, randomness: 10) - let app = try await self.setupClientAndServer( - serverExtensions: [.perMessageDeflate()], - clientExtensions: [.perMessageDeflate(maxWindow: 10)], - onServer: { ws in - XCTAssertEqual((ws.extensions.first as? PerMessageDeflateExtension)?.configuration.receiveMaxWindow, 10) - let stream = ws.readStream() - Task { - for try await data in stream { - XCTAssertEqual(data, .binary(buffer)) - } - ws.onClose { _ in - promise.succeed() - } - } - }, - onClient: { ws in - XCTAssertEqual((ws.extensions.first as? PerMessageDeflateExtension)?.configuration.sendMaxWindow, 10) - try await ws.write(.binary(buffer)) - try await ws.close() - } - ) - defer { app.stop() } - - try promise.wait() - } - - func testPerMessageDeflateNoContextTakeover() async throws { - let promise = TimeoutPromise(eventLoop: Self.eventLoopGroup.next(), timeout: .seconds(10)) - - let buffer = self.createRandomBuffer(size: 4096, randomness: 10) - let app = try await self.setupClientAndServer( - serverExtensions: [.perMessageDeflate()], - clientExtensions: [.perMessageDeflate(clientNoContextTakeover: true)], - onServer: { ws in - XCTAssertEqual((ws.extensions.first as? PerMessageDeflateExtension)?.configuration.receiveNoContextTakeover, true) - let stream = ws.readStream() - Task { - for try await data in stream { - XCTAssertEqual(data, .binary(buffer)) - } - ws.onClose { _ in - promise.succeed() - } - } - }, - onClient: { ws in - XCTAssertEqual((ws.extensions.first as? PerMessageDeflateExtension)?.configuration.sendNoContextTakeover, true) - try await ws.write(.binary(buffer)) - try await ws.close() - } - ) - defer { app.stop() } - - try promise.wait() - } - - func testPerMessageExtensionOrdering() async throws { - let promise = TimeoutPromise(eventLoop: Self.eventLoopGroup.next(), timeout: .seconds(10)) - - let buffer = self.createRandomBuffer(size: 4096, randomness: 10) - let app = try await self.setupClientAndServer( - serverExtensions: [.xor(), .perMessageDeflate()], - clientExtensions: [.xor(value: 34), .perMessageDeflate()], - onServer: { ws in - // XCTAssertEqual((ws.extensions.first as? PerMessageDeflateExtension)?.configuration.receiveNoContextTakeover, true) - let stream = ws.readStream() - Task { - for try await data in stream { - XCTAssertEqual(data, .binary(buffer)) - } - ws.onClose { _ in - promise.succeed() - } - } - }, - onClient: { ws in - // XCTAssertEqual((ws.extensions.first as? PerMessageDeflateExtension)?.configuration.sendNoContextTakeover, true) - try await ws.write(.binary(buffer)) - try await ws.close() - } - ) - defer { app.stop() } - - try promise.wait() - } - } - - struct XorWebSocketExtension: WebSocketExtension { - func shutdown() {} - - func xorFrame(_ frame: WebSocketFrame, ws: WebSocket) -> WebSocketFrame { - var newBuffer = ws.channel.allocator.buffer(capacity: frame.data.readableBytes) - for byte in frame.data.readableBytesView { - newBuffer.writeInteger(byte ^ self.value) - } - var frame = frame - frame.data = newBuffer - return frame - } - - func processReceivedFrame(_ frame: WebSocketFrame, ws: WebSocket) -> WebSocketFrame { - return self.xorFrame(frame, ws: ws) - } - - func processFrameToSend(_ frame: WebSocketFrame, ws: WebSocket) throws -> WebSocketFrame { - return self.xorFrame(frame, ws: ws) - } - - let value: UInt8 - } - - struct XorWebSocketExtensionBuilder: WebSocketExtensionBuilder { - static var name = "permessage-xor" - let value: UInt8? - - init(value: UInt8? = nil) { - self.value = value - } - - func clientRequestHeader() -> String { - var header = Self.name - if let value = value { - header += ";value=\(value)" - } - return header - } - - func serverReponseHeader(to request: WebSocketExtensionHTTPParameters) -> String? { - var header = Self.name - if let value = request.parameters["value"]?.integer { - header += ";value=\(value)" - } - return header - } - - func serverExtension(from request: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (WebSocketExtension)? { - XorWebSocketExtension(value: UInt8(request.parameters["value"]?.integer ?? 255)) - } - - func clientExtension(from request: WebSocketExtensionHTTPParameters, eventLoop: EventLoop) throws -> (WebSocketExtension)? { - XorWebSocketExtension(value: UInt8(request.parameters["value"]?.integer ?? 255)) - } - } - - extension WebSocketExtensionFactory { - static func xor(value: UInt8? = nil) -> WebSocketExtensionFactory { - .init { XorWebSocketExtensionBuilder(value: value) } - } - } - */ +import Hummingbird +import HummingbirdCore +@testable import HummingbirdWebSocket +@testable import HummingbirdWSCompression +import Logging +import NIOCore +import NIOWebSocket +import ServiceLifecycle +import XCTest + +final class HummingbirdWebSocketExtensionTests: XCTestCase { + func testClientAndServer( + serverChannel: HTTPChannelBuilder, + clientExtensions: [WebSocketExtensionFactory] = [], + client clientHandler: @escaping WebSocketDataHandler + ) async throws { + try await withThrowingTaskGroup(of: Void.self) { group in + let promise = Promise() + let serverLogger = { + var logger = Logger(label: "WebSocketServer") + logger.logLevel = .trace + return logger + }() + let clientLogger = { + var logger = Logger(label: "WebSocketClient") + logger.logLevel = .trace + return logger + }() + let serviceGroup: ServiceGroup + let router = Router() + let app = Application( + router: router, + server: serverChannel, + onServerRunning: { channel in await promise.complete(channel.localAddress!.port!) }, + logger: serverLogger + ) + serviceGroup = ServiceGroup( + configuration: .init( + services: [app], + gracefulShutdownSignals: [.sigterm, .sigint], + logger: app.logger + ) + ) + group.addTask { + try await serviceGroup.run() + } + group.addTask { + let port = await promise.wait() + let client = try WebSocketClient( + url: .init("ws://localhost:\(port)/test"), + configuration: .init(extensions: clientExtensions), + logger: clientLogger, + handler: clientHandler + ) + do { + try await client.run() + } catch { + print("\(error)") + throw error + } + } + do { + try await group.next() + await serviceGroup.triggerGracefulShutdown() + } catch { + await serviceGroup.triggerGracefulShutdown() + throw error + } + } + } + + func testClientAndServer( + serverExtensions: [WebSocketExtensionFactory] = [], + clientExtensions: [WebSocketExtensionFactory] = [], + server serverHandler: @escaping WebSocketDataHandler, + client clientHandler: @escaping WebSocketDataHandler + ) async throws { + try await self.testClientAndServer( + serverChannel: .webSocketUpgrade(configuration: .init(extensions: serverExtensions)) { _, _, _ in + .upgrade([:], serverHandler) + }, + clientExtensions: clientExtensions, + client: clientHandler + ) + } + + /// Create random buffer + /// - Parameters: + /// - size: size of buffer + /// - randomness: how random you want the buffer to be (percentage) + func createRandomBuffer(size: Int, randomness: Int = 100) -> ByteBuffer { + var buffer = ByteBufferAllocator().buffer(capacity: size) + let randomness = (randomness * randomness) / 100 + for i in 0.. WebSocketFrame { + var newBuffer = context.allocator.buffer(capacity: frame.data.readableBytes) + for byte in frame.data.readableBytesView { + newBuffer.writeInteger(byte ^ self.value) + } + var frame = frame + frame.data = newBuffer + return frame + } + + func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) -> WebSocketFrame { + return self.xorFrame(frame, context: context) + } + + func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContextProtocol) throws -> WebSocketFrame { + return self.xorFrame(frame, context: context) + } + + let value: UInt8 +} + +struct XorWebSocketExtensionBuilder: WebSocketExtensionBuilder { + static var name = "permessage-xor" + let value: UInt8? + + init(value: UInt8? = nil) { + self.value = value + } + + func clientRequestHeader() -> String { + var header = Self.name + if let value { + header += ";value=\(value)" + } + return header + } + + func serverReponseHeader(to request: WebSocketExtensionHTTPParameters) -> String? { + var header = Self.name + if let value = request.parameters["value"]?.integer { + header += ";value=\(value)" + } + return header + } + + func serverExtension(from request: WebSocketExtensionHTTPParameters) throws -> (WebSocketExtension)? { + XorWebSocketExtension(value: UInt8(request.parameters["value"]?.integer ?? 255)) + } + + func clientExtension(from request: WebSocketExtensionHTTPParameters) throws -> (WebSocketExtension)? { + XorWebSocketExtension(value: UInt8(request.parameters["value"]?.integer ?? 255)) + } +} + +extension WebSocketExtensionFactory { + static func xor(value: UInt8? = nil) -> WebSocketExtensionFactory { + .init { XorWebSocketExtensionBuilder(value: value) } + } +}