Skip to content

Commit

Permalink
[core] introduce IO.withLocal (#1005)
Browse files Browse the repository at this point in the history
Getting the value of a `Local` and then immediately performing a side
effect is a common pattern in the codebase. See
[TRef](https://github.com/getkyo/kyo/blob/9cdf197909d2ecdcd9436f5e9c687f5bbd9b5b2a/kyo-stm/shared/src/main/scala/kyo/TRef.scala#L185)
as an example. Other methods like
[Clock.deadline](https://github.com/getkyo/kyo/blob/9cdf197909d2ecdcd9436f5e9c687f5bbd9b5b2a/kyo-core/shared/src/main/scala/kyo/Clock.scala#L309)
have a similar pattern.

It currently requires two separate suspensions but that's unnecessary.
The local get already suspends execution the same way as `IO` does: via
the kernel's internal
[Defer](https://github.com/getkyo/kyo/blob/9cdf197909d2ecdcd9436f5e9c687f5bbd9b5b2a/kyo-kernel/shared/src/main/scala/kyo/kernel/internal/Kyo.scala#L80)
effect, which is the underlying suspension of `ContextEffect` and is
guaranteed to be the last effect to be handled. This PR introduces
`IO.withLocal` and `IO.Unsafe.withLocal` as APIs that provide both
operations in a single step/allocation.
  • Loading branch information
fwbrasil authored Jan 15, 2025
1 parent d668e60 commit 5674fab
Show file tree
Hide file tree
Showing 13 changed files with 257 additions and 128 deletions.
42 changes: 20 additions & 22 deletions kyo-core/shared/src/main/scala/kyo/Clock.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,21 @@ object Clock:
def withTimeShift[A, S](factor: Double)(v: => A < S)(using Frame): A < (IO & S) =
if factor == 1 then v
else
use { clock =>
IO.Unsafe {
val shifted =
new Unsafe:
val underlying = clock.unsafe
val start = underlying.now()
val sleepFactor = (1.toDouble / factor)
def nowMonotonic()(using AllowUnsafe) =
now().toDuration
def now()(using AllowUnsafe) =
val diff = underlying.now() - start
start + (diff * factor)
end now
override def sleep(duration: Duration) =
underlying.sleep(duration * sleepFactor)
let(Clock(shifted))(v)
}
IO.Unsafe.withLocal(local) { clock =>
val shifted =
new Unsafe:
val underlying = clock.unsafe
val start = underlying.now()
val sleepFactor = (1.toDouble / factor)
def nowMonotonic()(using AllowUnsafe) =
now().toDuration
def now()(using AllowUnsafe) =
val diff = underlying.now() - start
start + (diff * factor)
end now
override def sleep(duration: Duration) =
underlying.sleep(duration * sleepFactor)
let(Clock(shifted))(v)
}
end withTimeShift

Expand Down Expand Up @@ -274,7 +272,7 @@ object Clock:
* The current time
*/
def now(using Frame): Instant < IO =
use(_.now)
IO.Unsafe.withLocal(local)(_.unsafe.now())

/** Gets the current monotonic time using the local Clock instance. Unlike `now`, this is guaranteed to be strictly monotonic and
* suitable for measuring elapsed time.
Expand All @@ -286,18 +284,18 @@ object Clock:
* The current monotonic time as a Duration since system start
*/
def nowMonotonic(using Frame): Duration < IO =
use(_.nowMonotonic)
IO.Unsafe.withLocal(local)(_.unsafe.nowMonotonic())

private[kyo] def sleep(duration: Duration)(using Frame): Fiber[Nothing, Unit] < IO =
use(_.sleep(duration))
IO.Unsafe.withLocal(local)(_.unsafe.sleep(duration).safe)

/** Creates a new stopwatch using the local Clock instance.
*
* @return
* A new Stopwatch instance
*/
def stopwatch(using Frame): Stopwatch < IO =
use(_.stopwatch)
IO.Unsafe.withLocal(local)(_.unsafe.stopwatch().safe)

/** Creates a new deadline with the specified duration using the local Clock instance.
*
Expand All @@ -307,7 +305,7 @@ object Clock:
* A new Deadline instance
*/
def deadline(duration: Duration)(using Frame): Deadline < IO =
use(_.deadline(duration))
IO.Unsafe.withLocal(local)(_.unsafe.deadline(duration).safe)

/** Repeatedly executes a task with a fixed delay between completions.
*
Expand Down
52 changes: 24 additions & 28 deletions kyo-core/shared/src/main/scala/kyo/Console.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,14 @@ object Console:
* The result of the computation with IO effects.
*/
def withIn[A, S](lines: Iterable[String])(v: A < S)(using Frame): A < (IO & S) =
use { console =>
IO {
val it = lines.iterator
val proxy =
new Proxy(console.unsafe):
override def readLine()(using AllowUnsafe) =
if !it.hasNext then Result.fail(new EOFException("Consoles.readLine failed."))
else Result.success(it.next())
let(Console(proxy))(v)
}
IO.withLocal(local) { console =>
val it = lines.iterator
val proxy =
new Proxy(console.unsafe):
override def readLine()(using AllowUnsafe) =
if !it.hasNext then Result.fail(new EOFException("Consoles.readLine failed."))
else Result.success(it.next())
let(Console(proxy))(v)
}

/** Container for captured console output.
Expand All @@ -150,19 +148,17 @@ object Console:
* A tuple containing the captured output (Out) and the computation result.
*/
def withOut[A, S](v: A < S)(using Frame): (Out, A) < (IO & S) =
use { console =>
IO {
val stdOut = new StringBuffer
val stdErr = new StringBuffer
val proxy =
new Proxy(console.unsafe):
override def print(s: String)(using AllowUnsafe) = Result.success(stdOut.append(s)).unit
override def printErr(s: String)(using AllowUnsafe) = Result.success(stdErr.append(s)).unit
override def printLine(s: String)(using AllowUnsafe) = Result.success(stdOut.append(s + "\n")).unit
override def printLineErr(s: String)(using AllowUnsafe) = Result.success(stdErr.append(s + "\n")).unit
let(Console(proxy))(v)
.map(r => IO((Out(stdOut.toString(), stdErr.toString()), r)))
}
IO.withLocal(local) { console =>
val stdOut = new StringBuffer
val stdErr = new StringBuffer
val proxy =
new Proxy(console.unsafe):
override def print(s: String)(using AllowUnsafe) = Result.success(stdOut.append(s)).unit
override def printErr(s: String)(using AllowUnsafe) = Result.success(stdErr.append(s)).unit
override def printLine(s: String)(using AllowUnsafe) = Result.success(stdOut.append(s + "\n")).unit
override def printLineErr(s: String)(using AllowUnsafe) = Result.success(stdErr.append(s + "\n")).unit
let(Console(proxy))(v)
.map(r => IO((Out(stdOut.toString(), stdErr.toString()), r)))
}

/** Reads a line from the console.
Expand All @@ -171,7 +167,7 @@ object Console:
* A String representing the line read from the console.
*/
def readLine(using Frame): String < (IO & Abort[IOException]) =
local.use(_.readLine)
IO.Unsafe.withLocal(local)(console => Abort.get(console.unsafe.readLine()))

private def toString(v: Any)(using Frame): String =
v match
Expand All @@ -186,31 +182,31 @@ object Console:
* The value to print.
*/
def print[A](v: A)(using Frame): Unit < (IO & Abort[IOException]) =
local.use(_.print(toString(v)))
IO.Unsafe.withLocal(local)(console => Abort.get(console.unsafe.print(toString(v))))

/** Prints a value to the console's error stream without a newline.
*
* @param v
* The value to print to the error stream.
*/
def printErr[A](v: A)(using Frame): Unit < (IO & Abort[IOException]) =
local.use(_.printErr(toString(v)))
IO.Unsafe.withLocal(local)(console => Abort.get(console.unsafe.printErr(toString(v))))

/** Prints a value to the console followed by a newline.
*
* @param v
* The value to print.
*/
def printLine[A](v: A)(using Frame): Unit < (IO & Abort[IOException]) =
local.use(_.println(toString(v)))
IO.Unsafe.withLocal(local)(console => Abort.get(console.unsafe.printLine(toString(v))))

/** Prints a value to the console's error stream followed by a newline.
*
* @param v
* The value to print to the error stream.
*/
def printLineErr[A](v: A)(using Frame): Unit < (IO & Abort[IOException]) =
local.use(_.printLineErr(toString(v)))
IO.Unsafe.withLocal(local)(console => Abort.get(console.unsafe.printLineErr(toString(v))))

/** WARNING: Low-level API meant for integrations, libraries, and performance-sensitive code. See AllowUnsafe for more details. */
abstract class Unsafe:
Expand Down
26 changes: 24 additions & 2 deletions kyo-core/shared/src/main/scala/kyo/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,37 @@ object IO:
def ensure[A, S](f: => Unit < IO)(v: A < S)(using frame: Frame): A < (IO & S) =
Unsafe(Safepoint.ensure(IO.Unsafe.evalOrThrow(f))(v))

/** Retrieves a local value and applies a function that can perform side effects.
*
* This is the preferred way to access a local value when you need to perform side effects with it. Common use cases include accessing
* loggers, configuration, or request-scoped values that you need to use in computations that produce side effects.
*
* While `local.get.map(v => IO(f(v)))` would also work, this method is more direct since both IO and Local use the same underlying
* mechanism to handle effects. Under the hood, accessing a local value and performing IO operations both use the same type of
* suspension, the kernel's internal `Defer` effect. This means we can safely combine them without creating unnecessary layers of
* suspension.
*
* @param local
* The local value to access
* @param f
* Function that can perform side effects with the local value
* @return
* An IO effect containing the result of applying the function
*/
def withLocal[A, B, S](local: Local[A])(f: A => B < S)(using Frame): B < (S & IO) =
local.use(f)

/** WARNING: Low-level API meant for integrations, libraries, and performance-sensitive code. See AllowUnsafe for more details. */
object Unsafe:

inline def apply[A, S](inline f: AllowUnsafe ?=> A < S)(using inline frame: Frame): A < (IO & S) =
Effect.deferInline {
import AllowUnsafe.embrace.danger
f
f(using AllowUnsafe.embrace.danger)
}

def withLocal[A, B, S](local: Local[A])(f: AllowUnsafe ?=> A => B < S)(using Frame): B < (S & IO) =
local.use(f(using AllowUnsafe.embrace.danger))

/** Evaluates an IO effect that may throw exceptions, converting any thrown exceptions into the final result.
*
* WARNING: This is a low-level API that should be used with caution. It forcefully evaluates the IO effect and will throw any
Expand Down
4 changes: 2 additions & 2 deletions kyo-core/shared/src/main/scala/kyo/Log.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ object Log extends LogPlatformSpecific:
private inline def logWhen(inline level: Level)(inline doLog: Log => Unit < IO)(using
inline frame: Frame
): Unit < IO =
use { log =>
IO.Unsafe.withLocal(local) { log =>
if level.enabled(log.level) then
IO.Unsafe(doLog(log))
doLog(log)
else
(
)
Expand Down
6 changes: 3 additions & 3 deletions kyo-core/shared/src/main/scala/kyo/Meter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,12 @@ object Meter:

private inline def withReentry[A, S](inline reenter: => A < S)(acquire: AllowUnsafe ?=> A < S): A < (IO & S) =
if reentrant then
acquiredMeters.use { meters =>
IO.withLocal(acquiredMeters) { meters =>
if meters.contains(this) then reenter
else IO.Unsafe(acquire)
else acquire
}
else
IO.Unsafe(acquire)
acquire

private inline def withAcquiredMeter[A, S](inline v: => A < S) =
if reentrant then
Expand Down
39 changes: 26 additions & 13 deletions kyo-core/shared/src/main/scala/kyo/Random.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ object Random:
* @return
* A random Int value.
*/
def nextInt(using Frame): Int < IO = local.use(_.nextInt)
def nextInt(using Frame): Int < IO =
IO.Unsafe.withLocal(local)(_.unsafe.nextInt())

/** Generates a random integer between 0 (inclusive) and the specified bound (exclusive).
*
Expand All @@ -189,42 +190,48 @@ object Random:
* @return
* A random Int value within the specified range.
*/
def nextInt(exclusiveBound: Int)(using Frame): Int < IO = local.use(_.nextInt(exclusiveBound))
def nextInt(exclusiveBound: Int)(using Frame): Int < IO =
IO.Unsafe.withLocal(local)(_.unsafe.nextInt(exclusiveBound))

/** Generates a random long integer.
*
* @return
* A random Long value.
*/
def nextLong(using Frame): Long < IO = local.use(_.nextLong)
def nextLong(using Frame): Long < IO =
IO.Unsafe.withLocal(local)(_.unsafe.nextLong())

/** Generates a random double between 0.0 (inclusive) and 1.0 (exclusive).
*
* @return
* A random Double value between 0.0 and 1.0.
*/
def nextDouble(using Frame): Double < IO = local.use(_.nextDouble)
def nextDouble(using Frame): Double < IO =
IO.Unsafe.withLocal(local)(_.unsafe.nextDouble())

/** Generates a random boolean value.
*
* @return
* A random Boolean value.
*/
def nextBoolean(using Frame): Boolean < IO = local.use(_.nextBoolean)
def nextBoolean(using Frame): Boolean < IO =
IO.Unsafe.withLocal(local)(_.unsafe.nextBoolean())

/** Generates a random float between 0.0 (inclusive) and 1.0 (exclusive).
*
* @return
* A random Float value between 0.0 and 1.0.
*/
def nextFloat(using Frame): Float < IO = local.use(_.nextFloat)
def nextFloat(using Frame): Float < IO =
IO.Unsafe.withLocal(local)(_.unsafe.nextFloat())

/** Generates a random double from a Gaussian distribution with mean 0.0 and standard deviation 1.0.
*
* @return
* A random Double value from a Gaussian distribution.
*/
def nextGaussian(using Frame): Double < IO = local.use(_.nextGaussian)
def nextGaussian(using Frame): Double < IO =
IO.Unsafe.withLocal(local)(_.unsafe.nextGaussian())

/** Selects a random element from the given sequence.
*
Expand All @@ -235,7 +242,8 @@ object Random:
* @return
* A randomly selected element from the sequence.
*/
def nextValue[A](seq: Seq[A])(using Frame): A < IO = local.use(_.nextValue(seq))
def nextValue[A](seq: Seq[A])(using Frame): A < IO =
IO.Unsafe.withLocal(local)(_.unsafe.nextValue(seq): A)

/** Generates a sequence of random elements from the given sequence.
*
Expand All @@ -248,7 +256,8 @@ object Random:
* @return
* A new sequence of randomly selected elements.
*/
def nextValues[A](length: Int, seq: Seq[A])(using Frame): Seq[A] < IO = local.use(_.nextValues(length, seq))
def nextValues[A](length: Int, seq: Seq[A])(using Frame): Seq[A] < IO =
IO.Unsafe.withLocal(local)(_.unsafe.nextValues(length, seq))

/** Generates a random alphanumeric string of the specified length.
*
Expand All @@ -257,7 +266,8 @@ object Random:
* @return
* A random alphanumeric String of the specified length.
*/
def nextStringAlphanumeric(length: Int)(using Frame): String < IO = local.use(_.nextStringAlphanumeric(length))
def nextStringAlphanumeric(length: Int)(using Frame): String < IO =
IO.Unsafe.withLocal(local)(_.unsafe.nextStringAlphanumeric(length))

/** Generates a random string of the specified length using the given character sequence and the current Random instance.
*
Expand All @@ -268,7 +278,8 @@ object Random:
* @return
* A random String of the specified length.
*/
def nextString(length: Int, chars: Seq[Char])(using Frame): String < IO = local.use(_.nextString(length, chars))
def nextString(length: Int, chars: Seq[Char])(using Frame): String < IO =
IO.Unsafe.withLocal(local)(_.unsafe.nextString(length, chars))

/** Generates a sequence of random bytes.
*
Expand All @@ -277,7 +288,8 @@ object Random:
* @return
* A Seq[Byte] of random bytes.
*/
def nextBytes(length: Int)(using Frame): Seq[Byte] < IO = local.use(_.nextBytes(length))
def nextBytes(length: Int)(using Frame): Seq[Byte] < IO =
IO.Unsafe.withLocal(local)(_.unsafe.nextBytes(length))

/** Shuffles the elements of the given sequence randomly.
*
Expand All @@ -288,6 +300,7 @@ object Random:
* @return
* A new sequence with the elements shuffled randomly.
*/
def shuffle[A](seq: Seq[A])(using Frame): Seq[A] < IO = local.use(_.shuffle(seq))
def shuffle[A](seq: Seq[A])(using Frame): Seq[A] < IO =
IO.Unsafe.withLocal(local)(_.unsafe.shuffle(seq))

end Random
Loading

0 comments on commit 5674fab

Please sign in to comment.