Skip to content

Commit

Permalink
Updates based off changes to HB
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler committed May 16, 2024
1 parent 6cd9c6b commit cebb11f
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ import HummingbirdWSCore
import Logging
import NIOCore

extension HTTPChannelBuilder {
extension HTTPServerBuilder {
/// HTTP1 channel builder supporting a websocket upgrade
/// - parameters
public static func http1WebSocketUpgrade(
configuration: WebSocketServerConfiguration = .init(),
additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [],
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) async throws -> ShouldUpgradeResult<WebSocketDataHandler<BasicWebSocketContext>>
) -> HTTPChannelBuilder {
) -> HTTPServerBuilder {
return .init { responder in
return HTTP1WebSocketUpgradeChannel(
responder: responder,
Expand All @@ -42,7 +42,7 @@ extension HTTPChannelBuilder {
configuration: WebSocketServerConfiguration = .init(),
additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = [],
shouldUpgrade: @escaping @Sendable (HTTPRequest, Channel, Logger) throws -> ShouldUpgradeResult<WebSocketDataHandler<BasicWebSocketContext>>
) -> HTTPChannelBuilder {
) -> HTTPServerBuilder {
return .init { responder in
return HTTP1WebSocketUpgradeChannel(
responder: responder,
Expand Down
13 changes: 8 additions & 5 deletions Sources/HummingbirdWebSocket/WebSocketChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@ import NIOWebSocket
public struct HTTP1WebSocketUpgradeChannel: ServerChildChannel, HTTPChannelHandler {
public typealias WebSocketChannelHandler = @Sendable (NIOAsyncChannel<WebSocketFrame, WebSocketFrame>, Logger) async -> Void
/// Upgrade result (either a websocket AsyncChannel, or an HTTP1 AsyncChannel)
public enum UpgradeResult {
public enum UpgradeResult: Sendable {
case websocket(NIOAsyncChannel<WebSocketFrame, WebSocketFrame>, WebSocketChannelHandler, Logger)
case notUpgraded(NIOAsyncChannel<HTTPRequestPart, HTTPResponsePart>)
case failedUpgrade(NIOAsyncChannel<HTTPRequestPart, HTTPResponsePart>, Logger)
}

public typealias Value = EventLoopFuture<UpgradeResult>
public struct Value: ServerChildChannelValue {
let upgradeResult: EventLoopFuture<UpgradeResult>
public let channel: Channel
}

/// Initialize HTTP1AndWebSocketChannel with synchronous `shouldUpgrade` function
/// - Parameters:
Expand Down Expand Up @@ -181,17 +184,17 @@ public struct HTTP1WebSocketUpgradeChannel: ServerChildChannel, HTTPChannelHandl
configuration: .init(upgradeConfiguration: serverUpgradeConfiguration)
)

return negotiationResultFuture
return .init(upgradeResult: negotiationResultFuture, channel: channel)
}
}

/// Handle upgrade result output from channel
/// - Parameters:
/// - upgradeResult: The upgrade result output by Channel
/// - logger: Logger to use
public func handle(value upgradeResult: EventLoopFuture<UpgradeResult>, logger: Logger) async {
public func handle(value: Value, logger: Logger) async {
do {
let result = try await upgradeResult.get()
let result = try await value.upgradeResult.get()
switch result {
case .notUpgraded(let http1):
await self.handleHTTP(asyncChannel: http1, logger: logger)
Expand Down
4 changes: 2 additions & 2 deletions Sources/HummingbirdWebSocket/WebSocketRouter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ extension HTTP1WebSocketUpgradeChannel {
}
}

extension HTTPChannelBuilder {
extension HTTPServerBuilder {
/// HTTP1 channel builder supporting a websocket upgrade
///
/// With this function you provide a separate router from the one you have supplied
Expand All @@ -213,7 +213,7 @@ extension HTTPChannelBuilder {
webSocketRouter: WSResponderBuilder,
configuration: WebSocketServerConfiguration = .init(),
additionalChannelHandlers: @autoclosure @escaping @Sendable () -> [any RemovableChannelHandler] = []
) -> HTTPChannelBuilder where WSResponderBuilder.Responder.Context: WebSocketRequestContext {
) -> HTTPServerBuilder where WSResponderBuilder.Responder.Context: WebSocketRequestContext {
let webSocketReponder = webSocketRouter.buildResponder()
return .init { responder in
return HTTP1WebSocketUpgradeChannel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import XCTest

final class HummingbirdWebSocketExtensionTests: XCTestCase {
func testClientAndServer(
serverChannel: HTTPChannelBuilder,
serverChannel: HTTPServerBuilder,
clientExtensions: [WebSocketExtensionFactory] = [],
client clientHandler: @escaping WebSocketDataHandler<BasicWebSocketContext>
) async throws {
Expand Down
18 changes: 9 additions & 9 deletions Tests/HummingbirdWebSocketTests/WebSocketTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ final class HummingbirdWebSocketTests: XCTestCase {
}

@discardableResult func testClientAndServer(
serverChannel: HTTPChannelBuilder,
server: HTTPServerBuilder,
getClient: @escaping @Sendable (Int, Logger) throws -> WebSocketClient
) async throws -> WebSocketCloseFrame? {
try await withThrowingTaskGroup(of: Void.self) { group in
Expand All @@ -96,7 +96,7 @@ final class HummingbirdWebSocketTests: XCTestCase {
let serviceGroup: ServiceGroup
let app = Application(
router: router,
server: serverChannel,
server: server,
configuration: .init(address: .hostname("127.0.0.1", port: 0)),
onServerRunning: { channel in await promise.complete(channel.localAddress!.port!) },
logger: serverLogger
Expand Down Expand Up @@ -129,7 +129,7 @@ final class HummingbirdWebSocketTests: XCTestCase {
shouldUpgrade: @escaping @Sendable (HTTPRequest) throws -> HTTPFields? = { _ in return [:] },
getClient: @escaping @Sendable (Int, Logger) throws -> WebSocketClient
) async throws -> WebSocketCloseFrame? {
let webSocketUpgrade: HTTPChannelBuilder = .http1WebSocketUpgrade { head, _, _ in
let webSocketUpgrade: HTTPServerBuilder = .http1WebSocketUpgrade { head, _, _ in
if let headers = try shouldUpgrade(head) {
return .upgrade(headers, serverHandler)
} else {
Expand All @@ -138,12 +138,12 @@ final class HummingbirdWebSocketTests: XCTestCase {
}
if let serverTLSConfiguration {
return try await self.testClientAndServer(
serverChannel: .tls(webSocketUpgrade, tlsConfiguration: serverTLSConfiguration),
server: .tls(webSocketUpgrade, tlsConfiguration: serverTLSConfiguration),
getClient: getClient
)
} else {
return try await self.testClientAndServer(
serverChannel: webSocketUpgrade,
server: webSocketUpgrade,
getClient: getClient
)
}
Expand Down Expand Up @@ -173,9 +173,9 @@ final class HummingbirdWebSocketTests: XCTestCase {
webSocketRouter: Router<some WebSocketRequestContext>,
getClient: @escaping @Sendable (Int, Logger) throws -> WebSocketClient
) async throws -> WebSocketCloseFrame? {
let webSocketUpgrade: HTTPChannelBuilder = .http1WebSocketUpgrade(webSocketRouter: webSocketRouter)
let webSocketUpgrade: HTTPServerBuilder = .http1WebSocketUpgrade(webSocketRouter: webSocketRouter)
return try await self.testClientAndServer(
serverChannel: webSocketUpgrade,
server: webSocketUpgrade,
getClient: getClient
)
}
Expand Down Expand Up @@ -540,11 +540,11 @@ final class HummingbirdWebSocketTests: XCTestCase {
router.ws("/ws") { inbound, _, _ in
for try await _ in inbound {}
}
let webSocketUpgrade: HTTPChannelBuilder = .http1WebSocketUpgrade(
let webSocketUpgrade: HTTPServerBuilder = .http1WebSocketUpgrade(
webSocketRouter: router,
configuration: .init(autoPing: .enabled(timePeriod: .milliseconds(50)))
)
try await self.testClientAndServer(serverChannel: webSocketUpgrade) { port, logger in
try await self.testClientAndServer(server: webSocketUpgrade) { port, logger in
WebSocketClient(
url: .init("ws://localhost:\(port)/ws"),
configuration: .init(additionalHeaders: [.secWebSocketExtensions: "hb"]),
Expand Down

0 comments on commit cebb11f

Please sign in to comment.