Skip to content

Commit

Permalink
fix: incorrect mapping of floating point + and - ops (#131)
Browse files Browse the repository at this point in the history
* refactor: allow for custom schema creates in tests

* test: mapping of arithmetic operations to Substrait

* fix: incorrect mapping of fp + and - ops

* feat: map numeric negation

* feat: map mod

* feat: map power

* feat: map exp

* feat: map trigonometric functions

* feat: map abs

* feat: map sign

* refactor: simplify arithmetic tests
  • Loading branch information
vbarua authored Mar 2, 2023
1 parent fd7cd5f commit 963c72f
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,20 @@ public class FunctionMappings {
.add(
s(SqlStdOperatorTable.PLUS, "add"),
s(SqlStdOperatorTable.MINUS, "subtract"),
s(SqlStdOperatorTable.UNARY_MINUS, "negate"),
s(SqlStdOperatorTable.MULTIPLY, "multiply"),
s(SqlStdOperatorTable.DIVIDE, "divide"),
s(SqlStdOperatorTable.ABS, "abs"),
s(SqlStdOperatorTable.MOD, "modulus"),
s(SqlStdOperatorTable.POWER, "power"),
s(SqlStdOperatorTable.EXP, "exp"),
s(SqlStdOperatorTable.SIN, "sin"),
s(SqlStdOperatorTable.COS, "cos"),
s(SqlStdOperatorTable.TAN, "tan"),
s(SqlStdOperatorTable.ASIN, "asin"),
s(SqlStdOperatorTable.ACOS, "acos"),
s(SqlStdOperatorTable.ATAN, "atan"),
s(SqlStdOperatorTable.SIGN, "sign"),
s(SqlStdOperatorTable.AND),
s(SqlStdOperatorTable.OR),
s(SqlStdOperatorTable.NOT),
Expand Down Expand Up @@ -78,13 +90,13 @@ public class FunctionMappings {
SqlStdOperatorTable.PLUS,
resolver(
SqlStdOperatorTable.PLUS,
Set.of("i8", "i16", "i32", "i64", "f32", "f64", "dec")),
Set.of("i8", "i16", "i32", "i64", "fp32", "fp64", "dec")),
SqlStdOperatorTable.DATETIME_PLUS,
resolver(SqlStdOperatorTable.PLUS, Set.of("date", "time", "timestamp")),
SqlStdOperatorTable.MINUS,
resolver(
SqlStdOperatorTable.MINUS,
Set.of("i8", "i16", "i32", "i64", "f32", "f64", "dec")),
Set.of("i8", "i16", "i32", "i64", "fp32", "fp64", "dec")),
SqlStdOperatorTable.MINUS_DATE,
resolver(
SqlStdOperatorTable.MINUS_DATE, Set.of("date", "timestamp_tz", "timestamp")));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package io.substrait.isthmus;

import java.util.List;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

public class ArithmeticFunctionTest extends PlanTestBase {

static List<String> CREATES =
List.of(
"CREATE TABLE numbers (i8 TINYINT, i16 SMALLINT, i32 INT, i64 BIGINT, fp32 FLOAT, fp64 DOUBLE)");

@ParameterizedTest
@ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"})
void arithmetic(String c) throws Exception {
String query =
String.format(
"SELECT %s + %s, %s - %s, %s * %s, %s / %s FROM numbers", c, c, c, c, c, c, c, c);
assertSqlSubstraitRelRoundTrip(query, CREATES);
}

@ParameterizedTest
@ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"})
void abs(String column) throws Exception {
String query = String.format("SELECT abs(%s) FROM numbers", column);
assertSqlSubstraitRelRoundTrip(query, CREATES);
}

@ParameterizedTest
@ValueSource(strings = {"fp32", "fp64"})
void exponential(String column) throws Exception {
String query = String.format("SELECT exp(%s) FROM numbers", column);
assertSqlSubstraitRelRoundTrip(query, CREATES);
}

@ParameterizedTest
@ValueSource(strings = {"i8", "i16", "i32", "i64"})
void mod(String column) throws Exception {
String query = String.format("SELECT mod(%s, %s) FROM numbers", column, column);
assertSqlSubstraitRelRoundTrip(query, CREATES);
}

@ParameterizedTest
@ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"})
void negation(String column) throws Exception {
String query = String.format("SELECT -%s FROM numbers", column);
assertSqlSubstraitRelRoundTrip(query, CREATES);
}

@ParameterizedTest
@ValueSource(strings = {"i64", "fp32", "fp64"})
void power(String column) throws Exception {
String query = String.format("SELECT power(%s, %s) FROM numbers", column, column);
assertSqlSubstraitRelRoundTrip(query, CREATES);
}

@ParameterizedTest
@ValueSource(strings = {"sin", "cos", "tan", "asin", "acos", "atan"})
void trigonometric(String fname) throws Exception {
String query = String.format("SELECT %s(fp32), %s(fp64) FROM numbers", fname, fname);
assertSqlSubstraitRelRoundTrip(query, CREATES);
}

@ParameterizedTest
@ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"})
void sign(String column) throws Exception {
String query = String.format("SELECT sign(%s) FROM numbers", column);
assertSqlSubstraitRelRoundTrip(query, CREATES);
}
}
27 changes: 17 additions & 10 deletions isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,24 @@ public static String asString(String resource) throws IOException {
return Resources.toString(Resources.getResource(resource), Charsets.UTF_8);
}

public static List<String> tpchSchemaCreateStatements() throws IOException {
String[] values = asString("tpch/schema.sql").split(";");
return Arrays.stream(values)
.filter(t -> !t.trim().isBlank())
.collect(java.util.stream.Collectors.toList());
}

protected Plan assertProtoPlanRoundrip(String query) throws IOException, SqlParseException {
return assertProtoPlanRoundrip(query, new SqlToSubstrait());
}

protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s)
throws IOException, SqlParseException {
String[] values = asString("tpch/schema.sql").split(";");
var creates =
Arrays.stream(values)
.filter(t -> !t.trim().isBlank())
.collect(java.util.stream.Collectors.toList());
return assertProtoPlanRoundrip(query, s, tpchSchemaCreateStatements());
}

protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s, List<String> creates)
throws SqlParseException {
io.substrait.proto.Plan protoPlan1 = s.execute(query, creates);
Plan plan = new ProtoPlanConverter(EXTENSION_COLLECTION).from(protoPlan1);
io.substrait.proto.Plan protoPlan2 = new PlanProtoConverter().toProto(plan);
Expand All @@ -64,14 +71,14 @@ protected void assertPlanRoundrip(Plan plan) throws IOException, SqlParseExcepti
}

protected List<RelNode> assertSqlSubstraitRelRoundTrip(String query) throws Exception {
return assertSqlSubstraitRelRoundTrip(query, tpchSchemaCreateStatements());
}

protected List<RelNode> assertSqlSubstraitRelRoundTrip(String query, List<String> creates)
throws Exception {
// sql <--> substrait round trip test.
// Assert (sql -> substrait) and (sql -> substrait -> calcite rel -> substrait) are same.
// Return list of sql -> substrait rel -> Calcite rel.
String[] values = asString("tpch/schema.sql").split(";");
var creates =
Arrays.stream(values)
.filter(t -> !t.trim().isBlank())
.collect(java.util.stream.Collectors.toList());
List<RelNode> relNodeList = new ArrayList<>();

// 1. sql -> substrait rel
Expand Down

0 comments on commit 963c72f

Please sign in to comment.