Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spool] Introduce stage replacer and change send nodes to be able to send to more than one stage #14495

Merged
merged 12 commits into from
Nov 21, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,15 @@ public boolean areEquivalent(MailboxSendNode stage, MailboxSendNode visitedStage
return areBaseNodesEquivalent(stage, visitedStage)
// Commented out fields are used in equals() method of MailboxSendNode but not needed for equivalence.
// Receiver stage is not important for equivalence
// && node1.getReceiverStageId() == that.getReceiverStageId()
// && stage.getReceiverStageId() == visitedStage.getReceiverStageId()
&& stage.getExchangeType() == visitedStage.getExchangeType()
// Distribution type is not needed for equivalence. We deal with difference distribution types in the
// spooling logic.
// && Objects.equals(node1.getDistributionType(), that.getDistributionType())
// TODO: Distribution type not needed for equivalence in the first substituted send nodes. Their different
// distribution can be implemented in synthetic stages. But it is important in recursive send nodes
// (a send node that is equivalent to another but where both of them send to stages that are also
// equivalent).
// This makes the equivalence check more complex and therefore we are going to consider the distribution
// type in the equivalence check.
&& Objects.equals(stage.getDistributionType(), visitedStage.getDistributionType())
// TODO: Keys could probably be removed from the equivalence check, but would require to verify both
// keys are present in the data schema. We are not doing that for now.
&& Objects.equals(stage.getKeys(), visitedStage.getKeys())
Expand Down Expand Up @@ -220,9 +224,7 @@ public Boolean visitMailboxReceive(MailboxReceiveNode node1, PlanNode node2) {
// TODO: Keys should probably be removed from the equivalence check, but would require to verify both
// keys are present in the data schema. We are not doing that for now.
&& Objects.equals(node1.getKeys(), that.getKeys())
// Distribution type is not needed for equivalence. We deal with difference distribution types in the
// spooling logic.
// && node1.getDistributionType() == that.getDistributionType()
&& node1.getDistributionType() == that.getDistributionType()
// TODO: Sort, sort on sender and collations can probably be removed from the equivalence check, but would
// require some extra checks or transformation on the spooling logic. We are not doing that for now.
&& node1.isSort() == that.isSort()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.query.planner.logical;

import org.apache.pinot.query.planner.plannode.MailboxReceiveNode;
import org.apache.pinot.query.planner.plannode.MailboxSendNode;
import org.apache.pinot.query.planner.plannode.PlanNode;
import org.apache.pinot.query.planner.plannode.PlanNodeVisitor;


/**
* EquivalentStageReplacer is used to replace equivalent stages in the query plan.
*
* Given a {@link org.apache.pinot.query.planner.plannode.PlanNode} and a
* {@link GroupedStages}, modifies the plan node to replace equivalent stages.
*
* For each {@link MailboxReceiveNode} in the plan, if the sender is not the leader of the group,
* replaces the sender with the leader.
* The leader is also updated to include the receiver in its list of receivers.
*/
public class EquivalentStagesReplacer {
private EquivalentStagesReplacer() {
}

/**
* Replaces the equivalent stages in the query plan.
*
* @param root Root plan node
* @param equivalentStages Equivalent stages
*/
public static void replaceEquivalentStages(PlanNode root, GroupedStages equivalentStages) {
root.visit(Replacer.INSTANCE, equivalentStages);
}

private static class Replacer extends PlanNodeVisitor.DepthFirstVisitor<Void, GroupedStages> {
yashmayya marked this conversation as resolved.
Show resolved Hide resolved
private static final Replacer INSTANCE = new Replacer();

private Replacer() {
}

@Override
public Void visitMailboxReceive(MailboxReceiveNode node, GroupedStages equivalenceGroups) {
MailboxSendNode sender = node.getSender();
MailboxSendNode leader = equivalenceGroups.getGroup(sender).first();
if (canSubstitute(sender, leader)) {
// we don't want to visit the children of the node given it is going to be pruned
node.setSender(leader);
leader.addReceiver(node);
} else {
visitMailboxSend(leader, equivalenceGroups);
}
return null;
}

private boolean canSubstitute(MailboxSendNode actualSender, MailboxSendNode leader) {
return actualSender != leader // we don't need to replace the leader with itself
// the leader is already sending to this stage. Given we don't have the ability to send to multiple
// receivers in the same stage, we cannot optimize this case right now.
// If this case seems to be useful, it can be supported in the future.
&& !leader.sharesReceiverStages(actualSender);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Comparator;
import java.util.IdentityHashMap;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -77,6 +78,8 @@ public abstract SortedSet<MailboxSendNode> getGroup(MailboxSendNode stage)
*/
public abstract SortedSet<SortedSet<MailboxSendNode>> getGroups();

public abstract Set<MailboxSendNode> getStages();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small doc comment might be useful here to distinguish this method from getLeaders / getGroups.


@Override
public String toString() {
String content = getGroups().stream()
Expand Down Expand Up @@ -154,6 +157,11 @@ public boolean containsStage(MailboxSendNode stage) {
return _stageToGroup.containsKey(stage);
}

@Override
public Set<MailboxSendNode> getStages() {
return _stageToGroup.keySet();
}

@Override
public SortedSet<MailboxSendNode> getGroup(MailboxSendNode stage)
throws NoSuchElementException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class MailboxReceiveNode extends BasePlanNode {
private final boolean _sortedOnSender;

// NOTE: This is only available during query planning, and should not be serialized.
private final transient MailboxSendNode _sender;
private transient MailboxSendNode _sender;

// NOTE: null List is converted to empty List because there is no way to differentiate them in proto during ser/de.
public MailboxReceiveNode(int stageId, DataSchema dataSchema, int senderStageId,
Expand All @@ -57,6 +57,9 @@ public MailboxReceiveNode(int stageId, DataSchema dataSchema, int senderStageId,
}

public int getSenderStageId() {
if (_sender != null) {
return _sender.getStageId();
}
return _senderStageId;
}

Expand Down Expand Up @@ -93,6 +96,10 @@ public MailboxSendNode getSender() {
return _sender;
}

public void setSender(MailboxSendNode sender) {
_sender = sender;
}
gortiz marked this conversation as resolved.
Show resolved Hide resolved

@Override
public String explain() {
return "MAIL_RECEIVE(" + _distributionType + ")";
Expand Down Expand Up @@ -126,7 +133,8 @@ public boolean equals(Object o) {
return false;
}
MailboxReceiveNode that = (MailboxReceiveNode) o;
return _senderStageId == that._senderStageId && _sort == that._sort && _sortedOnSender == that._sortedOnSender
return getSenderStageId() == that.getSenderStageId() && _sort == that._sort
&& _sortedOnSender == that._sortedOnSender
&& _exchangeType == that._exchangeType && _distributionType == that._distributionType && Objects.equals(_keys,
that._keys) && Objects.equals(_collations, that._collations);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
*/
package org.apache.pinot.query.planner.plannode;

import com.google.common.base.Preconditions;
import java.util.BitSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import javax.annotation.Nullable;
Expand All @@ -28,7 +31,7 @@


public class MailboxSendNode extends BasePlanNode {
private final int _receiverStageId;
private final BitSet _receiverStages;
gortiz marked this conversation as resolved.
Show resolved Hide resolved
private final PinotRelExchangeType _exchangeType;
private RelDistribution.Type _distributionType;
private final List<Integer> _keys;
Expand All @@ -37,11 +40,12 @@ public class MailboxSendNode extends BasePlanNode {
private final boolean _sort;

// NOTE: null List is converted to empty List because there is no way to differentiate them in proto during ser/de.
public MailboxSendNode(int stageId, DataSchema dataSchema, List<PlanNode> inputs, int receiverStageId,
PinotRelExchangeType exchangeType, RelDistribution.Type distributionType, @Nullable List<Integer> keys,
boolean prePartitioned, @Nullable List<RelFieldCollation> collations, boolean sort) {
private MailboxSendNode(int stageId, DataSchema dataSchema, List<PlanNode> inputs,
BitSet receiverStages, PinotRelExchangeType exchangeType,
RelDistribution.Type distributionType, @Nullable List<Integer> keys, boolean prePartitioned,
@Nullable List<RelFieldCollation> collations, boolean sort) {
super(stageId, dataSchema, null, inputs);
_receiverStageId = receiverStageId;
_receiverStages = receiverStages;
_exchangeType = exchangeType;
_distributionType = distributionType;
_keys = keys != null ? keys : List.of();
Expand All @@ -50,8 +54,74 @@ public MailboxSendNode(int stageId, DataSchema dataSchema, List<PlanNode> inputs
_sort = sort;
}

public MailboxSendNode(int stageId, DataSchema dataSchema, List<PlanNode> inputs,
int receiverStage, PinotRelExchangeType exchangeType,
RelDistribution.Type distributionType, @Nullable List<Integer> keys, boolean prePartitioned,
@Nullable List<RelFieldCollation> collations, boolean sort) {
this(stageId, dataSchema, inputs, toBitSet(receiverStage), exchangeType, distributionType, keys, prePartitioned,
collations, sort);
}

private static BitSet toBitSet(int receiverStage) {
BitSet bitSet = new BitSet(receiverStage + 1);
bitSet.set(receiverStage);
return bitSet;
}

private static BitSet toBitSet(@Nullable List<Integer> receiverStages) {
BitSet bitSet = new BitSet();
if (receiverStages == null || receiverStages.isEmpty()) {
return bitSet;
}
for (int receiverStage : receiverStages) {
bitSet.set(receiverStage);
}
return bitSet;
}

public MailboxSendNode(int stageId, DataSchema dataSchema, List<PlanNode> inputs,
PinotRelExchangeType exchangeType, RelDistribution.Type distributionType, @Nullable List<Integer> keys,
boolean prePartitioned, @Nullable List<RelFieldCollation> collations, boolean sort) {
this(stageId, dataSchema, inputs, new BitSet(), exchangeType, distributionType, keys, prePartitioned, collations,
sort);
}

public boolean sharesReceiverStages(MailboxSendNode other) {
return _receiverStages.intersects(other._receiverStages);
}

/**
* Returns the receiver stage ids, sorted in ascending order.
*/
public Iterable<Integer> getReceiverStageIds() {
return () -> new Iterator<>() {
int _next = _receiverStages.nextSetBit(0);

@Override
public boolean hasNext() {
return _next >= 0;
}

@Override
public Integer next() {
int current = _next;
_next = _receiverStages.nextSetBit(_next + 1);
return current;
}
};
}

@Deprecated
public int getReceiverStageId() {
return _receiverStageId;
Preconditions.checkState(!_receiverStages.isEmpty(), "Receivers not set");
return _receiverStages.nextSetBit(0);
}

public void addReceiver(MailboxReceiveNode node) {
if (_receiverStages.get(node.getStageId())) {
throw new IllegalStateException("Receiver already added: " + node.getStageId());
}
_receiverStages.set(node.getStageId());
gortiz marked this conversation as resolved.
Show resolved Hide resolved
}

public PinotRelExchangeType getExchangeType() {
Expand Down Expand Up @@ -104,7 +174,7 @@ public <T, C> T visit(PlanNodeVisitor<T, C> visitor, C context) {

@Override
public PlanNode withInputs(List<PlanNode> inputs) {
return new MailboxSendNode(_stageId, _dataSchema, inputs, _receiverStageId, _exchangeType, _distributionType, _keys,
return new MailboxSendNode(_stageId, _dataSchema, inputs, _receiverStages, _exchangeType, _distributionType, _keys,
_prePartitioned, _collations, _sort);
}

Expand All @@ -120,22 +190,22 @@ public boolean equals(Object o) {
return false;
}
MailboxSendNode that = (MailboxSendNode) o;
return _receiverStageId == that._receiverStageId && _prePartitioned == that._prePartitioned && _sort == that._sort
&& _exchangeType == that._exchangeType && _distributionType == that._distributionType && Objects.equals(_keys,
that._keys) && Objects.equals(_collations, that._collations);
return Objects.equals(_receiverStages, that._receiverStages) && _prePartitioned == that._prePartitioned
&& _sort == that._sort && _exchangeType == that._exchangeType && _distributionType == that._distributionType
&& Objects.equals(_keys, that._keys) && Objects.equals(_collations, that._collations);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), _receiverStageId, _exchangeType, _distributionType, _keys, _prePartitioned,
return Objects.hash(super.hashCode(), _receiverStages, _exchangeType, _distributionType, _keys, _prePartitioned,
_collations, _sort);
}

@Override
public String toString() {
return "MailboxSendNode{"
+ "_stageId=" + _stageId
+ ", _receiverStageId=" + _receiverStageId
+ ", _receivers=" + _receiverStages
+ '}';
}
}
Loading
Loading