Skip to content

Commit

Permalink
update kyo-reactive-stream
Browse files Browse the repository at this point in the history
  • Loading branch information
HollandDM committed Jan 19, 2025
1 parent e1a2a97 commit c61f5e6
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package kyo.interop.flow

import java.util.concurrent.Flow.*
import kyo.*
import kyo.interop.flow.StreamSubscription.StreamFinishState
import kyo.interop.flow.StreamSubscription.StreamCanceled
import kyo.interop.flow.StreamSubscription.StreamComplete
import kyo.kernel.Boundary
import scala.annotation.nowarn

Expand All @@ -19,6 +20,10 @@ abstract private[kyo] class StreamPublisher[V, Ctx](
bind(subscriber)
end subscribe

private[StreamPublisher] def getSubscription(subscriber: Subscriber[? >: V])(using Frame): StreamSubscription[V, Ctx] < IO =
IO.Unsafe(new StreamSubscription[V, Ctx](stream, subscriber))
end getSubscription

end StreamPublisher

object StreamPublisher:
Expand All @@ -28,7 +33,7 @@ object StreamPublisher:
capacity: Int = Int.MaxValue
)(
using
Boundary[Ctx, IO],
Boundary[Ctx, IO & Abort[StreamCanceled]],
Frame,
Tag[Emit[Chunk[V]]],
Tag[Poll[Chunk[V]]]
Expand All @@ -41,13 +46,14 @@ object StreamPublisher:
end discardSubscriber

def consumeChannel(
publisher: StreamPublisher[V, Ctx],
channel: Channel[Subscriber[? >: V]],
supervisor: Fiber.Promise[Nothing, Unit]
): Unit < (Async & Ctx) =
Abort.recover[Closed](_ => supervisor.interrupt.unit)(
channel.stream().runForeach: subscriber =>
for
subscription <- IO.Unsafe(new StreamSubscription[V, Ctx](stream, subscriber))
subscription <- publisher.getSubscription(subscriber)
fiber <- subscription.subscribe.andThen(subscription.consume)
_ <- supervisor.onInterrupt(_ => fiber.interrupt(Result.Panic(Interrupt())).unit)
yield ()
Expand All @@ -68,7 +74,7 @@ object StreamPublisher:
case _ => discardSubscriber(subscriber)
}
supervisor <- Resource.acquireRelease(Fiber.Promise.init[Nothing, Unit])(_.interrupt.unit)
_ <- Resource.acquireRelease(Async._run(consumeChannel(channel, supervisor)))(_.interrupt.unit)
_ <- Resource.acquireRelease(Async._run(consumeChannel(publisher, channel, supervisor)))(_.interrupt.unit)
yield publisher
end for
end apply
Expand All @@ -77,7 +83,7 @@ object StreamPublisher:
@nowarn("msg=anonymous")
inline def apply[V, Ctx](
stream: Stream[V, Ctx],
subscribeCallback: (Fiber[Nothing, StreamFinishState] < (IO & Ctx)) => Unit
subscribeCallback: (Fiber[StreamCanceled, StreamComplete] < (IO & Ctx)) => Unit
)(
using
AllowUnsafe,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ final private[kyo] class StreamSubscriber[V](
}
}

def stream(using Frame, Tag[Emit[Chunk[V]]]): Stream[V, Async] = Stream(emit)
def stream(using Frame, Tag[Emit[Chunk[V]]]): Stream[V, Async] < (Resource & IO) =
Resource.ensure(interupt).andThen:
Stream(emit)

end StreamSubscriber

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ final private[kyo] class StreamSubscription[V, Ctx](

private[interop] inline def subscribe(using Frame): Unit < IO = IO(subscriber.onSubscribe(this))

private[interop] def poll(using Tag[Poll[Chunk[V]]], Frame): StreamFinishState < (Async & Poll[Chunk[V]]) =
inline def loopPoll(requesting: Long): (Chunk[V] | StreamFinishState) < (IO & Poll[Chunk[V]]) =
Loop[Long, Chunk[V] | StreamFinishState, IO & Poll[Chunk[V]]](requesting): requesting =>
private[interop] def poll(using Tag[Poll[Chunk[V]]], Frame): StreamComplete < (Async & Poll[Chunk[V]] & Abort[StreamCanceled]) =
inline def loopPoll(requesting: Long): (Chunk[V] | StreamComplete) < (IO & Poll[Chunk[V]]) =
Loop[Long, Chunk[V] | StreamComplete, IO & Poll[Chunk[V]]](requesting): requesting =>
Poll.andMap:
case Present(values) =>
if values.size <= requesting then
Expand All @@ -43,9 +43,9 @@ final private[kyo] class StreamSubscription[V, Ctx](
IO(values.take(requesting.intValue).foreach(subscriber.onNext(_)))
.andThen(Loop.done(values.drop(requesting.intValue)))
case Absent =>
IO(Loop.done(StreamFinishState.StreamComplete))
IO(Loop.done(StreamComplete))

Loop[Chunk[V], StreamFinishState, Async & Poll[Chunk[V]]](Chunk.empty[V]): leftOver =>
Loop[Chunk[V], StreamComplete, Async & Poll[Chunk[V]] & Abort[StreamCanceled]](Chunk.empty[V]): leftOver =>
Abort.run[Closed](requestChannel.safe.take).map:
case Result.Success(requesting) =>
if requesting <= leftOver.size then
Expand All @@ -55,27 +55,26 @@ final private[kyo] class StreamSubscription[V, Ctx](
IO(leftOver.foreach(subscriber.onNext(_)))
.andThen(loopPoll(requesting - leftOver.size))
.map {
case nextLeftOver: Chunk[V] => Loop.continue(nextLeftOver)
case state: StreamFinishState => Loop.done(state)
case nextLeftOver: Chunk[V] => Loop.continue(nextLeftOver)
case _: StreamComplete => Loop.done(StreamComplete)
}
case Result.Failure(_) => IO(Loop.done(StreamFinishState.StreamCanceled))
case Result.Panic(exception) => Abort.panic(exception).andThen(Loop.done(StreamFinishState.StreamCanceled))
case result => Abort.get(result.mapFailure(_ => StreamCanceled)).andThen(Loop.done(StreamComplete))
end poll

private[interop] def consume(
using
Tag[Emit[Chunk[V]]],
Tag[Poll[Chunk[V]]],
Frame,
Boundary[Ctx, IO]
): Fiber[Nothing, StreamFinishState] < (IO & Ctx) =
Boundary[Ctx, IO & Abort[StreamCanceled]]
): Fiber[StreamCanceled, StreamComplete] < (IO & Ctx) =
Async
._run[Nothing, StreamFinishState, Ctx](Poll.run(stream.emit)(poll).map(_._2))
._run[StreamCanceled, StreamComplete, Ctx](Poll.run(stream.emit)(poll).map(_._2))
.map { fiber =>
fiber.onComplete {
case Result.Success(StreamFinishState.StreamComplete) => IO(subscriber.onComplete())
case Result.Panic(e) => IO(subscriber.onError(e))
case _ => IO.unit
case Result.Success(StreamComplete) => IO(subscriber.onComplete())
case Result.Panic(e) => IO(subscriber.onError(e))
case Result.Failure(StreamCanceled) => IO.unit
}.andThen(fiber)
}
end consume
Expand All @@ -84,9 +83,10 @@ end StreamSubscription

object StreamSubscription:

private[interop] enum StreamFinishState derives CanEqual:
case StreamComplete, StreamCanceled
end StreamFinishState
type StreamComplete = StreamComplete.type
case object StreamComplete
type StreamCanceled = StreamCanceled.type
case object StreamCanceled

inline def subscribe[V, Ctx](
stream: Stream[V, Ctx],
Expand Down Expand Up @@ -120,7 +120,7 @@ object StreamSubscription:
stream: Stream[V, Ctx],
subscriber: Subscriber[? >: V]
)(
subscribeCallback: (Fiber[Nothing, StreamFinishState] < (IO & Ctx)) => Unit
subscribeCallback: (Fiber[StreamCanceled, StreamComplete] < (IO & Ctx)) => Unit
)(
using
AllowUnsafe,
Expand All @@ -134,7 +134,7 @@ object StreamSubscription:
stream: Stream[V, Ctx],
subscriber: Subscriber[? >: V]
)(
subscribeCallback: (Fiber[Nothing, StreamFinishState] < (IO & Ctx)) => Unit
subscribeCallback: (Fiber[StreamCanceled, StreamComplete] < (IO & Ctx)) => Unit
)(
using
AllowUnsafe,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ package object flow:
Frame,
Tag[Emit[Chunk[T]]],
Tag[Poll[Chunk[T]]]
): Stream[T, Async] < IO =
): Stream[T, Async] < (Resource & IO) =
for
subscriber <- StreamSubscriber[T](bufferSize, emitStrategy)
_ <- IO(publisher.subscribe(subscriber))
yield subscriber.stream
stream <- subscriber.stream
yield stream

@nowarn("msg=anonymous")
inline def subscribeToStream[T, Ctx](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package object reactivestreams:
Frame,
Tag[Emit[Chunk[T]]],
Tag[Poll[Chunk[T]]]
): Stream[T, Async] < IO =
): Stream[T, Async] < (Resource & IO) =
flow.fromPublisher(FlowAdapters.toFlowPublisher(publisher), bufferSize, emitStrategy)

@nowarn("msg=anonymous")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package kyo.interop.flow

import kyo.*
import kyo.Duration
import kyo.Result.Failure
import kyo.Result.Success
import kyo.interop.flow.StreamSubscriber.EmitStrategy
import kyo.interop.flow.StreamSubscription.StreamCanceled
import kyo.interop.flow.StreamSubscription.StreamComplete
import kyo.kernel.ArrowEffect

abstract private class PublisherToSubscriberTest extends Test:
Expand All @@ -16,7 +20,8 @@ abstract private class PublisherToSubscriberTest extends Test:
publisher <- stream.toPublisher
subscriber <- streamSubscriber
_ = publisher.subscribe(subscriber)
(isSame, _) <- subscriber.stream
subscriberStream <- subscriber.stream
(isSame, _) <- subscriberStream
.runFold(true -> 0) { case ((acc, expected), cur) =>
(acc && (expected == cur)) -> (expected + 1)
}
Expand All @@ -38,7 +43,8 @@ abstract private class PublisherToSubscriberTest extends Test:
publisher <- inputStream.toPublisher
subscriber <- streamSubscriber
_ = publisher.subscribe(subscriber)
result <- Abort.run[Throwable](subscriber.stream.runDiscard)
subscriberStream <- subscriber.stream
result <- Abort.run[Throwable](subscriberStream.runDiscard)
yield result match
case Result.Error(e: Throwable) => assert(e == TestError)
case _ => assert(false)
Expand All @@ -65,18 +71,22 @@ abstract private class PublisherToSubscriberTest extends Test:
inputStream = Stream(Emit.valueWith(Chunk.empty)(emit(counter)))
publisher <- inputStream.toPublisher
subscriber1 <- streamSubscriber
subStream1 <- subscriber1.stream
subscriber2 <- streamSubscriber
subStream2 <- subscriber2.stream
subscriber3 <- streamSubscriber
subStream3 <- subscriber3.stream
subscriber4 <- streamSubscriber
subStream4 <- subscriber4.stream
_ = publisher.subscribe(subscriber1)
_ = publisher.subscribe(subscriber2)
_ = publisher.subscribe(subscriber3)
_ = publisher.subscribe(subscriber4)
values <- Fiber.parallelUnbounded[Nothing, Chunk[Int], Any](List(
subscriber1.stream.run,
subscriber2.stream.run,
subscriber3.stream.run,
subscriber4.stream.run
subStream1.run,
subStream2.run,
subStream3.run,
subStream4.run
)).map(_.get)
yield
assert(values.size == 4)
Expand Down Expand Up @@ -117,27 +127,31 @@ abstract private class PublisherToSubscriberTest extends Test:
inputStream = Stream(Emit.valueWith(Chunk.empty)(emit(counter)))
publisher <- inputStream.toPublisher
subscriber1 <- streamSubscriber
subStream1 <- subscriber1.stream
subscriber2 <- streamSubscriber
subStream2 <- subscriber2.stream
subscriber3 <- streamSubscriber
subStream3 <- subscriber3.stream
subscriber4 <- streamSubscriber
subStream4 <- subscriber4.stream
_ = publisher.subscribe(subscriber1)
_ = publisher.subscribe(subscriber2)
_ = publisher.subscribe(subscriber3)
_ = publisher.subscribe(subscriber4)
fiber1 <- Async.run(modify(subscriber1.stream, shouldFail = false))
fiber2 <- Async.run(modify(subscriber2.stream, shouldFail = true))
fiber3 <- Async.run(modify(subscriber3.stream, shouldFail = false))
fiber4 <- Async.run(modify(subscriber4.stream, shouldFail = true))
fiber1 <- Async.run(modify(subStream1, shouldFail = false))
fiber2 <- Async.run(modify(subStream2, shouldFail = true))
fiber3 <- Async.run(modify(subStream3, shouldFail = false))
fiber4 <- Async.run(modify(subStream4, shouldFail = true))
value1 <- fiber1.get
value2 <- fiber2.getResult
value3 <- fiber3.get
value4 <- fiber4.getResult
yield
assert(value1.size + value3.size == MaxStreamLength)
assert(checkStrictIncrease(value1))
assert(value2 == Result.Panic(TestError))
assert(checkStrictIncrease(value3))
assert(value4 == Result.Panic(TestError))
assert(value1.size + value3.size == MaxStreamLength)
val actualSum = value1.sum + value3.sum
val expectedSum = (MaxStreamLength >> 1) * (MaxStreamLength - 1)
assert(actualSum == expectedSum)
Expand All @@ -158,14 +172,18 @@ abstract private class PublisherToSubscriberTest extends Test:
counter <- AtomicInt.init(0)
publisher <- Stream(Emit.valueWith(Chunk.empty)(emit(counter))).toPublisher
subscriber1 <- streamSubscriber
subStream1 <- subscriber1.stream
subscriber2 <- streamSubscriber
subStream2 <- subscriber2.stream
subscriber3 <- streamSubscriber
subStream3 <- subscriber3.stream
subscriber4 <- streamSubscriber
subStream4 <- subscriber4.stream
latch <- Latch.init(5)
fiber1 <- Async.run(latch.release.andThen(subscriber1.stream.run.unit))
fiber2 <- Async.run(latch.release.andThen(subscriber2.stream.run.unit))
fiber3 <- Async.run(latch.release.andThen(subscriber3.stream.run.unit))
fiber4 <- Async.run(latch.release.andThen(subscriber4.stream.run.unit))
fiber1 <- Async.run(latch.release.andThen(subStream1.run.unit))
fiber2 <- Async.run(latch.release.andThen(subStream2.run.unit))
fiber3 <- Async.run(latch.release.andThen(subStream3.run.unit))
fiber4 <- Async.run(latch.release.andThen(subStream4.run.unit))
publisherFiber <- Async.run(Resource.run(
Stream(Emit.valueWith(Chunk.empty)(emit(counter)))
.toPublisher
Expand All @@ -185,6 +203,31 @@ abstract private class PublisherToSubscriberTest extends Test:
yield assert(true)
end for
}

"when complete, associated subscription should be canceled" in runJVM {
val stream: Stream[Int, Any] =
Stream(
Loop(0)(cur => Emit.valueWith(Chunk(cur))(Loop.continue(cur + 1)))
)
for
promise <- Fiber.Promise.init[Throwable, Unit]
subscriber <- streamSubscriber
subscription <- IO.Unsafe {
StreamSubscription.Unsafe.subscribe(
stream,
subscriber
): fiber =>
discard(IO.Unsafe.evalOrThrow(fiber.map(_.onComplete { result =>
result match
case Failure(StreamCanceled) => promise.completeDiscard(Success(()))
case _ => promise.completeDiscard(Failure(TestError))
})))
}
_ <- Resource.run(subscriber.stream.map(_.take(10).runDiscard))
result <- promise.getResult
yield assert(result == Success(()))
end for
}
}
end PublisherToSubscriberTest

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ final class StreamSubscriberTest extends Test:
val publisher = getPublisher(BatchSize)
for
subscriber <- StreamSubscriber[Int](BufferSize, EmitStrategy.Eager)
subStream <- subscriber.stream
_ = publisher.subscribe(subscriber)
results <- subscriber.stream.take(StreamLength).runFold(0)(_ + _)
results <- subStream.take(StreamLength).runFold(0)(_ + _)
yield assert(results == (StreamLength >> 1) * (StreamLength + 1))
end for
}
Expand All @@ -54,8 +55,9 @@ final class StreamSubscriberTest extends Test:
val publisher = getPublisher(BatchSize)
for
subscriber <- StreamSubscriber[Int](BufferSize, EmitStrategy.Buffer)
subStream <- subscriber.stream
_ = publisher.subscribe(subscriber)
results <- subscriber.stream.take(StreamLength).runFold(0)(_ + _)
results <- subStream.take(StreamLength).runFold(0)(_ + _)
yield assert(results == (StreamLength >> 1) * (StreamLength + 1))
end for
}
Expand All @@ -64,12 +66,14 @@ final class StreamSubscriberTest extends Test:
val publisher = getPublisher(BatchSize)
for
subscriber1 <- StreamSubscriber[Int](BufferSize, EmitStrategy.Eager)
subStream1 <- subscriber1.stream
subscriber2 <- StreamSubscriber[Int](BufferSize, EmitStrategy.Buffer)
subStream2 <- subscriber2.stream
_ = publisher.subscribe(subscriber1)
_ = publisher.subscribe(subscriber2)
results <- Async.parallelUnbounded(List(
subscriber1.stream.take(StreamLength >> 1).runFold(0)(_ + _),
subscriber2.stream.take(StreamLength >> 1).runFold(0)(_ + _)
subStream1.take(StreamLength >> 1).runFold(0)(_ + _),
subStream2.take(StreamLength >> 1).runFold(0)(_ + _)
))
yield
assert(results.size == 2)
Expand Down

0 comments on commit c61f5e6

Please sign in to comment.