Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Remove redundant processing from exprToProtoInternal #1351

Merged
merged 3 commits into from
Feb 4, 2025
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

/**
* Convert a Spark expression to protobuf.
* Convert a Spark expression to a protocol-buffer representation of a native Comet/DataFusion
* expression.
*
* This method performs a transformation on the plan to handle decimal promotion and then calls
* into the recursive method [[exprToProtoInternal]].
*
* @param expr
* The input expression
Expand All @@ -910,7 +914,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
* @param binding
* Whether to bind the expression to the input attributes
* @return
* The protobuf representation of the expression, or None if the expression is not supported
* The protobuf representation of the expression, or None if the expression is not supported.
* In the case where None is returned, the expression will be tagged with the reason(s) why it
* is not supported.
*/
def exprToProto(
expr: Expression,
Expand All @@ -923,6 +929,21 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
exprToProtoInternal(newExpr, inputs, binding)
}

/**
* Convert a Spark expression to a protocol-buffer representation of a native Comet/DataFusion
* expression.
*
* @param expr
* The input expression
* @param inputs
* The input attributes
* @param binding
* Whether to bind the expression to the input attributes
* @return
* The protobuf representation of the expression, or None if the expression is not supported.
* In the case where None is returned, the expression will be tagged with the reason(s) why it
* is not supported.
*/
def exprToProtoInternal(
expr: Expression,
inputs: Seq[Attribute],
Expand Down Expand Up @@ -1237,7 +1258,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

if (isSupported) {
exprToProto(child, inputs, binding) match {
exprToProtoInternal(child, inputs, binding) match {
case Some(p) =>
val toJson = ExprOuterClass.ToJson
.newBuilder()
Expand Down Expand Up @@ -2235,7 +2256,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
return None
}

val valExprs = struct.valExprs.map(exprToProto(_, inputs, binding))
val valExprs = struct.valExprs.map(exprToProtoInternal(_, inputs, binding))

if (valExprs.forall(_.isDefined)) {
val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder()
Expand All @@ -2253,7 +2274,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

case GetStructField(child, ordinal, _) =>
exprToProto(child, inputs, binding).map { childExpr =>
exprToProtoInternal(child, inputs, binding).map { childExpr =>
val getStructFieldBuilder = ExprOuterClass.GetStructField
.newBuilder()
.setChild(childExpr)
Expand All @@ -2266,7 +2287,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

case CreateArray(children, _) =>
val childExprs = children.map(exprToProto(_, inputs, binding))
val childExprs = children.map(exprToProtoInternal(_, inputs, binding))

if (childExprs.forall(_.isDefined)) {
scalarExprToProto("make_array", childExprs: _*)
Expand All @@ -2276,8 +2297,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

case GetArrayItem(child, ordinal, failOnError) =>
val childExpr = exprToProto(child, inputs, binding)
val ordinalExpr = exprToProto(ordinal, inputs, binding)
val childExpr = exprToProtoInternal(child, inputs, binding)
val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)

if (childExpr.isDefined && ordinalExpr.isDefined) {
val listExtractBuilder = ExprOuterClass.ListExtract
Expand All @@ -2298,9 +2319,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

case expr if expr.prettyName == "array_insert" =>
val srcExprProto = exprToProto(expr.children(0), inputs, binding)
val posExprProto = exprToProto(expr.children(1), inputs, binding)
val itemExprProto = exprToProto(expr.children(2), inputs, binding)
val srcExprProto = exprToProtoInternal(expr.children(0), inputs, binding)
val posExprProto = exprToProtoInternal(expr.children(1), inputs, binding)
val itemExprProto = exprToProtoInternal(expr.children(2), inputs, binding)
val legacyNegativeIndex =
SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean
if (srcExprProto.isDefined && posExprProto.isDefined && itemExprProto.isDefined) {
Expand Down Expand Up @@ -2328,9 +2349,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim

case ElementAt(child, ordinal, defaultValue, failOnError)
if child.dataType.isInstanceOf[ArrayType] =>
val childExpr = exprToProto(child, inputs, binding)
val ordinalExpr = exprToProto(ordinal, inputs, binding)
val defaultExpr = defaultValue.flatMap(exprToProto(_, inputs, binding))
val childExpr = exprToProtoInternal(child, inputs, binding)
val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)
val defaultExpr = defaultValue.flatMap(exprToProtoInternal(_, inputs, binding))

if (childExpr.isDefined && ordinalExpr.isDefined &&
defaultExpr.isDefined == defaultValue.isDefined) {
Expand All @@ -2354,7 +2375,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

case GetArrayStructFields(child, _, ordinal, _, _) =>
val childExpr = exprToProto(child, inputs, binding)
val childExpr = exprToProtoInternal(child, inputs, binding)

if (childExpr.isDefined) {
val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
Expand Down Expand Up @@ -2397,13 +2418,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
binding,
(builder, binaryExpr) => builder.setArrayIntersect(binaryExpr))
case ArrayJoin(arrayExpr, delimiterExpr, nullReplacementExpr) =>
val arrayExprProto = exprToProto(arrayExpr, inputs, binding)
val delimiterExprProto = exprToProto(delimiterExpr, inputs, binding)
val arrayExprProto = exprToProtoInternal(arrayExpr, inputs, binding)
val delimiterExprProto = exprToProtoInternal(delimiterExpr, inputs, binding)

if (arrayExprProto.isDefined && delimiterExprProto.isDefined) {
val arrayJoinBuilder = nullReplacementExpr match {
case Some(nrExpr) =>
val nullReplacementExprProto = exprToProto(nrExpr, inputs, binding)
val nullReplacementExprProto = exprToProtoInternal(nrExpr, inputs, binding)
ExprOuterClass.ArrayJoin
.newBuilder()
.setArrayExpr(arrayExprProto.get)
Expand Down
Loading