Skip to content

Commit

Permalink
chore: Remove redundant processing from exprToProtoInternal (#1351)
Browse files Browse the repository at this point in the history
* Remove redundant processing from exprToProtoInternal

* fix typo
  • Loading branch information
andygrove authored Feb 4, 2025
1 parent 996362e commit 6cf140f
Showing 1 changed file with 36 additions and 15 deletions.
51 changes: 36 additions & 15 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

/**
* Convert a Spark expression to protobuf.
* Convert a Spark expression to a protocol-buffer representation of a native Comet/DataFusion
* expression.
*
* This method performs a transformation on the plan to handle decimal promotion and then calls
* into the recursive method [[exprToProtoInternal]].
*
* @param expr
* The input expression
Expand All @@ -910,7 +914,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
* @param binding
* Whether to bind the expression to the input attributes
* @return
* The protobuf representation of the expression, or None if the expression is not supported
* The protobuf representation of the expression, or None if the expression is not supported.
* In the case where None is returned, the expression will be tagged with the reason(s) why it
* is not supported.
*/
def exprToProto(
expr: Expression,
Expand All @@ -923,6 +929,21 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
exprToProtoInternal(newExpr, inputs, binding)
}

/**
* Convert a Spark expression to a protocol-buffer representation of a native Comet/DataFusion
* expression.
*
* @param expr
* The input expression
* @param inputs
* The input attributes
* @param binding
* Whether to bind the expression to the input attributes
* @return
* The protobuf representation of the expression, or None if the expression is not supported.
* In the case where None is returned, the expression will be tagged with the reason(s) why it
* is not supported.
*/
def exprToProtoInternal(
expr: Expression,
inputs: Seq[Attribute],
Expand Down Expand Up @@ -1250,7 +1271,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

if (isSupported) {
exprToProto(child, inputs, binding) match {
exprToProtoInternal(child, inputs, binding) match {
case Some(p) =>
val toJson = ExprOuterClass.ToJson
.newBuilder()
Expand Down Expand Up @@ -2222,7 +2243,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
return None
}

val valExprs = struct.valExprs.map(exprToProto(_, inputs, binding))
val valExprs = struct.valExprs.map(exprToProtoInternal(_, inputs, binding))

if (valExprs.forall(_.isDefined)) {
val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder()
Expand All @@ -2240,7 +2261,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

case GetStructField(child, ordinal, _) =>
exprToProto(child, inputs, binding).map { childExpr =>
exprToProtoInternal(child, inputs, binding).map { childExpr =>
val getStructFieldBuilder = ExprOuterClass.GetStructField
.newBuilder()
.setChild(childExpr)
Expand All @@ -2253,7 +2274,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

case CreateArray(children, _) =>
val childExprs = children.map(exprToProto(_, inputs, binding))
val childExprs = children.map(exprToProtoInternal(_, inputs, binding))

if (childExprs.forall(_.isDefined)) {
scalarExprToProto("make_array", childExprs: _*)
Expand All @@ -2263,8 +2284,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

case GetArrayItem(child, ordinal, failOnError) =>
val childExpr = exprToProto(child, inputs, binding)
val ordinalExpr = exprToProto(ordinal, inputs, binding)
val childExpr = exprToProtoInternal(child, inputs, binding)
val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)

if (childExpr.isDefined && ordinalExpr.isDefined) {
val listExtractBuilder = ExprOuterClass.ListExtract
Expand All @@ -2285,9 +2306,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

case expr if expr.prettyName == "array_insert" =>
val srcExprProto = exprToProto(expr.children(0), inputs, binding)
val posExprProto = exprToProto(expr.children(1), inputs, binding)
val itemExprProto = exprToProto(expr.children(2), inputs, binding)
val srcExprProto = exprToProtoInternal(expr.children(0), inputs, binding)
val posExprProto = exprToProtoInternal(expr.children(1), inputs, binding)
val itemExprProto = exprToProtoInternal(expr.children(2), inputs, binding)
val legacyNegativeIndex =
SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean
if (srcExprProto.isDefined && posExprProto.isDefined && itemExprProto.isDefined) {
Expand Down Expand Up @@ -2315,9 +2336,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim

case ElementAt(child, ordinal, defaultValue, failOnError)
if child.dataType.isInstanceOf[ArrayType] =>
val childExpr = exprToProto(child, inputs, binding)
val ordinalExpr = exprToProto(ordinal, inputs, binding)
val defaultExpr = defaultValue.flatMap(exprToProto(_, inputs, binding))
val childExpr = exprToProtoInternal(child, inputs, binding)
val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)
val defaultExpr = defaultValue.flatMap(exprToProtoInternal(_, inputs, binding))

if (childExpr.isDefined && ordinalExpr.isDefined &&
defaultExpr.isDefined == defaultValue.isDefined) {
Expand All @@ -2341,7 +2362,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}

case GetArrayStructFields(child, _, ordinal, _, _) =>
val childExpr = exprToProto(child, inputs, binding)
val childExpr = exprToProtoInternal(child, inputs, binding)

if (childExpr.isDefined) {
val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
Expand Down

0 comments on commit 6cf140f

Please sign in to comment.