mridulm commented on a change in pull request #34122:
URL: https://github.com/apache/spark/pull/34122#discussion_r785387151
##########
File path: core/src/main/scala/org/apache/spark/Dependency.scala
##########
@@ -145,19 +155,30 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C:
ClassTag](
}
/**
- * Returns true if push-based shuffle is disabled for this stage or empty
RDD,
- * or if the shuffle merge for this stage is finalized, i.e. the shuffle
merge
- * results for all partitions are available.
+ * Returns true if the RDD is an empty RDD or if the shuffle merge for this
shuffle is
+ * finalized.
*/
- def shuffleMergeFinalized: Boolean = {
+ def isShuffleMergeFinalized: Boolean = {
// Empty RDD won't be computed therefore shuffle merge finalized should be
true by default.
- if (shuffleMergeEnabled && numPartitions > 0) {
+ if (numPartitions > 0) {
Review comment:
Since `canShuffleMergeBeEnabled` already checks for `numPartition > 0`,
do we need this check here ?
##########
File path: core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
##########
@@ -910,4 +910,32 @@ class MapOutputTrackerSuite extends SparkFunSuite with
LocalSparkContext {
rpcEnv.shutdown()
slaveRpcEnv.shutdown()
}
+
+ test("SPARK-34826: Adaptive shuffle mergers") {
+ val newConf = new SparkConf
+ newConf.set("spark.shuffle.push.based.enabled", "true")
+ newConf.set("spark.shuffle.service.enabled", "true")
+
+ // needs TorrentBroadcast so need a SparkContext
+ withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) {
sc =>
+ val masterTracker =
sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+ val rpcEnv = sc.env.rpcEnv
+ val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv,
masterTracker, newConf)
+ rpcEnv.stop(masterTracker.trackerEndpoint)
+ rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
+
+ val slaveTracker = new MapOutputTrackerWorker(newConf)
+ slaveTracker.trackerEndpoint =
+ rpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
+
+ masterTracker.registerShuffle(20, 100, 100)
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+ val mergerLocs = (1 to 10).map(x => BlockManagerId(s"exec-$x",
s"host-$x", 7337))
+ masterTracker.registerShufflePushMergerLocations(20, mergerLocs)
+
+ assert(slaveTracker.getShufflePushMergerLocations(20).size == 10)
+ slaveTracker.unregisterShuffle(20)
+ assert(slaveTracker.shufflePushMergerLocations.isEmpty)
Review comment:
Note: We would need to enhance this test with the additional case
mentioned above for `shufflePushMergerLocations`.
##########
File path:
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -4147,7 +4146,128 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
}
- /**
+ test("SPARK-34826: Adaptively fetch shuffle mergers") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 7
+
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+
+ runEvent(makeCompletionEvent(
+ taskSets(0).tasks(0), Success, makeMapStatus("hostA", parts),
+ Seq.empty, Array.empty, createFakeTaskInfoWithId(0)))
+
+ val shuffleStage1 =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ assert(shuffleStage1.shuffleDep.getMergerLocs.isEmpty)
+ assert(mapOutputTracker.getShufflePushMergerLocations(0).isEmpty)
+
+ DAGSchedulerSuite.addMergerLocs(Seq("host6", "host7", "host8"))
+
+ runEvent(makeCompletionEvent(
+ taskSets(0).tasks(1), Success, makeMapStatus("hostA", parts),
+ Seq.empty, Array.empty, createFakeTaskInfoWithId(1)))
+
+ // Dummy executor added event to trigger registering of shuffle merger
locations
+ // as shuffle mergers are tracked separately for test
+ runEvent(ExecutorAdded("dummy", "dummy"))
Review comment:
Let us make the host valid (say `host6` for example: from the new set of
hosts) - in case this code evolves in future.
##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1281,6 +1331,22 @@ private[spark] class MapOutputTrackerWorker(conf:
SparkConf) extends MapOutputTr
}
}
+ override def getShufflePushMergerLocations(shuffleId: Int):
Seq[BlockManagerId] = {
+ shufflePushMergerLocations.getOrElse(shuffleId,
getMergerLocations(shuffleId))
+ }
+
+ private def getMergerLocations(shuffleId: Int): Seq[BlockManagerId] = {
+ fetchingLock.withLock(shuffleId) {
+ val mergers =
askTracker[Seq[BlockManagerId]](GetShufflePushMergerLocations(shuffleId))
Review comment:
Check if `shufflePushMergerLocations(shuffleId)` is nonEmpty before
fetching it again - in case of concurrent updates to before we acquire the lock
for shuffleId
(See `getStatuses` or other uses of `fetchingLock` for how).
##########
File path: core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
##########
@@ -910,4 +910,32 @@ class MapOutputTrackerSuite extends SparkFunSuite with
LocalSparkContext {
rpcEnv.shutdown()
slaveRpcEnv.shutdown()
}
+
+ test("SPARK-34826: Adaptive shuffle mergers") {
+ val newConf = new SparkConf
+ newConf.set("spark.shuffle.push.based.enabled", "true")
+ newConf.set("spark.shuffle.service.enabled", "true")
+
+ // needs TorrentBroadcast so need a SparkContext
+ withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) {
sc =>
+ val masterTracker =
sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+ val rpcEnv = sc.env.rpcEnv
+ val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv,
masterTracker, newConf)
+ rpcEnv.stop(masterTracker.trackerEndpoint)
+ rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
+
+ val slaveTracker = new MapOutputTrackerWorker(newConf)
Review comment:
rename as `workerTracker`
Looks like some of the other push based shuffle PR's have introduced this
term again - we should rename them to `worker` as well.
Can you pls file a jira under the improvement tasks for this ? Thx.
##########
File path:
core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala
##########
@@ -59,13 +59,24 @@ private[spark] class ShuffleWriteProcessor extends
Serializable with Logging {
rdd.iterator(partition, context).asInstanceOf[Iterator[_ <:
Product2[Any, Any]]])
val mapStatus = writer.stop(success = true)
if (mapStatus.isDefined) {
+ // Check if sufficient shuffle mergers are available now for the
ShuffleMapTask to push
+ if (dep.shuffleMergeAllowed && dep.getMergerLocs.isEmpty &&
!dep.isShuffleMergeFinalized) {
Review comment:
We cannot have both `dep.getMergerLocs.isEmpty == true` and
`dep.isShuffleMergeFinalized == false`.
Drop `dep.isShuffleMergeFinalized` here ?
##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1364,6 +1430,7 @@ private[spark] class MapOutputTrackerWorker(conf:
SparkConf) extends MapOutputTr
def unregisterShuffle(shuffleId: Int): Unit = {
mapStatuses.remove(shuffleId)
mergeStatuses.remove(shuffleId)
+ shufflePushMergerLocations.remove(shuffleId)
Review comment:
Review Note: I was trying to see if `shufflePushMergerLocations.clear`
in `updateEpoch` was sufficient to handle the shuffle attempt id issue above -
but I am not so sure if that is enough.
##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1176,6 +1223,9 @@ private[spark] class MapOutputTrackerWorker(conf:
SparkConf) extends MapOutputTr
// instantiate a serializer. See the followup to SPARK-36705 for more
details.
private lazy val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf,
isDriver = false)
+ // Exposed for testing
+ val shufflePushMergerLocations = new ConcurrentHashMap[Int,
Seq[BlockManagerId]]().asScala
Review comment:
This needs to account for shuffle merge id as well - to ensure the
merger locations are consistent.
Possible flow is:
a) attempt 0 has enough mergers, and so executors cache value for attempt 0
mergers.
b) one or more hosts get blacklisted before attempt 1.
c) attempt 1 starts, with merge allowed = true, but merge enabled = false -
so tasks start without mergers populated in dependency (merge might get enabled
as more executors are added).
d) executors which ran attempt 0 will use mergers from previous attempt -
new executors will use mergers from attempt 1.
##########
File path:
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -4147,7 +4146,128 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
}
- /**
+ test("SPARK-34826: Adaptively fetch shuffle mergers") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+ DAGSchedulerSuite.clearMergerLocs
Review comment:
nit: This should be `clearMergerLocs()` given it is changing state -
though not introduced in this PR.
##########
File path: core/src/main/scala/org/apache/spark/Dependency.scala
##########
@@ -106,7 +106,17 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C:
ClassTag](
// By default, shuffle merge is enabled for ShuffleDependency if push based
shuffle
// is enabled
- private[this] var _shuffleMergeEnabled = canShuffleMergeBeEnabled()
+ private[this] var _shuffleMergeAllowed = canShuffleMergeBeEnabled()
+
+ private[spark] def setShuffleMergeAllowed(shuffleMergeAllowed: Boolean):
Unit = {
+ _shuffleMergeAllowed = shuffleMergeAllowed
+ }
+
+ def shuffleMergeAllowed : Boolean = _shuffleMergeAllowed
+
+ // By default, shuffle merge is enabled for ShuffleDependency if push based
shuffle
+ // is enabled
+ private[this] var _shuffleMergeEnabled = shuffleMergeAllowed
private[spark] def setShuffleMergeEnabled(shuffleMergeEnabled: Boolean):
Unit = {
_shuffleMergeEnabled = shuffleMergeEnabled
Review comment:
Do we need this boolean ? Or will `mergerLocs.nonEmpty` not do ? (`def
shuffleMergeEnabled : Boolean = mergerLocs.nonEmpty`)
Also, in `setShuffleMergeEnabled ` (if we are keeping it) or in
`setMergerLocs`, add`assert (! shuffleMergeEnabled || shuffleMergeAllowed)`
##########
File path:
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -4147,7 +4146,128 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
}
- /**
+ test("SPARK-34826: Adaptively fetch shuffle mergers") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 7
+
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+
+ runEvent(makeCompletionEvent(
+ taskSets(0).tasks(0), Success, makeMapStatus("hostA", parts),
+ Seq.empty, Array.empty, createFakeTaskInfoWithId(0)))
+
+ val shuffleStage1 =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ assert(shuffleStage1.shuffleDep.getMergerLocs.isEmpty)
+ assert(mapOutputTracker.getShufflePushMergerLocations(0).isEmpty)
+
+ DAGSchedulerSuite.addMergerLocs(Seq("host6", "host7", "host8"))
+
+ runEvent(makeCompletionEvent(
+ taskSets(0).tasks(1), Success, makeMapStatus("hostA", parts),
+ Seq.empty, Array.empty, createFakeTaskInfoWithId(1)))
+
+ // Dummy executor added event to trigger registering of shuffle merger
locations
+ // as shuffle mergers are tracked separately for test
+ runEvent(ExecutorAdded("dummy", "dummy"))
+
+ // Check if new shuffle merger locations are available for push or not
+ assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 7)
Review comment:
Shouldn't this not be `6` ?
##########
File path:
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -4147,7 +4146,128 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
}
- /**
+ test("SPARK-34826: Adaptively fetch shuffle mergers") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 7
+
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+
+ runEvent(makeCompletionEvent(
+ taskSets(0).tasks(0), Success, makeMapStatus("hostA", parts),
+ Seq.empty, Array.empty, createFakeTaskInfoWithId(0)))
+
+ val shuffleStage1 =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ assert(shuffleStage1.shuffleDep.getMergerLocs.isEmpty)
+ assert(mapOutputTracker.getShufflePushMergerLocations(0).isEmpty)
+
+ DAGSchedulerSuite.addMergerLocs(Seq("host6", "host7", "host8"))
+
+ runEvent(makeCompletionEvent(
+ taskSets(0).tasks(1), Success, makeMapStatus("hostA", parts),
+ Seq.empty, Array.empty, createFakeTaskInfoWithId(1)))
Review comment:
I am trying to understand why task`(1)` was completed and
`addMergerLocs` is called before that.
Any comments ? Thx
(I understand that `addMergerLocs` needs to be called to make mergers >= 6,
but why this specific flow ?)
##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -2487,6 +2501,21 @@ private[spark] class DAGScheduler(
executorFailureEpoch -= execId
}
shuffleFileLostEpoch -= execId
+
+ if (pushBasedShuffleEnabled) {
+ // Only set merger locations for stages that are not yet finished and
have empty mergers
+ shuffleIdToMapStage.filter { case (_, stage) =>
Review comment:
Scratch that, the cost is minimal at best after the re-order of filter
checks below.
##########
File path: core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
##########
@@ -39,7 +39,9 @@ class StageInfo(
val taskMetrics: TaskMetrics = null,
private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] =
Seq.empty,
private[spark] val shuffleDepId: Option[Int] = None,
- val resourceProfileId: Int) {
+ val resourceProfileId: Int,
+ private[spark] var isPushBasedShuffleEnabled: Boolean = false,
+ private[spark] var shuffleMergerCount: Int = 0) {
Review comment:
Sounds good, that PR is currently not consuming it - but if we are
updating to also surface this, makes sense to include it.
@thejdeep would be great if you can update #34000 once this is merged (or in
follow up if #34000 gets merged before this PR). Thx !
##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -2487,6 +2501,21 @@ private[spark] class DAGScheduler(
executorFailureEpoch -= execId
}
shuffleFileLostEpoch -= execId
+
+ if (pushBasedShuffleEnabled) {
Review comment:
Thanks for clarifying, makes sense.
##########
File path:
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -4147,7 +4146,128 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
}
- /**
+ test("SPARK-34826: Adaptively fetch shuffle mergers") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 7
+
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+
+ runEvent(makeCompletionEvent(
+ taskSets(0).tasks(0), Success, makeMapStatus("hostA", parts),
+ Seq.empty, Array.empty, createFakeTaskInfoWithId(0)))
+
+ val shuffleStage1 =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ assert(shuffleStage1.shuffleDep.getMergerLocs.isEmpty)
+ assert(mapOutputTracker.getShufflePushMergerLocations(0).isEmpty)
+
+ DAGSchedulerSuite.addMergerLocs(Seq("host6", "host7", "host8"))
+
+ runEvent(makeCompletionEvent(
+ taskSets(0).tasks(1), Success, makeMapStatus("hostA", parts),
+ Seq.empty, Array.empty, createFakeTaskInfoWithId(1)))
+
+ // Dummy executor added event to trigger registering of shuffle merger
locations
+ // as shuffle mergers are tracked separately for test
+ runEvent(ExecutorAdded("dummy", "dummy"))
+
+ // Check if new shuffle merger locations are available for push or not
+ assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 7)
+ assert(shuffleStage1.shuffleDep.getMergerLocs.size == 7)
+
+ // Complete remaining tasks in ShuffleMapStage 0
+ (2 to 6).foreach(x => {
+ runEvent(makeCompletionEvent(taskSets(0).tasks(x), Success,
makeMapStatus("hostA", parts),
+ Seq.empty, Array.empty, createFakeTaskInfoWithId(x)))
+ })
+
+ completeNextResultStageWithSuccess(1, 0)
+ assert(results === Map(0 -> 42, 1 -> 42, 2 -> 42, 3 -> 42, 4 -> 42, 5 ->
42, 6 -> 42))
+
+ results.clear()
+ assertDataStructuresEmpty()
+ }
+
+ test("SPARK-34826: Adaptively fetch shuffle mergers with stage retry") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 7
+
+ val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+ val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new
HashPartitioner(parts))
+ val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+ val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+ tracker = mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+
+ val taskResults = taskSets(0).tasks.zipWithIndex.map {
+ case (_, idx) =>
+ (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+ }.toSeq
+
+ val shuffleStage1 =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ DAGSchedulerSuite.addMergerLocs(Seq("host6", "host7", "host8"))
+ // Dummy executor added event to trigger registering of shuffle merger
locations
+ // as shuffle mergers are tracked separately for test
+ runEvent(ExecutorAdded("dummy", "dummy"))
+ // Check if new shuffle merger locations are available for push or not
+ assert(mapOutputTracker.getShufflePushMergerLocations(0).size == 7)
+ assert(shuffleStage1.shuffleDep.getMergerLocs.size == 7)
+ val mergerLocsBeforeRetry = shuffleStage1.shuffleDep.getMergerLocs
+
+ // Remove MapStatus on one of the host before the stage ends to trigger
+ // a scenario where stage 0 needs to be resubmitted upon finishing all
tasks.
+ // Merge finalization should be scheduled in this case.
+ for ((result, i) <- taskResults.zipWithIndex) {
+ if (i == taskSets(0).tasks.size - 1) {
+ mapOutputTracker.removeOutputsOnHost("hostA")
+ }
+ if (i < taskSets(0).tasks.size) {
+ runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1,
result._2))
+ }
+ }
+
+ assert(shuffleStage1.shuffleDep.isShuffleMergeFinalized)
+
+ // Successfully completing the retry of stage 0.
+ complete(taskSets(2), taskSets(2).tasks.zipWithIndex.map {
+ case (_, idx) =>
+ (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+ }.toSeq)
+
+ DAGSchedulerSuite.addMergerLocs(Seq("host9", "host10"))
+ // Dummy executor added event to trigger registering of shuffle merger
locations
+ runEvent(ExecutorAdded("dummy1", "dummy1"))
+ assert(shuffleStage1.shuffleDep.getMergerLocs.size == 7)
+ assert(shuffleStage1.shuffleDep.isShuffleMergeFinalized)
+ complete(taskSets(1), taskSets(1).tasks.zipWithIndex.map {
+ case (_, idx) =>
+ (Success, makeMapStatus("host" + ('A' + idx).toChar, parts, 10))
+ }.toSeq)
+ val shuffleStage2 =
scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage]
+ // Shuffle merger locs should not be refreshed as the shuffle is already
finalized
+ assert(mergerLocsBeforeRetry.sortBy(_.host) ===
+ shuffleStage1.shuffleDep.getMergerLocs.sortBy(_.host))
Review comment:
We need to test for both DETERMINATE stages and INDETERMINATE stages.
We should force a new set of mergers for INDETERMINATE stage retry - In a
'real world case', this would mean adding a previously selected merger to the
excludeNodes - but for the test, we can simply clear `DAGSchedulerSuite
.mergerLocs` and add a different set of hosts there.
##########
File path:
core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
##########
@@ -131,7 +131,7 @@ private[spark] class SortShuffleManager(conf: SparkConf)
extends ShuffleManager
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]]
val (blocksByAddress, canEnableBatchFetch) =
- if (baseShuffleHandle.dependency.shuffleMergeEnabled) {
+ if (baseShuffleHandle.dependency.isShuffleMergeFinalized) {
Review comment:
Thanks @venkata91 !
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]