Skip to content

Commit

Permalink
Locals.update
Browse files Browse the repository at this point in the history
  • Loading branch information
fwbrasil committed Dec 15, 2023
1 parent 4fbe334 commit 9ae6418
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 11 deletions.
38 changes: 27 additions & 11 deletions kyo-core/shared/src/main/scala/kyo/locals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,40 @@ object locals {

val get: T > IOs =
new KyoIO[T, Any] {

def apply(v: Unit > (IOs with Any), s: Safepoint[IO, IOs], l: State) =
l.getOrElse(Local.this, default).asInstanceOf[T]
get(l)
}

def let[U, S1, S2](v: T > S1)(f: U > S2): U > (S1 with S2 with IOs) = {
def loop(v: T, f: U > S2): U > S2 =
f match {
case kyo: Kyo[MX, EX, Any, U, S2] @unchecked =>
new KyoCont[MX, EX, Any, U, S2](kyo) {
def apply(v2: Any > S2, s: Safepoint[MX, EX], l: Locals.State) =
loop(v, kyo(v2, s, l.updated(Local.this, v)))
def let[U, S](f: T)(v: U > S): U > (S with IOs) = {
def letLoop(f: T, v: U > S): U > S =
v match {
case kyo: Kyo[MX, EX, Any, U, S] @unchecked =>
new KyoCont[MX, EX, Any, U, S](kyo) {
def apply(v2: Any > S, s: Safepoint[MX, EX], l: Locals.State) =
letLoop(f, kyo(v2, s, l.updated(Local.this, f)))
}
case _ =>
f
v
}
v.map(loop(_, f))
letLoop(f, v)
}

def update[U, S](f: T => T)(v: U > S): U > (S with IOs) = {
def updateLoop(f: T => T, v: U > S): U > S =
v match {
case kyo: Kyo[MX, EX, Any, U, S] @unchecked =>
new KyoCont[MX, EX, Any, U, S](kyo) {
def apply(v2: Any > S, s: Safepoint[MX, EX], l: Locals.State) =
updateLoop(f, kyo(v2, s, l.updated(Local.this, f(get(l)))))
}
case _ =>
v
}
updateLoop(f, v)
}

private def get(l: Locals.State) =
l.getOrElse(Local.this, default).asInstanceOf[T]
}

object Locals {
Expand Down
34 changes: 34 additions & 0 deletions kyo-core/shared/src/test/scala/kyoTest/localsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,40 @@ class localsTest extends KyoTest {
}
}

"update" - {
"get" in {
val l = Locals.init(10)
assert(
IOs.run[Int](l.update(_ + 10)(l.get)) ==
20
)
}
"effect + get" in {
val l = Locals.init(10)
assert(
IOs.run[Option[Int]](Options.run(Options(1).map(_ => l.update(_ + 10)(l.get)))) ==
Some(20)
)
}
"effect + get + effect" in {
val l = Locals.init(10)
assert(
IOs.run[Option[Int]](
Options.run(Options(1).map(_ => l.update(_ + 10)(l.get).map(Options(_))))
) ==
Some(20)
)
}
"multiple" in {
val l1 = Locals.init(10)
val l2 = Locals.init(20)
assert(
IOs.run[(Int, Int)](zip(l1.update(_ + 10)(l1.get), l2.update(_ + 10)(l2.get))) ==
(20, 30)
)
}
}

"save" - {
"let + save" in {
val l = Locals.init(10)
Expand Down

0 comments on commit 9ae6418

Please sign in to comment.