This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch comet-parquet-exec
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/comet-parquet-exec by this 
push:
     new f43a1d48d chore: [comet-parquet-exec] Fix regressions related to 
zipping rdds (#1298)
f43a1d48d is described below

commit f43a1d48d97931a8a4b27c7de7fa056373191acd
Author: Andy Grove <[email protected]>
AuthorDate: Sat Jan 18 14:36:44 2025 -0700

    chore: [comet-parquet-exec] Fix regressions related to zipping rdds (#1298)
---
 .../org/apache/spark/sql/comet/operators.scala     | 51 ++++++++++------------
 1 file changed, 24 insertions(+), 27 deletions(-)

diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 9e2ca987f..f36b41aa6 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -258,48 +258,44 @@ abstract class CometNativeExec extends CometExec {
         // If the first non broadcast plan is found, we need to adjust the 
partition number of
         // the broadcast plans to make sure they have the same partition 
number as the first non
         // broadcast plan.
-        val firstNonBroadcastPlanNumPartitions =
-          firstNonBroadcastPlan.map(_._1.outputPartitioning.numPartitions)
+        val (firstNonBroadcastPlanRDD, firstNonBroadcastPlanNumPartitions) =
+          firstNonBroadcastPlan.get._1 match {
+            case plan: CometNativeExec =>
+              (null, plan.outputPartitioning.numPartitions)
+            case plan =>
+              val rdd = plan.executeColumnar()
+              (rdd, rdd.getNumPartitions)
+          }
 
         // Spark doesn't need to zip Broadcast RDDs, so it doesn't schedule 
Broadcast RDDs with
         // same partition number. But for Comet, we need to zip them so we 
need to adjust the
         // partition number of Broadcast RDDs to make sure they have the same 
partition number.
-        sparkPlans.zipWithIndex.foreach { case (plan, _) =>
+        sparkPlans.zipWithIndex.foreach { case (plan, idx) =>
           plan match {
-            case c: CometBroadcastExchangeExec if 
firstNonBroadcastPlanNumPartitions.nonEmpty =>
-              inputs += c
-                .setNumPartitions(firstNonBroadcastPlanNumPartitions.get)
-                .executeColumnar()
-            case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _)
-                if firstNonBroadcastPlanNumPartitions.nonEmpty =>
-              inputs += c
-                .setNumPartitions(firstNonBroadcastPlanNumPartitions.get)
-                .executeColumnar()
-            case ReusedExchangeExec(_, c: CometBroadcastExchangeExec)
-                if firstNonBroadcastPlanNumPartitions.nonEmpty =>
-              inputs += c
-                .setNumPartitions(firstNonBroadcastPlanNumPartitions.get)
-                .executeColumnar()
+            case c: CometBroadcastExchangeExec =>
+              inputs += 
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
+            case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) 
=>
+              inputs += 
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
+            case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) =>
+              inputs += 
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
             case BroadcastQueryStageExec(
                   _,
                   ReusedExchangeExec(_, c: CometBroadcastExchangeExec),
-                  _) if firstNonBroadcastPlanNumPartitions.nonEmpty =>
-              inputs += c
-                .setNumPartitions(firstNonBroadcastPlanNumPartitions.get)
-                .executeColumnar()
+                  _) =>
+              inputs += 
c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
             case _: CometNativeExec =>
             // no-op
-            case _ if firstNonBroadcastPlanNumPartitions.nonEmpty =>
+            case _ if idx == firstNonBroadcastPlan.get._2 =>
+              inputs += firstNonBroadcastPlanRDD
+            case _ =>
               val rdd = plan.executeColumnar()
-              if (plan.outputPartitioning.numPartitions != 
firstNonBroadcastPlanNumPartitions.get) {
+              if (rdd.getNumPartitions != firstNonBroadcastPlanNumPartitions) {
                 throw new CometRuntimeException(
                   s"Partition number mismatch: ${rdd.getNumPartitions} != " +
-                    s"${firstNonBroadcastPlanNumPartitions.get}")
+                    s"$firstNonBroadcastPlanNumPartitions")
               } else {
                 inputs += rdd
               }
-            case _ =>
-              throw new CometRuntimeException(s"Unexpected plan: $plan")
           }
         }
 
@@ -310,7 +306,7 @@ abstract class CometNativeExec extends CometExec {
         if (inputs.nonEmpty) {
           ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter)
         } else {
-          val partitionNum = firstNonBroadcastPlanNumPartitions.get
+          val partitionNum = firstNonBroadcastPlanNumPartitions
           CometExecRDD(sparkContext, partitionNum)(createCometExecIter)
         }
     }
@@ -648,6 +644,7 @@ case class CometUnionExec(
     override val output: Seq[Attribute],
     children: Seq[SparkPlan])
     extends CometExec {
+
   override def doExecuteColumnar(): RDD[ColumnarBatch] = {
     sparkContext.union(children.map(_.executeColumnar()))
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to