Skip to content

Commit

Permalink
WebSocketExtensionContext
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler committed Jul 3, 2024
1 parent 444cdb4 commit e716af4
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ struct PerMessageDeflateExtension: WebSocketExtension {
try self.decompressor.startStream()
}

func decompress(_ frame: WebSocketFrame, maxSize: Int, resetStream: Bool, context: some WebSocketContext) throws -> WebSocketFrame {
func decompress(_ frame: WebSocketFrame, maxSize: Int, resetStream: Bool, context: WebSocketExtensionContext) throws -> WebSocketFrame {
if self.state == .idle {
if frame.rsv1 {
self.state = .decompressingMessage
Expand Down Expand Up @@ -231,7 +231,7 @@ struct PerMessageDeflateExtension: WebSocketExtension {
try self.compressor.startStream()
}

func compress(_ frame: WebSocketFrame, resetStream: Bool, context: some WebSocketContext) throws -> WebSocketFrame {
func compress(_ frame: WebSocketFrame, resetStream: Bool, context: WebSocketExtensionContext) throws -> WebSocketFrame {
// if the frame is larger than `minFrameSizeToCompress` bytes, we haven't received a final frame
// or we are in the process of sending a message compress the data
let shouldWeCompress = frame.data.readableBytes >= self.minFrameSizeToCompress || !frame.fin || self.sendState != .idle
Expand Down Expand Up @@ -292,7 +292,7 @@ struct PerMessageDeflateExtension: WebSocketExtension {
try? await self.compressor.shutdown()
}

func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContext) async throws -> WebSocketFrame {
func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame {
return try await self.decompressor.decompress(
frame,
maxSize: self.configuration.maxDecompressedFrameSize,
Expand All @@ -301,7 +301,7 @@ struct PerMessageDeflateExtension: WebSocketExtension {
)
}

func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContext) async throws -> WebSocketFrame {
func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame {
let isCorrectType = frame.opcode == .text || frame.opcode == .binary || frame.opcode == .continuation
if isCorrectType {
return try await self.compressor.compress(frame, resetStream: self.configuration.sendNoContextTakeover, context: context)
Expand Down
17 changes: 15 additions & 2 deletions Sources/HummingbirdWSCore/WebSocketExtension.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,29 @@

import Foundation
import HTTPTypes
import Logging
import NIOCore
import NIOWebSocket

/// Basic context implementation of ``WebSocketContext``.
public struct WebSocketExtensionContext {
public let allocator: ByteBufferAllocator
public let logger: Logger

init(allocator: ByteBufferAllocator, logger: Logger) {
self.allocator = allocator
self.logger = logger
}
}

/// Protocol for WebSocket extension
public protocol WebSocketExtension: Sendable {
/// Extension name
var name: String { get }
/// Process frame received from websocket
func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContext) async throws -> WebSocketFrame
func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame
/// Process frame about to be sent to websocket
func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContext) async throws -> WebSocketFrame
func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) async throws -> WebSocketFrame
/// shutdown extension
func shutdown() async
}
Expand Down
30 changes: 12 additions & 18 deletions Sources/HummingbirdWSCore/WebSocketHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,6 @@ public struct WebSocketCloseFrame {
/// Manages ping, pong and close messages. Collates data and text messages into final frame
/// and passes them onto the ``WebSocketDataHandler`` data handler setup by the user.
package actor WebSocketHandler {
/// Basic context implementation of ``WebSocketContext``.
public struct Context: WebSocketContext {
public let allocator: ByteBufferAllocator
public let logger: Logger

package init(allocator: ByteBufferAllocator, logger: Logger) {
self.allocator = allocator
self.logger = logger
}
}

enum InternalError: Error {
case close(WebSocketErrorCode)
}
Expand All @@ -89,7 +78,8 @@ package actor WebSocketHandler {
var outbound: NIOAsyncChannelOutboundWriter<WebSocketFrame>
let type: WebSocketType
let configuration: Configuration
let context: Context
let logger: Logger
let allocator: ByteBufferAllocator
var pingData: ByteBuffer
var pingTime: ContinuousClock.Instant = .now
var closeState: CloseState
Expand All @@ -103,7 +93,8 @@ package actor WebSocketHandler {
self.outbound = outbound
self.type = type
self.configuration = configuration
self.context = .init(allocator: context.allocator, logger: context.logger)
self.logger = context.logger
self.allocator = context.allocator
self.pingData = ByteBufferAllocator().buffer(capacity: Self.pingDataSize)
self.closeState = .open
}
Expand Down Expand Up @@ -212,10 +203,13 @@ package actor WebSocketHandler {
var frame = frame
do {
for ext in self.configuration.extensions {
frame = try await ext.processFrameToSend(frame, context: self.context)
frame = try await ext.processFrameToSend(
frame,
context: WebSocketExtensionContext(allocator: self.allocator, logger: self.logger)
)
}
} catch {
self.context.logger.debug("Closing as we failed to generate valid frame data")
self.logger.debug("Closing as we failed to generate valid frame data")
throw WebSocketHandler.InternalError.close(.unexpectedServerError)
}
// Set mask key if client
Expand All @@ -224,7 +218,7 @@ package actor WebSocketHandler {
}
try await self.outbound.write(frame)

self.context.logger.trace("Sent \(frame.traceDescription)")
self.logger.trace("Sent \(frame.traceDescription)")
}

func finish() {
Expand Down Expand Up @@ -284,7 +278,7 @@ package actor WebSocketHandler {
) async throws {
switch self.closeState {
case .open:
var buffer = self.context.allocator.buffer(capacity: 2 + (reason?.utf8.count ?? 0))
var buffer = self.allocator.buffer(capacity: 2 + (reason?.utf8.count ?? 0))
buffer.write(webSocketErrorCode: code)
if let reason {
buffer.writeString(reason)
Expand Down Expand Up @@ -329,7 +323,7 @@ package actor WebSocketHandler {
.protocolError
}

var buffer = self.context.allocator.buffer(capacity: 2)
var buffer = self.allocator.buffer(capacity: 2)
buffer.write(webSocketErrorCode: code)

try await self.write(frame: .init(fin: true, opcode: .connectionClose, data: buffer))
Expand Down
11 changes: 7 additions & 4 deletions Sources/HummingbirdWSCore/WebSocketInboundStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public final class WebSocketInboundStream: AsyncSequence, Sendable {
// parse messages coming from inbound
while let frame = try await self.iterator.next() {
do {
self.handler.context.logger.trace("Received \(frame.traceDescription)")
self.handler.logger.trace("Received \(frame.traceDescription)")
switch frame.opcode {
case .connectionClose:
try await self.handler.receivedClose(frame)
Expand All @@ -70,16 +70,19 @@ public final class WebSocketInboundStream: AsyncSequence, Sendable {
// apply extensions
var frame = frame
for ext in self.handler.configuration.extensions.reversed() {
frame = try await ext.processReceivedFrame(frame, context: self.handler.context)
frame = try await ext.processReceivedFrame(
frame,
context: WebSocketExtensionContext(allocator: self.handler.allocator, logger: self.handler.logger)
)
}
return .init(from: frame)
default:
// if we receive a reserved opcode we should fail the connection
self.handler.context.logger.trace("Received reserved opcode", metadata: ["opcode": .stringConvertible(frame.opcode)])
self.handler.logger.trace("Received reserved opcode", metadata: ["opcode": .stringConvertible(frame.opcode)])
throw WebSocketHandler.InternalError.close(.protocolError)
}
} catch {
self.handler.context.logger.trace("Error: \(error)")
self.handler.logger.trace("Error: \(error)")
// catch errors while processing websocket frames so responding close message
// can be dealt with
let errorCode = WebSocketErrorCode(error)
Expand Down
4 changes: 2 additions & 2 deletions Sources/HummingbirdWSCore/WebSocketOutboundWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public struct WebSocketOutboundWriter: Sendable {
try await self.handler.write(frame: .init(fin: true, opcode: .binary, data: buffer))
case .text(let string):
// send text based data
let buffer = self.handler.context.allocator.buffer(string: string)
let buffer = self.handler.allocator.buffer(string: string)
try await self.handler.write(frame: .init(fin: true, opcode: .text, data: buffer))
case .pong:
// send unexplained pong as a heartbeat
Expand Down Expand Up @@ -73,7 +73,7 @@ public struct WebSocketOutboundWriter: Sendable {

/// Write string to WebSocket frame
public mutating func callAsFunction(_ text: String) async throws {
let buffer = self.handler.context.allocator.buffer(string: text)
let buffer = self.handler.allocator.buffer(string: text)
try await self.write(buffer, opcode: self.opcode)
}

Expand Down
10 changes: 5 additions & 5 deletions Tests/HummingbirdWebSocketTests/WebSocketExtensionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ struct XorWebSocketExtension: WebSocketExtension {
let name = "xor"
func shutdown() {}

func xorFrame(_ frame: WebSocketFrame, context: some WebSocketContext) -> WebSocketFrame {
func xorFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) -> WebSocketFrame {
var newBuffer = context.allocator.buffer(capacity: frame.data.readableBytes)
for byte in frame.unmaskedData.readableBytesView {
newBuffer.writeInteger(byte ^ self.value)
Expand All @@ -321,11 +321,11 @@ struct XorWebSocketExtension: WebSocketExtension {
return frame
}

func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContext) -> WebSocketFrame {
func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) -> WebSocketFrame {
return self.xorFrame(frame, context: context)
}

func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContext) throws -> WebSocketFrame {
func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) throws -> WebSocketFrame {
return self.xorFrame(frame, context: context)
}

Expand Down Expand Up @@ -377,12 +377,12 @@ struct CheckDeflateWebSocketExtension: WebSocketExtension {
let name = "check-deflate"
func shutdown() {}

func processReceivedFrame(_ frame: WebSocketFrame, context: some WebSocketContext) throws -> WebSocketFrame {
func processReceivedFrame(_ frame: WebSocketFrame, context: WebSocketExtensionContext) throws -> WebSocketFrame {
guard frame.rsv1 else { throw NoDeflateError() }
return frame
}

func processFrameToSend(_ frame: WebSocketFrame, context: some WebSocketContext) throws -> WebSocketFrame {
func processFrameToSend(_ frame: WebSocketFrame, context: WebSocketExtensionContext) throws -> WebSocketFrame {
return frame
}
}
Expand Down

0 comments on commit e716af4

Please sign in to comment.