From cea494f23f722d1863281c7c857bd9bf6121463a Mon Sep 17 00:00:00 2001 From: Dane Pitkin Date: Tue, 10 Oct 2023 19:09:17 -0400 Subject: [PATCH 01/11] feat: add NestedLoopJoin rel --- .../io/substrait/dsl/SubstraitBuilder.java | 25 +++++++ .../relation/AbstractRelVisitor.java | 5 ++ .../io/substrait/relation/NestedLoopJoin.java | 75 +++++++++++++++++++ .../substrait/relation/ProtoRelConverter.java | 27 +++++++ .../relation/RelCopyOnWriteVisitor.java | 19 +++++ .../substrait/relation/RelProtoConverter.java | 16 ++++ .../io/substrait/relation/RelVisitor.java | 2 + .../type/proto/ExtensionRoundtripTest.java | 14 ++++ 8 files changed, 183 insertions(+) create mode 100644 core/src/main/java/io/substrait/relation/NestedLoopJoin.java diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 71e816ce..b85c05fe 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -19,6 +19,7 @@ import io.substrait.relation.Filter; import io.substrait.relation.Join; import io.substrait.relation.NamedScan; +import io.substrait.relation.NestedLoopJoin; import io.substrait.relation.Project; import io.substrait.relation.Rel; import io.substrait.relation.Set; @@ -218,6 +219,30 @@ private NamedScan namedScan( return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build(); } + public NestedLoopJoin nestedLoopJoin( + Function conditionFn, + NestedLoopJoin.JoinType joinType, + Rel left, + Rel right) { + return nestedLoopJoin(conditionFn, joinType, Optional.empty(), left, right); + } + + private NestedLoopJoin nestedLoopJoin( + Function conditionFn, + NestedLoopJoin.JoinType joinType, + Optional remap, + Rel left, + Rel right) { + var condition = conditionFn.apply(new JoinInput(left, right)); + return NestedLoopJoin.builder() + .left(left) + .right(right) + .condition(condition) + .joinType(joinType) + .remap(remap) + .build(); + } + public Project project(Function> expressionsFn, Rel input) { return project(expressionsFn, Optional.empty(), input); } diff --git a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java index 645f692e..82a5b054 100644 --- a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java +++ b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java @@ -31,6 +31,11 @@ public OUTPUT visit(Join join) throws EXCEPTION { return visitFallback(join); } + @Override + public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION { + return visitFallback(nestedLoopJoin); + } + @Override public OUTPUT visit(Set set) throws EXCEPTION { return visitFallback(set); diff --git a/core/src/main/java/io/substrait/relation/NestedLoopJoin.java b/core/src/main/java/io/substrait/relation/NestedLoopJoin.java new file mode 100644 index 00000000..8581601a --- /dev/null +++ b/core/src/main/java/io/substrait/relation/NestedLoopJoin.java @@ -0,0 +1,75 @@ +package io.substrait.relation; + +import io.substrait.expression.Expression; +import io.substrait.proto.NestedLoopJoinRel; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.Optional; +import java.util.stream.Stream; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class NestedLoopJoin extends BiRel implements HasExtension { + + public abstract Optional getCondition(); + + public abstract JoinType getJoinType(); + + public static enum JoinType { + UNKNOWN(NestedLoopJoinRel.JoinType.JOIN_TYPE_UNSPECIFIED), + INNER(NestedLoopJoinRel.JoinType.JOIN_TYPE_INNER), + OUTER(NestedLoopJoinRel.JoinType.JOIN_TYPE_OUTER), + LEFT(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT), + RIGHT(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT), + LEFT_SEMI(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT_SEMI), + RIGHT_SEMI(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT_SEMI), + LEFT_ANTI(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT_ANTI), + RIGHT_ANTI(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT_ANTI); + + private NestedLoopJoinRel.JoinType proto; + + JoinType(NestedLoopJoinRel.JoinType proto) { + this.proto = proto; + } + + public NestedLoopJoinRel.JoinType toProto() { + return proto; + } + + public static JoinType fromProto(NestedLoopJoinRel.JoinType proto) { + for (var v : values()) { + if (v.proto == proto) { + return v; + } + } + + throw new IllegalArgumentException("Unknown type: " + proto); + } + } + + @Override + protected Type.Struct deriveRecordType() { + Stream leftTypes = + switch (getJoinType()) { + case RIGHT, RIGHT_SEMI, RIGHT_ANTI, OUTER -> getLeft().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + default -> getLeft().getRecordType().fields().stream(); + }; + Stream rightTypes = + switch (getJoinType()) { + case LEFT, LEFT_SEMI, LEFT_ANTI, OUTER -> getRight().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + default -> getRight().getRecordType().fields().stream(); + }; + return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); + } + + @Override + public O accept(RelVisitor visitor) throws E { + return visitor.visit(this); + } + + public static ImmutableNestedLoopJoin.Builder builder() { + return ImmutableNestedLoopJoin.builder(); + } +} diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index bb479b83..51d64043 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -17,6 +17,7 @@ import io.substrait.proto.FilterRel; import io.substrait.proto.HashJoinRel; import io.substrait.proto.JoinRel; +import io.substrait.proto.NestedLoopJoinRel; import io.substrait.proto.ProjectRel; import io.substrait.proto.ReadRel; import io.substrait.proto.SetRel; @@ -77,6 +78,9 @@ public Rel from(io.substrait.proto.Rel rel) { case JOIN -> { return newJoin(rel.getJoin()); } + case NESTED_LOOP_JOIN -> { + return newNestedLoopJoin(rel.getNestedLoopJoin()); + } case SET -> { return newSet(rel.getSet()); } @@ -470,6 +474,29 @@ private Join newJoin(JoinRel rel) { return builder.build(); } + private NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { + Rel left = from(rel.getLeft()); + Rel right = from(rel.getRight()); + Type.Struct leftStruct = left.getRecordType(); + Type.Struct rightStruct = right.getRecordType(); + Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); + var converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + var builder = + NestedLoopJoin.builder() + .left(left) + .right(right) + .condition(converter.from(rel.getExpression())) + .joinType(NestedLoopJoin.JoinType.fromProto(rel.getType())); + + builder + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())); + if (rel.hasAdvancedExtension()) { + builder.extension(advancedExtension(rel.getAdvancedExtension())); + } + return builder.build(); + } + private Rel newCross(CrossRel rel) { Rel left = from(rel.getLeft()); Rel right = from(rel.getRight()); diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 0dddfbd9..077b1b4c 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -120,6 +120,25 @@ public Optional visit(Join join) throws RuntimeException { .build()); } + @Override + public Optional visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException { + var left = nestedLoopJoin.getLeft().accept(this); + var right = nestedLoopJoin.getRight().accept(this); + var condition = nestedLoopJoin.getCondition().flatMap(t -> visitExpression(t)); + if (allEmpty(left, right, condition)) { + return Optional.empty(); + } + return Optional.of( + ImmutableNestedLoopJoin.builder() + .from(nestedLoopJoin) + .left(left.orElse(nestedLoopJoin.getLeft())) + .right(right.orElse(nestedLoopJoin.getRight())) + .condition( + Optional.ofNullable( + condition.orElseGet(() -> nestedLoopJoin.getCondition().orElse(null)))) + .build()); + } + @Override public Optional visit(Set set) throws RuntimeException { return transformList(set.getInputs(), t -> t.accept(this)) diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 566ae5c3..1f1a3782 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -15,6 +15,7 @@ import io.substrait.proto.FilterRel; import io.substrait.proto.HashJoinRel; import io.substrait.proto.JoinRel; +import io.substrait.proto.NestedLoopJoinRel; import io.substrait.proto.ProjectRel; import io.substrait.proto.ReadRel; import io.substrait.proto.Rel; @@ -179,6 +180,21 @@ public Rel visit(Join join) throws RuntimeException { return Rel.newBuilder().setJoin(builder).build(); } + @Override + public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException { + var builder = + NestedLoopJoinRel.newBuilder() + .setCommon(common(nestedLoopJoin)) + .setLeft(toProto(nestedLoopJoin.getLeft())) + .setRight(toProto(nestedLoopJoin.getRight())) + .setType(nestedLoopJoin.getJoinType().toProto()); + + nestedLoopJoin.getCondition().ifPresent(t -> builder.setExpression(toProto(t))); + + nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); + return Rel.newBuilder().setNestedLoopJoin(builder).build(); + } + @Override public Rel visit(Set set) throws RuntimeException { var builder = SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto()); diff --git a/core/src/main/java/io/substrait/relation/RelVisitor.java b/core/src/main/java/io/substrait/relation/RelVisitor.java index e8e78aaf..44d28ddb 100644 --- a/core/src/main/java/io/substrait/relation/RelVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelVisitor.java @@ -13,6 +13,8 @@ public interface RelVisitor { OUTPUT visit(Join join) throws EXCEPTION; + OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION; + OUTPUT visit(Set set) throws EXCEPTION; OUTPUT visit(NamedScan namedScan) throws EXCEPTION; diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index 5417625b..584892e6 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -16,6 +16,7 @@ import io.substrait.relation.Join; import io.substrait.relation.LocalFiles; import io.substrait.relation.NamedScan; +import io.substrait.relation.NestedLoopJoin; import io.substrait.relation.Project; import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.Rel; @@ -186,6 +187,19 @@ void hashJoin() { verifyRoundTrip(relWithoutKeys); } + @Test + void nested_loop_join() { + Rel rel = + NestedLoopJoin.builder() + .from( + b.nestedLoopJoin( + __ -> b.bool(true), NestedLoopJoin.JoinType.INNER, commonTable, commonTable)) + .commonExtension(commonExtension) + .extension(relExtension) + .build(); + verifyRoundTrip(rel); + } + @Test void project() { Rel rel = From dc5c1fe00b0829d4989ab596e7143c9e90cdccb4 Mon Sep 17 00:00:00 2001 From: Dane Pitkin Date: Thu, 12 Oct 2023 16:59:48 -0400 Subject: [PATCH 02/11] fix: record type for anti, semi joins --- .../src/main/java/io/substrait/relation/NestedLoopJoin.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/NestedLoopJoin.java b/core/src/main/java/io/substrait/relation/NestedLoopJoin.java index 8581601a..ec1dfe7a 100644 --- a/core/src/main/java/io/substrait/relation/NestedLoopJoin.java +++ b/core/src/main/java/io/substrait/relation/NestedLoopJoin.java @@ -51,14 +51,16 @@ public static JoinType fromProto(NestedLoopJoinRel.JoinType proto) { protected Type.Struct deriveRecordType() { Stream leftTypes = switch (getJoinType()) { - case RIGHT, RIGHT_SEMI, RIGHT_ANTI, OUTER -> getLeft().getRecordType().fields().stream() + case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() .map(TypeCreator::asNullable); + case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty(); default -> getLeft().getRecordType().fields().stream(); }; Stream rightTypes = switch (getJoinType()) { - case LEFT, LEFT_SEMI, LEFT_ANTI, OUTER -> getRight().getRecordType().fields().stream() + case LEFT, OUTER -> getRight().getRecordType().fields().stream() .map(TypeCreator::asNullable); + case LEFT_ANTI, LEFT_SEMI -> Stream.empty(); default -> getRight().getRecordType().fields().stream(); }; return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); From 90e62d9d31a557a55f9e97c995bad5be8a7266d8 Mon Sep 17 00:00:00 2001 From: Dane Pitkin Date: Wed, 25 Oct 2023 15:53:43 -0400 Subject: [PATCH 03/11] fix: move nlj operator into physical subfolder --- core/src/main/java/io/substrait/dsl/SubstraitBuilder.java | 2 +- .../main/java/io/substrait/relation/AbstractRelVisitor.java | 1 + .../main/java/io/substrait/relation/ProtoRelConverter.java | 1 + .../java/io/substrait/relation/RelCopyOnWriteVisitor.java | 2 ++ .../main/java/io/substrait/relation/RelProtoConverter.java | 1 + core/src/main/java/io/substrait/relation/RelVisitor.java | 1 + .../io/substrait/relation/{ => physical}/NestedLoopJoin.java | 5 ++++- .../java/io/substrait/type/proto/ExtensionRoundtripTest.java | 2 +- 8 files changed, 12 insertions(+), 3 deletions(-) rename core/src/main/java/io/substrait/relation/{ => physical}/NestedLoopJoin.java (93%) diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index b85c05fe..17063c91 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -19,12 +19,12 @@ import io.substrait.relation.Filter; import io.substrait.relation.Join; import io.substrait.relation.NamedScan; -import io.substrait.relation.NestedLoopJoin; import io.substrait.relation.Project; import io.substrait.relation.Rel; import io.substrait.relation.Set; import io.substrait.relation.Sort; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.ImmutableType; import io.substrait.type.NamedStruct; import io.substrait.type.Type; diff --git a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java index 82a5b054..52a70bf3 100644 --- a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java +++ b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java @@ -1,6 +1,7 @@ package io.substrait.relation; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.NestedLoopJoin; public abstract class AbstractRelVisitor implements RelVisitor { diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 51d64043..3294f748 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -28,6 +28,7 @@ import io.substrait.relation.files.ImmutableFileFormat; import io.substrait.relation.files.ImmutableFileOrFiles; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.ImmutableNamedStruct; import io.substrait.type.NamedStruct; import io.substrait.type.Type; diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 077b1b4c..676c4e7b 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -8,6 +8,8 @@ import io.substrait.expression.ImmutableFieldReference; import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.ImmutableHashJoin; +import io.substrait.relation.physical.ImmutableNestedLoopJoin; +import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.Type; import java.util.ArrayList; import java.util.List; diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 1f1a3782..61a3b0fd 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -25,6 +25,7 @@ import io.substrait.proto.SortRel; import io.substrait.relation.files.FileOrFiles; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.proto.TypeProtoConverter; import java.util.Collection; import java.util.List; diff --git a/core/src/main/java/io/substrait/relation/RelVisitor.java b/core/src/main/java/io/substrait/relation/RelVisitor.java index 44d28ddb..38b70816 100644 --- a/core/src/main/java/io/substrait/relation/RelVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelVisitor.java @@ -1,6 +1,7 @@ package io.substrait.relation; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.NestedLoopJoin; public interface RelVisitor { OUTPUT visit(Aggregate aggregate) throws EXCEPTION; diff --git a/core/src/main/java/io/substrait/relation/NestedLoopJoin.java b/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java similarity index 93% rename from core/src/main/java/io/substrait/relation/NestedLoopJoin.java rename to core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java index ec1dfe7a..6d8bbcca 100644 --- a/core/src/main/java/io/substrait/relation/NestedLoopJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java @@ -1,7 +1,10 @@ -package io.substrait.relation; +package io.substrait.relation.physical; import io.substrait.expression.Expression; import io.substrait.proto.NestedLoopJoinRel; +import io.substrait.relation.BiRel; +import io.substrait.relation.HasExtension; +import io.substrait.relation.RelVisitor; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.util.Optional; diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index 584892e6..6258a253 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -16,7 +16,6 @@ import io.substrait.relation.Join; import io.substrait.relation.LocalFiles; import io.substrait.relation.NamedScan; -import io.substrait.relation.NestedLoopJoin; import io.substrait.relation.Project; import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.Rel; @@ -24,6 +23,7 @@ import io.substrait.relation.Sort; import io.substrait.relation.VirtualTableScan; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.relation.utils.StringHolder; import io.substrait.relation.utils.StringHolderHandlingProtoRelConverter; import io.substrait.type.NamedStruct; From f7cec8c7e03b0c3bb06b634cb119b2b1fa9d913c Mon Sep 17 00:00:00 2001 From: Dane Pitkin Date: Wed, 25 Oct 2023 18:40:49 -0400 Subject: [PATCH 04/11] feat: add test --- .../io/substrait/type/proto/JoinRoundtripTest.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java index 9b0156dc..5fdf89cd 100644 --- a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java @@ -3,6 +3,7 @@ import io.substrait.TestBase; import io.substrait.relation.Rel; import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.NestedLoopJoin; import java.util.Arrays; import java.util.List; import org.junit.jupiter.api.Test; @@ -31,4 +32,15 @@ void hashJoin() { .build(); verifyRoundTrip(relWithoutKeys); } + + @Test + void nestedLoopJoin() { + Rel rel = + NestedLoopJoin.builder() + .from( + b.nestedLoopJoin( + __ -> b.bool(true), NestedLoopJoin.JoinType.INNER, leftTable, rightTable)) + .build(); + verifyRoundTrip(rel); + } } From 3ba7fc0e75b8826b27d8f5324cba649c2d25dba3 Mon Sep 17 00:00:00 2001 From: Dane Pitkin Date: Wed, 25 Oct 2023 22:22:55 -0400 Subject: [PATCH 05/11] fix: improve tests --- .../io/substrait/dsl/SubstraitBuilder.java | 13 +++++ .../substrait/relation/ProtoRelConverter.java | 52 +++++++++++-------- .../type/proto/JoinRoundtripTest.java | 17 +++++- 3 files changed, 57 insertions(+), 25 deletions(-) diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 17063c91..b79e0343 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -219,6 +219,10 @@ private NamedScan namedScan( return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build(); } + public NestedLoopJoin nestedLoopJoin(NestedLoopJoin.JoinType joinType, Rel left, Rel right) { + return nestedLoopJoin(Optional.empty(), joinType, Optional.empty(), left, right); + } + public NestedLoopJoin nestedLoopJoin( Function conditionFn, NestedLoopJoin.JoinType joinType, @@ -234,6 +238,15 @@ private NestedLoopJoin nestedLoopJoin( Rel left, Rel right) { var condition = conditionFn.apply(new JoinInput(left, right)); + return nestedLoopJoin(Optional.of(condition), joinType, remap, left, right); + } + + private NestedLoopJoin nestedLoopJoin( + Optional condition, + NestedLoopJoin.JoinType joinType, + Optional remap, + Rel left, + Rel right) { return NestedLoopJoin.builder() .left(left) .right(right) diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 3294f748..76965976 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -475,29 +475,6 @@ private Join newJoin(JoinRel rel) { return builder.build(); } - private NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { - Rel left = from(rel.getLeft()); - Rel right = from(rel.getRight()); - Type.Struct leftStruct = left.getRecordType(); - Type.Struct rightStruct = right.getRecordType(); - Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - var converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); - var builder = - NestedLoopJoin.builder() - .left(left) - .right(right) - .condition(converter.from(rel.getExpression())) - .joinType(NestedLoopJoin.JoinType.fromProto(rel.getType())); - - builder - .commonExtension(optionalAdvancedExtension(rel.getCommon())) - .remap(optionalRelmap(rel.getCommon())); - if (rel.hasAdvancedExtension()) { - builder.extension(advancedExtension(rel.getAdvancedExtension())); - } - return builder.build(); - } - private Rel newCross(CrossRel rel) { Rel left = from(rel.getLeft()); Rel right = from(rel.getRight()); @@ -560,6 +537,35 @@ private Rel newHashJoin(HashJoinRel rel) { return builder.build(); } + private NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { + Rel left = from(rel.getLeft()); + Rel right = from(rel.getRight()); + Type.Struct leftStruct = left.getRecordType(); + Type.Struct rightStruct = right.getRecordType(); + Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); + var converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + var builder = + NestedLoopJoin.builder() + .left(left) + .right(right) + .condition( + // defaults to true if the join expression is unassigned, resulting in a cartesian + // join + Optional.of( + rel.hasExpression() + ? converter.from(rel.getExpression()) + : Expression.BoolLiteral.builder().value(true).build())) + .joinType(NestedLoopJoin.JoinType.fromProto(rel.getType())); + + builder + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())); + if (rel.hasAdvancedExtension()) { + builder.extension(advancedExtension(rel.getAdvancedExtension())); + } + return builder.build(); + } + private static Optional optionalRelmap(io.substrait.proto.RelCommon relCommon) { return Optional.ofNullable( relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null); diff --git a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java index 5fdf89cd..b2123e5b 100644 --- a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java @@ -1,5 +1,11 @@ package io.substrait.type.proto; +<<<<<<< HEAD +======= +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +>>>>>>> 408672a (fix: improve tests) import io.substrait.TestBase; import io.substrait.relation.Rel; import io.substrait.relation.physical.HashJoin; @@ -35,12 +41,19 @@ void hashJoin() { @Test void nestedLoopJoin() { - Rel rel = + Rel relWithDefaultExpression = NestedLoopJoin.builder() .from( b.nestedLoopJoin( __ -> b.bool(true), NestedLoopJoin.JoinType.INNER, leftTable, rightTable)) .build(); - verifyRoundTrip(rel); + verifyRoundTrip(relWithDefaultExpression); + + Rel relWithoutExpression = + NestedLoopJoin.builder() + .from(b.nestedLoopJoin(NestedLoopJoin.JoinType.INNER, leftTable, rightTable)) + .build(); + assertNotEquals(relWithDefaultExpression, relWithoutExpression); + assertEquals(relWithDefaultExpression, roundTrip(relWithoutExpression)); } } From bb82de7dd95d4ca8fb4e4642104bde6fbc86d5d3 Mon Sep 17 00:00:00 2001 From: Dane Pitkin Date: Wed, 25 Oct 2023 22:32:54 -0400 Subject: [PATCH 06/11] fix: improve test case further --- .../java/io/substrait/type/proto/JoinRoundtripTest.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java index b2123e5b..d8bb34d0 100644 --- a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java @@ -41,6 +41,14 @@ void hashJoin() { @Test void nestedLoopJoin() { + Rel relWithCustomExpression = + NestedLoopJoin.builder() + .from( + b.nestedLoopJoin( + __ -> b.bool(false), NestedLoopJoin.JoinType.INNER, leftTable, rightTable)) + .build(); + verifyRoundTrip(relWithCustomExpression); + Rel relWithDefaultExpression = NestedLoopJoin.builder() .from( From d75f2a20d39b29d9d3e1214a2b023d90655cc125 Mon Sep 17 00:00:00 2001 From: Dane Pitkin Date: Wed, 1 Nov 2023 10:48:52 -0400 Subject: [PATCH 07/11] fix: use camelCase --- .../java/io/substrait/type/proto/ExtensionRoundtripTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index 6258a253..b076f03c 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -188,7 +188,7 @@ void hashJoin() { } @Test - void nested_loop_join() { + void nestedLoopJoin() { Rel rel = NestedLoopJoin.builder() .from( From 53d10210ec6b4c596bf8fe0ff71c748c4b32c1e9 Mon Sep 17 00:00:00 2001 From: Dane Pitkin Date: Wed, 1 Nov 2023 15:40:34 -0400 Subject: [PATCH 08/11] feat: make expression required --- .../io/substrait/dsl/SubstraitBuilder.java | 13 ------------ .../substrait/relation/ProtoRelConverter.java | 8 +------- .../relation/RelCopyOnWriteVisitor.java | 6 ++---- .../substrait/relation/RelProtoConverter.java | 3 +-- .../relation/physical/NestedLoopJoin.java | 3 +-- .../type/proto/JoinRoundtripTest.java | 20 ++----------------- 6 files changed, 7 insertions(+), 46 deletions(-) diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index b79e0343..17063c91 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -219,10 +219,6 @@ private NamedScan namedScan( return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build(); } - public NestedLoopJoin nestedLoopJoin(NestedLoopJoin.JoinType joinType, Rel left, Rel right) { - return nestedLoopJoin(Optional.empty(), joinType, Optional.empty(), left, right); - } - public NestedLoopJoin nestedLoopJoin( Function conditionFn, NestedLoopJoin.JoinType joinType, @@ -238,15 +234,6 @@ private NestedLoopJoin nestedLoopJoin( Rel left, Rel right) { var condition = conditionFn.apply(new JoinInput(left, right)); - return nestedLoopJoin(Optional.of(condition), joinType, remap, left, right); - } - - private NestedLoopJoin nestedLoopJoin( - Optional condition, - NestedLoopJoin.JoinType joinType, - Optional remap, - Rel left, - Rel right) { return NestedLoopJoin.builder() .left(left) .right(right) diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 76965976..fdd9f26d 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -548,13 +548,7 @@ private NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { NestedLoopJoin.builder() .left(left) .right(right) - .condition( - // defaults to true if the join expression is unassigned, resulting in a cartesian - // join - Optional.of( - rel.hasExpression() - ? converter.from(rel.getExpression()) - : Expression.BoolLiteral.builder().value(true).build())) + .condition(converter.from(rel.getExpression())) .joinType(NestedLoopJoin.JoinType.fromProto(rel.getType())); builder diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 676c4e7b..e67a6e11 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -126,7 +126,7 @@ public Optional visit(Join join) throws RuntimeException { public Optional visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException { var left = nestedLoopJoin.getLeft().accept(this); var right = nestedLoopJoin.getRight().accept(this); - var condition = nestedLoopJoin.getCondition().flatMap(t -> visitExpression(t)); + var condition = visitExpression(nestedLoopJoin.getCondition()); if (allEmpty(left, right, condition)) { return Optional.empty(); } @@ -135,9 +135,7 @@ public Optional visit(NestedLoopJoin nestedLoopJoin) throws RuntimeExceptio .from(nestedLoopJoin) .left(left.orElse(nestedLoopJoin.getLeft())) .right(right.orElse(nestedLoopJoin.getRight())) - .condition( - Optional.ofNullable( - condition.orElseGet(() -> nestedLoopJoin.getCondition().orElse(null)))) + .condition(condition.orElse(nestedLoopJoin.getCondition())) .build()); } diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 61a3b0fd..2ab4c052 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -188,10 +188,9 @@ public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException { .setCommon(common(nestedLoopJoin)) .setLeft(toProto(nestedLoopJoin.getLeft())) .setRight(toProto(nestedLoopJoin.getRight())) + .setExpression(toProto(nestedLoopJoin.getCondition())) .setType(nestedLoopJoin.getJoinType().toProto()); - nestedLoopJoin.getCondition().ifPresent(t -> builder.setExpression(toProto(t))); - nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); return Rel.newBuilder().setNestedLoopJoin(builder).build(); } diff --git a/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java b/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java index 6d8bbcca..722fdb47 100644 --- a/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java @@ -7,14 +7,13 @@ import io.substrait.relation.RelVisitor; import io.substrait.type.Type; import io.substrait.type.TypeCreator; -import java.util.Optional; import java.util.stream.Stream; import org.immutables.value.Value; @Value.Immutable public abstract class NestedLoopJoin extends BiRel implements HasExtension { - public abstract Optional getCondition(); + public abstract Expression getCondition(); public abstract JoinType getJoinType(); diff --git a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java index d8bb34d0..74132921 100644 --- a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java @@ -3,7 +3,6 @@ <<<<<<< HEAD ======= import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; >>>>>>> 408672a (fix: improve tests) import io.substrait.TestBase; @@ -41,27 +40,12 @@ void hashJoin() { @Test void nestedLoopJoin() { - Rel relWithCustomExpression = - NestedLoopJoin.builder() - .from( - b.nestedLoopJoin( - __ -> b.bool(false), NestedLoopJoin.JoinType.INNER, leftTable, rightTable)) - .build(); - verifyRoundTrip(relWithCustomExpression); - - Rel relWithDefaultExpression = + Rel rel = NestedLoopJoin.builder() .from( b.nestedLoopJoin( __ -> b.bool(true), NestedLoopJoin.JoinType.INNER, leftTable, rightTable)) .build(); - verifyRoundTrip(relWithDefaultExpression); - - Rel relWithoutExpression = - NestedLoopJoin.builder() - .from(b.nestedLoopJoin(NestedLoopJoin.JoinType.INNER, leftTable, rightTable)) - .build(); - assertNotEquals(relWithDefaultExpression, relWithoutExpression); - assertEquals(relWithDefaultExpression, roundTrip(relWithoutExpression)); + verifyRoundTrip(rel); } } From 47aa9d8808a8c238a804730eec04b8a7f6e9c765 Mon Sep 17 00:00:00 2001 From: Dane Pitkin Date: Fri, 3 Nov 2023 11:53:59 -0400 Subject: [PATCH 09/11] fix: use equality expression in test case --- .../io/substrait/relation/ProtoRelConverter.java | 6 +++++- .../io/substrait/type/proto/JoinRoundtripTest.java | 14 ++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index fdd9f26d..05c59541 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -548,7 +548,11 @@ private NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { NestedLoopJoin.builder() .left(left) .right(right) - .condition(converter.from(rel.getExpression())) + .condition( + // defaults to true (aka cartesian join) if the join expression is missing + rel.hasExpression() + ? converter.from(rel.getExpression()) + : Expression.BoolLiteral.builder().value(true).build()) .joinType(NestedLoopJoin.JoinType.fromProto(rel.getType())); builder diff --git a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java index 74132921..328dd83f 100644 --- a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java @@ -1,10 +1,5 @@ package io.substrait.type.proto; -<<<<<<< HEAD -======= -import static org.junit.jupiter.api.Assertions.assertEquals; - ->>>>>>> 408672a (fix: improve tests) import io.substrait.TestBase; import io.substrait.relation.Rel; import io.substrait.relation.physical.HashJoin; @@ -44,8 +39,15 @@ void nestedLoopJoin() { NestedLoopJoin.builder() .from( b.nestedLoopJoin( - __ -> b.bool(true), NestedLoopJoin.JoinType.INNER, leftTable, rightTable)) + __ -> b.equal(b.fieldReference(leftTable, 1), b.fieldReference(rightTable, 0)), + NestedLoopJoin.JoinType.INNER, + leftTable, + rightTable)) .build(); + io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); + Rel relReturned = protoRelConverter.from(protoRel); + System.out.println(rel); + System.out.println(relReturned); verifyRoundTrip(rel); } } From 476a3e1a473f02f54ca279414c1947da40822c37 Mon Sep 17 00:00:00 2001 From: Dane Pitkin Date: Fri, 3 Nov 2023 11:54:56 -0400 Subject: [PATCH 10/11] fix: remove debug output and fix indices --- .../java/io/substrait/type/proto/JoinRoundtripTest.java | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java index 328dd83f..b78f7978 100644 --- a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java @@ -39,15 +39,11 @@ void nestedLoopJoin() { NestedLoopJoin.builder() .from( b.nestedLoopJoin( - __ -> b.equal(b.fieldReference(leftTable, 1), b.fieldReference(rightTable, 0)), + __ -> b.equal(b.fieldReference(leftTable, 0), b.fieldReference(rightTable, 2)), NestedLoopJoin.JoinType.INNER, leftTable, rightTable)) .build(); - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); - Rel relReturned = protoRelConverter.from(protoRel); - System.out.println(rel); - System.out.println(relReturned); verifyRoundTrip(rel); } } From 9f6c6eedfcb547023d6644269a37f2dab6269559 Mon Sep 17 00:00:00 2001 From: Dane Pitkin Date: Fri, 3 Nov 2023 12:40:25 -0400 Subject: [PATCH 11/11] fix: use list of rels as input to equal expression --- .../main/java/io/substrait/dsl/SubstraitBuilder.java | 10 ++++++++++ .../io/substrait/type/proto/JoinRoundtripTest.java | 3 ++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 17063c91..c385c45e 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -311,6 +311,16 @@ public List fieldReferences(Rel input, int... indexes) { .collect(java.util.stream.Collectors.toList()); } + public FieldReference fieldReference(List inputs, int index) { + return ImmutableFieldReference.newInputRelReference(index, inputs); + } + + public List fieldReferences(List inputs, int... indexes) { + return Arrays.stream(indexes) + .mapToObj(index -> fieldReference(inputs, index)) + .collect(java.util.stream.Collectors.toList()); + } + public Expression cast(Expression input, Type type) { return Cast.builder() .input(input) diff --git a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java index b78f7978..8ae8a7da 100644 --- a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java @@ -35,11 +35,12 @@ void hashJoin() { @Test void nestedLoopJoin() { + List inputRels = Arrays.asList(leftTable, rightTable); Rel rel = NestedLoopJoin.builder() .from( b.nestedLoopJoin( - __ -> b.equal(b.fieldReference(leftTable, 0), b.fieldReference(rightTable, 2)), + __ -> b.equal(b.fieldReference(inputRels, 0), b.fieldReference(inputRels, 5)), NestedLoopJoin.JoinType.INNER, leftTable, rightTable))