diff --git a/gradle.properties b/gradle.properties index 07feada4..a74242bb 100644 --- a/gradle.properties +++ b/gradle.properties @@ -22,6 +22,8 @@ jackson.version=2.16.1 junit.version=5.8.1 protobuf.version=3.25.3 slf4j.version=2.0.13 +sparkbundle.version=3.4 +spark.version=3.4.2 #version that is going to be updated automatically by releases version = 0.33.0 diff --git a/settings.gradle.kts b/settings.gradle.kts index 67079364..224c6b50 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -1,6 +1,6 @@ rootProject.name = "substrait" -include("bom", "core", "isthmus", "isthmus-cli") +include("bom", "core", "isthmus", "isthmus-cli", "spark") pluginManagement { plugins { diff --git a/spark/build.gradle.kts b/spark/build.gradle.kts new file mode 100644 index 00000000..87a540ce --- /dev/null +++ b/spark/build.gradle.kts @@ -0,0 +1,113 @@ +plugins { + `maven-publish` + id("java") + id("scala") + id("idea") + id("com.diffplug.spotless") version "6.11.0" + signing +} + +publishing { + publications { + create("maven-publish") { + from(components["java"]) + + pom { + name.set("Substrait Java") + description.set( + "Create a well-defined, cross-language specification for data compute operations" + ) + url.set("https://github.com/substrait-io/substrait-java") + licenses { + license { + name.set("The Apache License, Version 2.0") + url.set("http://www.apache.org/licenses/LICENSE-2.0.txt") + } + } + developers { + developer { + // TBD Get the list of + } + } + scm { + connection.set("scm:git:git://github.com:substrait-io/substrait-java.git") + developerConnection.set("scm:git:ssh://github.com:substrait-io/substrait-java") + url.set("https://github.com/substrait-io/substrait-java/") + } + } + } + } + repositories { + maven { + name = "local" + val releasesRepoUrl = layout.buildDirectory.dir("repos/releases") + val snapshotsRepoUrl = layout.buildDirectory.dir("repos/snapshots") + url = uri(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl) + } + } +} + +signing { + setRequired({ gradle.taskGraph.hasTask("publishToSonatype") }) + val signingKeyId = + System.getenv("SIGNING_KEY_ID").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_KEY_ID"].toString() + val signingPassword = + System.getenv("SIGNING_PASSWORD").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_PASSWORD"].toString() + val signingKey = + System.getenv("SIGNING_KEY").takeUnless { it.isNullOrEmpty() } + ?: extra["SIGNING_KEY"].toString() + useInMemoryPgpKeys(signingKeyId, signingKey, signingPassword) + sign(publishing.publications["maven-publish"]) +} + +configurations.all { + if (name.startsWith("incrementalScalaAnalysis")) { + setExtendsFrom(emptyList()) + } +} + +java { + toolchain { languageVersion.set(JavaLanguageVersion.of(17)) } + withJavadocJar() + withSourcesJar() +} + +tasks.withType() { + targetCompatibility = "" + scalaCompileOptions.additionalParameters = listOf("-release:17") +} + +var SLF4J_VERSION = properties.get("slf4j.version") +var SPARKBUNDLE_VERSION = properties.get("sparkbundle.version") +var SPARK_VERSION = properties.get("spark.version") + +sourceSets { + main { scala { setSrcDirs(listOf("src/main/spark-${SPARKBUNDLE_VERSION}")) } } + test { scala { setSrcDirs(listOf("src/test/scala", "src/test/spark-3.2", "src/main/scala")) } } +} + +dependencies { + implementation(project(":core")) + implementation("org.scala-lang:scala-library:2.12.16") + implementation("org.apache.spark:spark-core_2.12:${SPARK_VERSION}") + implementation("org.apache.spark:spark-sql_2.12:${SPARK_VERSION}") + implementation("org.apache.spark:spark-catalyst_2.12:${SPARK_VERSION}") + implementation("org.slf4j:slf4j-api:${SLF4J_VERSION}") + + testImplementation("org.scalatest:scalatest_2.12:3.2.18") + testRuntimeOnly("org.junit.platform:junit-platform-engine:1.10.0") + testRuntimeOnly("org.junit.platform:junit-platform-launcher:1.10.0") + testRuntimeOnly("org.scalatestplus:junit-5-10_2.12:3.2.18.0") + testImplementation("org.apache.spark:spark-core_2.12:${SPARK_VERSION}:tests") + testImplementation("org.apache.spark:spark-sql_2.12:${SPARK_VERSION}:tests") + testImplementation("org.apache.spark:spark-catalyst_2.12:${SPARK_VERSION}:tests") +} + +tasks { + test { + dependsOn(":core:shadowJar") + useJUnitPlatform { includeEngines("scalatest") } + } +} diff --git a/spark/src/main/resources/spark.yml b/spark/src/main/resources/spark.yml new file mode 100644 index 00000000..e398aa3a --- /dev/null +++ b/spark/src/main/resources/spark.yml @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +%YAML 1.2 +--- +scalar_functions: + - + name: year + description: Returns the year component of the date/timestamp + impls: + - args: + - value: date + return: i32 + - + name: unscaled + description: >- + Return the unscaled Long value of a Decimal, assuming it fits in a Long. + Note: this expression is internal and created only by the optimizer, + we don't need to do type check for it. + impls: + - args: + - value: DECIMAL + return: i64 diff --git a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala new file mode 100644 index 00000000..421a040d --- /dev/null +++ b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.debug + +import io.substrait.spark.DefaultExpressionVisitor + +import org.apache.spark.sql.catalyst.util.DateTimeUtils + +import io.substrait.expression.{Expression, FieldReference} +import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, StrLiteral} +import io.substrait.function.ToTypeString +import io.substrait.util.DecimalUtil + +import scala.collection.JavaConverters.asScalaBufferConverter + +class ExpressionToString extends DefaultExpressionVisitor[String] { + + override def visit(expr: DecimalLiteral): String = { + val value = expr.value.toByteArray + val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16) + decimal.toString + } + + override def visit(expr: StrLiteral): String = { + expr.value() + } + + override def visit(expr: I32Literal): String = { + expr.value().toString + } + + override def visit(expr: DateLiteral): String = { + DateTimeUtils.toJavaDate(expr.value()).toString + } + + override def visit(expr: FieldReference): String = { + withFieldReference(expr)(i => "$" + i.toString) + } + + override def visit(expr: Expression.SingleOrList): String = { + expr.toString + } + + override def visit(expr: Expression.ScalarFunctionInvocation): String = { + val args = expr + .arguments() + .asScala + .zipWithIndex + .map { + case (arg, i) => + arg.accept(expr.declaration(), i, this) + } + .mkString(",") + + s"${expr.declaration().key()}[${expr.outputType().accept(ToTypeString.INSTANCE)}]($args)" + } + + override def visit(expr: Expression.UserDefinedLiteral): String = { + expr.toString + } +} diff --git a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala new file mode 100644 index 00000000..9f4f5c9f --- /dev/null +++ b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.debug + +import io.substrait.spark.DefaultRelVisitor + +import io.substrait.relation._ + +import scala.collection.mutable + +class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { + + private val expressionStringConverter = new ExpressionToString + + private def stringBuilder(rel: Rel, remapLength: Int): mutable.StringBuilder = { + val nodeName = rel.getClass.getSimpleName.replaceAll("Immutable", "") + val builder: mutable.StringBuilder = new mutable.StringBuilder(s"$nodeName[") + rel.getRemap.ifPresent(remap => builder.append("remap=").append(remap)) + if (builder.length > remapLength) builder.append(", ") + builder + } + + private def withBuilder(rel: Rel, remapLength: Int)(f: mutable.StringBuilder => Unit): String = { + val builder = stringBuilder(rel, remapLength) + f(builder) + builder.append("]").toString + } + + def apply(rel: Rel, maxFields: Int): String = { + rel.accept(this) + } + + override def visit(fetch: Fetch): String = { + withBuilder(fetch, 7)( + builder => { + builder.append("offset=").append(fetch.getOffset) + fetch.getCount.ifPresent( + count => { + builder.append(", ") + builder.append("count=").append(count) + }) + }) + } + override def visit(sort: Sort): String = { + withBuilder(sort, 5)( + builder => { + builder.append("sortFields=").append(sort.getSortFields) + }) + } + + override def visit(join: Join): String = { + withBuilder(join, 5)( + builder => { + join.getCondition.ifPresent( + condition => { + builder.append("condition=").append(condition) + builder.append(", ") + }) + + join.getPostJoinFilter.ifPresent( + postJoinFilter => { + builder.append("postJoinFilter=").append(postJoinFilter) + builder.append(", ") + }) + builder.append("joinType=").append(join.getJoinType) + }) + } + + override def visit(filter: Filter): String = { + withBuilder(filter, 7)( + builder => { + builder.append(filter.getCondition.accept(expressionStringConverter)) + }) + } + + def fillReadRel(read: AbstractReadRel, builder: mutable.StringBuilder): Unit = { + builder.append("initialSchema=").append(read.getInitialSchema) + read.getFilter.ifPresent( + filter => { + builder.append(", ") + builder.append("filter=").append(filter) + }) + read.getCommonExtension.ifPresent( + commonExtension => { + builder.append(", ") + builder.append("commonExtension=").append(commonExtension) + }) + } + override def visit(namedScan: NamedScan): String = { + withBuilder(namedScan, 10)( + builder => { + fillReadRel(namedScan, builder) + builder.append(", ") + builder.append("names=").append(namedScan.getNames) + + namedScan.getExtension.ifPresent( + extension => { + builder.append(", ") + builder.append("extension=").append(extension) + }) + }) + } + + override def visit(emptyScan: EmptyScan): String = { + withBuilder(emptyScan, 10)( + builder => { + fillReadRel(emptyScan, builder) + }) + } + + override def visit(project: Project): String = { + withBuilder(project, 8)( + builder => { + builder + .append("expressions=") + .append(project.getExpressions) + }) + } + + override def visit(aggregate: Aggregate): String = { + withBuilder(aggregate, 10)( + builder => { + builder + .append("groupings=") + .append(aggregate.getGroupings) + .append(", ") + .append("measures=") + .append(aggregate.getMeasures) + }) + } + + override def visit(localFiles: LocalFiles): String = { + withBuilder(localFiles, 10)( + builder => { + builder + .append("items=") + .append(localFiles.getItems) + }) + } +} + +object RelToVerboseString { + val verboseStringWithSuffix = new RelToVerboseString(true) + val verboseString = new RelToVerboseString(false) +} diff --git a/spark/src/main/scala/io/substrait/debug/TreePrinter.scala b/spark/src/main/scala/io/substrait/debug/TreePrinter.scala new file mode 100644 index 00000000..cd50f412 --- /dev/null +++ b/spark/src/main/scala/io/substrait/debug/TreePrinter.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.debug + +import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat +import org.apache.spark.sql.internal.SQLConf + +import RelToVerboseString.{verboseString, verboseStringWithSuffix} +import io.substrait.relation +import io.substrait.relation.Rel + +import scala.collection.JavaConverters.asScalaBufferConverter + +trait TreePrinter[T] { + def tree(t: T): String +} + +object TreePrinter { + + implicit object SubstraitRel extends TreePrinter[relation.Rel] { + override def tree(t: Rel): String = TreePrinter.tree(t) + } + + final def tree(rel: relation.Rel): String = treeString(rel, verbose = true) + + final def treeString( + rel: relation.Rel, + verbose: Boolean, + addSuffix: Boolean = false, + maxFields: Int = SQLConf.get.maxToStringFields, + printOperatorId: Boolean = false): String = { + val concat = new PlanStringConcat() + treeString(rel, concat.append, verbose, addSuffix, maxFields, printOperatorId) + concat.toString + } + + def treeString( + rel: relation.Rel, + append: String => Unit, + verbose: Boolean, + addSuffix: Boolean, + maxFields: Int, + printOperatorId: Boolean): Unit = { + generateTreeString(rel, 0, Nil, append, verbose, "", addSuffix, maxFields, printOperatorId) + } + + /** + * Appends the string representation of this node and its children to the given Writer. + * + * The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at + * depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and + * `lastChildren` for the root node should be empty. + */ + def generateTreeString( + rel: relation.Rel, + depth: Int, + lastChildren: Seq[Boolean], + append: String => Unit, + verbose: Boolean, + prefix: String = "", + addSuffix: Boolean = false, + maxFields: Int, + printNodeId: Boolean, + indent: Int = 0): Unit = { + + append(" " * indent) + if (depth > 0) { + lastChildren.init.foreach(isLast => append(if (isLast) " " else ": ")) + append(if (lastChildren.last) "+- " else ":- ") + } + + val str = if (verbose) { + if (addSuffix) verboseStringWithSuffix(rel, maxFields) else verboseString(rel, maxFields) + } else { + "" + } + append(prefix) + append(str) + append("\n") + + val children = rel.getInputs.asScala + if (children.nonEmpty) { + children.init.foreach( + generateTreeString( + _, + depth + 1, + lastChildren :+ false, + append, + verbose, + prefix, + addSuffix, + maxFields, + printNodeId = printNodeId, + indent = indent)) + + generateTreeString( + children.last, + depth + 1, + lastChildren :+ true, + append, + verbose, + prefix, + addSuffix, + maxFields, + printNodeId = printNodeId, + indent = indent) + } + } +} diff --git a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala new file mode 100644 index 00000000..d0d2e0d0 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import io.substrait.`type`.Type +import io.substrait.expression._ +import io.substrait.extension.SimpleExtension + +class DefaultExpressionVisitor[T] + extends AbstractExpressionVisitor[T, RuntimeException] + with FunctionArg.FuncArgVisitor[T, RuntimeException] { + + override def visitFallback(expr: Expression): T = + throw new UnsupportedOperationException( + s"Expression type ${expr.getClass.getCanonicalName} " + + s"not handled by visitor type ${getClass.getCanonicalName}.") + + override def visitType(fnDef: SimpleExtension.Function, argIdx: Int, t: Type): T = + throw new UnsupportedOperationException( + s"FunctionArg $t not handled by visitor type ${getClass.getCanonicalName}.") + + override def visitEnumArg(fnDef: SimpleExtension.Function, argIdx: Int, e: EnumArg): T = + throw new UnsupportedOperationException( + s"EnumArg(value=${e.value()}) not handled by visitor type ${getClass.getCanonicalName}.") + + protected def withFieldReference(fieldReference: FieldReference)(f: Int => T): T = { + if (fieldReference.isSimpleRootReference) { + val segment = fieldReference.segments().get(0) + segment match { + case s: FieldReference.StructField => f(s.offset()) + case _ => throw new IllegalArgumentException(s"Unhandled type: $segment") + } + } else { + visitFallback(fieldReference) + } + } + + override def visitExpr(fnDef: SimpleExtension.Function, argIdx: Int, e: Expression): T = + e.accept(this) + + override def visit(userDefinedLiteral: Expression.UserDefinedLiteral): T = { + visitFallback(userDefinedLiteral) + } +} diff --git a/spark/src/main/scala/io/substrait/spark/DefaultRelVisitor.scala b/spark/src/main/scala/io/substrait/spark/DefaultRelVisitor.scala new file mode 100644 index 00000000..7f1e181b --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/DefaultRelVisitor.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import io.substrait.relation +import io.substrait.relation.AbstractRelVisitor + +class DefaultRelVisitor[T] extends AbstractRelVisitor[T, RuntimeException] { + + override def visitFallback(rel: relation.Rel): T = + throw new UnsupportedOperationException( + s"Type ${rel.getClass.getCanonicalName}" + + s" not handled by visitor type ${getClass.getCanonicalName}.") +} diff --git a/spark/src/main/scala/io/substrait/spark/HasOutputStack.scala b/spark/src/main/scala/io/substrait/spark/HasOutputStack.scala new file mode 100644 index 00000000..3ff41699 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/HasOutputStack.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import scala.collection.mutable + +trait HasOutputStack[T] { + private val outputStack = mutable.Stack[T]() + def currentOutput: T = outputStack.top + def pushOutput(e: T): Unit = outputStack.push(e) + def popOutput(): T = outputStack.pop() +} diff --git a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala new file mode 100644 index 00000000..0d6d84b7 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import io.substrait.spark.expression.ToAggregateFunction + +import io.substrait.extension.SimpleExtension + +import java.util.Collections + +import scala.collection.JavaConverters +import scala.collection.JavaConverters.asScalaBufferConverter + +object SparkExtension { + private val SparkImpls: SimpleExtension.ExtensionCollection = + SimpleExtension.load(Collections.singletonList("/spark.yml")) + + private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection = + SimpleExtension.loadDefaults() + + lazy val SparkScalarFunctions: Seq[SimpleExtension.ScalarFunctionVariant] = { + val ret = new collection.mutable.ArrayBuffer[SimpleExtension.ScalarFunctionVariant]() + ret.appendAll(EXTENSION_COLLECTION.scalarFunctions().asScala) + ret.appendAll(SparkImpls.scalarFunctions().asScala) + ret + } + + val toAggregateFunction: ToAggregateFunction = ToAggregateFunction( + JavaConverters.asScalaBuffer(EXTENSION_COLLECTION.aggregateFunctions())) +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala new file mode 100644 index 00000000..3205b568 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.types.DataType +import org.apache.spark.substrait.ToSubstraitType + +import com.google.common.collect.{ArrayListMultimap, Multimap} +import io.substrait.`type`.Type +import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FunctionArg} +import io.substrait.expression.Expression.FailureBehavior +import io.substrait.extension.SimpleExtension +import io.substrait.function.{ParameterizedType, ToTypeString} +import io.substrait.utils.Util + +import java.{util => ju} + +import scala.annotation.tailrec +import scala.collection.JavaConverters +import scala.collection.JavaConverters.collectionAsScalaIterableConverter + +abstract class FunctionConverter[F <: SimpleExtension.Function, T](functions: Seq[F]) + extends Logging { + + protected val (signatures, substraitFuncKeyToSig) = init(functions) + + def generateBinding( + sparkExp: Expression, + function: F, + arguments: Seq[FunctionArg], + outputType: Type): T + def getSigs: Seq[Sig] + + private def init( + functions: Seq[F]): (ju.Map[Class[_], FunctionFinder[F, T]], Multimap[String, Sig]) = { + val alm = ArrayListMultimap.create[String, F]() + functions.foreach(f => alm.put(f.name().toLowerCase(ju.Locale.ROOT), f)) + + val sparkExpressions = ArrayListMultimap.create[String, Sig]() + getSigs.foreach(f => sparkExpressions.put(f.name, f)) + val matcherMap = + new ju.IdentityHashMap[Class[_], FunctionFinder[F, T]] + + JavaConverters + .asScalaSet(alm.keySet()) + .foreach( + key => { + val sigs = sparkExpressions.get(key) + if (sigs == null) { + logInfo("Dropping function due to no binding:" + key) + } else { + JavaConverters + .asScalaBuffer(sigs) + .foreach( + sig => { + val implList = alm.get(key) + if (implList != null && !implList.isEmpty) { + matcherMap + .put(sig.expClass, createFinder(key, JavaConverters.asScalaBuffer(implList))) + } + }) + } + }) + val keyMap = ArrayListMultimap.create[String, Sig] + + alm.entries.asScala.foreach( + entry => + sparkExpressions + .get(entry.getKey) + .asScala + .foreach(keyMap.put(entry.getValue.key(), _))) + + (matcherMap, keyMap) + } + + def getSparkExpressionFromSubstraitFunc(key: String, outputType: Type): Option[Sig] = { + val sigs = substraitFuncKeyToSig.get(key) + sigs.size() match { + case 0 => None + case 1 => Some(sigs.iterator().next()) + case _ => None + } + } + private def createFinder(name: String, functions: Seq[F]): FunctionFinder[F, T] = { + new FunctionFinder[F, T]( + name, + functions + .flatMap( + func => + if (func.requiredArguments().size() != func.args().size()) { + Seq( + func.key() -> func, + SimpleExtension.Function.constructKey(name, func.requiredArguments()) -> func) + } else { + Seq(func.key() -> func) + }) + .toMap, + FunctionFinder.getSingularInputType(functions), + parent = this + ) + } +} + +object FunctionFinder extends SQLConfHelper { + + /** + * Returns the most general of a set of types (that is, one type to which they can all be cast), + * or [[None]] if conversion is not possible. The result may be a new type that is less + * restrictive than any of the input types, e.g. leastRestrictive(INT, NUMERIC(3, 2)) + * could be NUMERIC(12, 2). + * + * @param types + * input types to be combined using union (not null, not empty) + * @return + * canonical union type descriptor + */ + def leastRestrictive(types: Seq[DataType]): Option[DataType] = { + val typeCoercion = if (conf.ansiEnabled) { + AnsiTypeCoercion + } else { + TypeCoercion + } + typeCoercion.findWiderCommonType(types) + } + + /** + * If some of the function variants for this function name have single, repeated argument type, we + * will attempt to find matches using these patterns and least-restrictive casting. + * + *

If this exists, the function finder will attempt to find a least-restrictive match using + * these. + */ + def getSingularInputType[F <: SimpleExtension.Function]( + functions: Seq[F]): Option[SingularArgumentMatcher[F]] = { + + @tailrec + def determineFirstType( + first: ParameterizedType, + index: Int, + list: ju.List[SimpleExtension.Argument]): ParameterizedType = + if (index >= list.size()) { + first + } else { + list.get(index) match { + case argument: SimpleExtension.ValueArgument => + val pt = argument.value() + val first_or_pt = if (first == null) pt else first + if (first == null || isMatch(first, pt)) { + determineFirstType(first_or_pt, index + 1, list) + } else { + null + } + case _ => null + } + } + + val matchers = functions + .map(f => (f, determineFirstType(null, 0, f.requiredArguments()))) + .filter(_._2 != null) + .map(f => singular(f._1, f._2)) + + matchers.size match { + case 0 => None + case 1 => Some(matchers.head) + case _ => Some(chained(matchers)) + } + } + + private def isMatch( + inputType: ParameterizedType, + parameterizedType: ParameterizedType): Boolean = { + if (parameterizedType.isWildcard) { + true + } else { + inputType.accept(new IgnoreNullableAndParameters(parameterizedType)) + } + } + + private def isMatch(inputType: Type, parameterizedType: ParameterizedType): Boolean = { + if (parameterizedType.isWildcard) { + true + } else { + inputType.accept(new IgnoreNullableAndParameters(parameterizedType)) + } + } + + def singular[F <: SimpleExtension.Function]( + function: F, + t: ParameterizedType): SingularArgumentMatcher[F] = + (inputType: Type, outputType: Type) => if (isMatch(inputType, t)) Some(function) else None + + def collectFirst[F <: SimpleExtension.Function]( + matchers: Seq[SingularArgumentMatcher[F]], + inputType: Type, + outputType: Type): Option[F] = { + val iter = matchers.iterator + while (iter.hasNext) { + val s = iter.next() + val result = s.apply(inputType, outputType) + if (result.isDefined) { + return result + } + } + None + } + + def chained[F <: SimpleExtension.Function]( + matchers: Seq[SingularArgumentMatcher[F]]): SingularArgumentMatcher[F] = + (inputType: Type, outputType: Type) => collectFirst(matchers, inputType, outputType) +} + +trait SingularArgumentMatcher[F <: SimpleExtension.Function] extends ((Type, Type) => Option[F]) + +class FunctionFinder[F <: SimpleExtension.Function, T]( + val name: String, + val directMap: Map[String, F], + val singularInputType: Option[SingularArgumentMatcher[F]], + val parent: FunctionConverter[F, T]) { + + def attemptMatch(expression: Expression, operands: Seq[SExpression]): Option[T] = { + + val opTypes = operands.map(_.getType) + val outputType = ToSubstraitType.apply(expression.dataType, expression.nullable) + val opTypesStr = opTypes.map(t => t.accept(ToTypeString.INSTANCE)) + + val possibleKeys = + Util.crossProduct(opTypesStr.map(s => Seq(s))).map(list => list.mkString("_")) + + val directMatchKey = possibleKeys + .map(name + ":" + _) + .find(k => directMap.contains(k)) + + if (directMatchKey.isDefined) { + val variant = directMap(directMatchKey.get) + variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType) + val funcArgs: Seq[FunctionArg] = operands + Option(parent.generateBinding(expression, variant, funcArgs, outputType)) + } else if (singularInputType.isDefined) { + val types = expression match { + case agg: AggregateExpression => agg.aggregateFunction.children.map(_.dataType) + case other => other.children.map(_.dataType) + } + val nullable = expression.children.exists(e => e.nullable) + FunctionFinder + .leastRestrictive(types) + .flatMap( + leastRestrictive => { + val leastRestrictiveSubstraitT = + ToSubstraitType.apply(leastRestrictive, nullable = nullable) + singularInputType + .flatMap(f => f(leastRestrictiveSubstraitT, outputType)) + .map( + declaration => { + val coercedArgs = coerceArguments(operands, leastRestrictiveSubstraitT) + declaration.validateOutputType( + JavaConverters.bufferAsJavaList(coercedArgs.toBuffer), + outputType) + val funcArgs: Seq[FunctionArg] = coercedArgs + parent.generateBinding(expression, declaration, funcArgs, outputType) + }) + }) + } else { + None + } + } + + /** + * Coerced types according to an expected output type. Coercion is only done for type mismatches, + * not for nullability or parameter mismatches. + */ + private def coerceArguments(arguments: Seq[SExpression], t: Type): Seq[SExpression] = { + arguments.map( + a => { + if (FunctionFinder.isMatch(t, a.getType)) { + a + } else { + ExpressionCreator.cast(t, a, FailureBehavior.THROW_EXCEPTION) + } + }) + } + + def allowedArgCount(count: Int): Boolean = true +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala new file mode 100644 index 00000000..08326454 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ + +import scala.reflect.ClassTag + +case class Sig(expClass: Class[_], name: String, builder: Seq[Expression] => Expression) { + def makeCall(args: Seq[Expression]): Expression = + builder(args) +} + +class FunctionMappings { + + private def s[T <: Expression: ClassTag](name: String): Sig = { + val builder = FunctionRegistryBase.build[T](name, None)._2 + Sig(scala.reflect.classTag[T].runtimeClass, name, builder) + } + + val SCALAR_SIGS: Seq[Sig] = Seq( + s[Add]("add"), + s[Subtract]("subtract"), + s[Multiply]("multiply"), + s[Divide]("divide"), + s[And]("and"), + s[Or]("or"), + s[Not]("not"), + s[LessThan]("lt"), + s[LessThanOrEqual]("lte"), + s[GreaterThan]("gt"), + s[GreaterThanOrEqual]("gte"), + s[EqualTo]("equal"), + // s[BitwiseXor]("xor"), + s[IsNull]("is_null"), + s[IsNotNull]("is_not_null"), + s[EndsWith]("ends_with"), + s[Like]("like"), + s[Contains]("contains"), + s[StartsWith]("starts_with"), + s[Substring]("substring"), + s[Year]("year"), + + // internal + s[UnscaledValue]("unscaled") + ) + + val AGGREGATE_SIGS: Seq[Sig] = Seq( + s[Sum]("sum"), + s[Average]("avg"), + s[Count]("count"), + s[Min]("min"), + s[Max]("max"), + s[HyperLogLogPlusPlus]("approx_count_distinct") + ) + + lazy val scalar_functions_map: Map[Class[_], Sig] = SCALAR_SIGS.map(s => (s.expClass, s)).toMap + lazy val aggregate_functions_map: Map[Class[_], Sig] = + AGGREGATE_SIGS.map(s => (s.expClass, s)).toMap +} + +object FunctionMappings extends FunctionMappings diff --git a/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala new file mode 100644 index 00000000..962a98b1 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import io.substrait.`type`.Type +import io.substrait.function.{ParameterizedType, ParameterizedTypeVisitor} + +import scala.annotation.nowarn + +class IgnoreNullableAndParameters(val typeToMatch: ParameterizedType) + extends ParameterizedTypeVisitor[Boolean, RuntimeException] { + + override def visit(`type`: Type.Bool): Boolean = typeToMatch.isInstanceOf[Type.Bool] + + override def visit(`type`: Type.I8): Boolean = typeToMatch.isInstanceOf[Type.I8] + + override def visit(`type`: Type.I16): Boolean = typeToMatch.isInstanceOf[Type.I16] + + override def visit(`type`: Type.I32): Boolean = typeToMatch.isInstanceOf[Type.I32] + + override def visit(`type`: Type.I64): Boolean = typeToMatch.isInstanceOf[Type.I64] + + override def visit(`type`: Type.FP32): Boolean = typeToMatch.isInstanceOf[Type.FP32] + + override def visit(`type`: Type.FP64): Boolean = typeToMatch.isInstanceOf[Type.FP64] + + override def visit(`type`: Type.Str): Boolean = typeToMatch.isInstanceOf[Type.Str] + + override def visit(`type`: Type.Binary): Boolean = typeToMatch.isInstanceOf[Type.Binary] + + override def visit(`type`: Type.Date): Boolean = typeToMatch.isInstanceOf[Type.Date] + + override def visit(`type`: Type.Time): Boolean = typeToMatch.isInstanceOf[Type.Time] + + @nowarn + override def visit(`type`: Type.TimestampTZ): Boolean = typeToMatch.isInstanceOf[Type.TimestampTZ] + + @nowarn + override def visit(`type`: Type.Timestamp): Boolean = typeToMatch.isInstanceOf[Type.Timestamp] + + override def visit(`type`: Type.IntervalYear): Boolean = + typeToMatch.isInstanceOf[Type.IntervalYear] + + override def visit(`type`: Type.IntervalDay): Boolean = typeToMatch.isInstanceOf[Type.IntervalDay] + + override def visit(`type`: Type.UUID): Boolean = typeToMatch.isInstanceOf[Type.UUID] + + override def visit(`type`: Type.FixedChar): Boolean = + typeToMatch.isInstanceOf[Type.FixedChar] || typeToMatch + .isInstanceOf[ParameterizedType.FixedChar] + + override def visit(`type`: Type.VarChar): Boolean = + typeToMatch.isInstanceOf[Type.VarChar] || typeToMatch.isInstanceOf[ParameterizedType.VarChar] + + override def visit(`type`: Type.FixedBinary): Boolean = + typeToMatch.isInstanceOf[Type.FixedBinary] || typeToMatch + .isInstanceOf[ParameterizedType.FixedBinary] + + override def visit(`type`: Type.Decimal): Boolean = + typeToMatch.isInstanceOf[Type.Decimal] || typeToMatch.isInstanceOf[ParameterizedType.Decimal] + + override def visit(`type`: Type.Struct): Boolean = + typeToMatch.isInstanceOf[Type.Struct] || typeToMatch.isInstanceOf[ParameterizedType.Struct] + + override def visit(`type`: Type.ListType): Boolean = + typeToMatch.isInstanceOf[Type.ListType] || typeToMatch.isInstanceOf[ParameterizedType.ListType] + + override def visit(`type`: Type.Map): Boolean = + typeToMatch.isInstanceOf[Type.Map] || typeToMatch.isInstanceOf[ParameterizedType.Map] + + override def visit(`type`: Type.UserDefined): Boolean = + typeToMatch.isInstanceOf[Type.UserDefined] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.FixedChar): Boolean = + typeToMatch.isInstanceOf[Type.FixedChar] || typeToMatch + .isInstanceOf[ParameterizedType.FixedChar] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.VarChar): Boolean = + typeToMatch.isInstanceOf[Type.VarChar] || typeToMatch.isInstanceOf[ParameterizedType.VarChar] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.FixedBinary): Boolean = + typeToMatch.isInstanceOf[Type.FixedBinary] || typeToMatch + .isInstanceOf[ParameterizedType.FixedBinary] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.Decimal): Boolean = + typeToMatch.isInstanceOf[Type.Decimal] || typeToMatch.isInstanceOf[ParameterizedType.Decimal] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.Struct): Boolean = + typeToMatch.isInstanceOf[Type.Struct] || typeToMatch.isInstanceOf[ParameterizedType.Struct] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.ListType): Boolean = + typeToMatch.isInstanceOf[Type.ListType] || typeToMatch.isInstanceOf[ParameterizedType.ListType] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.Map): Boolean = + typeToMatch.isInstanceOf[Type.Map] || typeToMatch.isInstanceOf[ParameterizedType.Map] + + @throws[RuntimeException] + override def visit(stringLiteral: ParameterizedType.StringLiteral): Boolean = false + + @throws[RuntimeException] + override def visit(precisionTimestamp: ParameterizedType.PrecisionTimestamp): Boolean = + typeToMatch.isInstanceOf[ParameterizedType.PrecisionTimestamp] + + @throws[RuntimeException] + override def visit(precisionTimestampTZ: ParameterizedType.PrecisionTimestampTZ): Boolean = + typeToMatch.isInstanceOf[ParameterizedType.PrecisionTimestampTZ] + + @throws[RuntimeException] + override def visit(precisionTimestamp: Type.PrecisionTimestamp): Boolean = + typeToMatch.isInstanceOf[Type.PrecisionTimestamp] + + @throws[RuntimeException] + override def visit(precisionTimestampTZ: Type.PrecisionTimestampTZ): Boolean = + typeToMatch.isInstanceOf[Type.PrecisionTimestampTZ] +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToAggregateFunction.scala b/spark/src/main/scala/io/substrait/spark/expression/ToAggregateFunction.scala new file mode 100644 index 00000000..0c5b50c6 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/ToAggregateFunction.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.aggregate._ + +import io.substrait.`type`.Type +import io.substrait.expression.{AggregateFunctionInvocation, Expression => SExpression, ExpressionCreator, FunctionArg} +import io.substrait.extension.SimpleExtension + +import java.util.Collections + +import scala.collection.JavaConverters + +abstract class ToAggregateFunction(functions: Seq[SimpleExtension.AggregateFunctionVariant]) + extends FunctionConverter[SimpleExtension.AggregateFunctionVariant, AggregateFunctionInvocation]( + functions) { + + override def generateBinding( + sparkExp: Expression, + function: SimpleExtension.AggregateFunctionVariant, + arguments: Seq[FunctionArg], + outputType: Type): AggregateFunctionInvocation = { + + val sparkAggregate = sparkExp.asInstanceOf[AggregateExpression] + + ExpressionCreator.aggregateFunction( + function, + outputType, + ToAggregateFunction.fromSpark(sparkAggregate.mode), + Collections.emptyList[SExpression.SortField](), + ToAggregateFunction.fromSpark(sparkAggregate.isDistinct), + JavaConverters.asJavaIterable(arguments) + ) + } + + def convert( + expression: AggregateExpression, + operands: Seq[SExpression]): Option[AggregateFunctionInvocation] = { + Option(signatures.get(expression.aggregateFunction.getClass)) + .filter(m => m.allowedArgCount(2)) + .flatMap(m => m.attemptMatch(expression, operands)) + } + + def apply( + expression: AggregateExpression, + operands: Seq[SExpression]): AggregateFunctionInvocation = { + convert(expression, operands).getOrElse( + throw new UnsupportedOperationException( + s"Unable to find binding for call ${expression.aggregateFunction}")) + } +} + +object ToAggregateFunction { + def fromSpark(mode: AggregateMode): SExpression.AggregationPhase = mode match { + case Partial => SExpression.AggregationPhase.INITIAL_TO_INTERMEDIATE + case PartialMerge => SExpression.AggregationPhase.INTERMEDIATE_TO_INTERMEDIATE + case Final => SExpression.AggregationPhase.INTERMEDIATE_TO_RESULT + case Complete => SExpression.AggregationPhase.INITIAL_TO_RESULT + case other => throw new UnsupportedOperationException(s"not currently supported: $other.") + } + def toSpark(phase: SExpression.AggregationPhase): AggregateMode = phase match { + case SExpression.AggregationPhase.INITIAL_TO_INTERMEDIATE => Partial + case SExpression.AggregationPhase.INTERMEDIATE_TO_INTERMEDIATE => PartialMerge + case SExpression.AggregationPhase.INTERMEDIATE_TO_RESULT => Final + case SExpression.AggregationPhase.INITIAL_TO_RESULT => Complete + } + def fromSpark(isDistinct: Boolean): SExpression.AggregationInvocation = if (isDistinct) { + SExpression.AggregationInvocation.DISTINCT + } else { + SExpression.AggregationInvocation.ALL + } + + def toSpark(innovation: SExpression.AggregationInvocation): Boolean = innovation match { + case SExpression.AggregationInvocation.DISTINCT => true + case _ => false + } + + def apply(functions: Seq[SimpleExtension.AggregateFunctionVariant]): ToAggregateFunction = { + new ToAggregateFunction(functions) { + override def getSigs: Seq[Sig] = FunctionMappings.AGGREGATE_SIGS + } + } + +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala b/spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala new file mode 100644 index 00000000..cd23611e --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import org.apache.spark.sql.catalyst.expressions.Expression + +import io.substrait.`type`.Type +import io.substrait.expression.{Expression => SExpression, FunctionArg} +import io.substrait.extension.SimpleExtension + +import scala.collection.JavaConverters + +abstract class ToScalarFunction(functions: Seq[SimpleExtension.ScalarFunctionVariant]) + extends FunctionConverter[SimpleExtension.ScalarFunctionVariant, SExpression](functions) { + + override def generateBinding( + sparkExp: Expression, + function: SimpleExtension.ScalarFunctionVariant, + arguments: Seq[FunctionArg], + outputType: Type): SExpression = { + SExpression.ScalarFunctionInvocation + .builder() + .outputType(outputType) + .declaration(function) + .addAllArguments(JavaConverters.asJavaIterable(arguments)) + .build() + } + + def convert(expression: Expression, operands: Seq[SExpression]): Option[SExpression] = { + Option(signatures.get(expression.getClass)) + .filter(m => m.allowedArgCount(2)) + .flatMap(m => m.attemptMatch(expression, operands)) + } +} + +object ToScalarFunction { + def apply(functions: Seq[SimpleExtension.ScalarFunctionVariant]): ToScalarFunction = { + new ToScalarFunction(functions) { + override def getSigs: Seq[Sig] = FunctionMappings.SCALAR_SIGS + } + } +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala new file mode 100644 index 00000000..430748c0 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack} +import io.substrait.spark.logical.ToLogicalPlan + +import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.types.Decimal +import org.apache.spark.substrait.{SparkTypeUtil, ToSubstraitType} +import org.apache.spark.unsafe.types.UTF8String + +import io.substrait.`type`.{StringTypeVisitor, Type} +import io.substrait.{expression => exp} +import io.substrait.expression.{Expression => SExpression} +import io.substrait.util.DecimalUtil + +import scala.collection.JavaConverters.asScalaBufferConverter + +class ToSparkExpression( + val scalarFunctionConverter: ToScalarFunction, + val toLogicalPlan: Option[ToLogicalPlan] = None) + extends DefaultExpressionVisitor[Expression] + with HasOutputStack[Seq[NamedExpression]] { + + override def visit(expr: SExpression.BoolLiteral): Expression = { + if (expr.value()) { + Literal.TrueLiteral + } else { + Literal.FalseLiteral + } + } + override def visit(expr: SExpression.I32Literal): Expression = { + Literal(expr.value(), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.I64Literal): Expression = { + Literal(expr.value(), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.FP64Literal): Expression = { + Literal(expr.value(), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.StrLiteral): Expression = { + Literal(UTF8String.fromString(expr.value()), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.FixedCharLiteral): Expression = { + Literal(UTF8String.fromString(expr.value()), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.VarCharLiteral): Expression = { + Literal(UTF8String.fromString(expr.value()), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.DecimalLiteral): Expression = { + val value = expr.value.toByteArray + val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16) + Literal(Decimal(decimal), ToSubstraitType.convert(expr.getType)) + } + override def visit(expr: SExpression.DateLiteral): Expression = { + Literal(expr.value(), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.Cast): Expression = { + val childExp = expr.input().accept(this) + Cast(childExp, ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: exp.FieldReference): Expression = { + withFieldReference(expr)(i => currentOutput(i).clone()) + } + override def visit(expr: SExpression.IfThen): Expression = { + val branches = expr + .ifClauses() + .asScala + .map( + ifClause => { + val predicate = ifClause.condition().accept(this) + val elseValue = ifClause.`then`().accept(this) + (predicate, elseValue) + }) + val default = expr.elseClause().accept(this) match { + case l: Literal if l.nullable => None + case other => Some(other) + } + CaseWhen(branches, default) + } + + override def visit(expr: SExpression.ScalarSubquery): Expression = { + val rel = expr.input() + val dataType = ToSubstraitType.convert(expr.getType) + toLogicalPlan + .map( + relConverter => { + val plan = rel.accept(relConverter) + require(plan.resolved) + val result = ScalarSubquery(plan) + SparkTypeUtil.sameType(result.dataType, dataType) + result + }) + .getOrElse(visitFallback(expr)) + } + + override def visit(expr: SExpression.SingleOrList): Expression = { + val value = expr.condition().accept(this) + val list = expr.options().asScala.map(e => e.accept(this)) + In(value, list) + } + override def visit(expr: SExpression.ScalarFunctionInvocation): Expression = { + val eArgs = expr.arguments().asScala + val args = eArgs.zipWithIndex.map { + case (arg, i) => + arg.accept(expr.declaration(), i, this) + } + + scalarFunctionConverter + .getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType()) + .flatMap(sig => Option(sig.makeCall(args))) + .getOrElse({ + val msg = String.format( + "Unable to convert scalar function %s(%s).", + expr.declaration.name, + expr.arguments.asScala + .map { + case ea: exp.EnumArg => ea.value.toString + case e: SExpression => e.getType.accept(new StringTypeVisitor) + case t: Type => t.accept(new StringTypeVisitor) + case a => throw new IllegalStateException("Unexpected value: " + a) + } + .mkString(", ") + ) + throw new IllegalArgumentException(msg) + }) + } +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala new file mode 100644 index 00000000..cf05aef4 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import io.substrait.spark.HasOutputStack + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.substrait.{SparkTypeUtil, ToSubstraitType} + +import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FieldReference, ImmutableExpression} +import io.substrait.expression.Expression.FailureBehavior +import io.substrait.utils.Util + +import scala.collection.JavaConverters.asJavaIterableConverter + +/** The builder to generate substrait expressions from catalyst expressions. */ +abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] { + + object ScalarFunction { + def unapply(e: Expression): Option[Seq[Expression]] = e match { + case BinaryExpression(left, right) => Some(Seq(left, right)) + case UnaryExpression(child) => Some(Seq(child)) + case t: TernaryExpression => Some(Seq(t.first, t.second, t.third)) + case _ => None + } + } + + type OutputT = Seq[Attribute] + + protected val toScalarFunction: ToScalarFunction + + protected def default(e: Expression): Option[SExpression] = { + throw new UnsupportedOperationException(s"Unable to convert the expression $e") + } + + def apply(e: Expression, output: OutputT = Nil): SExpression = { + convert(e, output).getOrElse( + throw new UnsupportedOperationException(s"Unable to convert the expression $e") + ) + } + def convert(expr: Expression, output: OutputT = Nil): Option[SExpression] = { + pushOutput(output) + try { + translateUp(expr) + } finally { + popOutput() + } + } + + protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = default(expr) + + protected def translateAttribute(a: AttributeReference): Option[SExpression] = { + val bindReference = + BindReferences.bindReference[Expression](a, currentOutput, allowFailures = false) + if (bindReference == a) { + default(a) + } else { + Some( + FieldReference.newRootStructReference( + bindReference.asInstanceOf[BoundReference].ordinal, + ToSubstraitType.apply(a.dataType, a.nullable)) + ) + } + } + + protected def translateCaseWhen( + branches: Seq[(Expression, Expression)], + elseValue: Option[Expression]): Option[SExpression] = { + val cases = + for ((predicate, trueValue) <- branches) + yield translateUp(predicate).flatMap( + p => + translateUp(trueValue).map( + t => { + ImmutableExpression.IfClause.builder + .condition(p) + .`then`(t) + .build() + })) + val sparkElse = elseValue.getOrElse(Literal.create(null, branches.head._2.dataType)) + Util + .seqToOption(cases.toList) + .flatMap( + caseConditions => + translateUp(sparkElse).map( + defaultResult => { + ExpressionCreator.ifThenStatement(defaultResult, caseConditions.asJava) + })) + } + protected def translateIn(value: Expression, list: Seq[Expression]): Option[SExpression] = { + Util + .seqToOption(list.map(translateUp).toList) + .flatMap( + inList => + translateUp(value).map( + inValue => { + SExpression.SingleOrList + .builder() + .condition(inValue) + .options(inList.asJava) + .build() + })) + } + + protected def translateUp(expr: Expression): Option[SExpression] = { + expr match { + case c @ Cast(child, dataType, _, _) => + translateUp(child) + .map(ExpressionCreator + .cast(ToSubstraitType.apply(dataType, c.nullable), _, FailureBehavior.THROW_EXCEPTION)) + case c @ CheckOverflow(child, dataType, _) => + // CheckOverflow similar with cast + translateUp(child) + .map( + childExpr => { + if (SparkTypeUtil.sameType(dataType, child.dataType)) { + childExpr + } else { + ExpressionCreator.cast( + ToSubstraitType.apply(dataType, c.nullable), + childExpr, + FailureBehavior.THROW_EXCEPTION) + } + }) + case SubstraitLiteral(substraitLiteral) => Some(substraitLiteral) + case a: AttributeReference if currentOutput.nonEmpty => translateAttribute(a) + case a: Alias => translateUp(a.child) + case p + if p.getClass.getCanonicalName.equals( // removed in spark-3.3 + "org.apache.spark.sql.catalyst.expressions.PromotePrecision") => + translateUp(p.children.head) + case CaseWhen(branches, elseValue) => translateCaseWhen(branches, elseValue) + case scalar @ ScalarFunction(children) => + Util + .seqToOption(children.map(translateUp)) + .flatMap(toScalarFunction.convert(scalar, _)) + case In(value, list) => translateIn(value, list) + case p: PlanExpression[_] => translateSubQuery(p) + case other => default(other) + } + } +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala new file mode 100644 index 00000000..d10b04f4 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.types._ +import org.apache.spark.substrait.ToSubstraitType +import org.apache.spark.unsafe.types.UTF8String + +import io.substrait.expression.{Expression => SExpression} +import io.substrait.expression.ExpressionCreator._ + +class ToSubstraitLiteral { + + object Nonnull { + private def sparkDecimal2Substrait( + d: Decimal, + precision: Int, + scale: Int): SExpression.Literal = + decimal(false, d.toJavaBigDecimal, precision, scale) + + val _bool: Boolean => SExpression.Literal = bool(false, _) + val _i8: Byte => SExpression.Literal = i8(false, _) + val _i16: Short => SExpression.Literal = i16(false, _) + val _i32: Int => SExpression.Literal = i32(false, _) + val _i64: Long => SExpression.Literal = i64(false, _) + val _fp32: Float => SExpression.Literal = fp32(false, _) + val _fp64: Double => SExpression.Literal = fp64(false, _) + val _decimal: (Decimal, Int, Int) => SExpression.Literal = sparkDecimal2Substrait + val _date: Int => SExpression.Literal = date(false, _) + val _string: String => SExpression.Literal = string(false, _) + } + + private def convertWithValue(literal: Literal): Option[SExpression.Literal] = { + Option.apply( + literal match { + case Literal(b: Boolean, BooleanType) => Nonnull._bool(b) + case Literal(b: Byte, ByteType) => Nonnull._i8(b) + case Literal(s: Short, ShortType) => Nonnull._i16(s) + case Literal(i: Integer, IntegerType) => Nonnull._i32(i) + case Literal(l: Long, LongType) => Nonnull._i64(l) + case Literal(f: Float, FloatType) => Nonnull._fp32(f) + case Literal(d: Double, DoubleType) => Nonnull._fp64(d) + case Literal(d: Decimal, dataType: DecimalType) => + Nonnull._decimal(d, dataType.precision, dataType.scale) + case Literal(d: Integer, DateType) => Nonnull._date(d) + case Literal(u: UTF8String, StringType) => Nonnull._string(u.toString) + case _ => null + } + ) + } + + def convert(literal: Literal): Option[SExpression.Literal] = { + if (literal.nullable) { + ToSubstraitType + .convert(literal.dataType, nullable = true) + .map(typedNull) + } else { + convertWithValue(literal) + } + } + + def apply(literal: Literal): SExpression.Literal = { + convert(literal) + .getOrElse( + throw new UnsupportedOperationException( + s"Unable to convert the type ${literal.dataType.typeName}")) + } +} + +object ToSubstraitLiteral extends ToSubstraitLiteral + +object SubstraitLiteral { + def unapply(literal: Literal): Option[SExpression.Literal] = { + ToSubstraitLiteral.convert(literal) + } +} diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala new file mode 100644 index 00000000..5d76f58b --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.logical + +import io.substrait.spark.{DefaultRelVisitor, SparkExtension} +import io.substrait.spark.expression._ + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InMemoryFileIndex, LogicalRelation} +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.types.{DataTypes, IntegerType, StructType} +import org.apache.spark.substrait.ToSubstraitType + +import io.substrait.`type`.{StringTypeVisitor, Type} +import io.substrait.{expression => exp} +import io.substrait.expression.{Expression => SExpression} +import io.substrait.plan.Plan +import io.substrait.relation +import io.substrait.relation.LocalFiles +import org.apache.hadoop.fs.Path + +import scala.collection.JavaConverters.asScalaBufferConverter +import scala.collection.mutable.ArrayBuffer + +/** + * RelVisitor to convert Substrait Rel plan to [[LogicalPlan]]. Unsupported Rel node will call + * visitFallback and throw UnsupportedOperationException. + */ +class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] { + + private val expressionConverter = + new ToSparkExpression(ToScalarFunction(SparkExtension.SparkScalarFunctions), Some(this)) + + private def fromMeasure(measure: relation.Aggregate.Measure): AggregateExpression = { + // this functions is called in createParentwithChild + val function = measure.getFunction + var arguments = function.arguments().asScala.zipWithIndex.map { + case (arg, i) => + arg.accept(function.declaration(), i, expressionConverter) + } + if (function.declaration.name == "count" && function.arguments.size == 0) { + // HACK - count() needs to be rewritten as count(1) + arguments = ArrayBuffer(Literal(1)) + } + + val aggregateFunction = SparkExtension.toAggregateFunction + .getSparkExpressionFromSubstraitFunc(function.declaration.key, function.outputType) + .map(sig => sig.makeCall(arguments)) + .map(_.asInstanceOf[AggregateFunction]) + .getOrElse({ + val msg = String.format( + "Unable to convert Aggregate function %s(%s).", + function.declaration.name, + function.arguments.asScala + .map { + case ea: exp.EnumArg => ea.value.toString + case e: SExpression => e.getType.accept(new StringTypeVisitor) + case t: Type => t.accept(new StringTypeVisitor) + case a => throw new IllegalStateException("Unexpected value: " + a) + } + .mkString(", ") + ) + throw new IllegalArgumentException(msg) + }) + AggregateExpression( + aggregateFunction, + ToAggregateFunction.toSpark(function.aggregationPhase()), + ToAggregateFunction.toSpark(function.invocation()), + None + ) + } + + private def toNamedExpression(e: Expression): NamedExpression = e match { + case ne: NamedExpression => ne + case other => Alias(other, toPrettySQL(other))() + } + + override def visit(aggregate: relation.Aggregate): LogicalPlan = { + require(aggregate.getGroupings.size() == 1) + val child = aggregate.getInput.accept(this) + withChild(child) { + val groupBy = aggregate.getGroupings + .get(0) + .getExpressions + .asScala + .map(expr => expr.accept(expressionConverter)) + + val outputs = groupBy.map(toNamedExpression) + val aggregateExpressions = + aggregate.getMeasures.asScala.map(fromMeasure).map(toNamedExpression) + Aggregate(groupBy, outputs ++= aggregateExpressions, child) + } + } + + override def visit(join: relation.Join): LogicalPlan = { + val left = join.getLeft.accept(this) + val right = join.getRight.accept(this) + withChild(left, right) { + val condition = Option(join.getCondition.orElse(null)) + .map(_.accept(expressionConverter)) + + val joinType = join.getJoinType match { + case relation.Join.JoinType.INNER => Inner + case relation.Join.JoinType.LEFT => LeftOuter + case relation.Join.JoinType.RIGHT => RightOuter + case relation.Join.JoinType.OUTER => FullOuter + case relation.Join.JoinType.SEMI => LeftSemi + case relation.Join.JoinType.ANTI => LeftAnti + case relation.Join.JoinType.UNKNOWN => + throw new UnsupportedOperationException("Unknown join type is not supported") + } + Join(left, right, joinType, condition, hint = JoinHint.NONE) + } + } + + override def visit(join: relation.Cross): LogicalPlan = { + val left = join.getLeft.accept(this) + val right = join.getRight.accept(this) + withChild(left, right) { + // TODO: Support different join types here when join types are added to cross rel for BNLJ + // Currently, this will change both cross and inner join types to inner join + Join(left, right, Inner, Option(null), hint = JoinHint.NONE) + } + } + + private def toSortOrder(sortField: SExpression.SortField): SortOrder = { + val expression = sortField.expr().accept(expressionConverter) + val (direction, nullOrdering) = sortField.direction() match { + case SExpression.SortDirection.ASC_NULLS_FIRST => (Ascending, NullsFirst) + case SExpression.SortDirection.DESC_NULLS_FIRST => (Descending, NullsFirst) + case SExpression.SortDirection.ASC_NULLS_LAST => (Ascending, NullsLast) + case SExpression.SortDirection.DESC_NULLS_LAST => (Descending, NullsLast) + case other => + throw new RuntimeException(s"Unexpected Expression.SortDirection enum: $other !") + } + SortOrder(expression, direction, nullOrdering, Seq.empty) + } + override def visit(fetch: relation.Fetch): LogicalPlan = { + val child = fetch.getInput.accept(this) + val limit = Literal(fetch.getCount.getAsLong.intValue(), IntegerType) + fetch.getOffset match { + case 1L => GlobalLimit(limitExpr = limit, child = child) + case -1L => LocalLimit(limitExpr = limit, child = child) + case _ => visitFallback(fetch) + } + } + override def visit(sort: relation.Sort): LogicalPlan = { + val child = sort.getInput.accept(this) + withChild(child) { + val sortOrders = sort.getSortFields.asScala.map(toSortOrder) + Sort(sortOrders, global = true, child) + } + } + + override def visit(project: relation.Project): LogicalPlan = { + val child = project.getInput.accept(this) + val (output, createProject) = child match { + case a: Aggregate => (a.aggregateExpressions, false) + case other => (other.output, true) + } + + withOutput(output) { + val projectList = + project.getExpressions.asScala + .map(expr => expr.accept(expressionConverter)) + .map(toNamedExpression) + if (createProject) { + Project(projectList, child) + } else { + val aggregate: Aggregate = child.asInstanceOf[Aggregate] + aggregate.copy(aggregateExpressions = projectList) + } + } + } + + override def visit(filter: relation.Filter): LogicalPlan = { + val child = filter.getInput.accept(this) + withChild(child) { + val condition = filter.getCondition.accept(expressionConverter) + Filter(condition, child) + } + } + + override def visit(emptyScan: relation.EmptyScan): LogicalPlan = { + LocalRelation(ToSubstraitType.toAttribute(emptyScan.getInitialSchema)) + } + override def visit(namedScan: relation.NamedScan): LogicalPlan = { + resolve(UnresolvedRelation(namedScan.getNames.asScala)) match { + case m: MultiInstanceRelation => m.newInstance() + case other => other + } + } + + override def visit(localFiles: LocalFiles): LogicalPlan = { + val schema = ToSubstraitType.toStructType(localFiles.getInitialSchema) + val output = schema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + new LogicalRelation( + relation = HadoopFsRelation( + location = new InMemoryFileIndex( + spark, + localFiles.getItems.asScala.map(i => new Path(i.getPath.get())), + Map(), + Some(schema)), + partitionSchema = new StructType(), + dataSchema = schema, + bucketSpec = None, + fileFormat = new CSVFileFormat(), + options = Map() + )(spark), + output = output, + catalogTable = None, + isStreaming = false + ) + } + + private def withChild(child: LogicalPlan*)(body: => LogicalPlan): LogicalPlan = { + val output = child.flatMap(_.output) + withOutput(output)(body) + } + + private def withOutput(output: Seq[NamedExpression])(body: => LogicalPlan): LogicalPlan = { + expressionConverter.pushOutput(output) + try { + body + } finally { + expressionConverter.popOutput() + } + } + private def resolve(plan: LogicalPlan): LogicalPlan = { + val qe = new QueryExecution(spark, plan) + qe.analyzed match { + case SubqueryAlias(_, child) => child + case other => other + } + } + + def convert(plan: Plan): LogicalPlan = { + val root = plan.getRoots.get(0) + val names = root.getNames.asScala + val output = names.map(name => AttributeReference(name, DataTypes.StringType)()) + withOutput(output) { + val logicalPlan = root.getInput.accept(this); + val projectList: List[NamedExpression] = logicalPlan.output.zipWithIndex + .map( + z => { + val (e, i) = z; + if (e.name == names(i)) { + e + } else { + Alias(e, names(i))() + } + }) + .toList + val wrapper = Project(projectList, logicalPlan) + require(wrapper.resolved) + wrapper + } + } +} diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala new file mode 100644 index 00000000..1d723cb9 --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -0,0 +1,406 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.logical + +import io.substrait.spark.SparkExtension +import io.substrait.spark.expression._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} +import org.apache.spark.sql.types.StructType +import org.apache.spark.substrait.ToSubstraitType +import org.apache.spark.substrait.ToSubstraitType.toNamedStruct + +import io.substrait.{proto, relation} +import io.substrait.debug.TreePrinter +import io.substrait.expression.{Expression => SExpression, ExpressionCreator} +import io.substrait.extension.ExtensionCollector +import io.substrait.plan.{ImmutablePlan, ImmutableRoot, Plan} +import io.substrait.relation.RelProtoConverter +import io.substrait.relation.files.{FileFormat, ImmutableFileOrFiles} +import io.substrait.relation.files.FileOrFiles.PathType + +import java.util.Collections + +import scala.collection.JavaConverters.asJavaIterableConverter +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { + + private val toSubstraitExp = new WithLogicalSubQuery(this) + + private val TRUE = ExpressionCreator.bool(false, true) + + override def default(p: LogicalPlan): relation.Rel = p match { + case p: LeafNode => convertReadOperator(p) + case s: SubqueryAlias => visit(s.child) + case other => t(other) + } + + private def fromGroupSet( + e: Seq[Expression], + output: Seq[Attribute]): relation.Aggregate.Grouping = { + + relation.Aggregate.Grouping.builder + .addAllExpressions(e.map(toExpression(output)).asJava) + .build() + } + + private def fromAggCall( + expression: AggregateExpression, + output: Seq[Attribute]): relation.Aggregate.Measure = { + val substraitExps = expression.aggregateFunction.children.map(toExpression(output)) + val invocation = + SparkExtension.toAggregateFunction.apply(expression, substraitExps) + relation.Aggregate.Measure.builder.function(invocation).build() + } + + private def collectAggregates( + resultExpressions: Seq[NamedExpression], + aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = { + var ordinal = 0 + resultExpressions.flatMap { + expr => + expr.collect { + // Do not push down duplicated aggregate expressions. For example, + // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one + // `max(a)` to the data source. + case agg: AggregateExpression if !aggExprToOutputOrdinal.contains(agg.canonicalized) => + aggExprToOutputOrdinal(agg.canonicalized) = ordinal + ordinal += 1 + agg + } + } + } + + private def translateAggregation( + groupBy: Seq[Expression], + aggregates: Seq[AggregateExpression], + output: Seq[Attribute], + input: relation.Rel): relation.Aggregate = { + val groupings = Collections.singletonList(fromGroupSet(groupBy, output)) + val aggCalls = aggregates.map(fromAggCall(_, output)).asJava + + relation.Aggregate.builder + .input(input) + .addAllGroupings(groupings) + .addAllMeasures(aggCalls) + .build + } + + /** + * The current substrait [[relation.Aggregate]] can't specify output, but spark [[Aggregate]] + * allow. So To support #1 select max(b) from table group by a, and #2 select + * a, max(b) + 1 from table group by a, We need create [[Project]] on top of [[Aggregate]] + * to correctly support it. + * + * TODO: support [[Rollup]] and [[GroupingSets]] + */ + override def visitAggregate(agg: Aggregate): relation.Rel = { + val input = visit(agg.child) + val actualResultExprs = agg.aggregateExpressions + val actualGroupExprs = agg.groupingExpressions + + val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] + val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal) + val aggOutputMap = aggregates.zipWithIndex.map { + case (e, i) => + AttributeReference(s"agg_func_$i", e.dataType)() -> e + } + val aggOutput = aggOutputMap.map(_._1) + + // collect group by + val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] + actualGroupExprs.zipWithIndex.foreach { + case (expr, ordinal) => + if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) { + groupByExprToOutputOrdinal(expr.canonicalized) = ordinal + } + } + val groupOutputMap = actualGroupExprs.zipWithIndex.map { + case (e, i) => + AttributeReference(s"group_col_$i", e.dataType)() -> e + } + val groupOutput = groupOutputMap.map(_._1) + + val substraitAgg = translateAggregation(actualGroupExprs, aggregates, agg.child.output, input) + val newOutput = groupOutput ++ aggOutput + + val projectExpressions = actualResultExprs.map { + expr => + expr.transformDown { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + aggOutput(ordinal) + case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) => + val ordinal = groupByExprToOutputOrdinal(expr.canonicalized) + groupOutput(ordinal) + } + } + val projects = projectExpressions.map(toExpression(newOutput)) + + relation.Project.builder + .remap(relation.Rel.Remap.offset(newOutput.size, projects.size)) + .expressions(projects.asJava) + .input(substraitAgg) + .build() + } + + private def asLong(e: Expression): Long = e match { + case IntegerLiteral(limit) => limit + case other => throw new UnsupportedOperationException(s"Unknown type: $other") + } + + private def fetchBuilder(limit: Long, global: Boolean): relation.ImmutableFetch.Builder = { + val offset = if (global) 1L else -1L + relation.Fetch + .builder() + .count(limit) + .offset(offset) + } + override def visitGlobalLimit(p: GlobalLimit): relation.Rel = { + fetchBuilder(asLong(p.limitExpr), global = true) + .input(visit(p.child)) + .build() + } + + override def visitLocalLimit(p: LocalLimit): relation.Rel = { + fetchBuilder(asLong(p.limitExpr), global = false) + .input(visit(p.child)) + .build() + } + + override def visitFilter(p: Filter): relation.Rel = { + val condition = toExpression(p.child.output)(p.condition) + relation.Filter.builder().condition(condition).input(visit(p.child)).build() + } + + private def toSubstraitJoin(joinType: JoinType): relation.Join.JoinType = joinType match { + case Inner | Cross => relation.Join.JoinType.INNER + case LeftOuter => relation.Join.JoinType.LEFT + case RightOuter => relation.Join.JoinType.RIGHT + case FullOuter => relation.Join.JoinType.OUTER + case LeftSemi => relation.Join.JoinType.SEMI + case LeftAnti => relation.Join.JoinType.ANTI + case other => throw new UnsupportedOperationException(s"Unsupported join type $other") + } + + override def visitJoin(p: Join): relation.Rel = { + val left = visit(p.left) + val right = visit(p.right) + val condition = p.condition.map(toExpression(p.left.output ++ p.right.output)).getOrElse(TRUE) + val joinType = toSubstraitJoin(p.joinType) + + if (joinType == relation.Join.JoinType.INNER && TRUE == condition) { + relation.Cross.builder + .left(left) + .right(right) + .build + } else { + relation.Join.builder + .condition(condition) + .joinType(joinType) + .left(left) + .right(right) + .build + } + } + + override def visitProject(p: Project): relation.Rel = { + val expressions = p.projectList.map(toExpression(p.child.output)).toList + + relation.Project.builder + .remap(relation.Rel.Remap.offset(p.child.output.size, expressions.size)) + .expressions(expressions.asJava) + .input(visit(p.child)) + .build() + } + + private def toSortField(output: Seq[Attribute] = Nil)(order: SortOrder): SExpression.SortField = { + val direction = (order.direction, order.nullOrdering) match { + case (Ascending, NullsFirst) => SExpression.SortDirection.ASC_NULLS_FIRST + case (Descending, NullsFirst) => SExpression.SortDirection.DESC_NULLS_FIRST + case (Ascending, NullsLast) => SExpression.SortDirection.ASC_NULLS_LAST + case (Descending, NullsLast) => SExpression.SortDirection.DESC_NULLS_LAST + } + val expr = toExpression(output)(order.child) + SExpression.SortField.builder().expr(expr).direction(direction).build() + } + override def visitSort(sort: Sort): relation.Rel = { + val input = visit(sort.child) + val fields = sort.order.map(toSortField(sort.child.output)).asJava + relation.Sort.builder.addAllSortFields(fields).input(input).build + } + + private def toExpression(output: Seq[Attribute])(e: Expression): SExpression = { + toSubstraitExp(e, output) + } + + private def buildNamedScan(schema: StructType, tableNames: List[String]): relation.NamedScan = { + val namedStruct = toNamedStruct(schema) + + val namedScan = relation.NamedScan.builder + .initialSchema(namedStruct) + .addAllNames(tableNames.asJava) + .build + namedScan + } + private def buildVirtualTableScan(localRelation: LocalRelation): relation.AbstractReadRel = { + val namedStruct = toNamedStruct(localRelation.schema) + + if (localRelation.data.isEmpty) { + relation.EmptyScan.builder().initialSchema(namedStruct).build() + } else { + relation.VirtualTableScan + .builder() + .addAllDfsNames(namedStruct.names()) + .addAllRows( + localRelation.data + .map( + row => { + var idx = 0 + val buf = new ArrayBuffer[SExpression.Literal](row.numFields) + while (idx < row.numFields) { + val l = Literal.apply(row.get(idx, localRelation.schema(idx).dataType)) + buf += ToSubstraitLiteral.apply(l) + idx += 1 + } + ExpressionCreator.struct(false, buf.asJava) + }) + .asJava) + .build() + } + } + + private def buildLocalFileScan(fsRelation: HadoopFsRelation): relation.AbstractReadRel = { + val namedStruct = toNamedStruct(fsRelation.schema) + + val ff = new FileFormat.ParquetReadOptions { + override def toString: String = "csv" // TODO this is hardcoded at the moment + } + + relation.LocalFiles + .builder() + .initialSchema(namedStruct) + .addAllItems( + fsRelation.location.inputFiles + .map( + file => { + ImmutableFileOrFiles + .builder() + .fileFormat(ff) + .partitionIndex(0) + .start(0) + .length(fsRelation.sizeInBytes) + .path(file) + .pathType(PathType.URI_FILE) + .build() + }) + .toList + .asJava + ) + .build() + } + + /** Read Operator: https://substrait.io/relations/logical_relations/#read-operator */ + private def convertReadOperator(plan: LeafNode): relation.AbstractReadRel = { + var tableNames: List[String] = null + plan match { + case logicalRelation: LogicalRelation if logicalRelation.catalogTable.isDefined => + tableNames = logicalRelation.catalogTable.get.identifier.unquotedString.split("\\.").toList + buildNamedScan(logicalRelation.schema, tableNames) + case dataSourceV2ScanRelation: DataSourceV2ScanRelation => + tableNames = dataSourceV2ScanRelation.relation.identifier.get.toString.split("\\.").toList + buildNamedScan(dataSourceV2ScanRelation.schema, tableNames) + case dataSourceV2Relation: DataSourceV2Relation => + tableNames = dataSourceV2Relation.identifier.get.toString.split("\\.").toList + buildNamedScan(dataSourceV2Relation.schema, tableNames) + case hiveTableRelation: HiveTableRelation => + tableNames = hiveTableRelation.tableMeta.identifier.unquotedString.split("\\.").toList + buildNamedScan(hiveTableRelation.schema, tableNames) + case localRelation: LocalRelation => buildVirtualTableScan(localRelation) + case logicalRelation: LogicalRelation => + logicalRelation.relation match { + case fsRelation: HadoopFsRelation => + buildLocalFileScan(fsRelation) + case _ => + throw new UnsupportedOperationException( + s"******* Unable to convert the plan to a substrait relation: " + + s"${logicalRelation.relation.toString}") + } + case _ => + throw new UnsupportedOperationException( + s"******* Unable to convert the plan to a substrait NamedScan: $plan") + } + } + def convert(p: LogicalPlan): Plan = { + val rel = visit(p) + ImmutablePlan.builder + .roots( + Collections.singletonList( + ImmutableRoot.builder().input(rel).addAllNames(p.output.map(_.name).asJava).build() + )) + .build() + } + + def tree(p: LogicalPlan): String = { + TreePrinter.tree(visit(p)) + } + + def toProtoSubstrait(p: LogicalPlan): Array[Byte] = { + val substraitRel = visit(p) + + val extensionCollector = new ExtensionCollector + val relProtoConverter = new RelProtoConverter(extensionCollector) + val builder = proto.Plan + .newBuilder() + .addRelations( + proto.PlanRel + .newBuilder() + .setRel(substraitRel + .accept(relProtoConverter)) + ) + extensionCollector.addExtensionsToPlan(builder) + builder.build().toByteArray + } +} + +private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel) + extends ToSubstraitExpression { + override protected val toScalarFunction: ToScalarFunction = + ToScalarFunction(SparkExtension.SparkScalarFunctions) + + override protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = { + expr match { + case s: ScalarSubquery if s.outerAttrs.isEmpty && s.joinCond.isEmpty => + val rel = toSubstraitRel.visit(s.plan) + Some( + SExpression.ScalarSubquery.builder + .input(rel) + .`type`(ToSubstraitType.apply(s.dataType, s.nullable)) + .build()) + case other => default(other) + } + } +} diff --git a/spark/src/main/scala/io/substrait/utils/Util.scala b/spark/src/main/scala/io/substrait/utils/Util.scala new file mode 100644 index 00000000..165d5995 --- /dev/null +++ b/spark/src/main/scala/io/substrait/utils/Util.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.utils + +import scala.annotation.tailrec +import scala.collection.mutable.ArrayBuffer + +object Util { + + /** + * Compute the cartesian product for n lists. + * + *

Based on Soln by + * Thomas Preissler + */ + def crossProduct[T](lists: Seq[Seq[T]]): Seq[Seq[T]] = { + + /** list [a, b], element 1 => list + element => [a, b, 1] */ + val appendElementToList: (Seq[T], T) => Seq[T] = + (list, element) => list :+ element + + /** ([a, b], [1, 2]) ==> [a, b, 1], [a, b, 2] */ + val appendAndGen: (Seq[T], Seq[T]) => Seq[Seq[T]] = + (list, elemsToAppend) => elemsToAppend.map(e => appendElementToList(list, e)) + + val firstListToJoin = lists.head + val startProduct = appendAndGen(new ArrayBuffer[T], firstListToJoin) + + /** ([ [a, b], [c, d] ], [1, 2]) -> [a, b, 1], [a, b, 2], [c, d, 1], [c, d, 2] */ + val appendAndGenLists: (Seq[Seq[T]], Seq[T]) => Seq[Seq[T]] = + (products, toJoin) => products.flatMap(product => appendAndGen(product, toJoin)) + lists.tail.foldLeft(startProduct)(appendAndGenLists) + } + + def seqToOption[T](s: Seq[Option[T]]): Option[Seq[T]] = { + @tailrec + def seqToOptionHelper(s: Seq[Option[T]], accum: Seq[T] = Seq[T]()): Option[Seq[T]] = { + s match { + case Some(head) :: Nil => + Option(accum :+ head) + case Some(head) :: tail => + seqToOptionHelper(tail, accum :+ head) + case _ => None + } + } + seqToOptionHelper(s) + } + +} diff --git a/spark/src/main/scala/org/apache/spark/substrait/SparkTypeUtil.scala b/spark/src/main/scala/org/apache/spark/substrait/SparkTypeUtil.scala new file mode 100644 index 00000000..7af79644 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/substrait/SparkTypeUtil.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.substrait + +import org.apache.spark.sql.types.DataType + +object SparkTypeUtil { + + def sameType(left: DataType, right: DataType): Boolean = { + left.sameType(right) + } + +} diff --git a/spark/src/main/scala/org/apache/spark/substrait/ToSubstraitType.scala b/spark/src/main/scala/org/apache/spark/substrait/ToSubstraitType.scala new file mode 100644 index 00000000..1ad0040f --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/substrait/ToSubstraitType.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.substrait + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.types._ + +import io.substrait.`type`.{NamedStruct, Type, TypeVisitor} +import io.substrait.function.TypeExpression +import io.substrait.utils.Util + +import scala.collection.JavaConverters +import scala.collection.JavaConverters.asScalaBufferConverter + +private class ToSparkType + extends TypeVisitor.TypeThrowsVisitor[DataType, RuntimeException]("Unknown expression type.") { + + override def visit(expr: Type.I32): DataType = IntegerType + override def visit(expr: Type.I64): DataType = LongType + + override def visit(expr: Type.FP32): DataType = FloatType + override def visit(expr: Type.FP64): DataType = DoubleType + + override def visit(expr: Type.Decimal): DataType = + DecimalType(expr.precision(), expr.scale()) + + override def visit(expr: Type.Date): DataType = DateType + + override def visit(expr: Type.Str): DataType = StringType + + override def visit(expr: Type.FixedChar): DataType = StringType + + override def visit(expr: Type.VarChar): DataType = StringType +} +class ToSubstraitType { + + def convert(typeExpression: TypeExpression): DataType = { + typeExpression.accept(new ToSparkType) + } + + def convert(dataType: DataType, nullable: Boolean): Option[Type] = { + convert(dataType, Seq.empty, nullable) + } + + def apply(dataType: DataType, nullable: Boolean): Type = { + convert(dataType, Seq.empty, nullable) + .getOrElse( + throw new UnsupportedOperationException(s"Unable to convert the type ${dataType.typeName}")) + } + + protected def convert(dataType: DataType, names: Seq[String], nullable: Boolean): Option[Type] = { + val creator = Type.withNullability(nullable) + dataType match { + case BooleanType => Some(creator.BOOLEAN) + case ByteType => Some(creator.I8) + case ShortType => Some(creator.I16) + case IntegerType => Some(creator.I32) + case LongType => Some(creator.I64) + case FloatType => Some(creator.FP32) + case DoubleType => Some(creator.FP64) + case decimal: DecimalType if decimal.precision <= 38 => + Some(creator.decimal(decimal.precision, decimal.scale)) + case charType: CharType => Some(creator.fixedChar(charType.length)) + case varcharType: VarcharType => Some(creator.varChar(varcharType.length)) + case StringType => Some(creator.STRING) + case DateType => Some(creator.DATE) + case TimestampType => Some(creator.TIMESTAMP) + case TimestampNTZType => Some(creator.TIMESTAMP_TZ) + case BinaryType => Some(creator.BINARY) + case ArrayType(elementType, containsNull) => + convert(elementType, Seq.empty, containsNull).map(creator.list) + case MapType(keyType, valueType, valueContainsNull) => + convert(keyType, Seq.empty, nullable = false) + .flatMap( + keyT => + convert(valueType, Seq.empty, valueContainsNull) + .map(valueT => creator.map(keyT, valueT))) + case _ => + None + } + } + def toNamedStruct(output: Seq[Attribute]): Option[NamedStruct] = { + val names = JavaConverters.seqAsJavaList(output.map(_.name)) + val creator = Type.withNullability(false) + Util + .seqToOption(output.map(a => convert(a.dataType, a.nullable))) + .map(l => creator.struct(JavaConverters.asJavaIterable(l))) + .map(NamedStruct.of(names, _)) + } + def toNamedStruct(schema: StructType): NamedStruct = { + val creator = Type.withNullability(false) + val names = new java.util.ArrayList[String] + val children = new java.util.ArrayList[Type] + schema.fields.foreach( + field => { + names.add(field.name) + children.add(apply(field.dataType, field.nullable)) + }) + val struct = creator.struct(children) + NamedStruct.of(names, struct) + } + + def toStructType(namedStruct: NamedStruct): StructType = { + StructType( + fields = namedStruct + .struct() + .fields() + .asScala + .map(t => (t, convert(t))) + .zip(namedStruct.names().asScala) + .map { case ((t, d), name) => StructField(name, d, t.nullable()) } + ) + } + + def toAttribute(namedStruct: NamedStruct): Seq[AttributeReference] = { + namedStruct + .struct() + .fields() + .asScala + .map(t => (t, convert(t))) + .zip(namedStruct.names().asScala) + .map { case ((t, d), name) => StructField(name, d, t.nullable()) } + .map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + } +} + +object ToSubstraitType extends ToSubstraitType diff --git a/spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala new file mode 100644 index 00000000..836a087f --- /dev/null +++ b/spark/src/main/spark-3.2/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.logical + +import org.apache.spark.sql.catalyst.plans.logical._ + +import io.substrait.relation +import io.substrait.relation.Rel + +class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] { + + protected def t(p: LogicalPlan): relation.Rel = + throw new UnsupportedOperationException(s"Unable to convert the LogicalPlan ${p.nodeName}") + + override def visitDistinct(p: Distinct): relation.Rel = t(p) + + override def visitExcept(p: Except): relation.Rel = t(p) + + override def visitExpand(p: Expand): relation.Rel = t(p) + + override def visitRepartition(p: Repartition): relation.Rel = t(p) + + override def visitRepartitionByExpr(p: RepartitionByExpression): relation.Rel = t(p) + + override def visitSample(p: Sample): relation.Rel = t(p) + + override def visitScriptTransform(p: ScriptTransformation): relation.Rel = t(p) + + override def visitUnion(p: Union): relation.Rel = t(p) + + override def visitWindow(p: Window): relation.Rel = t(p) + + override def visitTail(p: Tail): relation.Rel = t(p) + + override def visitGenerate(p: Generate): relation.Rel = t(p) + + override def visitGlobalLimit(p: GlobalLimit): relation.Rel = t(p) + + override def visitIntersect(p: Intersect): relation.Rel = t(p) + + override def visitLocalLimit(p: LocalLimit): relation.Rel = t(p) + + override def visitPivot(p: Pivot): relation.Rel = t(p) + + override def default(p: LogicalPlan): Rel = t(p) + + override def visitAggregate(p: Aggregate): Rel = t(p) + + override def visitFilter(p: Filter): Rel = t(p) + + override def visitJoin(p: Join): Rel = t(p) + + override def visitProject(p: Project): Rel = t(p) + + override def visitSort(sort: Sort): Rel = t(sort) + + override def visitWithCTE(p: WithCTE): Rel = t(p) +} diff --git a/spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala new file mode 100644 index 00000000..345cb215 --- /dev/null +++ b/spark/src/main/spark-3.3/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.logical + +import org.apache.spark.sql.catalyst.plans.logical._ + +import io.substrait.relation +import io.substrait.relation.Rel + +class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] { + + protected def t(p: LogicalPlan): relation.Rel = + throw new UnsupportedOperationException(s"Unable to convert the expression ${p.nodeName}") + + override def visitDistinct(p: Distinct): relation.Rel = t(p) + + override def visitExcept(p: Except): relation.Rel = t(p) + + override def visitExpand(p: Expand): relation.Rel = t(p) + + override def visitRepartition(p: Repartition): relation.Rel = t(p) + + override def visitRepartitionByExpr(p: RepartitionByExpression): relation.Rel = t(p) + + override def visitSample(p: Sample): relation.Rel = t(p) + + override def visitScriptTransform(p: ScriptTransformation): relation.Rel = t(p) + + override def visitUnion(p: Union): relation.Rel = t(p) + + override def visitWindow(p: Window): relation.Rel = t(p) + + override def visitTail(p: Tail): relation.Rel = t(p) + + override def visitGenerate(p: Generate): relation.Rel = t(p) + + override def visitGlobalLimit(p: GlobalLimit): relation.Rel = t(p) + + override def visitIntersect(p: Intersect): relation.Rel = t(p) + + override def visitLocalLimit(p: LocalLimit): relation.Rel = t(p) + + override def visitPivot(p: Pivot): relation.Rel = t(p) + + override def default(p: LogicalPlan): Rel = t(p) + + override def visitAggregate(p: Aggregate): Rel = t(p) + + override def visitFilter(p: Filter): Rel = t(p) + + override def visitJoin(p: Join): Rel = t(p) + + override def visitProject(p: Project): Rel = t(p) + + override def visitSort(sort: Sort): Rel = t(sort) + + override def visitWithCTE(p: WithCTE): Rel = t(p) + + override def visitRebalancePartitions(p: RebalancePartitions): Rel = t(p) +} diff --git a/spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala b/spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala new file mode 100644 index 00000000..ec3ee78e --- /dev/null +++ b/spark/src/main/spark-3.4/io/substrait/spark/logical/AbstractLogicalPlanVisitor.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.logical + +import org.apache.spark.sql.catalyst.plans.logical._ + +import io.substrait.relation +import io.substrait.relation.Rel + +class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] { + + protected def t(p: LogicalPlan): relation.Rel = + throw new UnsupportedOperationException(s"Unable to convert the expression ${p.nodeName}") + + override def visitDistinct(p: Distinct): relation.Rel = t(p) + + override def visitExcept(p: Except): relation.Rel = t(p) + + override def visitExpand(p: Expand): relation.Rel = t(p) + + override def visitRepartition(p: Repartition): relation.Rel = t(p) + + override def visitRepartitionByExpr(p: RepartitionByExpression): relation.Rel = t(p) + + override def visitSample(p: Sample): relation.Rel = t(p) + + override def visitScriptTransform(p: ScriptTransformation): relation.Rel = t(p) + + override def visitUnion(p: Union): relation.Rel = t(p) + + override def visitWindow(p: Window): relation.Rel = t(p) + + override def visitTail(p: Tail): relation.Rel = t(p) + + override def visitGenerate(p: Generate): relation.Rel = t(p) + + override def visitGlobalLimit(p: GlobalLimit): relation.Rel = t(p) + + override def visitIntersect(p: Intersect): relation.Rel = t(p) + + override def visitLocalLimit(p: LocalLimit): relation.Rel = t(p) + + override def visitPivot(p: Pivot): relation.Rel = t(p) + + override def default(p: LogicalPlan): Rel = t(p) + + override def visitAggregate(p: Aggregate): Rel = t(p) + + override def visitFilter(p: Filter): Rel = t(p) + + override def visitJoin(p: Join): Rel = t(p) + + override def visitProject(p: Project): Rel = t(p) + + override def visitSort(sort: Sort): Rel = t(sort) + + override def visitWithCTE(p: WithCTE): Rel = t(p) + + override def visitOffset(p: Offset): Rel = t(p) + + override def visitRebalancePartitions(p: RebalancePartitions): Rel = t(p) +} diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala new file mode 100644 index 00000000..4fa9ec26 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel} + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.resourceToString +import org.apache.spark.sql.test.SharedSparkSession + +import io.substrait.debug.TreePrinter +import io.substrait.extension.ExtensionCollector +import io.substrait.plan.{Plan, PlanProtoConverter, ProtoPlanConverter} +import io.substrait.proto +import io.substrait.relation.RelProtoConverter +import org.scalactic.Equality +import org.scalactic.source.Position +import org.scalatest.Succeeded +import org.scalatest.compatible.Assertion +import org.scalatest.exceptions.{StackDepthException, TestFailedException} + +trait SubstraitPlanTestBase { self: SharedSparkSession => + + implicit class PlainEquality[T: TreePrinter](actual: T) { + // Like should equal, but does not try to mark diffs in strings with square brackets, + // so that IntelliJ can show a proper diff. + def shouldEqualPlainly(expected: T)(implicit equality: Equality[T]): Assertion = + if (!equality.areEqual(actual, expected)) { + + throw new TestFailedException( + (e: StackDepthException) => + Some( + s"${implicitly[TreePrinter[T]].tree(actual)}" + + s" did not equal ${implicitly[TreePrinter[T]].tree(expected)}"), + None, + Position.here + ) + } else Succeeded + } + + def sqlToProtoPlan(sql: String): proto.Plan = { + val convert = new ToSubstraitRel() + val logicalPlan = plan(sql) + val substraitRel = convert.visit(logicalPlan) + + val extensionCollector = new ExtensionCollector + val relProtoConverter = new RelProtoConverter(extensionCollector) + val builder = proto.Plan + .newBuilder() + .addRelations( + proto.PlanRel + .newBuilder() + .setRoot( + proto.RelRoot + .newBuilder() + .setInput(substraitRel + .accept(relProtoConverter)) + ) + ) + extensionCollector.addExtensionsToPlan(builder) + builder.build() + } + + def assertProtoPlanRoundrip(sql: String): Plan = { + val protoPlan1 = sqlToProtoPlan(sql) + val plan = new ProtoPlanConverter().from(protoPlan1) + val protoPlan2 = new PlanProtoConverter().toProto(plan) + assertResult(protoPlan1)(protoPlan2) + assertResult(1)(plan.getRoots.size()) + plan + } + + def assertSqlSubstraitRelRoundTrip(query: String): LogicalPlan = { + // TODO need a more robust way of testing this than round-tripping. + val logicalPlan = plan(query) + val pojoRel = new ToSubstraitRel().visit(logicalPlan) + val converter = new ToLogicalPlan(spark = spark); + val logicalPlan2 = pojoRel.accept(converter); + require(logicalPlan2.resolved); + val pojoRel2 = new ToSubstraitRel().visit(logicalPlan2) + + pojoRel2.shouldEqualPlainly(pojoRel) + logicalPlan2 + } + + def plan(sql: String): LogicalPlan = { + spark.sql(sql).queryExecution.optimizedPlan + } + + def assertPlanRoundrip(plan: Plan): Unit = { + val protoPlan1 = new PlanProtoConverter().toProto(plan) + val protoPlan2 = new PlanProtoConverter().toProto(new ProtoPlanConverter().from(protoPlan1)) + assertResult(protoPlan1)(protoPlan2) + } + + def testQuery(group: String, query: String, suffix: String = ""): Unit = { + val queryString = resourceToString( + s"$group/$query.sql", + classLoader = Thread.currentThread().getContextClassLoader) + assert(queryString != null) + assertSqlSubstraitRelRoundTrip(queryString) + } +} diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala new file mode 100644 index 00000000..7cfb3cd2 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import org.apache.spark.sql.TPCDSBase +import org.apache.spark.sql.internal.SQLConf + +class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { + + private val runAllQueriesIncludeFailed = false + override def beforeAll(): Unit = { + super.beforeAll() + sparkContext.setLogLevel("WARN") + + conf.setConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED, false) + // introduced in spark 3.4 + spark.conf.set("spark.sql.readSideCharPadding", "false") + } + + // "q9" failed in spark 3.3 + val successfulSQL: Set[String] = Set("q41", "q62", "q93", "q96", "q99") + + tpcdsQueries.foreach { + q => + if (runAllQueriesIncludeFailed || successfulSQL.contains(q)) { + test(s"check simplified (tpcds-v1.4/$q)") { + testQuery("tpcds", q) + } + } else { + ignore(s"check simplified (tpcds-v1.4/$q)") { + testQuery("tpcds", q) + } + } + } + + ignore("window") { + val qry = s"""(SELECT + | item_sk, + | rank() + | OVER ( + | ORDER BY rank_col DESC) rnk + | FROM (SELECT + | ss_item_sk item_sk, + | avg(ss_net_profit) rank_col + | FROM store_sales ss1 + | WHERE ss_store_sk = 4 + | GROUP BY ss_item_sk + | HAVING avg(ss_net_profit) > 0.9 * (SELECT avg(ss_net_profit) rank_col + | FROM store_sales + | WHERE ss_store_sk = 4 + | AND ss_addr_sk IS NULL + | GROUP BY ss_store_sk)) V2) """.stripMargin + assertSqlSubstraitRelRoundTrip(qry) + } +} diff --git a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala new file mode 100644 index 00000000..7f2e9978 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark + +import org.apache.spark.sql.TPCHBase + +class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { + + override def beforeAll(): Unit = { + super.beforeAll() + sparkContext.setLogLevel("WARN") + } + + tpchQueries.foreach { + q => + test(s"check simplified (tpch/$q)") { + testQuery("tpch", q) + } + } + + test("Decimal") { + assertSqlSubstraitRelRoundTrip("select l_returnflag," + + " sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) from lineitem group by l_returnflag") + } + + test("simpleJoin") { + assertSqlSubstraitRelRoundTrip( + "select l_partkey + l_orderkey, l_extendedprice * 0.1 + 100.0, o_orderkey from lineitem " + + "join orders on l_orderkey = o_orderkey where l_shipdate < date '1998-01-01' ") + assertSqlSubstraitRelRoundTrip( + "select l_partkey + l_orderkey, l_extendedprice * 0.1 + 100.0, o_orderkey from lineitem " + + "left join orders on l_orderkey = o_orderkey where l_shipdate < date '1998-01-01' ") + assertSqlSubstraitRelRoundTrip( + "select l_partkey + l_orderkey, l_extendedprice * 0.1 + 100.0, o_orderkey from lineitem " + + "right join orders on l_orderkey = o_orderkey where l_shipdate < date '1998-01-01' ") + assertSqlSubstraitRelRoundTrip( + "select l_partkey + l_orderkey, l_extendedprice * 0.1 + 100.0, o_orderkey from lineitem " + + "full join orders on l_orderkey = o_orderkey where l_shipdate < date '1998-01-01' ") + } + + test("simpleOrderByClause") { + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate, l_discount") + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate asc, l_discount desc") +// assertSqlSubstraitRelRoundTrip( +// "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + +// "order by l_shipdate asc, l_discount desc limit 100 offset 1000") +// assertSqlSubstraitRelRoundTrip( +// "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + +// "order by l_shipdate asc, l_discount desc limit 100") +// assertSqlSubstraitRelRoundTrip( +// "select l_partkey from lineitem where l_shipdate < date '1998-01-01' limit 100") + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate asc, l_discount desc nulls first") + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate asc, l_discount desc nulls last") + } + + test("simpleTest") { + val query = "select p_size from part where p_partkey > cast(100 as bigint)" + assertSqlSubstraitRelRoundTrip(query) + } + + test("simpleTest2") { + val query = "select l_partkey, l_discount from lineitem where l_orderkey > cast(100 as bigint)" + assertSqlSubstraitRelRoundTrip(query) + } + + test("simpleTestAgg") { + assertSqlSubstraitRelRoundTrip( + "select l_partkey, count(l_tax), COUNT(distinct l_discount) from lineitem group by l_partkey") + + assertSqlSubstraitRelRoundTrip( + "select count(l_tax), COUNT(distinct l_discount)" + + " from lineitem group by l_partkey + l_orderkey") + + assertSqlSubstraitRelRoundTrip( + "select l_partkey + l_orderkey, count(l_tax), COUNT(distinct l_discount)" + + " from lineitem group by l_partkey + l_orderkey") + } + + ignore("avg(distinct)") { + assertSqlSubstraitRelRoundTrip( + "select l_partkey, sum(l_tax), sum(distinct l_tax)," + + " avg(l_discount), avg(distinct l_discount) from lineitem group by l_partkey") + } + + test("simpleTestAgg3") { + assertSqlSubstraitRelRoundTrip( + "select l_partkey, sum(l_extendedprice * (1.0-l_discount)) from lineitem group by l_partkey") + } + + ignore("simpleTestAggFilter") { + assertSqlSubstraitRelRoundTrip( + "select sum(l_tax) filter(WHERE l_orderkey > l_partkey) from lineitem") + // cast is added to avoid the difference by implicit cast + assertSqlSubstraitRelRoundTrip( + "select sum(l_tax) filter(WHERE l_orderkey > cast(10.0 as bigint)) from lineitem") + } + + test("simpleTestAggNoGB") { + assertSqlSubstraitRelRoundTrip("select count(l_tax), count(distinct l_discount) from lineitem") + } + + test("simpleTestApproxCountDistinct") { + val query = "select approx_count_distinct(l_tax) from lineitem" + val plan = assertSqlSubstraitRelRoundTrip(query) + } + + test("simpleTestDateInterval") { + assertSqlSubstraitRelRoundTrip( + "select l_partkey+l_orderkey, l_shipdate from lineitem where l_shipdate < date '1998-01-01' ") + assertSqlSubstraitRelRoundTrip( + "select l_partkey+l_orderkey, l_shipdate from lineitem " + + "where l_shipdate < date '1998-01-01' + interval '3' month ") + assertSqlSubstraitRelRoundTrip( + "select l_partkey+l_orderkey, l_shipdate from lineitem " + + "where l_shipdate < date '1998-01-01' + interval '1' year") + assertSqlSubstraitRelRoundTrip( + "select l_partkey+l_orderkey, l_shipdate from lineitem " + + "where l_shipdate < date '1998-01-01' + interval '1-3' year to month") + } + + test("simpleTestDecimal") { + assertSqlSubstraitRelRoundTrip( + "select l_partkey + l_orderkey, l_extendedprice * 0.1 + 100.0 from lineitem" + + " where l_shipdate < date '1998-01-01' ") + } + + ignore("simpleTestGroupingSets [has Expand]") { + assertSqlSubstraitRelRoundTrip( + "select sum(l_discount) from lineitem group by grouping sets " + + "((l_orderkey, L_COMMITDATE), l_shipdate)") + + assertSqlSubstraitRelRoundTrip( + "select sum(l_discount) from lineitem group by grouping sets " + + "((l_orderkey, L_COMMITDATE), l_shipdate), l_linestatus") + + assertSqlSubstraitRelRoundTrip( + "select sum(l_discount) from lineitem group by grouping sets " + + "((l_orderkey, L_COMMITDATE), l_shipdate, ())") + + assertSqlSubstraitRelRoundTrip( + "select sum(l_discount) from lineitem group by grouping sets " + + "((l_orderkey, L_COMMITDATE), l_shipdate, ()), l_linestatus") + assertSqlSubstraitRelRoundTrip( + "select sum(l_discount) from lineitem group by grouping sets " + + "((l_orderkey, L_COMMITDATE), (l_orderkey, L_COMMITDATE, l_linestatus), l_shipdate, ())") + } + + test("tpch_q1_variant") { + // difference from tpch_q1 : 1) remove order by clause; 2) remove interval date literal + assertSqlSubstraitRelRoundTrip( + "select\n" + + " l_returnflag,\n" + + " l_linestatus,\n" + + " sum(l_quantity) as sum_qty,\n" + + " sum(l_extendedprice) as sum_base_price,\n" + + " sum(l_extendedprice * (1.0 - l_discount)) as sum_disc_price,\n" + + " sum(l_extendedprice * (1.0 - l_discount) * (1.0 + l_tax)) as sum_charge,\n" + + " avg(l_quantity) as avg_qty,\n" + + " avg(l_extendedprice) as avg_price,\n" + + " avg(l_discount) as avg_disc,\n" + + " count(*) as count_order\n" + + "from\n" + + " lineitem\n" + + "where\n" + + " l_shipdate <= date '1998-12-01' \n" + + "group by\n" + + " l_returnflag,\n" + + " l_linestatus\n") + } +} diff --git a/spark/src/test/scala/io/substrait/spark/expression/ArithmeticExpressionSuite.scala b/spark/src/test/scala/io/substrait/spark/expression/ArithmeticExpressionSuite.scala new file mode 100644 index 00000000..f94230b1 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/expression/ArithmeticExpressionSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{IntegerType, LongType} + +import io.substrait.`type`.TypeCreator +import io.substrait.expression.{Expression => SExpression, ExpressionCreator} +import io.substrait.expression.Expression.FailureBehavior + +class ArithmeticExpressionSuite extends SparkFunSuite with SubstraitExpressionTestBase { + + test("+ (Add)") { + runTest( + "add:i64_i64", + Add(Literal(1), Literal(2L)), + func => { + assertResult(true)(func.arguments().get(1).isInstanceOf[SExpression.I64Literal]) + assertResult( + ExpressionCreator.cast( + TypeCreator.REQUIRED.I64, + ExpressionCreator.i32(false, 1), + FailureBehavior.THROW_EXCEPTION + ))(func.arguments().get(0)) + }, + bidirectional = false + ) // TODO: implicit calcite cast + + runTest( + "add:i64_i64", + Add(Cast(Literal(1), LongType), Literal(2L)), + func => {}, + bidirectional = true) + + runTest("add:i32_i32", Add(Literal(1), Cast(Literal(2L), IntegerType))) + + runTest( + "add:i32_i32", + Add(Literal(1), Literal(2)), + func => { + assertResult(true)(func.arguments().get(0).isInstanceOf[SExpression.I32Literal]) + assertResult(true)(func.arguments().get(1).isInstanceOf[SExpression.I32Literal]) + }, + bidirectional = true + ) + } +} diff --git a/spark/src/test/scala/io/substrait/spark/expression/PredicateSuite.scala b/spark/src/test/scala/io/substrait/spark/expression/PredicateSuite.scala new file mode 100644 index 00000000..254ba99c --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/expression/PredicateSuite.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{And, Literal} + +class PredicateSuite extends SparkFunSuite with SubstraitExpressionTestBase { + + test("And") { + runTest("and:bool", And(Literal(true), Literal(false))) + } +} diff --git a/spark/src/test/scala/io/substrait/spark/expression/SubstraitExpressionTestBase.scala b/spark/src/test/scala/io/substrait/spark/expression/SubstraitExpressionTestBase.scala new file mode 100644 index 00000000..45de335b --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/expression/SubstraitExpressionTestBase.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import io.substrait.spark.SparkExtension + +import org.apache.spark.sql.catalyst.expressions.Expression + +import io.substrait.expression.{Expression => SExpression} +import org.scalatest.Assertions.assertResult + +trait SubstraitExpressionTestBase { + + private val toSparkExpression = + new ToSparkExpression(ToScalarFunction(SparkExtension.SparkScalarFunctions)) + + private val toSubstraitExpression = new ToSubstraitExpression { + override protected val toScalarFunction: ToScalarFunction = + ToScalarFunction(SparkExtension.SparkScalarFunctions) + } + + protected def runTest(expectedName: String, expression: Expression): Unit = { + runTest(expectedName, expression, func => {}, bidirectional = true) + } + + protected def runTest( + expectedName: String, + expression: Expression, + f: SExpression.ScalarFunctionInvocation => Unit, + bidirectional: Boolean): Unit = { + val substraitExp = toSubstraitExpression(expression) + .asInstanceOf[SExpression.ScalarFunctionInvocation] + assertResult(expectedName)(substraitExp.declaration().key()) + f(substraitExp) + + if (bidirectional) { + val convertedExpression = substraitExp.accept(toSparkExpression) + assertResult(expression)(convertedExpression) + } + } +} diff --git a/spark/src/test/scala/io/substrait/spark/expression/YamlTest.scala b/spark/src/test/scala/io/substrait/spark/expression/YamlTest.scala new file mode 100644 index 00000000..e855e0d4 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/expression/YamlTest.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import io.substrait.spark.SparkExtension + +import org.apache.spark.SparkFunSuite + +class YamlTest extends SparkFunSuite { + + test("has_year_definition") { + assert( + SparkExtension.SparkScalarFunctions + .map(f => f.key()) + .exists(p => p.equals("year:date"))) + assert( + SparkExtension.SparkScalarFunctions + .map(f => f.key()) + .exists(p => p.equals("unscaled:dec"))) + } +} diff --git a/spark/src/test/spark-3.2/org/apache/spark/sql/TPCBase.scala b/spark/src/test/spark-3.2/org/apache/spark/sql/TPCBase.scala new file mode 100644 index 00000000..c2c0beac --- /dev/null +++ b/spark/src/test/spark-3.2/org/apache/spark/sql/TPCBase.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.SparkConf +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +trait TPCBase extends SharedSparkSession { + + protected def injectStats: Boolean = false + + override protected def sparkConf: SparkConf = { + if (injectStats) { + super.sparkConf + .set(SQLConf.MAX_TO_STRING_FIELDS, Int.MaxValue) + .set(SQLConf.CBO_ENABLED, true) + .set(SQLConf.PLAN_STATS_ENABLED, true) + .set(SQLConf.JOIN_REORDER_ENABLED, true) + } else { + super.sparkConf.set(SQLConf.MAX_TO_STRING_FIELDS, Int.MaxValue) + } + } + + override def beforeAll(): Unit = { + super.beforeAll() + createTables() + } + + override def afterAll(): Unit = { + dropTables() + super.afterAll() + } + + protected def createTables(): Unit + + protected def dropTables(): Unit +} diff --git a/spark/src/test/spark-3.2/org/apache/spark/sql/TPCHBase.scala b/spark/src/test/spark-3.2/org/apache/spark/sql/TPCHBase.scala new file mode 100644 index 00000000..c3247c5c --- /dev/null +++ b/spark/src/test/spark-3.2/org/apache/spark/sql/TPCHBase.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.TableIdentifier + +trait TPCHBase extends TPCBase { + + override def createTables(): Unit = { + tpchCreateTable.values.foreach(sql => spark.sql(sql)) + } + + override def dropTables(): Unit = { + tpchCreateTable.keys.foreach { + tableName => spark.sessionState.catalog.dropTable(TableIdentifier(tableName), true, true) + } + } + + val tpchCreateTable = Map( + "orders" -> + """ + |CREATE TABLE `orders` ( + |`o_orderkey` BIGINT, `o_custkey` BIGINT, `o_orderstatus` STRING, + |`o_totalprice` DECIMAL(10,0), `o_orderdate` DATE, `o_orderpriority` STRING, + |`o_clerk` STRING, `o_shippriority` INT, `o_comment` STRING) + |USING parquet + """.stripMargin, + "nation" -> + """ + |CREATE TABLE `nation` ( + |`n_nationkey` BIGINT, `n_name` STRING, `n_regionkey` BIGINT, `n_comment` STRING) + |USING parquet + """.stripMargin, + "region" -> + """ + |CREATE TABLE `region` ( + |`r_regionkey` BIGINT, `r_name` STRING, `r_comment` STRING) + |USING parquet + """.stripMargin, + "part" -> + """ + |CREATE TABLE `part` (`p_partkey` BIGINT, `p_name` STRING, `p_mfgr` STRING, + |`p_brand` STRING, `p_type` STRING, `p_size` INT, `p_container` STRING, + |`p_retailprice` DECIMAL(10,0), `p_comment` STRING) + |USING parquet + """.stripMargin, + "partsupp" -> + """ + |CREATE TABLE `partsupp` (`ps_partkey` BIGINT, `ps_suppkey` BIGINT, + |`ps_availqty` INT, `ps_supplycost` DECIMAL(10,0), `ps_comment` STRING) + |USING parquet + """.stripMargin, + "customer" -> + """ + |CREATE TABLE `customer` (`c_custkey` BIGINT, `c_name` STRING, `c_address` STRING, + |`c_nationkey` BIGINT, `c_phone` STRING, `c_acctbal` DECIMAL(10,0), + |`c_mktsegment` STRING, `c_comment` STRING) + |USING parquet + """.stripMargin, + "supplier" -> + """ + |CREATE TABLE `supplier` (`s_suppkey` BIGINT, `s_name` STRING, `s_address` STRING, + |`s_nationkey` BIGINT, `s_phone` STRING, `s_acctbal` DECIMAL(10,0), `s_comment` STRING) + |USING parquet + """.stripMargin, + "lineitem" -> + """ + |CREATE TABLE `lineitem` (`l_orderkey` BIGINT, `l_partkey` BIGINT, `l_suppkey` BIGINT, + |`l_linenumber` INT, `l_quantity` DECIMAL(10,0), `l_extendedprice` DECIMAL(10,0), + |`l_discount` DECIMAL(10,0), `l_tax` DECIMAL(10,0), `l_returnflag` STRING, + |`l_linestatus` STRING, `l_shipdate` DATE, `l_commitdate` DATE, `l_receiptdate` DATE, + |`l_shipinstruct` STRING, `l_shipmode` STRING, `l_comment` STRING) + |USING parquet + """.stripMargin + ) + + val tpchQueries = Seq( + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11", + "q12", + "q13", + "q14", + "q15", + "q16", + "q17", + "q18", + "q19", + "q20", + "q21", + "q22") +}