This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new c1e112035 fix: Avoid spark plan execution cache preventing 
CometBatchRDD numPartitions change (#2420)
c1e112035 is described below

commit c1e112035496befcc7cf996b814984e6841382a0
Author: Zhen Wang <[email protected]>
AuthorDate: Sat Sep 20 12:14:16 2025 +0800

    fix: Avoid spark plan execution cache preventing CometBatchRDD 
numPartitions change (#2420)
    
    * fix: Avoid spark plan execution cache preventing CometBatchRDD 
numPartitions change
    
    * refactor
---
 .../sql/comet/CometBroadcastExchangeExec.scala     | 30 +++++++++++-----------
 .../org/apache/spark/sql/comet/operators.scala     | 20 +++------------
 2 files changed, 19 insertions(+), 31 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
index 9114caf6e..95770592f 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala
@@ -26,7 +26,7 @@ import scala.concurrent.{ExecutionContext, Promise}
 import scala.concurrent.duration.NANOSECONDS
 import scala.util.control.NonFatal
 
-import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
+import org.apache.spark.{broadcast, Partition, SparkContext, SparkException, 
TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -102,14 +102,8 @@ case class CometBroadcastExchangeExec(
   @transient
   private lazy val maxBroadcastRows = 512000000
 
-  private var numPartitions: Option[Int] = None
-
-  def setNumPartitions(numPartitions: Int): CometBroadcastExchangeExec = {
-    this.numPartitions = Some(numPartitions)
-    this
-  }
   def getNumPartitions(): Int = {
-    numPartitions.getOrElse(child.executeColumnar().getNumPartitions)
+    child.executeColumnar().getNumPartitions
   }
 
   @transient
@@ -224,6 +218,18 @@ case class CometBroadcastExchangeExec(
     new CometBatchRDD(sparkContext, getNumPartitions(), broadcasted)
   }
 
+  // After https://issues.apache.org/jira/browse/SPARK-48195, Spark plan will 
cache created RDD.
+  // Since we may change the number of partitions in CometBatchRDD,
+  // we need a method that always creates a new CometBatchRDD.
+  def executeColumnar(numPartitions: Int): RDD[ColumnarBatch] = {
+    if (isCanonicalizedPlan) {
+      throw SparkException.internalError("A canonicalized plan is not supposed 
to be executed.")
+    }
+
+    val broadcasted = executeBroadcast[Array[ChunkedByteBuffer]]()
+    new CometBatchRDD(sparkContext, numPartitions, broadcasted)
+  }
+
   override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] 
= {
     try {
       relationFuture.get(timeout, 
TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
@@ -276,7 +282,7 @@ object CometBroadcastExchangeExec {
  */
 class CometBatchRDD(
     sc: SparkContext,
-    @volatile var numPartitions: Int,
+    val numPartitions: Int,
     value: broadcast.Broadcast[Array[ChunkedByteBuffer]])
     extends RDD[ColumnarBatch](sc, Nil) {
 
@@ -289,12 +295,6 @@ class CometBatchRDD(
     partition.value.value.toIterator
       .flatMap(Utils.decodeBatches(_, this.getClass.getSimpleName))
   }
-
-  def withNumPartitions(numPartitions: Int): CometBatchRDD = {
-    this.numPartitions = numPartitions
-    this
-  }
-
 }
 
 class CometBatchPartition(
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index aa0ecdcb6..a7cfacc47 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -274,28 +274,16 @@ abstract class CometNativeExec extends CometExec {
         sparkPlans.zipWithIndex.foreach { case (plan, idx) =>
           plan match {
             case c: CometBroadcastExchangeExec =>
-              inputs += c
-                .executeColumnar()
-                .asInstanceOf[CometBatchRDD]
-                .withNumPartitions(firstNonBroadcastPlanNumPartitions)
+              inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
             case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) 
=>
-              inputs += c
-                .executeColumnar()
-                .asInstanceOf[CometBatchRDD]
-                .withNumPartitions(firstNonBroadcastPlanNumPartitions)
+              inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
             case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) =>
-              inputs += c
-                .executeColumnar()
-                .asInstanceOf[CometBatchRDD]
-                .withNumPartitions(firstNonBroadcastPlanNumPartitions)
+              inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
             case BroadcastQueryStageExec(
                   _,
                   ReusedExchangeExec(_, c: CometBroadcastExchangeExec),
                   _) =>
-              inputs += c
-                .executeColumnar()
-                .asInstanceOf[CometBatchRDD]
-                .withNumPartitions(firstNonBroadcastPlanNumPartitions)
+              inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
             case _: CometNativeExec =>
             // no-op
             case _ if idx == firstNonBroadcastPlan.get._2 =>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to