Skip to content

Commit

Permalink
feat: spark-substrait example
Browse files Browse the repository at this point in the history
Signed-off-by: MBWhite <[email protected]>
  • Loading branch information
mbwhite committed Sep 9, 2024
1 parent c8c31ec commit 21e97f2
Show file tree
Hide file tree
Showing 28 changed files with 1,593 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ trim_trailing_whitespace = true
[*.{yaml,yml}]
indent_size = 2

[{**/*.sql,**/OuterReferenceResolver.md,gradlew.bat}]
[{**/*.sql,**/OuterReferenceResolver.md,**gradlew.bat}]
charset = unset
end_of_line = unset
insert_final_newline = unset
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ gen
out/**
*.iws
.vscode
.pmdCache
2 changes: 2 additions & 0 deletions examples/subtrait-spark/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_apps
_data
587 changes: 587 additions & 0 deletions examples/subtrait-spark/README.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions examples/subtrait-spark/app/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
spark-warehouse
derby.log
62 changes: 62 additions & 0 deletions examples/subtrait-spark/app/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* This file was generated by the Gradle 'init' task.
*
* This project uses @Incubating APIs which are subject to change.
*/

plugins {
id 'buildlogic.java-application-conventions'
}

dependencies {
implementation 'org.apache.commons:commons-text'
// for running as a Spark application for real, this could be compile-only


implementation libs.substrait.core
implementation libs.substrait.spark
implementation libs.spark.sql

// For a real Spark application, these would not be required since they would be in the Spark server classpath
runtimeOnly libs.spark.core
// https://mvnrepository.com/artifact/org.apache.spark/spark-hive
runtimeOnly libs.spark.hive



}

def jvmArguments = [
"--add-exports",
"java.base/sun.nio.ch=ALL-UNNAMED",
"--add-opens=java.base/java.net=ALL-UNNAMED",
"--add-opens=java.base/java.nio=ALL-UNNAMED",
"-Dspark.master=local"
]

application {
// Define the main class for the application.
mainClass = 'io.substrait.examples.App'
applicationDefaultJvmArgs = jvmArguments
}

jar {
zip64 = true
duplicatesStrategy = DuplicatesStrategy.EXCLUDE

manifest {
attributes 'Main-Class': 'io.substrait.examples.App'
}

from {
configurations.runtimeClasspath.collect { it.isDirectory() ? it : zipTree(it) }
}

exclude 'META-INF/*.RSA'
exclude 'META-INF/*.SF'
exclude 'META-INF/*.DSA'
}

repositories {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package io.substrait.examples;

import java.nio.file.Files;
import java.nio.file.Paths;

import io.substrait.plan.Plan;
import io.substrait.plan.ProtoPlanConverter;

public class App {

public static interface Action {
public void run(String arg);
}

private App() {
}

public static void main(String args[]) {
try {

if (args.length == 0) {
args = new String[] { "SparkDataset" };
}
String exampleClass = args[0];

var clz = Class.forName(App.class.getPackageName() + "." + exampleClass);
var action = (Action) clz.getDeclaredConstructor().newInstance();

if (args.length == 2) {
action.run(args[1]);
} else {
action.run(null);
}

} catch (Exception e) {
e.printStackTrace();
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package io.substrait.examples;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;

import io.substrait.plan.Plan;
import io.substrait.plan.ProtoPlanConverter;
import io.substrait.spark.logical.ToLogicalPlan;

import static io.substrait.examples.SparkHelper.ROOT_DIR;

/** Minimal Spark application */
public class SparkConsumeSubstrait implements App.Action {

public SparkConsumeSubstrait() {
}

@Override
public void run(String arg) {

// Connect to a local in-process Spark instance
try (SparkSession spark = SparkHelper.connectLocalSpark()) {

System.out.println("Reading from " + arg);
byte[] buffer = Files.readAllBytes(Paths.get(ROOT_DIR, arg));

io.substrait.proto.Plan proto = io.substrait.proto.Plan.parseFrom(buffer);
ProtoPlanConverter protoToPlan = new ProtoPlanConverter();
Plan plan = protoToPlan.from(proto);

ToLogicalPlan substraitConverter = new ToLogicalPlan(spark);
LogicalPlan sparkPlan = substraitConverter.convert(plan);

System.out.println(sparkPlan);

Dataset.ofRows(spark, sparkPlan).show();

spark.stop();
} catch (IOException e) {
e.printStackTrace();
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package io.substrait.examples;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import java.io.IOException;
import java.nio.file.*;
import io.substrait.plan.PlanProtoConverter;
import io.substrait.spark.logical.ToSubstraitRel;
import static io.substrait.examples.SparkHelper.ROOT_DIR;
import static io.substrait.examples.SparkHelper.TESTS_CSV;
import static io.substrait.examples.SparkHelper.VEHICLES_CSV;

/** Minimal Spark application */
public class SparkDataset implements App.Action {

public SparkDataset() {

}

@Override
public void run(String arg) {

// Connect to a local in-process Spark instance
try (SparkSession spark = SparkHelper.connectLocalSpark()) {

Dataset<Row> dsVehicles;
Dataset<Row> dsTests;

// load from CSV files
String vehiclesFile = Paths.get(ROOT_DIR, VEHICLES_CSV).toString();
String testsFile = Paths.get(ROOT_DIR, TESTS_CSV).toString();

System.out.println("Reading "+vehiclesFile);
System.out.println("Reading "+testsFile);

dsVehicles = spark.read().option("delimiter", ",").option("header", "true").csv(vehiclesFile);
dsVehicles.show();

dsTests = spark.read().option("delimiter", ",").option("header", "true").csv(testsFile);
dsTests.show();

// created the joined dataset
Dataset<Row> joinedDs = dsVehicles.join(dsTests, dsVehicles.col("vehicle_id").equalTo(dsTests.col("vehicle_id")))
.filter(dsTests.col("test_result").equalTo("P"))
.groupBy(dsVehicles.col("colour"))
.count();

joinedDs = joinedDs.orderBy(joinedDs.col("count"));
joinedDs.show();

LogicalPlan plan = joinedDs.queryExecution().optimizedPlan();

System.out.println(plan);
createSubstrait(plan);

spark.stop();
} catch (Exception e) {
e.printStackTrace(System.out);
}
}

public void createSubstrait(LogicalPlan enginePlan) {
ToSubstraitRel toSubstrait = new ToSubstraitRel();
io.substrait.plan.Plan plan = toSubstrait.convert(enginePlan);

System.out.println(plan);

PlanProtoConverter planToProto = new PlanProtoConverter();
byte[] buffer = planToProto.toProto(plan).toByteArray();
try {
Files.write(Paths.get(ROOT_DIR,"spark_dataset_substrait.plan"), buffer);
System.out.println("File written to "+Paths.get(ROOT_DIR,"spark_sql_substrait.plan"));
} catch (IOException e) {
e.printStackTrace(System.out);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package io.substrait.examples;

import org.apache.spark.sql.SparkSession;

public class SparkHelper {
public static final String NAMESPACE = "demo_db";
public static final String VEHICLE_TABLE = "vehicles";
public static final String TESTS_TABLE = "tests";

public static final String VEHICLES_PQ = "vehicles_subset_2023.parquet";
public static final String TESTS_PQ = "tests_subset_2023.parquet";

public static final String VEHICLES_CSV = "vehicles_subset_2023.csv";
public static final String TESTS_CSV = "tests_subset_2023.csv";

public static final String ROOT_DIR = "/opt/spark-data";

// Connect to local spark for demo purposes
public static SparkSession connectSpark(String spark_master) {

SparkSession spark = SparkSession.builder()
// .config("spark.sql.warehouse.dir", "spark-warehouse")
.config("spark.master", spark_master)
.enableHiveSupport()
.getOrCreate();

spark.sparkContext().setLogLevel("ERROR");

return spark;
}

public static SparkSession connectLocalSpark() {

SparkSession spark = SparkSession.builder()
.enableHiveSupport()
.getOrCreate();

spark.sparkContext().setLogLevel("ERROR");

return spark;
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package io.substrait.examples;

import static io.substrait.examples.SparkHelper.ROOT_DIR;
import static io.substrait.examples.SparkHelper.TESTS_CSV;
import static io.substrait.examples.SparkHelper.TESTS_TABLE;
import static io.substrait.examples.SparkHelper.VEHICLES_CSV;
import static io.substrait.examples.SparkHelper.VEHICLE_TABLE;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;

import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;

import io.substrait.plan.PlanProtoConverter;
import io.substrait.spark.logical.ToSubstraitRel;

/** Minimal Spark application */
public class SparkSQL implements App.Action {

public SparkSQL() {

}

@Override
public void run(String arg) {

// Connect to a local in-process Spark instance
try (SparkSession spark = SparkHelper.connectLocalSpark()) {
spark.catalog().listDatabases().show();

// load from CSV files
String vehiclesFile = Paths.get(ROOT_DIR, VEHICLES_CSV).toString();
String testsFile = Paths.get(ROOT_DIR, TESTS_CSV).toString();

System.out.println("Reading " + vehiclesFile);
System.out.println("Reading " + testsFile);

spark.read().option("delimiter", ",").option("header", "true").csv(vehiclesFile)
.createOrReplaceTempView(VEHICLE_TABLE);
spark.read().option("delimiter", ",").option("header", "true").csv(testsFile)
.createOrReplaceTempView(TESTS_TABLE);

String sqlQuery = """
SELECT vehicles.colour, count(*) as colourcount
FROM vehicles
INNER JOIN tests ON vehicles.vehicle_id=tests.vehicle_id
WHERE tests.test_result = 'P'
GROUP BY vehicles.colour
ORDER BY count(*)
""";

var result = spark.sql(sqlQuery);
result.show();

LogicalPlan logical = result.logicalPlan();
System.out.println(logical);

LogicalPlan optimised = result.queryExecution().optimizedPlan();
System.out.println(optimised);

createSubstrait(optimised);
spark.stop();
} catch (Exception e) {
e.printStackTrace(System.out);
}
}

public void createSubstrait(LogicalPlan enginePlan) {
ToSubstraitRel toSubstrait = new ToSubstraitRel();
io.substrait.plan.Plan plan = toSubstrait.convert(enginePlan);
System.out.println(plan);

PlanProtoConverter planToProto = new PlanProtoConverter();
byte[] buffer = planToProto.toProto(plan).toByteArray();
try {
Files.write(Paths.get(ROOT_DIR,"spark_sql_substrait.plan"), buffer);
System.out.println("File written to "+Paths.get(ROOT_DIR,"spark_sql_substrait.plan"));

} catch (IOException e) {
e.printStackTrace();
}
}

}
Loading

0 comments on commit 21e97f2

Please sign in to comment.