diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index b592341b..a17c9733 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -584,7 +584,13 @@ abstract static class IfThen implements Expression { public abstract Expression elseClause(); public Type getType() { - return elseClause().getType(); + Type elseType = elseClause().getType(); + + // If any of the clauses are nullable, the whole expression is also nullable. + if (ifClauses().stream().anyMatch(clause -> clause.then().getType().nullable())) { + return TypeCreator.asNullable(elseType); + } + return elseType; } public static ImmutableExpression.IfThen.Builder builder() { diff --git a/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java new file mode 100644 index 00000000..5d0a7e46 --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java @@ -0,0 +1,45 @@ +package io.substrait.type.proto; + +import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.proto.ExpressionProtoConverter; +import io.substrait.expression.proto.ProtoExpressionConverter; +import java.util.Arrays; +import org.junit.jupiter.api.Test; + +public class IfThenRoundtripTest extends TestBase { + + @Test + void ifThenNotNullable() { + final Expression.IfThen ifRel = + b.ifThen( + Arrays.asList( + b.ifClause(ExpressionCreator.bool(false, false), ExpressionCreator.i64(false, 1))), + ExpressionCreator.i64(false, 2)); + assertFalse(ifRel.getType().nullable()); + + var to = new ExpressionProtoConverter(null, null); + var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); + assertEquals(ifRel, from.from(ifRel.accept(to))); + } + + @Test + void ifThenNullable() { + final Expression.IfThen ifRel = + b.ifThen( + Arrays.asList( + b.ifClause(ExpressionCreator.bool(true, false), ExpressionCreator.i64(true, 1))), + ExpressionCreator.i64(false, 2)); + assertTrue(ifRel.getType().nullable()); + + var to = new ExpressionProtoConverter(null, null); + var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); + assertEquals(ifRel, from.from(ifRel.accept(to))); + } +}