Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add NestedLoopJoin rel #188

Merged
merged 11 commits into from
Nov 3, 2023
35 changes: 35 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.ImmutableType;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
Expand Down Expand Up @@ -218,6 +219,30 @@ private NamedScan namedScan(
return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build();
}

public NestedLoopJoin nestedLoopJoin(
Function<JoinInput, Expression> conditionFn,
NestedLoopJoin.JoinType joinType,
Rel left,
Rel right) {
return nestedLoopJoin(conditionFn, joinType, Optional.empty(), left, right);
}

private NestedLoopJoin nestedLoopJoin(
Function<JoinInput, Expression> conditionFn,
NestedLoopJoin.JoinType joinType,
Optional<Rel.Remap> remap,
Rel left,
Rel right) {
var condition = conditionFn.apply(new JoinInput(left, right));
return NestedLoopJoin.builder()
.left(left)
.right(right)
.condition(condition)
.joinType(joinType)
.remap(remap)
.build();
}

public Project project(Function<Rel, Iterable<? extends Expression>> expressionsFn, Rel input) {
return project(expressionsFn, Optional.empty(), input);
}
Expand Down Expand Up @@ -286,6 +311,16 @@ public List<FieldReference> fieldReferences(Rel input, int... indexes) {
.collect(java.util.stream.Collectors.toList());
}

public FieldReference fieldReference(List<Rel> inputs, int index) {
return ImmutableFieldReference.newInputRelReference(index, inputs);
}

public List<FieldReference> fieldReferences(List<Rel> inputs, int... indexes) {
return Arrays.stream(indexes)
.mapToObj(index -> fieldReference(inputs, index))
.collect(java.util.stream.Collectors.toList());
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call on these, they will be quite helpful.


public Expression cast(Expression input, Type type) {
return Cast.builder()
.input(input)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.relation;

import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.NestedLoopJoin;

public abstract class AbstractRelVisitor<OUTPUT, EXCEPTION extends Exception>
implements RelVisitor<OUTPUT, EXCEPTION> {
Expand Down Expand Up @@ -31,6 +32,11 @@ public OUTPUT visit(Join join) throws EXCEPTION {
return visitFallback(join);
}

@Override
public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
return visitFallback(nestedLoopJoin);
}

@Override
public OUTPUT visit(Set set) throws EXCEPTION {
return visitFallback(set);
Expand Down
32 changes: 32 additions & 0 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.substrait.proto.FilterRel;
import io.substrait.proto.HashJoinRel;
import io.substrait.proto.JoinRel;
import io.substrait.proto.NestedLoopJoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
import io.substrait.proto.SetRel;
Expand All @@ -27,6 +28,7 @@
import io.substrait.relation.files.ImmutableFileFormat;
import io.substrait.relation.files.ImmutableFileOrFiles;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.ImmutableNamedStruct;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
Expand Down Expand Up @@ -77,6 +79,9 @@ public Rel from(io.substrait.proto.Rel rel) {
case JOIN -> {
return newJoin(rel.getJoin());
}
case NESTED_LOOP_JOIN -> {
return newNestedLoopJoin(rel.getNestedLoopJoin());
}
case SET -> {
return newSet(rel.getSet());
}
Expand Down Expand Up @@ -532,6 +537,33 @@ private Rel newHashJoin(HashJoinRel rel) {
return builder.build();
}

private NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) {
Rel left = from(rel.getLeft());
Rel right = from(rel.getRight());
Type.Struct leftStruct = left.getRecordType();
Type.Struct rightStruct = right.getRecordType();
Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build();
var converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this);
var builder =
NestedLoopJoin.builder()
.left(left)
.right(right)
.condition(
// defaults to true (aka cartesian join) if the join expression is missing
rel.hasExpression()
? converter.from(rel.getExpression())
: Expression.BoolLiteral.builder().value(true).build())
.joinType(NestedLoopJoin.JoinType.fromProto(rel.getType()));

builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()));
if (rel.hasAdvancedExtension()) {
builder.extension(advancedExtension(rel.getAdvancedExtension()));
}
return builder.build();
}

private static Optional<Rel.Remap> optionalRelmap(io.substrait.proto.RelCommon relCommon) {
return Optional.ofNullable(
relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.ImmutableHashJoin;
import io.substrait.relation.physical.ImmutableNestedLoopJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.Type;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -120,6 +122,23 @@ public Optional<Rel> visit(Join join) throws RuntimeException {
.build());
}

@Override
public Optional<Rel> visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
var left = nestedLoopJoin.getLeft().accept(this);
var right = nestedLoopJoin.getRight().accept(this);
var condition = visitExpression(nestedLoopJoin.getCondition());
if (allEmpty(left, right, condition)) {
return Optional.empty();
}
return Optional.of(
ImmutableNestedLoopJoin.builder()
.from(nestedLoopJoin)
.left(left.orElse(nestedLoopJoin.getLeft()))
.right(right.orElse(nestedLoopJoin.getRight()))
.condition(condition.orElse(nestedLoopJoin.getCondition()))
.build());
}

@Override
public Optional<Rel> visit(Set set) throws RuntimeException {
return transformList(set.getInputs(), t -> t.accept(this))
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/java/io/substrait/relation/RelProtoConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io.substrait.proto.FilterRel;
import io.substrait.proto.HashJoinRel;
import io.substrait.proto.JoinRel;
import io.substrait.proto.NestedLoopJoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
import io.substrait.proto.Rel;
Expand All @@ -24,6 +25,7 @@
import io.substrait.proto.SortRel;
import io.substrait.relation.files.FileOrFiles;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.Collection;
import java.util.List;
Expand Down Expand Up @@ -179,6 +181,20 @@ public Rel visit(Join join) throws RuntimeException {
return Rel.newBuilder().setJoin(builder).build();
}

@Override
public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
var builder =
NestedLoopJoinRel.newBuilder()
.setCommon(common(nestedLoopJoin))
.setLeft(toProto(nestedLoopJoin.getLeft()))
.setRight(toProto(nestedLoopJoin.getRight()))
.setExpression(toProto(nestedLoopJoin.getCondition()))
.setType(nestedLoopJoin.getJoinType().toProto());

nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
return Rel.newBuilder().setNestedLoopJoin(builder).build();
}

@Override
public Rel visit(Set set) throws RuntimeException {
var builder = SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto());
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/java/io/substrait/relation/RelVisitor.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.relation;

import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.NestedLoopJoin;

public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {
OUTPUT visit(Aggregate aggregate) throws EXCEPTION;
Expand All @@ -13,6 +14,8 @@ public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {

OUTPUT visit(Join join) throws EXCEPTION;

OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION;

OUTPUT visit(Set set) throws EXCEPTION;

OUTPUT visit(NamedScan namedScan) throws EXCEPTION;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package io.substrait.relation.physical;

import io.substrait.expression.Expression;
import io.substrait.proto.NestedLoopJoinRel;
import io.substrait.relation.BiRel;
import io.substrait.relation.HasExtension;
import io.substrait.relation.RelVisitor;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.stream.Stream;
import org.immutables.value.Value;

@Value.Immutable
public abstract class NestedLoopJoin extends BiRel implements HasExtension {

public abstract Expression getCondition();

public abstract JoinType getJoinType();

public static enum JoinType {
UNKNOWN(NestedLoopJoinRel.JoinType.JOIN_TYPE_UNSPECIFIED),
INNER(NestedLoopJoinRel.JoinType.JOIN_TYPE_INNER),
OUTER(NestedLoopJoinRel.JoinType.JOIN_TYPE_OUTER),
LEFT(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT),
RIGHT(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT),
LEFT_SEMI(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT_SEMI),
RIGHT_SEMI(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT_SEMI),
LEFT_ANTI(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT_ANTI),
RIGHT_ANTI(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT_ANTI);

private NestedLoopJoinRel.JoinType proto;

JoinType(NestedLoopJoinRel.JoinType proto) {
this.proto = proto;
}

public NestedLoopJoinRel.JoinType toProto() {
return proto;
}

public static JoinType fromProto(NestedLoopJoinRel.JoinType proto) {
for (var v : values()) {
if (v.proto == proto) {
return v;
}
}

throw new IllegalArgumentException("Unknown type: " + proto);
}
}

@Override
protected Type.Struct deriveRecordType() {
Stream<Type> leftTypes =
switch (getJoinType()) {
case RIGHT, OUTER -> getLeft().getRecordType().fields().stream()
.map(TypeCreator::asNullable);
case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty();
default -> getLeft().getRecordType().fields().stream();
};
Stream<Type> rightTypes =
switch (getJoinType()) {
case LEFT, OUTER -> getRight().getRecordType().fields().stream()
.map(TypeCreator::asNullable);
case LEFT_ANTI, LEFT_SEMI -> Stream.empty();
default -> getRight().getRecordType().fields().stream();
};
return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes));
}

@Override
public <O, E extends Exception> O accept(RelVisitor<O, E> visitor) throws E {
return visitor.visit(this);
}

public static ImmutableNestedLoopJoin.Builder builder() {
return ImmutableNestedLoopJoin.builder();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.substrait.relation.Sort;
import io.substrait.relation.VirtualTableScan;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.relation.utils.StringHolder;
import io.substrait.relation.utils.StringHolderHandlingProtoRelConverter;
import io.substrait.type.NamedStruct;
Expand Down Expand Up @@ -186,6 +187,19 @@ void hashJoin() {
verifyRoundTrip(relWithoutKeys);
}

@Test
void nestedLoopJoin() {
Rel rel =
NestedLoopJoin.builder()
.from(
b.nestedLoopJoin(
__ -> b.bool(true), NestedLoopJoin.JoinType.INNER, commonTable, commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
verifyRoundTrip(rel);
}

@Test
void project() {
Rel rel =
Expand Down
16 changes: 16 additions & 0 deletions core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.substrait.TestBase;
import io.substrait.relation.Rel;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import java.util.Arrays;
import java.util.List;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -31,4 +32,19 @@ void hashJoin() {
.build();
verifyRoundTrip(relWithoutKeys);
}

@Test
void nestedLoopJoin() {
List<Rel> inputRels = Arrays.asList(leftTable, rightTable);
Rel rel =
NestedLoopJoin.builder()
.from(
b.nestedLoopJoin(
__ -> b.equal(b.fieldReference(inputRels, 0), b.fieldReference(inputRels, 5)),
NestedLoopJoin.JoinType.INNER,
leftTable,
rightTable))
.build();
verifyRoundTrip(rel);
}
}