rmcyang commented on code in PR #34500:
URL: https://github.com/apache/spark/pull/34500#discussion_r847769241


##########
core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala:
##########
@@ -1371,22 +1383,84 @@ private[spark] class DAGScheduler(
   private def prepareShuffleServicesForShuffleMapStage(stage: 
ShuffleMapStage): Unit = {
     assert(stage.shuffleDep.shuffleMergeEnabled && 
!stage.shuffleDep.shuffleMergeFinalized)
     if (stage.shuffleDep.getMergerLocs.isEmpty) {
-      val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
-        stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      getAndSetShufflePushMergerLocations(stage)
+    }
+  }
+
+  private def getAndSetShufflePushMergerLocations(stage: ShuffleMapStage): 
Seq[BlockManagerId] = {
+    if (stage.shuffleDep.getMergerLocs.isEmpty && 
!stage.shuffleDep.shuffleMergeFinalized) {
+      val coPartitionedSiblingStages = if (reuseMergerLocations) {
+        // Reuse merger locations for sibling stages (for eg: join cases) so 
that
+        // both the RDD's output will be collocated giving better locality. 
Since this
+        // method is invoked only within the event loop thread, it's safe to 
find sibling
+        // stages and set the merger locations accordingly in this way.
+        findCoPartitionedSiblingMapStages(stage)
+      } else {
+        Set.empty[ShuffleMapStage]
+      }
+      val mergerLocs = if (reuseMergerLocations) {
+        coPartitionedSiblingStages.collectFirst({
+          case s if s.shuffleDep.getMergerLocs.nonEmpty => 
s.shuffleDep.getMergerLocs})
+          .getOrElse(sc.schedulerBackend.getShufflePushMergerLocations(
+            stage.shuffleDep.partitioner.numPartitions, 
stage.resourceProfileId))
+      } else {
+        sc.schedulerBackend.getShufflePushMergerLocations(
+          stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      }
       if (mergerLocs.nonEmpty) {
         stage.shuffleDep.setMergerLocs(mergerLocs)
-        logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
-          s" ${stage.shuffleDep.getMergerLocs.size} merger locations")
-
-        logDebug("List of shuffle push merger locations " +
-          s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+        if (reuseMergerLocations) {
+          coPartitionedSiblingStages.filter(_.shuffleDep.getMergerLocs.isEmpty)
+            .foreach(_.shuffleDep.setMergerLocs(mergerLocs))
+        }
+        stage.shuffleDep.setShuffleMergeEnabled(true)
       } else {
         stage.shuffleDep.setShuffleMergeEnabled(false)
-        logInfo(s"Push-based shuffle disabled for $stage (${stage.name})")
       }
+      logDebug(s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+      mergerLocs
+    } else {
+      stage.shuffleDep.getMergerLocs
     }
   }
 
+  /**
+   * Identify sibling stages for a given ShuffleMapStage. A sibling stage is 
defined as
+   * one stage that shares one or more child stages with the given stage. For 
example,
+   * when we have a join, the reduce stage of the join will be the common 
child stage for
+   * the shuffle map stages shuffling the tables involved in this join. These 
shuffle map
+   * stages are sibling stages to each other. Being able to identify the 
sibling stages would
+   * allow us to set common merger locations for these shuffle map stages so 
that the reduce
+   * stage can have better locality especially for join operations.
+   */
+  private def findCoPartitionedSiblingMapStages(
+      stage: ShuffleMapStage): Set[ShuffleMapStage] = {
+    val numShufflePartitions = stage.shuffleDep.partitioner.numPartitions

Review Comment:
   They implemented a bit differently. In 
[Stage.scala](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/Stage.scala#L66):
   `val numPartitions = rdd.partitions.length`
   And in 
[Partitioner.scala](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/Partitioner.scala#L204):
   ```
   def numPartitions: Int = partitions // HashPartitioner
   def numPartitions: Int = rangeBounds.length + 1 // RangePartitioner
   ```
   I think they should return the same value, will adjust to use 
`stage.numPartitions`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to