Skip to content

Commit

Permalink
[fix](Nereids) topn runtime filter only support simplest case (#29312) (
Browse files Browse the repository at this point in the history
  • Loading branch information
morrySnow authored Jan 3, 2024
1 parent ae9f70a commit 8bc3c36
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,73 @@
import org.apache.doris.nereids.trees.plans.algebra.Filter;
import org.apache.doris.nereids.trees.plans.algebra.OlapScan;
import org.apache.doris.nereids.trees.plans.algebra.Project;
import org.apache.doris.nereids.trees.plans.algebra.TopN;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalSort;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalSink;
import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.ImmutableList;

/**
* topN opt
* refer to:
* <a href="https://github.com/apache/doris/pull/15558">...</a>
* <a href="https://github.com/apache/doris/pull/15663">...</a>
*
* // only support simple case: select ... from tbl [where ...] order by ... limit ...
*/

public class TopNScanOpt extends PlanPostProcessor {

@Override
public Plan visit(Plan plan, CascadesContext context) {
return plan;
}

@Override
public Plan visitPhysicalSink(PhysicalSink<? extends Plan> physicalSink, CascadesContext context) {
if (physicalSink.child() instanceof TopN) {
return super.visit(physicalSink, context);
}
return physicalSink;
}

@Override
public Plan visitPhysicalDistribute(PhysicalDistribute<? extends Plan> distribute, CascadesContext context) {
if (distribute.child() instanceof TopN && distribute.child() instanceof AbstractPhysicalSort
&& ((AbstractPhysicalSort<?>) distribute.child()).getSortPhase() == SortPhase.LOCAL_SORT) {
return super.visit(distribute, context);
}
return distribute;
}

@Override
public PhysicalTopN<? extends Plan> visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, CascadesContext ctx) {
Plan child = topN.child().accept(this, ctx);
topN = rewriteTopN(topN);
if (child != topN.child()) {
topN = ((PhysicalTopN) topN.withChildren(child)).copyStatsAndGroupIdFrom(topN);
if (topN.getSortPhase() == SortPhase.LOCAL_SORT) {
Plan child = topN.child();
topN = rewriteTopN(topN);
if (child != topN.child()) {
topN = ((PhysicalTopN<? extends Plan>) topN.withChildren(child)).copyStatsAndGroupIdFrom(topN);
}
return topN;
} else if (topN.getSortPhase() == SortPhase.MERGE_SORT) {
return (PhysicalTopN<? extends Plan>) super.visit(topN, ctx);
}
return topN;
}

@Override
public Plan visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN<? extends Plan> topN,
CascadesContext context) {
Plan child = topN.child().accept(this, context);
if (child != topN.child()) {
topN = topN.withChildren(ImmutableList.of(child)).copyStatsAndGroupIdFrom(topN);
}
PhysicalTopN<? extends Plan> rewrittenTopN = rewriteTopN(topN.getPhysicalTopN());
if (topN.getPhysicalTopN() != rewrittenTopN) {
topN = topN.withPhysicalTopN(rewrittenTopN).copyStatsAndGroupIdFrom(topN);
if (topN.getSortPhase() == SortPhase.LOCAL_SORT) {
PhysicalTopN<? extends Plan> rewrittenTopN = rewriteTopN(topN.getPhysicalTopN());
if (topN.getPhysicalTopN() != rewrittenTopN) {
topN = topN.withPhysicalTopN(rewrittenTopN).copyStatsAndGroupIdFrom(topN);
}
return topN;
} else if (topN.getSortPhase() == SortPhase.MERGE_SORT) {
return super.visit(topN, context);
}
return topN;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import org.apache.doris.nereids.datasets.ssb.SSBTestBase;
import org.apache.doris.nereids.processor.post.PlanPostProcessors;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.SortPhase;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
import org.apache.doris.nereids.util.PlanChecker;

import org.junit.jupiter.api.Assertions;
Expand All @@ -41,7 +43,7 @@ public void testUseTopNRf() {
.implement();
PhysicalPlan plan = checker.getPhysicalPlan();
plan = new PlanPostProcessors(checker.getCascadesContext()).process(plan);
Assertions.assertTrue(plan.children().get(0).child(0) instanceof PhysicalDeferMaterializeTopN);
Assertions.assertInstanceOf(PhysicalDeferMaterializeTopN.class, plan.children().get(0).child(0));
PhysicalDeferMaterializeTopN<? extends Plan> localTopN
= (PhysicalDeferMaterializeTopN<? extends Plan>) plan.child(0).child(0);
Assertions.assertTrue(localTopN.getPhysicalTopN().isEnableRuntimeFilter());
Expand All @@ -56,9 +58,25 @@ public void testNotUseTopNRf() {
.implement();
PhysicalPlan plan = checker.getPhysicalPlan();
plan = new PlanPostProcessors(checker.getCascadesContext()).process(plan);
Assertions.assertTrue(plan.children().get(0).child(0) instanceof PhysicalDeferMaterializeTopN);
Assertions.assertInstanceOf(PhysicalDeferMaterializeTopN.class, plan.children().get(0).child(0));
PhysicalDeferMaterializeTopN<? extends Plan> localTopN
= (PhysicalDeferMaterializeTopN<? extends Plan>) plan.child(0).child(0);
Assertions.assertFalse(localTopN.getPhysicalTopN().isEnableRuntimeFilter());
}

@Test
public void testNotUseTopNRfForComplexCase() {
String sql = "select * from (select 1) tl join (select * from customer order by c_custkey limit 5) tb";
PlanChecker checker = PlanChecker.from(connectContext).analyze(sql)
.rewrite()
.implement();
PhysicalPlan plan = checker.getPhysicalPlan();
plan = new PlanPostProcessors(checker.getCascadesContext()).process(plan);
Assertions.assertInstanceOf(PhysicalTopN.class, plan.child(0).child(0).child(1).child(0));
Assertions.assertEquals(SortPhase.LOCAL_SORT, ((PhysicalTopN<? extends Plan>) plan
.child(0).child(0).child(1).child(0)).getSortPhase());
PhysicalTopN<? extends Plan> localTopN = (PhysicalTopN<? extends Plan>) plan
.child(0).child(0).child(1).child(0);
Assertions.assertFalse(localTopN.isEnableRuntimeFilter());
}
}

0 comments on commit 8bc3c36

Please sign in to comment.