From 24425989dadab6d6e4167174791a23d4e2a6d0c3 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Tue, 26 Apr 2022 16:04:54 +0200 Subject: [PATCH] Call `didSendRequestPart` after the write has hit the socket (#566) ### Motivation Today `didSendRequestPart` is called after a request body part has been passed to the executor. However, this does not mean that the write hit the socket. Users may depend on this behavior to implement back-pressure. For this reason, we should only call this `didSendRequestPart` once the write was successful. ### Modification Pass a promise to the actual channel write and only call the delegate once that promise succeeds. ### Result The delegate method `didSendRequestPart` is only called after the write was successful. Fixes #565. Co-authored-by: Fabian Fett --- .../AsyncAwait/Transaction.swift | 8 +- .../HTTP1/HTTP1ClientChannelHandler.swift | 69 ++++++--- .../HTTP1/HTTP1ConnectionStateMachine.swift | 63 +++++--- .../HTTP2/HTTP2ClientRequestHandler.swift | 55 ++++--- .../HTTPExecutableRequest.swift | 4 +- .../HTTPRequestStateMachine.swift | 71 +++++---- Sources/AsyncHTTPClient/RequestBag.swift | 18 ++- .../HTTP1ConnectionStateMachineTests.swift | 54 +++++-- .../HTTPRequestStateMachineTests+XCTest.swift | 1 + .../HTTPRequestStateMachineTests.swift | 143 ++++++++++++------ .../Mocks/MockRequestExecutor.swift | 6 +- 11 files changed, 327 insertions(+), 165 deletions(-) diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift index c2ce52eeb..8830406b4 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift @@ -63,7 +63,7 @@ final class Transaction: @unchecked Sendable { switch writeAction { case .writeAndWait(let executor), .writeAndContinue(let executor): - executor.writeRequestBodyPart(.byteBuffer(byteBuffer), request: self) + executor.writeRequestBodyPart(.byteBuffer(byteBuffer), request: self, promise: nil) case .fail: // an error/cancellation has happened. we don't need to continue here @@ -105,14 +105,14 @@ final class Transaction: @unchecked Sendable { switch self.state.writeNextRequestPart() { case .writeAndContinue(let executor): self.stateLock.unlock() - executor.writeRequestBodyPart(.byteBuffer(part), request: self) + executor.writeRequestBodyPart(.byteBuffer(part), request: self, promise: nil) case .writeAndWait(let executor): try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in self.state.waitForRequestBodyDemand(continuation: continuation) self.stateLock.unlock() - executor.writeRequestBodyPart(.byteBuffer(part), request: self) + executor.writeRequestBodyPart(.byteBuffer(part), request: self, promise: nil) } case .fail: @@ -132,7 +132,7 @@ final class Transaction: @unchecked Sendable { break case .forwardStreamFinished(let executor, let succeedContinuation): - executor.finishRequestBodyStream(self) + executor.finishRequestBodyStream(self, promise: nil) succeedContinuation?.resume(returning: nil) } return diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift index 9d1a3b5fd..2a3bc9c27 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift @@ -185,11 +185,11 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { case .sendRequestHead(let head, startBody: let startBody): self.sendRequestHead(head, startBody: startBody, context: context) - case .sendBodyPart(let part): - context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: nil) + case .sendBodyPart(let part, let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: writePromise) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + case .sendRequestEnd(let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { self.runTimeoutAction(timeoutAction, context: context) @@ -260,16 +260,25 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { switch finalAction { case .close: context.close(promise: nil) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + oldRequest.succeedRequest(buffer) + case .sendRequestEnd(let writePromise): + let writePromise = writePromise ?? context.eventLoop.makePromise(of: Void.self) + // We need to defer succeeding the old request to avoid ordering issues + writePromise.futureResult.whenComplete { result in + switch result { + case .success: + oldRequest.succeedRequest(buffer) + case .failure(let error): + oldRequest.fail(error) + } + } + + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) case .informConnectionIsIdle: self.connection.taskCompleted() - case .none: - break + oldRequest.succeedRequest(buffer) } - oldRequest.succeedRequest(buffer) - case .failRequest(let error, let finalAction): // see comment in the `succeedRequest` case. let oldRequest = self.request! @@ -277,17 +286,25 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) switch finalAction { - case .close: + case .close(let writePromise): context.close(promise: nil) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + writePromise?.fail(error) + oldRequest.fail(error) + case .informConnectionIsIdle: self.connection.taskCompleted() + oldRequest.fail(error) + + case .failWritePromise(let writePromise): + writePromise?.fail(error) + oldRequest.fail(error) + case .none: - break + oldRequest.fail(error) } - oldRequest.fail(error) + case .failSendBodyPart(let error, let writePromise), .failSendStreamFinished(let error, let writePromise): + writePromise?.fail(error) } } @@ -355,27 +372,29 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { // MARK: Private HTTPRequestExecutor - private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest) { + private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { guard self.request === request, let context = self.channelContext else { // Because the HTTPExecutableRequest may run in a different thread to our eventLoop, // calls from the HTTPExecutableRequest to our ChannelHandler may arrive here after // the request has been popped by the state machine or the ChannelHandler has been // removed from the Channel pipeline. This is a normal threading issue, noone has // screwed up. + promise?.fail(HTTPClientError.requestStreamCancelled) return } - let action = self.state.requestStreamPartReceived(data) + let action = self.state.requestStreamPartReceived(data, promise: promise) self.run(action, context: context) } - private func finishRequestBodyStream0(_ request: HTTPExecutableRequest) { + private func finishRequestBodyStream0(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` + promise?.fail(HTTPClientError.requestStreamCancelled) return } - let action = self.state.requestStreamFinished() + let action = self.state.requestStreamFinished(promise: promise) self.run(action, context: context) } @@ -405,22 +424,22 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { } extension HTTP1ClientChannelHandler: HTTPRequestExecutor { - func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest) { + func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { if self.eventLoop.inEventLoop { - self.writeRequestBodyPart0(data, request: request) + self.writeRequestBodyPart0(data, request: request, promise: promise) } else { self.eventLoop.execute { - self.writeRequestBodyPart0(data, request: request) + self.writeRequestBodyPart0(data, request: request, promise: promise) } } } - func finishRequestBodyStream(_ request: HTTPExecutableRequest) { + func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { if self.eventLoop.inEventLoop { - self.finishRequestBodyStream0(request) + self.finishRequestBodyStream0(request, promise: promise) } else { self.eventLoop.execute { - self.finishRequestBodyStream0(request) + self.finishRequestBodyStream0(request, promise: promise) } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift index 19825aec7..ecff7afc7 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift @@ -28,21 +28,37 @@ struct HTTP1ConnectionStateMachine { enum Action { /// A action to execute, when we consider a request "done". - enum FinalStreamAction { + enum FinalSuccessfulStreamAction { /// Close the connection case close /// If the server has replied, with a status of 200...300 before all data was sent, a request is considered succeeded, /// as soon as we wrote the request end onto the wire. - case sendRequestEnd + /// + /// The promise is an optional write promise. + case sendRequestEnd(EventLoopPromise?) /// Inform an observer that the connection has become idle case informConnectionIsIdle + } + + /// A action to execute, when we consider a request "done". + enum FinalFailedStreamAction { + /// Close the connection + /// + /// The promise is an optional write promise. + case close(EventLoopPromise?) + /// Inform an observer that the connection has become idle + case informConnectionIsIdle + /// Fail the write promise + case failWritePromise(EventLoopPromise?) /// Do nothing. case none } case sendRequestHead(HTTPRequestHead, startBody: Bool) - case sendBodyPart(IOData) - case sendRequestEnd + case sendBodyPart(IOData, EventLoopPromise?) + case sendRequestEnd(EventLoopPromise?) + case failSendBodyPart(Error, EventLoopPromise?) + case failSendStreamFinished(Error, EventLoopPromise?) case pauseRequestBodyStream case resumeRequestBodyStream @@ -50,8 +66,8 @@ struct HTTP1ConnectionStateMachine { case forwardResponseHead(HTTPResponseHead, pauseRequestBodyStream: Bool) case forwardResponseBodyParts(CircularBuffer) - case failRequest(Error, FinalStreamAction) - case succeedRequest(FinalStreamAction, CircularBuffer) + case failRequest(Error, FinalFailedStreamAction) + case succeedRequest(FinalSuccessfulStreamAction, CircularBuffer) case read case close @@ -189,25 +205,25 @@ struct HTTP1ConnectionStateMachine { } } - mutating func requestStreamPartReceived(_ part: IOData) -> Action { + mutating func requestStreamPartReceived(_ part: IOData, promise: EventLoopPromise?) -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { preconditionFailure("Invalid state: \(self.state)") } return self.avoidingStateMachineCoW { state -> Action in - let action = requestStateMachine.requestStreamPartReceived(part) + let action = requestStateMachine.requestStreamPartReceived(part, promise: promise) state = .inRequest(requestStateMachine, close: close) return state.modify(with: action) } } - mutating func requestStreamFinished() -> Action { + mutating func requestStreamFinished(promise: EventLoopPromise?) -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { preconditionFailure("Invalid state: \(self.state)") } return self.avoidingStateMachineCoW { state -> Action in - let action = requestStateMachine.requestStreamFinished() + let action = requestStateMachine.requestStreamFinished(promise: promise) state = .inRequest(requestStateMachine, close: close) return state.modify(with: action) } @@ -377,10 +393,10 @@ extension HTTP1ConnectionStateMachine.State { return .pauseRequestBodyStream case .resumeRequestBodyStream: return .resumeRequestBodyStream - case .sendBodyPart(let part): - return .sendBodyPart(part) - case .sendRequestEnd: - return .sendRequestEnd + case .sendBodyPart(let part, let writePromise): + return .sendBodyPart(part, writePromise) + case .sendRequestEnd(let writePromise): + return .sendRequestEnd(writePromise) case .forwardResponseHead(let head, let pauseRequestBodyStream): return .forwardResponseHead(head, pauseRequestBodyStream: pauseRequestBodyStream) case .forwardResponseBodyParts(let parts): @@ -390,13 +406,13 @@ extension HTTP1ConnectionStateMachine.State { preconditionFailure("Invalid state: \(self)") } - let newFinalAction: HTTP1ConnectionStateMachine.Action.FinalStreamAction + let newFinalAction: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction switch finalAction { case .close: self = .closing newFinalAction = .close - case .sendRequestEnd: - newFinalAction = .sendRequestEnd + case .sendRequestEnd(let writePromise): + newFinalAction = .sendRequestEnd(writePromise) case .none: self = .idle newFinalAction = close ? .close : .informConnectionIsIdle @@ -410,9 +426,12 @@ extension HTTP1ConnectionStateMachine.State { case .idle: preconditionFailure("How can we fail a task, if we are idle") case .inRequest(_, close: let close): - if close || finalAction == .close { + if case .close(let promise) = finalAction { + self = .closing + return .failRequest(error, .close(promise)) + } else if close { self = .closing - return .failRequest(error, .close) + return .failRequest(error, .close(nil)) } else { self = .idle return .failRequest(error, .informConnectionIsIdle) @@ -433,6 +452,12 @@ extension HTTP1ConnectionStateMachine.State { case .wait: return .wait + + case .failSendBodyPart(let error, let writePromise): + return .failSendBodyPart(error, writePromise) + + case .failSendStreamFinished(let error, let writePromise): + return .failSendStreamFinished(error, writePromise) } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift index 8b2a50738..578b83029 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift @@ -148,11 +148,11 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // that the request is neither failed nor finished yet self.request!.pauseRequestBodyStream() - case .sendBodyPart(let data): - context.writeAndFlush(self.wrapOutboundOut(.body(data)), promise: nil) + case .sendBodyPart(let data, let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.body(data)), promise: writePromise) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + case .sendRequestEnd(let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) if let timeoutAction = self.idleReadTimeoutStateMachine?.requestEndSent() { self.runTimeoutAction(timeoutAction, context: context) @@ -185,7 +185,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // that the request is neither failed nor finished yet self.request!.receiveResponseBodyParts(parts) - case .failRequest(let error, _): + case .failRequest(let error, let finalAction): // We can force unwrap the request here, as we have just validated in the state machine, // that the request object is still present. self.request!.fail(error) @@ -195,7 +195,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // once the h2 stream is closed, it is released from the h2 multiplexer. The // HTTPRequestStateMachine may signal finalAction: .none in the error case (as this is // the right result for HTTP/1). In the h2 case we MUST always close. - self.runFinalAction(.close, context: context) + self.runFailedFinalAction(finalAction, context: context, error: error) case .succeedRequest(let finalAction, let finalParts): // We can force unwrap the request here, as we have just validated in the state machine, @@ -203,7 +203,10 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.request!.succeedRequest(finalParts) self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) - self.runFinalAction(finalAction, context: context) + self.runSuccessfulFinalAction(finalAction, context: context) + + case .failSendBodyPart(let error, let writePromise), .failSendStreamFinished(let error, let writePromise): + writePromise?.fail(error) } } @@ -234,13 +237,24 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { } } - private func runFinalAction(_ action: HTTPRequestStateMachine.Action.FinalStreamAction, context: ChannelHandlerContext) { + private func runSuccessfulFinalAction(_ action: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction, context: ChannelHandlerContext) { switch action { case .close: context.close(promise: nil) - case .sendRequestEnd: - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + case .sendRequestEnd(let writePromise): + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: writePromise) + + case .none: + break + } + } + + private func runFailedFinalAction(_ action: HTTPRequestStateMachine.Action.FinalFailedRequestAction, context: ChannelHandlerContext, error: Error) { + switch action { + case .close(let writePromise): + context.close(promise: nil) + writePromise?.fail(error) case .none: break @@ -281,27 +295,28 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // MARK: Private HTTPRequestExecutor - private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest) { + private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { guard self.request === request, let context = self.channelContext else { // Because the HTTPExecutableRequest may run in a different thread to our eventLoop, // calls from the HTTPExecutableRequest to our ChannelHandler may arrive here after // the request has been popped by the state machine or the ChannelHandler has been // removed from the Channel pipeline. This is a normal threading issue, noone has // screwed up. + promise?.fail(HTTPClientError.requestStreamCancelled) return } - let action = self.state.requestStreamPartReceived(data) + let action = self.state.requestStreamPartReceived(data, promise: promise) self.run(action, context: context) } - private func finishRequestBodyStream0(_ request: HTTPExecutableRequest) { + private func finishRequestBodyStream0(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` return } - let action = self.state.requestStreamFinished() + let action = self.state.requestStreamFinished(promise: promise) self.run(action, context: context) } @@ -327,22 +342,22 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { } extension HTTP2ClientRequestHandler: HTTPRequestExecutor { - func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest) { + func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { if self.eventLoop.inEventLoop { - self.writeRequestBodyPart0(data, request: request) + self.writeRequestBodyPart0(data, request: request, promise: promise) } else { self.eventLoop.execute { - self.writeRequestBodyPart0(data, request: request) + self.writeRequestBodyPart0(data, request: request, promise: promise) } } } - func finishRequestBodyStream(_ request: HTTPExecutableRequest) { + func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { if self.eventLoop.inEventLoop { - self.finishRequestBodyStream0(request) + self.finishRequestBodyStream0(request, promise: promise) } else { self.eventLoop.execute { - self.finishRequestBodyStream0(request) + self.finishRequestBodyStream0(request, promise: promise) } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift index 2477e1154..d64ceedd6 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift @@ -180,12 +180,12 @@ protocol HTTPRequestExecutor { /// Writes a body part into the channel pipeline /// /// This method may be **called on any thread**. The executor needs to ensure thread safety. - func writeRequestBodyPart(_: IOData, request: HTTPExecutableRequest) + func writeRequestBodyPart(_: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) /// Signals that the request body stream has finished /// /// This method may be **called on any thread**. The executor needs to ensure thread safety. - func finishRequestBodyStream(_ task: HTTPExecutableRequest) + func finishRequestBodyStream(_ task: HTTPExecutableRequest, promise: EventLoopPromise?) /// Signals that more bytes from response body stream can be consumed. /// diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index fa520a865..aafa3d28b 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -70,21 +70,34 @@ struct HTTPRequestStateMachine { } enum Action { - /// A action to execute, when we consider a request "done". - enum FinalStreamAction { + /// A action to execute, when we consider a successful request "done". + enum FinalSuccessfulRequestAction { /// Close the connection case close /// If the server has replied, with a status of 200...300 before all data was sent, a request is considered succeeded, /// as soon as we wrote the request end onto the wire. - case sendRequestEnd + /// + /// The promise is an optional write promise. + case sendRequestEnd(EventLoopPromise?) + /// Do nothing. This is action is used, if the request failed, before we the request head was written onto the wire. + /// This might happen if the request is cancelled, or the request failed the soundness check. + case none + } + + /// A action to execute, when we consider a failed request "done". + enum FinalFailedRequestAction { + /// Close the connection + case close(EventLoopPromise?) /// Do nothing. This is action is used, if the request failed, before we the request head was written onto the wire. /// This might happen if the request is cancelled, or the request failed the soundness check. case none } case sendRequestHead(HTTPRequestHead, startBody: Bool) - case sendBodyPart(IOData) - case sendRequestEnd + case sendBodyPart(IOData, EventLoopPromise?) + case sendRequestEnd(EventLoopPromise?) + case failSendBodyPart(Error, EventLoopPromise?) + case failSendStreamFinished(Error, EventLoopPromise?) case pauseRequestBodyStream case resumeRequestBodyStream @@ -92,8 +105,8 @@ struct HTTPRequestStateMachine { case forwardResponseHead(HTTPResponseHead, pauseRequestBodyStream: Bool) case forwardResponseBodyParts(CircularBuffer) - case failRequest(Error, FinalStreamAction) - case succeedRequest(FinalStreamAction, CircularBuffer) + case failRequest(Error, FinalFailedRequestAction) + case succeedRequest(FinalSuccessfulRequestAction, CircularBuffer) case read case wait @@ -212,7 +225,7 @@ struct HTTPRequestStateMachine { return .failRequest(error, .none) case .running: self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(nil)) case .finished, .failed: // ignore error @@ -254,14 +267,14 @@ struct HTTPRequestStateMachine { // we have received all necessary bytes. For this reason we forward the uncleanShutdown // error to the user. self.state = .failed(NIOSSLError.uncleanShutdown) - return .failRequest(NIOSSLError.uncleanShutdown, .close) + return .failRequest(NIOSSLError.uncleanShutdown, .close(nil)) case .waitForChannelToBecomeWritable, .running, .finished, .failed, .initialized, .modifying: return nil } } - mutating func requestStreamPartReceived(_ part: IOData) -> Action { + mutating func requestStreamPartReceived(_ part: IOData, promise: EventLoopPromise?) -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable, @@ -274,7 +287,7 @@ struct HTTPRequestStateMachine { // won't be interested. We expect that the producer has been informed to pause // producing. assert(producerState == .paused) - return .wait + return .failSendBodyPart(HTTPClientError.requestStreamCancelled, promise) case .running(.streaming(let expectedBodyLength, var sentBodyBytes, let producerState), let responseState): // We don't check the producer state here: @@ -290,7 +303,7 @@ struct HTTPRequestStateMachine { if let expected = expectedBodyLength, sentBodyBytes + part.readableBytes > expected { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(promise)) } sentBodyBytes += part.readableBytes @@ -303,10 +316,10 @@ struct HTTPRequestStateMachine { self.state = .running(requestState, responseState) - return .sendBodyPart(part) + return .sendBodyPart(part, promise) - case .failed: - return .wait + case .failed(let error): + return .failSendBodyPart(error, promise) case .finished: // A request may be finished, before we have send all parts. This might be the case if @@ -318,14 +331,14 @@ struct HTTPRequestStateMachine { // We may still receive something, here because of potential race conditions with the // producing thread. - return .wait + return .failSendBodyPart(HTTPClientError.requestStreamCancelled, promise) case .modifying: preconditionFailure("Invalid state: \(self.state)") } } - mutating func requestStreamFinished() -> Action { + mutating func requestStreamFinished(promise: EventLoopPromise?) -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable, @@ -336,11 +349,11 @@ struct HTTPRequestStateMachine { if let expected = expectedBodyLength, expected != sentBodyBytes { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(promise)) } self.state = .running(.endSent, .waitingForHead) - return .sendRequestEnd + return .sendRequestEnd(promise) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .receivingBody(let head, let streamState)): assert(head.status.code < 300) @@ -348,24 +361,24 @@ struct HTTPRequestStateMachine { if let expected = expectedBodyLength, expected != sentBodyBytes { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(promise)) } self.state = .running(.endSent, .receivingBody(head, streamState)) - return .sendRequestEnd + return .sendRequestEnd(promise) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .endReceived): if let expected = expectedBodyLength, expected != sentBodyBytes { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(promise)) } self.state = .finished - return .succeedRequest(.sendRequestEnd, .init()) + return .succeedRequest(.sendRequestEnd(promise), .init()) - case .failed: - return .wait + case .failed(let error): + return .failSendStreamFinished(error, promise) case .finished: // A request may be finished, before we have send all parts. This might be the case if @@ -377,7 +390,7 @@ struct HTTPRequestStateMachine { // We may still receive something, here because of potential race conditions with the // producing thread. - return .wait + return .failSendStreamFinished(HTTPClientError.requestStreamCancelled, promise) case .modifying: preconditionFailure("Invalid state: \(self.state)") @@ -398,7 +411,7 @@ struct HTTPRequestStateMachine { case .running: let error = HTTPClientError.cancelled self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(nil)) case .finished: return .wait @@ -597,7 +610,7 @@ struct HTTPRequestStateMachine { // the request is still uploading, we will not be able to finish the upload. For // this reason we can fail the request here. state = .failed(HTTPClientError.remoteConnectionClosed) - return .failRequest(HTTPClientError.remoteConnectionClosed, .close) + return .failRequest(HTTPClientError.remoteConnectionClosed, .close(nil)) } } @@ -670,7 +683,7 @@ struct HTTPRequestStateMachine { case .running(.endSent, .waitingForHead), .running(.endSent, .receivingBody): let error = HTTPClientError.readTimeout self.state = .failed(error) - return .failRequest(error, .close) + return .failRequest(error, .close(nil)) case .running(.endSent, .endReceived): preconditionFailure("Invalid state. This state should be: .finished") diff --git a/Sources/AsyncHTTPClient/RequestBag.swift b/Sources/AsyncHTTPClient/RequestBag.swift index b4aeef0e7..dbef802e9 100644 --- a/Sources/AsyncHTTPClient/RequestBag.swift +++ b/Sources/AsyncHTTPClient/RequestBag.swift @@ -154,8 +154,11 @@ final class RequestBag { return self.task.eventLoop.makeFailedFuture(error) case .write(let part, let writer, let future): - writer.writeRequestBodyPart(part, request: self) - self.delegate.didSendRequestPart(task: self.task, part) + let promise = self.task.eventLoop.makePromise(of: Void.self) + promise.futureResult.whenSuccess { + self.delegate.didSendRequestPart(task: self.task, part) + } + writer.writeRequestBodyPart(part, request: self, promise: promise) return future } } @@ -168,11 +171,12 @@ final class RequestBag { switch action { case .none: break - case .forwardStreamFinished(let writer, let promise): - writer.finishRequestBodyStream(self) - promise?.succeed(()) - - self.delegate.didSendRequest(task: self.task) + case .forwardStreamFinished(let writer, let writerPromise): + let promise = writerPromise ?? self.task.eventLoop.makePromise(of: Void.self) + promise.futureResult.whenSuccess { + self.delegate.didSendRequest(task: self.task) + } + writer.finishRequestBodyStream(self, promise: promise) case .forwardStreamFailureAndFailTask(let writer, let error, let promise): writer.cancelRequest(self) diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift index c8ad3d510..55014f8c6 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift @@ -32,22 +32,22 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) let part2 = IOData.byteBuffer(ByteBuffer(bytes: [2])) let part3 = IOData.byteBuffer(ByteBuffer(bytes: [3])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) // oh the channel reports... we should slow down producing... XCTAssertEqual(state.writabilityChanged(writable: false), .pauseRequestBodyStream) // but we issued a .produceMoreRequestBodyData before... Thus, we must accept more produced // data - XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) + XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) // however when we have put the data on the channel, we should not issue further // .produceMoreRequestBodyData events // once we receive a writable event again, we can allow the producer to produce more data XCTAssertEqual(state.writabilityChanged(writable: true), .resumeRequestBodyStream) - XCTAssertEqual(state.requestStreamPartReceived(part3), .sendBodyPart(part3)) - XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd) + XCTAssertEqual(state.requestStreamPartReceived(part3, promise: nil), .sendBodyPart(part3, nil)) + XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -186,9 +186,9 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) - XCTAssertEqual(state.requestCancelled(closeConnection: false), .failRequest(HTTPClientError.cancelled, .close)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) + XCTAssertEqual(state.requestCancelled(closeConnection: false), .failRequest(HTTPClientError.cancelled, .close(nil))) } func testCancelRequestIsIgnoredWhenConnectionIsIdle() { @@ -241,7 +241,7 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: "Hello world!\n"))), .wait) XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: "Foo Bar!\n"))), .wait) let decompressionError = NIOHTTPDecompression.DecompressionError.limit - XCTAssertEqual(state.errorHappened(decompressionError), .failRequest(decompressionError, .close)) + XCTAssertEqual(state.errorHappened(decompressionError), .failRequest(decompressionError, .close(nil))) } func testConnectionIsClosedAfterSwitchingProtocols() { @@ -295,8 +295,8 @@ extension HTTP1ConnectionStateMachine.Action: Equatable { case (.sendRequestHead(let lhsHead, let lhsStartBody), .sendRequestHead(let rhsHead, let rhsStartBody)): return lhsHead == rhsHead && lhsStartBody == rhsStartBody - case (.sendBodyPart(let lhsData), .sendBodyPart(let rhsData)): - return lhsData == rhsData + case (.sendBodyPart(let lhsData, let lhsPromise), .sendBodyPart(let rhsData, let rhsPromise)): + return lhsData == rhsData && lhsPromise?.futureResult == rhsPromise?.futureResult case (.sendRequestEnd, .sendRequestEnd): return true @@ -332,3 +332,35 @@ extension HTTP1ConnectionStateMachine.Action: Equatable { } } } + +extension HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction: Equatable { + public static func == (lhs: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction, rhs: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction) -> Bool { + switch (lhs, rhs) { + case (.close, .close): + return true + case (sendRequestEnd(let lhsPromise), sendRequestEnd(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult + case (informConnectionIsIdle, informConnectionIsIdle): + return true + default: + return false + } + } +} + +extension HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction: Equatable { + public static func == (lhs: HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction, rhs: HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction) -> Bool { + switch (lhs, rhs) { + case (.close(let lhsPromise), .close(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult + case (.informConnectionIsIdle, .informConnectionIsIdle): + return true + case (.failWritePromise(let lhsPromise), .failWritePromise(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult + case (.none, .none): + return true + default: + return false + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift index b54865fd8..ad85bd71e 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests+XCTest.swift @@ -30,6 +30,7 @@ extension HTTPRequestStateMachineTests { ("testPOSTContentLengthIsTooLong", testPOSTContentLengthIsTooLong), ("testPOSTContentLengthIsTooShort", testPOSTContentLengthIsTooShort), ("testRequestBodyStreamIsCancelledIfServerRespondsWith301", testRequestBodyStreamIsCancelledIfServerRespondsWith301), + ("testStreamPartReceived_whenCancelled", testStreamPartReceived_whenCancelled), ("testRequestBodyStreamIsCancelledIfServerRespondsWith301WhileWriteBackpressure", testRequestBodyStreamIsCancelledIfServerRespondsWith301WhileWriteBackpressure), ("testRequestBodyStreamIsContinuedIfServerRespondsWith200", testRequestBodyStreamIsContinuedIfServerRespondsWith200), ("testRequestBodyStreamIsContinuedIfServerSendHeadWithStatus200", testRequestBodyStreamIsContinuedIfServerSendHeadWithStatus200), diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift index ab55345c9..a68d58aa0 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -14,6 +14,7 @@ @testable import AsyncHTTPClient import NIOCore +import NIOEmbedded import NIOHTTP1 import NIOSSL import XCTest @@ -42,22 +43,22 @@ class HTTPRequestStateMachineTests: XCTestCase { let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) let part2 = IOData.byteBuffer(ByteBuffer(bytes: [2])) let part3 = IOData.byteBuffer(ByteBuffer(bytes: [3])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) // oh the channel reports... we should slow down producing... XCTAssertEqual(state.writabilityChanged(writable: false), .pauseRequestBodyStream) // but we issued a .produceMoreRequestBodyData before... Thus, we must accept more produced // data - XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) + XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) // however when we have put the data on the channel, we should not issue further // .produceMoreRequestBodyData events // once we receive a writable event again, we can allow the producer to produce more data XCTAssertEqual(state.writabilityChanged(writable: true), .resumeRequestBodyStream) - XCTAssertEqual(state.requestStreamPartReceived(part3), .sendBodyPart(part3)) - XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd) + XCTAssertEqual(state.requestStreamPartReceived(part3, promise: nil), .sendBodyPart(part3, nil)) + XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -74,9 +75,9 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) - state.requestStreamPartReceived(part1).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) + state.requestStreamPartReceived(part1, promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) // if another error happens the new one is ignored XCTAssertEqual(state.errorHappened(HTTPClientError.remoteConnectionClosed), .wait) @@ -88,9 +89,9 @@ class HTTPRequestStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(8)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) - state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) + state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) } func testRequestBodyStreamIsCancelledIfServerRespondsWith301() { @@ -99,22 +100,31 @@ class HTTPRequestStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) - XCTAssertEqual(state.requestStreamPartReceived(part), .sendBodyPart(part)) + XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .sendBodyPart(part, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .movedPermanently) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: true)) XCTAssertEqual(state.writabilityChanged(writable: false), .wait) XCTAssertEqual(state.writabilityChanged(writable: true), .wait) - XCTAssertEqual(state.requestStreamPartReceived(part), .wait, + XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), "Expected to drop all stream data after having received a response head, with status >= 300") XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init())) - XCTAssertEqual(state.requestStreamPartReceived(part), .wait, + XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), "Expected to drop all stream data after having received a response head, with status >= 300") - XCTAssertEqual(state.requestStreamFinished(), .wait, + XCTAssertEqual(state.requestStreamFinished(promise: nil), .failSendStreamFinished(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300") + } + + func testStreamPartReceived_whenCancelled() { + var state = HTTPRequestStateMachine(isChannelWritable: false) + let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) + + XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.cancelled, .none)) + XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .failSendBodyPart(HTTPClientError.cancelled, nil), "Expected to drop all stream data after having received a response head, with status >= 300") } @@ -124,22 +134,22 @@ class HTTPRequestStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) - XCTAssertEqual(state.requestStreamPartReceived(part), .sendBodyPart(part)) + XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .sendBodyPart(part, nil)) XCTAssertEqual(state.writabilityChanged(writable: false), .pauseRequestBodyStream) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .movedPermanently) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) XCTAssertEqual(state.writabilityChanged(writable: true), .wait) - XCTAssertEqual(state.requestStreamPartReceived(part), .wait, + XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), "Expected to drop all stream data after having received a response head, with status >= 300") XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init())) - XCTAssertEqual(state.requestStreamPartReceived(part), .wait, + XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), "Expected to drop all stream data after having received a response head, with status >= 300") - XCTAssertEqual(state.requestStreamFinished(), .wait, + XCTAssertEqual(state.requestStreamFinished(promise: nil), .failSendStreamFinished(HTTPClientError.requestStreamCancelled, nil), "Expected to drop all stream data after having received a response head, with status >= 300") } @@ -149,7 +159,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) @@ -157,10 +167,12 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseBodyParts(.init())) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) let part2 = IOData.byteBuffer(ByteBuffer(bytes: 8...11)) - XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) - XCTAssertEqual(state.requestStreamFinished(), .succeedRequest(.sendRequestEnd, .init())) + XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) + XCTAssertEqual(state.requestStreamFinished(promise: nil), .succeedRequest(.sendRequestEnd(nil), .init())) + + XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil)) } func testRequestBodyStreamIsContinuedIfServerSendHeadWithStatus200() { @@ -169,17 +181,17 @@ class HTTPRequestStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) let part2 = IOData.byteBuffer(ByteBuffer(bytes: 8...11)) - XCTAssertEqual(state.requestStreamPartReceived(part2), .sendBodyPart(part2)) - XCTAssertEqual(state.requestStreamFinished(), .sendRequestEnd) + XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) + XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil)) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) } @@ -190,7 +202,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) @@ -198,8 +210,8 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseBodyParts(.init())) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) - state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) + state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -209,15 +221,15 @@ class HTTPRequestStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) - XCTAssertEqual(state.requestStreamPartReceived(part0), .sendBodyPart(part0)) + XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) - XCTAssertEqual(state.requestStreamPartReceived(part1), .sendBodyPart(part1)) - state.requestStreamFinished().assertFailRequest(HTTPClientError.bodyLengthMismatch, .close) + XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) + state.requestStreamFinished(promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) XCTAssertEqual(state.channelRead(.end(nil)), .wait) } @@ -366,7 +378,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close) + state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close(nil)) } func testRemoteSuddenlyClosesTheConnection() { @@ -374,8 +386,8 @@ class HTTPRequestStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: .init([("content-length", "4")])) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) - state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close) - XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(.init(bytes: 1...3))), .wait) + state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close(nil)) + XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(.init(bytes: 1...3)), promise: nil), .failSendBodyPart(HTTPClientError.cancelled, nil)) } func testReadTimeoutLeadsToFailureWithEverythingAfterBeingIgnored() { @@ -388,7 +400,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) let part0 = ByteBuffer(bytes: 0...3) XCTAssertEqual(state.channelRead(.body(part0)), .wait) - state.idleReadTimeoutTriggered().assertFailRequest(HTTPClientError.readTimeout, .close) + state.idleReadTimeoutTriggered().assertFailRequest(HTTPClientError.readTimeout, .close(nil)) XCTAssertEqual(state.channelRead(.body(ByteBuffer(bytes: 4...7))), .wait) XCTAssertEqual(state.channelRead(.body(ByteBuffer(bytes: 8...11))), .wait) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) @@ -441,7 +453,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - state.errorHappened(HTTPParserError.invalidChunkSize).assertFailRequest(HTTPParserError.invalidChunkSize, .close) + state.errorHappened(HTTPParserError.invalidChunkSize).assertFailRequest(HTTPParserError.invalidChunkSize, .close(nil)) XCTAssertEqual(state.requestCancelled(), .wait, "A cancellation that happens to late is ignored") } @@ -486,7 +498,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: true)) let part1: ByteBuffer = .init(string: "foo") - XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(part1)), .sendBodyPart(.byteBuffer(part1))) + XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(part1), promise: nil), .sendBodyPart(.byteBuffer(part1), nil)) let responseHead = HTTPResponseHead(version: .http1_0, status: .ok) let body = ByteBuffer(string: "foo bar") XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) @@ -495,7 +507,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.channelRead(.body(body)), .wait) - state.channelRead(.end(nil)).assertFailRequest(HTTPClientError.remoteConnectionClosed, .close) + state.channelRead(.end(nil)).assertFailRequest(HTTPClientError.remoteConnectionClosed, .close(nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -510,7 +522,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelRead(.body(body)), .wait) - state.errorHappened(NIOSSLError.uncleanShutdown).assertFailRequest(NIOSSLError.uncleanShutdown, .close) + state.errorHappened(NIOSSLError.uncleanShutdown).assertFailRequest(NIOSSLError.uncleanShutdown, .close(nil)) XCTAssertEqual(state.channelRead(.end(nil)), .wait) XCTAssertEqual(state.channelInactive(), .wait) } @@ -532,7 +544,7 @@ class HTTPRequestStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false)) - state.errorHappened(ArbitraryError()).assertFailRequest(ArbitraryError(), .close) + state.errorHappened(ArbitraryError()).assertFailRequest(ArbitraryError(), .close(nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -550,7 +562,7 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.channelRead(.body(body)), .wait) XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts([body])) XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .wait) - state.errorHappened(HTTPParserError.invalidEOFState).assertFailRequest(HTTPParserError.invalidEOFState, .close) + state.errorHappened(HTTPParserError.invalidEOFState).assertFailRequest(HTTPParserError.invalidEOFState, .close(nil)) XCTAssertEqual(state.channelInactive(), .wait) } @@ -656,11 +668,11 @@ extension HTTPRequestStateMachine.Action: Equatable { case (.sendRequestHead(let lhsHead, let lhsStartBody), .sendRequestHead(let rhsHead, let rhsStartBody)): return lhsHead == rhsHead && lhsStartBody == rhsStartBody - case (.sendBodyPart(let lhsData), .sendBodyPart(let rhsData)): - return lhsData == rhsData + case (.sendBodyPart(let lhsData, let lhsPromise), .sendBodyPart(let rhsData, let rhsPromise)): + return lhsData == rhsData && lhsPromise?.futureResult == rhsPromise?.futureResult - case (.sendRequestEnd, .sendRequestEnd): - return true + case (.sendRequestEnd(let lhsPromise), .sendRequestEnd(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult case (.pauseRequestBodyStream, .pauseRequestBodyStream): return true @@ -685,6 +697,45 @@ extension HTTPRequestStateMachine.Action: Equatable { case (.wait, .wait): return true + case (.failSendBodyPart(let lhsError as HTTPClientError, let lhsPromise), .failSendBodyPart(let rhsError as HTTPClientError, let rhsPromise)): + return lhsError == rhsError && lhsPromise?.futureResult == rhsPromise?.futureResult + + case (.failSendStreamFinished(let lhsError as HTTPClientError, let lhsPromise), .failSendStreamFinished(let rhsError as HTTPClientError, let rhsPromise)): + return lhsError == rhsError && lhsPromise?.futureResult == rhsPromise?.futureResult + + default: + return false + } + } +} + +extension HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction: Equatable { + public static func == (lhs: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction, rhs: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction) -> Bool { + switch (lhs, rhs) { + case (.close, close): + return true + + case (.sendRequestEnd(let lhsPromise), .sendRequestEnd(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult + + case (.none, .none): + return true + + default: + return false + } + } +} + +extension HTTPRequestStateMachine.Action.FinalFailedRequestAction: Equatable { + public static func == (lhs: HTTPRequestStateMachine.Action.FinalFailedRequestAction, rhs: HTTPRequestStateMachine.Action.FinalFailedRequestAction) -> Bool { + switch (lhs, rhs) { + case (.close(let lhsPromise), close(let rhsPromise)): + return lhsPromise?.futureResult == rhsPromise?.futureResult + + case (.none, .none): + return true + default: return false } @@ -694,7 +745,7 @@ extension HTTPRequestStateMachine.Action: Equatable { extension HTTPRequestStateMachine.Action { fileprivate func assertFailRequest( _ expectedError: Error, - _ expectedFinalStreamAction: HTTPRequestStateMachine.Action.FinalStreamAction, + _ expectedFinalStreamAction: HTTPRequestStateMachine.Action.FinalFailedRequestAction, file: StaticString = #file, line: UInt = #line ) where Error: Swift.Error & Equatable { diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift b/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift index b5b67c809..b37ce8fa3 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift @@ -184,12 +184,14 @@ extension MockRequestExecutor: HTTPRequestExecutor { // this should always be called twice. When we receive the first call, the next call to produce // data is already scheduled. If we call pause here, once, after the second call new subsequent // calls should not be scheduled. - func writeRequestBodyPart(_ part: IOData, request: HTTPExecutableRequest) { + func writeRequestBodyPart(_ part: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { self.writeNextRequestPart(.body(part), request: request) + promise?.succeed(()) } - func finishRequestBodyStream(_ request: HTTPExecutableRequest) { + func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { self.writeNextRequestPart(.endOfStream, request: request) + promise?.succeed(()) } private func writeNextRequestPart(_ part: RequestParts, request: HTTPExecutableRequest) {