From 6cf140f43879e4b2aa5d0471aa1b89ad93b3ced5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 4 Feb 2025 08:59:35 -0800 Subject: [PATCH] chore: Remove redundant processing from exprToProtoInternal (#1351) * Remove redundant processing from exprToProtoInternal * fix typo --- .../apache/comet/serde/QueryPlanSerde.scala | 51 +++++++++++++------ 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index f4699af8d..9ce6ea9c2 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -901,7 +901,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } /** - * Convert a Spark expression to protobuf. + * Convert a Spark expression to a protocol-buffer representation of a native Comet/DataFusion + * expression. + * + * This method performs a transformation on the plan to handle decimal promotion and then calls + * into the recursive method [[exprToProtoInternal]]. * * @param expr * The input expression @@ -910,7 +914,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim * @param binding * Whether to bind the expression to the input attributes * @return - * The protobuf representation of the expression, or None if the expression is not supported + * The protobuf representation of the expression, or None if the expression is not supported. + * In the case where None is returned, the expression will be tagged with the reason(s) why it + * is not supported. */ def exprToProto( expr: Expression, @@ -923,6 +929,21 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim exprToProtoInternal(newExpr, inputs, binding) } + /** + * Convert a Spark expression to a protocol-buffer representation of a native Comet/DataFusion + * expression. + * + * @param expr + * The input expression + * @param inputs + * The input attributes + * @param binding + * Whether to bind the expression to the input attributes + * @return + * The protobuf representation of the expression, or None if the expression is not supported. + * In the case where None is returned, the expression will be tagged with the reason(s) why it + * is not supported. + */ def exprToProtoInternal( expr: Expression, inputs: Seq[Attribute], @@ -1250,7 +1271,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } if (isSupported) { - exprToProto(child, inputs, binding) match { + exprToProtoInternal(child, inputs, binding) match { case Some(p) => val toJson = ExprOuterClass.ToJson .newBuilder() @@ -2222,7 +2243,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim return None } - val valExprs = struct.valExprs.map(exprToProto(_, inputs, binding)) + val valExprs = struct.valExprs.map(exprToProtoInternal(_, inputs, binding)) if (valExprs.forall(_.isDefined)) { val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder() @@ -2240,7 +2261,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } case GetStructField(child, ordinal, _) => - exprToProto(child, inputs, binding).map { childExpr => + exprToProtoInternal(child, inputs, binding).map { childExpr => val getStructFieldBuilder = ExprOuterClass.GetStructField .newBuilder() .setChild(childExpr) @@ -2253,7 +2274,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } case CreateArray(children, _) => - val childExprs = children.map(exprToProto(_, inputs, binding)) + val childExprs = children.map(exprToProtoInternal(_, inputs, binding)) if (childExprs.forall(_.isDefined)) { scalarExprToProto("make_array", childExprs: _*) @@ -2263,8 +2284,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } case GetArrayItem(child, ordinal, failOnError) => - val childExpr = exprToProto(child, inputs, binding) - val ordinalExpr = exprToProto(ordinal, inputs, binding) + val childExpr = exprToProtoInternal(child, inputs, binding) + val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding) if (childExpr.isDefined && ordinalExpr.isDefined) { val listExtractBuilder = ExprOuterClass.ListExtract @@ -2285,9 +2306,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } case expr if expr.prettyName == "array_insert" => - val srcExprProto = exprToProto(expr.children(0), inputs, binding) - val posExprProto = exprToProto(expr.children(1), inputs, binding) - val itemExprProto = exprToProto(expr.children(2), inputs, binding) + val srcExprProto = exprToProtoInternal(expr.children(0), inputs, binding) + val posExprProto = exprToProtoInternal(expr.children(1), inputs, binding) + val itemExprProto = exprToProtoInternal(expr.children(2), inputs, binding) val legacyNegativeIndex = SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean if (srcExprProto.isDefined && posExprProto.isDefined && itemExprProto.isDefined) { @@ -2315,9 +2336,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case ElementAt(child, ordinal, defaultValue, failOnError) if child.dataType.isInstanceOf[ArrayType] => - val childExpr = exprToProto(child, inputs, binding) - val ordinalExpr = exprToProto(ordinal, inputs, binding) - val defaultExpr = defaultValue.flatMap(exprToProto(_, inputs, binding)) + val childExpr = exprToProtoInternal(child, inputs, binding) + val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding) + val defaultExpr = defaultValue.flatMap(exprToProtoInternal(_, inputs, binding)) if (childExpr.isDefined && ordinalExpr.isDefined && defaultExpr.isDefined == defaultValue.isDefined) { @@ -2341,7 +2362,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } case GetArrayStructFields(child, _, ordinal, _, _) => - val childExpr = exprToProto(child, inputs, binding) + val childExpr = exprToProtoInternal(child, inputs, binding) if (childExpr.isDefined) { val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields