Skip to content

Commit

Permalink
Merge pull request typelevel#4121 from lenguyenthanh/make-io-onError-…
Browse files Browse the repository at this point in the history
…consistent-with-applicative-error

Make `IO#onError` consistent with `ApplicativeError`
  • Loading branch information
armanbilge authored Aug 28, 2024
2 parents 7168625 + c6026b3 commit dab9d23
Showing 1 changed file with 35 additions and 12 deletions.
47 changes: 35 additions & 12 deletions core/shared/src/main/scala/cats/effect/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -489,11 +489,8 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
def guarantee(finalizer: IO[Unit]): IO[A] =
// this is a little faster than the default implementation, which helps Resource
IO uncancelable { poll =>
val handled = finalizer handleErrorWith { t =>
IO.executionContext.flatMap(ec => IO(ec.reportFailure(t)))
}

poll(this).onCancel(finalizer).onError(_ => handled).flatTap(_ => finalizer)
val onError: PartialFunction[Throwable, IO[Unit]] = { case _ => finalizer.reportError }
poll(this).onCancel(finalizer).onError(onError).flatTap(_ => finalizer)
}

/**
Expand All @@ -519,12 +516,10 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
def guaranteeCase(finalizer: OutcomeIO[A @uncheckedVariance] => IO[Unit]): IO[A] =
IO.uncancelable { poll =>
val finalized = poll(this).onCancel(finalizer(Outcome.canceled))
val handled = finalized.onError { e =>
finalizer(Outcome.errored(e)).handleErrorWith { t =>
IO.executionContext.flatMap(ec => IO(ec.reportFailure(t)))
}
val onError: PartialFunction[Throwable, IO[Unit]] = {
case e => finalizer(Outcome.errored(e)).reportError
}
handled.flatTap(a => finalizer(Outcome.succeeded(IO.pure(a))))
finalized.onError(onError).flatTap { (a: A) => finalizer(Outcome.succeeded(IO.pure(a))) }
}

def handleError[B >: A](f: Throwable => B): IO[B] =
Expand Down Expand Up @@ -588,8 +583,20 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
def onCancel(fin: IO[Unit]): IO[A] =
IO.OnCancel(this, fin)

def onError(f: Throwable => IO[Unit]): IO[A] =
handleErrorWith(t => f(t).voidError *> IO.raiseError(t))
@deprecated("Use onError with PartialFunction argument", "3.6.0")
def onError(f: Throwable => IO[Unit]): IO[A] = {
val pf: PartialFunction[Throwable, IO[Unit]] = { case t => f(t).reportError }
onError(pf)
}

/**
* Execute a callback on certain errors, then rethrow them. Any non matching error is rethrown
* as well.
*
* Implements `ApplicativeError.onError`.
*/
def onError(pf: PartialFunction[Throwable, IO[Unit]]): IO[A] =
handleErrorWith(t => pf.applyOrElse(t, (_: Throwable) => IO.unit) *> IO.raiseError(t))

/**
* Like `Parallel.parProductL`
Expand Down Expand Up @@ -928,6 +935,19 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
def void: IO[Unit] =
map(_ => ())

/**
* Similar to [[IO.voidError]], but also reports the error.
*/
private[effect] def reportError(implicit ev: A <:< Unit): IO[Unit] = {
val _ = ev
asInstanceOf[IO[Unit]].handleErrorWith { t =>
IO.executionContext.flatMap(ec => IO(ec.reportFailure(t)))
}
}

/**
* Discard any error raised by the source.
*/
def voidError(implicit ev: A <:< Unit): IO[Unit] = {
val _ = ev
asInstanceOf[IO[Unit]].handleError(_ => ())
Expand Down Expand Up @@ -1975,6 +1995,9 @@ object IO extends IOCompanionPlatform with IOLowPriorityImplicits with TuplePara
override def handleError[A](fa: IO[A])(f: Throwable => A): IO[A] =
fa.handleError(f)

override def onError[A](fa: IO[A])(pf: PartialFunction[Throwable, IO[Unit]]): IO[A] =
fa.onError(pf)

override def timeout[A](fa: IO[A], duration: FiniteDuration)(
implicit ev: TimeoutException <:< Throwable): IO[A] = {
fa.timeout(duration)
Expand Down

0 comments on commit dab9d23

Please sign in to comment.