From 4d8a65498f4d5eff7b75a04178a6e80a1c1a6d82 Mon Sep 17 00:00:00 2001 From: Alejandro Serrano Date: Mon, 10 Jul 2023 13:34:12 +0200 Subject: [PATCH] MemoizedDeepRecursiveFunction --- .../core/MemoizedDeepRecursiveFunction.kt | 28 +++++++++++++++++++ .../kotlin/arrow/core/MemoizationTest.kt | 15 ++++++++++ 2 files changed, 43 insertions(+) create mode 100644 arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/MemoizedDeepRecursiveFunction.kt diff --git a/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/MemoizedDeepRecursiveFunction.kt b/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/MemoizedDeepRecursiveFunction.kt new file mode 100644 index 00000000000..5ec6995de8f --- /dev/null +++ b/arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/MemoizedDeepRecursiveFunction.kt @@ -0,0 +1,28 @@ +package arrow.core + +import arrow.atomic.Atomic +import arrow.atomic.loop + +public fun MemoizedDeepRecursiveFunction( + block: suspend DeepRecursiveScope.(T) -> R +): DeepRecursiveFunction { + val cache = Atomic(emptyMap()) + return DeepRecursiveFunction { x -> + when (x) { + in cache.get() -> cache.get().getValue(x) + else -> { + val result = block(x) + cache.loop { old -> + when (x) { + in old -> + return@DeepRecursiveFunction old.getValue(x) + else -> { + if (cache.compareAndSet(old, old + Pair(x, result))) + return@DeepRecursiveFunction result + } + } + } + } + } + } +} diff --git a/arrow-libs/core/arrow-core/src/commonTest/kotlin/arrow/core/MemoizationTest.kt b/arrow-libs/core/arrow-core/src/commonTest/kotlin/arrow/core/MemoizationTest.kt index 24478adfc21..c719a76f9e5 100644 --- a/arrow-libs/core/arrow-core/src/commonTest/kotlin/arrow/core/MemoizationTest.kt +++ b/arrow-libs/core/arrow-core/src/commonTest/kotlin/arrow/core/MemoizationTest.kt @@ -215,6 +215,21 @@ class MemoizationTest : StringSpec({ memoized(1, 2, 3, 4, 5) shouldBe null runs shouldBe 1 } + + "Recursive memoization" { + var runs = 0 + val memoizedDeepRecursiveFibonacci: DeepRecursiveFunction = + MemoizedDeepRecursiveFunction { n -> + when (n) { + 0 -> 0.also { runs++ } + 1 -> 1 + else -> callRecursive(n - 1) + callRecursive(n - 2) + } + } + val result = memoizedDeepRecursiveFibonacci(5) + result shouldBe 5 + runs shouldBe 1 + } }) private fun consecSumResult(n: Int): Int = (n * (n + 1)) / 2