Skip to content

Commit

Permalink
fix: left/right/outer joins have nullable fields
Browse files Browse the repository at this point in the history
During a left, right, or outer join, some set of fields from input
become nullable because they are optionally matched in the output. For
a left join, the right fields become nullable. For a right join, the
left fields become nullable. For an outer join, both sets become
nullable.

This test fails with assertions enabled in Calcite code because the
left join should have nullable right types.
  • Loading branch information
carlyeks committed Jul 13, 2023
1 parent fb1cc75 commit a2e8680
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 4 deletions.
10 changes: 10 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import com.github.bsideup.jabel.Desugar;
import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.Expression.FailureBehavior;
import io.substrait.expression.FieldReference;
import io.substrait.expression.ImmutableExpression.Cast;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.extension.SimpleExtension;
import io.substrait.plan.ImmutablePlan;
Expand Down Expand Up @@ -245,6 +247,14 @@ public List<FieldReference> fieldReferences(Rel input, int... indexes) {
.collect(java.util.stream.Collectors.toList());
}

public Expression cast(Expression input, Type type) {
return Cast.builder()
.input(input)
.type(type)
.failureBehavior(FailureBehavior.UNSPECIFIED)
.build();
}

public List<Expression.SortField> sortFields(Rel input, int... indexes) {
return Arrays.stream(indexes)
.mapToObj(
Expand Down
166 changes: 162 additions & 4 deletions core/src/main/java/io/substrait/relation/Join.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,32 @@
import io.substrait.expression.Expression;
import io.substrait.proto.JoinRel;
import io.substrait.type.Type;
import io.substrait.type.Type.Binary;
import io.substrait.type.Type.Bool;
import io.substrait.type.Type.Date;
import io.substrait.type.Type.Decimal;
import io.substrait.type.Type.FP32;
import io.substrait.type.Type.FP64;
import io.substrait.type.Type.FixedBinary;
import io.substrait.type.Type.FixedChar;
import io.substrait.type.Type.I16;
import io.substrait.type.Type.I32;
import io.substrait.type.Type.I64;
import io.substrait.type.Type.I8;
import io.substrait.type.Type.IntervalDay;
import io.substrait.type.Type.IntervalYear;
import io.substrait.type.Type.ListType;
import io.substrait.type.Type.Map;
import io.substrait.type.Type.Str;
import io.substrait.type.Type.Struct;
import io.substrait.type.Type.Time;
import io.substrait.type.Type.Timestamp;
import io.substrait.type.Type.TimestampTZ;
import io.substrait.type.Type.UUID;
import io.substrait.type.Type.UserDefined;
import io.substrait.type.Type.VarChar;
import io.substrait.type.TypeCreator;
import io.substrait.type.TypeVisitor;
import java.util.Optional;
import java.util.stream.Stream;
import org.immutables.value.Value;
Expand Down Expand Up @@ -47,12 +72,145 @@ public static JoinType fromProto(JoinRel.JoinType proto) {
}
}

private static final class NullableTypeVisitor implements TypeVisitor<Type, RuntimeException> {

@Override
public Type visit(Bool type) throws RuntimeException {
return TypeCreator.NULLABLE.BOOLEAN;
}

@Override
public Type visit(I8 type) throws RuntimeException {
return TypeCreator.NULLABLE.I8;
}

@Override
public Type visit(I16 type) throws RuntimeException {
return TypeCreator.NULLABLE.I16;
}

@Override
public Type visit(I32 type) throws RuntimeException {
return TypeCreator.NULLABLE.I32;
}

@Override
public Type visit(I64 type) throws RuntimeException {
return TypeCreator.NULLABLE.I64;
}

@Override
public Type visit(FP32 type) throws RuntimeException {
return TypeCreator.NULLABLE.FP32;
}

@Override
public Type visit(FP64 type) throws RuntimeException {
return TypeCreator.NULLABLE.FP64;
}

@Override
public Type visit(Str type) throws RuntimeException {
return TypeCreator.NULLABLE.STRING;
}

@Override
public Type visit(Binary type) throws RuntimeException {
return TypeCreator.NULLABLE.BINARY;
}

@Override
public Type visit(Date type) throws RuntimeException {
return TypeCreator.NULLABLE.DATE;
}

@Override
public Type visit(Time type) throws RuntimeException {
return TypeCreator.NULLABLE.TIME;
}

@Override
public Type visit(TimestampTZ type) throws RuntimeException {
return TypeCreator.NULLABLE.TIMESTAMP_TZ;
}

@Override
public Type visit(Timestamp type) throws RuntimeException {
return TypeCreator.NULLABLE.TIMESTAMP;
}

@Override
public Type visit(IntervalYear type) throws RuntimeException {
return TypeCreator.NULLABLE.INTERVAL_YEAR;
}

@Override
public Type visit(IntervalDay type) throws RuntimeException {
return TypeCreator.NULLABLE.INTERVAL_DAY;
}

@Override
public Type visit(UUID type) throws RuntimeException {
return TypeCreator.NULLABLE.UUID;
}

@Override
public Type visit(FixedChar type) throws RuntimeException {
return TypeCreator.NULLABLE.fixedChar(type.length());
}

@Override
public Type visit(VarChar type) throws RuntimeException {
return TypeCreator.NULLABLE.varChar(type.length());
}

@Override
public Type visit(FixedBinary type) throws RuntimeException {
return TypeCreator.NULLABLE.fixedBinary(type.length());
}

@Override
public Type visit(Decimal type) throws RuntimeException {
return TypeCreator.NULLABLE.decimal(type.precision(), type.scale());
}

@Override
public Type visit(Struct type) throws RuntimeException {
return TypeCreator.NULLABLE.struct(type.fields());
}

@Override
public Type visit(ListType type) throws RuntimeException {
return TypeCreator.NULLABLE.list(type.elementType());
}

@Override
public Type visit(Map type) throws RuntimeException {
return TypeCreator.NULLABLE.map(type.key(), type.value());
}

@Override
public Type visit(UserDefined type) throws RuntimeException {
return TypeCreator.NULLABLE.userDefined(type.uri(), type.name());
}
}

@Override
protected Type.Struct deriveRecordType() {
return TypeCreator.REQUIRED.struct(
Stream.concat(
getLeft().getRecordType().fields().stream(),
getRight().getRecordType().fields().stream()));
var nullable = new NullableTypeVisitor();
Stream<Type> leftTypes =
switch (getJoinType()) {
case RIGHT, OUTER -> getLeft().getRecordType().fields().stream()
.map(t -> t.accept(nullable));
default -> getLeft().getRecordType().fields().stream();
};
Stream<Type> rightTypes =
switch (getJoinType()) {
case LEFT, OUTER -> getRight().getRecordType().fields().stream()
.map(t -> t.accept(nullable));
default -> getRight().getRecordType().fields().stream();
};
return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.plan.Plan;
import io.substrait.relation.Join.JoinType;
import io.substrait.relation.Rel;
import io.substrait.relation.Set.SetOp;
import io.substrait.type.Type;
Expand Down Expand Up @@ -150,6 +151,31 @@ public void emit() {
var relNode = converter.convert(root.getInput());
assertRowMatch(relNode.getRowType(), R.I32, N.STRING);
}

@Test
public void leftJoin() {
final List<Type> joinTableType = List.of(N.STRING, N.FP64, N.BINARY);
final Rel joinTable = b.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType);

Plan.Root root =
b.root(
b.project(
r -> b.fieldReferences(r, 0, 1, 3),
b.remap(0, 1, 3),
b.join(
ji -> b.bool(true),
JoinType.LEFT,
joinTable,
b.project(
r ->
List.of(
b.cast(b.fieldReference(r, 2), R.STRING), b.fieldReference(r, 1)),
b.remap(3, 4),
joinTable))));

var relNode = converter.convert(root.getInput());
assertRowMatch(relNode.getRowType(), N.STRING, N.FP64, N.STRING);
}
}

@Nested
Expand Down

0 comments on commit a2e8680

Please sign in to comment.