Skip to content

Commit

Permalink
feat: Add TraversalBuilder.getValuePresentedSource method for further…
Browse files Browse the repository at this point in the history
… optimization.
  • Loading branch information
He-Pin committed Jan 10, 2025
1 parent b160861 commit 11e9547
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@ import org.apache.pekko
import pekko.NotUsed
import pekko.stream._
import pekko.stream.impl.TraversalTestUtils._
import pekko.stream.scaladsl.Keep
import pekko.stream.impl.fusing.IterableSource
import pekko.stream.impl.fusing.GraphStages.{ FutureSource, SingleSource }
import pekko.stream.scaladsl.{ Keep, Source }
import pekko.util.OptionVal
import pekko.testkit.PekkoSpec

import scala.concurrent.Future

class TraversalBuilderSpec extends PekkoSpec {

"CompositeTraversalBuilder" must {
Expand Down Expand Up @@ -447,4 +452,93 @@ class TraversalBuilderSpec extends PekkoSpec {
}
}

"find Source.single via TraversalBuilder" in {
TraversalBuilder.getSingleSource(Source.single("a")).get.elem should ===("a")
TraversalBuilder.getSingleSource(Source(List("a", "b"))) should be(OptionVal.None)

val singleSourceA = new SingleSource("a")
TraversalBuilder.getSingleSource(singleSourceA) should be(OptionVal.Some(singleSourceA))

TraversalBuilder.getSingleSource(Source.single("c").async) should be(OptionVal.None)
TraversalBuilder.getSingleSource(Source.single("d").mapMaterializedValue(_ => "Mat")) should be(OptionVal.None)
}

"find Source.single via TraversalBuilder with getValuePresentedSource" in {
TraversalBuilder.getValuePresentedSource(Source.single("a")).get.asInstanceOf[SingleSource[String]].elem should ===(
"a")
val singleSourceA = new SingleSource("a")
TraversalBuilder.getValuePresentedSource(singleSourceA) should be(OptionVal.Some(singleSourceA))

TraversalBuilder.getValuePresentedSource(Source.single("c").async) should be(OptionVal.None)
TraversalBuilder.getValuePresentedSource(Source.single("d").mapMaterializedValue(_ => "Mat")) should be(
OptionVal.None)
}

"find Source.empty via TraversalBuilder with getValuePresentedSource" in {
val emptySource = EmptySource
TraversalBuilder.getValuePresentedSource(emptySource) should be(OptionVal.Some(emptySource))

TraversalBuilder.getValuePresentedSource(Source.empty.async) should be(OptionVal.None)
TraversalBuilder.getValuePresentedSource(Source.empty.mapMaterializedValue(_ => "Mat")) should be(OptionVal.None)
}

"find javadsl Source.empty via TraversalBuilder with getValuePresentedSource" in {
import pekko.stream.javadsl.Source
val emptySource = Source.empty()
TraversalBuilder.getValuePresentedSource(Source.empty()) should be(OptionVal.Some(emptySource))

TraversalBuilder.getValuePresentedSource(Source.empty().async) should be(OptionVal.None)
TraversalBuilder.getValuePresentedSource(Source.empty().mapMaterializedValue(_ => "Mat")) should be(OptionVal.None)
}

"find Source.future via TraversalBuilder with getValuePresentedSource" in {
val future = Future.successful("a")
TraversalBuilder.getValuePresentedSource(Source.future(future)).get.asInstanceOf[FutureSource[String]].future should ===(
future)
val futureSourceA = new FutureSource(future)
TraversalBuilder.getValuePresentedSource(futureSourceA) should be(OptionVal.Some(futureSourceA))

TraversalBuilder.getValuePresentedSource(Source.future(future).async) should be(OptionVal.None)
TraversalBuilder.getValuePresentedSource(Source.future(future).mapMaterializedValue(_ => "Mat")) should be(
OptionVal.None)
}

"find Source.iterable via TraversalBuilder with getValuePresentedSource" in {
val iterable = List("a")
TraversalBuilder.getValuePresentedSource(Source(iterable)).get.asInstanceOf[IterableSource[String]].elements should ===(
iterable)
val iterableSource = new IterableSource(iterable)
TraversalBuilder.getValuePresentedSource(iterableSource) should be(OptionVal.Some(iterableSource))

TraversalBuilder.getValuePresentedSource(Source(iterable).async) should be(OptionVal.None)
TraversalBuilder.getValuePresentedSource(Source(iterable).mapMaterializedValue(_ => "Mat")) should be(
OptionVal.None)
}

"find Source.javaStreamSource via TraversalBuilder with getValuePresentedSource" in {
val javaStream = java.util.stream.Stream.empty[String]()
TraversalBuilder.getValuePresentedSource(Source.fromJavaStream(() => javaStream)).get
.asInstanceOf[JavaStreamSource[String, _]].open() shouldEqual javaStream
val streamSource = new JavaStreamSource(() => javaStream)
TraversalBuilder.getValuePresentedSource(streamSource) should be(OptionVal.Some(streamSource))

TraversalBuilder.getValuePresentedSource(Source.fromJavaStream(() => javaStream).async) should be(OptionVal.None)
TraversalBuilder.getValuePresentedSource(
Source.fromJavaStream(() => javaStream).mapMaterializedValue(_ => "Mat")) should be(
OptionVal.None)
}

"find Source.failed via TraversalBuilder with getValuePresentedSource" in {
val failure = new RuntimeException("failure")
TraversalBuilder.getValuePresentedSource(Source.failed(failure)).get.asInstanceOf[FailedSource[String]]
.failure should ===(
failure)
val failedSourceA = new FailedSource(failure)
TraversalBuilder.getValuePresentedSource(failedSourceA) should be(OptionVal.Some(failedSourceA))

TraversalBuilder.getValuePresentedSource(Source.failed(failure).async) should be(OptionVal.None)
TraversalBuilder.getValuePresentedSource(Source.failed(failure).mapMaterializedValue(_ => "Mat")) should be(
OptionVal.None)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ import scala.concurrent.duration._
import org.apache.pekko
import pekko.NotUsed
import pekko.stream._
import pekko.stream.impl.TraversalBuilder
import pekko.stream.impl.fusing.GraphStages.SingleSource
import pekko.stream.stage.GraphStage
import pekko.stream.stage.GraphStageLogic
import pekko.stream.stage.OutHandler
Expand All @@ -29,7 +27,6 @@ import pekko.stream.testkit.TestPublisher
import pekko.stream.testkit.Utils.TE
import pekko.stream.testkit.scaladsl.TestSink
import pekko.testkit.TestLatch
import pekko.util.OptionVal

import org.scalatest.exceptions.TestFailedException

Expand Down Expand Up @@ -283,16 +280,5 @@ class FlowFlattenMergeSpec extends StreamSpec {
probe.expectComplete()
}

"find Source.single via TraversalBuilder" in {
TraversalBuilder.getSingleSource(Source.single("a")).get.elem should ===("a")
TraversalBuilder.getSingleSource(Source(List("a", "b"))) should be(OptionVal.None)

val singleSourceA = new SingleSource("a")
TraversalBuilder.getSingleSource(singleSourceA) should be(OptionVal.Some(singleSourceA))

TraversalBuilder.getSingleSource(Source.single("c").async) should be(OptionVal.None)
TraversalBuilder.getSingleSource(Source.single("d").mapMaterializedValue(_ => "Mat")) should be(OptionVal.None)
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import pekko.stream.stage.{ GraphStage, GraphStageLogic, OutHandler }
/**
* INTERNAL API
*/
@InternalApi private[pekko] final class FailedSource[T](failure: Throwable) extends GraphStage[SourceShape[T]] {
@InternalApi private[pekko] final class FailedSource[T](val failure: Throwable) extends GraphStage[SourceShape[T]] {
val out = Outlet[T]("FailedSource.out")
override val shape = SourceShape(out)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.function.Consumer

/** INTERNAL API */
@InternalApi private[stream] final class JavaStreamSource[T, S <: java.util.stream.BaseStream[T, S]](
open: () => java.util.stream.BaseStream[T, S])
val open: () => java.util.stream.BaseStream[T, S])
extends GraphStage[SourceShape[T]] {

val out: Outlet[T] = Outlet("JavaStreamSource")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import pekko.annotation.{ DoNotInherit, InternalApi }
import pekko.stream._
import pekko.stream.impl.StreamLayout.AtomicModule
import pekko.stream.impl.TraversalBuilder.{ AnyFunction1, AnyFunction2 }
import pekko.stream.impl.fusing.GraphStageModule
import pekko.stream.impl.fusing.GraphStages.SingleSource
import pekko.stream.impl.fusing.{ GraphStageModule, IterableSource }
import pekko.stream.impl.fusing.GraphStages.{ FutureSource, SingleSource }
import pekko.stream.scaladsl.Keep
import pekko.util.OptionVal
import pekko.util.unused
Expand Down Expand Up @@ -380,12 +380,53 @@ import pekko.util.unused
}
}

/**
* Try to find `SingleSource` or wrapped such. This is used as a
* performance optimization in FlattenConcat and possibly other places.
* @since 1.2.0
*/
@InternalApi def getValuePresentedSource[A >: Null](
graph: Graph[SourceShape[A], _]): OptionVal[Graph[SourceShape[A], _]] = {
def isValuePresentedSource(graph: Graph[SourceShape[_ <: A], _]): Boolean = graph match {
case _: SingleSource[_] | _: FutureSource[_] | _: IterableSource[_] | _: JavaStreamSource[_, _] |
_: FailedSource[_] =>
true
case maybeEmpty if isEmptySource(maybeEmpty) => true
case _ => false
}
graph match {
case _ if isValuePresentedSource(graph) => OptionVal.Some(graph)
case _ =>
graph.traversalBuilder match {
case l: LinearTraversalBuilder =>
l.pendingBuilder match {
case OptionVal.Some(a: AtomicTraversalBuilder) =>
a.module match {
case m: GraphStageModule[_, _] =>
m.stage match {
case _ if isValuePresentedSource(m.stage.asInstanceOf[Graph[SourceShape[A], _]]) =>
// It would be != EmptyTraversal if mapMaterializedValue was used and then we can't optimize.
if ((l.traversalSoFar eq EmptyTraversal) && !l.attributes.isAsync)
OptionVal.Some(m.stage.asInstanceOf[Graph[SourceShape[A], _]])
else OptionVal.None
case _ => OptionVal.None
}
case _ => OptionVal.None
}
case _ => OptionVal.None
}
case _ => OptionVal.None
}
}
}

/**
* Test if a Graph is an empty Source.
*/
def isEmptySource(graph: Graph[SourceShape[_], _]): Boolean = graph match {
case source: scaladsl.Source[_, _] if source eq scaladsl.Source.empty => true
case source: javadsl.Source[_, _] if source eq javadsl.Source.empty() => true
case EmptySource => true
case _ => false
}

Expand Down

0 comments on commit 11e9547

Please sign in to comment.