From 963c72f8bd403c6d8d2b6f6e095c788ff9627f13 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Thu, 2 Mar 2023 14:34:34 -0800 Subject: [PATCH] fix: incorrect mapping of floating point + and - ops (#131) * refactor: allow for custom schema creates in tests * test: mapping of arithmetic operations to Substrait * fix: incorrect mapping of fp + and - ops * feat: map numeric negation * feat: map mod * feat: map power * feat: map exp * feat: map trigonometric functions * feat: map abs * feat: map sign * refactor: simplify arithmetic tests --- .../isthmus/expression/FunctionMappings.java | 16 ++++- .../isthmus/ArithmeticFunctionTest.java | 70 +++++++++++++++++++ .../io/substrait/isthmus/PlanTestBase.java | 27 ++++--- 3 files changed, 101 insertions(+), 12 deletions(-) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index 75b50a44..1a32a64f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -21,8 +21,20 @@ public class FunctionMappings { .add( s(SqlStdOperatorTable.PLUS, "add"), s(SqlStdOperatorTable.MINUS, "subtract"), + s(SqlStdOperatorTable.UNARY_MINUS, "negate"), s(SqlStdOperatorTable.MULTIPLY, "multiply"), s(SqlStdOperatorTable.DIVIDE, "divide"), + s(SqlStdOperatorTable.ABS, "abs"), + s(SqlStdOperatorTable.MOD, "modulus"), + s(SqlStdOperatorTable.POWER, "power"), + s(SqlStdOperatorTable.EXP, "exp"), + s(SqlStdOperatorTable.SIN, "sin"), + s(SqlStdOperatorTable.COS, "cos"), + s(SqlStdOperatorTable.TAN, "tan"), + s(SqlStdOperatorTable.ASIN, "asin"), + s(SqlStdOperatorTable.ACOS, "acos"), + s(SqlStdOperatorTable.ATAN, "atan"), + s(SqlStdOperatorTable.SIGN, "sign"), s(SqlStdOperatorTable.AND), s(SqlStdOperatorTable.OR), s(SqlStdOperatorTable.NOT), @@ -78,13 +90,13 @@ public class FunctionMappings { SqlStdOperatorTable.PLUS, resolver( SqlStdOperatorTable.PLUS, - Set.of("i8", "i16", "i32", "i64", "f32", "f64", "dec")), + Set.of("i8", "i16", "i32", "i64", "fp32", "fp64", "dec")), SqlStdOperatorTable.DATETIME_PLUS, resolver(SqlStdOperatorTable.PLUS, Set.of("date", "time", "timestamp")), SqlStdOperatorTable.MINUS, resolver( SqlStdOperatorTable.MINUS, - Set.of("i8", "i16", "i32", "i64", "f32", "f64", "dec")), + Set.of("i8", "i16", "i32", "i64", "fp32", "fp64", "dec")), SqlStdOperatorTable.MINUS_DATE, resolver( SqlStdOperatorTable.MINUS_DATE, Set.of("date", "timestamp_tz", "timestamp"))); diff --git a/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java new file mode 100644 index 00000000..bb1cec36 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java @@ -0,0 +1,70 @@ +package io.substrait.isthmus; + +import java.util.List; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +public class ArithmeticFunctionTest extends PlanTestBase { + + static List CREATES = + List.of( + "CREATE TABLE numbers (i8 TINYINT, i16 SMALLINT, i32 INT, i64 BIGINT, fp32 FLOAT, fp64 DOUBLE)"); + + @ParameterizedTest + @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) + void arithmetic(String c) throws Exception { + String query = + String.format( + "SELECT %s + %s, %s - %s, %s * %s, %s / %s FROM numbers", c, c, c, c, c, c, c, c); + assertSqlSubstraitRelRoundTrip(query, CREATES); + } + + @ParameterizedTest + @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) + void abs(String column) throws Exception { + String query = String.format("SELECT abs(%s) FROM numbers", column); + assertSqlSubstraitRelRoundTrip(query, CREATES); + } + + @ParameterizedTest + @ValueSource(strings = {"fp32", "fp64"}) + void exponential(String column) throws Exception { + String query = String.format("SELECT exp(%s) FROM numbers", column); + assertSqlSubstraitRelRoundTrip(query, CREATES); + } + + @ParameterizedTest + @ValueSource(strings = {"i8", "i16", "i32", "i64"}) + void mod(String column) throws Exception { + String query = String.format("SELECT mod(%s, %s) FROM numbers", column, column); + assertSqlSubstraitRelRoundTrip(query, CREATES); + } + + @ParameterizedTest + @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) + void negation(String column) throws Exception { + String query = String.format("SELECT -%s FROM numbers", column); + assertSqlSubstraitRelRoundTrip(query, CREATES); + } + + @ParameterizedTest + @ValueSource(strings = {"i64", "fp32", "fp64"}) + void power(String column) throws Exception { + String query = String.format("SELECT power(%s, %s) FROM numbers", column, column); + assertSqlSubstraitRelRoundTrip(query, CREATES); + } + + @ParameterizedTest + @ValueSource(strings = {"sin", "cos", "tan", "asin", "acos", "atan"}) + void trigonometric(String fname) throws Exception { + String query = String.format("SELECT %s(fp32), %s(fp64) FROM numbers", fname, fname); + assertSqlSubstraitRelRoundTrip(query, CREATES); + } + + @ParameterizedTest + @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) + void sign(String column) throws Exception { + String query = String.format("SELECT sign(%s) FROM numbers", column); + assertSqlSubstraitRelRoundTrip(query, CREATES); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index da79489e..2f93a8c1 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -32,17 +32,24 @@ public static String asString(String resource) throws IOException { return Resources.toString(Resources.getResource(resource), Charsets.UTF_8); } + public static List tpchSchemaCreateStatements() throws IOException { + String[] values = asString("tpch/schema.sql").split(";"); + return Arrays.stream(values) + .filter(t -> !t.trim().isBlank()) + .collect(java.util.stream.Collectors.toList()); + } + protected Plan assertProtoPlanRoundrip(String query) throws IOException, SqlParseException { return assertProtoPlanRoundrip(query, new SqlToSubstrait()); } protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s) throws IOException, SqlParseException { - String[] values = asString("tpch/schema.sql").split(";"); - var creates = - Arrays.stream(values) - .filter(t -> !t.trim().isBlank()) - .collect(java.util.stream.Collectors.toList()); + return assertProtoPlanRoundrip(query, s, tpchSchemaCreateStatements()); + } + + protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s, List creates) + throws SqlParseException { io.substrait.proto.Plan protoPlan1 = s.execute(query, creates); Plan plan = new ProtoPlanConverter(EXTENSION_COLLECTION).from(protoPlan1); io.substrait.proto.Plan protoPlan2 = new PlanProtoConverter().toProto(plan); @@ -64,14 +71,14 @@ protected void assertPlanRoundrip(Plan plan) throws IOException, SqlParseExcepti } protected List assertSqlSubstraitRelRoundTrip(String query) throws Exception { + return assertSqlSubstraitRelRoundTrip(query, tpchSchemaCreateStatements()); + } + + protected List assertSqlSubstraitRelRoundTrip(String query, List creates) + throws Exception { // sql <--> substrait round trip test. // Assert (sql -> substrait) and (sql -> substrait -> calcite rel -> substrait) are same. // Return list of sql -> substrait rel -> Calcite rel. - String[] values = asString("tpch/schema.sql").split(";"); - var creates = - Arrays.stream(values) - .filter(t -> !t.trim().isBlank()) - .collect(java.util.stream.Collectors.toList()); List relNodeList = new ArrayList<>(); // 1. sql -> substrait rel