wForget commented on code in PR #2417:
URL: https://github.com/apache/datafusion-comet/pull/2417#discussion_r2357652621


##########
spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala:
##########
@@ -154,405 +154,416 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
       operator2Proto(op).map(fun).getOrElse(op)
     }
 
-    plan.transformUp {
-      // Fully native scan for V1
-      case scan: CometScanExec if scan.scanImpl == 
CometConf.SCAN_NATIVE_DATAFUSION =>
-        val nativeOp = QueryPlanSerde.operator2Proto(scan).get
-        CometNativeScanExec(nativeOp, scan.wrapped, scan.session)
-
-      // Comet JVM + native scan for V1 and V2
-      case op if isCometScan(op) =>
-        val nativeOp = QueryPlanSerde.operator2Proto(op)
-        CometScanWrapper(nativeOp.get, op)
-
-      case op if shouldApplySparkToColumnar(conf, op) =>
-        val cometOp = CometSparkToColumnarExec(op)
-        val nativeOp = QueryPlanSerde.operator2Proto(cometOp)
-        CometScanWrapper(nativeOp.get, cometOp)
-
-      case op: ProjectExec =>
-        newPlanWithProto(
-          op,
-          CometProjectExec(_, op, op.output, op.projectList, op.child, 
SerializedPlan(None)))
+    def convertNode(op: SparkPlan): SparkPlan = {
+      op match {
+        // Fully native scan for V1
+        case scan: CometScanExec if scan.scanImpl == 
CometConf.SCAN_NATIVE_DATAFUSION =>
+          val nativeOp = QueryPlanSerde.operator2Proto(scan).get
+          CometNativeScanExec(nativeOp, scan.wrapped, scan.session)
+
+        // Comet JVM + native scan for V1 and V2
+        case op if isCometScan(op) =>
+          val nativeOp = QueryPlanSerde.operator2Proto(op)
+          CometScanWrapper(nativeOp.get, op)
+
+        case op if shouldApplySparkToColumnar(conf, op) =>
+          val cometOp = CometSparkToColumnarExec(op)
+          val nativeOp = QueryPlanSerde.operator2Proto(cometOp)
+          CometScanWrapper(nativeOp.get, cometOp)
+
+        case op: ProjectExec =>
+          newPlanWithProto(
+            op,
+            CometProjectExec(_, op, op.output, op.projectList, op.child, 
SerializedPlan(None)))
 
-      case op: FilterExec =>
-        newPlanWithProto(
-          op,
-          CometFilterExec(_, op, op.output, op.condition, op.child, 
SerializedPlan(None)))
+        case op: FilterExec =>
+          newPlanWithProto(
+            op,
+            CometFilterExec(_, op, op.output, op.condition, op.child, 
SerializedPlan(None)))
 
-      case op: SortExec =>
-        newPlanWithProto(
-          op,
-          CometSortExec(
-            _,
+        case op: SortExec =>
+          newPlanWithProto(
             op,
-            op.output,
-            op.outputOrdering,
-            op.sortOrder,
-            op.child,
-            SerializedPlan(None)))
+            CometSortExec(
+              _,
+              op,
+              op.output,
+              op.outputOrdering,
+              op.sortOrder,
+              op.child,
+              SerializedPlan(None)))
+
+        case op: LocalLimitExec =>
+          newPlanWithProto(
+            op,
+            CometLocalLimitExec(_, op, op.limit, op.child, 
SerializedPlan(None)))
 
-      case op: LocalLimitExec =>
-        newPlanWithProto(op, CometLocalLimitExec(_, op, op.limit, op.child, 
SerializedPlan(None)))
+        case op: GlobalLimitExec =>
+          newPlanWithProto(
+            op,
+            CometGlobalLimitExec(_, op, op.limit, op.offset, op.child, 
SerializedPlan(None)))
 
-      case op: GlobalLimitExec =>
-        newPlanWithProto(
-          op,
-          CometGlobalLimitExec(_, op, op.limit, op.offset, op.child, 
SerializedPlan(None)))
+        case op: CollectLimitExec =>
+          val fallbackReasons = new ListBuffer[String]()
+          if (!CometConf.COMET_EXEC_COLLECT_LIMIT_ENABLED.get(conf)) {
+            fallbackReasons += 
s"${CometConf.COMET_EXEC_COLLECT_LIMIT_ENABLED.key} is false"
+          }
+          if (!isCometShuffleEnabled(conf)) {
+            fallbackReasons += "Comet shuffle is not enabled"
+          }
+          if (fallbackReasons.nonEmpty) {
+            withInfos(op, fallbackReasons.toSet)
+          } else {
+            if (!isCometNative(op.child)) {
+              // no reason to report reason if child is not native
+              op
+            } else {
+              QueryPlanSerde
+                .operator2Proto(op)
+                .map { nativeOp =>
+                  val cometOp =
+                    CometCollectLimitExec(op, op.limit, op.offset, op.child)
+                  CometSinkPlaceHolder(nativeOp, op, cometOp)
+                }
+                .getOrElse(op)
+            }
+          }
 
-      case op: CollectLimitExec =>
-        val fallbackReasons = new ListBuffer[String]()
-        if (!CometConf.COMET_EXEC_COLLECT_LIMIT_ENABLED.get(conf)) {
-          fallbackReasons += 
s"${CometConf.COMET_EXEC_COLLECT_LIMIT_ENABLED.key} is false"
-        }
-        if (!isCometShuffleEnabled(conf)) {
-          fallbackReasons += "Comet shuffle is not enabled"
-        }
-        if (fallbackReasons.nonEmpty) {
-          withInfos(op, fallbackReasons.toSet)
-        } else {
-          if (!isCometNative(op.child)) {
-            // no reason to report reason if child is not native
+        case op: ExpandExec =>
+          newPlanWithProto(
+            op,
+            CometExpandExec(_, op, op.output, op.projections, op.child, 
SerializedPlan(None)))
+
+        // When Comet shuffle is disabled, we don't want to transform the 
HashAggregate
+        // to CometHashAggregate. Otherwise, we probably get partial Comet 
aggregation
+        // and final Spark aggregation.
+        case op: BaseAggregateExec
+            if op.isInstanceOf[HashAggregateExec] ||
+              op.isInstanceOf[ObjectHashAggregateExec] &&
+              isCometShuffleEnabled(conf) =>
+          val modes = op.aggregateExpressions.map(_.mode).distinct
+          // In distinct aggregates there can be a combination of modes
+          val multiMode = modes.size > 1
+          // For a final mode HashAggregate, we only need to transform the 
HashAggregate
+          // if there is Comet partial aggregation.
+          val sparkFinalMode = modes.contains(Final) && 
findCometPartialAgg(op.child).isEmpty
+
+          if (multiMode || sparkFinalMode) {
             op
           } else {
-            QueryPlanSerde
-              .operator2Proto(op)
-              .map { nativeOp =>
-                val cometOp =
-                  CometCollectLimitExec(op, op.limit, op.offset, op.child)
-                CometSinkPlaceHolder(nativeOp, op, cometOp)
-              }
-              .getOrElse(op)
+            newPlanWithProto(
+              op,
+              nativeOp => {
+                // The aggExprs could be empty. For example, if the aggregate 
functions only have
+                // distinct aggregate functions or only have group by, the 
aggExprs is empty and
+                // modes is empty too. If aggExprs is not empty, we need to 
verify all the
+                // aggregates have the same mode.
+                assert(modes.length == 1 || modes.isEmpty)
+                CometHashAggregateExec(
+                  nativeOp,
+                  op,
+                  op.output,
+                  op.groupingExpressions,
+                  op.aggregateExpressions,
+                  op.resultExpressions,
+                  op.child.output,
+                  modes.headOption,
+                  op.child,
+                  SerializedPlan(None))
+              })
           }
-        }
 
-      case op: ExpandExec =>
-        newPlanWithProto(
-          op,
-          CometExpandExec(_, op, op.output, op.projections, op.child, 
SerializedPlan(None)))
-
-      // When Comet shuffle is disabled, we don't want to transform the 
HashAggregate
-      // to CometHashAggregate. Otherwise, we probably get partial Comet 
aggregation
-      // and final Spark aggregation.
-      case op: BaseAggregateExec
-          if op.isInstanceOf[HashAggregateExec] ||
-            op.isInstanceOf[ObjectHashAggregateExec] &&
-            isCometShuffleEnabled(conf) =>
-        val modes = op.aggregateExpressions.map(_.mode).distinct
-        // In distinct aggregates there can be a combination of modes
-        val multiMode = modes.size > 1
-        // For a final mode HashAggregate, we only need to transform the 
HashAggregate
-        // if there is Comet partial aggregation.
-        val sparkFinalMode = modes.contains(Final) && 
findCometPartialAgg(op.child).isEmpty
-
-        if (multiMode || sparkFinalMode) {
-          op
-        } else {
+        case op: ShuffledHashJoinExec
+            if CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) &&
+              op.children.forall(isCometNative) =>
           newPlanWithProto(
             op,
-            nativeOp => {
-              // The aggExprs could be empty. For example, if the aggregate 
functions only have
-              // distinct aggregate functions or only have group by, the 
aggExprs is empty and
-              // modes is empty too. If aggExprs is not empty, we need to 
verify all the
-              // aggregates have the same mode.
-              assert(modes.length == 1 || modes.isEmpty)
-              CometHashAggregateExec(
-                nativeOp,
-                op,
-                op.output,
-                op.groupingExpressions,
-                op.aggregateExpressions,
-                op.resultExpressions,
-                op.child.output,
-                modes.headOption,
-                op.child,
-                SerializedPlan(None))
-            })
-        }
+            CometHashJoinExec(
+              _,
+              op,
+              op.output,
+              op.outputOrdering,
+              op.leftKeys,
+              op.rightKeys,
+              op.joinType,
+              op.condition,
+              op.buildSide,
+              op.left,
+              op.right,
+              SerializedPlan(None)))
+
+        case op: ShuffledHashJoinExec if 
!CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) =>
+          withInfo(op, "ShuffleHashJoin is not enabled")
+
+        case op: ShuffledHashJoinExec if !op.children.forall(isCometNative) =>
+          op
 
-      case op: ShuffledHashJoinExec
-          if CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) &&
-            op.children.forall(isCometNative) =>
-        newPlanWithProto(
-          op,
-          CometHashJoinExec(
-            _,
-            op,
-            op.output,
-            op.outputOrdering,
-            op.leftKeys,
-            op.rightKeys,
-            op.joinType,
-            op.condition,
-            op.buildSide,
-            op.left,
-            op.right,
-            SerializedPlan(None)))
-
-      case op: ShuffledHashJoinExec if 
!CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) =>
-        withInfo(op, "ShuffleHashJoin is not enabled")
-
-      case op: ShuffledHashJoinExec if !op.children.forall(isCometNative) =>
-        op
-
-      case op: BroadcastHashJoinExec
-          if CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) &&
-            op.children.forall(isCometNative) =>
-        newPlanWithProto(
-          op,
-          CometBroadcastHashJoinExec(
-            _,
+        case op: BroadcastHashJoinExec
+            if CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) &&
+              op.children.forall(isCometNative) =>
+          newPlanWithProto(
             op,
-            op.output,
-            op.outputOrdering,
-            op.leftKeys,
-            op.rightKeys,
-            op.joinType,
-            op.condition,
-            op.buildSide,
-            op.left,
-            op.right,
-            SerializedPlan(None)))
-
-      case op: SortMergeJoinExec
-          if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) &&
-            op.children.forall(isCometNative) =>
-        newPlanWithProto(
-          op,
-          CometSortMergeJoinExec(
-            _,
+            CometBroadcastHashJoinExec(
+              _,
+              op,
+              op.output,
+              op.outputOrdering,
+              op.leftKeys,
+              op.rightKeys,
+              op.joinType,
+              op.condition,
+              op.buildSide,
+              op.left,
+              op.right,
+              SerializedPlan(None)))
+
+        case op: SortMergeJoinExec
+            if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) &&
+              op.children.forall(isCometNative) =>
+          newPlanWithProto(
             op,
-            op.output,
-            op.outputOrdering,
-            op.leftKeys,
-            op.rightKeys,
-            op.joinType,
-            op.condition,
-            op.left,
-            op.right,
-            SerializedPlan(None)))
-
-      case op: SortMergeJoinExec
-          if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) &&
-            !op.children.forall(isCometNative) =>
-        op
-
-      case op: SortMergeJoinExec if 
!CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) =>
-        withInfo(op, "SortMergeJoin is not enabled")
-
-      case op: SortMergeJoinExec if !op.children.forall(isCometNative) =>
-        op
-
-      case c @ CoalesceExec(numPartitions, child)
-          if CometConf.COMET_EXEC_COALESCE_ENABLED.get(conf)
-            && isCometNative(child) =>
-        QueryPlanSerde
-          .operator2Proto(c)
-          .map { nativeOp =>
-            val cometOp = CometCoalesceExec(c, c.output, numPartitions, child)
-            CometSinkPlaceHolder(nativeOp, c, cometOp)
-          }
-          .getOrElse(c)
-
-      case c @ CoalesceExec(_, _) if 
!CometConf.COMET_EXEC_COALESCE_ENABLED.get(conf) =>
-        withInfo(c, "Coalesce is not enabled")
-
-      case op: CoalesceExec if !op.children.forall(isCometNative) =>
-        op
-
-      case s: TakeOrderedAndProjectExec
-          if isCometNative(s.child) && 
CometConf.COMET_EXEC_TAKE_ORDERED_AND_PROJECT_ENABLED
-            .get(conf)
-            && isCometShuffleEnabled(conf) &&
-            CometTakeOrderedAndProjectExec.isSupported(s) =>
-        QueryPlanSerde
-          .operator2Proto(s)
-          .map { nativeOp =>
-            val cometOp =
-              CometTakeOrderedAndProjectExec(
-                s,
-                s.output,
-                s.limit,
-                s.offset,
-                s.sortOrder,
-                s.projectList,
-                s.child)
-            CometSinkPlaceHolder(nativeOp, s, cometOp)
-          }
-          .getOrElse(s)
-
-      case s: TakeOrderedAndProjectExec =>
-        val info1 = createMessage(
-          !CometConf.COMET_EXEC_TAKE_ORDERED_AND_PROJECT_ENABLED.get(conf),
-          "TakeOrderedAndProject is not enabled")
-        val info2 = createMessage(
-          !isCometShuffleEnabled(conf),
-          "TakeOrderedAndProject requires shuffle to be enabled")
-        withInfo(s, Seq(info1, info2).flatten.mkString(","))
-
-      case w: WindowExec =>
-        newPlanWithProto(
-          w,
-          CometWindowExec(
-            _,
+            CometSortMergeJoinExec(
+              _,
+              op,
+              op.output,
+              op.outputOrdering,
+              op.leftKeys,
+              op.rightKeys,
+              op.joinType,
+              op.condition,
+              op.left,
+              op.right,
+              SerializedPlan(None)))
+
+        case op: SortMergeJoinExec
+            if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) &&
+              !op.children.forall(isCometNative) =>
+          op
+
+        case op: SortMergeJoinExec if 
!CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) =>
+          withInfo(op, "SortMergeJoin is not enabled")
+
+        case op: SortMergeJoinExec if !op.children.forall(isCometNative) =>
+          op
+
+        case c @ CoalesceExec(numPartitions, child)
+            if CometConf.COMET_EXEC_COALESCE_ENABLED.get(conf)
+              && isCometNative(child) =>
+          QueryPlanSerde
+            .operator2Proto(c)
+            .map { nativeOp =>
+              val cometOp = CometCoalesceExec(c, c.output, numPartitions, 
child)
+              CometSinkPlaceHolder(nativeOp, c, cometOp)
+            }
+            .getOrElse(c)
+
+        case c @ CoalesceExec(_, _) if 
!CometConf.COMET_EXEC_COALESCE_ENABLED.get(conf) =>
+          withInfo(c, "Coalesce is not enabled")
+
+        case op: CoalesceExec if !op.children.forall(isCometNative) =>
+          op
+
+        case s: TakeOrderedAndProjectExec
+            if isCometNative(s.child) && 
CometConf.COMET_EXEC_TAKE_ORDERED_AND_PROJECT_ENABLED
+              .get(conf)
+              && isCometShuffleEnabled(conf) &&
+              CometTakeOrderedAndProjectExec.isSupported(s) =>
+          QueryPlanSerde
+            .operator2Proto(s)
+            .map { nativeOp =>
+              val cometOp =
+                CometTakeOrderedAndProjectExec(
+                  s,
+                  s.output,
+                  s.limit,
+                  s.offset,
+                  s.sortOrder,
+                  s.projectList,
+                  s.child)
+              CometSinkPlaceHolder(nativeOp, s, cometOp)
+            }
+            .getOrElse(s)
+
+        case s: TakeOrderedAndProjectExec =>
+          val info1 = createMessage(
+            !CometConf.COMET_EXEC_TAKE_ORDERED_AND_PROJECT_ENABLED.get(conf),
+            "TakeOrderedAndProject is not enabled")
+          val info2 = createMessage(
+            !isCometShuffleEnabled(conf),
+            "TakeOrderedAndProject requires shuffle to be enabled")
+          withInfo(s, Seq(info1, info2).flatten.mkString(","))
+
+        case w: WindowExec =>
+          newPlanWithProto(
             w,
-            w.output,
-            w.windowExpression,
-            w.partitionSpec,
-            w.orderSpec,
-            w.child,
-            SerializedPlan(None)))
-
-      case u: UnionExec
-          if CometConf.COMET_EXEC_UNION_ENABLED.get(conf) &&
-            u.children.forall(isCometNative) =>
-        newPlanWithProto(
-          u, {
-            val cometOp = CometUnionExec(u, u.output, u.children)
-            CometSinkPlaceHolder(_, u, cometOp)
-          })
-
-      case u: UnionExec if !CometConf.COMET_EXEC_UNION_ENABLED.get(conf) =>
-        withInfo(u, "Union is not enabled")
-
-      case op: UnionExec if !op.children.forall(isCometNative) =>
-        op
-
-      // For AQE broadcast stage on a Comet broadcast exchange
-      case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) =>
-        newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
-
-      case s @ BroadcastQueryStageExec(
-            _,
-            ReusedExchangeExec(_, _: CometBroadcastExchangeExec),
-            _) =>
-        newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
-
-      // `CometBroadcastExchangeExec`'s broadcast output is not compatible 
with Spark's broadcast
-      // exchange. It is only used for Comet native execution. We only 
transform Spark broadcast
-      // exchange to Comet broadcast exchange if its downstream is a Comet 
native plan or if the
-      // broadcast exchange is forced to be enabled by Comet config.
-      case plan if plan.children.exists(_.isInstanceOf[BroadcastExchangeExec]) 
=>
-        val newChildren = plan.children.map {
-          case b: BroadcastExchangeExec
-              if isCometNative(b.child) &&
-                CometConf.COMET_EXEC_BROADCAST_EXCHANGE_ENABLED.get(conf) =>
-            QueryPlanSerde.operator2Proto(b) match {
-              case Some(nativeOp) =>
-                val cometOp = CometBroadcastExchangeExec(b, b.output, b.mode, 
b.child)
-                CometSinkPlaceHolder(nativeOp, b, cometOp)
-              case None => b
+            CometWindowExec(
+              _,
+              w,
+              w.output,
+              w.windowExpression,
+              w.partitionSpec,
+              w.orderSpec,
+              w.child,
+              SerializedPlan(None)))
+
+        case u: UnionExec
+            if CometConf.COMET_EXEC_UNION_ENABLED.get(conf) &&
+              u.children.forall(isCometNative) =>
+          newPlanWithProto(
+            u, {
+              val cometOp = CometUnionExec(u, u.output, u.children)
+              CometSinkPlaceHolder(_, u, cometOp)
+            })
+
+        case u: UnionExec if !CometConf.COMET_EXEC_UNION_ENABLED.get(conf) =>
+          withInfo(u, "Union is not enabled")
+
+        case op: UnionExec if !op.children.forall(isCometNative) =>
+          op
+
+        // For AQE broadcast stage on a Comet broadcast exchange
+        case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) 
=>
+          newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
+
+        case s @ BroadcastQueryStageExec(
+              _,
+              ReusedExchangeExec(_, _: CometBroadcastExchangeExec),
+              _) =>
+          newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
+
+        // `CometBroadcastExchangeExec`'s broadcast output is not compatible 
with Spark's broadcast
+        // exchange. It is only used for Comet native execution. We only 
transform Spark broadcast
+        // exchange to Comet broadcast exchange if its downstream is a Comet 
native plan or if the
+        // broadcast exchange is forced to be enabled by Comet config.
+        case plan if 
plan.children.exists(_.isInstanceOf[BroadcastExchangeExec]) =>
+          val newChildren = plan.children.map {
+            case b: BroadcastExchangeExec
+                if isCometNative(b.child) &&
+                  CometConf.COMET_EXEC_BROADCAST_EXCHANGE_ENABLED.get(conf) =>
+              QueryPlanSerde.operator2Proto(b) match {
+                case Some(nativeOp) =>
+                  val cometOp = CometBroadcastExchangeExec(b, b.output, 
b.mode, b.child)
+                  CometSinkPlaceHolder(nativeOp, b, cometOp)
+                case None => b
+              }
+            case other => other
+          }
+          if (!newChildren.exists(_.isInstanceOf[BroadcastExchangeExec])) {
+            val newPlan = convertNode(plan.withNewChildren(newChildren))

Review Comment:
   Replace `apply` with `convertNode`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to