From 98be10047fdca23c5246b89c19d84a85fc61574c Mon Sep 17 00:00:00 2001
From: Tim <hello@timsmart.co>
Date: Fri, 24 Jan 2025 22:54:41 +1300
Subject: [PATCH] add {FiberHandle,FiberSet,FiberMap}.awaitEmpty apis (#4337)

---
 .changeset/tough-cars-invite.md          |  5 ++
 packages/effect/src/FiberHandle.ts       | 71 +++++++++++-----------
 packages/effect/src/FiberMap.ts          | 77 ++++++++++++------------
 packages/effect/src/FiberSet.ts          | 59 ++++++++++--------
 packages/effect/test/FiberHandle.test.ts | 14 ++++-
 packages/effect/test/FiberMap.test.ts    | 17 +++++-
 packages/effect/test/FiberSet.test.ts    | 17 +++++-
 7 files changed, 159 insertions(+), 101 deletions(-)
 create mode 100644 .changeset/tough-cars-invite.md

diff --git a/.changeset/tough-cars-invite.md b/.changeset/tough-cars-invite.md
new file mode 100644
index 00000000000..304c8a69e64
--- /dev/null
+++ b/.changeset/tough-cars-invite.md
@@ -0,0 +1,5 @@
+---
+"effect": minor
+---
+
+add {FiberHandle,FiberSet,FiberMap}.awaitEmpty apis
diff --git a/packages/effect/src/FiberHandle.ts b/packages/effect/src/FiberHandle.ts
index fc9c94d9dff..3a506f5c5f9 100644
--- a/packages/effect/src/FiberHandle.ts
+++ b/packages/effect/src/FiberHandle.ts
@@ -339,44 +339,31 @@ export const run: {
 } = function() {
   const self = arguments[0] as FiberHandle
   if (Effect.isEffect(arguments[1])) {
-    const effect = arguments[1]
-    const options = arguments[2] as {
-      readonly onlyIfMissing?: boolean
-      readonly propagateInterruption?: boolean | undefined
-    } | undefined
-    return Effect.suspend(() => {
-      if (self.state._tag === "Closed") {
-        return Effect.interrupt
-      } else if (self.state.fiber !== undefined && options?.onlyIfMissing === true) {
-        return Effect.sync(constInterruptedFiber)
-      }
-      return Effect.uninterruptibleMask((restore) =>
-        Effect.tap(
-          restore(Effect.forkDaemon(effect)),
-          (fiber) => set(self, fiber, options)
-        )
-      )
-    }) as any
+    return runImpl(self, arguments[1], arguments[2]) as any
   }
-  const options = arguments[1] as {
+  const options = arguments[1]
+  return (effect: Effect.Effect<unknown, unknown, any>) => runImpl(self, effect, options)
+}
+
+const runImpl = <A, E, R, XE extends E, XA extends A>(
+  self: FiberHandle<A, E>,
+  effect: Effect.Effect<XA, XE, R>,
+  options?: {
     readonly onlyIfMissing?: boolean
     readonly propagateInterruption?: boolean | undefined
-  } | undefined
-  return (effect: Effect.Effect<unknown, unknown, any>) =>
-    Effect.suspend(() => {
-      if (self.state._tag === "Closed") {
-        return Effect.interrupt
-      } else if (self.state.fiber !== undefined && options?.onlyIfMissing === true) {
-        return Effect.sync(constInterruptedFiber)
-      }
-      return Effect.uninterruptibleMask((restore) =>
-        Effect.tap(
-          restore(Effect.forkDaemon(effect)),
-          (fiber) => set(self, fiber, options)
-        )
-      )
-    })
-}
+  }
+): Effect.Effect<Fiber.RuntimeFiber<XA, XE>, never, R> =>
+  Effect.fiberIdWith((fiberId) => {
+    if (self.state._tag === "Closed") {
+      return Effect.interrupt
+    } else if (self.state.fiber !== undefined && options?.onlyIfMissing === true) {
+      return Effect.sync(constInterruptedFiber)
+    }
+    return Effect.tap(
+      Effect.forkDaemon(effect),
+      (fiber) => unsafeSet(self, fiber, { ...options, interruptAs: fiberId })
+    )
+  })
 
 /**
  * Capture a Runtime and use it to fork Effect's, adding the forked fibers to the FiberHandle.
@@ -470,3 +457,17 @@ export const runtime: <A, E>(
  */
 export const join = <A, E>(self: FiberHandle<A, E>): Effect.Effect<void, E> =>
   Deferred.await(self.deferred as Deferred.Deferred<void, E>)
+
+/**
+ * Wait for the fiber in the FiberHandle to complete.
+ *
+ * @since 3.13.0
+ * @categories combinators
+ */
+export const awaitEmpty = <A, E>(self: FiberHandle<A, E>): Effect.Effect<void, E> =>
+  Effect.suspend(() => {
+    if (self.state._tag === "Closed" || self.state.fiber === undefined) {
+      return Effect.void
+    }
+    return Fiber.await(self.state.fiber)
+  })
diff --git a/packages/effect/src/FiberMap.ts b/packages/effect/src/FiberMap.ts
index b80d0b13c13..e4ba703c693 100644
--- a/packages/effect/src/FiberMap.ts
+++ b/packages/effect/src/FiberMap.ts
@@ -8,7 +8,7 @@ import * as Effect from "./Effect.js"
 import * as Exit from "./Exit.js"
 import * as Fiber from "./Fiber.js"
 import * as FiberId from "./FiberId.js"
-import { constFalse, dual } from "./Function.js"
+import { constFalse, constVoid, dual } from "./Function.js"
 import * as HashSet from "./HashSet.js"
 import * as Inspectable from "./Inspectable.js"
 import * as Iterable from "./Iterable.js"
@@ -438,49 +438,35 @@ export const run: {
     } | undefined
   ): Effect.Effect<Fiber.RuntimeFiber<XA, XE>, never, R>
 } = function() {
+  const self = arguments[0]
   if (Effect.isEffect(arguments[2])) {
-    const self = arguments[0] as FiberMap<any>
-    const key = arguments[1]
-    const effect = arguments[2] as Effect.Effect<any, any, any>
-    const options = arguments[3] as {
-      readonly onlyIfMissing?: boolean
-      readonly propagateInterruption?: boolean | undefined
-    } | undefined
-    return Effect.suspend(() => {
-      if (self.state._tag === "Closed") {
-        return Effect.interrupt
-      } else if (options?.onlyIfMissing === true && unsafeHas(self, key)) {
-        return Effect.sync(constInterruptedFiber)
-      }
-      return Effect.uninterruptibleMask((restore) =>
-        Effect.tap(
-          restore(Effect.forkDaemon(effect)),
-          (fiber) => set(self, key, fiber, options)
-        )
-      )
-    }) as any
+    return runImpl(self, arguments[1], arguments[2], arguments[3]) as any
   }
-  const self = arguments[0] as FiberMap<any>
   const key = arguments[1]
-  const options = arguments[2] as {
+  const options = arguments[2]
+  return (effect: Effect.Effect<any, any, any>) => runImpl(self, key, effect, options)
+}
+
+const runImpl = <K, A, E, R, XE extends E, XA extends A>(
+  self: FiberMap<K, A, E>,
+  key: K,
+  effect: Effect.Effect<XA, XE, R>,
+  options?: {
     readonly onlyIfMissing?: boolean
     readonly propagateInterruption?: boolean | undefined
-  } | undefined
-  return (effect: Effect.Effect<any, any, any>) =>
-    Effect.suspend(() => {
-      if (self.state._tag === "Closed") {
-        return Effect.interrupt
-      } else if (options?.onlyIfMissing === true && unsafeHas(self, key)) {
-        return Effect.sync(constInterruptedFiber)
-      }
-      return Effect.uninterruptibleMask((restore) =>
-        Effect.tap(
-          restore(Effect.forkDaemon(effect)),
-          (fiber) => set(self, key, fiber, options)
-        )
-      )
-    })
-}
+  }
+) =>
+  Effect.fiberIdWith((fiberId) => {
+    if (self.state._tag === "Closed") {
+      return Effect.interrupt
+    } else if (options?.onlyIfMissing === true && unsafeHas(self, key)) {
+      return Effect.sync(constInterruptedFiber)
+    }
+    return Effect.tap(
+      Effect.forkDaemon(effect),
+      (fiber) => unsafeSet(self, key, fiber, { ...options, interruptAs: fiberId })
+    )
+  })
 
 /**
  * Capture a Runtime and use it to fork Effect's, adding the forked fibers to the FiberMap.
@@ -581,3 +567,16 @@ export const size = <K, A, E>(self: FiberMap<K, A, E>): Effect.Effect<number> =>
  */
 export const join = <K, A, E>(self: FiberMap<K, A, E>): Effect.Effect<void, E> =>
   Deferred.await(self.deferred as Deferred.Deferred<void, E>)
+
+/**
+ * Wait for the FiberMap to be empty.
+ *
+ * @since 3.13.0
+ * @categories combinators
+ */
+export const awaitEmpty = <K, A, E>(self: FiberMap<K, A, E>): Effect.Effect<void, E> =>
+  Effect.whileLoop({
+    while: () => self.state._tag === "Open" && MutableHashMap.size(self.state.backing) > 0,
+    body: () => Fiber.await(Iterable.unsafeHead(self)[1]),
+    step: constVoid
+  })
diff --git a/packages/effect/src/FiberSet.ts b/packages/effect/src/FiberSet.ts
index 722db68018c..754d976544c 100644
--- a/packages/effect/src/FiberSet.ts
+++ b/packages/effect/src/FiberSet.ts
@@ -7,7 +7,7 @@ import * as Effect from "./Effect.js"
 import * as Exit from "./Exit.js"
 import * as Fiber from "./Fiber.js"
 import * as FiberId from "./FiberId.js"
-import { constFalse, dual } from "./Function.js"
+import { constFalse, constVoid, dual } from "./Function.js"
 import * as HashSet from "./HashSet.js"
 import * as Inspectable from "./Inspectable.js"
 import * as Iterable from "./Iterable.js"
@@ -291,34 +291,32 @@ export const run: {
 } = function() {
   const self = arguments[0] as FiberSet<any, any>
   if (!Effect.isEffect(arguments[1])) {
-    const options = arguments[1] as { readonly propagateInterruption?: boolean | undefined } | undefined
-    return (effect: Effect.Effect<any, any, any>) =>
-      Effect.suspend(() => {
-        if (self.state._tag === "Closed") {
-          return Effect.interrupt
-        }
-        return Effect.uninterruptibleMask((restore) =>
-          Effect.tap(
-            restore(Effect.forkDaemon(effect)),
-            (fiber) => add(self, fiber, options)
-          )
-        )
-      })
+    const options = arguments[1]
+    return (effect: Effect.Effect<any, any, any>) => runImpl(self, effect, options)
   }
-  const effect = arguments[1]
-  const options = arguments[2] as { readonly propagateInterruption?: boolean | undefined } | undefined
-  return Effect.suspend(() => {
+  return runImpl(self, arguments[1], arguments[2]) as any
+}
+
+const runImpl = <A, E, R, XE extends E, XA extends A>(
+  self: FiberSet<A, E>,
+  effect: Effect.Effect<XA, XE, R>,
+  options?: {
+    readonly propagateInterruption?: boolean | undefined
+  }
+): Effect.Effect<Fiber.RuntimeFiber<XA, XE>, never, R> =>
+  Effect.fiberIdWith((fiberId) => {
     if (self.state._tag === "Closed") {
       return Effect.interrupt
     }
-    return Effect.uninterruptibleMask((restore) =>
-      Effect.tap(
-        restore(Effect.forkDaemon(effect)),
-        (fiber) => add(self, fiber, options)
-      )
+    return Effect.tap(
+      Effect.forkDaemon(effect),
+      (fiber) =>
+        unsafeAdd(self, fiber, {
+          ...options,
+          interruptAs: fiberId
+        })
     )
-  }) as any
-}
+  })
 
 /**
  * Capture a Runtime and use it to fork Effect's, adding the forked fibers to the FiberSet.
@@ -405,3 +403,16 @@ export const size = <A, E>(self: FiberSet<A, E>): Effect.Effect<number> =>
  */
 export const join = <A, E>(self: FiberSet<A, E>): Effect.Effect<void, E> =>
   Deferred.await(self.deferred as Deferred.Deferred<void, E>)
+
+/**
+ * Wait until the fiber set is empty.
+ *
+ * @since 3.13.0
+ * @categories combinators
+ */
+export const awaitEmpty = <A, E>(self: FiberSet<A, E>): Effect.Effect<void> =>
+  Effect.whileLoop({
+    while: () => self.state._tag === "Open" && self.state.backing.size > 0,
+    body: () => Fiber.await(Iterable.unsafeHead(self)),
+    step: constVoid
+  })
diff --git a/packages/effect/test/FiberHandle.test.ts b/packages/effect/test/FiberHandle.test.ts
index 3efc116cab4..9078c03155a 100644
--- a/packages/effect/test/FiberHandle.test.ts
+++ b/packages/effect/test/FiberHandle.test.ts
@@ -1,4 +1,4 @@
-import { Deferred, Effect, Exit, Fiber, Ref } from "effect"
+import { Deferred, Effect, Exit, Fiber, Ref, TestClock } from "effect"
 import * as FiberHandle from "effect/FiberHandle"
 import * as it from "effect/test/utils/extend"
 import { assert, describe } from "vitest"
@@ -100,4 +100,16 @@ describe("FiberHandle", () => {
         )
       ))
     }))
+
+  it.scoped("awaitEmpty", () =>
+    Effect.gen(function*() {
+      const handle = yield* FiberHandle.make()
+      yield* FiberHandle.run(handle, Effect.sleep(1000))
+
+      const fiber = yield* Effect.fork(FiberHandle.awaitEmpty(handle))
+      yield* TestClock.adjust(500)
+      assert.isNull(fiber.unsafePoll())
+      yield* TestClock.adjust(500)
+      assert.isDefined(fiber.unsafePoll())
+    }))
 })
diff --git a/packages/effect/test/FiberMap.test.ts b/packages/effect/test/FiberMap.test.ts
index c0453ef452e..80a94a7efd8 100644
--- a/packages/effect/test/FiberMap.test.ts
+++ b/packages/effect/test/FiberMap.test.ts
@@ -1,4 +1,4 @@
-import { Array, Deferred, Effect, Exit, Fiber, Ref, Scope } from "effect"
+import { Array, Deferred, Effect, Exit, Fiber, Ref, Scope, TestClock } from "effect"
 import * as FiberMap from "effect/FiberMap"
 import * as it from "effect/test/utils/extend"
 import { assert, describe } from "vitest"
@@ -122,4 +122,19 @@ describe("FiberMap", () => {
         )
       ))
     }))
+
+  it.scoped("awaitEmpty", () =>
+    Effect.gen(function*() {
+      const map = yield* FiberMap.make<string>()
+      yield* FiberMap.run(map, "a", Effect.sleep(1000))
+      yield* FiberMap.run(map, "b", Effect.sleep(1000))
+      yield* FiberMap.run(map, "c", Effect.sleep(1000))
+      yield* FiberMap.run(map, "d", Effect.sleep(1000))
+
+      const fiber = yield* Effect.fork(FiberMap.awaitEmpty(map))
+      yield* TestClock.adjust(500)
+      assert.isNull(fiber.unsafePoll())
+      yield* TestClock.adjust(500)
+      assert.isDefined(fiber.unsafePoll())
+    }))
 })
diff --git a/packages/effect/test/FiberSet.test.ts b/packages/effect/test/FiberSet.test.ts
index 84ead23be50..3eaa8ddfbf8 100644
--- a/packages/effect/test/FiberSet.test.ts
+++ b/packages/effect/test/FiberSet.test.ts
@@ -1,4 +1,4 @@
-import { Array, Deferred, Effect, Exit, Fiber, Ref, Scope } from "effect"
+import { Array, Deferred, Effect, Exit, Fiber, Ref, Scope, TestClock } from "effect"
 import * as FiberSet from "effect/FiberSet"
 import * as it from "effect/test/utils/extend"
 import { assert, describe } from "vitest"
@@ -94,4 +94,19 @@ describe("FiberSet", () => {
         )
       ))
     }))
+
+  it.scoped("awaitEmpty", () =>
+    Effect.gen(function*() {
+      const set = yield* FiberSet.make()
+      yield* FiberSet.run(set, Effect.sleep(1000))
+      yield* FiberSet.run(set, Effect.sleep(1000))
+      yield* FiberSet.run(set, Effect.sleep(1000))
+      yield* FiberSet.run(set, Effect.sleep(1000))
+
+      const fiber = yield* Effect.fork(FiberSet.awaitEmpty(set))
+      yield* TestClock.adjust(500)
+      assert.isNull(fiber.unsafePoll())
+      yield* TestClock.adjust(500)
+      assert.isDefined(fiber.unsafePoll())
+    }))
 })