This is an automated email from the ASF dual-hosted git repository.
mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b5a1503 [SPARK-32920][SHUFFLE] Finalization of Shuffle push/merge
with Push based shuffle and preparation step for the reduce stage
b5a1503 is described below
commit b5a15035851bfba12ef1c68d10103cec42cbac0c
Author: Venkata krishnan Sowrirajan <[email protected]>
AuthorDate: Thu Jun 10 13:06:15 2021 -0500
[SPARK-32920][SHUFFLE] Finalization of Shuffle push/merge with Push based
shuffle and preparation step for the reduce stage
### What changes were proposed in this pull request?
Summary of the changes made as part of this PR:
1. `DAGScheduler` changes to finalize a ShuffleMapStage which involves
talking to all the shuffle mergers (`ExternalShuffleService`) and getting all
the completed merge statuses.
2. Once the `ShuffleMapStage` finalization is complete, mark the
`ShuffleMapStage` to be finalized which marks the stage as complete and
subsequently letting the child stage start.
3. Also added the relevant tests to `DAGSchedulerSuite` for changes made as
part of [SPARK-32919](https://issues.apache.org/jira/browse/SPARK-32919)
Lead-authored-by: Min Shen mshenlinkedin.com
Co-authored-by: Venkata krishnan Sowrirajan vsowrirajanlinkedin.com
Co-authored-by: Chandni Singh chsinghlinkedin.com
### Why are the changes needed?
Refer to [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602)
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added unit tests to DAGSchedulerSuite
Closes #30691 from venkata91/SPARK-32920.
Lead-authored-by: Venkata krishnan Sowrirajan <[email protected]>
Co-authored-by: Min Shen <[email protected]>
Co-authored-by: Chandni Singh <[email protected]>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
.../main/scala/org/apache/spark/Dependency.scala | 38 ++
.../scala/org/apache/spark/MapOutputTracker.scala | 44 +-
.../org/apache/spark/internal/config/package.scala | 23 +-
.../org/apache/spark/scheduler/DAGScheduler.scala | 257 +++++++++---
.../apache/spark/scheduler/DAGSchedulerEvent.scala | 6 +
.../org/apache/spark/scheduler/StageInfo.scala | 2 +-
...g.apache.spark.scheduler.ExternalClusterManager | 1 +
.../apache/spark/scheduler/DAGSchedulerSuite.scala | 448 ++++++++++++++++++++-
8 files changed, 747 insertions(+), 72 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala
b/core/src/main/scala/org/apache/spark/Dependency.scala
index d21b9d9..0a9acf4 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor}
import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -96,12 +97,31 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C:
ClassTag](
val shuffleHandle: ShuffleHandle =
_rdd.context.env.shuffleManager.registerShuffle(
shuffleId, this)
+ // By default, shuffle merge is enabled for ShuffleDependency if push based
shuffle
+ // is enabled
+ private[this] var _shuffleMergeEnabled =
+ Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf) &&
+ // TODO: SPARK-35547: Push based shuffle is currently unsupported for
Barrier stages
+ !rdd.isBarrier()
+
+ private[spark] def setShuffleMergeEnabled(shuffleMergeEnabled: Boolean):
Unit = {
+ _shuffleMergeEnabled = shuffleMergeEnabled
+ }
+
+ def shuffleMergeEnabled : Boolean = _shuffleMergeEnabled
+
/**
* Stores the location of the list of chosen external shuffle services for
handling the
* shuffle merge requests from mappers in this shuffle map stage.
*/
private[spark] var mergerLocs: Seq[BlockManagerId] = Nil
+ /**
+ * Stores the information about whether the shuffle merge is finalized for
the shuffle map stage
+ * associated with this shuffle dependency
+ */
+ private[this] var _shuffleMergedFinalized: Boolean = false
+
def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = {
if (mergerLocs != null) {
this.mergerLocs = mergerLocs
@@ -110,6 +130,24 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C:
ClassTag](
def getMergerLocs: Seq[BlockManagerId] = mergerLocs
+ private[spark] def markShuffleMergeFinalized(): Unit = {
+ _shuffleMergedFinalized = true
+ }
+
+ /**
+ * Returns true if push-based shuffle is disabled for this stage or empty
RDD,
+ * or if the shuffle merge for this stage is finalized, i.e. the shuffle
merge
+ * results for all partitions are available.
+ */
+ def shuffleMergeFinalized: Boolean = {
+ // Empty RDD won't be computed therefore shuffle merge finalized should be
true by default.
+ if (shuffleMergeEnabled && rdd.getNumPartitions > 0) {
+ _shuffleMergedFinalized
+ } else {
+ true
+ }
+ }
+
_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
_rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId)
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index ea9e641..9f2228b 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -214,6 +214,7 @@ private class ShuffleStatus(
def removeOutputsOnHost(host: String): Unit = withWriteLock {
logDebug(s"Removing outputs for host ${host}")
removeOutputsByFilter(x => x.host == host)
+ removeMergeResultsByFilter(x => x.host == host)
}
/**
@@ -238,6 +239,12 @@ private class ShuffleStatus(
invalidateSerializedMapOutputStatusCache()
}
}
+ }
+
+ /**
+ * Removes all shuffle merge result which satisfies the filter.
+ */
+ def removeMergeResultsByFilter(f: BlockManagerId => Boolean): Unit =
withWriteLock {
for (reduceId <- mergeStatuses.indices) {
if (mergeStatuses(reduceId) != null &&
f(mergeStatuses(reduceId).location)) {
_numAvailableMergeResults -= 1
@@ -708,15 +715,16 @@ private[spark] class MapOutputTrackerMaster(
}
}
- /** Unregister all map output information of the given shuffle. */
- def unregisterAllMapOutput(shuffleId: Int): Unit = {
+ /** Unregister all map and merge output information of the given shuffle. */
+ def unregisterAllMapAndMergeOutput(shuffleId: Int): Unit = {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.removeOutputsByFilter(x => true)
+ shuffleStatus.removeMergeResultsByFilter(x => true)
incrementEpoch()
case None =>
throw new SparkException(
- s"unregisterAllMapOutput called for nonexistent shuffle ID
$shuffleId.")
+ s"unregisterAllMapAndMergeOutput called for nonexistent shuffle ID
$shuffleId.")
}
}
@@ -731,25 +739,26 @@ private[spark] class MapOutputTrackerMaster(
}
/**
- * Unregisters a merge result corresponding to the reduceId if present. If
the optional mapId
- * is specified, it will only unregister the merge result if the mapId is
part of that merge
+ * Unregisters a merge result corresponding to the reduceId if present. If
the optional mapIndex
+ * is specified, it will only unregister the merge result if the mapIndex is
part of that merge
* result.
*
* @param shuffleId the shuffleId.
* @param reduceId the reduceId.
* @param bmAddress block manager address.
- * @param mapId the optional mapId which should be checked to see it was
part of the merge
- * result.
+ * @param mapIndex the optional mapIndex which should be checked to see it
was part of the
+ * merge result.
*/
def unregisterMergeResult(
- shuffleId: Int,
- reduceId: Int,
- bmAddress: BlockManagerId,
- mapId: Option[Int] = None): Unit = {
+ shuffleId: Int,
+ reduceId: Int,
+ bmAddress: BlockManagerId,
+ mapIndex: Option[Int] = None): Unit = {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
val mergeStatus = shuffleStatus.mergeStatuses(reduceId)
- if (mergeStatus != null && (mapId.isEmpty ||
mergeStatus.tracker.contains(mapId.get))) {
+ if (mergeStatus != null &&
+ (mapIndex.isEmpty || mergeStatus.tracker.contains(mapIndex.get))) {
shuffleStatus.removeMergeResult(reduceId, bmAddress)
incrementEpoch()
}
@@ -758,6 +767,17 @@ private[spark] class MapOutputTrackerMaster(
}
}
+ def unregisterAllMergeResult(shuffleId: Int): Unit = {
+ shuffleStatuses.get(shuffleId) match {
+ case Some(shuffleStatus) =>
+ shuffleStatus.removeMergeResultsByFilter(x => true)
+ incrementEpoch()
+ case None =>
+ throw new SparkException(
+ s"unregisterAllMergeResult called for nonexistent shuffle ID
$shuffleId.")
+ }
+ }
+
/** Unregister shuffle data */
def unregisterShuffle(shuffleId: Int): Unit = {
shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala
b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 9574416..84bd8cc 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -2084,6 +2084,27 @@ package object config {
.booleanConf
.createWithDefault(false)
+ private[spark] val PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT =
+ ConfigBuilder("spark.shuffle.push.merge.results.timeout")
+ .doc("Specify the max amount of time DAGScheduler waits for the merge
results from " +
+ "all remote shuffle services for a given shuffle. DAGScheduler will
start to submit " +
+ "following stages if not all results are received within the timeout.")
+ .version("3.2.0")
+ .timeConf(TimeUnit.SECONDS)
+ .checkValue(_ >= 0L, "Timeout must be >= 0.")
+ .createWithDefaultString("10s")
+
+ private[spark] val PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT =
+ ConfigBuilder("spark.shuffle.push.merge.finalize.timeout")
+ .doc("Specify the amount of time DAGScheduler waits after all mappers
finish for " +
+ "a given shuffle map stage before it starts sending merge finalize
requests to " +
+ "remote shuffle services. This allows the shuffle services some extra
time to " +
+ "merge as many blocks as possible.")
+ .version("3.2.0")
+ .timeConf(TimeUnit.SECONDS)
+ .checkValue(_ >= 0L, "Timeout must be >= 0.")
+ .createWithDefaultString("10s")
+
private[spark] val SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS =
ConfigBuilder("spark.shuffle.push.maxRetainedMergerLocations")
.doc("Maximum number of shuffle push merger locations cached for push
based shuffle. " +
@@ -2117,7 +2138,7 @@ package object config {
s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} set to 0.05, we
would need " +
"at least 50 mergers to enable push based shuffle for that stage.")
.version("3.1.0")
- .doubleConf
+ .intConf
.createWithDefault(5)
private[spark] val SHUFFLE_NUM_PUSH_THREADS =
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b359501..1f37638 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
import java.io.NotSerializableException
import java.util.Properties
-import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, TimeoutException, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
import scala.annotation.tailrec
@@ -29,12 +29,16 @@ import scala.collection.mutable.{HashMap, HashSet,
ListBuffer}
import scala.concurrent.duration._
import scala.util.control.NonFatal
+import com.google.common.util.concurrent.{Futures, SettableFuture}
+
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config
import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY
+import org.apache.spark.network.shuffle.{BlockStoreClient,
MergeFinalizerListener}
+import org.apache.spark.network.shuffle.protocol.MergeStatuses
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.partial.{ApproximateActionListener,
ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
@@ -254,6 +258,24 @@ private[spark] class DAGScheduler(
private val blockManagerMasterDriverHeartbeatTimeout =
sc.getConf.get(config.STORAGE_BLOCKMANAGER_MASTER_DRIVER_HEARTBEAT_TIMEOUT).millis
+ private val shuffleMergeResultsTimeoutSec =
+ sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT)
+
+ private val shuffleMergeFinalizeWaitSec =
+ sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT)
+
+ // Since SparkEnv gets initialized after DAGScheduler, externalShuffleClient
needs to be
+ // initialized lazily
+ private lazy val externalShuffleClient: Option[BlockStoreClient] =
+ if (pushBasedShuffleEnabled) {
+ Some(env.blockManager.blockStoreClient)
+ } else {
+ None
+ }
+
+ private val shuffleMergeFinalizeScheduler =
+
ThreadUtils.newDaemonThreadPoolScheduledExecutor("shuffle-merge-finalizer", 8)
+
/**
* Called by the TaskSetManager to report task's starting.
*/
@@ -689,7 +711,10 @@ private[spark] class DAGScheduler(
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getOrCreateShuffleMapStage(shufDep,
stage.firstJobId)
- if (!mapStage.isAvailable) {
+ // Mark mapStage as available with shuffle outputs only after
shuffle merge is
+ // finalized with push based shuffle. If not, subsequent
ShuffleMapStage won't
+ // read from merged output as the MergeStatuses are not
available.
+ if (!mapStage.isAvailable ||
!mapStage.shuffleDep.shuffleMergeFinalized) {
missing += mapStage
}
case narrowDep: NarrowDependency[_] =>
@@ -1271,21 +1296,21 @@ private[spark] class DAGScheduler(
* locations for block push/merge by getting the historical locations of
past executors.
*/
private def prepareShuffleServicesForShuffleMapStage(stage:
ShuffleMapStage): Unit = {
- // TODO(SPARK-32920) Handle stage reuse/retry cases separately as without
finalize
- // TODO changes we cannot disable shuffle merge for the retry/reuse cases
- val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
- stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
-
- if (mergerLocs.nonEmpty) {
- stage.shuffleDep.setMergerLocs(mergerLocs)
- logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
- s" ${stage.shuffleDep.getMergerLocs.size} merger locations")
-
- logDebug("List of shuffle push merger locations " +
- s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
- } else {
- logInfo("No available merger locations." +
- s" Push-based shuffle disabled for $stage (${stage.name})")
+ assert(stage.shuffleDep.shuffleMergeEnabled &&
!stage.shuffleDep.shuffleMergeFinalized)
+ if (stage.shuffleDep.getMergerLocs.isEmpty) {
+ val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
+ stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+ if (mergerLocs.nonEmpty) {
+ stage.shuffleDep.setMergerLocs(mergerLocs)
+ logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
+ s" ${stage.shuffleDep.getMergerLocs.size} merger locations")
+
+ logDebug("List of shuffle push merger locations " +
+ s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+ } else {
+ stage.shuffleDep.setShuffleMergeEnabled(false)
+ logInfo("Push-based shuffle disabled for $stage (${stage.name})")
+ }
}
}
@@ -1298,7 +1323,9 @@ private[spark] class DAGScheduler(
// `findMissingPartitions()` returns all partitions every time.
stage match {
case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
- mapOutputTracker.unregisterAllMapOutput(sms.shuffleDep.shuffleId)
+ // TODO: SPARK-32923: Clean all push-based shuffle metadata like merge
enabled and
+ // TODO: finalized as we are clearing all the merge results.
+
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
case _ =>
}
@@ -1318,11 +1345,19 @@ private[spark] class DAGScheduler(
stage match {
case s: ShuffleMapStage =>
outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId =
s.numPartitions - 1)
- // Only generate merger location for a given shuffle dependency once.
This way, even if
- // this stage gets retried, it would still be merging blocks using the
same set of
- // shuffle services.
- if (pushBasedShuffleEnabled) {
- prepareShuffleServicesForShuffleMapStage(s)
+ // Only generate merger location for a given shuffle dependency once.
+ if (s.shuffleDep.shuffleMergeEnabled) {
+ if (!s.shuffleDep.shuffleMergeFinalized) {
+ prepareShuffleServicesForShuffleMapStage(s)
+ } else {
+ // Disable Shuffle merge for the retry/reuse of the same shuffle
dependency if it has
+ // already been merge finalized. If the shuffle dependency was
previously assigned
+ // merger locations but the corresponding shuffle map stage did
not complete
+ // successfully, we would still enable push for its retry.
+ s.shuffleDep.setShuffleMergeEnabled(false)
+ logInfo("Push-based shuffle disabled for $stage (${stage.name})
since it" +
+ " is already shuffle merge finalized")
+ }
}
case s: ResultStage =>
outputCommitCoordinator.stageStart(
@@ -1678,38 +1713,16 @@ private[spark] class DAGScheduler(
}
if (runningStages.contains(shuffleStage) &&
shuffleStage.pendingPartitions.isEmpty) {
- markStageAsFinished(shuffleStage)
- logInfo("looking for newly runnable stages")
- logInfo("running: " + runningStages)
- logInfo("waiting: " + waitingStages)
- logInfo("failed: " + failedStages)
-
- // This call to increment the epoch may not be strictly
necessary, but it is retained
- // for now in order to minimize the changes in behavior from an
earlier version of the
- // code. This existing behavior of always incrementing the epoch
following any
- // successful shuffle map stage completion may have benefits by
causing unneeded
- // cached map outputs to be cleaned up earlier on executors. In
the future we can
- // consider removing this call, but this will require some extra
investigation.
- // See
https://github.com/apache/spark/pull/17955/files#r117385673 for more details.
- mapOutputTracker.incrementEpoch()
-
- clearCacheLocs()
-
- if (!shuffleStage.isAvailable) {
- // Some tasks had failed; let's resubmit this shuffleStage.
- // TODO: Lower-level scheduler should also deal with this
- logInfo("Resubmitting " + shuffleStage + " (" +
shuffleStage.name +
- ") because some of its tasks had failed: " +
- shuffleStage.findMissingPartitions().mkString(", "))
- submitStage(shuffleStage)
+ if (!shuffleStage.shuffleDep.shuffleMergeFinalized &&
+ shuffleStage.shuffleDep.getMergerLocs.nonEmpty) {
+ scheduleShuffleMergeFinalize(shuffleStage)
} else {
- markMapStageJobsAsFinished(shuffleStage)
- submitWaitingChildStages(shuffleStage)
+ processShuffleMapStageCompletion(shuffleStage)
}
}
}
- case FetchFailed(bmAddress, shuffleId, _, mapIndex, _, failureMessage) =>
+ case FetchFailed(bmAddress, shuffleId, _, mapIndex, reduceId,
failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleIdToMapStage(shuffleId)
@@ -1739,10 +1752,18 @@ private[spark] class DAGScheduler(
if (mapStage.rdd.isBarrier()) {
// Mark all the map as broken in the map stage, to ensure retry
all the tasks on
// resubmitted stage attempt.
- mapOutputTracker.unregisterAllMapOutput(shuffleId)
+ // TODO: SPARK-35547: Clean all push-based shuffle metadata like
merge enabled and
+ // TODO: finalized as we are clearing all the merge results.
+ mapOutputTracker.unregisterAllMapAndMergeOutput(shuffleId)
} else if (mapIndex != -1) {
// Mark the map whose fetch failed as broken in the map stage
mapOutputTracker.unregisterMapOutput(shuffleId, mapIndex,
bmAddress)
+ if (pushBasedShuffleEnabled) {
+ // Possibly unregister the merge result <shuffleId, reduceId>,
if the FetchFailed
+ // mapIndex is part of the merge result of <shuffleId, reduceId>
+ mapOutputTracker.
+ unregisterMergeResult(shuffleId, reduceId, bmAddress,
Option(mapIndex))
+ }
}
if (failedStage.rdd.isBarrier()) {
@@ -1750,7 +1771,7 @@ private[spark] class DAGScheduler(
case failedMapStage: ShuffleMapStage =>
// Mark all the map as broken in the map stage, to ensure
retry all the tasks on
// resubmitted stage attempt.
-
mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId)
+
mapOutputTracker.unregisterAllMapAndMergeOutput(failedMapStage.shuffleDep.shuffleId)
case failedResultStage: ResultStage =>
// Abort the failed result stage since we may have committed
output for some
@@ -1959,7 +1980,7 @@ private[spark] class DAGScheduler(
case failedMapStage: ShuffleMapStage =>
// Mark all the map as broken in the map stage, to ensure
retry all the tasks on
// resubmitted stage attempt.
-
mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId)
+
mapOutputTracker.unregisterAllMapAndMergeOutput(failedMapStage.shuffleDep.shuffleId)
case failedResultStage: ResultStage =>
// Abort the failed result stage since we may have committed
output for some
@@ -2000,6 +2021,130 @@ private[spark] class DAGScheduler(
}
}
+ /**
+ * Schedules shuffle merge finalize.
+ */
+ private[scheduler] def scheduleShuffleMergeFinalize(stage: ShuffleMapStage):
Unit = {
+ // TODO: SPARK-33701: Instead of waiting for a constant amount of time for
finalization
+ // TODO: for all the stages, adaptively tune timeout for merge finalization
+ logInfo(("%s (%s) scheduled for finalizing" +
+ " shuffle merge in %s s").format(stage, stage.name,
shuffleMergeFinalizeWaitSec))
+ shuffleMergeFinalizeScheduler.schedule(
+ new Runnable {
+ override def run(): Unit = finalizeShuffleMerge(stage)
+ },
+ shuffleMergeFinalizeWaitSec,
+ TimeUnit.SECONDS
+ )
+ }
+
+ /**
+ * DAGScheduler notifies all the remote shuffle services chosen to serve
shuffle merge request for
+ * the given shuffle map stage to finalize the shuffle merge process for
this shuffle. This is
+ * invoked in a separate thread to reduce the impact on the DAGScheduler
main thread, as the
+ * scheduler might need to talk to 1000s of shuffle services to finalize
shuffle merge.
+ */
+ private[scheduler] def finalizeShuffleMerge(stage: ShuffleMapStage): Unit = {
+ logInfo("%s (%s) finalizing the shuffle merge".format(stage, stage.name))
+ externalShuffleClient.foreach { shuffleClient =>
+ val shuffleId = stage.shuffleDep.shuffleId
+ val numMergers = stage.shuffleDep.getMergerLocs.length
+ val results = (0 until numMergers).map(_ =>
SettableFuture.create[Boolean]())
+
+ stage.shuffleDep.getMergerLocs.zipWithIndex.foreach {
+ case (shuffleServiceLoc, index) =>
+ // Sends async request to shuffle service to finalize shuffle merge
on that host
+ // TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage is
cancelled
+ // TODO: during shuffleMergeFinalizeWaitSec
+ shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
+ shuffleServiceLoc.port, shuffleId,
+ new MergeFinalizerListener {
+ override def onShuffleMergeSuccess(statuses: MergeStatuses):
Unit = {
+ assert(shuffleId == statuses.shuffleId)
+ eventProcessLoop.post(RegisterMergeStatuses(stage, MergeStatus.
+ convertMergeStatusesToMergeStatusArr(statuses,
shuffleServiceLoc)))
+ results(index).set(true)
+ }
+
+ override def onShuffleMergeFailure(e: Throwable): Unit = {
+ logWarning(s"Exception encountered when trying to finalize
shuffle " +
+ s"merge on ${shuffleServiceLoc.host} for shuffle
$shuffleId", e)
+ // Do not fail the future as this would cause dag scheduler to
prematurely
+ // give up on waiting for merge results from the remaining
shuffle services
+ // if one fails
+ results(index).set(false)
+ }
+ })
+ }
+ // DAGScheduler only waits for a limited amount of time for the merge
results.
+ // It will attempt to submit the next stage(s) irrespective of whether
merge results
+ // from all shuffle services are received or not.
+ try {
+ Futures.allAsList(results: _*).get(shuffleMergeResultsTimeoutSec,
TimeUnit.SECONDS)
+ } catch {
+ case _: TimeoutException =>
+ logInfo(s"Timed out on waiting for merge results from all " +
+ s"$numMergers mergers for shuffle $shuffleId")
+ } finally {
+ eventProcessLoop.post(ShuffleMergeFinalized(stage))
+ }
+ }
+ }
+
+ private def processShuffleMapStageCompletion(shuffleStage: ShuffleMapStage):
Unit = {
+ markStageAsFinished(shuffleStage)
+ logInfo("looking for newly runnable stages")
+ logInfo("running: " + runningStages)
+ logInfo("waiting: " + waitingStages)
+ logInfo("failed: " + failedStages)
+
+ // This call to increment the epoch may not be strictly necessary, but it
is retained
+ // for now in order to minimize the changes in behavior from an earlier
version of the
+ // code. This existing behavior of always incrementing the epoch following
any
+ // successful shuffle map stage completion may have benefits by causing
unneeded
+ // cached map outputs to be cleaned up earlier on executors. In the future
we can
+ // consider removing this call, but this will require some extra
investigation.
+ // See https://github.com/apache/spark/pull/17955/files#r117385673 for
more details.
+ mapOutputTracker.incrementEpoch()
+
+ clearCacheLocs()
+
+ if (!shuffleStage.isAvailable) {
+ // Some tasks had failed; let's resubmit this shuffleStage.
+ // TODO: Lower-level scheduler should also deal with this
+ logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name +
+ ") because some of its tasks had failed: " +
+ shuffleStage.findMissingPartitions().mkString(", "))
+ submitStage(shuffleStage)
+ } else {
+ markMapStageJobsAsFinished(shuffleStage)
+ submitWaitingChildStages(shuffleStage)
+ }
+ }
+
+ private[scheduler] def handleRegisterMergeStatuses(
+ stage: ShuffleMapStage,
+ mergeStatuses: Seq[(Int, MergeStatus)]): Unit = {
+ // Register merge statuses if the stage is still running and shuffle merge
is not finalized yet.
+ // TODO: SPARK-35549: Currently merge statuses results which come after
shuffle merge
+ // TODO: is finalized is not registered.
+ if (runningStages.contains(stage) &&
!stage.shuffleDep.shuffleMergeFinalized) {
+ mapOutputTracker.registerMergeResults(stage.shuffleDep.shuffleId,
mergeStatuses)
+ }
+ }
+
+ private[scheduler] def handleShuffleMergeFinalized(stage: ShuffleMapStage):
Unit = {
+ // Only update MapOutputTracker metadata if the stage is still active. i.e
not cancelled.
+ if (runningStages.contains(stage)) {
+ stage.shuffleDep.markShuffleMergeFinalized()
+ processShuffleMapStageCompletion(stage)
+ } else {
+ // Unregister all merge results if the stage is currently not
+ // active (i.e. the stage is cancelled)
+ mapOutputTracker.unregisterAllMergeResult(stage.shuffleDep.shuffleId)
+ }
+ }
+
private def handleResubmittedFailure(task: Task[_], stage: Stage): Unit = {
logInfo(s"Resubmitted $task, so marking it as still running.")
stage match {
@@ -2447,6 +2592,12 @@ private[scheduler] class
DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case ResubmitFailedStages =>
dagScheduler.resubmitFailedStages()
+
+ case RegisterMergeStatuses(stage, mergeStatuses) =>
+ dagScheduler.handleRegisterMergeStatuses(stage, mergeStatuses)
+
+ case ShuffleMergeFinalized(stage) =>
+ dagScheduler.handleShuffleMergeFinalized(stage)
}
override def onError(e: Throwable): Unit = {
diff --git
a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index d226fe8..307844c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -105,3 +105,9 @@ private[scheduler]
case class UnschedulableTaskSetRemoved(stageId: Int, stageAttemptId: Int)
extends DAGSchedulerEvent
+private[scheduler] case class RegisterMergeStatuses(
+ stage: ShuffleMapStage, mergeStatuses: Seq[(Int, MergeStatus)])
+ extends DAGSchedulerEvent
+
+private[scheduler] case class ShuffleMergeFinalized(stage: ShuffleMapStage)
+ extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index 556478d..7b681bf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -42,7 +42,7 @@ class StageInfo(
val resourceProfileId: Int) {
/** When this stage was submitted from the DAGScheduler to a TaskScheduler.
*/
var submissionTime: Option[Long] = None
- /** Time when all tasks in the stage completed or when the stage was
cancelled. */
+ /** Time when the stage completed or when the stage was cancelled. */
var completionTime: Option[Long] = None
/** If the stage failed, the reason why. */
var failureReason: Option[String] = None
diff --git
a/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager
b/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager
index 60054c8..33b162e 100644
---
a/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager
+++
b/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager
@@ -1,3 +1,4 @@
org.apache.spark.scheduler.DummyExternalClusterManager
org.apache.spark.scheduler.MockExternalClusterManager
org.apache.spark.scheduler.CSMockExternalClusterManager
+org.apache.spark.scheduler.PushBasedClusterManager
diff --git
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 4c74e4f..f6e87ee 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -25,9 +25,8 @@ import scala.annotation.meta.param
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.util.control.NonFatal
-import org.mockito.Mockito.spy
-import org.mockito.Mockito.times
-import org.mockito.Mockito.verify
+import org.mockito.Mockito._
+import org.roaringbitmap.RoaringBitmap
import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
import org.scalatest.exceptions.TestFailedException
import org.scalatest.time.SpanSugar._
@@ -40,9 +39,10 @@ import org.apache.spark.rdd.{DeterministicLevel, RDD}
import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile,
ResourceProfileBuilder, TaskResourceRequests}
import org.apache.spark.resource.ResourceUtils.{FPGA, GPU}
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.scheduler.local.LocalSchedulerBackend
import org.apache.spark.shuffle.{FetchFailedException,
MetadataFetchFailedException}
import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
-import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite,
LongAccumulator, Utils}
+import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite,
Clock, LongAccumulator, SystemClock, Utils}
class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
extends DAGSchedulerEventProcessLoop(dagScheduler) {
@@ -295,6 +295,35 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
}
}
+ class MyDAGScheduler(
+ sc: SparkContext,
+ taskScheduler: TaskScheduler,
+ listenerBus: LiveListenerBus,
+ mapOutputTracker: MapOutputTrackerMaster,
+ blockManagerMaster: BlockManagerMaster,
+ env: SparkEnv,
+ clock: Clock = new SystemClock(),
+ shuffleMergeFinalize: Boolean = true,
+ shuffleMergeRegister: Boolean = true
+ ) extends DAGScheduler(
+ sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster,
env, clock) {
+ /**
+ * Schedules shuffle merge finalize.
+ */
+ override private[scheduler] def scheduleShuffleMergeFinalize(
+ shuffleMapStage: ShuffleMapStage): Unit = {
+ if (shuffleMergeRegister) {
+ for (part <- 0 until
shuffleMapStage.shuffleDep.partitioner.numPartitions) {
+ val mergeStatuses = Seq((part, makeMergeStatus("")))
+ handleRegisterMergeStatuses(shuffleMapStage, mergeStatuses)
+ }
+ if (shuffleMergeFinalize) {
+ handleShuffleMergeFinalized(shuffleMapStage)
+ }
+ }
+ }
+ }
+
override def beforeEach(): Unit = {
super.beforeEach()
firstInit = true
@@ -322,13 +351,14 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
broadcastManager = new BroadcastManager(true, sc.getConf)
mapOutputTracker = spy(new MyMapOutputTrackerMaster(sc.getConf,
broadcastManager))
blockManagerMaster = spy(new MyBlockManagerMaster(sc.getConf))
- scheduler = new DAGScheduler(
+ scheduler = new MyDAGScheduler(
sc,
taskScheduler,
sc.listenerBus,
mapOutputTracker,
blockManagerMaster,
sc.env)
+
dagEventProcessLoopTester = new
DAGSchedulerEventProcessLoopTester(scheduler)
}
@@ -3393,6 +3423,359 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
assert(rprofsE === Set())
}
+ private def initPushBasedShuffleConfs(conf: SparkConf) = {
+ conf.set(config.SHUFFLE_SERVICE_ENABLED, true)
+ conf.set(config.PUSH_BASED_SHUFFLE_ENABLED, true)
+ conf.set("spark.master", "pushbasedshuffleclustermanager")
+ }
+
+ test("SPARK-32920: shuffle merge finalization") {
+ initPushBasedShuffleConfs(conf)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 2
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+ completeShuffleMapStageSuccessfully(0, 0, parts)
+ assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId)
== parts)
+ completeNextResultStageWithSuccess(1, 0)
+ assert(results === Map(0 -> 42, 1 -> 42))
+ results.clear()
+ assertDataStructuresEmpty()
+ }
+
+ test("SPARK-32920: merger locations not empty") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 3)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 2
+
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+ completeShuffleMapStageSuccessfully(0, 0, parts)
+ val shuffleStage =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ assert(shuffleStage.shuffleDep.getMergerLocs.nonEmpty)
+
+ assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId)
== parts)
+ completeNextResultStageWithSuccess(1, 0)
+ assert(results === Map(0 -> 42, 1 -> 42))
+
+ results.clear()
+ assertDataStructuresEmpty()
+ }
+
+ test("SPARK-32920: merger locations reuse from shuffle dependency") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS, 3)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 2
+
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+ submit(reduceRdd, Array(0, 1))
+
+ completeShuffleMapStageSuccessfully(0, 0, parts)
+ assert(shuffleDep.getMergerLocs.nonEmpty)
+ val mergerLocs = shuffleDep.getMergerLocs
+ completeNextResultStageWithSuccess(1, 0 )
+
+ // submit another job w/ the shared dependency, and have a fetch failure
+ val reduce2 = new MyRDD(sc, 2, List(shuffleDep))
+ submit(reduce2, Array(0, 1))
+ // Note that the stage numbering here is only b/c the shared dependency
produces a new, skipped
+ // stage. If instead it reused the existing stage, then this would be
stage 2
+ completeNextStageWithFetchFailure(3, 0, shuffleDep)
+ scheduler.resubmitFailedStages()
+
+ assert(scheduler.runningStages.nonEmpty)
+ assert(scheduler.stageIdToStage(2)
+ .asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs.nonEmpty)
+ val newMergerLocs = scheduler.stageIdToStage(2)
+ .asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs
+
+ // Check if same merger locs is reused for the new stage with shared
shuffle dependency
+ assert(mergerLocs.zip(newMergerLocs).forall(x => x._1.host == x._2.host))
+ completeShuffleMapStageSuccessfully(2, 0, 2)
+ completeNextResultStageWithSuccess(3, 1, idx => idx + 1234)
+ assert(results === Map(0 -> 1234, 1 -> 1235))
+
+ assertDataStructuresEmpty()
+ }
+
+ test("SPARK-32920: Disable shuffle merge due to not enough mergers
available") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 7
+
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+ completeShuffleMapStageSuccessfully(0, 0, parts)
+ val shuffleStage =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ assert(!shuffleStage.shuffleDep.shuffleMergeEnabled)
+
+ completeNextResultStageWithSuccess(1, 0)
+ assert(results === Map(2 -> 42, 5 -> 42, 4 -> 42, 1 -> 42, 3 -> 42, 6 ->
42, 0 -> 42))
+
+ results.clear()
+ assertDataStructuresEmpty()
+ }
+
+ test("SPARK-32920: Ensure child stage should not start before all the" +
+ " parent stages are completed with shuffle merge finalized for all the
parent stages") {
+ initPushBasedShuffleConfs(conf)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 1
+ val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+ val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new
HashPartitioner(parts))
+
+ val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+ val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new
HashPartitioner(parts))
+
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
tracker = mapOutputTracker)
+
+ // Submit a reduce job
+ submit(reduceRdd, (0 until parts).toArray)
+
+ complete(taskSets(0), Seq((Success, makeMapStatus("hostA", 1))))
+ val shuffleStage1 =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ assert(shuffleStage1.shuffleDep.getMergerLocs.nonEmpty)
+
+ complete(taskSets(1), Seq((Success, makeMapStatus("hostA", 1))))
+ val shuffleStage2 =
scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage]
+ assert(shuffleStage2.shuffleDep.getMergerLocs.nonEmpty)
+
+ assert(shuffleStage2.shuffleDep.shuffleMergeFinalized)
+ assert(shuffleStage1.shuffleDep.shuffleMergeFinalized)
+ assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep1.shuffleId)
== parts)
+ assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep2.shuffleId)
== parts)
+
+ completeNextResultStageWithSuccess(2, 0)
+ assert(results === Map(0 -> 42))
+ results.clear()
+ assertDataStructuresEmpty()
+ }
+
+ test("SPARK-32920: Reused ShuffleDependency with Shuffle Merge disabled for
the corresponding" +
+ " ShuffleDependency should not cause DAGScheduler to hang") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 10)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 20
+
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+ val partitions = (0 until parts).toArray
+ submit(reduceRdd, partitions)
+
+ completeShuffleMapStageSuccessfully(0, 0, parts)
+ val shuffleStage =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ assert(!shuffleStage.shuffleDep.shuffleMergeEnabled)
+
+ completeNextResultStageWithSuccess(1, 0)
+ val reduce2 = new MyRDD(sc, parts, List(shuffleDep))
+ submit(reduce2, partitions)
+ // Stage 2 should not be executed as it should reuse the already computed
shuffle output
+ assert(scheduler.stageIdToStage(2).latestInfo.taskMetrics == null)
+ completeNextResultStageWithSuccess(3, 0, idx => idx + 1234)
+
+ val expected = (0 until parts).map(idx => (idx, idx + 1234))
+ assert(results === expected.toMap)
+
+ assertDataStructuresEmpty()
+ }
+
+ test("SPARK-32920: Reused ShuffleDependency with Shuffle Merge disabled for
the corresponding" +
+ " ShuffleDependency with shuffle data loss should recompute missing
partitions") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 10)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 20
+
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+ val partitions = (0 until parts).toArray
+ submit(reduceRdd, partitions)
+
+ completeShuffleMapStageSuccessfully(0, 0, parts)
+ val shuffleStage =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ assert(!shuffleStage.shuffleDep.shuffleMergeEnabled)
+
+ completeNextResultStageWithSuccess(1, 0)
+
+ DAGSchedulerSuite.clearMergerLocs
+ val hosts = (6 to parts).map {x => s"Host$x" }
+ DAGSchedulerSuite.addMergerLocs(hosts)
+
+ val reduce2 = new MyRDD(sc, parts, List(shuffleDep))
+ submit(reduce2, partitions)
+ // Note that the stage numbering here is only b/c the shared dependency
produces a new, skipped
+ // stage. If instead it reused the existing stage, then this would be
stage 2
+ completeNextStageWithFetchFailure(3, 0, shuffleDep)
+ scheduler.resubmitFailedStages()
+
+ // Make sure shuffle merge is disabled for the retry
+ val stage2 = scheduler.stageIdToStage(2).asInstanceOf[ShuffleMapStage]
+ assert(!stage2.shuffleDep.shuffleMergeEnabled)
+
+ // the scheduler now creates a new task set to regenerate the missing map
output, but this time
+ // using a different stage, the "skipped" one
+ assert(scheduler.stageIdToStage(2).latestInfo.taskMetrics != null)
+ completeShuffleMapStageSuccessfully(2, 0, 2)
+ completeNextResultStageWithSuccess(3, 1, idx => idx + 1234)
+
+ val expected = (0 until parts).map(idx => (idx, idx + 1234))
+ assert(results === expected.toMap)
+ assertDataStructuresEmpty()
+ }
+
+ test("SPARK-32920: Empty RDD should not be computed") {
+ initPushBasedShuffleConfs(conf)
+ val data = sc.emptyRDD[Int]
+ data.sortBy(x => x).collect()
+ assert(taskSets.isEmpty)
+ assertDataStructuresEmpty()
+ }
+
+ test("SPARK-32920: Merge results should be unregistered if the running stage
is cancelled" +
+ " before shuffle merge is finalized") {
+ initPushBasedShuffleConfs(conf)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ scheduler = new MyDAGScheduler(
+ sc,
+ taskScheduler,
+ sc.listenerBus,
+ mapOutputTracker,
+ blockManagerMaster,
+ sc.env,
+ shuffleMergeFinalize = false)
+ dagEventProcessLoopTester = new
DAGSchedulerEventProcessLoopTester(scheduler)
+
+ val parts = 2
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+ // Complete shuffle map stage successfully on hostA
+ complete(taskSets(0), taskSets(0).tasks.zipWithIndex.map {
+ case (task, _) =>
+ (Success, makeMapStatus("hostA", parts))
+ }.toSeq)
+
+ assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId)
== parts)
+ val shuffleMapStageToCancel =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ runEvent(StageCancelled(0, Option("Explicit cancel check")))
+ scheduler.handleShuffleMergeFinalized(shuffleMapStageToCancel)
+ assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId)
== 0)
+ }
+
+ test("SPARK-32920: SPARK-35549: Merge results should not get registered" +
+ " after shuffle merge finalization") {
+ initPushBasedShuffleConfs(conf)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+
+ scheduler = new MyDAGScheduler(
+ sc,
+ taskScheduler,
+ sc.listenerBus,
+ mapOutputTracker,
+ blockManagerMaster,
+ sc.env,
+ shuffleMergeFinalize = false,
+ shuffleMergeRegister = false)
+ dagEventProcessLoopTester = new
DAGSchedulerEventProcessLoopTester(scheduler)
+
+ val parts = 2
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+ // Complete shuffle map stage successfully on hostA
+ complete(taskSets(0), taskSets(0).tasks.zipWithIndex.map {
+ case (task, _) =>
+ (Success, makeMapStatus("hostA", parts))
+ }.toSeq)
+ val shuffleMapStage =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((0,
makeMergeStatus("hostA"))))
+ scheduler.handleShuffleMergeFinalized(shuffleMapStage)
+ scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((1,
makeMergeStatus("hostA"))))
+ assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId)
== 1)
+ }
+
+ test("SPARK-32920: Disable push based shuffle in the case of a barrier
stage") {
+ initPushBasedShuffleConfs(conf)
+ DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+
+ val parts = 2
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil).barrier().mapPartitions(iter
=> iter)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+ completeShuffleMapStageSuccessfully(0, 0, reduceRdd.partitions.length)
+ val shuffleMapStage =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ assert(!shuffleMapStage.shuffleDep.shuffleMergeEnabled)
+ }
+
+ test("SPARK-32920: metadata fetch failure should not unregister map status")
{
+ initPushBasedShuffleConfs(conf)
+ val parts = 2
+ val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new
HashPartitioner(parts))
+
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker =
mapOutputTracker)
+ submit(reduceRdd, (0 until parts).toArray)
+ assert(taskSets.length == 1)
+
+ // Complete shuffle map stage successfully on hostA
+ complete(taskSets(0), taskSets(0).tasks.zipWithIndex.map {
+ case (task, _) =>
+ (Success, makeMapStatus("hostA", parts))
+ }.toSeq)
+
+ assert(mapOutputTracker.getNumAvailableOutputs(shuffleDep.shuffleId) ==
parts)
+
+ // Finish the first task
+ runEvent(makeCompletionEvent(
+ taskSets(1).tasks(0), Success, makeMapStatus("hostA", parts)))
+
+ // The second task fails with Metadata Failed exception.
+ val metadataFetchFailedEx = new MetadataFetchFailedException(
+ shuffleDep.shuffleId, 1, "metadata failure");
+ runEvent(makeCompletionEvent(
+ taskSets(1).tasks(1), metadataFetchFailedEx.toTaskFailedReason, null))
+ assert(mapOutputTracker.getNumAvailableOutputs(shuffleDep.shuffleId) ==
parts)
+ }
+
/**
* Assert that the supplied TaskSet has exactly the given hosts as its
preferred locations.
* Note that this checks only the host and not the executor ID.
@@ -3448,14 +3831,69 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
}
object DAGSchedulerSuite {
+ val mergerLocs = ArrayBuffer[BlockManagerId]()
+
def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2, mapTaskId:
Long = -1): MapStatus =
MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes),
mapTaskId)
def makeBlockManagerId(host: String): BlockManagerId = {
BlockManagerId(host + "-exec", host, 12345)
}
+
+ def makeMergeStatus(host: String, size: Long = 1000): MergeStatus =
+ MergeStatus(makeBlockManagerId(host), mock(classOf[RoaringBitmap]), size)
+
+ def addMergerLocs(locs: Seq[String]): Unit = {
+ locs.foreach { loc => mergerLocs.append(makeBlockManagerId(loc)) }
+ }
+
+ def clearMergerLocs: Unit = mergerLocs.clear()
+
}
object FailThisAttempt {
val _fail = new AtomicBoolean(true)
}
+
+private class PushBasedSchedulerBackend(
+ conf: SparkConf,
+ scheduler: TaskSchedulerImpl,
+ cores: Int) extends LocalSchedulerBackend(conf, scheduler, cores) {
+
+ override def getShufflePushMergerLocations(
+ numPartitions: Int,
+ resourceProfileId: Int): Seq[BlockManagerId] = {
+ val mergerLocations =
Utils.randomize(DAGSchedulerSuite.mergerLocs).take(numPartitions)
+ if (mergerLocations.size < numPartitions && mergerLocations.size <
+ conf.getInt(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD.key,
5)) {
+ Seq.empty[BlockManagerId]
+ } else {
+ mergerLocations
+ }
+ }
+
+ // Currently this is only used in tests specifically for Push based shuffle
+ override def maxNumConcurrentTasks(rp: ResourceProfile): Int = {
+ 2
+ }
+}
+
+private class PushBasedClusterManager extends ExternalClusterManager {
+ def canCreate(masterURL: String): Boolean = masterURL ==
"pushbasedshuffleclustermanager"
+
+ override def createSchedulerBackend(
+ sc: SparkContext,
+ masterURL: String,
+ scheduler: TaskScheduler): SchedulerBackend = {
+ new PushBasedSchedulerBackend(sc.conf,
scheduler.asInstanceOf[TaskSchedulerImpl], 1)
+ }
+
+ override def createTaskScheduler(
+ sc: SparkContext,
+ masterURL: String): TaskScheduler = new TaskSchedulerImpl(sc, 1, isLocal
= true)
+
+ override def initialize(scheduler: TaskScheduler, backend:
SchedulerBackend): Unit = {
+ val sc = scheduler.asInstanceOf[TaskSchedulerImpl]
+ sc.initialize(backend)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]