diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 1cc9b3bc4..0a75a2b63 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -504,11 +504,43 @@ public class HTTPClient { /// - delegate: Delegate to process response parts. /// - eventLoop: NIO Event Loop preference. /// - deadline: Point in time by which the request must complete. - public func execute(request: Request, - delegate: Delegate, - eventLoop eventLoopPreference: EventLoopPreference, - deadline: NIODeadline? = nil, - logger originalLogger: Logger?) -> Task { + /// - logger: The logger to use for this request. + public func execute( + request: Request, + delegate: Delegate, + eventLoop eventLoopPreference: EventLoopPreference, + deadline: NIODeadline? = nil, + logger originalLogger: Logger? + ) -> Task { + self._execute( + request: request, + delegate: delegate, + eventLoop: eventLoopPreference, + deadline: deadline, + logger: originalLogger, + redirectState: RedirectState( + self.configuration.redirectConfiguration.mode, + initialURL: request.url.absoluteString + ) + ) + } + + /// Execute arbitrary HTTP request and handle response processing using provided delegate. + /// + /// - parameters: + /// - request: HTTP request to execute. + /// - delegate: Delegate to process response parts. + /// - eventLoop: NIO Event Loop preference. + /// - deadline: Point in time by which the request must complete. + /// - logger: The logger to use for this request. + func _execute( + request: Request, + delegate: Delegate, + eventLoop eventLoopPreference: EventLoopPreference, + deadline: NIODeadline? = nil, + logger originalLogger: Logger?, + redirectState: RedirectState? + ) -> Task { let logger = (originalLogger ?? HTTPClient.loggingDisabled).attachingRequestInformation(request, requestID: globalRequestID.add(1)) let taskEL: EventLoop switch eventLoopPreference.preference { @@ -543,22 +575,20 @@ public class HTTPClient { return failedTask } - let redirectHandler: RedirectHandler? - switch self.configuration.redirectConfiguration.configuration { - case .follow(let max, let allowCycles): - var request = request - if request.redirectState == nil { - request.redirectState = .init(count: max, visited: allowCycles ? nil : Set()) + let redirectHandler: RedirectHandler? = { + guard let redirectState = redirectState else { return nil } + + return .init(request: request, redirectState: redirectState) { newRequest, newRedirectState in + self._execute( + request: newRequest, + delegate: delegate, + eventLoop: eventLoopPreference, + deadline: deadline, + logger: logger, + redirectState: newRedirectState + ) } - redirectHandler = RedirectHandler(request: request) { newRequest in - self.execute(request: newRequest, - delegate: delegate, - eventLoop: eventLoopPreference, - deadline: deadline) - } - case .disallow: - redirectHandler = nil - } + }() let task = Task(eventLoop: taskEL, logger: logger) do { @@ -804,21 +834,21 @@ extension HTTPClient.Configuration { /// Specifies redirect processing settings. public struct RedirectConfiguration { - enum Configuration { + enum Mode { /// Redirects are not followed. case disallow /// Redirects are followed with a specified limit. case follow(max: Int, allowCycles: Bool) } - var configuration: Configuration + var mode: Mode init() { - self.configuration = .follow(max: 5, allowCycles: false) + self.mode = .follow(max: 5, allowCycles: false) } - init(configuration: Configuration) { - self.configuration = configuration + init(configuration: Mode) { + self.mode = configuration } /// Redirects are not followed. diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index e745726cb..3137bd145 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -109,16 +109,9 @@ extension HTTPClient { /// Request-specific TLS configuration, defaults to no request-specific TLS configuration. public var tlsConfiguration: TLSConfiguration? - struct RedirectState { - var count: Int - var visited: Set? - } - /// Parsed, validated and deconstructed URL. let deconstructedURL: DeconstructedURL - var redirectState: RedirectState? - /// Create HTTP request. /// /// - parameters: @@ -190,7 +183,6 @@ extension HTTPClient { public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil, tlsConfiguration: TLSConfiguration?) throws { self.deconstructedURL = try DeconstructedURL(url: url) - self.redirectState = nil self.url = url self.method = method self.headers = headers @@ -642,87 +634,42 @@ internal struct TaskCancelEvent {} internal struct RedirectHandler { let request: HTTPClient.Request - let execute: (HTTPClient.Request) -> HTTPClient.Task - - func redirectTarget(status: HTTPResponseStatus, headers: HTTPHeaders) -> URL? { - switch status { - case .movedPermanently, .found, .seeOther, .notModified, .useProxy, .temporaryRedirect, .permanentRedirect: - break - default: - return nil - } - - guard let location = headers.first(name: "Location") else { - return nil - } - - guard let url = URL(string: location, relativeTo: request.url) else { - return nil - } - - guard self.request.deconstructedURL.scheme.supportsRedirects(to: url.scheme) else { - return nil - } - - if url.isFileURL { - return nil - } - - return url.absoluteURL + let redirectState: RedirectState + let execute: (HTTPClient.Request, RedirectState) -> HTTPClient.Task + + func redirectTarget(status: HTTPResponseStatus, responseHeaders: HTTPHeaders) -> URL? { + responseHeaders.extractRedirectTarget( + status: status, + originalURL: self.request.url, + originalScheme: self.request.deconstructedURL.scheme + ) } - func redirect(status: HTTPResponseStatus, to redirectURL: URL, promise: EventLoopPromise) { - var nextState: HTTPClient.Request.RedirectState? - if var state = request.redirectState { - guard state.count > 0 else { - return promise.fail(HTTPClientError.redirectLimitReached) - } - - state.count -= 1 - - if var visited = state.visited { - guard !visited.contains(redirectURL) else { - return promise.fail(HTTPClientError.redirectCycleDetected) - } - - visited.insert(redirectURL) - state.visited = visited - } - - nextState = state - } - - let originalRequest = self.request - - var convertToGet = false - if status == .seeOther, self.request.method != .HEAD { - convertToGet = true - } else if status == .movedPermanently || status == .found, self.request.method == .POST { - convertToGet = true - } - - var method = originalRequest.method - var headers = originalRequest.headers - var body = originalRequest.body - - if convertToGet { - method = .GET - body = nil - headers.remove(name: "Content-Length") - headers.remove(name: "Content-Type") - } - - if !originalRequest.url.hasTheSameOrigin(as: redirectURL) { - headers.remove(name: "Origin") - headers.remove(name: "Cookie") - headers.remove(name: "Authorization") - headers.remove(name: "Proxy-Authorization") - } - + func redirect( + status: HTTPResponseStatus, + to redirectURL: URL, + promise: EventLoopPromise + ) { do { - var newRequest = try HTTPClient.Request(url: redirectURL, method: method, headers: headers, body: body) - newRequest.redirectState = nextState - self.execute(newRequest).futureResult.whenComplete { result in + var redirectState = self.redirectState + try redirectState.redirect(to: redirectURL.absoluteString) + + let (method, headers, body) = transformRequestForRedirect( + from: request.url, + method: self.request.method, + headers: self.request.headers, + body: self.request.body, + to: redirectURL, + status: status + ) + + let newRequest = try HTTPClient.Request( + url: redirectURL, + method: method, + headers: headers, + body: body + ) + self.execute(newRequest, redirectState).futureResult.whenComplete { result in promise.futureResult.eventLoop.execute { promise.completeWith(result) } diff --git a/Sources/AsyncHTTPClient/RedirectState.swift b/Sources/AsyncHTTPClient/RedirectState.swift new file mode 100644 index 000000000..c4e427ef1 --- /dev/null +++ b/Sources/AsyncHTTPClient/RedirectState.swift @@ -0,0 +1,142 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import struct Foundation.URL +import NIOHTTP1 + +typealias RedirectMode = HTTPClient.Configuration.RedirectConfiguration.Mode + +struct RedirectState { + /// number of redirects we are allowed to follow. + private var limit: Int + + /// All visited URLs. + private var visited: [String] + + /// if true, `redirect(to:)` will throw an error if a cycle is detected. + private let allowCycles: Bool +} + +extension RedirectState { + /// Creates a `RedirectState` from a configuration. + /// Returns nil if the user disallowed redirects, + /// otherwise an instance of `RedirectState` which respects the user defined settings. + init?( + _ configuration: RedirectMode, + initialURL: String + ) { + switch configuration { + case .disallow: + return nil + case .follow(let maxRedirects, let allowCycles): + self.init(limit: maxRedirects, visited: [initialURL], allowCycles: allowCycles) + } + } +} + +extension RedirectState { + /// Call this method when you are about to do a redirect to the given `redirectURL`. + /// This method records that URL into `self`. + /// - Parameter redirectURL: the new URL to redirect the request to + /// - Throws: if it reaches the redirect limit or detects a redirect cycle if and `allowCycles` is false + mutating func redirect(to redirectURL: String) throws { + guard self.visited.count <= limit else { + throw HTTPClientError.redirectLimitReached + } + + guard allowCycles || !self.visited.contains(redirectURL) else { + throw HTTPClientError.redirectCycleDetected + } + self.visited.append(redirectURL) + } +} + +extension HTTPHeaders { + /// Tries to extract a redirect URL from the `location` header if the `status` indicates it should do so. + /// It also validates that we can redirect to the scheme of the extracted redirect URL from the `originalScheme`. + /// - Parameters: + /// - status: response status of the request + /// - originalURL: url of the previous request + /// - originalScheme: scheme of the previous request + /// - Returns: redirect URL to follow + func extractRedirectTarget( + status: HTTPResponseStatus, + originalURL: URL, + originalScheme: Scheme + ) -> URL? { + switch status { + case .movedPermanently, .found, .seeOther, .notModified, .useProxy, .temporaryRedirect, .permanentRedirect: + break + default: + return nil + } + + guard let location = self.first(name: "Location") else { + return nil + } + + guard let url = URL(string: location, relativeTo: originalURL) else { + return nil + } + + guard originalScheme.supportsRedirects(to: url.scheme) else { + return nil + } + + if url.isFileURL { + return nil + } + + return url.absoluteURL + } +} + +/// Transforms the original `requestMethod`, `requestHeaders` and `requestBody` to be ready to be send out as a new request to the `redirectURL`. +/// - Returns: New `HTTPMethod`, `HTTPHeaders` and `Body` to be send as a new request to `redirectURL` +func transformRequestForRedirect( + from originalURL: URL, + method requestMethod: HTTPMethod, + headers requestHeaders: HTTPHeaders, + body requestBody: Body?, + to redirectURL: URL, + status responseStatus: HTTPResponseStatus +) -> (HTTPMethod, HTTPHeaders, Body?) { + let convertToGet: Bool + if responseStatus == .seeOther, requestMethod != .HEAD { + convertToGet = true + } else if responseStatus == .movedPermanently || responseStatus == .found, requestMethod == .POST { + convertToGet = true + } else { + convertToGet = false + } + + var method = requestMethod + var headers = requestHeaders + var body = requestBody + + if convertToGet { + method = .GET + body = nil + headers.remove(name: "Content-Length") + headers.remove(name: "Content-Type") + } + + if !originalURL.hasTheSameOrigin(as: redirectURL) { + headers.remove(name: "Origin") + headers.remove(name: "Cookie") + headers.remove(name: "Authorization") + headers.remove(name: "Proxy-Authorization") + } + return (method, headers, body) +} diff --git a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift index c20d5e211..cd82a9000 100644 --- a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift +++ b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift @@ -268,7 +268,10 @@ extension RequestBag.StateMachine { preconditionFailure("If we receive a response, we must not have received something else before") } - if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) { + if let redirectURL = self.redirectHandler?.redirectTarget( + status: head.status, + responseHeaders: head.headers + ) { self.state = .redirected(head, redirectURL) return false } else { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 22659d32c..b873833fa 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -67,6 +67,7 @@ extension HTTPClientTests { ("testDecompressionLimit", testDecompressionLimit), ("testLoopDetectionRedirectLimit", testLoopDetectionRedirectLimit), ("testCountRedirectLimit", testCountRedirectLimit), + ("testRedirectToTheInitialURLDoesThrowOnFirstRedirect", testRedirectToTheInitialURLDoesThrowOnFirstRedirect), ("testMultipleConcurrentRequests", testMultipleConcurrentRequests), ("testWorksWith500Error", testWorksWith500Error), ("testWorksWithHTTP10Response", testWorksWithHTTP10Response), @@ -109,6 +110,7 @@ extension HTTPClientTests { ("testWeHandleUsReceivingACloseHeaderCorrectly", testWeHandleUsReceivingACloseHeaderCorrectly), ("testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly", testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly), ("testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly", testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly), + ("testLoggingCorrectlyAttachesRequestInformationEvenAfterDuringRedirect", testLoggingCorrectlyAttachesRequestInformationEvenAfterDuringRedirect), ("testLoggingCorrectlyAttachesRequestInformation", testLoggingCorrectlyAttachesRequestInformation), ("testNothingIsLoggedAtInfoOrHigher", testNothingIsLoggedAtInfoOrHigher), ("testAllMethodsLog", testAllMethodsLog), diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index e7ba9d510..baeecca33 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -882,11 +882,40 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try localHTTPBin.shutdown()) } - XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(localHTTPBin.port)/redirect/infinite1").wait(), "Should fail with redirect limit") { error in + XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(localHTTPBin.port)/redirect/infinite1").timeout(after: .seconds(10)).wait()) { error in XCTAssertEqual(error as? HTTPClientError, HTTPClientError.redirectLimitReached) } } + func testRedirectToTheInitialURLDoesThrowOnFirstRedirect() throws { + let localHTTPBin = HTTPBin(.http1_1(ssl: true)) + defer { XCTAssertNoThrow(try localHTTPBin.shutdown()) } + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init( + certificateVerification: .none, + redirectConfiguration: .follow(max: 1, allowCycles: false) + ) + ) + defer { XCTAssertNoThrow(try localClient.syncShutdown()) } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/redirect/target", + method: .GET, + headers: [ + "X-Target-Redirect-URL": "/redirect/target", + ] + )) + guard let request = maybeRequest else { return } + + XCTAssertThrowsError( + try localClient.execute(request: request).timeout(after: .seconds(10)).wait() + ) { error in + XCTAssertEqual(error as? HTTPClientError, HTTPClientError.redirectCycleDetected) + } + } + func testMultipleConcurrentRequests() throws { let numberOfRequestsPerThread = 1000 let numberOfParallelWorkers = 5 @@ -2033,6 +2062,52 @@ class HTTPClientTests: XCTestCase { } } + func testLoggingCorrectlyAttachesRequestInformationEvenAfterDuringRedirect() { + let logStore = CollectEverythingLogHandler.LogStore() + + var logger = Logger(label: "\(#function)", factory: { _ in + CollectEverythingLogHandler(logStore: logStore) + }) + logger.logLevel = .trace + logger[metadataKey: "custom-request-id"] = "abcd" + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( + url: "http://localhost:\(self.defaultHTTPBin.port)/redirect/target", + method: .GET, + headers: [ + "X-Target-Redirect-URL": "/get", + ] + )) + guard let request = maybeRequest else { return } + + XCTAssertNoThrow(try self.defaultClient.execute( + request: request, + eventLoop: .indifferent, + deadline: nil, + logger: logger + ).wait()) + let logs = logStore.allEntries + + XCTAssertTrue(logs.allSatisfy { $0.metadata["custom-request-id"] == "abcd" }) + + guard let firstRequestID = logs.first?.metadata["ahc-request-id"] else { + return XCTFail("could not get first request ID") + } + guard let lastRequestID = logs.last?.metadata["ahc-request-id"] else { + return XCTFail("could not get second request ID") + } + + let firstRequestLogs = logs.prefix(while: { $0.metadata["ahc-request-id"] == firstRequestID }) + XCTAssertGreaterThan(firstRequestLogs.count, 0) + + let secondRequestLogs = logs.drop(while: { $0.metadata["ahc-request-id"] == firstRequestID }) + XCTAssertGreaterThan(secondRequestLogs.count, 0) + XCTAssertTrue(secondRequestLogs.allSatisfy { $0.metadata["ahc-request-id"] == lastRequestID }) + + logs.forEach { print($0) } + } + func testLoggingCorrectlyAttachesRequestInformation() { let logStore = CollectEverythingLogHandler.LogStore()