Skip to content

Commit

Permalink
[HTTPRequestStateMachine] Allow channelReadComplete at any time (#450)
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianfett authored Oct 1, 2021
1 parent a6ca288 commit d928cc8
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,21 @@ extension HTTPRequestStateMachine {
return buffer
}

// For all the following cases, please note:
// Normally these code paths should never be hit. However there is one way to trigger
// this:
//
// If the connection to a server is closed, NIO will forward all outstanding
// `channelRead`s without waiting for a next `context.read` call. After all
// `channelRead`s are delivered, we will also see a `channelReadComplete` call. After
// this has happened, we know that we will get a channelInactive or further
// `channelReads`. If the request ever gets to an `.end` all buffered data will be
// forwarded to the user.

case .waitingForRead,
.waitingForDemand,
.waitingForReadOrDemand:
preconditionFailure("How can we receive a body part, after a channelReadComplete, but no read has been forwarded yet. Invalid state: \(self.state)")
return nil

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ extension HTTP1ClientChannelHandlerTests {
("testWriteBackpressure", testWriteBackpressure),
("testClientHandlerCancelsRequestIfWeWantToShutdown", testClientHandlerCancelsRequestIfWeWantToShutdown),
("testIdleReadTimeout", testIdleReadTimeout),
("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand),
]
}
}
58 changes: 58 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,64 @@ class HTTP1ClientChannelHandlerTests: XCTestCase {
XCTAssertEqual($0 as? HTTPClientError, .readTimeout)
}
}

func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand() {
let embedded = EmbeddedChannel()
var maybeTestUtils: HTTP1TestTools?
XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection())
guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") }

var maybeRequest: HTTPClient.Request?
XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/"))
guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") }

let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop)
var maybeRequestBag: RequestBag<ResponseBackpressureDelegate>?
XCTAssertNoThrow(maybeRequestBag = try RequestBag(
request: request,
eventLoopPreference: .delegate(on: embedded.eventLoop),
task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger),
redirectHandler: nil,
connectionDeadline: .now() + .seconds(30),
requestOptions: .forTests(),
delegate: delegate
))
guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") }

testUtils.connection.executeRequest(requestBag)

XCTAssertNoThrow(try embedded.receiveHeadAndVerify {
XCTAssertEqual($0.method, .GET)
XCTAssertEqual($0.uri, "/")
XCTAssertEqual($0.headers.first(name: "host"), "localhost")
})
XCTAssertNoThrow(try embedded.receiveEnd())

let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "50")]))

XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0)
embedded.read()
XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 1)
XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead)))

// not sending anything after the head should lead to request fail and connection close
embedded.pipeline.fireChannelReadComplete()
embedded.pipeline.read()
XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 2)

XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.body(ByteBuffer(string: "foo bar"))))
embedded.pipeline.fireChannelReadComplete()
// We miss a `embedded.pipeline.read()` here by purpose.
XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 2)

XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.body(ByteBuffer(string: "last bytes"))))
embedded.pipeline.fireChannelReadComplete()
embedded.pipeline.fireChannelInactive()

XCTAssertThrowsError(try requestBag.task.futureResult.wait()) {
XCTAssertEqual($0 as? HTTPClientError, .remoteConnectionClosed)
}
}
}

class TestBackpressureWriter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ extension HTTPRequestStateMachineTests {
("testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdown", testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdown),
("testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt", testFailHTTP1RequestWithoutContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt),
("testFailHTTP1RequestWithContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt", testFailHTTP1RequestWithContentLengthWithNIOSSLErrorUncleanShutdownButIgnoreIt),
("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand),
("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead),
("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemand", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemand),
("testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemandMultipleTimes", testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemandMultipleTimes),
]
}
}
95 changes: 95 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,101 @@ class HTTPRequestStateMachineTests: XCTestCase {
XCTAssertEqual(state.errorHappened(HTTPParserError.invalidEOFState), .failRequest(HTTPParserError.invalidEOFState, .close))
XCTAssertEqual(state.channelInactive(), .wait)
}

func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForDemand() {
var state = HTTPRequestStateMachine(isChannelWritable: true, ignoreUncleanSSLShutdown: false)
let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/")
let metadata = RequestFramingMetadata(connectionClose: false, body: .none)
XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false))

let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"])
let body = ByteBuffer(string: "foo bar")
XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false))
XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait)
XCTAssertEqual(state.channelReadComplete(), .wait)
XCTAssertEqual(state.read(), .read)
XCTAssertEqual(state.channelRead(.body(body)), .wait)
XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts([body]))
XCTAssertEqual(state.read(), .wait)

XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: " baz lightyear"))), .wait)
XCTAssertEqual(state.channelReadComplete(), .wait)
XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none))
}

func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForRead() {
var state = HTTPRequestStateMachine(isChannelWritable: true, ignoreUncleanSSLShutdown: false)
let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/")
let metadata = RequestFramingMetadata(connectionClose: false, body: .none)
XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false))

let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"])
let body = ByteBuffer(string: "foo bar")
XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false))
XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait)
XCTAssertEqual(state.channelReadComplete(), .wait)
XCTAssertEqual(state.read(), .read)
XCTAssertEqual(state.channelRead(.body(body)), .wait)
XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts([body]))
XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait)

XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: " baz lightyear"))), .wait)
XCTAssertEqual(state.channelReadComplete(), .wait)
XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none))
}

func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemand() {
var state = HTTPRequestStateMachine(isChannelWritable: true, ignoreUncleanSSLShutdown: false)
let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/")
let metadata = RequestFramingMetadata(connectionClose: false, body: .none)
XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false))

let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"])
let body = ByteBuffer(string: "foo bar")
XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false))
XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait)
XCTAssertEqual(state.channelReadComplete(), .wait)
XCTAssertEqual(state.read(), .read)
XCTAssertEqual(state.channelRead(.body(body)), .wait)
XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts([body]))

XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: " baz lightyear"))), .wait)
XCTAssertEqual(state.channelReadComplete(), .wait)
XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none))
}

func testFailHTTPRequestWithContentLengthBecauseOfChannelInactiveWaitingForReadAndDemandMultipleTimes() {
var state = HTTPRequestStateMachine(isChannelWritable: true, ignoreUncleanSSLShutdown: false)
let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/")
let metadata = RequestFramingMetadata(connectionClose: false, body: .none)
XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, startBody: false))

let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"])
let body = ByteBuffer(string: "foo bar")
XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false))
XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait)
XCTAssertEqual(state.channelReadComplete(), .wait)
XCTAssertEqual(state.read(), .read)
XCTAssertEqual(state.channelRead(.body(body)), .wait)
XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts([body]))

let part1 = ByteBuffer(string: "baz lightyear")
XCTAssertEqual(state.channelRead(.body(part1)), .wait)
XCTAssertEqual(state.channelReadComplete(), .wait)

let part2 = ByteBuffer(string: "nearly last")
XCTAssertEqual(state.channelRead(.body(part2)), .wait)
XCTAssertEqual(state.channelReadComplete(), .wait)

let part3 = ByteBuffer(string: "final message")
XCTAssertEqual(state.channelRead(.body(part3)), .wait)
XCTAssertEqual(state.channelReadComplete(), .wait)

XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, [part1, part2, part3]))
XCTAssertEqual(state.channelReadComplete(), .wait)

XCTAssertEqual(state.channelInactive(), .wait)
}
}

extension HTTPRequestStateMachine.Action: Equatable {
Expand Down

0 comments on commit d928cc8

Please sign in to comment.