From cf720e3c2d5f3c3bfed132724f5c5f57292c2c2d Mon Sep 17 00:00:00 2001 From: Gonzalo Ortiz Jaureguizar Date: Thu, 21 Nov 2024 17:34:02 +0100 Subject: [PATCH] [Spool] Introduce stage replacer and change send nodes to be able to send to more than one stage (#14495) --- .../logical/EquivalentStagesFinder.java | 16 +- .../logical/EquivalentStagesReplacer.java | 79 ++++++ .../query/planner/logical/GroupedStages.java | 8 + .../planner/plannode/MailboxReceiveNode.java | 11 +- .../planner/plannode/MailboxSendNode.java | 94 ++++++- .../planner/plannode/PlanNodeVisitor.java | 54 +++- .../logical/EquivalentStagesFinderTest.java | 47 +++- .../logical/EquivalentStagesReplacerTest.java | 146 ++++++++++ .../query/planner/logical/StagesTestBase.java | 265 +++++++++++++++++- 9 files changed, 671 insertions(+), 49 deletions(-) create mode 100644 pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesReplacer.java create mode 100644 pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesReplacerTest.java diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java index a5c98eb54c38..28bca306cd5c 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinder.java @@ -120,11 +120,15 @@ public boolean areEquivalent(MailboxSendNode stage, MailboxSendNode visitedStage return areBaseNodesEquivalent(stage, visitedStage) // Commented out fields are used in equals() method of MailboxSendNode but not needed for equivalence. // Receiver stage is not important for equivalence -// && node1.getReceiverStageId() == that.getReceiverStageId() +// && stage.getReceiverStageId() == visitedStage.getReceiverStageId() && stage.getExchangeType() == visitedStage.getExchangeType() - // Distribution type is not needed for equivalence. We deal with difference distribution types in the - // spooling logic. -// && Objects.equals(node1.getDistributionType(), that.getDistributionType()) + // TODO: Distribution type not needed for equivalence in the first substituted send nodes. Their different + // distribution can be implemented in synthetic stages. But it is important in recursive send nodes + // (a send node that is equivalent to another but where both of them send to stages that are also + // equivalent). + // This makes the equivalence check more complex and therefore we are going to consider the distribution + // type in the equivalence check. + && Objects.equals(stage.getDistributionType(), visitedStage.getDistributionType()) // TODO: Keys could probably be removed from the equivalence check, but would require to verify both // keys are present in the data schema. We are not doing that for now. && Objects.equals(stage.getKeys(), visitedStage.getKeys()) @@ -220,9 +224,7 @@ public Boolean visitMailboxReceive(MailboxReceiveNode node1, PlanNode node2) { // TODO: Keys should probably be removed from the equivalence check, but would require to verify both // keys are present in the data schema. We are not doing that for now. && Objects.equals(node1.getKeys(), that.getKeys()) - // Distribution type is not needed for equivalence. We deal with difference distribution types in the - // spooling logic. -// && node1.getDistributionType() == that.getDistributionType() + && node1.getDistributionType() == that.getDistributionType() // TODO: Sort, sort on sender and collations can probably be removed from the equivalence check, but would // require some extra checks or transformation on the spooling logic. We are not doing that for now. && node1.isSort() == that.isSort() diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesReplacer.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesReplacer.java new file mode 100644 index 000000000000..06a4cf16dac3 --- /dev/null +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/EquivalentStagesReplacer.java @@ -0,0 +1,79 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.planner.logical; + +import org.apache.pinot.query.planner.plannode.MailboxReceiveNode; +import org.apache.pinot.query.planner.plannode.MailboxSendNode; +import org.apache.pinot.query.planner.plannode.PlanNode; +import org.apache.pinot.query.planner.plannode.PlanNodeVisitor; + + +/** + * EquivalentStageReplacer is used to replace equivalent stages in the query plan. + * + * Given a {@link org.apache.pinot.query.planner.plannode.PlanNode} and a + * {@link GroupedStages}, modifies the plan node to replace equivalent stages. + * + * For each {@link MailboxReceiveNode} in the plan, if the sender is not the leader of the group, + * replaces the sender with the leader. + * The leader is also updated to include the receiver in its list of receivers. + */ +public class EquivalentStagesReplacer { + private EquivalentStagesReplacer() { + } + + /** + * Replaces the equivalent stages in the query plan. + * + * @param root Root plan node + * @param equivalentStages Equivalent stages + */ + public static void replaceEquivalentStages(PlanNode root, GroupedStages equivalentStages) { + root.visit(Replacer.INSTANCE, equivalentStages); + } + + private static class Replacer extends PlanNodeVisitor.DepthFirstVisitor { + private static final Replacer INSTANCE = new Replacer(); + + private Replacer() { + } + + @Override + public Void visitMailboxReceive(MailboxReceiveNode node, GroupedStages equivalenceGroups) { + MailboxSendNode sender = node.getSender(); + MailboxSendNode leader = equivalenceGroups.getGroup(sender).first(); + if (canSubstitute(sender, leader)) { + // we don't want to visit the children of the node given it is going to be pruned + node.setSender(leader); + leader.addReceiver(node); + } else { + visitMailboxSend(leader, equivalenceGroups); + } + return null; + } + + private boolean canSubstitute(MailboxSendNode actualSender, MailboxSendNode leader) { + return actualSender != leader // we don't need to replace the leader with itself + // the leader is already sending to this stage. Given we don't have the ability to send to multiple + // receivers in the same stage, we cannot optimize this case right now. + // If this case seems to be useful, it can be supported in the future. + && !leader.sharesReceiverStages(actualSender); + } + } +} diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/GroupedStages.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/GroupedStages.java index 45b5b561f9a3..823a9e9832eb 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/GroupedStages.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/GroupedStages.java @@ -22,6 +22,7 @@ import java.util.Comparator; import java.util.IdentityHashMap; import java.util.NoSuchElementException; +import java.util.Set; import java.util.SortedSet; import java.util.TreeSet; import java.util.stream.Collectors; @@ -77,6 +78,8 @@ public abstract SortedSet getGroup(MailboxSendNode stage) */ public abstract SortedSet> getGroups(); + public abstract Set getStages(); + @Override public String toString() { String content = getGroups().stream() @@ -154,6 +157,11 @@ public boolean containsStage(MailboxSendNode stage) { return _stageToGroup.containsKey(stage); } + @Override + public Set getStages() { + return _stageToGroup.keySet(); + } + @Override public SortedSet getGroup(MailboxSendNode stage) throws NoSuchElementException { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxReceiveNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxReceiveNode.java index c918e9ea9116..407941e6b4c6 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxReceiveNode.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxReceiveNode.java @@ -29,7 +29,7 @@ public class MailboxReceiveNode extends BasePlanNode { - private final int _senderStageId; + private int _senderStageId; private final PinotRelExchangeType _exchangeType; private RelDistribution.Type _distributionType; private final List _keys; @@ -38,7 +38,7 @@ public class MailboxReceiveNode extends BasePlanNode { private final boolean _sortedOnSender; // NOTE: This is only available during query planning, and should not be serialized. - private final transient MailboxSendNode _sender; + private transient MailboxSendNode _sender; // NOTE: null List is converted to empty List because there is no way to differentiate them in proto during ser/de. public MailboxReceiveNode(int stageId, DataSchema dataSchema, int senderStageId, @@ -57,6 +57,8 @@ public MailboxReceiveNode(int stageId, DataSchema dataSchema, int senderStageId, } public int getSenderStageId() { + assert _sender == null || _sender.getStageId() == _senderStageId + : "_senderStageId should match _sender.getStageId()"; return _senderStageId; } @@ -93,6 +95,11 @@ public MailboxSendNode getSender() { return _sender; } + public void setSender(MailboxSendNode sender) { + _senderStageId = sender.getStageId(); + _sender = sender; + } + @Override public String explain() { return "MAIL_RECEIVE(" + _distributionType + ")"; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxSendNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxSendNode.java index b4aa8677e22b..9cc2c2e65792 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxSendNode.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/MailboxSendNode.java @@ -18,6 +18,9 @@ */ package org.apache.pinot.query.planner.plannode; +import com.google.common.base.Preconditions; +import java.util.BitSet; +import java.util.Iterator; import java.util.List; import java.util.Objects; import javax.annotation.Nullable; @@ -28,7 +31,7 @@ public class MailboxSendNode extends BasePlanNode { - private final int _receiverStageId; + private final BitSet _receiverStages; private final PinotRelExchangeType _exchangeType; private RelDistribution.Type _distributionType; private final List _keys; @@ -37,11 +40,12 @@ public class MailboxSendNode extends BasePlanNode { private final boolean _sort; // NOTE: null List is converted to empty List because there is no way to differentiate them in proto during ser/de. - public MailboxSendNode(int stageId, DataSchema dataSchema, List inputs, int receiverStageId, - PinotRelExchangeType exchangeType, RelDistribution.Type distributionType, @Nullable List keys, - boolean prePartitioned, @Nullable List collations, boolean sort) { + private MailboxSendNode(int stageId, DataSchema dataSchema, List inputs, + BitSet receiverStages, PinotRelExchangeType exchangeType, + RelDistribution.Type distributionType, @Nullable List keys, boolean prePartitioned, + @Nullable List collations, boolean sort) { super(stageId, dataSchema, null, inputs); - _receiverStageId = receiverStageId; + _receiverStages = receiverStages; _exchangeType = exchangeType; _distributionType = distributionType; _keys = keys != null ? keys : List.of(); @@ -50,8 +54,74 @@ public MailboxSendNode(int stageId, DataSchema dataSchema, List inputs _sort = sort; } + public MailboxSendNode(int stageId, DataSchema dataSchema, List inputs, + int receiverStage, PinotRelExchangeType exchangeType, + RelDistribution.Type distributionType, @Nullable List keys, boolean prePartitioned, + @Nullable List collations, boolean sort) { + this(stageId, dataSchema, inputs, toBitSet(receiverStage), exchangeType, distributionType, keys, prePartitioned, + collations, sort); + } + + private static BitSet toBitSet(int receiverStage) { + BitSet bitSet = new BitSet(receiverStage + 1); + bitSet.set(receiverStage); + return bitSet; + } + + private static BitSet toBitSet(@Nullable List receiverStages) { + BitSet bitSet = new BitSet(); + if (receiverStages == null || receiverStages.isEmpty()) { + return bitSet; + } + for (int receiverStage : receiverStages) { + bitSet.set(receiverStage); + } + return bitSet; + } + + public MailboxSendNode(int stageId, DataSchema dataSchema, List inputs, + PinotRelExchangeType exchangeType, RelDistribution.Type distributionType, @Nullable List keys, + boolean prePartitioned, @Nullable List collations, boolean sort) { + this(stageId, dataSchema, inputs, new BitSet(), exchangeType, distributionType, keys, prePartitioned, collations, + sort); + } + + public boolean sharesReceiverStages(MailboxSendNode other) { + return _receiverStages.intersects(other._receiverStages); + } + + /** + * Returns the receiver stage ids, sorted in ascending order. + */ + public Iterable getReceiverStageIds() { + return () -> new Iterator<>() { + int _next = _receiverStages.nextSetBit(0); + + @Override + public boolean hasNext() { + return _next >= 0; + } + + @Override + public Integer next() { + int current = _next; + _next = _receiverStages.nextSetBit(_next + 1); + return current; + } + }; + } + + @Deprecated public int getReceiverStageId() { - return _receiverStageId; + Preconditions.checkState(!_receiverStages.isEmpty(), "Receivers not set"); + return _receiverStages.nextSetBit(0); + } + + public void addReceiver(MailboxReceiveNode node) { + if (_receiverStages.get(node.getStageId())) { + throw new IllegalStateException("Receiver already added: " + node.getStageId()); + } + _receiverStages.set(node.getStageId()); } public PinotRelExchangeType getExchangeType() { @@ -104,7 +174,7 @@ public T visit(PlanNodeVisitor visitor, C context) { @Override public PlanNode withInputs(List inputs) { - return new MailboxSendNode(_stageId, _dataSchema, inputs, _receiverStageId, _exchangeType, _distributionType, _keys, + return new MailboxSendNode(_stageId, _dataSchema, inputs, _receiverStages, _exchangeType, _distributionType, _keys, _prePartitioned, _collations, _sort); } @@ -120,14 +190,14 @@ public boolean equals(Object o) { return false; } MailboxSendNode that = (MailboxSendNode) o; - return _receiverStageId == that._receiverStageId && _prePartitioned == that._prePartitioned && _sort == that._sort - && _exchangeType == that._exchangeType && _distributionType == that._distributionType && Objects.equals(_keys, - that._keys) && Objects.equals(_collations, that._collations); + return Objects.equals(_receiverStages, that._receiverStages) && _prePartitioned == that._prePartitioned + && _sort == that._sort && _exchangeType == that._exchangeType && _distributionType == that._distributionType + && Objects.equals(_keys, that._keys) && Objects.equals(_collations, that._collations); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), _receiverStageId, _exchangeType, _distributionType, _keys, _prePartitioned, + return Objects.hash(super.hashCode(), _receiverStages, _exchangeType, _distributionType, _keys, _prePartitioned, _collations, _sort); } @@ -135,7 +205,7 @@ public int hashCode() { public String toString() { return "MailboxSendNode{" + "_stageId=" + _stageId - + ", _receiverStageId=" + _receiverStageId + + ", _receivers=" + _receiverStages + '}'; } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/PlanNodeVisitor.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/PlanNodeVisitor.java index 0327d89e654f..49494f8df659 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/PlanNodeVisitor.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/PlanNodeVisitor.java @@ -69,7 +69,7 @@ public interface PlanNodeVisitor { * * The default implementation for each plan node type does nothing but visiting its inputs * (see {@link #visitChildren(PlanNode, Object)}) and then returning the result of calling - * {@link #defaultCase(PlanNode, Object)}. + * {@link #postChildren(PlanNode, Object)}. * * Subclasses can override each method to provide custom behavior for each plan node type. * For example: @@ -117,6 +117,17 @@ protected boolean traverseStageBoundary() { return true; } + /** + * The method that is called by default to handle a node that does not have a specific visit method. + * + * This method can be overridden to provide a default behavior for all nodes. + * + * The returned value of this method is ignored by default + */ + protected T preChildren(PlanNode node, C context) { + return null; + } + /** * The method that is called by default to handle a node that does not have a specific visit method. * @@ -124,89 +135,102 @@ protected boolean traverseStageBoundary() { * * The returned value of this method is what each default visit method will return. */ - protected T defaultCase(PlanNode node, C context) { + protected T postChildren(PlanNode node, C context) { return null; } @Override public T visitAggregate(AggregateNode node, C context) { + preChildren(node, context); visitChildren(node, context); - return defaultCase(node, context); + return postChildren(node, context); } @Override public T visitFilter(FilterNode node, C context) { + preChildren(node, context); visitChildren(node, context); - return defaultCase(node, context); + return postChildren(node, context); } @Override public T visitJoin(JoinNode node, C context) { + preChildren(node, context); visitChildren(node, context); - return defaultCase(node, context); + return postChildren(node, context); } @Override public T visitMailboxReceive(MailboxReceiveNode node, C context) { + preChildren(node, context); visitChildren(node, context); if (traverseStageBoundary()) { node.getSender().visit(this, context); } - return defaultCase(node, context); + return postChildren(node, context); } @Override public T visitMailboxSend(MailboxSendNode node, C context) { + preChildren(node, context); visitChildren(node, context); - return defaultCase(node, context); + return postChildren(node, context); } @Override public T visitProject(ProjectNode node, C context) { + preChildren(node, context); visitChildren(node, context); - return defaultCase(node, context); + return postChildren(node, context); } @Override public T visitSort(SortNode node, C context) { + preChildren(node, context); visitChildren(node, context); - return defaultCase(node, context); + return postChildren(node, context); } @Override public T visitTableScan(TableScanNode node, C context) { + preChildren(node, context); visitChildren(node, context); - return defaultCase(node, context); + return postChildren(node, context); } @Override public T visitValue(ValueNode node, C context) { + preChildren(node, context); visitChildren(node, context); - return defaultCase(node, context); + return postChildren(node, context); } @Override public T visitWindow(WindowNode node, C context) { + preChildren(node, context); visitChildren(node, context); - return defaultCase(node, context); + return postChildren(node, context); } @Override public T visitSetOp(SetOpNode node, C context) { + preChildren(node, context); visitChildren(node, context); - return defaultCase(node, context); + return postChildren(node, context); } @Override public T visitExchange(ExchangeNode node, C context) { + preChildren(node, context); visitChildren(node, context); - return defaultCase(node, context); + return postChildren(node, context); } @Override public T visitExplained(ExplainedNode node, C context) { + preChildren(node, context); visitChildren(node, context); - return defaultCase(node, context); + return postChildren(node, context); } } } diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinderTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinderTest.java index 54f101059bc5..1e7f71f39606 100644 --- a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinderTest.java +++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesFinderTest.java @@ -19,6 +19,7 @@ package org.apache.pinot.query.planner.logical; import java.util.Map; +import org.apache.calcite.rel.RelDistribution; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.query.planner.plannode.MailboxSendNode; import org.testng.annotations.Test; @@ -68,6 +69,34 @@ public void sharedJoin() { assertEquals(result.toString(), "[[0], [1, 2]]"); } + @Test + void sameDistributionKeepEquivalence() { + when( + join( + exchange(1, tableScan("T1")) + .withDistributionType(RelDistribution.Type.RANDOM_DISTRIBUTED), + exchange(2, tableScan("T1")) + .withDistributionType(RelDistribution.Type.RANDOM_DISTRIBUTED) + ) + ); + GroupedStages groupedStages = EquivalentStagesFinder.findEquivalentStages(stage(0)); + assertEquals(groupedStages.toString(), "[[0], [1, 2]]"); + } + + @Test + void differentDistributionBreakEquivalence() { + when( + join( + exchange(1, tableScan("T1")) + .withDistributionType(RelDistribution.Type.RANDOM_DISTRIBUTED), + exchange(2, tableScan("T1")) + .withDistributionType(RelDistribution.Type.BROADCAST_DISTRIBUTED) + ) + ); + GroupedStages groupedStages = EquivalentStagesFinder.findEquivalentStages(stage(0)); + assertEquals(groupedStages.toString(), "[[0], [1], [2]]"); + } + @Test public void sameHintsDontBreakEquivalence() { when( @@ -89,7 +118,7 @@ public void sameHintsDontBreakEquivalence() { } @Test - public void differentHintsImplyNotEquivalent() { + public void differentHintsBreakEquivalence() { when( join( exchange( @@ -109,7 +138,7 @@ public void differentHintsImplyNotEquivalent() { } @Test - public void differentHintsOneNullImplyNotEquivalent() { + public void differentHintsOneNullBreakEquivalence() { when( join( exchange(1, tableScan("T1")), @@ -199,4 +228,18 @@ public void deepSharedDifferentTables() { GroupedStages result = EquivalentStagesFinder.findEquivalentStages(stage(0)); assertEquals(result.toString(), "[[0], [1, 2], [3, 5], [4, 6]]"); } + + @Test + void notUniqueReceiversInStage() { + when(// stage 0 + exchange(1, + join( + exchange(2, tableScan("T1")), + exchange(3, tableScan("T1")) + ) + ) + ); + GroupedStages groupedStages = EquivalentStagesFinder.findEquivalentStages(stage(0)); + assertEquals(groupedStages.toString(), "[[0], [1], [2, 3]]"); + } } diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesReplacerTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesReplacerTest.java new file mode 100644 index 000000000000..830c8a2d78f3 --- /dev/null +++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/EquivalentStagesReplacerTest.java @@ -0,0 +1,146 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.planner.logical; + +import org.apache.pinot.query.planner.plannode.MailboxSendNode; +import org.testng.annotations.Test; + +import static org.testng.Assert.*; + + +public class EquivalentStagesReplacerTest extends StagesTestBase { + + @Test + public void test() { + when(// stage 0 + exchange(1, + join( + exchange(2, + join( + exchange(3, tableScan("T1")), + exchange(4, tableScan("T2")) + ) + ), + exchange(5, + join( + exchange(6, tableScan("T1")), + exchange(7, tableScan("T3")) + ) + ) + ) + ) + ); + + GroupedStages groupedStages = EquivalentStagesFinder.findEquivalentStages(stage(0)); + assertEquals(groupedStages.toString(), "[[0], [1], [2], [3, 6], [4], [5], [7]]"); + + MailboxSendNode rootStage = stage(0); + EquivalentStagesReplacer.replaceEquivalentStages(rootStage, groupedStages); + + cleanup(); + SpoolBuilder readT1 = new SpoolBuilder(3, tableScan("T1")); + MailboxSendNode expected = when(// stage 0 + exchange(1, + join( + exchange(2, + join( + readT1.newReceiver(), + exchange(4, tableScan("T2")) + ) + ), + exchange(5, + join( + readT1.newReceiver(), + exchange(7, tableScan("T3")) + ) + ) + ) + ) + ); + + assertEqualPlan(rootStage, expected); + } + + @Test + void notUniqueReceiversInStage() { + when(// stage 0 + exchange(1, + join( + exchange(2, tableScan("T1")), + exchange(3, tableScan("T1")) + ) + ) + ); + GroupedStages groupedStages = EquivalentStagesFinder.findEquivalentStages(stage(0)); + assertEquals(groupedStages.toString(), "[[0], [1], [2, 3]]"); + + MailboxSendNode rootStage = stage(0); + EquivalentStagesReplacer.replaceEquivalentStages(rootStage, groupedStages); + + cleanup(); + MailboxSendNode expected = when(// stage 0 + exchange(1, + join( + exchange(2, tableScan("T1")), + exchange(3, tableScan("T1")) + ) + ) + ); + assertEqualPlan(rootStage, expected); + } + + @Test + void groupSendingToTheSameStage() { + when(// stage 0 + exchange(1, + join( + exchange(2, tableScan("T1")), + exchange(3, + join( + exchange(4, tableScan("T1")), + exchange(5, tableScan("T1")) + ) + ) + ) + ) + ); + GroupedStages groupedStages = EquivalentStagesFinder.findEquivalentStages(stage(0)); + assertEquals(groupedStages.toString(), "[[0], [1], [2, 4, 5], [3]]"); + + MailboxSendNode rootStage = stage(0); + EquivalentStagesReplacer.replaceEquivalentStages(rootStage, groupedStages); + + cleanup(); + SpoolBuilder readT1 = new SpoolBuilder(2, tableScan("T1")); + MailboxSendNode expected = when(// stage 0 + exchange(1, + join( + readT1.newReceiver(), + exchange(3, + join( + readT1.newReceiver(), + exchange(5, tableScan("T1")) + ) + ) + ) + ) + ); + assertEqualPlan(rootStage, expected); + } +} diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/StagesTestBase.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/StagesTestBase.java index 93fc109583f0..3735a829766e 100644 --- a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/StagesTestBase.java +++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/logical/StagesTestBase.java @@ -18,20 +18,29 @@ */ package org.apache.pinot.query.planner.logical; +import com.google.common.base.Function; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import java.util.Collections; import java.util.HashMap; +import java.util.IdentityHashMap; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Set; import javax.annotation.Nullable; +import org.apache.calcite.rel.RelDistribution; +import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.core.JoinRelType; +import org.apache.pinot.calcite.rel.logical.PinotRelExchangeType; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.query.planner.plannode.JoinNode; import org.apache.pinot.query.planner.plannode.MailboxReceiveNode; import org.apache.pinot.query.planner.plannode.MailboxSendNode; import org.apache.pinot.query.planner.plannode.PlanNode; +import org.apache.pinot.query.planner.plannode.PlanNodeVisitor; import org.apache.pinot.query.planner.plannode.TableScanNode; +import org.testng.Assert; import org.testng.annotations.AfterMethod; @@ -116,19 +125,41 @@ public SimpleChildBuilder join( * Although there are builder methods to create send and receive mailboxes separately, this method is recommended * because it deals with the stageId management and creates tests that are easier to read. */ - public SimpleChildBuilder exchange( + public ExchangeBuilder exchange( int nextStageId, SimpleChildBuilder childBuilder) { - return (stageId, mySchema, myHints) -> { - PlanNode input = childBuilder.build(nextStageId); - MailboxSendNode mailboxSendNode = new MailboxSendNode(nextStageId, null, List.of(input), stageId, null, null, - null, false, null, false); - MailboxSendNode old = _stageRoots.put(nextStageId, mailboxSendNode); - Preconditions.checkState(old == null, "Mailbox already exists for stageId: %s", nextStageId); - return new MailboxReceiveNode(stageId, null, nextStageId, null, null, null, null, - false, false, mailboxSendNode); + return new ExchangeBuilder() { + @Override + public MailboxReceiveNode build(int stageId, DataSchema dataSchema, PlanNode.NodeHint hints, + PinotRelExchangeType exchangeType, RelDistribution.Type distribution, List keys, + boolean prePartitioned, List collations, boolean sort, boolean sortedOnSender) { + PlanNode input = childBuilder.build(nextStageId); + MailboxSendNode mailboxSendNode = new MailboxSendNode(nextStageId, input.getDataSchema(), List.of(input), + stageId, exchangeType, distribution, keys, prePartitioned, collations, sort); + MailboxSendNode old = _stageRoots.put(nextStageId, mailboxSendNode); + Preconditions.checkState(old == null, "Mailbox already exists for stageId: %s", nextStageId); + return new MailboxReceiveNode(stageId, input.getDataSchema(), nextStageId, exchangeType, distribution, keys, + collations, sort, sortedOnSender, mailboxSendNode); + } }; } + public interface ExchangeBuilder extends SimpleChildBuilder { + MailboxReceiveNode build(int stageId, DataSchema dataSchema, PlanNode.NodeHint hints, + PinotRelExchangeType exchangeType, RelDistribution.Type distribution, List keys, + boolean prePartitioned, List collations, boolean sort, boolean sortedOnSender); + + default MailboxReceiveNode build(int stageId, DataSchema dataSchema, PlanNode.NodeHint hints) { + return build(stageId, null, null, null, null, null, false, null, false, false); + } + + default ExchangeBuilder withDistributionType(RelDistribution.Type distribution) { + return (stageId, dataSchema, hints, exchangeType, distribution1, keys, prePartitioned, collations, sort, + sortedOnSender) -> + build(stageId, dataSchema, hints, exchangeType, distribution, keys, prePartitioned, collations, sort, + sortedOnSender); + } + } + /** * Creates a table scan node with the given table name. */ @@ -159,8 +190,8 @@ public SimpleChildBuilder sendMailbox( int newStageId, SimpleChildBuilder childBuilder) { return (stageId, mySchema, myHints) -> { PlanNode input = childBuilder.build(stageId); - MailboxSendNode mailboxSendNode = new MailboxSendNode(newStageId, mySchema, List.of(input), stageId, null, null, - null, false, null, false); + MailboxSendNode mailboxSendNode = new MailboxSendNode(newStageId, mySchema, List.of(input), stageId, null, + null, null, false, null, false); MailboxSendNode old = _stageRoots.put(stageId, mailboxSendNode); Preconditions.checkState(old == null, "Mailbox already exists for stageId: %s", stageId); return mailboxSendNode; @@ -229,4 +260,216 @@ default P build(int stageId) { return build(stageId, null, null); } } + + /** + * A helper class that can be used to create a spool in the context of a test. + *

+ * These spools are used to create a single sender that will send data to multiple receivers. + * This class is just a helper to make it easier to create the sender and the receivers in a single fluent way during + * a test. A spool breaks by definition the idea that plan nodes are tree-like. Instead once spools are used, the + * plan nodes are a directed graph that should not have cycles. The latter is not enforced by this class but a + * responsibility of the test writer. + *

+ * Graphs are more complex to write in a nice readable way and require some mutation on the nodes that are created. + * In order to help, this class has two states: the initial state and the sealed state. When a new spool is created, + * it is in the initial state and can {@link #newReceiver()} can be called multiple times to create multiple + * receivers. Once one of these receivers is built, the spool is sealed and no more receivers can be created. + *

+ * Usually this class should be used in the following manner: + *

+ *

+   *   Spool readT1 = new Spool(3, tableScan("T1")); // here the spool is created
+   *   ExchangeBuilder builder = exchange(1,
+   *     join(
+   *       readT1.newReceiver(), // here a new receiver is created
+   *       readT1.newReceiver() // another receiver is created
+   *     )
+   *   );
+   *   // here the builder is called, which recursively calls the build method on the receivers, which seals the spool
+   *   when(builder);
+   * 
+ *

+ * + * Notice that usually the builder is not stored as a variable but directly used as argument to when. For example, + * {@code when(exchange(1, ...));}. This is completely fine and recommended. The snippet above splits the creation of + * the builder from the call to when to make it easier to understand the flow of the test. + *

+ * This means that if more than one spool is needed in a test, the test writer should create multiple instances of + * this class. + */ + public static class SpoolBuilder { + private final int _senderStageId; + /** + * The set of receiver builders. A new element is added every time {@link #newReceiver()} is called. + * When the first builder is built, {@link #seal()} is called, which creates the sender node. + */ + private final Set _receiverBuilder = Collections.newSetFromMap(new IdentityHashMap<>()); + private MailboxSendNode _sender; + private final SimpleChildBuilder _childBuilder; + + /** + * Creates a new spool with the given sender stage id and child builder. + * + * The child builder will be used to create the child node that will generate the data that will be sent to the + * multiple receivers. + */ + public SpoolBuilder(int senderStageId, SimpleChildBuilder spoolChildBuilder) { + _senderStageId = senderStageId; + _childBuilder = spoolChildBuilder; + } + + /** + * Returns the sender node for this spool. + * + * This method can only be called after the spool is sealed, otherwise the sender won't be available and this method + * will fail with an exception. + */ + public MailboxSendNode getSender() { + Preconditions.checkState(isSealed(), "Spool not sealed"); + return _sender; + } + + /** + * Returns whether the spool is sealed or not. + */ + public boolean isSealed() { + return _sender != null; + } + + /** + * Creates a new receiver builder that can be used to create a new receiver for this spool. + * + * This method is similar to other builder methods (like {@link #tableScan(String)} or + * {@link #join(SimpleChildBuilder, SimpleChildBuilder)}) and can be called multiple times to create multiple + * receivers. + * + * In most scenarios, the overloaded method {@link #newReceiver()} is good enough. This method is useful when the + * test writer wants to customize the receiver in some way (for example, changing the data schema or hints). + * The customize function will be called with a base builder that creates the receiver with the same data schema + * as the server and no hints. + */ + public SimpleChildBuilder newReceiver( + Function, SimpleChildBuilder> customize) { + Preconditions.checkState(!isSealed(), "Spool already sealed"); + + SpoolReceiverBuilder spoolReceiverBuilder = new SpoolReceiverBuilder(customize); + + _receiverBuilder.add(spoolReceiverBuilder); + return spoolReceiverBuilder; + } + + + /** + * Creates a new receiver builder that can be used to create a new receiver for this spool. + * + * This method is similar to other builder methods (like {@link #tableScan(String)} or + * {@link #join(SimpleChildBuilder, SimpleChildBuilder)}) and can be called multiple times to create multiple + * receivers. + * + * This method creates a receiver with the same data schema as the sender and no hints. In case the test writer + * wants to customize the receiver, the method {@link #newReceiver(Function)} should be used. + */ + public SimpleChildBuilder newReceiver() { + return newReceiver(a -> a); + } + + private void seal() { + if (isSealed()) { // for simplicity the seal method may be called multiple times + return; + } + + PlanNode input = _childBuilder.build(_senderStageId); + DataSchema mySchema = input.getDataSchema(); + _sender = new MailboxSendNode(_senderStageId, mySchema, List.of(input), null, + null, null, false, null, false); + } + + /** + * This is the internal class returned as a result of the {@link #newReceiver(Function)} method. + * + * They don't just create the receiver, but also end up sealing the spool and modify the sender to add the receiver + * to the list of receivers. + */ + private class SpoolReceiverBuilder implements SimpleChildBuilder { + @Nullable + private MailboxReceiveNode _receiver; + private final Function, SimpleChildBuilder> _customize; + + public SpoolReceiverBuilder( + Function, SimpleChildBuilder> customize) { + _customize = customize; + } + + @Override + public MailboxReceiveNode build(int stageId, @Nullable DataSchema dataSchema, @Nullable PlanNode.NodeHint hints) { + Preconditions.checkState(dataSchema == null, "Data schema for spool must be set internally"); + Preconditions.checkState(hints == null, "Hints for spool must be set internally"); + if (_receiver == null) { + seal(); + SimpleChildBuilder baseBuilder = (currentStageId, ignoreSchema, ignoreHints) -> { + DataSchema mySchema = _sender.getDataSchema(); + return new MailboxReceiveNode(currentStageId, mySchema, _senderStageId, null, null, null, null, false, + false, _sender); + }; + SimpleChildBuilder receiveBuilder = _customize.apply(baseBuilder); + _receiver = receiveBuilder.build(stageId); + _sender.addReceiver(_receiver); + } + Preconditions.checkState(_receiver.getStageId() == stageId, "Receiver stageId mismatch. " + + "Expected %s, received %s", _receiver.getStageId(), stageId); + assert _receiver != null; + return _receiver; + } + } + } + + public void assertEqualPlan(PlanNode actual, PlanNode expected) { + if (expected == null || actual == null) { + if (expected == null && actual == null) { + return; + } + throw new AssertionError("Expected: \n" + expected + ", actual: \n" + actual); + } + if (Objects.equals(expected, actual)) { + return; + } + Assert.fail("Expected: \n" + explainNode(expected) + ", actual: \n" + explainNode(actual)); + } + + private String explainNode(PlanNode node) { + StringBuilder sb = new StringBuilder(); + NodePrinter nodePrinter = new NodePrinter(sb); + node.visit(nodePrinter, null); + return sb.toString(); + } + + private static class NodePrinter extends PlanNodeVisitor.DepthFirstVisitor { + private final StringBuilder _builder; + private int _indent; + + public NodePrinter(StringBuilder builder) { + _builder = builder; + } + + @Override + protected Void preChildren(PlanNode node, Void context) { + int stageId = node.getStageId(); + for (int i = 0; i < _indent; i++) { + _builder.append(" "); + } + _builder.append('[') + .append(stageId) + .append("]: ") + .append(node.explain()) + .append('\n'); + _indent++; + return null; + } + + @Override + protected Void postChildren(PlanNode node, Void context) { + _indent--; + return super.postChildren(node, context); + } + } }