Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 3.5.1-SNAPSHOT Shim #9962

Merged
merged 17 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
{"spark": "341"}
{"spark": "341db"}
{"spark": "350"}
{"spark": "351"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.tests.datagen

Expand Down
5 changes: 3 additions & 2 deletions integration_tests/src/main/python/collection_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from string_test import mk_str_gen
import pyspark.sql.functions as f
import pyspark.sql.utils
from spark_session import with_cpu_session, with_gpu_session
from spark_session import with_cpu_session, with_gpu_session, is_before_spark_351
from conftest import get_datagen_seed
from marks import allow_non_gpu

Expand Down Expand Up @@ -326,11 +326,12 @@ def test_sequence_illegal_boundaries(start_gen, stop_gen, step_gen):
@pytest.mark.parametrize('stop_gen', sequence_too_long_length_gens, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_sequence_too_long_sequence(stop_gen):
msg = "Too long sequence" if is_before_spark_351() else "Unsuccessful try to create array with"
assert_gpu_and_cpu_error(
# To avoid OOM, reduce the row number to 1, it is enough to verify this case.
lambda spark:unary_op_df(spark, stop_gen, 1).selectExpr(
"sequence(0, a)").collect(),
conf = {}, error_message = "Too long sequence")
conf = {}, error_message = msg)

def get_sequence_cases_mixed_df(spark, length=2048):
# Generate the sequence data following the 3 rules mixed in a single dataset.
Expand Down
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/spark_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def is_before_spark_341():
def is_before_spark_350():
return spark_version() < "3.5.0"

def is_before_spark_351():
return spark_version() < "3.5.1"

def is_spark_320_or_later():
return spark_version() >= "3.2.0"

Expand Down
22 changes: 22 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,26 @@
<module>delta-lake/delta-stub</module>
</modules>
</profile>
<profile>
<id>release351</id>
<activation>
<property>
<name>buildver</name>
<value>351</value>
</property>
</activation>
<properties>
<buildver>351</buildver>
<spark.version>${spark351.version}</spark.version>
<spark.test.version>${spark351.version}</spark.test.version>
<parquet.hadoop.version>1.13.1</parquet.hadoop.version>
<iceberg.version>${spark330.iceberg.version}</iceberg.version>
<slf4j.version>2.0.7</slf4j.version>
</properties>
<modules>
<module>delta-lake/delta-stub</module>
</modules>
</profile>
<profile>
<id>source-javadoc</id>
<build>
Expand Down Expand Up @@ -718,6 +738,7 @@
<spark332db.version>3.3.2-databricks</spark332db.version>
<spark341db.version>3.4.1-databricks</spark341db.version>
<spark350.version>3.5.0</spark350.version>
<spark351.version>3.5.1-SNAPSHOT</spark351.version>
<mockito.version>3.12.4</mockito.version>
<scala.plugin.version>4.3.0</scala.plugin.version>
<maven.install.plugin.version>3.1.1</maven.install.plugin.version>
Expand Down Expand Up @@ -767,6 +788,7 @@
350
</noSnapshot.buildvers>
<snapshot.buildvers>
351
</snapshot.buildvers>
<databricks.buildvers>
321db,
Expand Down
22 changes: 22 additions & 0 deletions scala2.13/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,26 @@
<module>delta-lake/delta-stub</module>
</modules>
</profile>
<profile>
<id>release351</id>
<activation>
<property>
<name>buildver</name>
<value>351</value>
</property>
</activation>
<properties>
<buildver>351</buildver>
<spark.version>${spark351.version}</spark.version>
<spark.test.version>${spark351.version}</spark.test.version>
<parquet.hadoop.version>1.13.1</parquet.hadoop.version>
<iceberg.version>${spark330.iceberg.version}</iceberg.version>
<slf4j.version>2.0.7</slf4j.version>
</properties>
<modules>
<module>delta-lake/delta-stub</module>
</modules>
</profile>
<profile>
<id>source-javadoc</id>
<build>
Expand Down Expand Up @@ -718,6 +738,7 @@
<spark332db.version>3.3.2-databricks</spark332db.version>
<spark341db.version>3.4.1-databricks</spark341db.version>
<spark350.version>3.5.0</spark350.version>
<spark351.version>3.5.1-SNAPSHOT</spark351.version>
<mockito.version>3.12.4</mockito.version>
<scala.plugin.version>4.3.0</scala.plugin.version>
<maven.install.plugin.version>3.1.1</maven.install.plugin.version>
Expand Down Expand Up @@ -767,6 +788,7 @@
350
</noSnapshot.buildvers>
<snapshot.buildvers>
351
</snapshot.buildvers>
<databricks.buildvers>
321db,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import ai.rapids.cudf.ast.BinaryOperator
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.{GpuTypeShims, ShimExpression, SparkShimImpl}
import com.nvidia.spark.rapids.shims.{DecimalMultiply128, GpuTypeShims, ShimExpression, SparkShimImpl}

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.{ComplexTypeMergingExpression, ExpectsInputTypes, Expression, NullIntolerant}
Expand All @@ -38,7 +38,8 @@ object AddOverflowChecks {
lhs: BinaryOperable,
rhs: BinaryOperable,
ret: ColumnVector): Unit = {
// Check overflow. It is true when both arguments have the opposite sign of the result.
// Check overflow. It is true if the arguments have different signs and
// the sign of the result is different from the sign of x.
// Which is equal to "((x ^ r) & (y ^ r)) < 0" in the form of arithmetic.
val signCV = withResource(ret.bitXor(lhs)) { lXor =>
withResource(ret.bitXor(rhs)) { rXor =>
Expand All @@ -53,9 +54,8 @@ object AddOverflowChecks {
withResource(signDiffCV) { signDiff =>
withResource(signDiff.any()) { any =>
if (any.isValid && any.getBoolean) {
throw RapidsErrorUtils.arithmeticOverflowError(
"One or more rows overflow for Add operation."
)
throw RapidsErrorUtils.
arithmeticOverflowError("One or more rows overflow for Add operation.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us leave formatting-only changes to dedicated PRs

}
}
}
Expand Down Expand Up @@ -109,6 +109,35 @@ object AddOverflowChecks {
}
}

object SubtractOverflowChecks {
def basicOpOverflowCheck(
lhs: BinaryOperable,
rhs: BinaryOperable,
ret: ColumnVector): Unit = {
// Check overflow. It is true if the arguments have different signs and
// the sign of the result is different from the sign of x.
// Which is equal to "((x ^ y) & (x ^ r)) < 0" in the form of arithmetic.
val signCV = withResource(lhs.bitXor(rhs)) { xyXor =>
withResource(lhs.bitXor(ret)) { xrXor =>
xyXor.bitAnd(xrXor)
}
}
val signDiffCV = withResource(signCV) { sign =>
withResource(Scalar.fromInt(0)) { zero =>
sign.lessThan(zero)
}
}
withResource(signDiffCV) { signDiff =>
withResource(signDiff.any()) { any =>
if (any.isValid && any.getBoolean) {
throw RapidsErrorUtils.
arithmeticOverflowError("One or more rows overflow for Subtract operation.")
}
}
}
}
}

object GpuAnsi {
def needBasicOpOverflowCheck(dt: DataType): Boolean =
dt.isInstanceOf[IntegralType]
Expand Down Expand Up @@ -289,35 +318,6 @@ abstract class GpuSubtractBase extends CudfBinaryArithmetic with Serializable {
override def binaryOp: BinaryOp = BinaryOp.SUB
override def astOperator: Option[BinaryOperator] = Some(ast.BinaryOperator.SUB)

private[this] def basicOpOverflowCheck(
lhs: BinaryOperable,
rhs: BinaryOperable,
ret: ColumnVector): Unit = {
// Check overflow. It is true if the arguments have different signs and
// the sign of the result is different from the sign of x.
// Which is equal to "((x ^ y) & (x ^ r)) < 0" in the form of arithmetic.

val signCV = withResource(lhs.bitXor(rhs)) { xyXor =>
withResource(lhs.bitXor(ret)) { xrXor =>
xyXor.bitAnd(xrXor)
}
}
val signDiffCV = withResource(signCV) { sign =>
withResource(Scalar.fromInt(0)) { zero =>
sign.lessThan(zero)
}
}
withResource(signDiffCV) { signDiff =>
withResource(signDiff.any()) { any =>
if (any.isValid && any.getBoolean) {
throw RapidsErrorUtils.arithmeticOverflowError(
"One or more rows overflow for Subtract operation."
)
}
}
}
}

private[this] def decimalOpOverflowCheck(
lhs: BinaryOperable,
rhs: BinaryOperable,
Expand Down Expand Up @@ -367,7 +367,7 @@ abstract class GpuSubtractBase extends CudfBinaryArithmetic with Serializable {
GpuTypeShims.isSupportedYearMonthType(dataType)) {
// For day time interval, Spark throws an exception when overflow,
// regardless of whether `SQLConf.get.ansiEnabled` is true or false
basicOpOverflowCheck(lhs, rhs, ret)
SubtractOverflowChecks.basicOpOverflowCheck(lhs, rhs, ret)
}

if (dataType.isInstanceOf[DecimalType]) {
Expand Down Expand Up @@ -452,7 +452,7 @@ trait GpuDecimalMultiplyBase extends GpuExpression {
rhs.getBase.castTo(DType.create(DType.DTypeEnum.DECIMAL128, rhs.getBase.getType.getScale))
}
withResource(castRhs) { castRhs =>
com.nvidia.spark.rapids.jni.DecimalUtils.multiply128(castLhs, castRhs, -dataType.scale)
DecimalMultiply128(castLhs, castRhs, -dataType.scale)
}
}
val retCol = withResource(retTab) { retTab =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@ import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.ArrayIndexUtils.firstIndexAndNumElementUnchecked
import com.nvidia.spark.rapids.BoolUtils.isAllValidTrue
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.ShimExpression
import com.nvidia.spark.rapids.shims.{ComputeSequenceSizes, ShimExpression}

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.{ElementAt, ExpectsInputTypes, Expression, ImplicitCastInputTypes, NamedExpression, NullIntolerant, RowOrdering, Sequence, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

case class GpuConcat(children: Seq[Expression]) extends GpuComplexTypeMergingExpression {

Expand Down Expand Up @@ -1311,7 +1310,7 @@ class GpuSequenceMeta(

object GpuSequenceUtil {

private def checkSequenceInputs(
def checkSequenceInputs(
start: ColumnVector,
stop: ColumnVector,
step: ColumnVector): Unit = {
Expand Down Expand Up @@ -1364,77 +1363,12 @@ object GpuSequenceUtil {
} // end of zero
}

/**
* Compute the size of each sequence according to 'start', 'stop' and 'step'.
* A row (Row[start, stop, step]) contains at least one null element will produce
* a null in the output.
*
* The returned column should be closed.
*/
def computeSequenceSizes(
start: ColumnVector,
stop: ColumnVector,
step: ColumnVector): ColumnVector = {
checkSequenceInputs(start, stop, step)

// Spark's algorithm to get the length (aka size)
// ``` Scala
// size = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / estimatedStep.toLong
// require(size <= MAX_ROUNDED_ARRAY_LENGTH,
// s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
// size.toInt
// ```
val sizeAsLong = withResource(Scalar.fromLong(1L)) { one =>
val diff = withResource(stop.castTo(DType.INT64)) { stopAsLong =>
withResource(start.castTo(DType.INT64)) { startAsLong =>
stopAsLong.sub(startAsLong)
}
}
val quotient = withResource(diff) { _ =>
withResource(step.castTo(DType.INT64)) { stepAsLong =>
diff.div(stepAsLong)
}
}
// actualSize = 1L + (stop.toLong - start.toLong) / estimatedStep.toLong
val actualSize = withResource(quotient) { quotient =>
quotient.add(one, DType.INT64)
}
withResource(actualSize) { _ =>
val mergedEquals = withResource(start.equalTo(stop)) { equals =>
if (step.hasNulls) {
// Also set the row to null where step is null.
equals.mergeAndSetValidity(BinaryOp.BITWISE_AND, equals, step)
} else {
equals.incRefCount()
}
}
withResource(mergedEquals) { _ =>
mergedEquals.ifElse(one, actualSize)
}
}
}

withResource(sizeAsLong) { _ =>
// check max size
withResource(Scalar.fromInt(MAX_ROUNDED_ARRAY_LENGTH)) { maxLen =>
withResource(sizeAsLong.lessOrEqualTo(maxLen)) { allValid =>
require(isAllValidTrue(allValid),
s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
}
}
// cast to int and return
sizeAsLong.castTo(DType.INT32)
}
}

}

case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expression],
timeZoneId: Option[String] = None) extends TimeZoneAwareExpression with GpuExpression
with ShimExpression {

import GpuSequenceUtil._

override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false)

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
Expand All @@ -1460,7 +1394,7 @@ case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expr
val steps = stepGpuColOpt.map(_.getBase.incRefCount())
.getOrElse(defaultStepsFunc(startCol, stopCol))
closeOnExcept(steps) { _ =>
(computeSequenceSizes(startCol, stopCol, steps), steps)
(ComputeSequenceSizes(startCol, stopCol, steps), steps)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.rapids.execution.python

import scala.collection.JavaConverters.seqAsJavaListConverter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
{"spark": "340"}
{"spark": "341"}
{"spark": "350"}
{"spark": "351"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
{"spark": "340"}
{"spark": "341"}
{"spark": "350"}
{"spark": "351"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

Expand Down
Loading