Skip to content

Commit

Permalink
[SPARK-50029][SQL] Make StaticInvoke compatible with the method tha…
Browse files Browse the repository at this point in the history
…t return `Any`

### What changes were proposed in this pull request?
The pr aims to make `StaticInvoke` compatible with the method that return `Any`.

### Why are the changes needed?
Currently, our `StaticInvoke` does not support calling the method with a return type signature of `Any`(actually, the type of return value may be `different data type`), while `Invoke` supports it, let's align it.

### Does this PR introduce _any_ user-facing change?
No, only for spark developer.

### How was this patch tested?
- Add new UT.
- Pass GA.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48542 from panbingkun/SPARK-50029.

Authored-by: panbingkun <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
panbingkun authored and cloud-fan committed Oct 21, 2024
1 parent 738dfa3 commit f9a5de4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -322,21 +322,22 @@ case class StaticInvoke(
val evaluate = if (returnNullable && !method.getReturnType.isPrimitive) {
if (CodeGenerator.defaultValue(dataType) == "null") {
s"""
${ev.value} = $callFunc;
${ev.value} = ($javaType) $callFunc;
${ev.isNull} = ${ev.value} == null;
"""
} else {
val boxedResult = ctx.freshName("boxedResult")
val boxedJavaType = CodeGenerator.boxedType(dataType)
s"""
${CodeGenerator.boxedType(dataType)} $boxedResult = $callFunc;
$boxedJavaType $boxedResult = ($boxedJavaType) $callFunc;
${ev.isNull} = $boxedResult == null;
if (!${ev.isNull}) {
${ev.value} = $boxedResult;
}
"""
}
} else {
s"${ev.value} = $callFunc;"
s"${ev.value} = ($javaType) $callFunc;"
}

val code = code"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,31 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val genCode = StaticInvoke(TestFun.getClass, IntegerType, "foo", arguments).genCode(ctx)
assert(!genCode.code.toString.contains("boxedResult"))
}

test("StaticInvoke call return `any` method") {
val cls = TestStaticInvokeReturnAny.getClass
Seq((0, IntegerType, true), (1, IntegerType, true), (2, IntegerType, false)).foreach {
case (arg, argDataType, returnNullable) =>
val dataType = arg match {
case 0 => ObjectType(classOf[java.lang.Integer])
case 1 => ShortType
case 2 => ObjectType(classOf[java.lang.Long])
}
val arguments = Seq(Literal(arg, argDataType))
val inputTypes = Seq(IntegerType)
val expected = arg match {
case 0 => java.lang.Integer.valueOf(1)
case 1 => 0.toShort
case 2 => java.lang.Long.valueOf(2)
}
val inputRow = InternalRow.fromSeq(Seq(arg))
checkObjectExprEvaluation(
StaticInvoke(cls, dataType, "func", arguments, inputTypes,
returnNullable = returnNullable),
expected,
inputRow)
}
}
}

class TestBean extends Serializable {
Expand Down Expand Up @@ -790,3 +815,10 @@ case object TestFun {
def foo(left: Int, right: Int): Int = left + right
}

object TestStaticInvokeReturnAny {
def func(input: Int): Any = input match {
case 0 => java.lang.Integer.valueOf(1)
case 1 => 0.toShort
case 2 => java.lang.Long.valueOf(2)
}
}

0 comments on commit f9a5de4

Please sign in to comment.