Skip to content

Commit

Permalink
move 3 topn rules from rbo to cbo
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Jan 13, 2025
1 parent 23cddd0 commit 41d9205
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughSort;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughWindow;
import org.apache.doris.nereids.rules.rewrite.PushDownJoinOtherCondition;
import org.apache.doris.nereids.rules.rewrite.PushDownLimitDistinctThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownProjectThroughLimit;
import org.apache.doris.nereids.rules.rewrite.PushDownTopNDistinctThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownTopNThroughJoin;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
Expand All @@ -132,6 +135,9 @@ public class RuleSet {
.add(PushDownProjectThroughInnerOuterJoin.INSTANCE)
.add(PushDownProjectThroughSemiJoin.INSTANCE)
.add(TransposeAggSemiJoinProject.INSTANCE)
.addAll(new PushDownTopNThroughJoin().buildRules())
.addAll(new PushDownLimitDistinctThroughJoin().buildRules())
.addAll(new PushDownTopNDistinctThroughJoin().buildRules())
.build();

public static final List<RuleFactory> PUSH_DOWN_FILTERS = ImmutableList.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.ImmutableList;

Expand All @@ -42,6 +43,10 @@ public List<Rule> buildRules() {
// limit -> distinct -> join
logicalLimit(logicalAggregate(logicalJoin())
.when(LogicalAggregate::isDistinct))
.when(limit ->
ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
>= limit.getLimit() + limit.getOffset())
.then(limit -> {
LogicalAggregate<LogicalJoin<Plan, Plan>> agg = limit.child();
LogicalJoin<Plan, Plan> join = agg.child();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.ImmutableList;

Expand All @@ -45,6 +46,10 @@ public List<Rule> buildRules() {
// topN -> join
logicalTopN(logicalAggregate(logicalJoin()).when(LogicalAggregate::isDistinct))
// TODO: complex order by
.when(topn ->
ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
>= topn.getLimit() + topn.getOffset())
.when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr)
.allMatch(Slot.class::isInstance))
.then(topN -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.ImmutableList;

Expand All @@ -43,6 +44,10 @@ public List<Rule> buildRules() {
// topN -> join
logicalTopN(logicalJoin())
// TODO: complex orderby
.when(topn ->
ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
>= topn.getLimit() + topn.getOffset())
.when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr)
.allMatch(Slot.class::isInstance))
.then(topN -> {
Expand Down Expand Up @@ -102,7 +107,6 @@ private Plan pushLimitThroughJoin(LogicalTopN<? extends Plan> topN, LogicalJoin<
}
return null;
case CROSS_JOIN:

if (join.left().getOutputSet().containsAll(orderbySlots)) {
return join.withChildren(
topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0, join.left()),
Expand Down

0 comments on commit 41d9205

Please sign in to comment.