This is an automated email from the ASF dual-hosted git repository. viirya pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new 5400fd7c fix: Use RDD partition index (#1112) 5400fd7c is described below commit 5400fd7c372d2f97ba607457f7d1098e36e2a6e8 Author: Liang-Chi Hsieh <vii...@gmail.com> AuthorDate: Sun Nov 24 23:13:25 2024 -0800 fix: Use RDD partition index (#1112) * fix: Use RDD partition index * fix * fix * fix --- .../scala/org/apache/comet/CometExecIterator.scala | 12 ++++++-- .../apache/spark/sql/comet/CometExecUtils.scala | 5 +-- .../sql/comet/CometTakeOrderedAndProjectExec.scala | 7 +++-- .../spark/sql/comet/ZippedPartitionsRDD.scala | 11 +++++-- .../shuffle/CometShuffleExchangeExec.scala | 10 ++++-- .../org/apache/spark/sql/comet/operators.scala | 36 +++++++++++++++++----- .../scala/org/apache/comet/CometNativeSuite.scala | 4 ++- 7 files changed, 64 insertions(+), 21 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 8a349bd3..bff3e792 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -39,13 +39,19 @@ import org.apache.comet.vector.NativeUtil * The input iterators producing sequence of batches of Arrow Arrays. * @param protobufQueryPlan * The serialized bytes of Spark execution plan. + * @param numParts + * The number of partitions. + * @param partitionIndex + * The index of the partition. */ class CometExecIterator( val id: Long, inputs: Seq[Iterator[ColumnarBatch]], numOutputCols: Int, protobufQueryPlan: Array[Byte], - nativeMetrics: CometMetricNode) + nativeMetrics: CometMetricNode, + numParts: Int, + partitionIndex: Int) extends Iterator[ColumnarBatch] { private val nativeLib = new Native() @@ -92,11 +98,13 @@ class CometExecIterator( } def getNextBatch(): Option[ColumnarBatch] = { + assert(partitionIndex >= 0 && partitionIndex < numParts) + nativeUtil.getNextBatch( numOutputCols, (arrayAddrs, schemaAddrs) => { val ctx = TaskContext.get() - nativeLib.executePlan(ctx.stageId(), ctx.partitionId(), plan, arrayAddrs, schemaAddrs) + nativeLib.executePlan(ctx.stageId(), partitionIndex, plan, arrayAddrs, schemaAddrs) }) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index 8cc03856..9698dc98 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -51,9 +51,10 @@ object CometExecUtils { childPlan: RDD[ColumnarBatch], outputAttribute: Seq[Attribute], limit: Int): RDD[ColumnarBatch] = { - childPlan.mapPartitionsInternal { iter => + val numParts = childPlan.getNumPartitions + childPlan.mapPartitionsWithIndexInternal { case (idx, iter) => val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get - CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp) + CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, numParts, idx) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index 6220c809..5582f4d6 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -77,12 +77,13 @@ case class CometTakeOrderedAndProjectExec( val localTopK = if (orderingSatisfies) { CometExecUtils.getNativeLimitRDD(childRDD, child.output, limit) } else { - childRDD.mapPartitionsInternal { iter => + val numParts = childRDD.getNumPartitions + childRDD.mapPartitionsWithIndexInternal { case (idx, iter) => val topK = CometExecUtils .getTopKNativePlan(child.output, sortOrder, child, limit) .get - CometExec.getCometIterator(Seq(iter), child.output.length, topK) + CometExec.getCometIterator(Seq(iter), child.output.length, topK, numParts, idx) } } @@ -102,7 +103,7 @@ case class CometTakeOrderedAndProjectExec( val topKAndProjection = CometExecUtils .getProjectionNativePlan(projectList, child.output, sortOrder, child, limit) .get - val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection) + val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection, 1, 0) setSubqueries(it.id, this) Option(TaskContext.get()).foreach { context => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala index 6db8c67d..fdf8bf39 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala @@ -31,16 +31,20 @@ import org.apache.spark.sql.vectorized.ColumnarBatch */ private[spark] class ZippedPartitionsRDD( sc: SparkContext, - var f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch], + var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch], var zipRdds: Seq[RDD[ColumnarBatch]], preservesPartitioning: Boolean = false) extends ZippedPartitionsBaseRDD[ColumnarBatch](sc, zipRdds, preservesPartitioning) { + // We need to get the number of partitions in `compute` but `getNumPartitions` is not available + // on the executors. So we need to capture it here. + private val numParts: Int = this.getNumPartitions + override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions val iterators = zipRdds.zipWithIndex.map(pair => pair._1.iterator(partitions(pair._2), context)) - f(iterators) + f(iterators, numParts, s.index) } override def clearDependencies(): Unit = { @@ -52,7 +56,8 @@ private[spark] class ZippedPartitionsRDD( object ZippedPartitionsRDD { def apply(sc: SparkContext, rdds: Seq[RDD[ColumnarBatch]])( - f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]): RDD[ColumnarBatch] = + f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch]) + : RDD[ColumnarBatch] = withScope(sc) { new ZippedPartitionsRDD(sc, f, rdds) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 4c3f994f..a7a33c40 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -227,13 +227,14 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { outputPartitioning: Partitioning, serializer: Serializer, metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + val numParts = rdd.getNumPartitions val dependency = new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( rdd.map( (0, _) ), // adding fake partitionId that is always 0 because ShuffleDependency requires it serializer = serializer, shuffleWriterProcessor = - new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics), + new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics, numParts), shuffleType = CometNativeShuffle, partitioner = new Partitioner { override def numPartitions: Int = outputPartitioning.numPartitions @@ -449,7 +450,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { class CometShuffleWriteProcessor( outputPartitioning: Partitioning, outputAttributes: Seq[Attribute], - metrics: Map[String, SQLMetric]) + metrics: Map[String, SQLMetric], + numParts: Int) extends ShimCometShuffleWriteProcessor { private val OFFSET_LENGTH = 8 @@ -499,7 +501,9 @@ class CometShuffleWriteProcessor( Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), outputAttributes.length, nativePlan, - nativeMetrics) + nativeMetrics, + numParts, + context.partitionId()) while (cometIter.hasNext) { cometIter.next() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index dd1526d8..77188312 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -120,20 +120,37 @@ object CometExec { def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], numOutputCols: Int, - nativePlan: Operator): CometExecIterator = { - getCometIterator(inputs, numOutputCols, nativePlan, CometMetricNode(Map.empty)) + nativePlan: Operator, + numParts: Int, + partitionIdx: Int): CometExecIterator = { + getCometIterator( + inputs, + numOutputCols, + nativePlan, + CometMetricNode(Map.empty), + numParts, + partitionIdx) } def getCometIterator( inputs: Seq[Iterator[ColumnarBatch]], numOutputCols: Int, nativePlan: Operator, - nativeMetrics: CometMetricNode): CometExecIterator = { + nativeMetrics: CometMetricNode, + numParts: Int, + partitionIdx: Int): CometExecIterator = { val outputStream = new ByteArrayOutputStream() nativePlan.writeTo(outputStream) outputStream.close() val bytes = outputStream.toByteArray - new CometExecIterator(newIterId, inputs, numOutputCols, bytes, nativeMetrics) + new CometExecIterator( + newIterId, + inputs, + numOutputCols, + bytes, + nativeMetrics, + numParts, + partitionIdx) } /** @@ -214,13 +231,18 @@ abstract class CometNativeExec extends CometExec { // TODO: support native metrics for all operators. val nativeMetrics = CometMetricNode.fromCometPlan(this) - def createCometExecIter(inputs: Seq[Iterator[ColumnarBatch]]): CometExecIterator = { + def createCometExecIter( + inputs: Seq[Iterator[ColumnarBatch]], + numParts: Int, + partitionIndex: Int): CometExecIterator = { val it = new CometExecIterator( CometExec.newIterId, inputs, output.length, serializedPlanCopy, - nativeMetrics) + nativeMetrics, + numParts, + partitionIndex) setSubqueries(it.id, this) @@ -295,7 +317,7 @@ abstract class CometNativeExec extends CometExec { throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") } - ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter(_)) + ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter) } } diff --git a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala index ef0485df..6ca38e83 100644 --- a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala @@ -37,7 +37,9 @@ class CometNativeSuite extends CometTestBase { override def next(): ColumnarBatch = throw new NullPointerException() }), 1, - limitOp) + limitOp, + 1, + 0) cometIter.next() cometIter.close() value --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org