diff --git a/build.sbt b/build.sbt index 5c64b6b7..5f51020e 100644 --- a/build.sbt +++ b/build.sbt @@ -23,7 +23,7 @@ inThisBuild( addCommandAlias("fmt", "all scalafmtSbt scalafmt test:scalafmt") addCommandAlias("check", "all scalafmtSbtCheck scalafmtCheck test:scalafmtCheck") -val zioVersion = "2.0.0-RC3" +val zioVersion = "2.0.0-RC4" lazy val root = project .in(file(".")) diff --git a/zio-query/shared/src/main/scala/zio/query/ZQuery.scala b/zio-query/shared/src/main/scala/zio/query/ZQuery.scala index 3bda7b82..63ea2279 100644 --- a/zio-query/shared/src/main/scala/zio/query/ZQuery.scala +++ b/zio-query/shared/src/main/scala/zio/query/ZQuery.scala @@ -606,14 +606,14 @@ final class ZQuery[-R, +E, +A] private (private val step: ZIO[R, Nothing, Result /** * Returns a new query that executes this one and times the execution. */ - final def timed(implicit trace: ZTraceElement): ZQuery[R with Clock, E, (Duration, A)] = + final def timed(implicit trace: ZTraceElement): ZQuery[R, E, (Duration, A)] = summarized(Clock.nanoTime)((start, end) => Duration.fromNanos(end - start)) /** * Returns an effect that will timeout this query, returning `None` if the * timeout elapses before the query was completed. */ - final def timeout(duration: => Duration)(implicit trace: ZTraceElement): ZQuery[R with Clock, E, Option[A]] = + final def timeout(duration: => Duration)(implicit trace: ZTraceElement): ZQuery[R, E, Option[A]] = timeoutTo(None)(Some(_))(duration) /** @@ -622,7 +622,7 @@ final class ZQuery[-R, +E, +A] private (private val step: ZIO[R, Nothing, Result */ final def timeoutFail[E1 >: E](e: => E1)(duration: => Duration)(implicit trace: ZTraceElement - ): ZQuery[R with Clock, E1, A] = + ): ZQuery[R, E1, A] = timeoutTo(ZQuery.fail(e))(ZQuery.succeedNow)(duration).flatten /** @@ -631,7 +631,7 @@ final class ZQuery[-R, +E, +A] private (private val step: ZIO[R, Nothing, Result */ final def timeoutFailCause[E1 >: E](cause: => Cause[E1])(duration: => Duration)(implicit trace: ZTraceElement - ): ZQuery[R with Clock, E1, A] = + ): ZQuery[R, E1, A] = timeoutTo(ZQuery.failCause(cause))(ZQuery.succeedNow)(duration).flatten /** @@ -641,7 +641,7 @@ final class ZQuery[-R, +E, +A] private (private val step: ZIO[R, Nothing, Result @deprecated("use timeoutFailCause", "0.3.0") final def timeoutHalt[E1 >: E](cause: => Cause[E1])(duration: => Duration)(implicit trace: ZTraceElement - ): ZQuery[R with Clock, E1, A] = + ): ZQuery[R, E1, A] = timeoutFailCause(cause)(duration) /** @@ -1216,17 +1216,7 @@ object ZQuery { )( f: A => ZQuery[R, E, B] )(implicit bf: BuildFrom[Collection[A], B, Collection[B]], trace: ZTraceElement): ZQuery[R, E, Collection[B]] = - if (as.isEmpty) ZQuery.succeed(bf.newBuilder(as).result()) - else { - val iterator = as.iterator - var builder: ZQuery[R, E, Builder[B, Collection[B]]] = null - while (iterator.hasNext) { - val a = iterator.next() - if (builder eq null) builder = f(a).map(bf.newBuilder(as) += _) - else builder = builder.zipWithPar(f(a))(_ += _) - } - builder.map(_.result()) - } + ZQuery(ZIO.foreachPar(Chunk.fromIterable(as))(f(_).step).map(Result.collectAllPar(_).map(bf.fromSpecific(as)))) /** * Performs a query for each element in a Set, collecting the results @@ -1513,38 +1503,38 @@ object ZQuery { final class TimeoutTo[-R, +E, +A, +B](self: ZQuery[R, E, A], b: () => B) { def apply[B1 >: B]( f: A => B1 - )(duration: => Duration)(implicit trace: ZTraceElement): ZQuery[R with Clock, E, B1] = - ZQuery.environment[Clock].flatMap { clock => - def race( - query: ZQuery[R, E, B1], - fiber: Fiber[Nothing, B1] - ): ZQuery[R, E, B1] = - ZQuery { - query.step.raceWith[R, Nothing, Nothing, B1, Result[R, E, B1]](fiber.join)( - (leftExit, rightFiber) => - leftExit.foldZIO( - cause => rightFiber.interrupt *> ZIO.succeedNow(Result.fail(cause)), - result => - result match { - case Result.Blocked(blockedRequests, continue) => - continue match { - case Continue.Effect(query) => - ZIO.succeedNow(Result.blocked(blockedRequests, Continue.effect(race(query, fiber)))) - case Continue.Get(io) => - ZIO.succeedNow( - Result.blocked(blockedRequests, Continue.effect(race(ZQuery.fromZIO(io), fiber))) - ) - } - case Result.Done(value) => rightFiber.interrupt *> ZIO.succeedNow(Result.done(value)) - case Result.Fail(cause) => rightFiber.interrupt *> ZIO.succeedNow(Result.fail(cause)) - } - ), - (rightExit, leftFiber) => leftFiber.interrupt *> ZIO.succeedNow(Result.fromExit(rightExit)) - ) - } + )(duration: => Duration)(implicit trace: ZTraceElement): ZQuery[R, E, B1] = { + + def race( + query: ZQuery[R, E, B1], + fiber: Fiber[Nothing, B1] + ): ZQuery[R, E, B1] = + ZQuery { + query.step.raceWith[R, Nothing, Nothing, B1, Result[R, E, B1]](fiber.join)( + (leftExit, rightFiber) => + leftExit.foldZIO( + cause => rightFiber.interrupt *> ZIO.succeedNow(Result.fail(cause)), + result => + result match { + case Result.Blocked(blockedRequests, continue) => + continue match { + case Continue.Effect(query) => + ZIO.succeedNow(Result.blocked(blockedRequests, Continue.effect(race(query, fiber)))) + case Continue.Get(io) => + ZIO.succeedNow( + Result.blocked(blockedRequests, Continue.effect(race(ZQuery.fromZIO(io), fiber))) + ) + } + case Result.Done(value) => rightFiber.interrupt *> ZIO.succeedNow(Result.done(value)) + case Result.Fail(cause) => rightFiber.interrupt *> ZIO.succeedNow(Result.fail(cause)) + } + ), + (rightExit, leftFiber) => leftFiber.interrupt *> ZIO.succeedNow(Result.fromExit(rightExit)) + ) + } - ZQuery.fromZIO(clock.get.sleep(duration).interruptible.as(b()).fork).flatMap(fiber => race(self.map(f), fiber)) - } + ZQuery.fromZIO(ZIO.sleep(duration).interruptible.as(b()).fork).flatMap(fiber => race(self.map(f), fiber)) + } } final class ServiceWithPartiallyApplied[R](private val dummy: Boolean = true) extends AnyVal { diff --git a/zio-query/shared/src/main/scala/zio/query/internal/Continue.scala b/zio-query/shared/src/main/scala/zio/query/internal/Continue.scala index 03750fe2..a41bfd1e 100644 --- a/zio-query/shared/src/main/scala/zio/query/internal/Continue.scala +++ b/zio-query/shared/src/main/scala/zio/query/internal/Continue.scala @@ -1,9 +1,9 @@ package zio.query.internal +import zio._ import zio.query._ import zio.query.internal.Continue._ import zio.stacktracer.TracingImplicits.disableAutoTrace -import zio.{ CanFail, Cause, IO, Ref, ZEnvironment, ZIO, ZTraceElement } /** * A `Continue[R, E, A]` models a continuation of a blocked request that @@ -167,6 +167,42 @@ private[query] object Continue { } } + /** + * Collects a collection of continuation into a continuation returning a + * collection of their results, in parallel. + */ + def collectAllPar[R, E, A, Collection[+Element] <: Iterable[Element]]( + continues: Collection[Continue[R, E, A]] + )(implicit + bf: BuildFrom[Collection[Continue[R, E, A]], A, Collection[A]], + trace: ZTraceElement + ): Continue[R, E, Collection[A]] = + continues.zipWithIndex + .foldLeft[(Chunk[(ZQuery[R, E, A], Int)], Chunk[(IO[E, A], Int)])]((Chunk.empty, Chunk.empty)) { + case ((queries, ios), (continue, index)) => + continue match { + case Effect(query) => (queries :+ ((query, index)), ios) + case Get(io) => (queries, ios :+ ((io, index))) + } + } match { + case (Chunk(), ios) => + get(ZIO.collectAll(ios.map(_._1)).map(bf.fromSpecific(continues))) + case (queries, ios) => + val query = ZQuery.collectAllPar(queries.map(_._1)).flatMap { as => + val array = Array.ofDim[AnyRef](continues.size) + as.zip(queries.map(_._2)).foreach { case (a, i) => + array(i) = a.asInstanceOf[AnyRef] + } + ZQuery.fromZIO(ZIO.collectAll(ios.map(_._1))).map { as => + as.zip(ios.map(_._2)).foreach { case (a, i) => + array(i) = a.asInstanceOf[AnyRef] + } + bf.fromSpecific(continues)(array.asInstanceOf[Array[A]]) + } + } + effect(query) + } + /** * Constructs a continuation that may perform arbitrary effects. */ diff --git a/zio-query/shared/src/main/scala/zio/query/internal/Result.scala b/zio-query/shared/src/main/scala/zio/query/internal/Result.scala index 94f1b796..aa054bc0 100644 --- a/zio-query/shared/src/main/scala/zio/query/internal/Result.scala +++ b/zio-query/shared/src/main/scala/zio/query/internal/Result.scala @@ -1,9 +1,9 @@ package zio.query.internal +import zio._ import zio.query.internal.Result._ import zio.query.{ DataSourceAspect, Described } import zio.stacktracer.TracingImplicits.disableAutoTrace -import zio.{ CanFail, Cause, Exit, ZEnvironment, ZTraceElement } /** * A `Result[R, E, A]` is the result of running one step of a `ZQuery`. A @@ -111,6 +111,43 @@ private[query] object Result { def blocked[R, E, A](blockedRequests: BlockedRequests[R], continue: Continue[R, E, A]): Result[R, E, A] = Blocked(blockedRequests, continue) + /** + * Collects a collection of results into a single result. Blocked requests + * and their continuations will be executed in parallel. + */ + def collectAllPar[R, E, A, Collection[+Element] <: Iterable[Element]](results: Collection[Result[R, E, A]])(implicit + bf: BuildFrom[Collection[Result[R, E, A]], A, Collection[A]], + trace: ZTraceElement + ): Result[R, E, Collection[A]] = + results.zipWithIndex + .foldLeft[(Chunk[((BlockedRequests[R], Continue[R, E, A]), Int)], Chunk[(A, Int)], Chunk[(Cause[E], Int)])]( + (Chunk.empty, Chunk.empty, Chunk.empty) + ) { case ((blocked, done, fails), (result, index)) => + result match { + case Blocked(br, c) => (blocked :+ (((br, c), index)), done, fails) + case Done(a) => (blocked, done :+ ((a, index)), fails) + case Fail(e) => (blocked, done, fails :+ ((e, index))) + } + } match { + case (Chunk(), done, Chunk()) => + Result.done(bf.fromSpecific(results)(done.map(_._1))) + case (blocked, done, Chunk()) => + val blockedRequests = blocked.map(_._1._1).foldLeft[BlockedRequests[R]](BlockedRequests.empty)(_ && _) + val continue = Continue.collectAllPar(blocked.map(_._1._2)).map { as => + val array = Array.ofDim[AnyRef](results.size) + as.zip(blocked.map(_._2)).foreach { case (a, i) => + array(i) = a.asInstanceOf[AnyRef] + } + done.foreach { case (a, i) => + array(i) = a.asInstanceOf[AnyRef] + } + bf.fromSpecific(results)(array.asInstanceOf[Array[A]]) + } + Result.blocked(blockedRequests, continue) + case (_, _, fail) => + Result.fail(fail.map(_._1).foldLeft[Cause[E]](Cause.empty)(_ && _)) + } + /** * Constructs a result that is done with the specified value. */ diff --git a/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala b/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala index d52ed350..d874c2cb 100644 --- a/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala +++ b/zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala @@ -172,7 +172,7 @@ object ZQuerySpec extends ZIOBaseSpec { richUsers <- ZQuery.foreachPar(users) { user => Sources .getPayment(user.paymentId) - .zipPar(Sources.getAddress(user.addressId)) + .zip(Sources.getAddress(user.addressId)) .map { case (payment, address) => (user, payment, address) } @@ -281,8 +281,8 @@ object ZQuerySpec extends ZIOBaseSpec { case object GetAllIds extends UserRequest[List[Int]] final case class GetNameById(id: Int) extends UserRequest[String] - val UserRequestDataSource: DataSource[Console, UserRequest[Any]] = - DataSource.Batched.make[Console, UserRequest[Any]]("UserRequestDataSource") { requests => + val UserRequestDataSource: DataSource[Any, UserRequest[Any]] = + DataSource.Batched.make[Any, UserRequest[Any]]("UserRequestDataSource") { requests => ZIO.when(requests.toSet.size != requests.size)(ZIO.dieMessage("Duplicate requests)")) *> Console.printLine(requests.toString).orDie *> ZIO.succeed { @@ -294,25 +294,25 @@ object ZQuerySpec extends ZIOBaseSpec { } } - val getAllUserIds: ZQuery[Console, Nothing, List[Int]] = + val getAllUserIds: ZQuery[Any, Nothing, List[Int]] = ZQuery.fromRequest(GetAllIds)(UserRequestDataSource) - def getUserNameById(id: Int): ZQuery[Console, Nothing, String] = + def getUserNameById(id: Int): ZQuery[Any, Nothing, String] = ZQuery.fromRequest(GetNameById(id))(UserRequestDataSource) - val getAllUserNames: ZQuery[Console, Nothing, List[String]] = + val getAllUserNames: ZQuery[Any, Nothing, List[String]] = for { userIds <- getAllUserIds userNames <- ZQuery.foreachPar(userIds)(getUserNameById) } yield userNames case object GetFoo extends Request[Nothing, String] - val getFoo: ZQuery[Console, Nothing, String] = ZQuery.fromRequest(GetFoo)( + val getFoo: ZQuery[Any, Nothing, String] = ZQuery.fromRequest(GetFoo)( DataSource.fromFunctionZIO("foo")(_ => Console.printLine("Running foo query") *> ZIO.succeed("foo")) ) case object GetBar extends Request[Nothing, String] - val getBar: ZQuery[Console, Nothing, String] = ZQuery.fromRequest(GetBar)( + val getBar: ZQuery[Any, Nothing, String] = ZQuery.fromRequest(GetBar)( DataSource.fromFunctionZIO("bar")(_ => Console.printLine("Running bar query") *> ZIO.succeed("bar")) ) @@ -447,12 +447,12 @@ object ZQuerySpec extends ZIOBaseSpec { 4 -> "d" ) - def backendGetAll: ZIO[Console, Nothing, Map[Int, String]] = + def backendGetAll: ZIO[Any, Nothing, Map[Int, String]] = for { _ <- Console.printLine("getAll called").orDie } yield testData - def backendGetSome(ids: Chunk[Int]): ZIO[Console, Nothing, Map[Int, String]] = + def backendGetSome(ids: Chunk[Int]): ZIO[Any, Nothing, Map[Int, String]] = for { _ <- Console.printLine(s"getSome ${ids.mkString(", ")} called").orDie } yield ids.flatMap { id => @@ -468,10 +468,10 @@ object ZQuerySpec extends ZIOBaseSpec { final case class Get(id: Int) extends Req[String] } - val ds: DataSource.Batched[Console, Req[_]] = new DataSource.Batched[Console, Req[_]] { + val ds: DataSource.Batched[Any, Req[_]] = new DataSource.Batched[Any, Req[_]] { override def run( requests: Chunk[Req[_]] - )(implicit trace: ZTraceElement): ZIO[Console, Nothing, CompletedRequestMap] = { + )(implicit trace: ZTraceElement): ZIO[Any, Nothing, CompletedRequestMap] = { val (all, oneByOne) = requests.partition { case Req.GetAll => true case Req.Get(_) => false @@ -505,8 +505,8 @@ object ZQuerySpec extends ZIOBaseSpec { override val identifier: String = "test" } - def getAll: ZQuery[Console, DataSourceErrors, Map[Int, String]] = + def getAll: ZQuery[Any, DataSourceErrors, Map[Int, String]] = ZQuery.fromRequest(Req.GetAll)(ds) - def get(id: Int): ZQuery[Console, DataSourceErrors, String] = + def get(id: Int): ZQuery[Any, DataSourceErrors, String] = ZQuery.fromRequest(Req.Get(id))(ds) }