Skip to content

Commit

Permalink
feat: port of substrait-spark module from Gluten
Browse files Browse the repository at this point in the history
This module was part of the gluten project and subsequently removed.
It is useful for converting spark query plans to and from substrait.

Signed-off-by: andrew-coleman <[email protected]>
  • Loading branch information
andrew-coleman committed Jun 18, 2024
1 parent 44c078d commit 157276e
Show file tree
Hide file tree
Showing 36 changed files with 3,597 additions and 1 deletion.
2 changes: 2 additions & 0 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ jackson.version=2.16.1
junit.version=5.8.1
protobuf.version=3.25.3
slf4j.version=2.0.13
sparkbundle.version=3.4
spark.version=3.4.2

#version that is going to be updated automatically by releases
version = 0.33.0
Expand Down
2 changes: 1 addition & 1 deletion settings.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
rootProject.name = "substrait"

include("bom", "core", "isthmus", "isthmus-cli")
include("bom", "core", "isthmus", "isthmus-cli", "spark")

pluginManagement {
plugins {
Expand Down
113 changes: 113 additions & 0 deletions spark/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
plugins {
`maven-publish`
id("java")
id("scala")
id("idea")
id("com.diffplug.spotless") version "6.11.0"
signing
}

publishing {
publications {
create<MavenPublication>("maven-publish") {
from(components["java"])

pom {
name.set("Substrait Java")
description.set(
"Create a well-defined, cross-language specification for data compute operations"
)
url.set("https://github.com/substrait-io/substrait-java")
licenses {
license {
name.set("The Apache License, Version 2.0")
url.set("http://www.apache.org/licenses/LICENSE-2.0.txt")
}
}
developers {
developer {
// TBD Get the list of
}
}
scm {
connection.set("scm:git:git://github.com:substrait-io/substrait-java.git")
developerConnection.set("scm:git:ssh://github.com:substrait-io/substrait-java")
url.set("https://github.com/substrait-io/substrait-java/")
}
}
}
}
repositories {
maven {
name = "local"
val releasesRepoUrl = layout.buildDirectory.dir("repos/releases")
val snapshotsRepoUrl = layout.buildDirectory.dir("repos/snapshots")
url = uri(if (version.toString().endsWith("SNAPSHOT")) snapshotsRepoUrl else releasesRepoUrl)
}
}
}

signing {
setRequired({ gradle.taskGraph.hasTask("publishToSonatype") })
val signingKeyId =
System.getenv("SIGNING_KEY_ID").takeUnless { it.isNullOrEmpty() }
?: extra["SIGNING_KEY_ID"].toString()
val signingPassword =
System.getenv("SIGNING_PASSWORD").takeUnless { it.isNullOrEmpty() }
?: extra["SIGNING_PASSWORD"].toString()
val signingKey =
System.getenv("SIGNING_KEY").takeUnless { it.isNullOrEmpty() }
?: extra["SIGNING_KEY"].toString()
useInMemoryPgpKeys(signingKeyId, signingKey, signingPassword)
sign(publishing.publications["maven-publish"])
}

configurations.all {
if (name.startsWith("incrementalScalaAnalysis")) {
setExtendsFrom(emptyList())
}
}

java {
toolchain { languageVersion.set(JavaLanguageVersion.of(17)) }
withJavadocJar()
withSourcesJar()
}

tasks.withType<ScalaCompile>() {
targetCompatibility = ""
scalaCompileOptions.additionalParameters = listOf("-release:17")
}

var SLF4J_VERSION = properties.get("slf4j.version")
var SPARKBUNDLE_VERSION = properties.get("sparkbundle.version")
var SPARK_VERSION = properties.get("spark.version")

sourceSets {
main { scala { setSrcDirs(listOf("src/main/spark-${SPARKBUNDLE_VERSION}")) } }
test { scala { setSrcDirs(listOf("src/test/scala", "src/test/spark-3.2", "src/main/scala")) } }
}

dependencies {
implementation(project(":core"))
implementation("org.scala-lang:scala-library:2.12.16")
implementation("org.apache.spark:spark-core_2.12:${SPARK_VERSION}")
implementation("org.apache.spark:spark-sql_2.12:${SPARK_VERSION}")
implementation("org.apache.spark:spark-catalyst_2.12:${SPARK_VERSION}")
implementation("org.slf4j:slf4j-api:${SLF4J_VERSION}")

testImplementation("org.scalatest:scalatest_2.12:3.2.18")
testRuntimeOnly("org.junit.platform:junit-platform-engine:1.10.0")
testRuntimeOnly("org.junit.platform:junit-platform-launcher:1.10.0")
testRuntimeOnly("org.scalatestplus:junit-5-10_2.12:3.2.18.0")
testImplementation("org.apache.spark:spark-core_2.12:${SPARK_VERSION}:tests")
testImplementation("org.apache.spark:spark-sql_2.12:${SPARK_VERSION}:tests")
testImplementation("org.apache.spark:spark-catalyst_2.12:${SPARK_VERSION}:tests")
}

tasks {
test {
dependsOn(":core:shadowJar")
useJUnitPlatform { includeEngines("scalatest") }
}
}
34 changes: 34 additions & 0 deletions spark/src/main/resources/spark.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
%YAML 1.2
---
scalar_functions:
-
name: year
description: Returns the year component of the date/timestamp
impls:
- args:
- value: date
return: i32
-
name: unscaled
description: >-
Return the unscaled Long value of a Decimal, assuming it fits in a Long.
Note: this expression is internal and created only by the optimizer,
we don't need to do type check for it.
impls:
- args:
- value: DECIMAL<P,S>
return: i64
75 changes: 75 additions & 0 deletions spark/src/main/scala/io/substrait/debug/ExpressionToString.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.substrait.debug

import io.substrait.spark.DefaultExpressionVisitor

import org.apache.spark.sql.catalyst.util.DateTimeUtils

import io.substrait.expression.{Expression, FieldReference}
import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, StrLiteral}
import io.substrait.function.ToTypeString
import io.substrait.util.DecimalUtil

import scala.collection.JavaConverters.asScalaBufferConverter

class ExpressionToString extends DefaultExpressionVisitor[String] {

override def visit(expr: DecimalLiteral): String = {
val value = expr.value.toByteArray
val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16)
decimal.toString
}

override def visit(expr: StrLiteral): String = {
expr.value()
}

override def visit(expr: I32Literal): String = {
expr.value().toString
}

override def visit(expr: DateLiteral): String = {
DateTimeUtils.toJavaDate(expr.value()).toString
}

override def visit(expr: FieldReference): String = {
withFieldReference(expr)(i => "$" + i.toString)
}

override def visit(expr: Expression.SingleOrList): String = {
expr.toString
}

override def visit(expr: Expression.ScalarFunctionInvocation): String = {
val args = expr
.arguments()
.asScala
.zipWithIndex
.map {
case (arg, i) =>
arg.accept(expr.declaration(), i, this)
}
.mkString(",")

s"${expr.declaration().key()}[${expr.outputType().accept(ToTypeString.INSTANCE)}]($args)"
}

override def visit(expr: Expression.UserDefinedLiteral): String = {
expr.toString
}
}
Loading

0 comments on commit 157276e

Please sign in to comment.