Skip to content

Commit

Permalink
Add an abstract Synchronous cache and optimize BlockedRequests
Browse files Browse the repository at this point in the history
  • Loading branch information
kyri-petrou committed Jul 2, 2024
1 parent a57ea1d commit d310588
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 28 deletions.
51 changes: 35 additions & 16 deletions zio-query/shared/src/main/scala/zio/query/Cache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,33 +74,52 @@ object Cache {
def empty(expectedNumOfElements: Int)(implicit trace: Trace): UIO[Cache] =
ZIO.succeed(Cache.unsafeMake(expectedNumOfElements))

private[query] final class Default(private val map: ConcurrentHashMap[Request[_, _], Promise[_, _]]) extends Cache {
/**
* A 'Synchronous' cache is one that doesn't require an effect to look up its
* value. Prefer extending this class when implementing a cache that doesn't
* perform any asynchronous IO.
*/
abstract class Synchronous extends Cache {
def getOrNull[E, A](request: Request[E, A]): Promise[E, A]
def lookupNow[E, A, B](request: Request[_, _]): Either[Promise[E, B], Promise[E, B]]
def putNow[E, A](request: Request[E, A], result: Promise[E, A]): Unit
def removeNow[E, A](request: Request[E, A]): Unit

def get[E, A](request: Request[E, A])(implicit trace: Trace): IO[Unit, Promise[E, A]] =
final def get[E, A](request: Request[E, A])(implicit trace: Trace): IO[Unit, Promise[E, A]] =
ZIO.suspendSucceed {
val out = map.get(request).asInstanceOf[Promise[E, A]]
if (out eq null) Exit.fail(()) else Exit.succeed(out)
val p = getOrNull(request)
if (p eq null) Exit.fail(()) else Exit.succeed(p)
}

def lookup[E, A, B](request: A)(implicit
ev: A <:< Request[E, B],
trace: Trace
): UIO[Either[Promise[E, B], Promise[E, B]]] =
ZIO.succeed(lookupUnsafe(request)(Unsafe.unsafe))
final def lookup[E, A, B](
request: A
)(implicit ev: A <:< Request[E, B], trace: Trace): UIO[Either[Promise[E, B], Promise[E, B]]] =
ZIO.succeed(lookupNow(request))

final def put[E, A](request: Request[E, A], result: Promise[E, A])(implicit trace: Trace): UIO[Unit] =
ZIO.succeed(putNow(request, result))

final def remove[E, A](request: Request[E, A])(implicit trace: Trace): UIO[Unit] =
ZIO.succeed(removeNow(request))
}

private final class Default(map: ConcurrentHashMap[Request[_, _], Promise[_, _]]) extends Synchronous {
private implicit val unsafe: Unsafe = Unsafe.unsafe

def getOrNull[E, A](request: Request[E, A]): Promise[E, A] =
map.get(request).asInstanceOf[Promise[E, A]]

def lookupUnsafe[E, A, B](request: Request[_, _])(implicit
unsafe: Unsafe
): Either[Promise[E, B], Promise[E, B]] = {
def lookupNow[E, A, B](request: Request[_, _]): Either[Promise[E, B], Promise[E, B]] = {
val newPromise = Promise.unsafe.make[E, B](FiberId.None)
val existing = map.putIfAbsent(request, newPromise).asInstanceOf[Promise[E, B]]
if (existing eq null) Left(newPromise) else Right(existing)
}

def put[E, A](request: Request[E, A], result: Promise[E, A])(implicit trace: Trace): UIO[Unit] =
ZIO.succeed(map.put(request, result))
def putNow[E, A](request: Request[E, A], result: Promise[E, A]): Unit =
map.put(request, result)

def remove[E, A](request: Request[E, A])(implicit trace: Trace): UIO[Unit] =
ZIO.succeed(map.remove(request))
def removeNow[E, A](request: Request[E, A]): Unit =
map.remove(request)
}

// TODO: Initialize the map with a sensible default value. Default is 16, which seems way too small for a cache
Expand Down
4 changes: 2 additions & 2 deletions zio-query/shared/src/main/scala/zio/query/ZQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1559,8 +1559,8 @@ object ZQuery {
}

cache match {
case cache: Cache.Default => foldPromise(cache.lookupUnsafe(request)(Unsafe.unsafe))
case cache => CachedResult.Effectful(cache.lookup(request).flatMap(foldPromise(_).toZIO))
case cache: Cache.Synchronous => foldPromise(cache.lookupNow(request))
case cache => CachedResult.Effectful(cache.lookup(request).flatMap(foldPromise(_).toZIO))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,28 +285,38 @@ private[query] object BlockedRequests {
private def completePromises(
dataSource: DataSource[_, Any],
sequential: Chunk[Chunk[BlockedRequest[Any]]]
)(get: Request[?, ?] => Option[Exit[Any, Any]]): Unit =
sequential.foreach {
_.foreach { br =>
val req = br.request
val res = get(req) match {
case Some(exit) => exit.asInstanceOf[Exit[br.Failure, br.Success]]
case None => Exit.die(QueryFailure(dataSource, req))
}
)(get: Request[?, ?] => Option[Exit[Any, Any]]): Unit = {

def loopInner(c: Chunk[BlockedRequest[Any]]): Unit = {
val it = c.iterator
while (it.hasNext) {
val br = it.next()
val req = br.request
val exit = get(req)
val res =
if (exit.isEmpty) Exit.die(QueryFailure(dataSource, req))
else exit.get.asInstanceOf[Exit[br.Failure, br.Success]]
br.result.unsafe.done(res)(Unsafe.unsafe)
}
}

val it0 = sequential.iterator
while (it0.hasNext) {
val next = it0.next()
loopInner(next)
}
}

private def cacheLeftovers(
cache: Cache,
map: mutable.HashMap[Request[_, _], Exit[Any, Any]]
)(implicit trace: Trace): UIO[Unit] =
cache match {
case cache: Cache.Default =>
case cache: Cache.Synchronous =>
ZIO.succeedUnsafe { implicit unsafe =>
map.foreach { case (request: Request[Any, Any], exit) =>
cache
.lookupUnsafe(request)
.lookupNow(request)
.merge
.unsafe
.done(exit)
Expand Down

0 comments on commit d310588

Please sign in to comment.