diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java index 2ca1380d1b050f..bc0586f40c405a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java @@ -24,27 +24,58 @@ import org.apache.doris.nereids.trees.plans.algebra.Filter; import org.apache.doris.nereids.trees.plans.algebra.OlapScan; import org.apache.doris.nereids.trees.plans.algebra.Project; +import org.apache.doris.nereids.trees.plans.algebra.TopN; +import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalSort; import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN; +import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; +import org.apache.doris.nereids.trees.plans.physical.PhysicalSink; import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; import org.apache.doris.qe.ConnectContext; -import com.google.common.collect.ImmutableList; - /** * topN opt * refer to: * ... * ... + * + * // only support simple case: select ... from tbl [where ...] order by ... limit ... */ public class TopNScanOpt extends PlanPostProcessor { + @Override + public Plan visit(Plan plan, CascadesContext context) { + return plan; + } + + @Override + public Plan visitPhysicalSink(PhysicalSink physicalSink, CascadesContext context) { + if (physicalSink.child() instanceof TopN) { + return super.visit(physicalSink, context); + } + return physicalSink; + } + + @Override + public Plan visitPhysicalDistribute(PhysicalDistribute distribute, CascadesContext context) { + if (distribute.child() instanceof TopN && distribute.child() instanceof AbstractPhysicalSort + && ((AbstractPhysicalSort) distribute.child()).getSortPhase() == SortPhase.LOCAL_SORT) { + return super.visit(distribute, context); + } + return distribute; + } + @Override public PhysicalTopN visitPhysicalTopN(PhysicalTopN topN, CascadesContext ctx) { - Plan child = topN.child().accept(this, ctx); - topN = rewriteTopN(topN); - if (child != topN.child()) { - topN = ((PhysicalTopN) topN.withChildren(child)).copyStatsAndGroupIdFrom(topN); + if (topN.getSortPhase() == SortPhase.LOCAL_SORT) { + Plan child = topN.child(); + topN = rewriteTopN(topN); + if (child != topN.child()) { + topN = ((PhysicalTopN) topN.withChildren(child)).copyStatsAndGroupIdFrom(topN); + } + return topN; + } else if (topN.getSortPhase() == SortPhase.MERGE_SORT) { + return (PhysicalTopN) super.visit(topN, ctx); } return topN; } @@ -52,13 +83,14 @@ public PhysicalTopN visitPhysicalTopN(PhysicalTopN topN, CascadesContext context) { - Plan child = topN.child().accept(this, context); - if (child != topN.child()) { - topN = topN.withChildren(ImmutableList.of(child)).copyStatsAndGroupIdFrom(topN); - } - PhysicalTopN rewrittenTopN = rewriteTopN(topN.getPhysicalTopN()); - if (topN.getPhysicalTopN() != rewrittenTopN) { - topN = topN.withPhysicalTopN(rewrittenTopN).copyStatsAndGroupIdFrom(topN); + if (topN.getSortPhase() == SortPhase.LOCAL_SORT) { + PhysicalTopN rewrittenTopN = rewriteTopN(topN.getPhysicalTopN()); + if (topN.getPhysicalTopN() != rewrittenTopN) { + topN = topN.withPhysicalTopN(rewrittenTopN).copyStatsAndGroupIdFrom(topN); + } + return topN; + } else if (topN.getSortPhase() == SortPhase.MERGE_SORT) { + return super.visit(topN, context); } return topN; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java index f4fdf6f44f069b..0ac233898adff5 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java @@ -20,8 +20,10 @@ import org.apache.doris.nereids.datasets.ssb.SSBTestBase; import org.apache.doris.nereids.processor.post.PlanPostProcessors; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.SortPhase; import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; import org.apache.doris.nereids.util.PlanChecker; import org.junit.jupiter.api.Assertions; @@ -41,7 +43,7 @@ public void testUseTopNRf() { .implement(); PhysicalPlan plan = checker.getPhysicalPlan(); plan = new PlanPostProcessors(checker.getCascadesContext()).process(plan); - Assertions.assertTrue(plan.children().get(0).child(0) instanceof PhysicalDeferMaterializeTopN); + Assertions.assertInstanceOf(PhysicalDeferMaterializeTopN.class, plan.children().get(0).child(0)); PhysicalDeferMaterializeTopN localTopN = (PhysicalDeferMaterializeTopN) plan.child(0).child(0); Assertions.assertTrue(localTopN.getPhysicalTopN().isEnableRuntimeFilter()); @@ -56,9 +58,25 @@ public void testNotUseTopNRf() { .implement(); PhysicalPlan plan = checker.getPhysicalPlan(); plan = new PlanPostProcessors(checker.getCascadesContext()).process(plan); - Assertions.assertTrue(plan.children().get(0).child(0) instanceof PhysicalDeferMaterializeTopN); + Assertions.assertInstanceOf(PhysicalDeferMaterializeTopN.class, plan.children().get(0).child(0)); PhysicalDeferMaterializeTopN localTopN = (PhysicalDeferMaterializeTopN) plan.child(0).child(0); Assertions.assertFalse(localTopN.getPhysicalTopN().isEnableRuntimeFilter()); } + + @Test + public void testNotUseTopNRfForComplexCase() { + String sql = "select * from (select 1) tl join (select * from customer order by c_custkey limit 5) tb"; + PlanChecker checker = PlanChecker.from(connectContext).analyze(sql) + .rewrite() + .implement(); + PhysicalPlan plan = checker.getPhysicalPlan(); + plan = new PlanPostProcessors(checker.getCascadesContext()).process(plan); + Assertions.assertInstanceOf(PhysicalTopN.class, plan.child(0).child(0).child(1).child(0)); + Assertions.assertEquals(SortPhase.LOCAL_SORT, ((PhysicalTopN) plan + .child(0).child(0).child(1).child(0)).getSortPhase()); + PhysicalTopN localTopN = (PhysicalTopN) plan + .child(0).child(0).child(1).child(0); + Assertions.assertFalse(localTopN.isEnableRuntimeFilter()); + } }