cloud-fan commented on a change in pull request #26516: [SPARK-29893] improve 
the local shuffle reader performance by changing the reading task number from 1 
to multi.
URL: https://github.com/apache/spark/pull/26516#discussion_r347728403
 
 

 ##########
 File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala
 ##########
 @@ -52,19 +56,52 @@ private final class LocalShuffledRowRDDPartition(
  */
 class LocalShuffledRowRDD(
      var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
-     metrics: Map[String, SQLMetric])
+     metrics: Map[String, SQLMetric],
+     advisoryParallelism : Option[Int] = None)
   extends RDD[InternalRow](dependency.rdd.context, Nil) {
 
   private[this] val numReducers = dependency.partitioner.numPartitions
   private[this] val numMappers = dependency.rdd.partitions.length
 
   override def getDependencies: Seq[Dependency[_]] = List(dependency)
 
-  override def getPartitions: Array[Partition] = {
+  /**
+   * To equally divide n elements into m buckets, basically each bucket should 
have n/m elements,
+   * for the remaining n%m elements, add one more element to the first n%m 
buckets each. Returns
+   * a sequence with length numBuckets and each value represents the start 
index of each bucket.
+   */
+  def equallyDivide(numElements: Int, numBuckets: Int): Seq[Int] = {
+    val elementsPerBucket = numElements / numBuckets
+    val remaining = numElements % numBuckets
+    val splitPoint = (elementsPerBucket + 1) * remaining
+    (0 until remaining).map(_ * (elementsPerBucket + 1)) ++
+      (remaining until numBuckets).map(i => splitPoint + (i - remaining) * 
elementsPerBucket)
+  }
+
+  private[this] val partitionStartIndices: Array[Int] = {
+    val expectedParallelism = advisoryParallelism.getOrElse(numReducers)
+    // TODO split by data size in the future.
+    equallyDivide(numReducers, math.max(1, expectedParallelism / 
numMappers)).toArray
+  }
+
+  private[this] val partitionEndIndices: Array[Int] =
+    Array.tabulate[Int](partitionStartIndices.length) { i =>
+      if (i < partitionStartIndices.length -1) {
+        partitionStartIndices(i + 1)
+      } else numReducers
+  }
 
-    Array.tabulate[Partition](numMappers) { i =>
-      new LocalShuffledRowRDDPartition(i)
+  override def getPartitions: Array[Partition] = {
+    assert(partitionStartIndices.length == partitionEndIndices.length)
 
 Review comment:
   we don't need to create `partitionEndIndices`. We can
   ```
   var partitionIndex = 0
   for (mapIndex <- 0 until numMappers) {
     (partitionStartIndices :+ numReducers).sliding(2, 1).foreach { case 
Seq(start, end) =>
       partitions += new LocalShuffledRowRDDPartition(partitionIndex, mapIndex, 
start, end)
       partitionIndex += 1
     }
   }
   
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to