Skip to content

Commit

Permalink
Avoid using a HashSet in BlockedRequests#run
Browse files Browse the repository at this point in the history
  • Loading branch information
kyri-petrou committed Apr 5, 2024
1 parent f660fbc commit c6ba3f3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ object CompletedRequestMap {
*/
def fromIterable[E, A](iterable: Iterable[(Request[E, A], Exit[E, A])]): CompletedRequestMap = {
val builder = HashMap.newBuilder[Any, Exit[Any, Any]]
builder.sizeHint(iterable.size)
builder ++= iterable
new CompletedRequestMap(builder.result())
}
Expand All @@ -109,13 +108,10 @@ object CompletedRequestMap {
*/
def fromIterableOption[E, A](iterable: Iterable[(Request[E, A], Exit[E, Option[A]])]): CompletedRequestMap = {
val builder = HashMap.newBuilder[Any, Exit[Any, Any]]
builder.sizeHint(iterable.size)
iterable.foreach { case (request, result) =>
result match {
case Exit.Failure(e) => builder += (request -> Exit.failCause(e))
case Exit.Success(Some(a)) => builder += (request -> Exit.succeed(a))
case Exit.Success(None) => ()
}
iterable.foreach {
case (request, Exit.Failure(e)) => builder += (request -> Exit.failCause(e))
case (request, Exit.Success(Some(a))) => builder += (request -> Exit.succeed(a))
case (_, Exit.Success(None)) => ()
}
new CompletedRequestMap(builder.result())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,30 +107,31 @@ private[query] sealed trait BlockedRequests[-R] { self =>
val flattened = BlockedRequests.flatten(self)
ZIO.foreachDiscard(flattened) { requestsByDataSource =>
ZIO.foreachParDiscard(requestsByDataSource.toIterable) { case (dataSource, sequential) =>
val requests = sequential.map(_.map(_.request))
val nRequests = sequential.foldLeft(0)(_ + _.size)
val requests = sequential.map(_.map(_.request))

dataSource
.runAll(requests)
.foldCause(
cause => {
.catchAllCause { cause =>
ZIO.succeed {
val exit = Exit.failCause(cause).asInstanceOf[Exit[Any, Any]]
val map = new mutable.HashMap[Request[_, _], Exit[Any, Any]]()
map.sizeHint(nRequests)
map ++= requests.view.flatten.map(r => r.asInstanceOf[Request[Any, Any]] -> exit)
},
_.toMutableMap
)
CompletedRequestMap.fromIterable(
requests.view.flatten.map(r => r.asInstanceOf[Request[Any, Any]] -> exit)
)
}
}
.flatMap { completedRequests =>
ZQuery.cachingEnabled.getWith {
if (_) {
val completed = mutable.HashSet.empty[Request[_, _]]
completed.sizeHint(nRequests)
completePromisesWith(dataSource, completedRequests, sequential)(completed.add)
val leftovers = completedRequests.keySet.diff(completed)
if (leftovers.nonEmpty) cacheLeftovers(cache, completedRequests, leftovers) else ZIO.unit
ZQuery.cachingEnabled.getWith { cachingEnabled =>
val completedRequestsM = completedRequests.toMutableMap
if (cachingEnabled) {
completePromises(dataSource, sequential) { req =>
// Pop the entry, and fallback to the immutable one if we already removed it
completedRequestsM.remove(req) orElse completedRequests.lookup(req)
}
// cache responses that were not requested but were completed by the DataSource
if (completedRequestsM.nonEmpty) cacheLeftovers(cache, completedRequestsM) else ZIO.unit
} else {
ZIO.succeed(completePromisesWith(dataSource, completedRequests, sequential)(_ => ()))
// No need to remove entries here since we don't need to know which ones we need to put in the cache
ZIO.succeed(completePromises(dataSource, sequential)(completedRequestsM.get))
}
}
}
Expand Down Expand Up @@ -276,47 +277,40 @@ private[query] object BlockedRequests {
else
parallel.sequential :: sequential

private def completePromisesWith(
private def completePromises(
dataSource: DataSource[_, Any],
completedRequests: mutable.HashMap[Request[_, _], Exit[Any, Any]],
sequential: Chunk[Chunk[BlockedRequest[Any]]]
)(onSuccess: Request[?, ?] => Unit): Unit =
)(get: Request[?, ?] => Option[Exit[Any, Any]]): Unit =
sequential.foreach {
_.foreach { br =>
val req = br.request
val res = completedRequests.get(req) match {
case Some(exit) =>
onSuccess(req)
exit.asInstanceOf[Exit[br.Failure, br.Success]]
case None => Exit.die(QueryFailure(dataSource, req))
val res = get(req) match {
case Some(exit) => exit.asInstanceOf[Exit[br.Failure, br.Success]]
case None => Exit.die(QueryFailure(dataSource, req))
}
br.result.unsafe.done(res)(Unsafe.unsafe)
}
}

private def cacheLeftovers(
cache: Cache,
completedRequests: mutable.HashMap[Request[_, _], Exit[Any, Any]],
leftovers: Iterable[Request[_, _]]
map: mutable.HashMap[Request[_, _], Exit[Any, Any]]
)(implicit trace: Trace): UIO[Unit] =
ZIO.fiberIdWith { fiberId =>
val iter = leftovers.iterator
cache match {
case cache: Cache.Default =>
ZIO.succeedUnsafe { implicit unsafe =>
while (iter.hasNext) {
val request = iter.next()
val exit = completedRequests(request)
map.foreachEntry { case (request, exit) =>
val promise = Promise.unsafe.make[Any, Any](fiberId)
promise.unsafe.done(exit)
cache.putUnsafe(request.asInstanceOf[Request[Any, Any]], promise)
}
}
case cache =>
val iter = map.iterator
ZIO.whileLoop(iter.hasNext) {
Promise.makeAs[Any, Any](fiberId).flatMap { promise =>
val request = iter.next()
val exit = completedRequests(request)
val (request, exit) = iter.next()
cache
.get(request)
.orElse(
Expand Down

0 comments on commit c6ba3f3

Please sign in to comment.