Skip to content

Commit

Permalink
Fix state reverting (#298)
Browse files Browse the repository at this point in the history
* fail if we get part when state is endOrError

* Prevent TaskHandler state change after `.endOrError`

Motivation:
Right now if task handler encounters an error, it changes state to
`.endOrError`. We gate on that state to make sure that we do not
process errors in the pipeline twice. Unfortunately, that state
can be reset when we upload body or receive response parts.

Modifications:
Adds state validation before state is updated to a new value
Adds a test

Result:
Fixes #297
  • Loading branch information
artemredkin authored Aug 24, 2020
1 parent ffcd1e1 commit 4b4d660
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 3 deletions.
19 changes: 16 additions & 3 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -839,16 +839,23 @@ extension TaskHandler: ChannelDuplexHandler {
}.flatMap {
self.writeBody(request: request, context: context)
}.flatMap {
self.state = .bodySent
context.eventLoop.assertInEventLoop()
if case .endOrError = self.state {
return context.eventLoop.makeSucceededFuture(())
}

self.state = .bodySent
if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
let error = HTTPClientError.bodyLengthMismatch
self.errorCaught(context: context, error: error)
return context.eventLoop.makeFailedFuture(error)
}
return context.writeAndFlush(self.wrapOutboundOut(.end(nil)))
}.map {
context.eventLoop.assertInEventLoop()
if case .endOrError = self.state {
return
}

self.state = .sent
self.callOutToDelegateFireAndForget(self.delegate.didSendRequest)
}.flatMapErrorThrowing { error in
Expand Down Expand Up @@ -924,6 +931,10 @@ extension TaskHandler: ChannelDuplexHandler {
let response = self.unwrapInboundIn(data)
switch response {
case .head(let head):
if case .endOrError = self.state {
return
}

if !head.isKeepAlive {
self.closing = true
}
Expand All @@ -940,7 +951,7 @@ extension TaskHandler: ChannelDuplexHandler {
}
case .body(let body):
switch self.state {
case .redirected:
case .redirected, .endOrError:
break
default:
self.state = .body
Expand All @@ -952,6 +963,8 @@ extension TaskHandler: ChannelDuplexHandler {
}
case .end:
switch self.state {
case .endOrError:
break
case .redirected(let head, let redirectURL):
self.state = .endOrError
self.task.releaseAssociatedConnection(delegateType: Delegate.self, closing: self.closing).whenSuccess {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ extension HTTPClientInternalTests {
("testInternalRequestURI", testInternalRequestURI),
("testBodyPartStreamStateChangedBeforeNotification", testBodyPartStreamStateChangedBeforeNotification),
("testHandlerDoubleError", testHandlerDoubleError),
("testTaskHandlerStateChangeAfterError", testTaskHandlerStateChangeAfterError),
]
}
}
44 changes: 44 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1119,4 +1119,48 @@ class HTTPClientInternalTests: XCTestCase {

XCTAssertEqual(delegate.count, 1)
}

func testTaskHandlerStateChangeAfterError() throws {
let channel = EmbeddedChannel()
let task = Task<Void>(eventLoop: channel.eventLoop, logger: HTTPClient.loggingDisabled)

let handler = TaskHandler(task: task,
kind: .host,
delegate: TestHTTPDelegate(),
redirectHandler: nil,
ignoreUncleanSSLShutdown: false,
logger: HTTPClient.loggingDisabled)

try channel.pipeline.addHandler(handler).wait()

var request = try Request(url: "http://localhost:8080/get")
request.headers.add(name: "X-Test-Header", value: "X-Test-Value")
request.body = .stream(length: 4) { writer in
writer.write(.byteBuffer(channel.allocator.buffer(string: "1234"))).map {
handler.state = .endOrError
}
}

XCTAssertNoThrow(try channel.writeOutbound(request))

try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok)))
XCTAssertTrue(handler.state.isEndOrError)

try channel.writeInbound(HTTPClientResponsePart.body(channel.allocator.buffer(string: "1234")))
XCTAssertTrue(handler.state.isEndOrError)

try channel.writeInbound(HTTPClientResponsePart.end(nil))
XCTAssertTrue(handler.state.isEndOrError)
}
}

extension TaskHandler.State {
var isEndOrError: Bool {
switch self {
case .endOrError:
return true
default:
return false
}
}
}

0 comments on commit 4b4d660

Please sign in to comment.