Skip to content

Commit

Permalink
[spark] Support varchar/char type (apache#3361)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zouxxyy authored May 21, 2024
1 parent 7579c83 commit 36dcebc
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 13 deletions.
12 changes: 11 additions & 1 deletion docs/content/spark/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,17 @@ All Spark's data types are available in package `org.apache.spark.sql.types`.
</tr>
<tr>
<td><code>StringType</code></td>
<td><code>VarCharType</code>, <code>CharType</code></td>
<td><code>VarCharType(Integer.MAX_VALUE)</code></td>
<td>true</td>
</tr>
<tr>
<td><code>VarCharType(length)</code></td>
<td><code>VarCharType(length)</code></td>
<td>true</td>
</tr>
<tr>
<td><code>CharType(length)</code></td>
<td><code>CharType(length)</code></td>
<td>true</td>
</tr>
<tr>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.CalendarIntervalType;
import org.apache.spark.sql.types.CharType;
import org.apache.spark.sql.types.DateType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
Expand All @@ -53,6 +54,7 @@
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.TimestampType;
import org.apache.spark.sql.types.UserDefinedType;
import org.apache.spark.sql.types.VarcharType;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;

Expand Down Expand Up @@ -205,7 +207,9 @@ public Object get(int ordinal, org.apache.spark.sql.types.DataType dataType) {
if (dataType instanceof DoubleType) {
return getDouble(ordinal);
}
if (dataType instanceof StringType) {
if (dataType instanceof StringType
|| dataType instanceof CharType
|| dataType instanceof VarcharType) {
return getUTF8String(ordinal);
}
if (dataType instanceof DecimalType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,16 @@ private static class PaimonToSparkTypeVisitor extends DataTypeDefaultVisitor<Dat

@Override
public DataType visit(CharType charType) {
return DataTypes.StringType;
return new org.apache.spark.sql.types.CharType(charType.getLength());
}

@Override
public DataType visit(VarCharType varCharType) {
return DataTypes.StringType;
if (varCharType.getLength() == VarCharType.MAX_LENGTH) {
return DataTypes.StringType;
} else {
return new org.apache.spark.sql.types.VarcharType(varCharType.getLength());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ import org.apache.paimon.operation.FileStoreCommit
import org.apache.paimon.table.FileStoreTable
import org.apache.paimon.table.sink.BatchWriteBuilder
import org.apache.paimon.types.RowType
import org.apache.paimon.utils.{FileStorePathFactory, RowDataPartitionComputer}
import org.apache.paimon.utils.RowDataPartitionComputer

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement
import org.apache.spark.sql.types.StructType

Expand All @@ -51,7 +52,7 @@ trait PaimonPartitionManagement extends SupportsPartitionManagement {
override def dropPartition(internalRow: InternalRow): Boolean = {
// convert internalRow to row
val row: Row = CatalystTypeConverters
.createToScalaConverter(partitionSchema())
.createToScalaConverter(CharVarcharUtils.replaceCharVarcharWithString(partitionSchema()))
.apply(internalRow)
.asInstanceOf[Row]
val rowDataPartitionComputer = new RowDataPartitionComputer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.paimon.utils.DateTimeUtils;

import org.apache.spark.sql.catalyst.CatalystTypeConverters;
import org.apache.spark.sql.catalyst.util.CharVarcharUtils;
import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
Expand All @@ -54,6 +55,8 @@ public void test() {
GenericRow.of(
1,
fromString("jingsong"),
fromString("apache"),
fromString("paimon"),
22.2,
new GenericMap(
Stream.of(
Expand All @@ -79,9 +82,12 @@ public void test() {
Decimal.fromBigDecimal(BigDecimal.valueOf(65782123123.01), 38, 2),
Decimal.fromBigDecimal(BigDecimal.valueOf(62123123.5), 10, 1));

// CatalystTypeConverters does not support char and varchar, we need to replace char and
// varchar with string
Function1<Object, Object> sparkConverter =
CatalystTypeConverters.createToScalaConverter(
SparkTypeUtils.fromPaimonType(ALL_TYPES));
CharVarcharUtils.replaceCharVarcharWithString(
SparkTypeUtils.fromPaimonType(ALL_TYPES)));
org.apache.spark.sql.Row sparkRow =
(org.apache.spark.sql.Row)
sparkConverter.apply(new SparkInternalRow(ALL_TYPES).replace(rowData));
Expand All @@ -90,6 +96,8 @@ public void test() {
"{"
+ "\"id\":1,"
+ "\"name\":\"jingsong\","
+ "\"char\":\"apache\","
+ "\"varchar\":\"paimon\","
+ "\"salary\":22.2,"
+ "\"locations\":{\"key1\":{\"posX\":1.2,\"posY\":2.3},\"key2\":{\"posX\":2.4,\"posY\":3.5}},"
+ "\"strArray\":[\"v1\",\"v5\"],"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ public void testCreateTableAs() {
spark.sql("CREATE TABLE testCreateTableAs AS SELECT * FROM testCreateTable");
List<Row> result = spark.sql("SELECT * FROM testCreateTableAs").collectAsList();

assertThat(result.stream().map(Row::toString)).containsExactlyInAnyOrder("[1,a,b]");
assertThat(result.stream().map(Row::toString))
.containsExactlyInAnyOrder("[1,a,b ]");

// partitioned table
spark.sql(
Expand Down Expand Up @@ -224,11 +225,13 @@ public void testCreateTableAs() {
+ " 'file.format' = 'parquet',\n"
+ " 'path' = '%s')\n"
+ "]]",
showCreateString("testTableAs", "a BIGINT", "b STRING", "c STRING"),
showCreateString(
"testTableAs", "a BIGINT", "b VARCHAR(10)", "c CHAR(10)"),
new Path(warehousePath, "default.db/testTableAs")));
List<Row> resultProp = spark.sql("SELECT * FROM testTableAs").collectAsList();

assertThat(resultProp.stream().map(Row::toString)).containsExactlyInAnyOrder("[1,a,b]");
assertThat(resultProp.stream().map(Row::toString))
.containsExactlyInAnyOrder("[1,a,b ]");

// primary key
spark.sql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ public class SparkTypeTest {
1)) // posX and posY have field id 0 and 1, here we start from 2
.field("id", DataTypes.INT().notNull())
.field("name", DataTypes.STRING()) /* optional by default */
.field("char", DataTypes.CHAR(10))
.field("varchar", DataTypes.VARCHAR(10))
.field("salary", DataTypes.DOUBLE().notNull())
.field(
"locations",
Expand Down Expand Up @@ -79,6 +81,8 @@ public void testAllTypes() {
"StructType("
+ "StructField(id,IntegerType,true),"
+ "StructField(name,StringType,true),"
+ "StructField(char,CharType(10),true),"
+ "StructField(varchar,VarcharType(10),true),"
+ "StructField(salary,DoubleType,true),"
+ nestedRowMapType
+ ","
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.paimon.spark.sql

import org.apache.paimon.spark.PaimonSparkTestBase

import org.apache.spark.sql.Row
import org.junit.jupiter.api.Assertions

abstract class DDLTestBase extends PaimonSparkTestBase {
Expand Down Expand Up @@ -84,4 +85,70 @@ abstract class DDLTestBase extends PaimonSparkTestBase {
"SparkCatalog can only create paimon table, but current provider is parquet"))
}
}

test("Paimon DDL: create table with char/varchar/string") {
Seq("orc", "avro").foreach(
format => {
withTable("paimon_tbl") {
spark.sql(
s"""
|CREATE TABLE paimon_tbl (id int, col_s1 char(9), col_s2 varchar(10), col_s3 string)
|USING PAIMON
|TBLPROPERTIES ('file.format' = '$format')
|""".stripMargin)

spark.sql(s"""
|insert into paimon_tbl values
|(1, 'Wednesday', 'Wednesday', 'Wednesday'),
|(2, 'Friday', 'Friday', 'Friday')
|""".stripMargin)

// check description
checkAnswer(
spark
.sql(s"DESC paimon_tbl")
.select("col_name", "data_type")
.where("col_name LIKE 'col_%'")
.orderBy("col_name"),
Row("col_s1", "char(9)") :: Row("col_s2", "varchar(10)") :: Row(
"col_s3",
"string") :: Nil
)

// check select
if (format == "orc" && !gteqSpark3_4) {
// Orc reader will right trim the char type, e.g. "Friday " => "Friday" (see orc's `CharTreeReader`)
// and Spark has a conf `spark.sql.readSideCharPadding` to auto padding char only since 3.4 (default true)
// So when using orc with Spark3.4-, here will return "Friday"
checkAnswer(
spark.sql(s"select col_s1 from paimon_tbl where id = 2"),
Row("Friday") :: Nil
)
// Spark will auto create the filter like Filter(isnotnull(col_s1#124) AND (col_s1#124 = Friday ))
// for char type, so here will not return any rows
checkAnswer(
spark.sql(s"select col_s1 from paimon_tbl where col_s1 = 'Friday'"),
Nil
)
} else {
checkAnswer(
spark.sql(s"select col_s1 from paimon_tbl where id = 2"),
Row("Friday ") :: Nil
)
checkAnswer(
spark.sql(s"select col_s1 from paimon_tbl where col_s1 = 'Friday'"),
Row("Friday ") :: Nil
)
}
checkAnswer(
spark.sql(s"select col_s2 from paimon_tbl where col_s2 = 'Friday'"),
Row("Friday") :: Nil
)
checkAnswer(
spark.sql(s"select col_s3 from paimon_tbl where col_s3 = 'Friday'"),
Row("Friday") :: Nil
)
}
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,17 @@ class PaimonPartitionManagementTest extends PaimonSparkTestBase {

checkAnswer(
spark.sql("select * from T"),
Row("a", "b", 1L, 20230816L, "1132") :: Row("a", "b", 1L, 20230816L, "1133") :: Row(
Row("a", "b ", 1L, 20230816L, "1132") :: Row(
"a",
"b",
"b ",
1L,
20230816L,
"1133") :: Row("a", "b ", 2L, 20230817L, "1132") :: Row(
"a",
"b ",
2L,
20230817L,
"1132") :: Row("a", "b", 2L, 20230817L, "1134") :: Nil
"1134") :: Nil
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ trait SparkVersionSupport {
lazy val sparkVersion: String = SPARK_VERSION

lazy val gteqSpark3_3: Boolean = sparkVersion >= "3.3"

lazy val gteqSpark3_4: Boolean = sparkVersion >= "3.4"
}

0 comments on commit 36dcebc

Please sign in to comment.