mridulm commented on a change in pull request #34122:
URL: https://github.com/apache/spark/pull/34122#discussion_r790361031



##########
File path: core/src/main/scala/org/apache/spark/Dependency.scala
##########
@@ -135,6 +144,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
   def shuffleMergeId: Int = _shuffleMergeId
 
   def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = {
+    assert(shuffleMergeEnabled || shuffleMergeAllowed)

Review comment:
       `shuffleMergeAllowed` is a superset - we dont need `shuffleMergeEnabled` 
here (and actually, is getting set after `setMergerLocs` in this PR).
   Also, see comment 
[above](https://github.com/apache/spark/pull/34122/files#r785389352) on whether 
we need the variable (`_shuffleMergeEnabled `) at all.

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -3686,14 +3686,13 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     completeNextStageWithFetchFailure(3, 0, shuffleDep)
     scheduler.resubmitFailedStages()
 
-    // Make sure shuffle merge is disabled for the retry
     val stage2 = scheduler.stageIdToStage(2).asInstanceOf[ShuffleMapStage]
-    assert(!stage2.shuffleDep.shuffleMergeEnabled)
+    assert(stage2.shuffleDep.shuffleMergeEnabled)

Review comment:
       Why did this behavior change ?
   I would expect it to be the same before/after this PR.
   
   For DETERMINATE stage, we should not have merge enabled for the retry if the 
previous attempt did finalize.

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -4147,7 +4146,210 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
   }
 
-    /**
+  test("SPARK-34826: Adaptively fetch shuffle mergers") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 2)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1"))
+    val parts = 2
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(0), Success, makeMapStatus("hostA", parts),
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(0)))
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.getMergerLocs.isEmpty)
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).isEmpty)
+
+    DAGSchedulerSuite.addMergerLocs(Seq("host2", "host3"))
+
+    // host2 executor added event to trigger registering of shuffle merger 
locations
+    // as shuffle mergers are tracked separately for test
+    runEvent(ExecutorAdded("host2", "host2"))

Review comment:
       nit: Use 'exec' instead of 'host' as prefix for executor id

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1176,6 +1223,9 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
   // instantiate a serializer. See the followup to SPARK-36705 for more 
details.
   private lazy val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf, 
isDriver = false)
 
+  // Exposed for testing
+  val shufflePushMergerLocations = new ConcurrentHashMap[Int, 
Seq[BlockManagerId]]().asScala

Review comment:
       Sounds good, this looks fine.

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -1369,24 +1370,34 @@ private[spark] class DAGScheduler(
    * locations for block push/merge by getting the historical locations of 
past executors.
    */
   private def prepareShuffleServicesForShuffleMapStage(stage: 
ShuffleMapStage): Unit = {
-    assert(stage.shuffleDep.shuffleMergeEnabled && 
!stage.shuffleDep.shuffleMergeFinalized)
+    assert(stage.shuffleDep.shuffleMergeAllowed && 
!stage.shuffleDep.shuffleMergeFinalized)
     if (stage.shuffleDep.getMergerLocs.isEmpty) {
-      val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
-        stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
-      if (mergerLocs.nonEmpty) {
-        stage.shuffleDep.setMergerLocs(mergerLocs)
-        logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
-          s" ${stage.shuffleDep.getMergerLocs.size} merger locations")
-
-        logDebug("List of shuffle push merger locations " +
-          s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
-      } else {
-        stage.shuffleDep.setShuffleMergeEnabled(false)
-        logInfo(s"Push-based shuffle disabled for $stage (${stage.name})")
-      }
+      getAndSetShufflePushMergerLocations(stage)
+    }
+
+    if (stage.shuffleDep.shuffleMergeEnabled) {
+      logInfo(("Shuffle merge enabled before starting the stage for %s (%s) 
with %d" +
+        " merger locations").format(stage, stage.name, 
stage.shuffleDep.getMergerLocs.size))

Review comment:
       Please add the shuffle id and shuffle merge id as well for all log 
messages - while debugging an issue recently, we found the lack of this detail 
making it difficult to reason about the state.
   @otterc is planning to file a jira for this anyway, but let us address it 
for all log message in this PR as well.

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -4147,7 +4146,210 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
   }
 
-    /**
+  test("SPARK-34826: Adaptively fetch shuffle mergers") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 2)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1"))
+    val parts = 2
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(0), Success, makeMapStatus("hostA", parts),
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(0)))
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.getMergerLocs.isEmpty)
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).isEmpty)
+
+    DAGSchedulerSuite.addMergerLocs(Seq("host2", "host3"))
+
+    // host2 executor added event to trigger registering of shuffle merger 
locations
+    // as shuffle mergers are tracked separately for test
+    runEvent(ExecutorAdded("host2", "host2"))
+
+    // Check if new shuffle merger locations are available for push or not
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 2)
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2)
+
+    // Complete remaining tasks in ShuffleMapStage 0
+    runEvent(makeCompletionEvent(taskSets(0).tasks(1), Success,
+      makeMapStatus("host1", parts), Seq.empty, Array.empty, 
createFakeTaskInfoWithId(1)))
+
+    completeNextResultStageWithSuccess(1, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-34826: Adaptively fetch shuffle mergers with stage retry") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 2)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1"))
+    val parts = 2
+
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + idx, parts))
+    }.toSeq
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    DAGSchedulerSuite.addMergerLocs(Seq("host2", "host3"))
+    // host2 executor added event to trigger registering of shuffle merger 
locations
+    // as shuffle mergers are tracked separately for test
+    runEvent(ExecutorAdded("host2", "host2"))
+    // Check if new shuffle merger locations are available for push or not
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 2)
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2)
+    val mergerLocsBeforeRetry = shuffleStage1.shuffleDep.getMergerLocs
+
+    // Clear merger locations to check if new mergers are not getting set for 
the
+    // retry of determinate stage
+    DAGSchedulerSuite.clearMergerLocs()
+
+    // Remove MapStatus on one of the host before the stage ends to trigger
+    // a scenario where stage 0 needs to be resubmitted upon finishing all 
tasks.
+    // Merge finalization should be scheduled in this case.
+    for ((result, i) <- taskResults.zipWithIndex) {
+      if (i == taskSets(0).tasks.size - 1) {
+        mapOutputTracker.removeOutputsOnHost("host0")
+      }
+      if (i < taskSets(0).tasks.size) {
+        runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, 
result._2))
+      }
+    }
+    assert(shuffleStage1.shuffleDep.shuffleMergeFinalized)
+
+    DAGSchedulerSuite.addMergerLocs(Seq("host4", "host5"))
+    // host4 executor added event shouldn't reset merger locations given 
merger locations
+    // are already set
+    runEvent(ExecutorAdded("host4", "host4"))
+
+    // Successfully completing the retry of stage 0.
+    complete(taskSets(2), taskSets(2).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + idx, parts))
+    }.toSeq)
+
+    assert(shuffleStage1.shuffleDep.shuffleMergeId == 0)
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2)
+    assert(shuffleStage1.shuffleDep.shuffleMergeFinalized)
+    val newMergerLocs =
+      
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs
+    assert(mergerLocsBeforeRetry.sortBy(_.host) === 
newMergerLocs.sortBy(_.host))
+    val shuffleStage2 = 
scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage]
+    complete(taskSets(1), taskSets(1).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + idx, parts, 10))
+    }.toSeq)
+    assert(shuffleStage2.shuffleDep.getMergerLocs.size == 2)
+    completeNextResultStageWithSuccess(2, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-34826: Adaptively fetch shuffle mergers with stage retry for 
indeterminate stage") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 2)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1"))
+    val parts = 2
+
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil, indeterminate = true)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil, indeterminate = true)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + idx, parts))
+    }.toSeq
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    DAGSchedulerSuite.addMergerLocs(Seq("host2", "host3"))
+    // host2 executor added event to trigger registering of shuffle merger 
locations
+    // as shuffle mergers are tracked separately for test
+    runEvent(ExecutorAdded("host2", "host2"))
+    // Check if new shuffle merger locations are available for push or not
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 2)
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2)
+    val mergerLocsBeforeRetry = shuffleStage1.shuffleDep.getMergerLocs
+
+    // Clear merger locations to check if new mergers are getting set for the
+    // retry of indeterminate stage
+    DAGSchedulerSuite.clearMergerLocs()
+
+    // Remove MapStatus on one of the host before the stage ends to trigger
+    // a scenario where stage 0 needs to be resubmitted upon finishing all 
tasks.
+    // Merge finalization should be scheduled in this case.
+    for ((result, i) <- taskResults.zipWithIndex) {
+      if (i == taskSets(0).tasks.size - 1) {
+        mapOutputTracker.removeOutputsOnHost("host0")
+      }
+      if (i < taskSets(0).tasks.size) {
+        runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, 
result._2))
+      }
+    }
+
+    // Indeterminate stage should recompute all partitions, hence
+    // shuffleMergeFinalized should be false here
+    assert(!shuffleStage1.shuffleDep.shuffleMergeFinalized)
+
+    DAGSchedulerSuite.addMergerLocs(Seq("host4", "host5"))
+    // host4 executor added event should reset merger locations given merger 
locations
+    // are already reset
+    runEvent(ExecutorAdded("host4", "host4"))
+    // Successfully completing the retry of stage 0.
+    complete(taskSets(2), taskSets(2).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + idx, parts))
+    }.toSeq)
+
+    assert(shuffleStage1.shuffleDep.shuffleMergeId == 2)

Review comment:
       Note: We really dont care about the merge id value - except that it 
should be different from earlier merge id.

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1281,6 +1341,28 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
     }
   }
 
+  override def getShufflePushMergerLocations(shuffleId: Int): 
Seq[BlockManagerId] = {
+    shufflePushMergerLocations.getOrElse(shuffleId, 
getMergerLocations(shuffleId))
+  }
+
+  private def getMergerLocations(shuffleId: Int): Seq[BlockManagerId] = {
+    val mergers = shufflePushMergerLocations.get(shuffleId).orNull
+    if (mergers == null) {
+      fetchingLock.withLock(shuffleId) {
+        val fetchedMergers =
+          
askTracker[Seq[BlockManagerId]](GetShufflePushMergerLocations(shuffleId))
+        if (fetchedMergers.nonEmpty) {
+          shufflePushMergerLocations(shuffleId) = fetchedMergers
+          fetchedMergers
+        } else {
+          Seq.empty[BlockManagerId]
+        }
+      }
+    } else {

Review comment:
       This is what I meant 
[here](https://github.com/apache/spark/pull/34122/#discussion_r785387856).
   
   ```suggestion
       if (mergers == null) {
         fetchingLock.withLock(shuffleId) {
           var fetchedMergers = shufflePushMergerLocations.get(shuffleId).orNull
           
           if (null == fetchedMergers) {
             fetchedMergers = 
askTracker[Seq[BlockManagerId]](GetShufflePushMergerLocations(shuffleId))
             if (fetchedMergers.nonEmpty) {
               shufflePushMergerLocations(shuffleId) = fetchedMergers
             } else {
               fetchedMergers = Seq.empty[BlockManagerId]
             }
           }
           fetchedMergers
       } else {
   ```
   

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -4147,7 +4146,210 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
   }
 
-    /**
+  test("SPARK-34826: Adaptively fetch shuffle mergers") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 2)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1"))
+    val parts = 2
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(0), Success, makeMapStatus("hostA", parts),
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(0)))
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.getMergerLocs.isEmpty)
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).isEmpty)
+
+    DAGSchedulerSuite.addMergerLocs(Seq("host2", "host3"))
+
+    // host2 executor added event to trigger registering of shuffle merger 
locations
+    // as shuffle mergers are tracked separately for test
+    runEvent(ExecutorAdded("host2", "host2"))
+
+    // Check if new shuffle merger locations are available for push or not
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 2)
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2)
+
+    // Complete remaining tasks in ShuffleMapStage 0
+    runEvent(makeCompletionEvent(taskSets(0).tasks(1), Success,
+      makeMapStatus("host1", parts), Seq.empty, Array.empty, 
createFakeTaskInfoWithId(1)))
+
+    completeNextResultStageWithSuccess(1, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-34826: Adaptively fetch shuffle mergers with stage retry") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 2)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1"))
+    val parts = 2
+
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + idx, parts))
+    }.toSeq
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    DAGSchedulerSuite.addMergerLocs(Seq("host2", "host3"))
+    // host2 executor added event to trigger registering of shuffle merger 
locations
+    // as shuffle mergers are tracked separately for test
+    runEvent(ExecutorAdded("host2", "host2"))
+    // Check if new shuffle merger locations are available for push or not
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 2)
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2)
+    val mergerLocsBeforeRetry = shuffleStage1.shuffleDep.getMergerLocs
+
+    // Clear merger locations to check if new mergers are not getting set for 
the
+    // retry of determinate stage
+    DAGSchedulerSuite.clearMergerLocs()
+
+    // Remove MapStatus on one of the host before the stage ends to trigger
+    // a scenario where stage 0 needs to be resubmitted upon finishing all 
tasks.
+    // Merge finalization should be scheduled in this case.
+    for ((result, i) <- taskResults.zipWithIndex) {
+      if (i == taskSets(0).tasks.size - 1) {
+        mapOutputTracker.removeOutputsOnHost("host0")
+      }
+      if (i < taskSets(0).tasks.size) {
+        runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, 
result._2))
+      }
+    }
+    assert(shuffleStage1.shuffleDep.shuffleMergeFinalized)
+
+    DAGSchedulerSuite.addMergerLocs(Seq("host4", "host5"))
+    // host4 executor added event shouldn't reset merger locations given 
merger locations
+    // are already set
+    runEvent(ExecutorAdded("host4", "host4"))
+
+    // Successfully completing the retry of stage 0.
+    complete(taskSets(2), taskSets(2).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + idx, parts))
+    }.toSeq)
+
+    assert(shuffleStage1.shuffleDep.shuffleMergeId == 0)
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2)
+    assert(shuffleStage1.shuffleDep.shuffleMergeFinalized)
+    val newMergerLocs =
+      
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs
+    assert(mergerLocsBeforeRetry.sortBy(_.host) === 
newMergerLocs.sortBy(_.host))
+    val shuffleStage2 = 
scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage]
+    complete(taskSets(1), taskSets(1).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + idx, parts, 10))
+    }.toSeq)
+    assert(shuffleStage2.shuffleDep.getMergerLocs.size == 2)
+    completeNextResultStageWithSuccess(2, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-34826: Adaptively fetch shuffle mergers with stage retry for 
indeterminate stage") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 2)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1"))
+    val parts = 2
+
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil, indeterminate = true)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil, indeterminate = true)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + idx, parts))
+    }.toSeq
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    DAGSchedulerSuite.addMergerLocs(Seq("host2", "host3"))
+    // host2 executor added event to trigger registering of shuffle merger 
locations
+    // as shuffle mergers are tracked separately for test
+    runEvent(ExecutorAdded("host2", "host2"))
+    // Check if new shuffle merger locations are available for push or not
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 2)
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2)
+    val mergerLocsBeforeRetry = shuffleStage1.shuffleDep.getMergerLocs
+
+    // Clear merger locations to check if new mergers are getting set for the
+    // retry of indeterminate stage
+    DAGSchedulerSuite.clearMergerLocs()
+
+    // Remove MapStatus on one of the host before the stage ends to trigger
+    // a scenario where stage 0 needs to be resubmitted upon finishing all 
tasks.
+    // Merge finalization should be scheduled in this case.
+    for ((result, i) <- taskResults.zipWithIndex) {
+      if (i == taskSets(0).tasks.size - 1) {
+        mapOutputTracker.removeOutputsOnHost("host0")
+      }
+      if (i < taskSets(0).tasks.size) {
+        runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, 
result._2))
+      }
+    }
+
+    // Indeterminate stage should recompute all partitions, hence
+    // shuffleMergeFinalized should be false here
+    assert(!shuffleStage1.shuffleDep.shuffleMergeFinalized)
+
+    DAGSchedulerSuite.addMergerLocs(Seq("host4", "host5"))
+    // host4 executor added event should reset merger locations given merger 
locations
+    // are already reset
+    runEvent(ExecutorAdded("host4", "host4"))

Review comment:
       nit: assert here that `getMergerLocs` is not empty ?
   (Move the `assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2)` from 
below to here).

##########
File path: core/src/main/scala/org/apache/spark/Dependency.scala
##########
@@ -144,12 +144,16 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
     _shuffleMergedFinalized = true
   }
 
+  def shuffleMergeFinalized: Boolean = {

Review comment:
       > I refactored this in to `shuffleMergeAllowed` and 
`shuffleMergeEnabled` to have a clear distinction where `shuffleMergeAllowed` 
controls the static level of knobs like the following:
   > 
   > 1. `Is RDD barrier?`
   > 2. `Can Push shuffle be enabled?`
   > 3. `Disabling push shuffle for retry once the determinate stage attempt is 
finalized` etc.
   >    and `shuffleMergeEnabled` will be checked only when 
`shuffleMergeAllowed` is true along with that if sufficient mergers are 
available then it becomes `true`.
   > 
   > Given all the above, I think we still require two separate methods one for 
`isShuffleMergeFinalized` only checking for the `shuffleMergedFinalized` value 
and `numPartitions > 0` and another `isShuffleMergeFinalizedIfEnabled` checking 
both `shuffleMergeEnabled` and `isShuffleMergeFinalized`.
   
   
   Sounds good.
   
   > 
   > In the existing code, before the `ShuffleMapStage` starts we are checking 
`if (!shuffleMergeFinalized)` then only we are calling 
`prepareShuffleServicesForShuffleMapStage` but it is possible the previous 
stage never had `shuffleMergeEnabled` due to not enough mergers therefore even 
the retry also not be shuffle merge enabled, ideally this can be shuffle merge 
enabled if enough mergers are available. But for proceeding with the next 
stage, we need to check for `isShuffleMergeFinalizedIfEnabled` which is 
checking `if (shuffleMergeEnabled) isShuffleMergeFinalized else true`
   > 
   > Let me know what you think.
   
   There are two cases here:
   
     * retry for an INDETERMINATE shuffle.
       * We are recomputing all partitions, so we want to enable shuffle (based 
on shuffleMergeAllowed ofcourse).
         * Note: `newShuffleMergeState` will set `_shuffleMergedFinalized` to 
`false`.
     * retry for a DETERMINATE shuffle.
       * If shuffle was finalized, we dont want to enable merge - else we want 
to enable.
         * Since it is determinate shuffle, recomputation is generating 
identical data as previous attempt anyway).  
       * The existing codepath, which does this, should continue to work as-is 
with adaptive merge as well ?
   
   Given this, does anything change due to adaptive shuffle merge ? I dont see 
a difference, but please let me know.

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -4147,7 +4146,210 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
   }
 
-    /**
+  test("SPARK-34826: Adaptively fetch shuffle mergers") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 2)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1"))
+    val parts = 2
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(0), Success, makeMapStatus("hostA", parts),
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(0)))
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.getMergerLocs.isEmpty)
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).isEmpty)
+
+    DAGSchedulerSuite.addMergerLocs(Seq("host2", "host3"))
+
+    // host2 executor added event to trigger registering of shuffle merger 
locations
+    // as shuffle mergers are tracked separately for test
+    runEvent(ExecutorAdded("host2", "host2"))
+
+    // Check if new shuffle merger locations are available for push or not
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 2)
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2)
+
+    // Complete remaining tasks in ShuffleMapStage 0
+    runEvent(makeCompletionEvent(taskSets(0).tasks(1), Success,
+      makeMapStatus("host1", parts), Seq.empty, Array.empty, 
createFakeTaskInfoWithId(1)))
+
+    completeNextResultStageWithSuccess(1, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-34826: Adaptively fetch shuffle mergers with stage retry") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 2)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1"))
+    val parts = 2
+
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + idx, parts))
+    }.toSeq
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    DAGSchedulerSuite.addMergerLocs(Seq("host2", "host3"))
+    // host2 executor added event to trigger registering of shuffle merger 
locations
+    // as shuffle mergers are tracked separately for test
+    runEvent(ExecutorAdded("host2", "host2"))
+    // Check if new shuffle merger locations are available for push or not
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 2)
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2)
+    val mergerLocsBeforeRetry = shuffleStage1.shuffleDep.getMergerLocs
+
+    // Clear merger locations to check if new mergers are not getting set for 
the
+    // retry of determinate stage
+    DAGSchedulerSuite.clearMergerLocs()
+
+    // Remove MapStatus on one of the host before the stage ends to trigger
+    // a scenario where stage 0 needs to be resubmitted upon finishing all 
tasks.
+    // Merge finalization should be scheduled in this case.
+    for ((result, i) <- taskResults.zipWithIndex) {
+      if (i == taskSets(0).tasks.size - 1) {
+        mapOutputTracker.removeOutputsOnHost("host0")
+      }
+      if (i < taskSets(0).tasks.size) {
+        runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, 
result._2))
+      }
+    }
+    assert(shuffleStage1.shuffleDep.shuffleMergeFinalized)
+
+    DAGSchedulerSuite.addMergerLocs(Seq("host4", "host5"))
+    // host4 executor added event shouldn't reset merger locations given 
merger locations
+    // are already set
+    runEvent(ExecutorAdded("host4", "host4"))
+
+    // Successfully completing the retry of stage 0.
+    complete(taskSets(2), taskSets(2).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + idx, parts))
+    }.toSeq)
+
+    assert(shuffleStage1.shuffleDep.shuffleMergeId == 0)
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2)
+    assert(shuffleStage1.shuffleDep.shuffleMergeFinalized)
+    val newMergerLocs =
+      
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs
+    assert(mergerLocsBeforeRetry.sortBy(_.host) === 
newMergerLocs.sortBy(_.host))
+    val shuffleStage2 = 
scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage]
+    complete(taskSets(1), taskSets(1).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + idx, parts, 10))
+    }.toSeq)
+    assert(shuffleStage2.shuffleDep.getMergerLocs.size == 2)
+    completeNextResultStageWithSuccess(2, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-34826: Adaptively fetch shuffle mergers with stage retry for 
indeterminate stage") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 2)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1"))
+    val parts = 2
+
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil, indeterminate = true)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil, indeterminate = true)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + idx, parts))
+    }.toSeq
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    DAGSchedulerSuite.addMergerLocs(Seq("host2", "host3"))
+    // host2 executor added event to trigger registering of shuffle merger 
locations
+    // as shuffle mergers are tracked separately for test
+    runEvent(ExecutorAdded("host2", "host2"))
+    // Check if new shuffle merger locations are available for push or not
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 2)
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 2)
+    val mergerLocsBeforeRetry = shuffleStage1.shuffleDep.getMergerLocs
+
+    // Clear merger locations to check if new mergers are getting set for the
+    // retry of indeterminate stage
+    DAGSchedulerSuite.clearMergerLocs()
+
+    // Remove MapStatus on one of the host before the stage ends to trigger
+    // a scenario where stage 0 needs to be resubmitted upon finishing all 
tasks.
+    // Merge finalization should be scheduled in this case.
+    for ((result, i) <- taskResults.zipWithIndex) {
+      if (i == taskSets(0).tasks.size - 1) {
+        mapOutputTracker.removeOutputsOnHost("host0")
+      }
+      if (i < taskSets(0).tasks.size) {

Review comment:
       This condition is always true - did you mean something else ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to