Skip to content

Commit

Permalink
addressed review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
razajafri committed Dec 21, 2023
1 parent 79ed020 commit 51cae93
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 217 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ object AddOverflowChecks {
withResource(signDiffCV) { signDiff =>
withResource(signDiff.any()) { any =>
if (any.isValid && any.getBoolean) {
throw RapidsErrorUtils.
arithmeticOverflowError("One or more rows overflow for Add operation.")
throw RapidsErrorUtils.arithmeticOverflowError(
"One or more rows overflow for Add operation."
)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.ArrayIndexUtils.firstIndexAndNumElementUnchecked
import com.nvidia.spark.rapids.BoolUtils.isAllValidTrue
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.{ComputeSequenceSizes, ShimExpression}
import com.nvidia.spark.rapids.shims.{GetSequenceSize, ShimExpression}

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.{ElementAt, ExpectsInputTypes, Expression, ImplicitCastInputTypes, NamedExpression, NullIntolerant, RowOrdering, Sequence, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

case class GpuConcat(children: Seq[Expression]) extends GpuComplexTypeMergingExpression {

Expand Down Expand Up @@ -1363,12 +1364,53 @@ object GpuSequenceUtil {
} // end of zero
}

/**
* Compute the size of each sequence according to 'start', 'stop' and 'step'.
* A row (Row[start, stop, step]) contains at least one null element will produce
* a null in the output.
*
* The returned column should be closed.
*/
def computeSequenceSize(
start: ColumnVector,
stop: ColumnVector,
step: ColumnVector): ColumnVector = {
checkSequenceInputs(start, stop, step)
val actualSize = GetSequenceSize(start, stop, step)
val sizeAsLong = withResource(actualSize) { _ =>
val mergedEquals = withResource(start.equalTo(stop)) { equals =>
if (step.hasNulls) {
// Also set the row to null where step is null.
equals.mergeAndSetValidity(BinaryOp.BITWISE_AND, equals, step)
} else {
equals.incRefCount()
}
}
withResource(mergedEquals) { _ =>
withResource(Scalar.fromLong(1L)) { one =>
mergedEquals.ifElse(one, actualSize)
}
}
}
withResource(sizeAsLong) { _ =>
// check max size
withResource(Scalar.fromInt(MAX_ROUNDED_ARRAY_LENGTH)) { maxLen =>
withResource(sizeAsLong.lessOrEqualTo(maxLen)) { allValid =>
require(isAllValidTrue(allValid), GetSequenceSize.TOO_LONG_SEQUENCE)
}
}
// cast to int and return
sizeAsLong.castTo(DType.INT32)
}
}
}

case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expression],
timeZoneId: Option[String] = None) extends TimeZoneAwareExpression with GpuExpression
with ShimExpression {

import GpuSequenceUtil._

override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false)

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
Expand All @@ -1394,7 +1436,7 @@ case class GpuSequence(start: Expression, stop: Expression, stepOpt: Option[Expr
val steps = stepGpuColOpt.map(_.getBase.incRefCount())
.getOrElse(defaultStepsFunc(startCol, stopCol))
closeOnExcept(steps) { _ =>
(ComputeSequenceSizes(startCol, stopCol, steps), steps)
(computeSequenceSize(startCol, stopCol, steps), steps)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,11 @@ package com.nvidia.spark.rapids.shims

import ai.rapids.cudf._
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.BoolUtils.isAllValidTrue

import org.apache.spark.sql.rapids.GpuSequenceUtil
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

object ComputeSequenceSizes {
object GetSequenceSize {
val TOO_LONG_SEQUENCE = s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH"
/**
* Compute the size of each sequence according to 'start', 'stop' and 'step'.
* A row (Row[start, stop, step]) contains at least one null element will produce
Expand All @@ -59,54 +58,29 @@ object ComputeSequenceSizes {
start: ColumnVector,
stop: ColumnVector,
step: ColumnVector): ColumnVector = {
GpuSequenceUtil.checkSequenceInputs(start, stop, step)
// Spark's algorithm to get the length (aka size)
// ``` Scala
// size = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / estimatedStep.toLong
// require(size <= MAX_ROUNDED_ARRAY_LENGTH,
// s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
// size.toInt
// ```
val sizeAsLong = withResource(Scalar.fromLong(1L)) { one =>
val diff = withResource(stop.castTo(DType.INT64)) { stopAsLong =>
withResource(start.castTo(DType.INT64)) { startAsLong =>
stopAsLong.sub(startAsLong)
}
val diff = withResource(stop.castTo(DType.INT64)) { stopAsLong =>
withResource(start.castTo(DType.INT64)) { startAsLong =>
stopAsLong.sub(startAsLong)
}
val quotient = withResource(diff) { _ =>
withResource(step.castTo(DType.INT64)) { stepAsLong =>
diff.div(stepAsLong)
}
}
// actualSize = 1L + (stop.toLong - start.toLong) / estimatedStep.toLong
val actualSize = withResource(quotient) { quotient =>
quotient.add(one, DType.INT64)
}
withResource(actualSize) { _ =>
val mergedEquals = withResource(start.equalTo(stop)) { equals =>
if (step.hasNulls) {
// Also set the row to null where step is null.
equals.mergeAndSetValidity(BinaryOp.BITWISE_AND, equals, step)
} else {
equals.incRefCount()
}
}
withResource(mergedEquals) { _ =>
mergedEquals.ifElse(one, actualSize)
}
}
val quotient = withResource(diff) { _ =>
withResource(step.castTo(DType.INT64)) { stepAsLong =>
diff.div(stepAsLong)
}
}

withResource(sizeAsLong) { _ =>
// check max size
withResource(Scalar.fromInt(MAX_ROUNDED_ARRAY_LENGTH)) { maxLen =>
withResource(sizeAsLong.lessOrEqualTo(maxLen)) { allValid =>
require(isAllValidTrue(allValid),
s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
}
// actualSize = 1L + (stop.toLong - start.toLong) / estimatedStep.toLong
val actualSize = withResource(Scalar.fromLong(1L)) { one =>
withResource(quotient) { quotient =>
quotient.add(one, DType.INT64)
}
// cast to int and return
sizeAsLong.castTo(DType.INT32)
}
actualSize
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "351"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import ai.rapids.cudf._
import com.nvidia.spark.rapids.Arm._

import org.apache.spark.sql.rapids.{AddOverflowChecks, SubtractOverflowChecks}
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

object GetSequenceSize {
val TOO_LONG_SEQUENCE = "Unsuccessful try to create array with elements exceeding the array " +
s"size limit $MAX_ROUNDED_ARRAY_LENGTH"
/**
* Compute the size of each sequence according to 'start', 'stop' and 'step'.
* A row (Row[start, stop, step]) contains at least one null element will produce
* a null in the output.
*
* The returned column should be closed.
*/
def apply(
start: ColumnVector,
stop: ColumnVector,
step: ColumnVector): ColumnVector = {

// Spark's algorithm to get the length (aka size)
// ``` Scala
// val delta = Math.subtractExact(stop, start)
// if (delta == Long.MinValue && step == -1L) {
// // We must special-case division of Long.MinValue by -1 to catch potential unchecked
// // overflow in next operation. Division does not have a builtin overflow check. We
// // previously special-case div-by-zero.
// throw new ArithmeticException("Long overflow (Long.MinValue / -1)")
// }
// val len = if (stop == start) 1L else Math.addExact(1L, (delta / step))
// if (len > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
// throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(len)
// }
// len.toInt
// ```
val delta = withResource(stop.castTo(DType.INT64)) { stopAsLong =>
withResource(start.castTo(DType.INT64)) { startAsLong =>
closeOnExcept(stopAsLong.sub(startAsLong)) { ret =>
// Throw an exception if stop - start overflows
SubtractOverflowChecks.basicOpOverflowCheck(stopAsLong, startAsLong, ret)
ret
}
}
}
withResource(Scalar.fromLong(Long.MinValue)) { longMin =>
withResource(delta.equalTo(longMin)) { hasLongMin =>
withResource(Scalar.fromInt(-1)) { minusOne =>
withResource(step.equalTo(minusOne)) { stepEqualsMinusOne =>
withResource(hasLongMin.and(stepEqualsMinusOne)) { hasLongMinAndStepMinusOne =>
withResource(hasLongMinAndStepMinusOne.any()) { result =>
if (result.isValid && result.getBoolean) {
// Overflow, throw an exception
throw new ArithmeticException("Unsuccessful try to create array with " +
s"elements exceeding the array size limit $MAX_ROUNDED_ARRAY_LENGTH")
}
}
}
}
}
}
}
val quotient = withResource(delta) { _ =>
withResource(step.castTo(DType.INT64)) { stepAsLong =>
delta.div(stepAsLong)
}
}
// delta = (stop.toLong - start.toLong) / estimatedStep.toLong
// actualSize = 1L + delta
val actualSize = withResource(Scalar.fromLong(1L)) { one =>
withResource(quotient) { quotient =>
closeOnExcept(quotient.add(one, DType.INT64)) { ret =>
AddOverflowChecks.basicOpOverflowCheck(quotient, one, ret)
ret
}
}
}
actualSize
}
}
Loading

0 comments on commit 51cae93

Please sign in to comment.