Skip to content

Commit

Permalink
feat: Add SingleOrList support to the Isthmus converter (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
carlyeks authored Jul 19, 2023
1 parent 9df39ed commit 297c535
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 0 deletions.
9 changes: 9 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.substrait.expression.Expression.FailureBehavior;
import io.substrait.expression.FieldReference;
import io.substrait.expression.ImmutableExpression.Cast;
import io.substrait.expression.ImmutableExpression.SingleOrList;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.extension.SimpleExtension;
import io.substrait.plan.ImmutablePlan;
Expand Down Expand Up @@ -237,6 +238,10 @@ public Expression.BoolLiteral bool(boolean v) {
return Expression.BoolLiteral.builder().value(v).build();
}

public Expression.I32Literal i32(int v) {
return Expression.I32Literal.builder().value(v).build();
}

public FieldReference fieldReference(Rel input, int index) {
return ImmutableFieldReference.newInputRelReference(index, input);
}
Expand Down Expand Up @@ -266,6 +271,10 @@ public List<Expression.SortField> sortFields(Rel input, int... indexes) {
.collect(java.util.stream.Collectors.toList());
}

public Expression singleOrList(Expression condition, Expression... options) {
return SingleOrList.builder().condition(condition).addOptions(options).build();
}

// Aggregate Functions

public AggregateFunctionInvocation aggregateFn(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.isthmus.expression;

import io.substrait.expression.*;
import io.substrait.expression.Expression.SingleOrList;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.TypeConverter;
import io.substrait.type.StringTypeVisitor;
Expand Down Expand Up @@ -140,6 +141,13 @@ public RexNode visit(Expression.TimeLiteral expr) throws RuntimeException {
return rexBuilder.makeLiteral(timeString, typeConverter.toCalcite(typeFactory, expr.getType()));
}

@Override
public RexNode visit(SingleOrList expr) throws RuntimeException {
var lhs = expr.condition().accept(this);
return rexBuilder.makeIn(
lhs, expr.options().stream().map(e -> e.accept(this)).collect(Collectors.toList()));
}

@Override
public RexNode visit(Expression.DateLiteral expr) throws RuntimeException {
return rexBuilder.makeLiteral(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package io.substrait.isthmus;

import static io.substrait.isthmus.expression.CallConverters.CREATE_SEARCH_CONV;
import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.expression.RexExpressionConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.plan.Plan;
import io.substrait.relation.Rel;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.io.IOException;
import java.util.List;
import org.apache.calcite.rel.core.Filter;
import org.junit.jupiter.api.Test;

/** Tests which test that an expression can be converted to and from Calcite expressions. */
public class ExpressionConvertabilityTest extends PlanTestBase {
static final TypeCreator R = TypeCreator.of(false);
static final TypeCreator N = TypeCreator.of(true);

final SubstraitBuilder b = new SubstraitBuilder(extensions);

// Define a shared table (i.e. a NamedScan) for use in tests.
final List<Type> commonTableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN);
final Rel commonTable =
b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType);

final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory);

@Test
public void singleOrList() throws IOException {
Plan.Root root =
b.root(
b.filter(
input -> b.singleOrList(b.fieldReference(input, 0), b.i32(5), b.i32(10)),
commonTable));
var relNode = converter.convert(root.getInput());
var expression =
((Filter) relNode)
.getCondition()
.accept(
new RexExpressionConverter(
CREATE_SEARCH_CONV.apply(relNode.getCluster().getRexBuilder()),
new ScalarFunctionConverter(
SimpleExtension.loadDefaults().scalarFunctions(), typeFactory)));
var to = new ExpressionProtoConverter(new ExtensionCollector(), null);
assertEquals(
expression.accept(to),
b.scalarFn(
"/functions_boolean.yaml",
"or:bool",
R.BOOLEAN,
b.scalarFn(
"/functions_comparison.yaml",
"equal:any_any",
R.BOOLEAN,
b.fieldReference(commonTable, 0),
b.i32(5)),
b.scalarFn(
"/functions_comparison.yaml",
"equal:any_any",
R.BOOLEAN,
b.fieldReference(commonTable, 0),
b.i32(10)))
.accept(to));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ public void filter() throws IOException, SqlParseException {
assertProtoPlanRoundrip("select * from lineitem WHERE L_ORDERKEY > 10");
}

@Test
public void in() throws IOException, SqlParseException {
assertProtoPlanRoundrip("select * from lineitem WHERE L_ORDERKEY IN (10, 20)");
}

@Test
public void joinWithMultiDDLInOneString() throws IOException, SqlParseException {
assertProtoPlanRoundrip(
Expand Down

0 comments on commit 297c535

Please sign in to comment.