diff --git a/stream-tests/src/test/scala/org/apache/pekko/stream/impl/TraversalBuilderSpec.scala b/stream-tests/src/test/scala/org/apache/pekko/stream/impl/TraversalBuilderSpec.scala index 484ac0ed47..4c0f82e755 100644 --- a/stream-tests/src/test/scala/org/apache/pekko/stream/impl/TraversalBuilderSpec.scala +++ b/stream-tests/src/test/scala/org/apache/pekko/stream/impl/TraversalBuilderSpec.scala @@ -17,9 +17,14 @@ import org.apache.pekko import pekko.NotUsed import pekko.stream._ import pekko.stream.impl.TraversalTestUtils._ -import pekko.stream.scaladsl.Keep +import pekko.stream.impl.fusing.IterableSource +import pekko.stream.impl.fusing.GraphStages.{ FutureSource, SingleSource } +import pekko.stream.scaladsl.{ Keep, Source } +import pekko.util.OptionVal import pekko.testkit.PekkoSpec +import scala.concurrent.Future + class TraversalBuilderSpec extends PekkoSpec { "CompositeTraversalBuilder" must { @@ -447,4 +452,93 @@ class TraversalBuilderSpec extends PekkoSpec { } } + "find Source.single via TraversalBuilder" in { + TraversalBuilder.getSingleSource(Source.single("a")).get.elem should ===("a") + TraversalBuilder.getSingleSource(Source(List("a", "b"))) should be(OptionVal.None) + + val singleSourceA = new SingleSource("a") + TraversalBuilder.getSingleSource(singleSourceA) should be(OptionVal.Some(singleSourceA)) + + TraversalBuilder.getSingleSource(Source.single("c").async) should be(OptionVal.None) + TraversalBuilder.getSingleSource(Source.single("d").mapMaterializedValue(_ => "Mat")) should be(OptionVal.None) + } + + "find Source.single via TraversalBuilder with getValuePresentedSource" in { + TraversalBuilder.getValuePresentedSource(Source.single("a")).get.asInstanceOf[SingleSource[String]].elem should ===( + "a") + val singleSourceA = new SingleSource("a") + TraversalBuilder.getValuePresentedSource(singleSourceA) should be(OptionVal.Some(singleSourceA)) + + TraversalBuilder.getValuePresentedSource(Source.single("c").async) should be(OptionVal.None) + TraversalBuilder.getValuePresentedSource(Source.single("d").mapMaterializedValue(_ => "Mat")) should be( + OptionVal.None) + } + + "find Source.empty via TraversalBuilder with getValuePresentedSource" in { + val emptySource = EmptySource + TraversalBuilder.getValuePresentedSource(emptySource) should be(OptionVal.Some(emptySource)) + + TraversalBuilder.getValuePresentedSource(Source.empty.async) should be(OptionVal.None) + TraversalBuilder.getValuePresentedSource(Source.empty.mapMaterializedValue(_ => "Mat")) should be(OptionVal.None) + } + + "find javadsl Source.empty via TraversalBuilder with getValuePresentedSource" in { + import pekko.stream.javadsl.Source + val emptySource = Source.empty() + TraversalBuilder.getValuePresentedSource(Source.empty()) should be(OptionVal.Some(emptySource)) + + TraversalBuilder.getValuePresentedSource(Source.empty().async) should be(OptionVal.None) + TraversalBuilder.getValuePresentedSource(Source.empty().mapMaterializedValue(_ => "Mat")) should be(OptionVal.None) + } + + "find Source.future via TraversalBuilder with getValuePresentedSource" in { + val future = Future.successful("a") + TraversalBuilder.getValuePresentedSource(Source.future(future)).get.asInstanceOf[FutureSource[String]].future should ===( + future) + val futureSourceA = new FutureSource(future) + TraversalBuilder.getValuePresentedSource(futureSourceA) should be(OptionVal.Some(futureSourceA)) + + TraversalBuilder.getValuePresentedSource(Source.future(future).async) should be(OptionVal.None) + TraversalBuilder.getValuePresentedSource(Source.future(future).mapMaterializedValue(_ => "Mat")) should be( + OptionVal.None) + } + + "find Source.iterable via TraversalBuilder with getValuePresentedSource" in { + val iterable = List("a") + TraversalBuilder.getValuePresentedSource(Source(iterable)).get.asInstanceOf[IterableSource[String]].elements should ===( + iterable) + val iterableSource = new IterableSource(iterable) + TraversalBuilder.getValuePresentedSource(iterableSource) should be(OptionVal.Some(iterableSource)) + + TraversalBuilder.getValuePresentedSource(Source(iterable).async) should be(OptionVal.None) + TraversalBuilder.getValuePresentedSource(Source(iterable).mapMaterializedValue(_ => "Mat")) should be( + OptionVal.None) + } + + "find Source.javaStreamSource via TraversalBuilder with getValuePresentedSource" in { + val javaStream = java.util.stream.Stream.empty[String]() + TraversalBuilder.getValuePresentedSource(Source.fromJavaStream(() => javaStream)).get + .asInstanceOf[JavaStreamSource[String, _]].open() shouldEqual javaStream + val streamSource = new JavaStreamSource(() => javaStream) + TraversalBuilder.getValuePresentedSource(streamSource) should be(OptionVal.Some(streamSource)) + + TraversalBuilder.getValuePresentedSource(Source.fromJavaStream(() => javaStream).async) should be(OptionVal.None) + TraversalBuilder.getValuePresentedSource( + Source.fromJavaStream(() => javaStream).mapMaterializedValue(_ => "Mat")) should be( + OptionVal.None) + } + + "find Source.failed via TraversalBuilder with getValuePresentedSource" in { + val failure = new RuntimeException("failure") + TraversalBuilder.getValuePresentedSource(Source.failed(failure)).get.asInstanceOf[FailedSource[String]] + .failure should ===( + failure) + val failedSourceA = new FailedSource(failure) + TraversalBuilder.getValuePresentedSource(failedSourceA) should be(OptionVal.Some(failedSourceA)) + + TraversalBuilder.getValuePresentedSource(Source.failed(failure).async) should be(OptionVal.None) + TraversalBuilder.getValuePresentedSource(Source.failed(failure).mapMaterializedValue(_ => "Mat")) should be( + OptionVal.None) + } + } diff --git a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowFlattenMergeSpec.scala b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowFlattenMergeSpec.scala index 99b342d74b..7d4cf5de37 100644 --- a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowFlattenMergeSpec.scala +++ b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowFlattenMergeSpec.scala @@ -19,8 +19,6 @@ import scala.concurrent.duration._ import org.apache.pekko import pekko.NotUsed import pekko.stream._ -import pekko.stream.impl.TraversalBuilder -import pekko.stream.impl.fusing.GraphStages.SingleSource import pekko.stream.stage.GraphStage import pekko.stream.stage.GraphStageLogic import pekko.stream.stage.OutHandler @@ -29,7 +27,6 @@ import pekko.stream.testkit.TestPublisher import pekko.stream.testkit.Utils.TE import pekko.stream.testkit.scaladsl.TestSink import pekko.testkit.TestLatch -import pekko.util.OptionVal import org.scalatest.exceptions.TestFailedException @@ -283,16 +280,5 @@ class FlowFlattenMergeSpec extends StreamSpec { probe.expectComplete() } - "find Source.single via TraversalBuilder" in { - TraversalBuilder.getSingleSource(Source.single("a")).get.elem should ===("a") - TraversalBuilder.getSingleSource(Source(List("a", "b"))) should be(OptionVal.None) - - val singleSourceA = new SingleSource("a") - TraversalBuilder.getSingleSource(singleSourceA) should be(OptionVal.Some(singleSourceA)) - - TraversalBuilder.getSingleSource(Source.single("c").async) should be(OptionVal.None) - TraversalBuilder.getSingleSource(Source.single("d").mapMaterializedValue(_ => "Mat")) should be(OptionVal.None) - } - } } diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/FailedSource.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/FailedSource.scala index 4ab1c25355..b107857f86 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/FailedSource.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/FailedSource.scala @@ -22,7 +22,7 @@ import pekko.stream.stage.{ GraphStage, GraphStageLogic, OutHandler } /** * INTERNAL API */ -@InternalApi private[pekko] final class FailedSource[T](failure: Throwable) extends GraphStage[SourceShape[T]] { +@InternalApi private[pekko] final class FailedSource[T](val failure: Throwable) extends GraphStage[SourceShape[T]] { val out = Outlet[T]("FailedSource.out") override val shape = SourceShape(out) diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/JavaStreamSource.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/JavaStreamSource.scala index 74bba55d0a..d05625fee1 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/JavaStreamSource.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/JavaStreamSource.scala @@ -23,7 +23,7 @@ import java.util.function.Consumer /** INTERNAL API */ @InternalApi private[stream] final class JavaStreamSource[T, S <: java.util.stream.BaseStream[T, S]]( - open: () => java.util.stream.BaseStream[T, S]) + val open: () => java.util.stream.BaseStream[T, S]) extends GraphStage[SourceShape[T]] { val out: Outlet[T] = Outlet("JavaStreamSource") diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/TraversalBuilder.scala b/stream/src/main/scala/org/apache/pekko/stream/impl/TraversalBuilder.scala index 7ff61a2b3a..24410e0f6d 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/impl/TraversalBuilder.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/impl/TraversalBuilder.scala @@ -21,8 +21,8 @@ import pekko.annotation.{ DoNotInherit, InternalApi } import pekko.stream._ import pekko.stream.impl.StreamLayout.AtomicModule import pekko.stream.impl.TraversalBuilder.{ AnyFunction1, AnyFunction2 } -import pekko.stream.impl.fusing.GraphStageModule -import pekko.stream.impl.fusing.GraphStages.SingleSource +import pekko.stream.impl.fusing.{ GraphStageModule, IterableSource } +import pekko.stream.impl.fusing.GraphStages.{ FutureSource, SingleSource } import pekko.stream.scaladsl.Keep import pekko.util.OptionVal import pekko.util.unused @@ -380,12 +380,53 @@ import pekko.util.unused } } + /** + * Try to find `SingleSource` or wrapped such. This is used as a + * performance optimization in FlattenConcat and possibly other places. + * @since 1.2.0 + */ + @InternalApi def getValuePresentedSource[A >: Null]( + graph: Graph[SourceShape[A], _]): OptionVal[Graph[SourceShape[A], _]] = { + def isValuePresentedSource(graph: Graph[SourceShape[_ <: A], _]): Boolean = graph match { + case _: SingleSource[_] | _: FutureSource[_] | _: IterableSource[_] | _: JavaStreamSource[_, _] | + _: FailedSource[_] => + true + case maybeEmpty if isEmptySource(maybeEmpty) => true + case _ => false + } + graph match { + case _ if isValuePresentedSource(graph) => OptionVal.Some(graph) + case _ => + graph.traversalBuilder match { + case l: LinearTraversalBuilder => + l.pendingBuilder match { + case OptionVal.Some(a: AtomicTraversalBuilder) => + a.module match { + case m: GraphStageModule[_, _] => + m.stage match { + case _ if isValuePresentedSource(m.stage.asInstanceOf[Graph[SourceShape[A], _]]) => + // It would be != EmptyTraversal if mapMaterializedValue was used and then we can't optimize. + if ((l.traversalSoFar eq EmptyTraversal) && !l.attributes.isAsync) + OptionVal.Some(m.stage.asInstanceOf[Graph[SourceShape[A], _]]) + else OptionVal.None + case _ => OptionVal.None + } + case _ => OptionVal.None + } + case _ => OptionVal.None + } + case _ => OptionVal.None + } + } + } + /** * Test if a Graph is an empty Source. */ def isEmptySource(graph: Graph[SourceShape[_], _]): Boolean = graph match { case source: scaladsl.Source[_, _] if source eq scaladsl.Source.empty => true case source: javadsl.Source[_, _] if source eq javadsl.Source.empty() => true + case EmptySource => true case _ => false }