Skip to content

Commit

Permalink
fix: adding new ExpressionRoundtripTest for substrait to calcite (pr …
Browse files Browse the repository at this point in the history
…comments)
  • Loading branch information
carlyeks committed Jul 18, 2023
1 parent 3e2083f commit 77a97e1
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package io.substrait.isthmus;

import static io.substrait.isthmus.expression.CallConverters.CREATE_SEARCH_CONV;
import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.expression.RexExpressionConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.plan.Plan;
import io.substrait.relation.Rel;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.io.IOException;
import java.util.List;
import org.apache.calcite.rel.core.Filter;
import org.junit.jupiter.api.Test;

/** Tests which test that an expression round-trips correctly to and from Calcite expressions. */
public class ExpressionRoundtripTest extends PlanTestBase {
static final TypeCreator R = TypeCreator.of(false);
static final TypeCreator N = TypeCreator.of(true);

final SubstraitBuilder b = new SubstraitBuilder(extensions);

// Define a shared table (i.e. a NamedScan) for use in tests.
final List<Type> commonTableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN);
final Rel commonTable =
b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType);

final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory);

@Test
public void singleOrList() throws IOException {
Plan.Root root =
b.root(
b.filter(
input -> b.singleOrList(b.fieldReference(input, 0), b.i32(5), b.i32(10)),
commonTable));
var relNode = converter.convert(root.getInput());
var expression =
((Filter) relNode)
.getCondition()
.accept(
new RexExpressionConverter(
CREATE_SEARCH_CONV.apply(relNode.getCluster().getRexBuilder()),
new ScalarFunctionConverter(
SimpleExtension.loadDefaults().scalarFunctions(), typeFactory)));
var to = new ExpressionProtoConverter(new ExtensionCollector(), null);
assertEquals(
expression.accept(to),
b.scalarFn(
"/functions_boolean.yaml",
"or:bool",
R.BOOLEAN,
b.scalarFn(
"/functions_comparison.yaml",
"equal:any_any",
R.BOOLEAN,
b.fieldReference(commonTable, 0),
b.i32(5)),
b.scalarFn(
"/functions_comparison.yaml",
"equal:any_any",
R.BOOLEAN,
b.fieldReference(commonTable, 0),
b.i32(10)))
.accept(to));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,6 @@ public void emit() {
var relNode = converter.convert(root.getInput());
assertRowMatch(relNode.getRowType(), R.I32, N.STRING);
}

@Test
public void singleOrList() {
Plan.Root root =
b.root(
b.filter(
input -> b.singleOrList(b.fieldReference(input, 0), b.i32(5), b.i32(10)),
commonTable));
var relNode = converter.convert(root.getInput());
assertRowMatch(relNode.getRowType(), commonTableType);
}
}

@Nested
Expand Down

0 comments on commit 77a97e1

Please sign in to comment.