This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 5400fd7c fix: Use RDD partition index (#1112)
5400fd7c is described below

commit 5400fd7c372d2f97ba607457f7d1098e36e2a6e8
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Sun Nov 24 23:13:25 2024 -0800

    fix: Use RDD partition index (#1112)
    
    * fix: Use RDD partition index
    
    * fix
    
    * fix
    
    * fix
---
 .../scala/org/apache/comet/CometExecIterator.scala | 12 ++++++--
 .../apache/spark/sql/comet/CometExecUtils.scala    |  5 +--
 .../sql/comet/CometTakeOrderedAndProjectExec.scala |  7 +++--
 .../spark/sql/comet/ZippedPartitionsRDD.scala      | 11 +++++--
 .../shuffle/CometShuffleExchangeExec.scala         | 10 ++++--
 .../org/apache/spark/sql/comet/operators.scala     | 36 +++++++++++++++++-----
 .../scala/org/apache/comet/CometNativeSuite.scala  |  4 ++-
 7 files changed, 64 insertions(+), 21 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala 
b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
index 8a349bd3..bff3e792 100644
--- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
+++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
@@ -39,13 +39,19 @@ import org.apache.comet.vector.NativeUtil
  *   The input iterators producing sequence of batches of Arrow Arrays.
  * @param protobufQueryPlan
  *   The serialized bytes of Spark execution plan.
+ * @param numParts
+ *   The number of partitions.
+ * @param partitionIndex
+ *   The index of the partition.
  */
 class CometExecIterator(
     val id: Long,
     inputs: Seq[Iterator[ColumnarBatch]],
     numOutputCols: Int,
     protobufQueryPlan: Array[Byte],
-    nativeMetrics: CometMetricNode)
+    nativeMetrics: CometMetricNode,
+    numParts: Int,
+    partitionIndex: Int)
     extends Iterator[ColumnarBatch] {
 
   private val nativeLib = new Native()
@@ -92,11 +98,13 @@ class CometExecIterator(
   }
 
   def getNextBatch(): Option[ColumnarBatch] = {
+    assert(partitionIndex >= 0 && partitionIndex < numParts)
+
     nativeUtil.getNextBatch(
       numOutputCols,
       (arrayAddrs, schemaAddrs) => {
         val ctx = TaskContext.get()
-        nativeLib.executePlan(ctx.stageId(), ctx.partitionId(), plan, 
arrayAddrs, schemaAddrs)
+        nativeLib.executePlan(ctx.stageId(), partitionIndex, plan, arrayAddrs, 
schemaAddrs)
       })
   }
 
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala
index 8cc03856..9698dc98 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala
@@ -51,9 +51,10 @@ object CometExecUtils {
       childPlan: RDD[ColumnarBatch],
       outputAttribute: Seq[Attribute],
       limit: Int): RDD[ColumnarBatch] = {
-    childPlan.mapPartitionsInternal { iter =>
+    val numParts = childPlan.getNumPartitions
+    childPlan.mapPartitionsWithIndexInternal { case (idx, iter) =>
       val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, 
limit).get
-      CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp)
+      CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, 
numParts, idx)
     }
   }
 
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala
index 6220c809..5582f4d6 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala
@@ -77,12 +77,13 @@ case class CometTakeOrderedAndProjectExec(
         val localTopK = if (orderingSatisfies) {
           CometExecUtils.getNativeLimitRDD(childRDD, child.output, limit)
         } else {
-          childRDD.mapPartitionsInternal { iter =>
+          val numParts = childRDD.getNumPartitions
+          childRDD.mapPartitionsWithIndexInternal { case (idx, iter) =>
             val topK =
               CometExecUtils
                 .getTopKNativePlan(child.output, sortOrder, child, limit)
                 .get
-            CometExec.getCometIterator(Seq(iter), child.output.length, topK)
+            CometExec.getCometIterator(Seq(iter), child.output.length, topK, 
numParts, idx)
           }
         }
 
@@ -102,7 +103,7 @@ case class CometTakeOrderedAndProjectExec(
         val topKAndProjection = CometExecUtils
           .getProjectionNativePlan(projectList, child.output, sortOrder, 
child, limit)
           .get
-        val it = CometExec.getCometIterator(Seq(iter), output.length, 
topKAndProjection)
+        val it = CometExec.getCometIterator(Seq(iter), output.length, 
topKAndProjection, 1, 0)
         setSubqueries(it.id, this)
 
         Option(TaskContext.get()).foreach { context =>
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala
index 6db8c67d..fdf8bf39 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala
@@ -31,16 +31,20 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
  */
 private[spark] class ZippedPartitionsRDD(
     sc: SparkContext,
-    var f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch],
+    var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch],
     var zipRdds: Seq[RDD[ColumnarBatch]],
     preservesPartitioning: Boolean = false)
     extends ZippedPartitionsBaseRDD[ColumnarBatch](sc, zipRdds, 
preservesPartitioning) {
 
+  // We need to get the number of partitions in `compute` but 
`getNumPartitions` is not available
+  // on the executors. So we need to capture it here.
+  private val numParts: Int = this.getNumPartitions
+
   override def compute(s: Partition, context: TaskContext): 
Iterator[ColumnarBatch] = {
     val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
     val iterators =
       zipRdds.zipWithIndex.map(pair => pair._1.iterator(partitions(pair._2), 
context))
-    f(iterators)
+    f(iterators, numParts, s.index)
   }
 
   override def clearDependencies(): Unit = {
@@ -52,7 +56,8 @@ private[spark] class ZippedPartitionsRDD(
 
 object ZippedPartitionsRDD {
   def apply(sc: SparkContext, rdds: Seq[RDD[ColumnarBatch]])(
-      f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]): 
RDD[ColumnarBatch] =
+      f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch])
+      : RDD[ColumnarBatch] =
     withScope(sc) {
       new ZippedPartitionsRDD(sc, f, rdds)
     }
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index 4c3f994f..a7a33c40 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -227,13 +227,14 @@ object CometShuffleExchangeExec extends 
ShimCometShuffleExchangeExec {
       outputPartitioning: Partitioning,
       serializer: Serializer,
       metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, 
ColumnarBatch] = {
+    val numParts = rdd.getNumPartitions
     val dependency = new CometShuffleDependency[Int, ColumnarBatch, 
ColumnarBatch](
       rdd.map(
         (0, _)
       ), // adding fake partitionId that is always 0 because ShuffleDependency 
requires it
       serializer = serializer,
       shuffleWriterProcessor =
-        new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, 
metrics),
+        new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, 
metrics, numParts),
       shuffleType = CometNativeShuffle,
       partitioner = new Partitioner {
         override def numPartitions: Int = outputPartitioning.numPartitions
@@ -449,7 +450,8 @@ object CometShuffleExchangeExec extends 
ShimCometShuffleExchangeExec {
 class CometShuffleWriteProcessor(
     outputPartitioning: Partitioning,
     outputAttributes: Seq[Attribute],
-    metrics: Map[String, SQLMetric])
+    metrics: Map[String, SQLMetric],
+    numParts: Int)
     extends ShimCometShuffleWriteProcessor {
 
   private val OFFSET_LENGTH = 8
@@ -499,7 +501,9 @@ class CometShuffleWriteProcessor(
       Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
       outputAttributes.length,
       nativePlan,
-      nativeMetrics)
+      nativeMetrics,
+      numParts,
+      context.partitionId())
 
     while (cometIter.hasNext) {
       cometIter.next()
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index dd1526d8..77188312 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -120,20 +120,37 @@ object CometExec {
   def getCometIterator(
       inputs: Seq[Iterator[ColumnarBatch]],
       numOutputCols: Int,
-      nativePlan: Operator): CometExecIterator = {
-    getCometIterator(inputs, numOutputCols, nativePlan, 
CometMetricNode(Map.empty))
+      nativePlan: Operator,
+      numParts: Int,
+      partitionIdx: Int): CometExecIterator = {
+    getCometIterator(
+      inputs,
+      numOutputCols,
+      nativePlan,
+      CometMetricNode(Map.empty),
+      numParts,
+      partitionIdx)
   }
 
   def getCometIterator(
       inputs: Seq[Iterator[ColumnarBatch]],
       numOutputCols: Int,
       nativePlan: Operator,
-      nativeMetrics: CometMetricNode): CometExecIterator = {
+      nativeMetrics: CometMetricNode,
+      numParts: Int,
+      partitionIdx: Int): CometExecIterator = {
     val outputStream = new ByteArrayOutputStream()
     nativePlan.writeTo(outputStream)
     outputStream.close()
     val bytes = outputStream.toByteArray
-    new CometExecIterator(newIterId, inputs, numOutputCols, bytes, 
nativeMetrics)
+    new CometExecIterator(
+      newIterId,
+      inputs,
+      numOutputCols,
+      bytes,
+      nativeMetrics,
+      numParts,
+      partitionIdx)
   }
 
   /**
@@ -214,13 +231,18 @@ abstract class CometNativeExec extends CometExec {
         // TODO: support native metrics for all operators.
         val nativeMetrics = CometMetricNode.fromCometPlan(this)
 
-        def createCometExecIter(inputs: Seq[Iterator[ColumnarBatch]]): 
CometExecIterator = {
+        def createCometExecIter(
+            inputs: Seq[Iterator[ColumnarBatch]],
+            numParts: Int,
+            partitionIndex: Int): CometExecIterator = {
           val it = new CometExecIterator(
             CometExec.newIterId,
             inputs,
             output.length,
             serializedPlanCopy,
-            nativeMetrics)
+            nativeMetrics,
+            numParts,
+            partitionIndex)
 
           setSubqueries(it.id, this)
 
@@ -295,7 +317,7 @@ abstract class CometNativeExec extends CometExec {
           throw new CometRuntimeException(s"No input for CometNativeExec:\n 
$this")
         }
 
-        ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter(_))
+        ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter)
     }
   }
 
diff --git a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala
index ef0485df..6ca38e83 100644
--- a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala
@@ -37,7 +37,9 @@ class CometNativeSuite extends CometTestBase {
           override def next(): ColumnarBatch = throw new NullPointerException()
         }),
         1,
-        limitOp)
+        limitOp,
+        1,
+        0)
       cometIter.next()
       cometIter.close()
       value


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to