mridulm commented on a change in pull request #34122:
URL: https://github.com/apache/spark/pull/34122#discussion_r785387151



##########
File path: core/src/main/scala/org/apache/spark/Dependency.scala
##########
@@ -145,19 +155,30 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
   }
 
   /**
-   * Returns true if push-based shuffle is disabled for this stage or empty 
RDD,
-   * or if the shuffle merge for this stage is finalized, i.e. the shuffle 
merge
-   * results for all partitions are available.
+   * Returns true if the RDD is an empty RDD or if the shuffle merge for this 
shuffle is
+   * finalized.
    */
-  def shuffleMergeFinalized: Boolean = {
+  def isShuffleMergeFinalized: Boolean = {
     // Empty RDD won't be computed therefore shuffle merge finalized should be 
true by default.
-    if (shuffleMergeEnabled && numPartitions > 0) {
+    if (numPartitions > 0) {

Review comment:
       Since `canShuffleMergeBeEnabled` already checks for `numPartition > 0`, 
do we need this check here ?

##########
File path: core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
##########
@@ -910,4 +910,32 @@ class MapOutputTrackerSuite extends SparkFunSuite with 
LocalSparkContext {
     rpcEnv.shutdown()
     slaveRpcEnv.shutdown()
   }
+
+  test("SPARK-34826: Adaptive shuffle mergers") {
+    val newConf = new SparkConf
+    newConf.set("spark.shuffle.push.based.enabled", "true")
+    newConf.set("spark.shuffle.service.enabled", "true")
+
+    // needs TorrentBroadcast so need a SparkContext
+    withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { 
sc =>
+      val masterTracker = 
sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+      val rpcEnv = sc.env.rpcEnv
+      val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, 
masterTracker, newConf)
+      rpcEnv.stop(masterTracker.trackerEndpoint)
+      rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
+
+      val slaveTracker = new MapOutputTrackerWorker(newConf)
+      slaveTracker.trackerEndpoint =
+        rpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
+
+      masterTracker.registerShuffle(20, 100, 100)
+      slaveTracker.updateEpoch(masterTracker.getEpoch)
+      val mergerLocs = (1 to 10).map(x => BlockManagerId(s"exec-$x", 
s"host-$x", 7337))
+      masterTracker.registerShufflePushMergerLocations(20, mergerLocs)
+
+      assert(slaveTracker.getShufflePushMergerLocations(20).size == 10)
+      slaveTracker.unregisterShuffle(20)
+      assert(slaveTracker.shufflePushMergerLocations.isEmpty)

Review comment:
       Note: We would need to enhance this test with the additional case 
mentioned above for `shufflePushMergerLocations`.

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -4147,7 +4146,128 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
   }
 
-    /**
+  test("SPARK-34826: Adaptively fetch shuffle mergers") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 7
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(0), Success, makeMapStatus("hostA", parts),
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(0)))
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.getMergerLocs.isEmpty)
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).isEmpty)
+
+    DAGSchedulerSuite.addMergerLocs(Seq("host6", "host7", "host8"))
+
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(1), Success, makeMapStatus("hostA", parts),
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(1)))
+
+    // Dummy executor added event to trigger registering of shuffle merger 
locations
+    // as shuffle mergers are tracked separately for test
+    runEvent(ExecutorAdded("dummy", "dummy"))

Review comment:
       Let us make the host valid (say `host6` for example: from the new set of 
hosts) - in case this code evolves in future.

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1281,6 +1331,22 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
     }
   }
 
+  override def getShufflePushMergerLocations(shuffleId: Int): 
Seq[BlockManagerId] = {
+    shufflePushMergerLocations.getOrElse(shuffleId, 
getMergerLocations(shuffleId))
+  }
+
+  private def getMergerLocations(shuffleId: Int): Seq[BlockManagerId] = {
+    fetchingLock.withLock(shuffleId) {
+      val mergers = 
askTracker[Seq[BlockManagerId]](GetShufflePushMergerLocations(shuffleId))

Review comment:
       Check if `shufflePushMergerLocations(shuffleId)` is nonEmpty before 
fetching it again - in case of concurrent updates to before we acquire the lock 
for shuffleId
   (See `getStatuses` or other uses of `fetchingLock` for how).

##########
File path: core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
##########
@@ -910,4 +910,32 @@ class MapOutputTrackerSuite extends SparkFunSuite with 
LocalSparkContext {
     rpcEnv.shutdown()
     slaveRpcEnv.shutdown()
   }
+
+  test("SPARK-34826: Adaptive shuffle mergers") {
+    val newConf = new SparkConf
+    newConf.set("spark.shuffle.push.based.enabled", "true")
+    newConf.set("spark.shuffle.service.enabled", "true")
+
+    // needs TorrentBroadcast so need a SparkContext
+    withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { 
sc =>
+      val masterTracker = 
sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+      val rpcEnv = sc.env.rpcEnv
+      val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, 
masterTracker, newConf)
+      rpcEnv.stop(masterTracker.trackerEndpoint)
+      rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
+
+      val slaveTracker = new MapOutputTrackerWorker(newConf)

Review comment:
       rename as `workerTracker`
   
   Looks like some of the other push based shuffle PR's have introduced this 
term again - we should rename them to `worker` as well.
   Can you pls file a jira under the improvement tasks for this ? Thx.

##########
File path: 
core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala
##########
@@ -59,13 +59,24 @@ private[spark] class ShuffleWriteProcessor extends 
Serializable with Logging {
         rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: 
Product2[Any, Any]]])
       val mapStatus = writer.stop(success = true)
       if (mapStatus.isDefined) {
+        // Check if sufficient shuffle mergers are available now for the 
ShuffleMapTask to push
+        if (dep.shuffleMergeAllowed && dep.getMergerLocs.isEmpty && 
!dep.isShuffleMergeFinalized) {

Review comment:
       We cannot have both `dep.getMergerLocs.isEmpty == true` and 
`dep.isShuffleMergeFinalized == false`.
   Drop `dep.isShuffleMergeFinalized`  here ?

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1364,6 +1430,7 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
   def unregisterShuffle(shuffleId: Int): Unit = {
     mapStatuses.remove(shuffleId)
     mergeStatuses.remove(shuffleId)
+    shufflePushMergerLocations.remove(shuffleId)

Review comment:
       Review Note: I was trying to see if `shufflePushMergerLocations.clear` 
in `updateEpoch` was sufficient to handle the shuffle attempt id issue above - 
but I am not so sure if that is enough.

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1176,6 +1223,9 @@ private[spark] class MapOutputTrackerWorker(conf: 
SparkConf) extends MapOutputTr
   // instantiate a serializer. See the followup to SPARK-36705 for more 
details.
   private lazy val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf, 
isDriver = false)
 
+  // Exposed for testing
+  val shufflePushMergerLocations = new ConcurrentHashMap[Int, 
Seq[BlockManagerId]]().asScala

Review comment:
       This needs to account for shuffle merge id as well - to ensure the 
merger locations are consistent.
   Possible flow is:
   a) attempt 0 has enough mergers, and so executors cache value for attempt 0 
mergers.
   b) one or more hosts get blacklisted before attempt 1.
   c) attempt 1 starts, with merge allowed = true, but merge enabled = false - 
so tasks start without mergers populated in dependency (merge might get enabled 
as more executors are added).
   d) executors which ran attempt 0 will use mergers from previous attempt - 
new executors will use mergers from attempt 1.
   

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -4147,7 +4146,128 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
   }
 
-    /**
+  test("SPARK-34826: Adaptively fetch shuffle mergers") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+    DAGSchedulerSuite.clearMergerLocs

Review comment:
       nit: This should be `clearMergerLocs()` given it is changing state - 
though not introduced in this PR.

##########
File path: core/src/main/scala/org/apache/spark/Dependency.scala
##########
@@ -106,7 +106,17 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
 
   // By default, shuffle merge is enabled for ShuffleDependency if push based 
shuffle
   // is enabled
-  private[this] var _shuffleMergeEnabled = canShuffleMergeBeEnabled()
+  private[this] var _shuffleMergeAllowed = canShuffleMergeBeEnabled()
+
+  private[spark] def setShuffleMergeAllowed(shuffleMergeAllowed: Boolean): 
Unit = {
+    _shuffleMergeAllowed = shuffleMergeAllowed
+  }
+
+  def shuffleMergeAllowed : Boolean = _shuffleMergeAllowed
+
+  // By default, shuffle merge is enabled for ShuffleDependency if push based 
shuffle
+  // is enabled
+  private[this] var _shuffleMergeEnabled = shuffleMergeAllowed
 
   private[spark] def setShuffleMergeEnabled(shuffleMergeEnabled: Boolean): 
Unit = {
     _shuffleMergeEnabled = shuffleMergeEnabled

Review comment:
       Do we need this boolean ? Or will `mergerLocs.nonEmpty` not do ? (`def 
shuffleMergeEnabled : Boolean = mergerLocs.nonEmpty`)
   
   Also, in `setShuffleMergeEnabled ` (if we are keeping it) or in 
`setMergerLocs`, add`assert (! shuffleMergeEnabled  || shuffleMergeAllowed)`

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -4147,7 +4146,128 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
   }
 
-    /**
+  test("SPARK-34826: Adaptively fetch shuffle mergers") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 7
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(0), Success, makeMapStatus("hostA", parts),
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(0)))
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.getMergerLocs.isEmpty)
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).isEmpty)
+
+    DAGSchedulerSuite.addMergerLocs(Seq("host6", "host7", "host8"))
+
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(1), Success, makeMapStatus("hostA", parts),
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(1)))
+
+    // Dummy executor added event to trigger registering of shuffle merger 
locations
+    // as shuffle mergers are tracked separately for test
+    runEvent(ExecutorAdded("dummy", "dummy"))
+
+    // Check if new shuffle merger locations are available for push or not
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 7)

Review comment:
       Shouldn't this not be `6` ?

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -4147,7 +4146,128 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
   }
 
-    /**
+  test("SPARK-34826: Adaptively fetch shuffle mergers") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 7
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(0), Success, makeMapStatus("hostA", parts),
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(0)))
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.getMergerLocs.isEmpty)
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).isEmpty)
+
+    DAGSchedulerSuite.addMergerLocs(Seq("host6", "host7", "host8"))
+
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(1), Success, makeMapStatus("hostA", parts),
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(1)))

Review comment:
       I am trying to understand why task`(1)` was completed and 
`addMergerLocs` is called before that.
   Any comments ? Thx
   
   (I understand that `addMergerLocs` needs to be called to make mergers >= 6, 
but why this specific flow ?)

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -2487,6 +2501,21 @@ private[spark] class DAGScheduler(
       executorFailureEpoch -= execId
     }
     shuffleFileLostEpoch -= execId
+
+    if (pushBasedShuffleEnabled) {
+      // Only set merger locations for stages that are not yet finished and 
have empty mergers
+      shuffleIdToMapStage.filter { case (_, stage) =>

Review comment:
       Scratch that, the cost is minimal at best after the re-order of filter 
checks below.

##########
File path: core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
##########
@@ -39,7 +39,9 @@ class StageInfo(
     val taskMetrics: TaskMetrics = null,
     private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = 
Seq.empty,
     private[spark] val shuffleDepId: Option[Int] = None,
-    val resourceProfileId: Int) {
+    val resourceProfileId: Int,
+    private[spark] var isPushBasedShuffleEnabled: Boolean = false,
+    private[spark] var shuffleMergerCount: Int = 0) {

Review comment:
       Sounds good, that PR is currently not consuming it - but if we are 
updating to also surface this, makes sense to include it.
   @thejdeep would be great if you can update #34000 once this is merged (or in 
follow up if #34000 gets merged before this PR). Thx !

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -2487,6 +2501,21 @@ private[spark] class DAGScheduler(
       executorFailureEpoch -= execId
     }
     shuffleFileLostEpoch -= execId
+
+    if (pushBasedShuffleEnabled) {

Review comment:
       Thanks for clarifying, makes sense.

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -4147,7 +4146,128 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
   }
 
-    /**
+  test("SPARK-34826: Adaptively fetch shuffle mergers") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 7
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(0), Success, makeMapStatus("hostA", parts),
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(0)))
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.getMergerLocs.isEmpty)
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).isEmpty)
+
+    DAGSchedulerSuite.addMergerLocs(Seq("host6", "host7", "host8"))
+
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(1), Success, makeMapStatus("hostA", parts),
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(1)))
+
+    // Dummy executor added event to trigger registering of shuffle merger 
locations
+    // as shuffle mergers are tracked separately for test
+    runEvent(ExecutorAdded("dummy", "dummy"))
+
+    // Check if new shuffle merger locations are available for push or not
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 7)
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 7)
+
+    // Complete remaining tasks in ShuffleMapStage 0
+    (2 to 6).foreach(x => {
+      runEvent(makeCompletionEvent(taskSets(0).tasks(x), Success, 
makeMapStatus("hostA", parts),
+        Seq.empty, Array.empty, createFakeTaskInfoWithId(x)))
+    })
+
+    completeNextResultStageWithSuccess(1, 0)
+    assert(results === Map(0 -> 42, 1 -> 42, 2 -> 42, 3 -> 42, 4 -> 42, 5 -> 
42, 6 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-34826: Adaptively fetch shuffle mergers with stage retry") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 7
+
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+    }.toSeq
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    DAGSchedulerSuite.addMergerLocs(Seq("host6", "host7", "host8"))
+    // Dummy executor added event to trigger registering of shuffle merger 
locations
+    // as shuffle mergers are tracked separately for test
+    runEvent(ExecutorAdded("dummy", "dummy"))
+    // Check if new shuffle merger locations are available for push or not
+    assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 7)
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 7)
+    val mergerLocsBeforeRetry = shuffleStage1.shuffleDep.getMergerLocs
+
+    // Remove MapStatus on one of the host before the stage ends to trigger
+    // a scenario where stage 0 needs to be resubmitted upon finishing all 
tasks.
+    // Merge finalization should be scheduled in this case.
+    for ((result, i) <- taskResults.zipWithIndex) {
+      if (i == taskSets(0).tasks.size - 1) {
+        mapOutputTracker.removeOutputsOnHost("hostA")
+      }
+      if (i < taskSets(0).tasks.size) {
+        runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, 
result._2))
+      }
+    }
+
+    assert(shuffleStage1.shuffleDep.isShuffleMergeFinalized)
+
+    // Successfully completing the retry of stage 0.
+    complete(taskSets(2), taskSets(2).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+    }.toSeq)
+
+    DAGSchedulerSuite.addMergerLocs(Seq("host9", "host10"))
+    // Dummy executor added event to trigger registering of shuffle merger 
locations
+    runEvent(ExecutorAdded("dummy1", "dummy1"))
+    assert(shuffleStage1.shuffleDep.getMergerLocs.size == 7)
+    assert(shuffleStage1.shuffleDep.isShuffleMergeFinalized)
+    complete(taskSets(1), taskSets(1).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts, 10))
+    }.toSeq)
+    val shuffleStage2 = 
scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage]
+    // Shuffle merger locs should not be refreshed as the shuffle is already 
finalized
+    assert(mergerLocsBeforeRetry.sortBy(_.host) ===
+      shuffleStage1.shuffleDep.getMergerLocs.sortBy(_.host))

Review comment:
       We need to test for both DETERMINATE stages and INDETERMINATE stages.
   We should force a new set of mergers for INDETERMINATE stage retry - In a 
'real world case', this would mean adding a previously selected merger to the 
excludeNodes - but for the test, we can simply clear `DAGSchedulerSuite 
.mergerLocs` and add a different set of hosts there.

##########
File path: 
core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
##########
@@ -131,7 +131,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) 
extends ShuffleManager
       metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
     val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]]
     val (blocksByAddress, canEnableBatchFetch) =
-      if (baseShuffleHandle.dependency.shuffleMergeEnabled) {
+      if (baseShuffleHandle.dependency.isShuffleMergeFinalized) {

Review comment:
       Thanks @venkata91 !




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to