Skip to content

Commit

Permalink
Channel: drain/drainUpTo include pending puts (#980)
Browse files Browse the repository at this point in the history
closes #978

---------

Co-authored-by: Adam Hearn <[email protected]>
Co-authored-by: Flavio Brasil <[email protected]>
  • Loading branch information
3 people authored Jan 9, 2025
1 parent 7a3afd8 commit ff73419
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 14 deletions.
48 changes: 37 additions & 11 deletions kyo-core/shared/src/main/scala/kyo/Channel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,17 @@ object Channel:
def takeExactly(n: Int)(using Frame): Chunk[A] < (Abort[Closed] & Async) =
if n <= 0 then Chunk.empty
else
Loop(Chunk.empty[A]): lastChunk =>
val nextN = n - lastChunk.size
Loop(Chunk.empty[A], 0): (lastChunk, lastSize) =>
val nextN = n - lastSize
Channel.drainUpTo(self)(nextN).map: chunk =>
val chunk1 = lastChunk.concat(chunk)
if chunk1.size == n then Loop.done(chunk1)
if chunk1.size >= n then Loop.done(chunk1)
else
self.take.map: a =>
val chunk2 = chunk1.append(a)
if chunk2.size == n then Loop.done(chunk2)
else Loop.continue(chunk2)
val size2 = chunk2.size
if size2 >= n then Loop.done(chunk2)
else Loop.continue(chunk2, size2)
end if

/** Creates a fiber that puts an element into the channel.
Expand Down Expand Up @@ -339,9 +340,23 @@ object Channel:
end poll

def drainUpTo(max: Int)(using AllowUnsafe) =
val result = queue.drainUpTo(max)
if result.exists(_.nonEmpty) then flush()
result
@tailrec
def loop(current: Chunk[A], i: Int): Result[Closed, Chunk[A]] =
if i == 0 then Result.Success(current)
else
val next = queue.drainUpTo(i)
next match
case Result.Success(c) =>
if c.isEmpty then Result.Success(current)
else
flush()
loop(current.concat(c), i - c.length)
case other => other
end match
end if
end loop

loop(Chunk.empty, max)
end drainUpTo

def putFiber(value: A)(using AllowUnsafe): Fiber.Unsafe[Closed, Unit] =
Expand All @@ -368,9 +383,20 @@ object Channel:
end takeFiber

def drain()(using AllowUnsafe) =
val result = queue.drain()
if result.exists(_.nonEmpty) then flush()
result
@tailrec
def loop(current: Chunk[A]): Result[Closed, Chunk[A]] =
val next = queue.drain()
next match
case Result.Success(c) =>
if c.isEmpty then Result.Success(current)
else
flush()
loop(current.concat(c))
case other => other
end match
end loop

loop(Chunk.empty)
end drain

def close()(using frame: Frame, allow: AllowUnsafe) =
Expand Down
34 changes: 31 additions & 3 deletions kyo-core/shared/src/test/scala/kyo/ChannelTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,33 @@ class ChannelTest extends Test:
r <- c.drain
yield assert(r == Seq(1, 2))
}
"should consider pending puts" in run {
import AllowUnsafe.embrace.danger
IO.Unsafe.evalOrThrow {
for
c <- Channel.init[Int](2)
_ <- c.putFiber(1)
_ <- c.putFiber(2)
_ <- c.putFiber(3)
result <- c.drain
finalSize <- c.size
yield assert(result == Chunk(1, 2, 3) && finalSize == 0)
}
}
"should consider pending puts - zero capacity" in pendingUntilFixed {
import AllowUnsafe.embrace.danger
IO.Unsafe.evalOrThrow {
for
c <- Channel.init[Int](0)
_ <- c.putFiber(1)
_ <- c.putFiber(2)
_ <- c.putFiber(3)
result <- c.drain
finalSize <- c.size
yield assert(result == Chunk(1, 2, 3) && finalSize == 0)
}
()
}
}
"drainUpTo" - {
"zero or negative" in run {
Expand Down Expand Up @@ -356,19 +383,19 @@ class ChannelTest extends Test:
s <- c.size
yield assert(r == Seq(1, 2) && s == 2)
}
"should consider pending puts" in pendingUntilFixed {
"should consider pending puts" in run {
import AllowUnsafe.embrace.danger
IO.Unsafe.evalOrThrow {
for
c <- Channel.init[Int](2)
_ <- c.putFiber(1)
_ <- c.putFiber(2)
_ <- c.putFiber(3)
_ <- c.putFiber(4)
result <- c.drainUpTo(3)
finalSize <- c.size
yield assert(result == Chunk(1, 2, 3) && finalSize == 0)
yield assert(result == Chunk(1, 2, 3) && finalSize == 1)
}
()
}
"should consider pending puts - zero capacity" in pendingUntilFixed {
import AllowUnsafe.embrace.danger
Expand All @@ -378,6 +405,7 @@ class ChannelTest extends Test:
_ <- c.putFiber(1)
_ <- c.putFiber(2)
_ <- c.putFiber(3)
_ <- c.putFiber(4)
result <- c.drainUpTo(3)
finalSize <- c.size
yield assert(result == Chunk(1, 2, 3) && finalSize == 0)
Expand Down

0 comments on commit ff73419

Please sign in to comment.