Skip to content

Commit

Permalink
Added logic to support reading unknown collation name as utf8_binary
Browse files Browse the repository at this point in the history
  • Loading branch information
vladanvasi-db committed Nov 5, 2024
1 parent 46fe10a commit b95f6ff
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ private[sql] trait SqlApiConf {
def defaultStringType: StringType
def stackTracesInDataFrameContext: Int
def legacyAllowUntypedScalaUDFs: Boolean
def unknownCollationNameEnabled: Boolean
}

private[sql] object SqlApiConf {
Expand All @@ -58,6 +59,7 @@ private[sql] object SqlApiConf {
SqlApiConfHelper.LOCAL_RELATION_CACHE_THRESHOLD_KEY
}
val DEFAULT_COLLATION: String = SqlApiConfHelper.DEFAULT_COLLATION
val UNKNOWN_COLLATION_NAME_ENABLED: String = SqlApiConfHelper.UNKNOWN_COLLATION_NAME_ENABLED

def get: SqlApiConf = SqlApiConfHelper.getConfGetter.get()()

Expand Down Expand Up @@ -85,4 +87,5 @@ private[sql] object DefaultSqlApiConf extends SqlApiConf {
override def defaultStringType: StringType = StringType
override def stackTracesInDataFrameContext: Int = 1
override def legacyAllowUntypedScalaUDFs: Boolean = false
override def unknownCollationNameEnabled: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ private[sql] object SqlApiConfHelper {
val SESSION_LOCAL_TIMEZONE_KEY: String = "spark.sql.session.timeZone"
val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String = "spark.sql.session.localRelationCacheThreshold"
val DEFAULT_COLLATION: String = "spark.sql.session.collation.default"
val UNKNOWN_COLLATION_NAME_ENABLED: String = "spark.sql.collation.unknown.enabled"

val confGetter: AtomicReference[() => SqlApiConf] = {
new AtomicReference[() => SqlApiConf](() => DefaultSqlApiConf)
Expand Down
11 changes: 9 additions & 2 deletions sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.{SparkIllegalArgumentException, SparkThrowable}
import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkThrowable}
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.analysis.SqlApiAnalysis
import org.apache.spark.sql.catalyst.parser.DataTypeParser
Expand Down Expand Up @@ -350,7 +350,14 @@ object DataType {
}

private def stringTypeWithCollation(collationName: String): StringType = {
StringType(CollationFactory.collationNameToId(collationName))
try {
StringType(CollationFactory.collationNameToId(collationName))
}
catch {
case e: SparkException if e.getCondition == "COLLATION_INVALID_NAME" &&
SqlApiConf.get.unknownCollationNameEnabled =>
StringType(CollationFactory.UTF8_BINARY_COLLATION_ID)
}
}

protected[types] def buildFormattedString(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,14 @@ object SQLConf {
.booleanConf
.createWithDefault(Utils.isTesting)

lazy val UNKNOWN_COLLATION_NAME_ENABLED =
buildConf(SqlApiConfHelper.UNKNOWN_COLLATION_NAME_ENABLED)
.internal()
.doc("Enables spark to read unknown collation name as UTF8_BINARY.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val DEFAULT_COLLATION =
buildConf(SqlApiConfHelper.DEFAULT_COLLATION)
.doc("Sets default collation to use for string literals, parameter markers or the string" +
Expand Down Expand Up @@ -5522,6 +5530,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
}
}

override def unknownCollationNameEnabled: Boolean = getConf(UNKNOWN_COLLATION_NAME_ENABLED)

def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED)

def adaptiveExecutionLogLevel: String = getConf(ADAPTIVE_EXECUTION_LOG_LEVEL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ import org.json4s.jackson.JsonMethods
import org.apache.spark.{SparkException, SparkFunSuite, SparkIllegalArgumentException}
import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CollationFactory, StringConcat}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes}

class DataTypeSuite extends SparkFunSuite {
class DataTypeSuite extends SparkFunSuite with SQLHelper {

private val UNICODE_COLLATION_ID = CollationFactory.collationNameToId("UNICODE")

Expand Down Expand Up @@ -876,6 +878,49 @@ class DataTypeSuite extends SparkFunSuite {
}
}

test("string field with invalid collation name") {
val collationProviders = Seq("spark", "icu")
collationProviders.foreach { provider =>
val json =
s"""
|{
| "type": "struct",
| "fields": [
| {
| "name": "c1",
| "type": "string",
| "nullable": false,
| "metadata": {
| "${DataType.COLLATIONS_METADATA_KEY}": {
| "c1": "$provider.INVALID"
| }
| }
| }
| ]
|}
|""".stripMargin

// Check that the exception will be thrown in case of invalid collation name and
// UNKNOWN_COLLATION_NAME config not enabled.
checkError(
exception = intercept[SparkException] {
DataType.fromJson(json)
},
condition = "COLLATION_INVALID_NAME",
parameters = Map(
"proposals" -> "id",
"collationName" -> "INVALID"))

// Check that the exception will not be thrown in case of invalid collation name and
// UNKNOWN_COLLATION_NAME enabled, but UTF8_BINARY collation will be returned.
withSQLConf(SQLConf.UNKNOWN_COLLATION_NAME_ENABLED.key -> "true") {
val dataType = DataType.fromJson(json)
assert(dataType === StructType(
StructField("c1", StringType(CollationFactory.UTF8_BINARY_COLLATION_ID), false) :: Nil))
}
}
}

test("non string field has collation metadata") {
val json =
s"""
Expand Down Expand Up @@ -1023,6 +1068,42 @@ class DataTypeSuite extends SparkFunSuite {
assert(parsedWithCollations === ArrayType(StringType(unicodeCollationId)))
}

test("parse array type with invalid collation metadata") {
val utf8BinaryCollationId = CollationFactory.UTF8_BINARY_COLLATION_ID
val arrayJson =
s"""
|{
| "type": "array",
| "elementType": "string",
| "containsNull": true
|}
|""".stripMargin

val collationsMap = Map("element" -> "INVALID")

// Parse without collations map
assert(DataType.parseDataType(JsonMethods.parse(arrayJson)) === ArrayType(StringType))

// Check that the exception will be thrown in case of invalid collation name and
// UNKNOWN_COLLATION_NAME config not enabled.
checkError(
exception = intercept[SparkException] {
DataType.parseDataType(JsonMethods.parse(arrayJson), collationsMap = collationsMap)
},
condition = "COLLATION_INVALID_NAME",
parameters = Map(
"proposals" -> "id",
"collationName" -> "INVALID"))

// Check that the exception will not be thrown in case of invalid collation name and
// UNKNOWN_COLLATION_NAME enabled, but UTF8_BINARY collation will be returned.
withSQLConf(SQLConf.UNKNOWN_COLLATION_NAME_ENABLED.key -> "true") {
val dataType = DataType.parseDataType(
JsonMethods.parse(arrayJson), collationsMap = collationsMap)
assert(dataType === ArrayType(StringType(utf8BinaryCollationId)))
}
}

test("parse map type with collation metadata") {
val unicodeCollationId = CollationFactory.collationNameToId("UNICODE")
val mapJson =
Expand All @@ -1046,6 +1127,44 @@ class DataTypeSuite extends SparkFunSuite {
MapType(StringType(unicodeCollationId), StringType(unicodeCollationId)))
}

test("parse map type with invalid collation metadata") {
val utf8BinaryCollationId = CollationFactory.UTF8_BINARY_COLLATION_ID
val mapJson =
s"""
|{
| "type": "map",
| "keyType": "string",
| "valueType": "string",
| "valueContainsNull": true
|}
|""".stripMargin

val collationsMap = Map("key" -> "INVALID", "value" -> "INVALID")

// Parse without collations map
assert(DataType.parseDataType(JsonMethods.parse(mapJson)) === MapType(StringType, StringType))

// Check that the exception will be thrown in case of invalid collation name and
// UNKNOWN_COLLATION_NAME config not enabled.
checkError(
exception = intercept[SparkException] {
DataType.parseDataType(JsonMethods.parse(mapJson), collationsMap = collationsMap)
},
condition = "COLLATION_INVALID_NAME",
parameters = Map(
"proposals" -> "id",
"collationName" -> "INVALID"))

// Check that the exception will not be thrown in case of invalid collation name and
// UNKNOWN_COLLATION_NAME enabled, but UTF8_BINARY collation will be returned.
withSQLConf(SQLConf.UNKNOWN_COLLATION_NAME_ENABLED.key -> "true") {
val dataType = DataType.parseDataType(
JsonMethods.parse(mapJson), collationsMap = collationsMap)
assert(dataType === MapType(
StringType(utf8BinaryCollationId), StringType(utf8BinaryCollationId)))
}
}

test("SPARK-48680: Add CharType and VarcharType to DataTypes JAVA API") {
assert(DataTypes.createCharType(1) === CharType(1))
assert(DataTypes.createVarcharType(100) === VarcharType(100))
Expand Down

0 comments on commit b95f6ff

Please sign in to comment.