This is an automated email from the ASF dual-hosted git repository.
mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f6128a6 [SPARK-33701][SHUFFLE] Adaptive shuffle merge finalization
for push-based shuffle
f6128a6 is described below
commit f6128a6f4215dc45a19209d799dd9bf98fab6d8a
Author: Venkata krishnan Sowrirajan <[email protected]>
AuthorDate: Wed Jan 5 01:47:01 2022 -0600
[SPARK-33701][SHUFFLE] Adaptive shuffle merge finalization for push-based
shuffle
### What changes were proposed in this pull request?
As part of SPARK-32920 implemented a simple approach to finalization for
push-based shuffle. Shuffle merge finalization is the final operation happens
at the end of the stage when all the tasks are completed asking all the
external shuffle services to complete the shuffle merge for the stage. Once
this request is completed no more shuffle pushes will be accepted. With this
approach, `DAGScheduler` waits for a fixed time of 10s
(`spark.shuffle.push.finalize.timeout`) to allow some time [...]
In this PR, instead of waiting for fixed amount of time before shuffle
merge finalization now this is controlled adaptively if min threshold number of
map tasks shuffle push (`spark.shuffle.push.minPushRatio`) completed then
shuffle merge finalization will be scheduled. Also additionally if the total
shuffle generated is lesser than min threshold shuffle size
(`spark.shuffle.push.minShuffleSizeToWait`) then immediately shuffle merge
finalization is scheduled.
### Why are the changes needed?
This is a performance improvement to the existing functionality
### Does this PR introduce _any_ user-facing change?
Yes additional user facing configs `spark.shuffle.push.minPushRatio` and
`spark.shuffle.push.minShuffleSizeToWait`
### How was this patch tested?
Added unit tests in `DAGSchedulerSuite`, `ShuffleBlockPusherSuite`
Lead-authored-by: Min Shen <mshenlinkedin.com>
Co-authored-by: Venkata krishnan Sowrirajan <vsowrirajanlinkedin.com>
Closes #33896 from venkata91/SPARK-33701.
Lead-authored-by: Venkata krishnan Sowrirajan <[email protected]>
Co-authored-by: Min Shen <[email protected]>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
.../main/scala/org/apache/spark/Dependency.scala | 35 ++-
.../scala/org/apache/spark/MapOutputTracker.scala | 6 +-
.../src/main/scala/org/apache/spark/SparkEnv.scala | 3 +
.../executor/CoarseGrainedExecutorBackend.scala | 6 +
.../org/apache/spark/internal/config/package.scala | 27 ++
.../org/apache/spark/scheduler/DAGScheduler.scala | 278 +++++++++++++----
.../apache/spark/scheduler/DAGSchedulerEvent.scala | 4 +
.../cluster/CoarseGrainedClusterMessage.scala | 3 +
.../cluster/CoarseGrainedSchedulerBackend.scala | 3 +
.../apache/spark/shuffle/ShuffleBlockPusher.scala | 39 ++-
.../apache/spark/scheduler/DAGSchedulerSuite.scala | 340 +++++++++++++++++++--
.../spark/shuffle/ShuffleBlockPusherSuite.scala | 101 +++++-
docs/configuration.md | 16 +
13 files changed, 772 insertions(+), 89 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala
b/core/src/main/scala/org/apache/spark/Dependency.scala
index 1b4e7ba..8e348ee 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -17,8 +17,12 @@
package org.apache.spark
+import java.util.concurrent.ScheduledFuture
+
import scala.reflect.ClassTag
+import org.roaringbitmap.RoaringBitmap
+
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
@@ -131,9 +135,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C:
ClassTag](
def shuffleMergeId: Int = _shuffleMergeId
def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = {
- if (mergerLocs != null) {
- this.mergerLocs = mergerLocs
- }
+ this.mergerLocs = mergerLocs
}
def getMergerLocs: Seq[BlockManagerId] = mergerLocs
@@ -160,6 +162,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C:
ClassTag](
_shuffleMergedFinalized = false
mergerLocs = Nil
_shuffleMergeId += 1
+ finalizeTask = None
+ shufflePushCompleted.clear()
}
private def canShuffleMergeBeEnabled(): Boolean = {
@@ -169,11 +173,34 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C:
ClassTag](
if (isPushShuffleEnabled && rdd.isBarrier()) {
logWarning("Push-based shuffle is currently not supported for barrier
stages")
}
- isPushShuffleEnabled &&
+ isPushShuffleEnabled && numPartitions > 0 &&
// TODO: SPARK-35547: Push based shuffle is currently unsupported for
Barrier stages
!rdd.isBarrier()
}
+ @transient private[this] val shufflePushCompleted = new RoaringBitmap()
+
+ /**
+ * Mark a given map task as push completed in the tracking bitmap.
+ * Using the bitmap ensures that the same map task launched multiple times
due to
+ * either speculation or stage retry is only counted once.
+ * @param mapIndex Map task index
+ * @return number of map tasks with block push completed
+ */
+ def incPushCompleted(mapIndex: Int): Int = {
+ shufflePushCompleted.add(mapIndex)
+ shufflePushCompleted.getCardinality
+ }
+
+ // Only used by DAGScheduler to coordinate shuffle merge finalization
+ @transient private[this] var finalizeTask: Option[ScheduledFuture[_]] = None
+
+ def getFinalizeTask: Option[ScheduledFuture[_]] = finalizeTask
+
+ def setFinalizeTask(task: ScheduledFuture[_]): Unit = {
+ finalizeTask = Option(task)
+ }
+
_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
_rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId)
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index af26abc..d71fb09 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -917,7 +917,7 @@ private[spark] class MapOutputTrackerMaster(
Runtime.getRuntime.availableProcessors(),
statuses.length.toLong * totalSizes.length / parallelAggThreshold +
1).toInt
if (parallelism <= 1) {
- for (s <- statuses) {
+ statuses.filter(_ != null).foreach { s =>
for (i <- 0 until totalSizes.length) {
totalSizes(i) += s.getSizeForBlock(i)
}
@@ -928,8 +928,8 @@ private[spark] class MapOutputTrackerMaster(
implicit val executionContext =
ExecutionContext.fromExecutor(threadPool)
val mapStatusSubmitTasks = equallyDivide(totalSizes.length,
parallelism).map {
reduceIds => Future {
- for (s <- statuses; i <- reduceIds) {
- totalSizes(i) += s.getSizeForBlock(i)
+ statuses.filter(_ != null).foreach { s =>
+ reduceIds.foreach(i => totalSizes(i) += s.getSizeForBlock(i))
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala
b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 0388c7b..d07614a 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
+import org.apache.spark.executor.ExecutorBackend
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.internal.config._
import org.apache.spark.memory.{MemoryManager, UnifiedMemoryManager}
@@ -81,6 +82,8 @@ class SparkEnv (
private[spark] var driverTmpDir: Option[String] = None
+ private[spark] var executorBackend: Option[ExecutorBackend] = None
+
private[spark] def stop(): Unit = {
if (!isStopped) {
diff --git
a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 43887a7..fb7b4e6 100644
---
a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++
b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -106,6 +106,7 @@ private[spark] class CoarseGrainedExecutorBackend(
rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
// This is a very fast action so we can use "ThreadUtils.sameThread"
driver = Some(ref)
+ env.executorBackend = Option(this)
ref.ask[Boolean](RegisterExecutor(executorId, self, hostname, cores,
extractLogUrls,
extractAttributes, _resources, resourceProfile.id))
}(ThreadUtils.sameThread).onComplete {
@@ -162,6 +163,11 @@ private[spark] class CoarseGrainedExecutorBackend(
.map(e => (e._1.substring(prefix.length).toUpperCase(Locale.ROOT),
e._2)).toMap
}
+ def notifyDriverAboutPushCompletion(shuffleId: Int, shuffleMergeId: Int,
mapIndex: Int): Unit = {
+ val msg = ShufflePushCompletion(shuffleId, shuffleMergeId, mapIndex)
+ driver.foreach(_.send(msg))
+ }
+
override def receive: PartialFunction[Any, Unit] = {
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala
b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 71a11f6..a942ba5 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -2193,6 +2193,33 @@ package object config {
// with small MB sized chunk of data.
.createWithDefaultString("3m")
+ private[spark] val PUSH_BASED_SHUFFLE_MERGE_FINALIZE_THREADS =
+ ConfigBuilder("spark.shuffle.push.merge.finalizeThreads")
+ .doc("Number of threads used by driver to finalize shuffle merge. Since
it could" +
+ " potentially take seconds for a large shuffle to finalize, having
multiple threads helps" +
+ " driver to handle concurrent shuffle merge finalize requests when
push-based" +
+ " shuffle is enabled.")
+ .version("3.3.0")
+ .intConf
+ .createWithDefault(3)
+
+ private[spark] val PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT =
+ ConfigBuilder("spark.shuffle.push.minShuffleSizeToWait")
+ .doc("Driver will wait for merge finalization to complete only if total
shuffle size is" +
+ " more than this threshold. If total shuffle size is less, driver will
immediately" +
+ " finalize the shuffle output")
+ .version("3.3.0")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefaultString("500m")
+
+ private[spark] val PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO =
+ ConfigBuilder("spark.shuffle.push.minCompletedPushRatio")
+ .doc("Fraction of map partitions that should be push complete before
driver starts" +
+ " shuffle merge finalization during push based shuffle")
+ .version("3.3.0")
+ .doubleConf
+ .createWithDefault(1.0)
+
private[spark] val JAR_IVY_REPO_PATH =
ConfigBuilder("spark.jars.ivy")
.doc("Path to specify the Ivy user directory, used for the local Ivy
cache and " +
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 4ed734c..eed71038 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
import java.io.NotSerializableException
import java.util.Properties
-import java.util.concurrent.{ConcurrentHashMap, TimeoutException, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture,
TimeoutException, TimeUnit }
import java.util.concurrent.atomic.AtomicInteger
import scala.annotation.tailrec
@@ -265,6 +265,14 @@ private[spark] class DAGScheduler(
private val shuffleMergeFinalizeWaitSec =
sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT)
+ private val shuffleMergeWaitMinSizeThreshold =
+ sc.getConf.get(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT)
+
+ private val shufflePushMinRatio =
sc.getConf.get(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO)
+
+ private val shuffleMergeFinalizeNumThreads =
+ sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_THREADS)
+
// Since SparkEnv gets initialized after DAGScheduler, externalShuffleClient
needs to be
// initialized lazily
private lazy val externalShuffleClient: Option[BlockStoreClient] =
@@ -274,8 +282,12 @@ private[spark] class DAGScheduler(
None
}
+ // Use multi-threaded scheduled executor. The merge finalization task could
take some time,
+ // depending on the time to establish connections to mergers, and the time
to get MergeStatuses
+ // from all the mergers.
private val shuffleMergeFinalizeScheduler =
-
ThreadUtils.newDaemonThreadPoolScheduledExecutor("shuffle-merge-finalizer", 8)
+ ThreadUtils.newDaemonThreadPoolScheduledExecutor("shuffle-merge-finalizer",
+ shuffleMergeFinalizeNumThreads)
/**
* Called by the TaskSetManager to report task's starting.
@@ -1065,6 +1077,14 @@ private[spark] class DAGScheduler(
}
/**
+ * Receives notification about shuffle push for a given shuffle from one map
+ * task has completed
+ */
+ def shufflePushCompleted(shuffleId: Int, shuffleMergeId: Int, mapIndex:
Int): Unit = {
+ eventProcessLoop.post(ShufflePushCompleted(shuffleId, shuffleMergeId,
mapIndex))
+ }
+
+ /**
* Kill a given task. It will be retried.
*
* @return Whether the task was successfully killed.
@@ -1407,7 +1427,7 @@ private[spark] class DAGScheduler(
// merger locations but the corresponding shuffle map stage did
not complete
// successfully, we would still enable push for its retry.
s.shuffleDep.setShuffleMergeEnabled(false)
- logInfo("Push-based shuffle disabled for $stage (${stage.name})
since it" +
+ logInfo(s"Push-based shuffle disabled for $stage (${stage.name})
since it" +
" is already shuffle merge finalized")
}
}
@@ -1636,6 +1656,42 @@ private[spark] class DAGScheduler(
}
}
+ private[scheduler] def checkAndScheduleShuffleMergeFinalize(
+ shuffleStage: ShuffleMapStage): Unit = {
+ // Check if a finalize task has already been scheduled. This is to prevent
scenarios
+ // where we don't schedule multiple shuffle merge finalization which can
happen due to
+ // stage retry or shufflePushMinRatio is already hit etc.
+ if (shuffleStage.shuffleDep.getFinalizeTask.isEmpty) {
+ // 1. Stage indeterminate and some map outputs are not available -
finalize
+ // immediately without registering shuffle merge results.
+ // 2. Stage determinate and some map outputs are not available - decide
to
+ // register merge results based on map outputs size available and
+ // shuffleMergeWaitMinSizeThreshold.
+ // 3. All shuffle outputs available - decide to register merge results
based
+ // on map outputs size available and shuffleMergeWaitMinSizeThreshold.
+ val totalSize = {
+ lazy val computedTotalSize =
+ mapOutputTracker.getStatistics(shuffleStage.shuffleDep).
+ bytesByPartitionId.filter(_ > 0).sum
+ if (shuffleStage.isAvailable) {
+ computedTotalSize
+ } else {
+ if (shuffleStage.isIndeterminate) {
+ 0L
+ } else {
+ computedTotalSize
+ }
+ }
+ }
+
+ if (totalSize < shuffleMergeWaitMinSizeThreshold) {
+ scheduleShuffleMergeFinalize(shuffleStage, delay = 0,
registerMergeResults = false)
+ } else {
+ scheduleShuffleMergeFinalize(shuffleStage, shuffleMergeFinalizeWaitSec)
+ }
+ }
+ }
+
/**
* Responds to a task finishing. This is called inside the event loop so it
assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end
event from outside.
@@ -1767,7 +1823,7 @@ private[spark] class DAGScheduler(
if (runningStages.contains(shuffleStage) &&
shuffleStage.pendingPartitions.isEmpty) {
if (!shuffleStage.shuffleDep.shuffleMergeFinalized &&
shuffleStage.shuffleDep.getMergerLocs.nonEmpty) {
- scheduleShuffleMergeFinalize(shuffleStage)
+ checkAndScheduleShuffleMergeFinalize(shuffleStage)
} else {
processShuffleMapStageCompletion(shuffleStage)
}
@@ -2074,20 +2130,63 @@ private[spark] class DAGScheduler(
}
/**
- * Schedules shuffle merge finalize.
+ *
+ * Schedules shuffle merge finalization.
+ *
+ * @param stage the stage to finalize shuffle merge
+ * @param delay how long to wait before finalizing shuffle merge
+ * @param registerMergeResults indicate whether DAGScheduler would register
the received
+ * MergeStatus with MapOutputTracker and wait to
schedule the reduce
+ * stage until MergeStatus have been received
from all mergers or
+ * reaches timeout. For very small shuffle, this
could be set to
+ * false to avoid impact to job runtime.
*/
- private[scheduler] def scheduleShuffleMergeFinalize(stage: ShuffleMapStage):
Unit = {
- // TODO: SPARK-33701: Instead of waiting for a constant amount of time for
finalization
- // TODO: for all the stages, adaptively tune timeout for merge finalization
- logInfo(("%s (%s) scheduled for finalizing" +
- " shuffle merge in %s s").format(stage, stage.name,
shuffleMergeFinalizeWaitSec))
- shuffleMergeFinalizeScheduler.schedule(
- new Runnable {
- override def run(): Unit = finalizeShuffleMerge(stage)
- },
- shuffleMergeFinalizeWaitSec,
- TimeUnit.SECONDS
- )
+ private[scheduler] def scheduleShuffleMergeFinalize(
+ stage: ShuffleMapStage,
+ delay: Long,
+ registerMergeResults: Boolean = true): Unit = {
+ val shuffleDep = stage.shuffleDep
+ val scheduledTask: Option[ScheduledFuture[_]] = shuffleDep.getFinalizeTask
+ scheduledTask match {
+ case Some(task) =>
+ // If we find an already scheduled task, check if the task has been
triggered yet.
+ // If it's already triggered, do nothing. Otherwise, cancel it and
schedule a new
+ // one for immediate execution. Note that we should get here only when
+ // handleShufflePushCompleted schedules a finalize task after the
shuffle map stage
+ // completed earlier and scheduled a task with default delay.
+ // The current task should be coming from handleShufflePushCompleted,
thus the
+ // delay should be 0 and registerMergeResults should be true.
+ assert(delay == 0 && registerMergeResults)
+ if (task.getDelay(TimeUnit.NANOSECONDS) > 0 && task.cancel(false)) {
+ logInfo(s"$stage (${stage.name}) scheduled for finalizing shuffle
merge immediately " +
+ s"after cancelling previously scheduled task.")
+ shuffleDep.setFinalizeTask(
+ shuffleMergeFinalizeScheduler.schedule(
+ new Runnable {
+ override def run(): Unit = finalizeShuffleMerge(stage,
registerMergeResults)
+ },
+ 0,
+ TimeUnit.SECONDS
+ )
+ )
+ } else {
+ logInfo(s"$stage (${stage.name}) existing scheduled task for
finalizing shuffle merge" +
+ s"would either be in-progress or finished. No need to schedule
shuffle merge" +
+ s" finalization again.")
+ }
+ case None =>
+ // If no previous finalization task is scheduled, schedule the
finalization task.
+ logInfo(s"$stage (${stage.name}) scheduled for finalizing shuffle
merge in $delay s")
+ shuffleDep.setFinalizeTask(
+ shuffleMergeFinalizeScheduler.schedule(
+ new Runnable {
+ override def run(): Unit = finalizeShuffleMerge(stage,
registerMergeResults)
+ },
+ delay,
+ TimeUnit.SECONDS
+ )
+ )
+ }
}
/**
@@ -2095,38 +2194,72 @@ private[spark] class DAGScheduler(
* the given shuffle map stage to finalize the shuffle merge process for
this shuffle. This is
* invoked in a separate thread to reduce the impact on the DAGScheduler
main thread, as the
* scheduler might need to talk to 1000s of shuffle services to finalize
shuffle merge.
+ *
+ * @param stage ShuffleMapStage to finalize shuffle merge for
+ * @param registerMergeResults indicate whether DAGScheduler would register
the received
+ * MergeStatus with MapOutputTracker and wait to
schedule the reduce
+ * stage until MergeStatus have been received
from all mergers or
+ * reaches timeout. For very small shuffle, this
could be set to
+ * false to avoid impact to job runtime.
*/
- private[scheduler] def finalizeShuffleMerge(stage: ShuffleMapStage): Unit = {
- logInfo("%s (%s) finalizing the shuffle merge".format(stage, stage.name))
+ private[scheduler] def finalizeShuffleMerge(
+ stage: ShuffleMapStage,
+ registerMergeResults: Boolean = true): Unit = {
+ logInfo(s"$stage (${stage.name}) finalizing the shuffle merge with
registering merge " +
+ s"results set to $registerMergeResults")
+ val shuffleId = stage.shuffleDep.shuffleId
+ val shuffleMergeId = stage.shuffleDep.shuffleMergeId
+ val numMergers = stage.shuffleDep.getMergerLocs.length
+ val results = (0 until numMergers).map(_ =>
SettableFuture.create[Boolean]())
externalShuffleClient.foreach { shuffleClient =>
- val shuffleId = stage.shuffleDep.shuffleId
- val numMergers = stage.shuffleDep.getMergerLocs.length
- val results = (0 until numMergers).map(_ =>
SettableFuture.create[Boolean]())
-
- stage.shuffleDep.getMergerLocs.zipWithIndex.foreach {
- case (shuffleServiceLoc, index) =>
- // Sends async request to shuffle service to finalize shuffle merge
on that host
- // TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage is
cancelled
- // TODO: during shuffleMergeFinalizeWaitSec
- shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
- shuffleServiceLoc.port, shuffleId, stage.shuffleDep.shuffleMergeId,
- new MergeFinalizerListener {
- override def onShuffleMergeSuccess(statuses: MergeStatuses):
Unit = {
- assert(shuffleId == statuses.shuffleId)
- eventProcessLoop.post(RegisterMergeStatuses(stage, MergeStatus.
- convertMergeStatusesToMergeStatusArr(statuses,
shuffleServiceLoc)))
- results(index).set(true)
- }
+ if (!registerMergeResults) {
+ results.foreach(_.set(true))
+ // Finalize in separate thread as shuffle merge is a no-op in this case
+ shuffleMergeFinalizeScheduler.schedule(new Runnable {
+ override def run(): Unit = {
+ stage.shuffleDep.getMergerLocs.foreach {
+ case shuffleServiceLoc =>
+ // Sends async request to shuffle service to finalize shuffle
merge on that host.
+ // Since merge statuses will not be registered in this case,
+ // we pass a no-op listener.
+ shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
+ shuffleServiceLoc.port, shuffleId, shuffleMergeId,
+ new MergeFinalizerListener {
+ override def onShuffleMergeSuccess(statuses:
MergeStatuses): Unit = {
+ }
- override def onShuffleMergeFailure(e: Throwable): Unit = {
- logWarning(s"Exception encountered when trying to finalize
shuffle " +
- s"merge on ${shuffleServiceLoc.host} for shuffle
$shuffleId", e)
- // Do not fail the future as this would cause dag scheduler to
prematurely
- // give up on waiting for merge results from the remaining
shuffle services
- // if one fails
- results(index).set(false)
- }
- })
+ override def onShuffleMergeFailure(e: Throwable): Unit = {
+ }
+ })
+ }
+ }
+ }, 0, TimeUnit.SECONDS)
+ } else {
+ stage.shuffleDep.getMergerLocs.zipWithIndex.foreach {
+ case (shuffleServiceLoc, index) =>
+ // Sends async request to shuffle service to finalize shuffle
merge on that host
+ // TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage is
cancelled
+ // TODO: during shuffleMergeFinalizeWaitSec
+ shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
+ shuffleServiceLoc.port, shuffleId, shuffleMergeId,
+ new MergeFinalizerListener {
+ override def onShuffleMergeSuccess(statuses: MergeStatuses):
Unit = {
+ assert(shuffleId == statuses.shuffleId)
+ eventProcessLoop.post(RegisterMergeStatuses(stage,
MergeStatus.
+ convertMergeStatusesToMergeStatusArr(statuses,
shuffleServiceLoc)))
+ results(index).set(true)
+ }
+
+ override def onShuffleMergeFailure(e: Throwable): Unit = {
+ logWarning(s"Exception encountered when trying to finalize
shuffle " +
+ s"merge on ${shuffleServiceLoc.host} for shuffle
$shuffleId", e)
+ // Do not fail the future as this would cause dag scheduler
to prematurely
+ // give up on waiting for merge results from the remaining
shuffle services
+ // if one fails
+ results(index).set(false)
+ }
+ })
+ }
}
// DAGScheduler only waits for a limited amount of time for the merge
results.
// It will attempt to submit the next stage(s) irrespective of whether
merge results
@@ -2185,15 +2318,45 @@ private[spark] class DAGScheduler(
}
}
- private[scheduler] def handleShuffleMergeFinalized(stage: ShuffleMapStage):
Unit = {
- // Only update MapOutputTracker metadata if the stage is still active. i.e
not cancelled.
- if (runningStages.contains(stage)) {
- stage.shuffleDep.markShuffleMergeFinalized()
- processShuffleMapStageCompletion(stage)
- } else {
- // Unregister all merge results if the stage is currently not
- // active (i.e. the stage is cancelled)
- mapOutputTracker.unregisterAllMergeResult(stage.shuffleDep.shuffleId)
+ private[scheduler] def handleShuffleMergeFinalized(stage: ShuffleMapStage,
+ shuffleMergeId: Int): Unit = {
+ // Check if update is for the same merge id - finalization might have
completed for an earlier
+ // adaptive attempt while the stage might have failed/killed and shuffle
id is getting
+ // re-executing now.
+ if (stage.shuffleDep.shuffleMergeId == shuffleMergeId) {
+ if (stage.pendingPartitions.isEmpty) {
+ if (runningStages.contains(stage)) {
+ stage.shuffleDep.markShuffleMergeFinalized()
+ processShuffleMapStageCompletion(stage)
+ } else {
+ // Unregister all merge results if the stage is currently not
+ // active (i.e. the stage is cancelled)
+ mapOutputTracker.unregisterAllMergeResult(stage.shuffleDep.shuffleId)
+ }
+ } else {
+ // stage still running, mark merge finalized. Stage completion will
invoke
+ // processShuffleMapStageCompletion
+ stage.shuffleDep.markShuffleMergeFinalized()
+ }
+ }
+ }
+
+ private[scheduler] def handleShufflePushCompleted(
+ shuffleId: Int, shuffleMergeId: Int, mapIndex: Int): Unit = {
+ shuffleIdToMapStage.get(shuffleId) match {
+ case Some(mapStage) =>
+ val shuffleDep = mapStage.shuffleDep
+ // Only update shufflePushCompleted events for the current active
stage map tasks.
+ // This is required to prevent shuffle merge finalization by dangling
tasks of a
+ // previous attempt in the case of indeterminate stage.
+ if (shuffleDep.shuffleMergeId == shuffleMergeId) {
+ if (!shuffleDep.shuffleMergeFinalized &&
+ shuffleDep.incPushCompleted(mapIndex).toDouble /
shuffleDep.rdd.partitions.length
+ >= shufflePushMinRatio) {
+ scheduleShuffleMergeFinalize(mapStage, delay = 0)
+ }
+ }
+ case None =>
}
}
@@ -2649,7 +2812,10 @@ private[scheduler] class
DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
dagScheduler.handleRegisterMergeStatuses(stage, mergeStatuses)
case ShuffleMergeFinalized(stage) =>
- dagScheduler.handleShuffleMergeFinalized(stage)
+ dagScheduler.handleShuffleMergeFinalized(stage,
stage.shuffleDep.shuffleMergeId)
+
+ case ShufflePushCompleted(shuffleId, shuffleMergeId, mapIndex) =>
+ dagScheduler.handleShufflePushCompleted(shuffleId, shuffleMergeId,
mapIndex)
}
override def onError(e: Throwable): Unit = {
diff --git
a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 307844c..f3798da 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -111,3 +111,7 @@ private[scheduler] case class RegisterMergeStatuses(
private[scheduler] case class ShuffleMergeFinalized(stage: ShuffleMapStage)
extends DAGSchedulerEvent
+
+private[scheduler] case class ShufflePushCompleted(
+ shuffleId: Int, shuffleMergeId: Int, mapIndex: Int)
+ extends DAGSchedulerEvent
diff --git
a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 66ac40f..61ee865 100644
---
a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++
b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -85,6 +85,9 @@ private[spark] object CoarseGrainedClusterMessages {
}
}
+ case class ShufflePushCompletion(shuffleId: Int, shuffleMergeId: Int,
mapIndex: Int)
+ extends CoarseGrainedClusterMessage
+
// Internal messages in driver
case object ReviveOffers extends CoarseGrainedClusterMessage
diff --git
a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 326ea83..13a7183 100644
---
a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++
b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -168,6 +168,9 @@ class CoarseGrainedSchedulerBackend(scheduler:
TaskSchedulerImpl, val rpcEnv: Rp
}
}
+ case ShufflePushCompletion(shuffleId, shuffleMergeId, mapIndex) =>
+ scheduler.dagScheduler.shufflePushCompleted(shuffleId, shuffleMergeId,
mapIndex)
+
case ReviveOffers =>
makeOffers()
diff --git
a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
index 8790371..d6972cd 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
@@ -26,6 +26,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap,
HashSet, Queue}
import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.annotation.Since
+import org.apache.spark.executor.{CoarseGrainedExecutorBackend,
ExecutorBackend}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.launcher.SparkLauncher
@@ -53,7 +54,7 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf)
extends Logging {
private[this] val maxBytesInFlight = conf.get(REDUCER_MAX_SIZE_IN_FLIGHT) *
1024 * 1024
private[this] val maxReqsInFlight = conf.get(REDUCER_MAX_REQS_IN_FLIGHT)
private[this] val maxBlocksInFlightPerAddress =
conf.get(REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS)
- private[this] var bytesInFlight = 0L
+ private[shuffle] var bytesInFlight = 0L
private[this] var reqsInFlight = 0
private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId,
Int]()
private[this] val deferredPushRequests = new HashMap[BlockManagerId,
Queue[PushRequest]]()
@@ -61,6 +62,10 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf)
extends Logging {
private[this] val errorHandler = createErrorHandler()
// VisibleForTesting
private[shuffle] val unreachableBlockMgrs = new HashSet[BlockManagerId]()
+ private[this] var shuffleId = -1
+ private[this] var mapIndex = -1
+ private[this] var shuffleMergeId = -1
+ private[this] var pushCompletionNotified = false
// VisibleForTesting
private[shuffle] def createErrorHandler(): BlockPushErrorHandler = {
@@ -84,6 +89,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf)
extends Logging {
}
}
}
+ // VisibleForTesting
+ private[shuffle] def isPushCompletionNotified = pushCompletionNotified
/**
* Initiates the block push.
@@ -101,11 +108,17 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf)
extends Logging {
mapIndex: Int): Unit = {
val numPartitions = dep.partitioner.numPartitions
val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
+ this.shuffleId = dep.shuffleId
+ this.shuffleMergeId = dep.shuffleMergeId
+ this.mapIndex = mapIndex
val requests = prepareBlockPushRequests(numPartitions, mapIndex,
dep.shuffleId,
dep.shuffleMergeId, dataFile, partitionLengths, dep.getMergerLocs,
transportConf)
// Randomize the orders of the PushRequest, so different mappers pushing
blocks at the same
// time won't be pushing the same ranges of shuffle partitions.
pushRequests ++= Utils.randomize(requests)
+ if (pushRequests.isEmpty) {
+ notifyDriverAboutPushCompletion()
+ }
submitTask(() => {
tryPushUpToMax()
@@ -327,11 +340,35 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf)
extends Logging {
s"stop.")
return false
} else {
+ if (reqsInFlight <= 0 && pushRequests.isEmpty &&
deferredPushRequests.isEmpty) {
+ notifyDriverAboutPushCompletion()
+ }
remainingBlocks.isEmpty && (pushRequests.nonEmpty ||
deferredPushRequests.nonEmpty)
}
}
/**
+ * Notify the driver about all the blocks generated by the current map task
having been pushed.
+ * This enables the DAGScheduler to finalize shuffle merge as soon as
sufficient map tasks have
+ * completed push instead of always waiting for a fixed amount of time.
+ *
+ * VisibleForTesting
+ */
+ protected def notifyDriverAboutPushCompletion(): Unit = {
+ assert(shuffleId >= 0 && mapIndex >= 0)
+ if (!pushCompletionNotified) {
+ SparkEnv.get.executorBackend match {
+ case Some(cb: CoarseGrainedExecutorBackend) =>
+ cb.notifyDriverAboutPushCompletion(shuffleId, shuffleMergeId,
mapIndex)
+ case Some(eb: ExecutorBackend) =>
+ logWarning(s"Currently $eb doesn't support push-based shuffle")
+ case None =>
+ }
+ pushCompletionNotified = true
+ }
+ }
+
+ /**
* Convert the shuffle data file of the current mapper into a list of
PushRequest. Basically,
* continuous blocks in the shuffle file are grouped into a single request
to allow more
* efficient read of the block data. Each mapper for a given shuffle will
receive the same
diff --git
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index afea912..76612cb 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.scheduler
import java.util.Properties
-import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.{CountDownLatch, Delayed, ScheduledFuture,
TimeUnit}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong, AtomicReference}
import scala.annotation.meta.param
@@ -124,6 +124,31 @@ class MyRDD(
override def toString: String = "DAGSchedulerSuiteRDD " + id
}
+class DummyScheduledFuture(
+ val delay: Long,
+ val registerMergeResults: Boolean)
+ extends ScheduledFuture[Int] {
+
+ override def get(timeout: Long, unit: TimeUnit): Int =
+ throw new IllegalStateException("should not be reached")
+
+ override def getDelay(unit: TimeUnit): Long = delay
+
+ override def compareTo(o: Delayed): Int =
+ throw new IllegalStateException("should not be reached")
+
+ override def cancel(mayInterruptIfRunning: Boolean): Boolean = true
+
+ override def isCancelled: Boolean =
+ throw new IllegalStateException("should not be reached")
+
+ override def isDone: Boolean =
+ throw new IllegalStateException("should not be reached")
+
+ override def get(): Int =
+ throw new IllegalStateException("should not be reached")
+}
+
class DAGSchedulerSuiteDummyException extends Exception
class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with
TimeLimits {
@@ -312,16 +337,27 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
* Schedules shuffle merge finalize.
*/
override private[scheduler] def scheduleShuffleMergeFinalize(
- shuffleMapStage: ShuffleMapStage): Unit = {
- if (shuffleMergeRegister) {
+ shuffleMapStage: ShuffleMapStage,
+ delay: Long,
+ registerMergeResults: Boolean = true): Unit = {
+ if (shuffleMergeRegister && registerMergeResults) {
for (part <- 0 until
shuffleMapStage.shuffleDep.partitioner.numPartitions) {
val mergeStatuses = Seq((part, makeMergeStatus("",
shuffleMapStage.shuffleDep.shuffleMergeId)))
handleRegisterMergeStatuses(shuffleMapStage, mergeStatuses)
}
- if (shuffleMergeFinalize) {
- handleShuffleMergeFinalized(shuffleMapStage)
- }
+ }
+
+ shuffleMapStage.shuffleDep.getFinalizeTask match {
+ case Some(_) =>
+ assert(delay == 0 && registerMergeResults)
+ case None =>
+ }
+
+ shuffleMapStage.shuffleDep.setFinalizeTask(
+ new DummyScheduledFuture(delay, registerMergeResults))
+ if (shuffleMergeFinalize) {
+ handleShuffleMergeFinalized(shuffleMapStage,
shuffleMapStage.shuffleDep.shuffleMergeId)
}
}
}
@@ -472,6 +508,12 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
assert(this.results === expected)
}
+ /** Sends ShufflePushCompleted to the DAG scheduler. */
+ private def pushComplete(
+ shuffleId: Int, shuffleMergeId: Int, mapIndex: Int): Unit = {
+ runEvent(ShufflePushCompleted(shuffleId, shuffleMergeId, mapIndex))
+ }
+
test("[SPARK-3353] parent stage should have lower stage id") {
sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count()
val stageByOrderOfExecution = sparkListener.stageByOrderOfExecution
@@ -3428,6 +3470,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
private def initPushBasedShuffleConfs(conf: SparkConf) = {
conf.set(config.SHUFFLE_SERVICE_ENABLED, true)
conf.set(config.PUSH_BASED_SHUFFLE_ENABLED, true)
+ conf.set(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT, 1L)
conf.set("spark.master", "pushbasedshuffleclustermanager")
// Needed to run push-based shuffle tests in ad-hoc manner through IDE
conf.set(Tests.IS_TESTING, true)
@@ -3439,7 +3482,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
test("SPARK-32920: shuffle merge finalization") {
initPushBasedShuffleConfs(conf)
- DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.clearMergerLocs()
DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
val parts = 2
val shuffleMapRdd = new MyRDD(sc, parts, Nil)
@@ -3459,7 +3502,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
test("SPARK-32920: merger locations not empty") {
initPushBasedShuffleConfs(conf)
conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 3)
- DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.clearMergerLocs()
DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
val parts = 2
@@ -3484,7 +3527,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
test("SPARK-32920: merger locations reuse from shuffle dependency") {
initPushBasedShuffleConfs(conf)
conf.set(config.SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS, 3)
- DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.clearMergerLocs()
DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
val parts = 2
@@ -3524,7 +3567,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
test("SPARK-32920: Disable shuffle merge due to not enough mergers
available") {
initPushBasedShuffleConfs(conf)
conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
- DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.clearMergerLocs()
DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
val parts = 7
@@ -3548,7 +3591,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
test("SPARK-32920: Ensure child stage should not start before all the" +
" parent stages are completed with shuffle merge finalized for all the
parent stages") {
initPushBasedShuffleConfs(conf)
- DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.clearMergerLocs()
DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
val parts = 1
val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
@@ -3585,7 +3628,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
" ShuffleDependency should not cause DAGScheduler to hang") {
initPushBasedShuffleConfs(conf)
conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 10)
- DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.clearMergerLocs()
DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
val parts = 20
@@ -3616,7 +3659,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
" ShuffleDependency with shuffle data loss should recompute missing
partitions") {
initPushBasedShuffleConfs(conf)
conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 10)
- DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.clearMergerLocs()
DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
val parts = 20
@@ -3632,7 +3675,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
completeNextResultStageWithSuccess(1, 0)
- DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.clearMergerLocs()
val hosts = (6 to parts).map {x => s"Host$x" }
DAGSchedulerSuite.addMergerLocs(hosts)
@@ -3669,7 +3712,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
test("SPARK-32920: Merge results should be unregistered if the running stage
is cancelled" +
" before shuffle merge is finalized") {
initPushBasedShuffleConfs(conf)
- DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.clearMergerLocs()
DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
scheduler = new MyDAGScheduler(
sc,
@@ -3697,14 +3740,15 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId)
== parts)
val shuffleMapStageToCancel =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
runEvent(StageCancelled(0, Option("Explicit cancel check")))
- scheduler.handleShuffleMergeFinalized(shuffleMapStageToCancel)
+ scheduler.handleShuffleMergeFinalized(shuffleMapStageToCancel,
+ shuffleMapStageToCancel.shuffleDep.shuffleMergeId)
assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId)
== 0)
}
test("SPARK-32920: SPARK-35549: Merge results should not get registered" +
" after shuffle merge finalization") {
initPushBasedShuffleConfs(conf)
- DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.clearMergerLocs()
DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
scheduler = new MyDAGScheduler(
@@ -3733,7 +3777,8 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
val shuffleMapStage =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((0,
makeMergeStatus("hostA",
shuffleDep.shuffleMergeId))))
- scheduler.handleShuffleMergeFinalized(shuffleMapStage)
+ scheduler.handleShuffleMergeFinalized(shuffleMapStage,
+ shuffleMapStage.shuffleDep.shuffleMergeId)
scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((1,
makeMergeStatus("hostA",
shuffleDep.shuffleMergeId))))
assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId)
== 1)
@@ -3741,7 +3786,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
test("SPARK-32920: Disable push based shuffle in the case of a barrier
stage") {
initPushBasedShuffleConfs(conf)
- DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.clearMergerLocs()
DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
val parts = 2
@@ -3788,7 +3833,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
test("SPARK-32923: handle stage failure for indeterminate map stage with
push-based shuffle") {
initPushBasedShuffleConfs(conf)
- DAGSchedulerSuite.clearMergerLocs
+ DAGSchedulerSuite.clearMergerLocs()
DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
val (shuffleId1, shuffleId2) = constructIndeterminateStageFetchFailed()
@@ -3847,11 +3892,262 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
// Job successful ended.
assert(results === Map(0 -> 11, 1 -> 12))
+ }
+
+ test("SPARK-33701: check adaptive shuffle merge finalization triggered
after" +
+ " stage completion") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 3)
+ DAGSchedulerSuite.clearMergerLocs()
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 2
+
+ val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+ val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new
HashPartitioner(parts))
+ val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+ val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+ tracker = mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+
+ val taskResults = taskSets(0).tasks.zipWithIndex.map {
+ case (_, idx) =>
+ (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+ }.toSeq
+ for ((result, i) <- taskResults.zipWithIndex) {
+ runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2))
+ }
+ // Verify finalize task is set with default delay of 10s and merge results
are marked
+ // for registration
+ val shuffleStage1 =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ assert(shuffleStage1.shuffleDep.shuffleMergeEnabled)
+ val finalizeTask1 = shuffleStage1.shuffleDep.getFinalizeTask.get
+ .asInstanceOf[DummyScheduledFuture]
+ assert(finalizeTask1.delay == 10 && finalizeTask1.registerMergeResults)
+ assert(shuffleStage1.shuffleDep.shuffleMergeFinalized)
+
+ complete(taskSets(1), taskSets(1).tasks.zipWithIndex.map {
+ case (_, idx) =>
+ (Success, makeMapStatus("host" + ('A' + idx).toChar, parts, 10))
+ }.toSeq)
+ val shuffleStage2 =
scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage]
+ assert(shuffleStage2.shuffleDep.getFinalizeTask.nonEmpty)
+ val finalizeTask2 = shuffleStage2.shuffleDep.getFinalizeTask.get
+ .asInstanceOf[DummyScheduledFuture]
+ assert(finalizeTask2.delay == 10 && finalizeTask2.registerMergeResults)
+
+ assert(mapOutputTracker.
+ getNumAvailableMergeResults(shuffleStage1.shuffleDep.shuffleId) == parts)
+ assert(mapOutputTracker.
+ getNumAvailableMergeResults(shuffleStage2.shuffleDep.shuffleId) == parts)
+ completeNextResultStageWithSuccess(2, 0)
+ assert(results === Map(0 -> 42, 1 -> 42))
+
results.clear()
assertDataStructuresEmpty()
}
- /**
+ test("SPARK-33701: check adaptive shuffle merge finalization triggered after
minimum" +
+ " threshold push complete") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT, 10L)
+ conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 5)
+ conf.set(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO, 0.5)
+ DAGSchedulerSuite.clearMergerLocs()
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 4
+
+ val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+ val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new
HashPartitioner(parts))
+ val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+ val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+ tracker = mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+
+ val taskResults = taskSets(0).tasks.zipWithIndex.map {
+ case (_, idx) =>
+ (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+ }.toSeq
+
+ runEvent(makeCompletionEvent(taskSets(0).tasks(0), taskResults(0)._1,
taskResults(0)._2))
+ runEvent(makeCompletionEvent(taskSets(0).tasks(1), taskResults(0)._1,
taskResults(0)._2))
+
+ val shuffleStage1 =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ assert(shuffleStage1.shuffleDep.shuffleMergeEnabled)
+
+ pushComplete(shuffleStage1.shuffleDep.shuffleId, 0, 0)
+ pushComplete(shuffleStage1.shuffleDep.shuffleId, 0, 1)
+
+ // Minimum push complete for 2 tasks, should have scheduled merge
finalization
+ val finalizeTask = shuffleStage1.shuffleDep.getFinalizeTask.get
+ .asInstanceOf[DummyScheduledFuture]
+ assert(finalizeTask.registerMergeResults && finalizeTask.delay == 0)
+
+ runEvent(makeCompletionEvent(taskSets(0).tasks(2), taskResults(0)._1,
taskResults(0)._2))
+ runEvent(makeCompletionEvent(taskSets(0).tasks(3), taskResults(0)._1,
taskResults(0)._2))
+
+ completeShuffleMapStageSuccessfully(1, 0, parts)
+
+ completeNextResultStageWithSuccess(2, 0)
+ assert(results === Map(0 -> 42, 1 -> 42, 2 -> 42, 3 -> 42))
+
+ results.clear()
+ assertDataStructuresEmpty()
+ }
+
+ // Test the behavior of stage cancellation during the
spark.shuffle.push.finalize.timeout
+ // wait for shuffle merge finalization, there are 2 different cases:
+ // 1. Deterministic stage - With deterministic stage, the shuffleMergeId = 0
for multiple
+ // stage attempts, so if the stage is cancelled before shuffle is merge
finalized then
+ // the merge results are unregistered from MapOutputTracker
+ // 2. Indeterminate stage - Different attempt of the same stage can trigger
shuffle merge
+ // finalization but it is validated by the shuffleMergeId (unique across
stages and stage
+ // attempts for indeterminate stages) and only the shuffle merge is finalized
+ test("SPARK-33701: check adaptive shuffle merge finalization behavior with
stage" +
+ " cancellation during spark.shuffle.push.finalize.timeout wait") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT, 10L)
+ conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 5)
+ conf.set(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO, 0.5)
+ DAGSchedulerSuite.clearMergerLocs()
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 4
+
+ scheduler = new MyDAGScheduler(
+ sc,
+ taskScheduler,
+ sc.listenerBus,
+ mapOutputTracker,
+ blockManagerMaster,
+ sc.env,
+ shuffleMergeFinalize = false)
+ dagEventProcessLoopTester = new
DAGSchedulerEventProcessLoopTester(scheduler)
+
+ // Determinate stage
+ val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+ val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new
HashPartitioner(parts))
+ val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+ val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+ tracker = mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+
+ val taskResults = taskSets(0).tasks.zipWithIndex.map {
+ case (_, idx) =>
+ (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+ }.toSeq
+
+ for ((result, i) <- taskResults.zipWithIndex) {
+ runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2))
+ }
+ val shuffleStage1 =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ runEvent(StageCancelled(0, Option("Explicit cancel check")))
+ scheduler.handleShuffleMergeFinalized(shuffleStage1,
shuffleStage1.shuffleDep.shuffleMergeId)
+
+ assert(shuffleStage1.shuffleDep.shuffleMergeEnabled)
+ assert(!shuffleStage1.shuffleDep.shuffleMergeFinalized)
+ assert(mapOutputTracker.
+ getNumAvailableMergeResults(shuffleStage1.shuffleDep.shuffleId) == 0)
+
+ // Indeterminate stage
+ val shuffleMapIndeterminateRdd1 = new MyRDD(sc, parts, Nil, indeterminate
= true)
+ val shuffleIndeterminateDep1 = new ShuffleDependency(
+ shuffleMapIndeterminateRdd1, new HashPartitioner(parts))
+ val shuffleMapIndeterminateRdd2 = new MyRDD(sc, parts, Nil, indeterminate
= true)
+ val shuffleIndeterminateDep2 = new ShuffleDependency(
+ shuffleMapIndeterminateRdd2, new HashPartitioner(parts))
+ val reduceIndeterminateRdd = new MyRDD(sc, parts, List(
+ shuffleIndeterminateDep1, shuffleIndeterminateDep2), tracker =
mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceIndeterminateRdd, (0 until parts).toArray)
+
+ val indeterminateResults = taskSets(0).tasks.zipWithIndex.map {
+ case (_, idx) =>
+ (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+ }.toSeq
+
+ for ((result, i) <- indeterminateResults.zipWithIndex) {
+ runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2))
+ }
+
+ val shuffleIndeterminateStage =
scheduler.stageIdToStage(3).asInstanceOf[ShuffleMapStage]
+ assert(shuffleIndeterminateStage.isIndeterminate)
+ scheduler.handleShuffleMergeFinalized(shuffleIndeterminateStage, 2)
+ assert(shuffleIndeterminateStage.shuffleDep.shuffleMergeEnabled)
+ assert(!shuffleIndeterminateStage.shuffleDep.shuffleMergeFinalized)
+ }
+
+ // With Adaptive shuffle merge finalization, once minimum shuffle pushes
complete after stage
+ // completion, the existing shuffle merge finalization task with
+ // delay = spark.shuffle.push.finalize.timeout should be replaced with a new
shuffle merge
+ // finalization task with delay = 0
+ test("SPARK-33701: check adaptive shuffle merge finalization with minimum
pushes complete" +
+ " after the stage completion replacing the finalize task with delay = 0") {
+ initPushBasedShuffleConfs(conf)
+ conf.set(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT, 10L)
+ conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 5)
+ conf.set(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO, 0.5)
+ DAGSchedulerSuite.clearMergerLocs()
+ DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4",
"host5"))
+ val parts = 4
+
+ scheduler = new MyDAGScheduler(
+ sc,
+ taskScheduler,
+ sc.listenerBus,
+ mapOutputTracker,
+ blockManagerMaster,
+ sc.env,
+ shuffleMergeFinalize = false)
+ dagEventProcessLoopTester = new
DAGSchedulerEventProcessLoopTester(scheduler)
+
+ // Determinate stage
+ val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+ val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new
HashPartitioner(parts))
+ val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+ val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new
HashPartitioner(parts))
+ val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+ tracker = mapOutputTracker)
+
+ // Submit a reduce job that depends which will create a map stage
+ submit(reduceRdd, (0 until parts).toArray)
+
+ val taskResults = taskSets(0).tasks.zipWithIndex.map {
+ case (_, idx) =>
+ (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+ }.toSeq
+
+ for ((result, i) <- taskResults.zipWithIndex) {
+ runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2))
+ }
+ val shuffleStage1 =
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+ assert(shuffleStage1.shuffleDep.shuffleMergeEnabled)
+ assert(!shuffleStage1.shuffleDep.shuffleMergeFinalized)
+ val finalizeTask1 = shuffleStage1.shuffleDep.getFinalizeTask.get.
+ asInstanceOf[DummyScheduledFuture]
+ assert(finalizeTask1.delay == 10 && finalizeTask1.registerMergeResults)
+
+ // Minimum shuffle pushes complete, replace the finalizeTask with delay =
10
+ // with a finalizeTask with delay = 0
+ pushComplete(shuffleStage1.shuffleDep.shuffleId, 0, 0)
+ pushComplete(shuffleStage1.shuffleDep.shuffleId, 0, 1)
+
+ // Existing finalizeTask with delay = 10 should be replaced with
finalizeTask
+ // with delay = 0
+ val finalizeTask2 = shuffleStage1.shuffleDep.getFinalizeTask.get.
+ asInstanceOf[DummyScheduledFuture]
+ assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
+ }
+
+ /**
* Assert that the supplied TaskSet has exactly the given hosts as its
preferred locations.
* Note that this checks only the host and not the executor ID.
*/
@@ -3922,7 +4218,7 @@ object DAGSchedulerSuite {
locs.foreach { loc => mergerLocs.append(makeBlockManagerId(loc)) }
}
- def clearMergerLocs: Unit = mergerLocs.clear()
+ def clearMergerLocs(): Unit = mergerLocs.clear()
}
diff --git
a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
index 298ba50..94c0417 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
@@ -20,11 +20,11 @@ package org.apache.spark.shuffle
import java.io.{File, FileNotFoundException, IOException}
import java.net.ConnectException
import java.nio.ByteBuffer
-import java.util.concurrent.LinkedBlockingQueue
+import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue, Semaphore}
import scala.collection.mutable.ArrayBuffer
-import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.{ArgumentMatchers, Mock, MockitoAnnotations}
import org.mockito.Answers.RETURNS_SMART_NULLS
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
@@ -32,6 +32,8 @@ import org.mockito.invocation.InvocationOnMock
import org.scalatest.BeforeAndAfterEach
import org.apache.spark._
+import org.apache.spark.executor.CoarseGrainedExecutorBackend
+import
org.apache.spark.internal.config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.server.BlockPushNonFatalFailure
import org.apache.spark.network.server.BlockPushNonFatalFailure.ReturnCode
@@ -40,12 +42,14 @@ import org.apache.spark.network.util.TransportConf
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.ShuffleBlockPusher.PushRequest
import org.apache.spark.storage._
+import org.apache.spark.util.ThreadUtils
class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
@Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager =
_
@Mock(answer = RETURNS_SMART_NULLS) private var dependency:
ShuffleDependency[Int, Int, Int] = _
@Mock(answer = RETURNS_SMART_NULLS) private var shuffleClient:
BlockStoreClient = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var executorBackend:
CoarseGrainedExecutorBackend = _
private var conf: SparkConf = _
private var pushedBlocks = new ArrayBuffer[String]
@@ -54,6 +58,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with
BeforeAndAfterEach {
super.beforeEach()
conf = new SparkConf(loadDefaults = false)
MockitoAnnotations.openMocks(this).close()
+ when(dependency.shuffleId).thenReturn(0)
when(dependency.partitioner).thenReturn(new HashPartitioner(8))
when(dependency.serializer).thenReturn(new JavaSerializer(conf))
when(dependency.getMergerLocs).thenReturn(Seq(BlockManagerId("test-client",
"test-client", 1)))
@@ -62,6 +67,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with
BeforeAndAfterEach {
when(mockEnv.conf).thenReturn(conf)
when(mockEnv.blockManager).thenReturn(blockManager)
SparkEnv.set(mockEnv)
+ when(SparkEnv.get.executorBackend).thenReturn(Some(executorBackend))
when(blockManager.blockStoreClient).thenReturn(shuffleClient)
}
@@ -91,37 +97,104 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with
BeforeAndAfterEach {
})
}
+ private def verifyBlockPushCompleted(
+ blockPusher: ShuffleBlockPusher): Unit = {
+ verify(executorBackend, times(1))
+ .notifyDriverAboutPushCompletion(dependency.shuffleId, 0, 0)
+ assert(blockPusher.isPushCompletionNotified)
+ }
+
test("A batch of blocks is limited by maxBlocksBatchSize") {
+ interceptPushedBlocksForSuccess()
conf.set("spark.shuffle.push.maxBlockBatchSize", "1m")
conf.set("spark.shuffle.push.maxBlockSizeToPush", "2048k")
val blockPusher = new TestShuffleBlockPusher(conf)
val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("",
loc.host, loc.port))
val largeBlockSize = 2 * 1024 * 1024
+ blockPusher.initiateBlockPush(mock(classOf[File]),
+ Array.fill(dependency.partitioner.numPartitions) { 5 }, dependency, 0)
val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
mock(classOf[File]), Array(2, 2, 2, largeBlockSize, largeBlockSize),
mergerLocs,
mock(classOf[TransportConf]))
+ blockPusher.runPendingTasks()
assert(pushRequests.length == 3)
+ verifyBlockPushCompleted(blockPusher)
verifyPushRequests(pushRequests, Seq(6, largeBlockSize, largeBlockSize))
}
test("Large blocks are excluded in the preparation") {
+ interceptPushedBlocksForSuccess()
conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k")
val blockPusher = new TestShuffleBlockPusher(conf)
val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("",
loc.host, loc.port))
+ blockPusher.initiateBlockPush(mock(classOf[File]),
+ Array.fill(dependency.partitioner.numPartitions) { 5 }, dependency, 0)
val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
mock(classOf[File]), Array(2, 2, 2, 1028, 1024), mergerLocs,
mock(classOf[TransportConf]))
+ blockPusher.runPendingTasks()
assert(pushRequests.length == 2)
verifyPushRequests(pushRequests, Seq(6, 1024))
+ verifyBlockPushCompleted(blockPusher)
}
test("Number of blocks in a push request are limited by
maxBlocksInFlightPerAddress ") {
+ interceptPushedBlocksForSuccess()
conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
val blockPusher = new TestShuffleBlockPusher(conf)
val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("",
loc.host, loc.port))
+ blockPusher.initiateBlockPush(mock(classOf[File]),
+ Array.fill(dependency.partitioner.numPartitions) { 5 }, dependency, 0)
val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
mock(classOf[File]), Array(2, 2, 2, 2, 2), mergerLocs,
mock(classOf[TransportConf]))
+ blockPusher.runPendingTasks()
assert(pushRequests.length == 5)
verifyPushRequests(pushRequests, Seq(2, 2, 2, 2, 2))
+ verifyBlockPushCompleted(blockPusher)
+ }
+
+ test("SPARK-33701: Ensure all the blocks are pushed before notifying driver"
+
+ " about push completion") {
+ conf.set(REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS, 12)
+ conf.set("spark.shuffle.push.maxBlockBatchSize", "20b")
+ val latch = new CountDownLatch(1)
+ // Different remote servers to send 2 different requests to ensure that
all the blocks
+ // are pushed before notifying driver about push completion
+
when(dependency.getMergerLocs).thenReturn(Seq(BlockManagerId("test-client",
"test-client", 1),
+ BlockManagerId("slow-client", "slow-client", 1)))
+ when(shuffleClient.pushBlocks(ArgumentMatchers.eq("slow-client"), any(),
any(), any(), any()))
+ .thenAnswer((invocation: InvocationOnMock) => {
+ val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+ val blockPushListener =
invocation.getArguments()(4).asInstanceOf[BlockPushingListener]
+ latch.await()
+ // Add a small wait here to delay the "onBlockPushSuccess" to mimic
the real world
+ Thread.sleep(500)
+ blocks.foreach { blockId =>
+ blockPushListener.onBlockPushSuccess(blockId,
mock(classOf[ManagedBuffer]))
+ }
+ })
+ when(shuffleClient.pushBlocks(ArgumentMatchers.eq("test-client"), any(),
any(), any(), any()))
+ .thenAnswer((invocation: InvocationOnMock) => {
+ val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+ val blockPushListener =
invocation.getArguments()(4).asInstanceOf[BlockPushingListener]
+ latch.await()
+ blocks.foreach { blockId =>
+ blockPushListener.onBlockPushSuccess(blockId,
mock(classOf[ManagedBuffer]))
+ }
+ })
+ val semaphore = new Semaphore(0)
+ val blockPusher = new ConcurrentTestBlockPusher(conf, semaphore)
+ val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("",
loc.host, loc.port))
+ blockPusher.initiateBlockPush(mock(classOf[File]),
+ Array.fill(dependency.partitioner.numPartitions) { 5 }, dependency, 0)
+ val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
+ mock(classOf[File]), Array(2, 2, 2, 2, 2), mergerLocs,
mock(classOf[TransportConf]))
+ latch.countDown()
+ latch.countDown()
+ semaphore.acquire()
+ assert(blockPusher.bytesInFlight <= 0)
+ assert(pushRequests.length == 2)
+ verifyPushRequests(pushRequests, Seq(6, 4))
+ verifyBlockPushCompleted(blockPusher)
}
test("Basic block push") {
@@ -133,6 +206,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with
BeforeAndAfterEach {
verify(shuffleClient, times(1))
.pushBlocks(any(), any(), any(), any(), any())
assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+ verifyBlockPushCompleted(blockPusher)
ShuffleBlockPusher.stop()
}
@@ -146,6 +220,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with
BeforeAndAfterEach {
verify(shuffleClient, times(1))
.pushBlocks(any(), any(), any(), any(), any())
assert(pushedBlocks.length == dependency.partitioner.numPartitions - 1)
+ verifyBlockPushCompleted(pusher)
ShuffleBlockPusher.stop()
}
@@ -159,6 +234,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with
BeforeAndAfterEach {
verify(shuffleClient, times(8))
.pushBlocks(any(), any(), any(), any(), any())
assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+ verifyBlockPushCompleted(pusher)
ShuffleBlockPusher.stop()
}
@@ -199,6 +275,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with
BeforeAndAfterEach {
verify(shuffleClient, times(4))
.pushBlocks(any(), any(), any(), any(), any())
assert(pushedBlocks.length == 8)
+ verifyBlockPushCompleted(pusher)
ShuffleBlockPusher.stop()
}
@@ -213,6 +290,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with
BeforeAndAfterEach {
verify(shuffleClient, times(4))
.pushBlocks(any(), any(), any(), any(), any())
assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+ verifyBlockPushCompleted(pusher)
ShuffleBlockPusher.stop()
}
@@ -279,6 +357,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with
BeforeAndAfterEach {
verify(shuffleClient, times(8))
.pushBlocks(any(), any(), any(), any(), any())
assert(pushedBlocks.length == 7)
+ verifyBlockPushCompleted(pusher)
}
test("More blocks are not pushed when a block push fails with too late " +
@@ -333,6 +412,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with
BeforeAndAfterEach {
// 2 blocks for each merger locations
assert(pushedBlocks.length == 4)
assert(pusher.unreachableBlockMgrs.size == 2)
+ verifyBlockPushCompleted(pusher)
}
test("SPARK-36255: FileNotFoundException stops the push") {
@@ -359,7 +439,8 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with
BeforeAndAfterEach {
ShuffleBlockPusher.stop()
}
- private class TestShuffleBlockPusher(conf: SparkConf) extends
ShuffleBlockPusher(conf) {
+ private class TestShuffleBlockPusher(
+ conf: SparkConf) extends ShuffleBlockPusher(conf) {
val tasks = new LinkedBlockingQueue[Runnable]
override protected def submitTask(task: Runnable): Unit = {
@@ -385,4 +466,18 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with
BeforeAndAfterEach {
managedBuffer
}
}
+
+ private class ConcurrentTestBlockPusher(conf: SparkConf, semaphore:
Semaphore)
+ extends TestShuffleBlockPusher(conf) {
+ val blockPusher = ThreadUtils.newDaemonFixedThreadPool(1,
"test-block-pusher")
+
+ override protected def submitTask(task: Runnable): Unit = {
+ blockPusher.execute(task)
+ }
+
+ override def notifyDriverAboutPushCompletion(): Unit = {
+ super.notifyDriverAboutPushCompletion()
+ semaphore.release()
+ }
+ }
}
diff --git a/docs/configuration.md b/docs/configuration.md
index 2d4164f..80f17a8 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -3268,4 +3268,20 @@ Push-based shuffle helps improve the reliability and
performance of spark shuffl
</td>
<td>3.2.0</td>
</tr>
+<tr>
+ <td><code>spark.shuffle.push.minShuffleSizeToWait</code></td>
+ <td><code>500m</code></td>
+ <td>
+ Driver will wait for merge finalization to complete only if total shuffle
data size is more than this threshold. If total shuffle size is less, driver
will immediately finalize the shuffle output.
+ </td>
+ <td>3.3.0</td>
+</tr>
+<tr>
+ <td><code>spark.shuffle.push.minCompletedPushRatio</code></td>
+ <td><code>1.0</code></td>
+ <td>
+ Fraction of minimum map partitions that should be push complete before
driver starts shuffle merge finalization during push based shuffle.
+ </td>
+ <td>3.3.0</td>
+</tr>
</table>
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]