peter-toth commented on code in PR #3349:
URL: https://github.com/apache/datafusion-comet/pull/3349#discussion_r2754880669
##########
spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala:
##########
@@ -19,39 +19,207 @@
package org.apache.spark.sql.comet
-import org.apache.spark.{Partition, SparkContext, TaskContext}
-import org.apache.spark.rdd.{RDD, RDDOperationScope}
+import org.apache.spark._
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.execution.ScalarSubquery
import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.SerializableConfiguration
+
+import org.apache.comet.CometExecIterator
+import org.apache.comet.serde.OperatorOuterClass
+
+/**
+ * Partition that carries per-partition planning data, avoiding closure
capture of all partitions.
+ */
+private[spark] class CometExecPartition(
+ override val index: Int,
+ val inputPartitions: Array[Partition],
+ val planDataByKey: Map[String, Array[Byte]])
+ extends Partition
/**
- * A RDD that executes Spark SQL query in Comet native execution to generate
ColumnarBatch.
+ * Unified RDD for Comet native execution.
+ *
+ * Solves the closure capture problem: instead of capturing all partitions'
data in the closure
+ * (which gets serialized to every task), each Partition object carries only
its own data.
+ *
+ * Handles three cases:
+ * - With inputs + per-partition data: injects planning data into operator
tree
+ * - With inputs + no per-partition data: just zips inputs (no injection
overhead)
+ * - No inputs: uses numPartitions to create partitions
+ *
+ * NOTE: This RDD does not handle DPP (InSubqueryExec), which is resolved in
+ * CometIcebergNativeScanExec.serializedPartitionData before this RDD is
created. It also handles
+ * ScalarSubquery expressions by registering them with CometScalarSubquery
before execution.
*/
private[spark] class CometExecRDD(
sc: SparkContext,
- partitionNum: Int,
- var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch])
+ inputRDDs: Seq[RDD[ColumnarBatch]],
+ commonByKey: Map[String, Array[Byte]],
+ @transient perPartitionByKey: Map[String, Array[Array[Byte]]],
+ serializedPlan: Array[Byte],
+ defaultNumPartitions: Int,
+ numOutputCols: Int,
+ nativeMetrics: CometMetricNode,
+ subqueries: Seq[ScalarSubquery],
+ broadcastedHadoopConfForEncryption:
Option[Broadcast[SerializableConfiguration]] = None,
+ encryptedFilePaths: Seq[String] = Seq.empty)
extends RDD[ColumnarBatch](sc, Nil) {
- override def compute(s: Partition, context: TaskContext):
Iterator[ColumnarBatch] = {
- f(Seq.empty, partitionNum, s.index)
+ // Determine partition count: from inputs if available, otherwise from
parameter
+ private val numPartitions: Int = if (inputRDDs.nonEmpty) {
+ inputRDDs.head.partitions.length
+ } else if (perPartitionByKey.nonEmpty) {
+ perPartitionByKey.values.head.length
+ } else {
+ defaultNumPartitions
}
+ // Validate all per-partition arrays have the same length to prevent
+ // ArrayIndexOutOfBoundsException in getPartitions (e.g., from broadcast
scans with
+ // different partition counts after DPP filtering)
+ require(
+ perPartitionByKey.values.forall(_.length == numPartitions),
+ s"All per-partition arrays must have length $numPartitions, but found: " +
+ perPartitionByKey.map { case (key, arr) => s"$key -> ${arr.length}"
}.mkString(", "))
+
override protected def getPartitions: Array[Partition] = {
- Array.tabulate(partitionNum)(i =>
- new Partition {
- override def index: Int = i
- })
+ (0 until numPartitions).map { idx =>
+ val inputParts = inputRDDs.map(_.partitions(idx)).toArray
+ val planData = perPartitionByKey.map { case (key, arr) => key ->
arr(idx) }
+ new CometExecPartition(idx, inputParts, planData)
+ }.toArray
+ }
+
+ override def compute(split: Partition, context: TaskContext):
Iterator[ColumnarBatch] = {
+ val partition = split.asInstanceOf[CometExecPartition]
+
+ val inputs = inputRDDs.zip(partition.inputPartitions).map { case (rdd,
part) =>
+ rdd.iterator(part, context)
+ }
+
+ // Only inject if we have per-partition planning data
+ val actualPlan = if (commonByKey.nonEmpty) {
+ val basePlan = OperatorOuterClass.Operator.parseFrom(serializedPlan)
+ val injected =
+ PlanDataInjector.injectPlanData(basePlan, commonByKey,
partition.planDataByKey)
+ PlanDataInjector.serializeOperator(injected)
+ } else {
+ serializedPlan
+ }
+
+ val it = new CometExecIterator(
+ CometExec.newIterId,
+ inputs,
+ numOutputCols,
+ actualPlan,
+ nativeMetrics,
+ numPartitions,
+ partition.index,
+ broadcastedHadoopConfForEncryption,
+ encryptedFilePaths)
+
+ // Register ScalarSubqueries so native code can look them up
+ subqueries.foreach(sub => CometScalarSubquery.setSubquery(it.id, sub))
+
+ Option(context).foreach { ctx =>
+ ctx.addTaskCompletionListener[Unit] { _ =>
+ it.close()
+ subqueries.foreach(sub => CometScalarSubquery.removeSubquery(it.id,
sub))
+ }
+ }
+
+ it
+ }
+
+ override def getDependencies: Seq[Dependency[_]] =
Review Comment:
You can pass this in (instead of `Nil`) when you extend
`RDD[ColumnarBatch](sc, Nil)`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]