peter-toth commented on code in PR #3349:
URL: https://github.com/apache/datafusion-comet/pull/3349#discussion_r2754880669


##########
spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala:
##########
@@ -19,39 +19,207 @@
 
 package org.apache.spark.sql.comet
 
-import org.apache.spark.{Partition, SparkContext, TaskContext}
-import org.apache.spark.rdd.{RDD, RDDOperationScope}
+import org.apache.spark._
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.execution.ScalarSubquery
 import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.SerializableConfiguration
+
+import org.apache.comet.CometExecIterator
+import org.apache.comet.serde.OperatorOuterClass
+
+/**
+ * Partition that carries per-partition planning data, avoiding closure 
capture of all partitions.
+ */
+private[spark] class CometExecPartition(
+    override val index: Int,
+    val inputPartitions: Array[Partition],
+    val planDataByKey: Map[String, Array[Byte]])
+    extends Partition
 
 /**
- * A RDD that executes Spark SQL query in Comet native execution to generate 
ColumnarBatch.
+ * Unified RDD for Comet native execution.
+ *
+ * Solves the closure capture problem: instead of capturing all partitions' 
data in the closure
+ * (which gets serialized to every task), each Partition object carries only 
its own data.
+ *
+ * Handles three cases:
+ *   - With inputs + per-partition data: injects planning data into operator 
tree
+ *   - With inputs + no per-partition data: just zips inputs (no injection 
overhead)
+ *   - No inputs: uses numPartitions to create partitions
+ *
+ * NOTE: This RDD does not handle DPP (InSubqueryExec), which is resolved in
+ * CometIcebergNativeScanExec.serializedPartitionData before this RDD is 
created. It also handles
+ * ScalarSubquery expressions by registering them with CometScalarSubquery 
before execution.
  */
 private[spark] class CometExecRDD(
     sc: SparkContext,
-    partitionNum: Int,
-    var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch])
+    inputRDDs: Seq[RDD[ColumnarBatch]],
+    commonByKey: Map[String, Array[Byte]],
+    @transient perPartitionByKey: Map[String, Array[Array[Byte]]],
+    serializedPlan: Array[Byte],
+    defaultNumPartitions: Int,
+    numOutputCols: Int,
+    nativeMetrics: CometMetricNode,
+    subqueries: Seq[ScalarSubquery],
+    broadcastedHadoopConfForEncryption: 
Option[Broadcast[SerializableConfiguration]] = None,
+    encryptedFilePaths: Seq[String] = Seq.empty)
     extends RDD[ColumnarBatch](sc, Nil) {
 
-  override def compute(s: Partition, context: TaskContext): 
Iterator[ColumnarBatch] = {
-    f(Seq.empty, partitionNum, s.index)
+  // Determine partition count: from inputs if available, otherwise from 
parameter
+  private val numPartitions: Int = if (inputRDDs.nonEmpty) {
+    inputRDDs.head.partitions.length
+  } else if (perPartitionByKey.nonEmpty) {
+    perPartitionByKey.values.head.length
+  } else {
+    defaultNumPartitions
   }
 
+  // Validate all per-partition arrays have the same length to prevent
+  // ArrayIndexOutOfBoundsException in getPartitions (e.g., from broadcast 
scans with
+  // different partition counts after DPP filtering)
+  require(
+    perPartitionByKey.values.forall(_.length == numPartitions),
+    s"All per-partition arrays must have length $numPartitions, but found: " +
+      perPartitionByKey.map { case (key, arr) => s"$key -> ${arr.length}" 
}.mkString(", "))
+
   override protected def getPartitions: Array[Partition] = {
-    Array.tabulate(partitionNum)(i =>
-      new Partition {
-        override def index: Int = i
-      })
+    (0 until numPartitions).map { idx =>
+      val inputParts = inputRDDs.map(_.partitions(idx)).toArray
+      val planData = perPartitionByKey.map { case (key, arr) => key -> 
arr(idx) }
+      new CometExecPartition(idx, inputParts, planData)
+    }.toArray
+  }
+
+  override def compute(split: Partition, context: TaskContext): 
Iterator[ColumnarBatch] = {
+    val partition = split.asInstanceOf[CometExecPartition]
+
+    val inputs = inputRDDs.zip(partition.inputPartitions).map { case (rdd, 
part) =>
+      rdd.iterator(part, context)
+    }
+
+    // Only inject if we have per-partition planning data
+    val actualPlan = if (commonByKey.nonEmpty) {
+      val basePlan = OperatorOuterClass.Operator.parseFrom(serializedPlan)
+      val injected =
+        PlanDataInjector.injectPlanData(basePlan, commonByKey, 
partition.planDataByKey)
+      PlanDataInjector.serializeOperator(injected)
+    } else {
+      serializedPlan
+    }
+
+    val it = new CometExecIterator(
+      CometExec.newIterId,
+      inputs,
+      numOutputCols,
+      actualPlan,
+      nativeMetrics,
+      numPartitions,
+      partition.index,
+      broadcastedHadoopConfForEncryption,
+      encryptedFilePaths)
+
+    // Register ScalarSubqueries so native code can look them up
+    subqueries.foreach(sub => CometScalarSubquery.setSubquery(it.id, sub))
+
+    Option(context).foreach { ctx =>
+      ctx.addTaskCompletionListener[Unit] { _ =>
+        it.close()
+        subqueries.foreach(sub => CometScalarSubquery.removeSubquery(it.id, 
sub))
+      }
+    }
+
+    it
+  }
+
+  override def getDependencies: Seq[Dependency[_]] =

Review Comment:
   You can pass this in (instead of `Nil`) where you extend 
`RDD[ColumnarBatch](sc, Nil)`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to