diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 12e6a4fc4..361f61159 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -839,16 +839,23 @@ extension TaskHandler: ChannelDuplexHandler { }.flatMap { self.writeBody(request: request, context: context) }.flatMap { - self.state = .bodySent context.eventLoop.assertInEventLoop() + if case .endOrError = self.state { + return context.eventLoop.makeSucceededFuture(()) + } + + self.state = .bodySent if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength { let error = HTTPClientError.bodyLengthMismatch - self.errorCaught(context: context, error: error) return context.eventLoop.makeFailedFuture(error) } return context.writeAndFlush(self.wrapOutboundOut(.end(nil))) }.map { context.eventLoop.assertInEventLoop() + if case .endOrError = self.state { + return + } + self.state = .sent self.callOutToDelegateFireAndForget(self.delegate.didSendRequest) }.flatMapErrorThrowing { error in @@ -924,6 +931,10 @@ extension TaskHandler: ChannelDuplexHandler { let response = self.unwrapInboundIn(data) switch response { case .head(let head): + if case .endOrError = self.state { + return + } + if !head.isKeepAlive { self.closing = true } @@ -940,7 +951,7 @@ extension TaskHandler: ChannelDuplexHandler { } case .body(let body): switch self.state { - case .redirected: + case .redirected, .endOrError: break default: self.state = .body @@ -952,6 +963,8 @@ extension TaskHandler: ChannelDuplexHandler { } case .end: switch self.state { + case .endOrError: + break case .redirected(let head, let redirectURL): self.state = .endOrError self.task.releaseAssociatedConnection(delegateType: Delegate.self, closing: self.closing).whenSuccess { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift index 648eb8078..839a68460 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift @@ -47,6 +47,7 @@ extension HTTPClientInternalTests { ("testInternalRequestURI", testInternalRequestURI), ("testBodyPartStreamStateChangedBeforeNotification", testBodyPartStreamStateChangedBeforeNotification), ("testHandlerDoubleError", testHandlerDoubleError), + ("testTaskHandlerStateChangeAfterError", testTaskHandlerStateChangeAfterError), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index 706a3bbd7..803824a0c 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -1119,4 +1119,48 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertEqual(delegate.count, 1) } + + func testTaskHandlerStateChangeAfterError() throws { + let channel = EmbeddedChannel() + let task = Task<Void>(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled) + + let handler = TaskHandler(task: task, + kind: .host, + delegate: TestHTTPDelegate(), + redirectHandler: nil, + ignoreUncleanSSLShutdown: false, + logger: HTTPClient.loggingDisabled) + + try channel.pipeline.addHandler(handler).wait() + + var request = try Request(url: "http://localhost:8080/get") + request.headers.add(name: "X-Test-Header", value: "X-Test-Value") + request.body = .stream(length: 4) { writer in + writer.write(.byteBuffer(channel.allocator.buffer(string: "1234"))).map { + handler.state = .endOrError + } + } + + XCTAssertNoThrow(try channel.writeOutbound(request)) + + try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok))) + XCTAssertTrue(handler.state.isEndOrError) + + try channel.writeInbound(HTTPClientResponsePart.body(channel.allocator.buffer(string: "1234"))) + XCTAssertTrue(handler.state.isEndOrError) + + try channel.writeInbound(HTTPClientResponsePart.end(nil)) + XCTAssertTrue(handler.state.isEndOrError) + } +} + +extension TaskHandler.State { + var isEndOrError: Bool { + switch self { + case .endOrError: + return true + default: + return false + } + } }