diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
index 2ca1380d1b050f..bc0586f40c405a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
@@ -24,27 +24,58 @@
import org.apache.doris.nereids.trees.plans.algebra.Filter;
import org.apache.doris.nereids.trees.plans.algebra.OlapScan;
import org.apache.doris.nereids.trees.plans.algebra.Project;
+import org.apache.doris.nereids.trees.plans.algebra.TopN;
+import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalSort;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalSink;
import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
import org.apache.doris.qe.ConnectContext;
-import com.google.common.collect.ImmutableList;
-
/**
* topN opt
* refer to:
* ...
* ...
+ *
+ * // only support simple case: select ... from tbl [where ...] order by ... limit ...
*/
public class TopNScanOpt extends PlanPostProcessor {
+ @Override
+ public Plan visit(Plan plan, CascadesContext context) {
+ return plan;
+ }
+
+ @Override
+ public Plan visitPhysicalSink(PhysicalSink extends Plan> physicalSink, CascadesContext context) {
+ if (physicalSink.child() instanceof TopN) {
+ return super.visit(physicalSink, context);
+ }
+ return physicalSink;
+ }
+
+ @Override
+ public Plan visitPhysicalDistribute(PhysicalDistribute extends Plan> distribute, CascadesContext context) {
+ if (distribute.child() instanceof TopN && distribute.child() instanceof AbstractPhysicalSort
+ && ((AbstractPhysicalSort>) distribute.child()).getSortPhase() == SortPhase.LOCAL_SORT) {
+ return super.visit(distribute, context);
+ }
+ return distribute;
+ }
+
@Override
public PhysicalTopN extends Plan> visitPhysicalTopN(PhysicalTopN extends Plan> topN, CascadesContext ctx) {
- Plan child = topN.child().accept(this, ctx);
- topN = rewriteTopN(topN);
- if (child != topN.child()) {
- topN = ((PhysicalTopN) topN.withChildren(child)).copyStatsAndGroupIdFrom(topN);
+ if (topN.getSortPhase() == SortPhase.LOCAL_SORT) {
+ Plan child = topN.child();
+ topN = rewriteTopN(topN);
+ if (child != topN.child()) {
+ topN = ((PhysicalTopN extends Plan>) topN.withChildren(child)).copyStatsAndGroupIdFrom(topN);
+ }
+ return topN;
+ } else if (topN.getSortPhase() == SortPhase.MERGE_SORT) {
+ return (PhysicalTopN extends Plan>) super.visit(topN, ctx);
}
return topN;
}
@@ -52,13 +83,14 @@ public PhysicalTopN extends Plan> visitPhysicalTopN(PhysicalTopN extends Pla
@Override
public Plan visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN extends Plan> topN,
CascadesContext context) {
- Plan child = topN.child().accept(this, context);
- if (child != topN.child()) {
- topN = topN.withChildren(ImmutableList.of(child)).copyStatsAndGroupIdFrom(topN);
- }
- PhysicalTopN extends Plan> rewrittenTopN = rewriteTopN(topN.getPhysicalTopN());
- if (topN.getPhysicalTopN() != rewrittenTopN) {
- topN = topN.withPhysicalTopN(rewrittenTopN).copyStatsAndGroupIdFrom(topN);
+ if (topN.getSortPhase() == SortPhase.LOCAL_SORT) {
+ PhysicalTopN extends Plan> rewrittenTopN = rewriteTopN(topN.getPhysicalTopN());
+ if (topN.getPhysicalTopN() != rewrittenTopN) {
+ topN = topN.withPhysicalTopN(rewrittenTopN).copyStatsAndGroupIdFrom(topN);
+ }
+ return topN;
+ } else if (topN.getSortPhase() == SortPhase.MERGE_SORT) {
+ return super.visit(topN, context);
}
return topN;
}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java
index f4fdf6f44f069b..0ac233898adff5 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java
@@ -20,8 +20,10 @@
import org.apache.doris.nereids.datasets.ssb.SSBTestBase;
import org.apache.doris.nereids.processor.post.PlanPostProcessors;
import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.SortPhase;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
import org.apache.doris.nereids.util.PlanChecker;
import org.junit.jupiter.api.Assertions;
@@ -41,7 +43,7 @@ public void testUseTopNRf() {
.implement();
PhysicalPlan plan = checker.getPhysicalPlan();
plan = new PlanPostProcessors(checker.getCascadesContext()).process(plan);
- Assertions.assertTrue(plan.children().get(0).child(0) instanceof PhysicalDeferMaterializeTopN);
+ Assertions.assertInstanceOf(PhysicalDeferMaterializeTopN.class, plan.children().get(0).child(0));
PhysicalDeferMaterializeTopN extends Plan> localTopN
= (PhysicalDeferMaterializeTopN extends Plan>) plan.child(0).child(0);
Assertions.assertTrue(localTopN.getPhysicalTopN().isEnableRuntimeFilter());
@@ -56,9 +58,25 @@ public void testNotUseTopNRf() {
.implement();
PhysicalPlan plan = checker.getPhysicalPlan();
plan = new PlanPostProcessors(checker.getCascadesContext()).process(plan);
- Assertions.assertTrue(plan.children().get(0).child(0) instanceof PhysicalDeferMaterializeTopN);
+ Assertions.assertInstanceOf(PhysicalDeferMaterializeTopN.class, plan.children().get(0).child(0));
PhysicalDeferMaterializeTopN extends Plan> localTopN
= (PhysicalDeferMaterializeTopN extends Plan>) plan.child(0).child(0);
Assertions.assertFalse(localTopN.getPhysicalTopN().isEnableRuntimeFilter());
}
+
+ @Test
+ public void testNotUseTopNRfForComplexCase() {
+ String sql = "select * from (select 1) tl join (select * from customer order by c_custkey limit 5) tb";
+ PlanChecker checker = PlanChecker.from(connectContext).analyze(sql)
+ .rewrite()
+ .implement();
+ PhysicalPlan plan = checker.getPhysicalPlan();
+ plan = new PlanPostProcessors(checker.getCascadesContext()).process(plan);
+ Assertions.assertInstanceOf(PhysicalTopN.class, plan.child(0).child(0).child(1).child(0));
+ Assertions.assertEquals(SortPhase.LOCAL_SORT, ((PhysicalTopN extends Plan>) plan
+ .child(0).child(0).child(1).child(0)).getSortPhase());
+ PhysicalTopN extends Plan> localTopN = (PhysicalTopN extends Plan>) plan
+ .child(0).child(0).child(1).child(0);
+ Assertions.assertFalse(localTopN.isEnableRuntimeFilter());
+ }
}