Skip to content

Commit

Permalink
Decoupled planning: improve join support (apache#17039)
Browse files Browse the repository at this point in the history
There were some problematic cases

join branches are run with finalize=false instead of finalize=true like normal subqueries
this inconsistency is not good - but fixing it is a bigger thing
ensure that right hand sides of joins are always subqueries - or accessible globally
To achieve the above:

operand indexes were needed for the upstream reltree nodes in the generator
source unwrapping now takes the join situation into account as well
  • Loading branch information
kgyrtkirk authored Sep 18, 2024
1 parent dd8c7de commit d84d53c
Show file tree
Hide file tree
Showing 27 changed files with 2,363 additions and 826 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.druid.sql.calcite.rel.PartialDruidQuery;
import org.apache.druid.sql.calcite.rel.PartialDruidQuery.Stage;
import org.apache.druid.sql.calcite.rel.logical.DruidAggregate;
import org.apache.druid.sql.calcite.rel.logical.DruidJoin;
import org.apache.druid.sql.calcite.rel.logical.DruidLogicalNode;
import org.apache.druid.sql.calcite.rel.logical.DruidSort;

Expand All @@ -58,32 +59,95 @@ public DruidQueryGenerator(PlannerContext plannerContext, DruidLogicalNode relRo
this.vertexFactory = new PDQVertexFactory(plannerContext, rexBuilder);
}

/**
* Tracks the upstream nodes during traversal.
*
* Its main purpose is to provide access to parent nodes;
* so that context sensitive logics can be formalized with it.
*/
static class DruidNodeStack
{
static class Entry
{
public final DruidLogicalNode node;
public final int operandIndex;

public Entry(DruidLogicalNode node, int operandIndex)
{
this.node = node;
this.operandIndex = operandIndex;
}
}

Stack<Entry> stack = new Stack<>();

public void push(DruidLogicalNode item)
{
push(item, 0);
}

public void push(DruidLogicalNode item, int operandIndex)
{
stack.push(new Entry(item, operandIndex));
}

public void pop()
{
stack.pop();
}

public int size()
{
return stack.size();
}

public DruidLogicalNode peekNode()
{
return stack.peek().node;
}

public DruidLogicalNode parentNode()
{
return getNode(1).node;
}

public Entry getNode(int i)
{
return stack.get(stack.size() - 1 - i);
}

public int peekOperandIndex()
{
return stack.peek().operandIndex;
}
}

public DruidQuery buildQuery()
{
Stack<DruidLogicalNode> stack = new Stack<>();
DruidNodeStack stack = new DruidNodeStack();
stack.push(relRoot);
Vertex vertex = buildVertexFor(stack);
return vertex.buildQuery(true);
}

private Vertex buildVertexFor(Stack<DruidLogicalNode> stack)
private Vertex buildVertexFor(DruidNodeStack stack)
{
List<Vertex> newInputs = new ArrayList<>();

for (RelNode input : stack.peek().getInputs()) {
stack.push((DruidLogicalNode) input);
for (RelNode input : stack.peekNode().getInputs()) {
stack.push((DruidLogicalNode) input, newInputs.size());
newInputs.add(buildVertexFor(stack));
stack.pop();
}
Vertex vertex = processNodeWithInputs(stack, newInputs);
return vertex;
}

private Vertex processNodeWithInputs(Stack<DruidLogicalNode> stack, List<Vertex> newInputs)
private Vertex processNodeWithInputs(DruidNodeStack stack, List<Vertex> newInputs)
{
DruidLogicalNode node = stack.peek();
DruidLogicalNode node = stack.peekNode();
if (node instanceof SourceDescProducer) {
return vertexFactory.createVertex(PartialDruidQuery.create(node), newInputs);
return vertexFactory.createVertex(stack, PartialDruidQuery.create(node), newInputs);
}
if (newInputs.size() == 1) {
Vertex inputVertex = newInputs.get(0);
Expand All @@ -92,6 +156,7 @@ private Vertex processNodeWithInputs(Stack<DruidLogicalNode> stack, List<Vertex>
return newVertex.get();
}
inputVertex = vertexFactory.createVertex(
stack,
PartialDruidQuery.createOuterQuery(((PDQVertex) inputVertex).partialDruidQuery, vertexFactory.plannerContext),
ImmutableList.of(inputVertex)
);
Expand All @@ -116,7 +181,7 @@ private interface Vertex
/**
* Extends the current vertex to include the specified parent.
*/
Optional<Vertex> extendWith(Stack<DruidLogicalNode> stack);
Optional<Vertex> extendWith(DruidNodeStack stack);

/**
* Decides wether this {@link Vertex} can be unwrapped into an {@link SourceDesc}.
Expand All @@ -133,6 +198,42 @@ private interface Vertex
SourceDesc unwrapSourceDesc();
}

enum JoinSupportTweaks
{
NONE,
LEFT,
RIGHT;

static JoinSupportTweaks analyze(DruidNodeStack stack)
{
if (stack.size() < 2) {
return NONE;
}
DruidLogicalNode possibleJoin = stack.parentNode();
if (!(possibleJoin instanceof DruidJoin)) {
return NONE;
}
if (stack.peekOperandIndex() == 0) {
return LEFT;
} else {
return RIGHT;
}
}

boolean finalizeSubQuery()
{
return this == NONE;
}

boolean forceSubQuery(SourceDesc sourceDesc)
{
if (sourceDesc.dataSource.isGlobal()) {
return false;
}
return this == RIGHT;
}
}

/**
* {@link PartialDruidQuery} based {@link Vertex} factory.
*/
Expand All @@ -147,20 +248,23 @@ public PDQVertexFactory(PlannerContext plannerContext, RexBuilder rexBuilder)
this.rexBuilder = rexBuilder;
}

Vertex createVertex(PartialDruidQuery partialDruidQuery, List<Vertex> inputs)
Vertex createVertex(DruidNodeStack stack, PartialDruidQuery partialDruidQuery, List<Vertex> inputs)
{
return new PDQVertex(partialDruidQuery, inputs);
JoinSupportTweaks jst = JoinSupportTweaks.analyze(stack);
return new PDQVertex(partialDruidQuery, inputs, jst);
}

public class PDQVertex implements Vertex
{
final PartialDruidQuery partialDruidQuery;
final List<Vertex> inputs;
final JoinSupportTweaks jst;

public PDQVertex(PartialDruidQuery partialDruidQuery, List<Vertex> inputs)
public PDQVertex(PartialDruidQuery partialDruidQuery, List<Vertex> inputs, JoinSupportTweaks jst)
{
this.partialDruidQuery = partialDruidQuery;
this.inputs = inputs;
this.jst = jst;
}

@Override
Expand All @@ -172,7 +276,7 @@ public DruidQuery buildQuery(boolean topLevel)
source.rowSignature,
plannerContext,
rexBuilder,
!topLevel
!(topLevel) && jst.finalizeSubQuery()
);
}

Expand Down Expand Up @@ -207,21 +311,22 @@ private SourceDesc getSource()
* Extends the the current partial query with the new parent if possible.
*/
@Override
public Optional<Vertex> extendWith(Stack<DruidLogicalNode> stack)
public Optional<Vertex> extendWith(DruidNodeStack stack)
{
Optional<PartialDruidQuery> newPartialQuery = extendPartialDruidQuery(stack);
if (!newPartialQuery.isPresent()) {
return Optional.empty();

}
return Optional.of(createVertex(newPartialQuery.get(), inputs));
return Optional.of(createVertex(stack, newPartialQuery.get(), inputs));
}

/**
* Merges the given {@link RelNode} into the current {@link PartialDruidQuery}.
*/
private Optional<PartialDruidQuery> extendPartialDruidQuery(Stack<DruidLogicalNode> stack)
private Optional<PartialDruidQuery> extendPartialDruidQuery(DruidNodeStack stack)
{
DruidLogicalNode parentNode = stack.peek();
DruidLogicalNode parentNode = stack.peekNode();
if (accepts(stack, Stage.WHERE_FILTER, Filter.class)) {
PartialDruidQuery newPartialQuery = partialDruidQuery.withWhereFilter((Filter) parentNode);
return Optional.of(newPartialQuery);
Expand Down Expand Up @@ -261,12 +366,12 @@ private Optional<PartialDruidQuery> extendPartialDruidQuery(Stack<DruidLogicalNo
return Optional.empty();
}

private boolean accepts(Stack<DruidLogicalNode> stack, Stage stage, Class<? extends RelNode> clazz)
private boolean accepts(DruidNodeStack stack, Stage stage, Class<? extends RelNode> clazz)
{
DruidLogicalNode currentNode = stack.peek();
DruidLogicalNode currentNode = stack.peekNode();
if (Project.class == clazz && stack.size() >= 2) {
// peek at parent and postpone project for next query stage
DruidLogicalNode parentNode = stack.get(stack.size() - 2);
DruidLogicalNode parentNode = stack.parentNode();
if (stage.ordinal() > Stage.AGGREGATE.ordinal()
&& parentNode instanceof DruidAggregate
&& !partialDruidQuery.canAccept(Stage.AGGREGATE)) {
Expand Down Expand Up @@ -295,6 +400,9 @@ public SourceDesc unwrapSourceDesc()
@Override
public boolean canUnwrapSourceDesc()
{
if (jst.forceSubQuery(getSource())) {
return false;
}
if (partialDruidQuery.stage() == Stage.SCAN) {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ public void testExactTopNOnInnerJoinWithLimit()
);
}

@DecoupledTestConfig(quidemReason = QuidemTestCaseReason.EQUIV_PLAN_EXTRA_COLUMNS, separateDefaultModeTest = true)
@Test
@NotYetSupported(Modes.STACK_OVERFLOW)
public void testJoinOuterGroupByAndSubqueryHasLimit()
{
// Cannot vectorize JOIN operator.
Expand Down Expand Up @@ -321,7 +321,6 @@ public void testJoinOuterGroupByAndSubqueryHasLimit()

@MethodSource("provideQueryContexts")
@ParameterizedTest(name = "{0}")
@NotYetSupported(Modes.JOIN_TABLE_TABLE)
public void testJoinOuterGroupByAndSubqueryNoLimit(Map<String, Object> queryContext)
{
// Fully removing the join allows this query to vectorize.
Expand Down Expand Up @@ -405,7 +404,6 @@ public void testJoinOuterGroupByAndSubqueryNoLimit(Map<String, Object> queryCont
}

@Test
@NotYetSupported(Modes.JOIN_TABLE_TABLE)
public void testJoinWithLimitBeforeJoining()
{
// Cannot vectorize JOIN operator.
Expand Down Expand Up @@ -1532,7 +1530,6 @@ public void testManyManyInnerJoinOnManyManyLookup(Map<String, Object> queryConte
);
}

@DecoupledTestConfig(quidemReason = QuidemTestCaseReason.FINALIZING_FIELD_ACCESS)
@MethodSource("provideQueryContexts")
@ParameterizedTest(name = "{0}")
public void testInnerJoinQueryOfLookup(Map<String, Object> queryContext)
Expand Down Expand Up @@ -1712,7 +1709,7 @@ public void testInnerJoinTwoLookupsToTableUsingNumericColumn(Map<String, Object>
);
}

@NotYetSupported(Modes.JOIN_TABLE_TABLE)
@DecoupledTestConfig(quidemReason = QuidemTestCaseReason.EQUIV_PLAN_CAST_MATERIALIZED_EARLIER)
@MethodSource("provideQueryContexts")
@ParameterizedTest(name = "{0}")
public void testInnerJoinTwoLookupsToTableUsingNumericColumnInReverse(Map<String, Object> queryContext)
Expand Down Expand Up @@ -1770,7 +1767,6 @@ public void testInnerJoinTwoLookupsToTableUsingNumericColumnInReverse(Map<String
);
}

@NotYetSupported(Modes.JOIN_TABLE_TABLE)
@MethodSource("provideQueryContexts")
@ParameterizedTest(name = "{0}")
public void testInnerJoinLookupTableTable(Map<String, Object> queryContext)
Expand Down Expand Up @@ -1853,7 +1849,6 @@ public void testInnerJoinLookupTableTable(Map<String, Object> queryContext)
);
}

@NotYetSupported(Modes.JOIN_TABLE_TABLE)
@MethodSource("provideQueryContexts")
@ParameterizedTest(name = "{0}")
public void testInnerJoinLookupTableTableChained(Map<String, Object> queryContext)
Expand Down Expand Up @@ -2082,7 +2077,7 @@ public void testCommaJoinTableLookupTableMismatchedTypes(Map<String, Object> que
);
}

@NotYetSupported(Modes.JOIN_TABLE_TABLE)
@DecoupledTestConfig(quidemReason = QuidemTestCaseReason.EQUIV_PLAN_CAST_MATERIALIZED_EARLIER)
@MethodSource("provideQueryContexts")
@ParameterizedTest(name = "{0}")
public void testJoinTableLookupTableMismatchedTypesWithoutComma(Map<String, Object> queryContext)
Expand Down Expand Up @@ -3729,7 +3724,6 @@ public void testLeftJoinSubqueryWithSelectorFilter(Map<String, Object> queryCont
);
}

@NotYetSupported(Modes.JOIN_TABLE_TABLE)
@MethodSource("provideQueryContexts")
@ParameterizedTest(name = "{0}")
public void testLeftJoinWithNotNullFilter(Map<String, Object> queryContext)
Expand Down Expand Up @@ -3777,7 +3771,6 @@ public void testLeftJoinWithNotNullFilter(Map<String, Object> queryContext)
);
}

@NotYetSupported(Modes.JOIN_TABLE_TABLE)
@MethodSource("provideQueryContexts")
@ParameterizedTest(name = "{0}")
public void testInnerJoin(Map<String, Object> queryContext)
Expand Down Expand Up @@ -3832,7 +3825,6 @@ public void testInnerJoin(Map<String, Object> queryContext)
);
}

@NotYetSupported(Modes.JOIN_TABLE_TABLE)
@MethodSource("provideQueryContexts")
@ParameterizedTest(name = "{0}")
public void testJoinWithExplicitIsNotDistinctFromCondition(Map<String, Object> queryContext)
Expand Down Expand Up @@ -5845,7 +5837,6 @@ public void testRegressionFilteredAggregatorsSubqueryJoins(Map<String, Object> q

@SqlTestFrameworkConfig.MinTopNThreshold(1)
@Test
@NotYetSupported(Modes.JOIN_TABLE_TABLE)
public void testJoinWithAliasAndOrderByNoGroupBy()
{
Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import org.apache.calcite.rel.rules.CoreRules;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
Expand Down Expand Up @@ -87,9 +85,18 @@ enum QuidemTestCaseReason
*/
DEFINETLY_WORSE_PLAN,
/**
* A new {@link FinalizingFieldAccessPostAggregator} appeared in the plan.
* Some extra unused columns are being projected.
*
* Example: ScanQuery over a join projects columns=[dim2, j0.m1, m1, m2] instead of just columns=[dim2, m2]
*/
EQUIV_PLAN_EXTRA_COLUMNS,
/**
* Materialization of a CAST was pushed down to a join branch
*
* instead of joining on condition (CAST("j0.k", 'DOUBLE') == "_j0.m1")
* a vc was computed for CAST("j0.k", 'DOUBLE')
*/
FINALIZING_FIELD_ACCESS;
EQUIV_PLAN_CAST_MATERIALIZED_EARLIER;

public boolean isPresent()
{
Expand Down
Loading

0 comments on commit d84d53c

Please sign in to comment.