Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge provides all elements from the subsequences on cancellation #276

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 46 additions & 44 deletions Sources/AsyncAlgorithms/Merge/MergeStateMachine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ struct MergeStateMachine<
buffer: Deque<Element>,
upstreamContinuations: [UnsafeContinuation<Void, Error>],
upstreamsFinished: Int,
downstreamContinuation: UnsafeContinuation<Element?, Error>?
downstreamContinuation: UnsafeContinuation<Element?, Error>?,
cancelled: Bool
Copy link
Member

@phausler phausler Jul 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this just be another state case?

if so it wouldn't need to hold onto the Task which means that the closure (which captures the bases) would no longer be held on to.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make this another state case, we will have a lot of duplicated code that is shared between merging and cancelled. In the interest in keeping code duplication minimal, IMO we should not create a separate case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I think we should just leave it with the current case

)

/// The state once any of the upstream sequences threw an `Error`.
Expand Down Expand Up @@ -100,11 +101,11 @@ struct MergeStateMachine<
// Nothing to do here. No demand was signalled until now
return .none

case .merging(_, _, _, _, .some):
case .merging(_, _, _, _, .some, _):
// An iterator was deinitialized while we have a suspended continuation.
preconditionFailure("Internal inconsistency current state \(self.state) and received iteratorDeinitialized()")

case let .merging(task, _, upstreamContinuations, _, .none):
case let .merging(task, _, upstreamContinuations, _, .none, _):
// The iterator was dropped which signals that the consumer is finished.
// We can transition to finished now and need to clean everything up.
state = .finished
Expand Down Expand Up @@ -142,7 +143,8 @@ struct MergeStateMachine<
buffer: .init(),
upstreamContinuations: [], // This should reserve capacity in the variadic generics case
upstreamsFinished: 0,
downstreamContinuation: nil
downstreamContinuation: nil,
cancelled: false
)

case .merging, .upstreamFailure, .finished:
Expand Down Expand Up @@ -175,11 +177,11 @@ struct MergeStateMachine<
// Child tasks are only created after we transitioned to `merging`
preconditionFailure("Internal inconsistency current state \(self.state) and received childTaskSuspended()")

case .merging(_, _, _, _, .some):
case .merging(_, _, _, _, .some, _):
// We have outstanding demand so request the next element
return .resumeContinuation(upstreamContinuation: continuation)

case .merging(let task, let buffer, var upstreamContinuations, let upstreamsFinished, .none):
case .merging(let task, let buffer, var upstreamContinuations, let upstreamsFinished, .none, let cancelled):
// There is no outstanding demand from the downstream
// so we are storing the continuation and resume it once there is demand.
state = .modifying
Expand All @@ -191,7 +193,8 @@ struct MergeStateMachine<
buffer: buffer,
upstreamContinuations: upstreamContinuations,
upstreamsFinished: upstreamsFinished,
downstreamContinuation: nil
downstreamContinuation: nil,
cancelled: cancelled
)

return .none
Expand Down Expand Up @@ -236,7 +239,7 @@ struct MergeStateMachine<
// Child tasks that are producing elements are only created after we transitioned to `merging`
preconditionFailure("Internal inconsistency current state \(self.state) and received elementProduced()")

case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .some(downstreamContinuation)):
case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .some(downstreamContinuation), cancelled):
// We produced an element and have an outstanding downstream continuation
// this means we can go right ahead and resume the continuation with that element
precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty")
Expand All @@ -246,15 +249,16 @@ struct MergeStateMachine<
buffer: buffer,
upstreamContinuations: upstreamContinuations,
upstreamsFinished: upstreamsFinished,
downstreamContinuation: nil
downstreamContinuation: nil,
cancelled: cancelled
)

return .resumeContinuation(
downstreamContinuation: downstreamContinuation,
element: element
)

case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none):
case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none, let cancelled):
// There is not outstanding downstream continuation so we must buffer the element
// This happens if we race our upstream sequences to produce elements
// and the _losers_ are signalling their produced element
Expand All @@ -267,7 +271,8 @@ struct MergeStateMachine<
buffer: buffer,
upstreamContinuations: upstreamContinuations,
upstreamsFinished: upstreamsFinished,
downstreamContinuation: nil
downstreamContinuation: nil,
cancelled: cancelled
)

return .none
Expand Down Expand Up @@ -310,7 +315,7 @@ struct MergeStateMachine<
case .initial:
preconditionFailure("Internal inconsistency current state \(self.state) and received upstreamFinished()")

case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, let .some(downstreamContinuation)):
case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, let .some(downstreamContinuation), let cancelled):
// One of the upstreams finished
precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty")

Expand All @@ -335,13 +340,14 @@ struct MergeStateMachine<
buffer: buffer,
upstreamContinuations: upstreamContinuations,
upstreamsFinished: upstreamsFinished,
downstreamContinuation: downstreamContinuation
downstreamContinuation: downstreamContinuation,
cancelled: cancelled
)

return .none
}

case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, .none):
case .merging(let task, let buffer, let upstreamContinuations, var upstreamsFinished, .none, let cancelled):
// First we increment our counter of finished upstreams
upstreamsFinished += 1

Expand All @@ -350,7 +356,8 @@ struct MergeStateMachine<
buffer: buffer,
upstreamContinuations: upstreamContinuations,
upstreamsFinished: upstreamsFinished,
downstreamContinuation: nil
downstreamContinuation: nil,
cancelled: cancelled
)

if upstreamsFinished == self.numberOfUpstreamSequences {
Expand Down Expand Up @@ -402,7 +409,7 @@ struct MergeStateMachine<
case .initial:
preconditionFailure("Internal inconsistency current state \(self.state) and received upstreamThrew()")

case let .merging(task, buffer, upstreamContinuations, _, .some(downstreamContinuation)):
case let .merging(task, buffer, upstreamContinuations, _, .some(downstreamContinuation), _):
// An upstream threw an error and we have a downstream continuation.
// We just need to resume the downstream continuation with the error and cancel everything
precondition(buffer.isEmpty, "We are holding a continuation so the buffer must be empty")
Expand All @@ -417,7 +424,7 @@ struct MergeStateMachine<
upstreamContinuations: upstreamContinuations
)

case let .merging(task, buffer, upstreamContinuations, _, .none):
case let .merging(task, buffer, upstreamContinuations, _, .none, _):
// An upstream threw an error and we don't have a downstream continuation.
// We need to store the error and wait for the downstream to consume the
// rest of the buffer and the error. However, we can already cancel the task
Expand Down Expand Up @@ -454,10 +461,7 @@ struct MergeStateMachine<
upstreamContinuations: [UnsafeContinuation<Void, Error>]
)
/// Indicates that the task and the upstream continuations should be cancelled.
case cancelTaskAndUpstreamContinuations(
task: Task<Void, Never>,
upstreamContinuations: [UnsafeContinuation<Void, Error>]
)
case cancelTask(Task<Void, Never>)
/// Indicates that nothing should be done.
case none
}
Expand All @@ -471,26 +475,21 @@ struct MergeStateMachine<

return .none

case let .merging(task, _, upstreamContinuations, _, .some(downstreamContinuation)):
// The downstream Task got cancelled so we need to cancel our upstream Task
// and resume all continuations. We can also transition to finished.
state = .finished
case let .merging(task, buffer, upstreamContinuations, upstreamFinished, downstreamContinuation, cancelled):
guard !cancelled else {
return .none
}

return .resumeDownstreamContinuationWithNilAndCancelTaskAndUpstreamContinuations(
downstreamContinuation: downstreamContinuation,
self.state = .merging(
task: task,
upstreamContinuations: upstreamContinuations
buffer: buffer,
upstreamContinuations: upstreamContinuations,
upstreamsFinished: upstreamFinished,
downstreamContinuation: downstreamContinuation,
cancelled: true
)

case let .merging(task, _, upstreamContinuations, _, .none):
// The downstream Task got cancelled so we need to cancel our upstream Task
// and resume all continuations. We can also transition to finished.
state = .finished

return .cancelTaskAndUpstreamContinuations(
task: task,
upstreamContinuations: upstreamContinuations
)
return .cancelTask(task)

case .upstreamFailure:
// An upstream already threw and we cancelled everything already.
Expand Down Expand Up @@ -531,11 +530,11 @@ struct MergeStateMachine<
// We are transitioning to merging in the taskStarted method.
return .startTaskAndSuspendDownstreamTask(base1, base2, base3)

case .merging(_, _, _, _, .some):
case .merging(_, _, _, _, .some, _):
// We have multiple AsyncIterators iterating the sequence
preconditionFailure("Internal inconsistency current state \(self.state) and received next()")

case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none):
case .merging(let task, var buffer, let upstreamContinuations, let upstreamsFinished, .none, let cancelled):
state = .modifying

if let element = buffer.popFirst() {
Expand All @@ -545,7 +544,8 @@ struct MergeStateMachine<
buffer: buffer,
upstreamContinuations: upstreamContinuations,
upstreamsFinished: upstreamsFinished,
downstreamContinuation: nil
downstreamContinuation: nil,
cancelled: cancelled
)

return .returnElement(.success(element))
Expand All @@ -556,7 +556,8 @@ struct MergeStateMachine<
buffer: buffer,
upstreamContinuations: upstreamContinuations,
upstreamsFinished: upstreamsFinished,
downstreamContinuation: nil
downstreamContinuation: nil,
cancelled: cancelled
)

return .suspendDownstreamTask
Expand Down Expand Up @@ -601,21 +602,22 @@ struct MergeStateMachine<
mutating func next(for continuation: UnsafeContinuation<Element?, Error>) -> NextForAction {
switch state {
case .initial,
.merging(_, _, _, _, .some),
.merging(_, _, _, _, .some, _),
.upstreamFailure,
.finished:
// All other states are handled by `next` already so we should never get in here with
// any of those
preconditionFailure("Internal inconsistency current state \(self.state) and received next(for:)")

case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .none):
case let .merging(task, buffer, upstreamContinuations, upstreamsFinished, .none, cancelled):
// We suspended the task and need signal the upstreams
state = .merging(
task: task,
buffer: buffer,
upstreamContinuations: [], // TODO: don't alloc new array here
upstreamsFinished: upstreamsFinished,
downstreamContinuation: continuation
downstreamContinuation: continuation,
cancelled: cancelled
)

return .resumeUpstreamContinuations(
Expand Down
11 changes: 3 additions & 8 deletions Sources/AsyncAlgorithms/Merge/MergeStorage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,7 @@ final class MergeStorage<

downstreamContinuation.resume(returning: nil)

case let .cancelTaskAndUpstreamContinuations(
task,
upstreamContinuations
):
upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) }

case let .cancelTask(task):
task.cancel()

case .none:
Expand Down Expand Up @@ -262,8 +257,8 @@ final class MergeStorage<
task,
upstreamContinuations
):
upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) }
task.cancel()
upstreamContinuations.forEach { $0.resume() }

downstreamContinuation.resume(returning: nil)

Expand All @@ -273,8 +268,8 @@ final class MergeStorage<
task,
upstreamContinuations
):
upstreamContinuations.forEach { $0.resume(throwing: CancellationError()) }
task.cancel()
upstreamContinuations.forEach { $0.resume() }

break loop
case .none:
Expand Down
69 changes: 69 additions & 0 deletions Tests/AsyncAlgorithmsTests/TestMerge.swift
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,38 @@ final class TestMerge2: XCTestCase {
}
t.cancel()
}

func testAsyncStreamElementsThatAreInjectedOnCancellationAreDelivered() async {
let (stream1, continuation1) = AsyncStream.makeStream(of: Int.self)
let (stream2, continuation2) = AsyncStream.makeStream(of: Int.self)
continuation1.onTermination = { reason in
XCTAssertEqual(reason, .cancelled)
continuation1.yield(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems a bit off; it is yielding a value AFTER it has transitioned to a terminal state? It would be perfectly reasonable for the stream to no longer produce values once it has moved to a terminal

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually surprised by this behaviour as well and I don't know if this is intentional in AsyncStream. I expected the onTermination to be called after we returned nil the consumer. This allows to inject new values before cancellation finishes :D

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that is perhaps the cooperative nature of cancellation - it eventually happens but may not happen immediately

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do agree that cancellation can be cooperative but having users rely on this behaviour seems rather brittle. Nothing in the API nor in the proposal specifies this behaviour. I will include this in the new back pressured interface proposal.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is reasonable, that users have the option to yield something before cancellation. If it fits with current naming 🤷, but I think the feature in itself is very important.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree on the premise that user's should be able to yield something before cancellation and my new proposed APIs for AsyncStream actually disallow this but it has a future direction that lays out how we can enable this again. The reason why I think the current APIs are a bit wonky here is that you can yield something in onTermination but you cannot "stop" the cancellation. So we are somewhere in the middle where we allow users to produce something but do not give them full control like withTaskCancellationHandler does.

Overall, not important to this PR but a general discussion for the forums.

}
continuation2.onTermination = { reason in
XCTAssertEqual(reason, .cancelled)
continuation2.yield(2)
}
continuation1.yield(0) // initial
let merge = merge(stream1, stream2)
let finished = expectation(description: "finished")
let iterated = expectation(description: "iterated")
let task = Task {
var count = 0
for await _ in merge {
if count == 0 { iterated.fulfill() }
count += 1
}
finished.fulfill()
XCTAssertEqual(count, 3)
}
// ensure the other task actually starts
await fulfillment(of: [iterated], timeout: 1.0)
// cancellation should ensure the loop finishes
// without regards to the remaining underlying sequence
task.cancel()
await fulfillment(of: [finished], timeout: 1.0)
}
}

final class TestMerge3: XCTestCase {
Expand Down Expand Up @@ -555,4 +587,41 @@ final class TestMerge3: XCTestCase {

iterator = nil
}

func testAsyncStreamElementsThatAreInjectedOnCancellationAreDelivered() async {
let (stream1, continuation1) = AsyncStream.makeStream(of: Int.self)
let (stream2, continuation2) = AsyncStream.makeStream(of: Int.self)
let (stream3, continuation3) = AsyncStream.makeStream(of: Int.self)
continuation1.onTermination = { reason in
XCTAssertEqual(reason, .cancelled)
continuation1.yield(1)
}
continuation2.onTermination = { reason in
XCTAssertEqual(reason, .cancelled)
continuation2.yield(2)
}
continuation3.onTermination = { reason in
XCTAssertEqual(reason, .cancelled)
continuation3.yield(3)
}
continuation1.yield(0) // initial
let merge = merge(stream1, stream2, stream3)
let finished = expectation(description: "finished")
let iterated = expectation(description: "iterated")
let task = Task {
var count = 0
for await _ in merge {
if count == 0 { iterated.fulfill() }
count += 1
}
finished.fulfill()
XCTAssertEqual(count, 4)
}
// ensure the other task actually starts
await fulfillment(of: [iterated], timeout: 1.0)
// cancellation should ensure the loop finishes
// without regards to the remaining underlying sequence
task.cancel()
await fulfillment(of: [finished], timeout: 1.0)
}
}