diff --git a/Sources/HummingbirdWSCore/WebSocketHandler.swift b/Sources/HummingbirdWSCore/WebSocketHandler.swift index 9ca029f..669dedc 100644 --- a/Sources/HummingbirdWSCore/WebSocketHandler.swift +++ b/Sources/HummingbirdWSCore/WebSocketHandler.swift @@ -134,7 +134,7 @@ package actor WebSocketHandler { } } } - let rt = try await webSocketHandler.handle(inbound: inbound, outbound: outbound, handler: handler, context: context) + let rt = try await webSocketHandler.handle(type: type, inbound: inbound, outbound: outbound, handler: handler, context: context) group.cancelAll() return rt } @@ -153,6 +153,7 @@ package actor WebSocketHandler { } func handle( + type: WebSocketType, inbound: NIOAsyncChannelInboundStream, outbound: NIOAsyncChannelOutboundWriter, handler: @escaping WebSocketDataHandler, @@ -166,6 +167,7 @@ package actor WebSocketHandler { handler: self ) let closeCode: WebSocketErrorCode + var clientError: Error? do { // handle websocket data and text try await handler(webSocketInbound, webSocketOutbound, context) @@ -173,6 +175,7 @@ package actor WebSocketHandler { } catch InternalError.close(let code) { closeCode = code } catch { + clientError = error closeCode = .unexpectedServerError } do { @@ -188,6 +191,9 @@ package actor WebSocketHandler { } // don't propagate error if channel is already closed } catch ChannelError.ioOnClosedChannel {} + if type == .client, let clientError { + throw clientError + } } onGracefulShutdown: { Task { try? await self.close(code: .normalClosure) diff --git a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift index 203490d..9d49c82 100644 --- a/Tests/HummingbirdWebSocketTests/WebSocketTests.swift +++ b/Tests/HummingbirdWebSocketTests/WebSocketTests.swift @@ -708,17 +708,45 @@ final class HummingbirdWebSocketTests: XCTestCase { await serviceGroup.triggerGracefulShutdown() } } -} -extension Logger { - /// Create new Logger with additional metadata value - /// - Parameters: - /// - metadataKey: Metadata key - /// - value: Metadata value - /// - Returns: Logger - func with(metadataKey: String, value: MetadataValue) -> Logger { - var logger = self - logger[metadataKey: metadataKey] = value - return logger + func testClientErrorHandling() async throws { + struct ClientError: Error {} + let app = Application( + router: Router(), + server: .http1WebSocketUpgrade { _, _, _ in + .upgrade([:]) { inbound, _, _ in + for try await _ in inbound {} + } + } + ) + try await app.test(.live) { client in + do { + _ = try await client.ws("/") { _, _, _ in + throw ClientError() + } + XCTFail("Shouldnt reach here") + } catch is ClientError { + } catch { + XCTFail("Throwing wrong error") + } + } + } + + func testServerErrorHandling() async throws { + struct ServerError: Error {} + let app = Application( + router: Router(), + server: .http1WebSocketUpgrade { _, _, _ in + .upgrade([:]) { _, _, _ in + throw ServerError() + } + } + ) + try await app.test(.live) { client in + let closeFrame = try await client.ws("/") { inbound, _, _ in + for try await _ in inbound {} + } + XCTAssertEqual(closeFrame?.closeCode, .unexpectedServerError) + } } }