From 297c535875b5084967c27d7f72639c16b9563477 Mon Sep 17 00:00:00 2001 From: Carl Yeksigian Date: Tue, 18 Jul 2023 20:56:49 -0400 Subject: [PATCH] feat: Add SingleOrList support to the Isthmus converter (#159) --- .../io/substrait/dsl/SubstraitBuilder.java | 9 +++ .../expression/ExpressionRexConverter.java | 8 +++ .../isthmus/ExpressionConvertabilityTest.java | 72 +++++++++++++++++++ .../io/substrait/isthmus/SimplePlansTest.java | 5 ++ 4 files changed, 94 insertions(+) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index dbdfaf46..75e46b4e 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -6,6 +6,7 @@ import io.substrait.expression.Expression.FailureBehavior; import io.substrait.expression.FieldReference; import io.substrait.expression.ImmutableExpression.Cast; +import io.substrait.expression.ImmutableExpression.SingleOrList; import io.substrait.expression.ImmutableFieldReference; import io.substrait.extension.SimpleExtension; import io.substrait.plan.ImmutablePlan; @@ -237,6 +238,10 @@ public Expression.BoolLiteral bool(boolean v) { return Expression.BoolLiteral.builder().value(v).build(); } + public Expression.I32Literal i32(int v) { + return Expression.I32Literal.builder().value(v).build(); + } + public FieldReference fieldReference(Rel input, int index) { return ImmutableFieldReference.newInputRelReference(index, input); } @@ -266,6 +271,10 @@ public List sortFields(Rel input, int... indexes) { .collect(java.util.stream.Collectors.toList()); } + public Expression singleOrList(Expression condition, Expression... options) { + return SingleOrList.builder().condition(condition).addOptions(options).build(); + } + // Aggregate Functions public AggregateFunctionInvocation aggregateFn( diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 5b346ab0..d173c5d7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -1,6 +1,7 @@ package io.substrait.isthmus.expression; import io.substrait.expression.*; +import io.substrait.expression.Expression.SingleOrList; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.TypeConverter; import io.substrait.type.StringTypeVisitor; @@ -140,6 +141,13 @@ public RexNode visit(Expression.TimeLiteral expr) throws RuntimeException { return rexBuilder.makeLiteral(timeString, typeConverter.toCalcite(typeFactory, expr.getType())); } + @Override + public RexNode visit(SingleOrList expr) throws RuntimeException { + var lhs = expr.condition().accept(this); + return rexBuilder.makeIn( + lhs, expr.options().stream().map(e -> e.accept(this)).collect(Collectors.toList())); + } + @Override public RexNode visit(Expression.DateLiteral expr) throws RuntimeException { return rexBuilder.makeLiteral( diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java b/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java new file mode 100644 index 00000000..55328ede --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java @@ -0,0 +1,72 @@ +package io.substrait.isthmus; + +import static io.substrait.isthmus.expression.CallConverters.CREATE_SEARCH_CONV; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.proto.ExpressionProtoConverter; +import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.expression.RexExpressionConverter; +import io.substrait.isthmus.expression.ScalarFunctionConverter; +import io.substrait.plan.Plan; +import io.substrait.relation.Rel; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.io.IOException; +import java.util.List; +import org.apache.calcite.rel.core.Filter; +import org.junit.jupiter.api.Test; + +/** Tests which test that an expression can be converted to and from Calcite expressions. */ +public class ExpressionConvertabilityTest extends PlanTestBase { + static final TypeCreator R = TypeCreator.of(false); + static final TypeCreator N = TypeCreator.of(true); + + final SubstraitBuilder b = new SubstraitBuilder(extensions); + + // Define a shared table (i.e. a NamedScan) for use in tests. + final List commonTableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN); + final Rel commonTable = + b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType); + + final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory); + + @Test + public void singleOrList() throws IOException { + Plan.Root root = + b.root( + b.filter( + input -> b.singleOrList(b.fieldReference(input, 0), b.i32(5), b.i32(10)), + commonTable)); + var relNode = converter.convert(root.getInput()); + var expression = + ((Filter) relNode) + .getCondition() + .accept( + new RexExpressionConverter( + CREATE_SEARCH_CONV.apply(relNode.getCluster().getRexBuilder()), + new ScalarFunctionConverter( + SimpleExtension.loadDefaults().scalarFunctions(), typeFactory))); + var to = new ExpressionProtoConverter(new ExtensionCollector(), null); + assertEquals( + expression.accept(to), + b.scalarFn( + "/functions_boolean.yaml", + "or:bool", + R.BOOLEAN, + b.scalarFn( + "/functions_comparison.yaml", + "equal:any_any", + R.BOOLEAN, + b.fieldReference(commonTable, 0), + b.i32(5)), + b.scalarFn( + "/functions_comparison.yaml", + "equal:any_any", + R.BOOLEAN, + b.fieldReference(commonTable, 0), + b.i32(10))) + .accept(to)); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimplePlansTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimplePlansTest.java index 4bac2b0e..01a95c49 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimplePlansTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimplePlansTest.java @@ -24,6 +24,11 @@ public void filter() throws IOException, SqlParseException { assertProtoPlanRoundrip("select * from lineitem WHERE L_ORDERKEY > 10"); } + @Test + public void in() throws IOException, SqlParseException { + assertProtoPlanRoundrip("select * from lineitem WHERE L_ORDERKEY IN (10, 20)"); + } + @Test public void joinWithMultiDDLInOneString() throws IOException, SqlParseException { assertProtoPlanRoundrip(