mridulm commented on a change in pull request #34500:
URL: https://github.com/apache/spark/pull/34500#discussion_r791962190



##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -2546,6 +2620,14 @@ private[spark] class DAGScheduler(
     }
     listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
     runningStages -= stage
+    if (errorMessage.isEmpty) {
+      // Add succeeded stage into the succeeded set
+      succeededStages += stage

Review comment:
       Add only if stage is a `ShuffleMapStage`

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -164,6 +164,9 @@ private[spark] class DAGScheduler(
   // Stages that must be resubmitted due to fetch failures
   private[scheduler] val failedStages = new HashSet[Stage]
 
+  // Stages that have succeeded yet whose child stages aren't done
+  private[scheduler] val succeededStages = new HashSet[Stage]

Review comment:
       Note:
   We have to be careful with this and ensure things are cleaned up.
   Else it will result in holding on to references and prevent cleanup via 
gc/cleaner.
   
   One thing I am looking at is what is the interaction between 
`succeededStages ` and `submitMapStage` - and what is the possibility of leaks 
here. For example, if a child stages are not submitted after `submitMapStage` 
for a parent completes (due to timeout/etc for example) - will we will end up 
accumulating state and eventually OOM ? (it is not just the cost of the `Stage` 
itself - but the references from the `Stage` which wont get cleaned up).

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -1371,22 +1383,84 @@ private[spark] class DAGScheduler(
   private def prepareShuffleServicesForShuffleMapStage(stage: 
ShuffleMapStage): Unit = {
     assert(stage.shuffleDep.shuffleMergeEnabled && 
!stage.shuffleDep.shuffleMergeFinalized)
     if (stage.shuffleDep.getMergerLocs.isEmpty) {
-      val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
-        stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      getAndSetShufflePushMergerLocations(stage)
+    }
+  }
+
+  private def getAndSetShufflePushMergerLocations(stage: ShuffleMapStage): 
Seq[BlockManagerId] = {
+    if (stage.shuffleDep.getMergerLocs.isEmpty && 
!stage.shuffleDep.shuffleMergeFinalized) {
+      val coPartitionedSiblingStages = if (reuseMergerLocations) {
+        // Reuse merger locations for sibling stages (for eg: join cases) so 
that
+        // both the RDD's output will be collocated giving better locality. 
Since this
+        // method is invoked only within the event loop thread, it's safe to 
find sibling
+        // stages and set the merger locations accordingly in this way.
+        findCoPartitionedSiblingMapStages(stage)
+      } else {
+        Set.empty[ShuffleMapStage]
+      }
+      val mergerLocs = if (reuseMergerLocations) {
+        coPartitionedSiblingStages.collectFirst({
+          case s if s.shuffleDep.getMergerLocs.nonEmpty => 
s.shuffleDep.getMergerLocs})
+          .getOrElse(sc.schedulerBackend.getShufflePushMergerLocations(
+            stage.shuffleDep.partitioner.numPartitions, 
stage.resourceProfileId))

Review comment:
       Formatting is off here, please fix.

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -1371,22 +1383,84 @@ private[spark] class DAGScheduler(
   private def prepareShuffleServicesForShuffleMapStage(stage: 
ShuffleMapStage): Unit = {
     assert(stage.shuffleDep.shuffleMergeEnabled && 
!stage.shuffleDep.shuffleMergeFinalized)
     if (stage.shuffleDep.getMergerLocs.isEmpty) {
-      val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
-        stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      getAndSetShufflePushMergerLocations(stage)
+    }
+  }
+
+  private def getAndSetShufflePushMergerLocations(stage: ShuffleMapStage): 
Seq[BlockManagerId] = {
+    if (stage.shuffleDep.getMergerLocs.isEmpty && 
!stage.shuffleDep.shuffleMergeFinalized) {
+      val coPartitionedSiblingStages = if (reuseMergerLocations) {
+        // Reuse merger locations for sibling stages (for eg: join cases) so 
that
+        // both the RDD's output will be collocated giving better locality. 
Since this
+        // method is invoked only within the event loop thread, it's safe to 
find sibling
+        // stages and set the merger locations accordingly in this way.
+        findCoPartitionedSiblingMapStages(stage)
+      } else {
+        Set.empty[ShuffleMapStage]
+      }
+      val mergerLocs = if (reuseMergerLocations) {
+        coPartitionedSiblingStages.collectFirst({
+          case s if s.shuffleDep.getMergerLocs.nonEmpty => 
s.shuffleDep.getMergerLocs})
+          .getOrElse(sc.schedulerBackend.getShufflePushMergerLocations(
+            stage.shuffleDep.partitioner.numPartitions, 
stage.resourceProfileId))
+      } else {
+        sc.schedulerBackend.getShufflePushMergerLocations(
+          stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      }
       if (mergerLocs.nonEmpty) {
         stage.shuffleDep.setMergerLocs(mergerLocs)
-        logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
-          s" ${stage.shuffleDep.getMergerLocs.size} merger locations")
-
-        logDebug("List of shuffle push merger locations " +
-          s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+        if (reuseMergerLocations) {
+          coPartitionedSiblingStages.filter(_.shuffleDep.getMergerLocs.isEmpty)
+            .foreach(_.shuffleDep.setMergerLocs(mergerLocs))
+        }
+        stage.shuffleDep.setShuffleMergeEnabled(true)
       } else {
         stage.shuffleDep.setShuffleMergeEnabled(false)
-        logInfo(s"Push-based shuffle disabled for $stage (${stage.name})")
       }
+      logDebug(s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")

Review comment:
       Can we reformulate this as something like this (or something more 
clearer) ?
   ```
   val mergerLocs = {
     def findMergerLocs() = sc.schedulerBackend.getShufflePushMergerLocations 
...
   
     // limit copartitioning logic to this if condition
     if (reuseMergerLocations) {
       val coPartitionedSiblingStages = findCoPartitionedSiblingMapStages ...
       val siblingMergerLocs = findCoPartitionedSiblingMapStages.collectFirst 
...
   
       if (siblingMergerLocs.isEmpty) {
         val mergerLocs = findMergerLocs()
         if (mergerLocs.nonEmpty) {
           // set for siblings
           
coPartitionedSiblingStages.foreach(_.shuffleDep.setMergerLocs(mergerLocs))
         }
         mergerLocs
       } else {
         siblingMergerLocs
       }
     } else {
       findMergerLocs() 
     }
   }
   
   if (mergerLocs.nonEmpty) {
     stage.shuffleDep.setMergerLocs(mergerLocs)
     ...
   }
   
   ```

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -1371,22 +1383,84 @@ private[spark] class DAGScheduler(
   private def prepareShuffleServicesForShuffleMapStage(stage: 
ShuffleMapStage): Unit = {
     assert(stage.shuffleDep.shuffleMergeEnabled && 
!stage.shuffleDep.shuffleMergeFinalized)
     if (stage.shuffleDep.getMergerLocs.isEmpty) {
-      val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
-        stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      getAndSetShufflePushMergerLocations(stage)
+    }
+  }
+
+  private def getAndSetShufflePushMergerLocations(stage: ShuffleMapStage): 
Seq[BlockManagerId] = {
+    if (stage.shuffleDep.getMergerLocs.isEmpty && 
!stage.shuffleDep.shuffleMergeFinalized) {
+      val coPartitionedSiblingStages = if (reuseMergerLocations) {
+        // Reuse merger locations for sibling stages (for eg: join cases) so 
that
+        // both the RDD's output will be collocated giving better locality. 
Since this
+        // method is invoked only within the event loop thread, it's safe to 
find sibling
+        // stages and set the merger locations accordingly in this way.
+        findCoPartitionedSiblingMapStages(stage)
+      } else {
+        Set.empty[ShuffleMapStage]
+      }
+      val mergerLocs = if (reuseMergerLocations) {
+        coPartitionedSiblingStages.collectFirst({
+          case s if s.shuffleDep.getMergerLocs.nonEmpty => 
s.shuffleDep.getMergerLocs})
+          .getOrElse(sc.schedulerBackend.getShufflePushMergerLocations(
+            stage.shuffleDep.partitioner.numPartitions, 
stage.resourceProfileId))
+      } else {
+        sc.schedulerBackend.getShufflePushMergerLocations(
+          stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      }
       if (mergerLocs.nonEmpty) {
         stage.shuffleDep.setMergerLocs(mergerLocs)
-        logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
-          s" ${stage.shuffleDep.getMergerLocs.size} merger locations")
-
-        logDebug("List of shuffle push merger locations " +
-          s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+        if (reuseMergerLocations) {
+          coPartitionedSiblingStages.filter(_.shuffleDep.getMergerLocs.isEmpty)
+            .foreach(_.shuffleDep.setMergerLocs(mergerLocs))
+        }
+        stage.shuffleDep.setShuffleMergeEnabled(true)
       } else {
         stage.shuffleDep.setShuffleMergeEnabled(false)
-        logInfo(s"Push-based shuffle disabled for $stage (${stage.name})")
       }
+      logDebug(s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+      mergerLocs
+    } else {
+      stage.shuffleDep.getMergerLocs
     }
   }
 
+  /**
+   * Identify sibling stages for a given ShuffleMapStage. A sibling stage is 
defined as
+   * one stage that shares one or more child stages with the given stage. For 
example,
+   * when we have a join, the reduce stage of the join will be the common 
child stage for
+   * the shuffle map stages shuffling the tables involved in this join. These 
shuffle map
+   * stages are sibling stages to each other. Being able to identify the 
sibling stages would
+   * allow us to set common merger locations for these shuffle map stages so 
that the reduce
+   * stage can have better locality especially for join operations.
+   */
+  private def findCoPartitionedSiblingMapStages(
+      stage: ShuffleMapStage): Set[ShuffleMapStage] = {
+    val numShufflePartitions = stage.shuffleDep.partitioner.numPartitions
+    val siblingStages = new HashSet[ShuffleMapStage]
+    siblingStages += stage
+    var prevSize = 0
+    val allStagesSoFar = waitingStages ++ runningStages ++ failedStages ++ 
succeededStages
+    do {
+      prevSize = siblingStages.size
+      siblingStages ++= 
allStagesSoFar.filter(_.parents.intersect(siblingStages.toSeq).nonEmpty)
+        .flatMap(_.parents).filter{ parentStage =>
+          parentStage.isInstanceOf[ShuffleMapStage] &&
+            parentStage.asInstanceOf[ShuffleMapStage]
+              .shuffleDep.partitioner.numPartitions == numShufflePartitions

Review comment:
       Do we want to compare the partition here instead of partitions ?

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -1371,22 +1383,84 @@ private[spark] class DAGScheduler(
   private def prepareShuffleServicesForShuffleMapStage(stage: 
ShuffleMapStage): Unit = {
     assert(stage.shuffleDep.shuffleMergeEnabled && 
!stage.shuffleDep.shuffleMergeFinalized)
     if (stage.shuffleDep.getMergerLocs.isEmpty) {
-      val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
-        stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      getAndSetShufflePushMergerLocations(stage)
+    }
+  }
+
+  private def getAndSetShufflePushMergerLocations(stage: ShuffleMapStage): 
Seq[BlockManagerId] = {
+    if (stage.shuffleDep.getMergerLocs.isEmpty && 
!stage.shuffleDep.shuffleMergeFinalized) {

Review comment:
       This will interact with #34122 - you might want to probably make this PR 
WIP and wait for that one to get merged before ?
   Particularly around state changes and with adaptive addition of mergers in 
`handleExecutorAdded`.
   We should refactor the changes around handling for 
`coPartitionedSiblingStages`, etc and reuse it from both places after #34122 is 
merged.

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -1371,22 +1383,84 @@ private[spark] class DAGScheduler(
   private def prepareShuffleServicesForShuffleMapStage(stage: 
ShuffleMapStage): Unit = {
     assert(stage.shuffleDep.shuffleMergeEnabled && 
!stage.shuffleDep.shuffleMergeFinalized)
     if (stage.shuffleDep.getMergerLocs.isEmpty) {
-      val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
-        stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      getAndSetShufflePushMergerLocations(stage)
+    }
+  }
+
+  private def getAndSetShufflePushMergerLocations(stage: ShuffleMapStage): 
Seq[BlockManagerId] = {
+    if (stage.shuffleDep.getMergerLocs.isEmpty && 
!stage.shuffleDep.shuffleMergeFinalized) {
+      val coPartitionedSiblingStages = if (reuseMergerLocations) {
+        // Reuse merger locations for sibling stages (for eg: join cases) so 
that
+        // both the RDD's output will be collocated giving better locality. 
Since this
+        // method is invoked only within the event loop thread, it's safe to 
find sibling
+        // stages and set the merger locations accordingly in this way.
+        findCoPartitionedSiblingMapStages(stage)
+      } else {
+        Set.empty[ShuffleMapStage]
+      }
+      val mergerLocs = if (reuseMergerLocations) {
+        coPartitionedSiblingStages.collectFirst({
+          case s if s.shuffleDep.getMergerLocs.nonEmpty => 
s.shuffleDep.getMergerLocs})
+          .getOrElse(sc.schedulerBackend.getShufflePushMergerLocations(
+            stage.shuffleDep.partitioner.numPartitions, 
stage.resourceProfileId))
+      } else {
+        sc.schedulerBackend.getShufflePushMergerLocations(
+          stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      }
       if (mergerLocs.nonEmpty) {
         stage.shuffleDep.setMergerLocs(mergerLocs)
-        logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
-          s" ${stage.shuffleDep.getMergerLocs.size} merger locations")
-
-        logDebug("List of shuffle push merger locations " +
-          s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+        if (reuseMergerLocations) {
+          coPartitionedSiblingStages.filter(_.shuffleDep.getMergerLocs.isEmpty)
+            .foreach(_.shuffleDep.setMergerLocs(mergerLocs))
+        }
+        stage.shuffleDep.setShuffleMergeEnabled(true)
       } else {
         stage.shuffleDep.setShuffleMergeEnabled(false)
-        logInfo(s"Push-based shuffle disabled for $stage (${stage.name})")
       }
+      logDebug(s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+      mergerLocs
+    } else {
+      stage.shuffleDep.getMergerLocs
     }
   }
 
+  /**
+   * Identify sibling stages for a given ShuffleMapStage. A sibling stage is 
defined as
+   * one stage that shares one or more child stages with the given stage. For 
example,
+   * when we have a join, the reduce stage of the join will be the common 
child stage for
+   * the shuffle map stages shuffling the tables involved in this join. These 
shuffle map
+   * stages are sibling stages to each other. Being able to identify the 
sibling stages would
+   * allow us to set common merger locations for these shuffle map stages so 
that the reduce
+   * stage can have better locality especially for join operations.
+   */
+  private def findCoPartitionedSiblingMapStages(
+      stage: ShuffleMapStage): Set[ShuffleMapStage] = {
+    val numShufflePartitions = stage.shuffleDep.partitioner.numPartitions
+    val siblingStages = new HashSet[ShuffleMapStage]
+    siblingStages += stage
+    var prevSize = 0
+    val allStagesSoFar = waitingStages ++ runningStages ++ failedStages ++ 
succeededStages

Review comment:
       Do a `.toSet` here and avoid the `.toSeq` in intersect below

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -1371,22 +1383,84 @@ private[spark] class DAGScheduler(
   private def prepareShuffleServicesForShuffleMapStage(stage: 
ShuffleMapStage): Unit = {
     assert(stage.shuffleDep.shuffleMergeEnabled && 
!stage.shuffleDep.shuffleMergeFinalized)
     if (stage.shuffleDep.getMergerLocs.isEmpty) {
-      val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
-        stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      getAndSetShufflePushMergerLocations(stage)
+    }
+  }
+
+  private def getAndSetShufflePushMergerLocations(stage: ShuffleMapStage): 
Seq[BlockManagerId] = {
+    if (stage.shuffleDep.getMergerLocs.isEmpty && 
!stage.shuffleDep.shuffleMergeFinalized) {
+      val coPartitionedSiblingStages = if (reuseMergerLocations) {
+        // Reuse merger locations for sibling stages (for eg: join cases) so 
that
+        // both the RDD's output will be collocated giving better locality. 
Since this
+        // method is invoked only within the event loop thread, it's safe to 
find sibling
+        // stages and set the merger locations accordingly in this way.
+        findCoPartitionedSiblingMapStages(stage)
+      } else {
+        Set.empty[ShuffleMapStage]
+      }
+      val mergerLocs = if (reuseMergerLocations) {
+        coPartitionedSiblingStages.collectFirst({
+          case s if s.shuffleDep.getMergerLocs.nonEmpty => 
s.shuffleDep.getMergerLocs})
+          .getOrElse(sc.schedulerBackend.getShufflePushMergerLocations(
+            stage.shuffleDep.partitioner.numPartitions, 
stage.resourceProfileId))
+      } else {
+        sc.schedulerBackend.getShufflePushMergerLocations(
+          stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      }
       if (mergerLocs.nonEmpty) {
         stage.shuffleDep.setMergerLocs(mergerLocs)
-        logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
-          s" ${stage.shuffleDep.getMergerLocs.size} merger locations")
-
-        logDebug("List of shuffle push merger locations " +
-          s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+        if (reuseMergerLocations) {
+          coPartitionedSiblingStages.filter(_.shuffleDep.getMergerLocs.isEmpty)
+            .foreach(_.shuffleDep.setMergerLocs(mergerLocs))
+        }
+        stage.shuffleDep.setShuffleMergeEnabled(true)
       } else {
         stage.shuffleDep.setShuffleMergeEnabled(false)
-        logInfo(s"Push-based shuffle disabled for $stage (${stage.name})")
       }
+      logDebug(s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+      mergerLocs
+    } else {
+      stage.shuffleDep.getMergerLocs
     }
   }
 
+  /**
+   * Identify sibling stages for a given ShuffleMapStage. A sibling stage is 
defined as
+   * one stage that shares one or more child stages with the given stage. For 
example,
+   * when we have a join, the reduce stage of the join will be the common 
child stage for
+   * the shuffle map stages shuffling the tables involved in this join. These 
shuffle map
+   * stages are sibling stages to each other. Being able to identify the 
sibling stages would
+   * allow us to set common merger locations for these shuffle map stages so 
that the reduce
+   * stage can have better locality especially for join operations.
+   */
+  private def findCoPartitionedSiblingMapStages(
+      stage: ShuffleMapStage): Set[ShuffleMapStage] = {
+    val numShufflePartitions = stage.shuffleDep.partitioner.numPartitions

Review comment:
       Any reason to use `s.shuffleDep.partitioner.numPartitions` instead of 
`s.numPartitions` ? (here and below)

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -2546,6 +2620,14 @@ private[spark] class DAGScheduler(
     }
     listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
     runningStages -= stage
+    if (errorMessage.isEmpty) {
+      // Add succeeded stage into the succeeded set
+      succeededStages += stage
+      // Remove all parent stages with no pending child stages from the 
succeeded set
+      stage.parents.filter{ parentStage =>

Review comment:
       Do we need to do this for every parent ? Or just the immediate parent ?

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -1371,22 +1383,84 @@ private[spark] class DAGScheduler(
   private def prepareShuffleServicesForShuffleMapStage(stage: 
ShuffleMapStage): Unit = {
     assert(stage.shuffleDep.shuffleMergeEnabled && 
!stage.shuffleDep.shuffleMergeFinalized)
     if (stage.shuffleDep.getMergerLocs.isEmpty) {
-      val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
-        stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      getAndSetShufflePushMergerLocations(stage)
+    }
+  }
+
+  private def getAndSetShufflePushMergerLocations(stage: ShuffleMapStage): 
Seq[BlockManagerId] = {
+    if (stage.shuffleDep.getMergerLocs.isEmpty && 
!stage.shuffleDep.shuffleMergeFinalized) {
+      val coPartitionedSiblingStages = if (reuseMergerLocations) {
+        // Reuse merger locations for sibling stages (for eg: join cases) so 
that
+        // both the RDD's output will be collocated giving better locality. 
Since this
+        // method is invoked only within the event loop thread, it's safe to 
find sibling
+        // stages and set the merger locations accordingly in this way.
+        findCoPartitionedSiblingMapStages(stage)
+      } else {
+        Set.empty[ShuffleMapStage]
+      }
+      val mergerLocs = if (reuseMergerLocations) {
+        coPartitionedSiblingStages.collectFirst({
+          case s if s.shuffleDep.getMergerLocs.nonEmpty => 
s.shuffleDep.getMergerLocs})
+          .getOrElse(sc.schedulerBackend.getShufflePushMergerLocations(
+            stage.shuffleDep.partitioner.numPartitions, 
stage.resourceProfileId))
+      } else {
+        sc.schedulerBackend.getShufflePushMergerLocations(
+          stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      }
       if (mergerLocs.nonEmpty) {
         stage.shuffleDep.setMergerLocs(mergerLocs)
-        logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
-          s" ${stage.shuffleDep.getMergerLocs.size} merger locations")
-
-        logDebug("List of shuffle push merger locations " +
-          s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+        if (reuseMergerLocations) {
+          coPartitionedSiblingStages.filter(_.shuffleDep.getMergerLocs.isEmpty)
+            .foreach(_.shuffleDep.setMergerLocs(mergerLocs))
+        }
+        stage.shuffleDep.setShuffleMergeEnabled(true)
       } else {
         stage.shuffleDep.setShuffleMergeEnabled(false)
-        logInfo(s"Push-based shuffle disabled for $stage (${stage.name})")
       }
+      logDebug(s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+      mergerLocs
+    } else {
+      stage.shuffleDep.getMergerLocs
     }
   }
 
+  /**
+   * Identify sibling stages for a given ShuffleMapStage. A sibling stage is 
defined as
+   * one stage that shares one or more child stages with the given stage. For 
example,
+   * when we have a join, the reduce stage of the join will be the common 
child stage for
+   * the shuffle map stages shuffling the tables involved in this join. These 
shuffle map
+   * stages are sibling stages to each other. Being able to identify the 
sibling stages would
+   * allow us to set common merger locations for these shuffle map stages so 
that the reduce
+   * stage can have better locality especially for join operations.
+   */
+  private def findCoPartitionedSiblingMapStages(
+      stage: ShuffleMapStage): Set[ShuffleMapStage] = {
+    val numShufflePartitions = stage.shuffleDep.partitioner.numPartitions
+    val siblingStages = new HashSet[ShuffleMapStage]
+    siblingStages += stage
+    var prevSize = 0
+    val allStagesSoFar = waitingStages ++ runningStages ++ failedStages ++ 
succeededStages
+    do {
+      prevSize = siblingStages.size
+      siblingStages ++= 
allStagesSoFar.filter(_.parents.intersect(siblingStages.toSeq).nonEmpty)
+        .flatMap(_.parents).filter{ parentStage =>
+          parentStage.isInstanceOf[ShuffleMapStage] &&

Review comment:
       This will always be true ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to