This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b5a1503  [SPARK-32920][SHUFFLE] Finalization of Shuffle push/merge 
with Push based shuffle and preparation step for the reduce stage
b5a1503 is described below

commit b5a15035851bfba12ef1c68d10103cec42cbac0c
Author: Venkata krishnan Sowrirajan <vsowrira...@linkedin.com>
AuthorDate: Thu Jun 10 13:06:15 2021 -0500

    [SPARK-32920][SHUFFLE] Finalization of Shuffle push/merge with Push based 
shuffle and preparation step for the reduce stage
    
    ### What changes were proposed in this pull request?
    
    Summary of the changes made as part of this PR:
    
    1. `DAGScheduler` changes to finalize a ShuffleMapStage which involves 
talking to all the shuffle mergers (`ExternalShuffleService`) and getting all 
the completed merge statuses.
    2. Once the `ShuffleMapStage` finalization is complete, mark the 
`ShuffleMapStage` to be finalized which marks the stage as complete and 
subsequently letting the child stage start.
    3. Also added the relevant tests to `DAGSchedulerSuite` for changes made as 
part of [SPARK-32919](https://issues.apache.org/jira/browse/SPARK-32919)
    
    Lead-authored-by: Min Shen mshenlinkedin.com
    Co-authored-by: Venkata krishnan Sowrirajan vsowrirajanlinkedin.com
    Co-authored-by: Chandni Singh chsinghlinkedin.com
    
    ### Why are the changes needed?
    
    Refer to [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602)
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added unit tests to DAGSchedulerSuite
    
    Closes #30691 from venkata91/SPARK-32920.
    
    Lead-authored-by: Venkata krishnan Sowrirajan <vsowrira...@linkedin.com>
    Co-authored-by: Min Shen <ms...@linkedin.com>
    Co-authored-by: Chandni Singh <chsi...@linkedin.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../main/scala/org/apache/spark/Dependency.scala   |  38 ++
 .../scala/org/apache/spark/MapOutputTracker.scala  |  44 +-
 .../org/apache/spark/internal/config/package.scala |  23 +-
 .../org/apache/spark/scheduler/DAGScheduler.scala  | 257 +++++++++---
 .../apache/spark/scheduler/DAGSchedulerEvent.scala |   6 +
 .../org/apache/spark/scheduler/StageInfo.scala     |   2 +-
 ...g.apache.spark.scheduler.ExternalClusterManager |   1 +
 .../apache/spark/scheduler/DAGSchedulerSuite.scala | 448 ++++++++++++++++++++-
 8 files changed, 747 insertions(+), 72 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala 
b/core/src/main/scala/org/apache/spark/Dependency.scala
index d21b9d9..0a9acf4 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor}
 import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.Utils
 
 /**
  * :: DeveloperApi ::
@@ -96,12 +97,31 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
   val shuffleHandle: ShuffleHandle = 
_rdd.context.env.shuffleManager.registerShuffle(
     shuffleId, this)
 
+  // By default, shuffle merge is enabled for ShuffleDependency if push based 
shuffle
+  // is enabled
+  private[this] var _shuffleMergeEnabled =
+    Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf) &&
+    // TODO: SPARK-35547: Push based shuffle is currently unsupported for 
Barrier stages
+    !rdd.isBarrier()
+
+  private[spark] def setShuffleMergeEnabled(shuffleMergeEnabled: Boolean): 
Unit = {
+    _shuffleMergeEnabled = shuffleMergeEnabled
+  }
+
+  def shuffleMergeEnabled : Boolean = _shuffleMergeEnabled
+
   /**
    * Stores the location of the list of chosen external shuffle services for 
handling the
    * shuffle merge requests from mappers in this shuffle map stage.
    */
   private[spark] var mergerLocs: Seq[BlockManagerId] = Nil
 
+  /**
+   * Stores the information about whether the shuffle merge is finalized for 
the shuffle map stage
+   * associated with this shuffle dependency
+   */
+  private[this] var _shuffleMergedFinalized: Boolean = false
+
   def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = {
     if (mergerLocs != null) {
       this.mergerLocs = mergerLocs
@@ -110,6 +130,24 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
 
   def getMergerLocs: Seq[BlockManagerId] = mergerLocs
 
+  private[spark] def markShuffleMergeFinalized(): Unit = {
+    _shuffleMergedFinalized = true
+  }
+
+  /**
+   * Returns true if push-based shuffle is disabled for this stage or empty 
RDD,
+   * or if the shuffle merge for this stage is finalized, i.e. the shuffle 
merge
+   * results for all partitions are available.
+   */
+  def shuffleMergeFinalized: Boolean = {
+    // Empty RDD won't be computed therefore shuffle merge finalized should be 
true by default.
+    if (shuffleMergeEnabled && rdd.getNumPartitions > 0) {
+      _shuffleMergedFinalized
+    } else {
+      true
+    }
+  }
+
   _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
   _rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId)
 }
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala 
b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index ea9e641..9f2228b 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -214,6 +214,7 @@ private class ShuffleStatus(
   def removeOutputsOnHost(host: String): Unit = withWriteLock {
     logDebug(s"Removing outputs for host ${host}")
     removeOutputsByFilter(x => x.host == host)
+    removeMergeResultsByFilter(x => x.host == host)
   }
 
   /**
@@ -238,6 +239,12 @@ private class ShuffleStatus(
         invalidateSerializedMapOutputStatusCache()
       }
     }
+  }
+
+  /**
+   * Removes all shuffle merge result which satisfies the filter.
+   */
+  def removeMergeResultsByFilter(f: BlockManagerId => Boolean): Unit = 
withWriteLock {
     for (reduceId <- mergeStatuses.indices) {
       if (mergeStatuses(reduceId) != null && 
f(mergeStatuses(reduceId).location)) {
         _numAvailableMergeResults -= 1
@@ -708,15 +715,16 @@ private[spark] class MapOutputTrackerMaster(
     }
   }
 
-  /** Unregister all map output information of the given shuffle. */
-  def unregisterAllMapOutput(shuffleId: Int): Unit = {
+  /** Unregister all map and merge output information of the given shuffle. */
+  def unregisterAllMapAndMergeOutput(shuffleId: Int): Unit = {
     shuffleStatuses.get(shuffleId) match {
       case Some(shuffleStatus) =>
         shuffleStatus.removeOutputsByFilter(x => true)
+        shuffleStatus.removeMergeResultsByFilter(x => true)
         incrementEpoch()
       case None =>
         throw new SparkException(
-          s"unregisterAllMapOutput called for nonexistent shuffle ID 
$shuffleId.")
+          s"unregisterAllMapAndMergeOutput called for nonexistent shuffle ID 
$shuffleId.")
     }
   }
 
@@ -731,25 +739,26 @@ private[spark] class MapOutputTrackerMaster(
   }
 
   /**
-   * Unregisters a merge result corresponding to the reduceId if present. If 
the optional mapId
-   * is specified, it will only unregister the merge result if the mapId is 
part of that merge
+   * Unregisters a merge result corresponding to the reduceId if present. If 
the optional mapIndex
+   * is specified, it will only unregister the merge result if the mapIndex is 
part of that merge
    * result.
    *
    * @param shuffleId the shuffleId.
    * @param reduceId  the reduceId.
    * @param bmAddress block manager address.
-   * @param mapId     the optional mapId which should be checked to see it was 
part of the merge
-   *                  result.
+   * @param mapIndex  the optional mapIndex which should be checked to see it 
was part of the
+   *                  merge result.
    */
   def unregisterMergeResult(
-    shuffleId: Int,
-    reduceId: Int,
-    bmAddress: BlockManagerId,
-    mapId: Option[Int] = None): Unit = {
+      shuffleId: Int,
+      reduceId: Int,
+      bmAddress: BlockManagerId,
+      mapIndex: Option[Int] = None): Unit = {
     shuffleStatuses.get(shuffleId) match {
       case Some(shuffleStatus) =>
         val mergeStatus = shuffleStatus.mergeStatuses(reduceId)
-        if (mergeStatus != null && (mapId.isEmpty || 
mergeStatus.tracker.contains(mapId.get))) {
+        if (mergeStatus != null &&
+          (mapIndex.isEmpty || mergeStatus.tracker.contains(mapIndex.get))) {
           shuffleStatus.removeMergeResult(reduceId, bmAddress)
           incrementEpoch()
         }
@@ -758,6 +767,17 @@ private[spark] class MapOutputTrackerMaster(
     }
   }
 
+  def unregisterAllMergeResult(shuffleId: Int): Unit = {
+    shuffleStatuses.get(shuffleId) match {
+      case Some(shuffleStatus) =>
+        shuffleStatus.removeMergeResultsByFilter(x => true)
+        incrementEpoch()
+      case None =>
+        throw new SparkException(
+          s"unregisterAllMergeResult called for nonexistent shuffle ID 
$shuffleId.")
+    }
+  }
+
   /** Unregister shuffle data */
   def unregisterShuffle(shuffleId: Int): Unit = {
     shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala 
b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 9574416..84bd8cc 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -2084,6 +2084,27 @@ package object config {
       .booleanConf
       .createWithDefault(false)
 
+  private[spark] val PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT =
+    ConfigBuilder("spark.shuffle.push.merge.results.timeout")
+      .doc("Specify the max amount of time DAGScheduler waits for the merge 
results from " +
+        "all remote shuffle services for a given shuffle. DAGScheduler will 
start to submit " +
+        "following stages if not all results are received within the timeout.")
+      .version("3.2.0")
+      .timeConf(TimeUnit.SECONDS)
+      .checkValue(_ >= 0L, "Timeout must be >= 0.")
+      .createWithDefaultString("10s")
+
+  private[spark] val PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT =
+    ConfigBuilder("spark.shuffle.push.merge.finalize.timeout")
+      .doc("Specify the amount of time DAGScheduler waits after all mappers 
finish for " +
+        "a given shuffle map stage before it starts sending merge finalize 
requests to " +
+        "remote shuffle services. This allows the shuffle services some extra 
time to " +
+        "merge as many blocks as possible.")
+      .version("3.2.0")
+      .timeConf(TimeUnit.SECONDS)
+      .checkValue(_ >= 0L, "Timeout must be >= 0.")
+      .createWithDefaultString("10s")
+
   private[spark] val SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS =
     ConfigBuilder("spark.shuffle.push.maxRetainedMergerLocations")
       .doc("Maximum number of shuffle push merger locations cached for push 
based shuffle. " +
@@ -2117,7 +2138,7 @@ package object config {
         s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} set to 0.05, we 
would need " +
         "at least 50 mergers to enable push based shuffle for that stage.")
       .version("3.1.0")
-      .doubleConf
+      .intConf
       .createWithDefault(5)
 
   private[spark] val SHUFFLE_NUM_PUSH_THREADS =
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b359501..1f37638 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
 
 import java.io.NotSerializableException
 import java.util.Properties
-import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, TimeoutException, TimeUnit}
 import java.util.concurrent.atomic.AtomicInteger
 
 import scala.annotation.tailrec
@@ -29,12 +29,16 @@ import scala.collection.mutable.{HashMap, HashSet, 
ListBuffer}
 import scala.concurrent.duration._
 import scala.util.control.NonFatal
 
+import com.google.common.util.concurrent.{Futures, SettableFuture}
+
 import org.apache.spark._
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config
 import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY
+import org.apache.spark.network.shuffle.{BlockStoreClient, 
MergeFinalizerListener}
+import org.apache.spark.network.shuffle.protocol.MergeStatuses
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.partial.{ApproximateActionListener, 
ApproximateEvaluator, PartialResult}
 import org.apache.spark.rdd.{RDD, RDDCheckpointData}
@@ -254,6 +258,24 @@ private[spark] class DAGScheduler(
   private val blockManagerMasterDriverHeartbeatTimeout =
     
sc.getConf.get(config.STORAGE_BLOCKMANAGER_MASTER_DRIVER_HEARTBEAT_TIMEOUT).millis
 
+  private val shuffleMergeResultsTimeoutSec =
+    sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT)
+
+  private val shuffleMergeFinalizeWaitSec =
+    sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT)
+
+  // Since SparkEnv gets initialized after DAGScheduler, externalShuffleClient 
needs to be
+  // initialized lazily
+  private lazy val externalShuffleClient: Option[BlockStoreClient] =
+    if (pushBasedShuffleEnabled) {
+      Some(env.blockManager.blockStoreClient)
+    } else {
+      None
+    }
+
+  private val shuffleMergeFinalizeScheduler =
+    
ThreadUtils.newDaemonThreadPoolScheduledExecutor("shuffle-merge-finalizer", 8)
+
   /**
    * Called by the TaskSetManager to report task's starting.
    */
@@ -689,7 +711,10 @@ private[spark] class DAGScheduler(
             dep match {
               case shufDep: ShuffleDependency[_, _, _] =>
                 val mapStage = getOrCreateShuffleMapStage(shufDep, 
stage.firstJobId)
-                if (!mapStage.isAvailable) {
+                // Mark mapStage as available with shuffle outputs only after 
shuffle merge is
+                // finalized with push based shuffle. If not, subsequent 
ShuffleMapStage won't
+                // read from merged output as the MergeStatuses are not 
available.
+                if (!mapStage.isAvailable || 
!mapStage.shuffleDep.shuffleMergeFinalized) {
                   missing += mapStage
                 }
               case narrowDep: NarrowDependency[_] =>
@@ -1271,21 +1296,21 @@ private[spark] class DAGScheduler(
    * locations for block push/merge by getting the historical locations of 
past executors.
    */
   private def prepareShuffleServicesForShuffleMapStage(stage: 
ShuffleMapStage): Unit = {
-    // TODO(SPARK-32920) Handle stage reuse/retry cases separately as without 
finalize
-    // TODO changes we cannot disable shuffle merge for the retry/reuse cases
-    val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
-      stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
-
-    if (mergerLocs.nonEmpty) {
-      stage.shuffleDep.setMergerLocs(mergerLocs)
-      logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
-        s" ${stage.shuffleDep.getMergerLocs.size} merger locations")
-
-      logDebug("List of shuffle push merger locations " +
-        s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
-    } else {
-      logInfo("No available merger locations." +
-        s" Push-based shuffle disabled for $stage (${stage.name})")
+    assert(stage.shuffleDep.shuffleMergeEnabled && 
!stage.shuffleDep.shuffleMergeFinalized)
+    if (stage.shuffleDep.getMergerLocs.isEmpty) {
+      val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
+        stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+      if (mergerLocs.nonEmpty) {
+        stage.shuffleDep.setMergerLocs(mergerLocs)
+        logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
+          s" ${stage.shuffleDep.getMergerLocs.size} merger locations")
+
+        logDebug("List of shuffle push merger locations " +
+          s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+      } else {
+        stage.shuffleDep.setShuffleMergeEnabled(false)
+        logInfo("Push-based shuffle disabled for $stage (${stage.name})")
+      }
     }
   }
 
@@ -1298,7 +1323,9 @@ private[spark] class DAGScheduler(
     // `findMissingPartitions()` returns all partitions every time.
     stage match {
       case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
-        mapOutputTracker.unregisterAllMapOutput(sms.shuffleDep.shuffleId)
+        // TODO: SPARK-32923: Clean all push-based shuffle metadata like merge 
enabled and
+        // TODO: finalized as we are clearing all the merge results.
+        
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
       case _ =>
     }
 
@@ -1318,11 +1345,19 @@ private[spark] class DAGScheduler(
     stage match {
       case s: ShuffleMapStage =>
         outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = 
s.numPartitions - 1)
-        // Only generate merger location for a given shuffle dependency once. 
This way, even if
-        // this stage gets retried, it would still be merging blocks using the 
same set of
-        // shuffle services.
-        if (pushBasedShuffleEnabled) {
-          prepareShuffleServicesForShuffleMapStage(s)
+        // Only generate merger location for a given shuffle dependency once.
+        if (s.shuffleDep.shuffleMergeEnabled) {
+          if (!s.shuffleDep.shuffleMergeFinalized) {
+            prepareShuffleServicesForShuffleMapStage(s)
+          } else {
+            // Disable Shuffle merge for the retry/reuse of the same shuffle 
dependency if it has
+            // already been merge finalized. If the shuffle dependency was 
previously assigned
+            // merger locations but the corresponding shuffle map stage did 
not complete
+            // successfully, we would still enable push for its retry.
+            s.shuffleDep.setShuffleMergeEnabled(false)
+            logInfo("Push-based shuffle disabled for $stage (${stage.name}) 
since it" +
+              " is already shuffle merge finalized")
+          }
         }
       case s: ResultStage =>
         outputCommitCoordinator.stageStart(
@@ -1678,38 +1713,16 @@ private[spark] class DAGScheduler(
             }
 
             if (runningStages.contains(shuffleStage) && 
shuffleStage.pendingPartitions.isEmpty) {
-              markStageAsFinished(shuffleStage)
-              logInfo("looking for newly runnable stages")
-              logInfo("running: " + runningStages)
-              logInfo("waiting: " + waitingStages)
-              logInfo("failed: " + failedStages)
-
-              // This call to increment the epoch may not be strictly 
necessary, but it is retained
-              // for now in order to minimize the changes in behavior from an 
earlier version of the
-              // code. This existing behavior of always incrementing the epoch 
following any
-              // successful shuffle map stage completion may have benefits by 
causing unneeded
-              // cached map outputs to be cleaned up earlier on executors. In 
the future we can
-              // consider removing this call, but this will require some extra 
investigation.
-              // See 
https://github.com/apache/spark/pull/17955/files#r117385673 for more details.
-              mapOutputTracker.incrementEpoch()
-
-              clearCacheLocs()
-
-              if (!shuffleStage.isAvailable) {
-                // Some tasks had failed; let's resubmit this shuffleStage.
-                // TODO: Lower-level scheduler should also deal with this
-                logInfo("Resubmitting " + shuffleStage + " (" + 
shuffleStage.name +
-                  ") because some of its tasks had failed: " +
-                  shuffleStage.findMissingPartitions().mkString(", "))
-                submitStage(shuffleStage)
+              if (!shuffleStage.shuffleDep.shuffleMergeFinalized &&
+                shuffleStage.shuffleDep.getMergerLocs.nonEmpty) {
+                scheduleShuffleMergeFinalize(shuffleStage)
               } else {
-                markMapStageJobsAsFinished(shuffleStage)
-                submitWaitingChildStages(shuffleStage)
+                processShuffleMapStageCompletion(shuffleStage)
               }
             }
         }
 
-      case FetchFailed(bmAddress, shuffleId, _, mapIndex, _, failureMessage) =>
+      case FetchFailed(bmAddress, shuffleId, _, mapIndex, reduceId, 
failureMessage) =>
         val failedStage = stageIdToStage(task.stageId)
         val mapStage = shuffleIdToMapStage(shuffleId)
 
@@ -1739,10 +1752,18 @@ private[spark] class DAGScheduler(
           if (mapStage.rdd.isBarrier()) {
             // Mark all the map as broken in the map stage, to ensure retry 
all the tasks on
             // resubmitted stage attempt.
-            mapOutputTracker.unregisterAllMapOutput(shuffleId)
+            // TODO: SPARK-35547: Clean all push-based shuffle metadata like 
merge enabled and
+            // TODO: finalized as we are clearing all the merge results.
+            mapOutputTracker.unregisterAllMapAndMergeOutput(shuffleId)
           } else if (mapIndex != -1) {
             // Mark the map whose fetch failed as broken in the map stage
             mapOutputTracker.unregisterMapOutput(shuffleId, mapIndex, 
bmAddress)
+            if (pushBasedShuffleEnabled) {
+              // Possibly unregister the merge result <shuffleId, reduceId>, 
if the FetchFailed
+              // mapIndex is part of the merge result of <shuffleId, reduceId>
+              mapOutputTracker.
+                unregisterMergeResult(shuffleId, reduceId, bmAddress, 
Option(mapIndex))
+            }
           }
 
           if (failedStage.rdd.isBarrier()) {
@@ -1750,7 +1771,7 @@ private[spark] class DAGScheduler(
               case failedMapStage: ShuffleMapStage =>
                 // Mark all the map as broken in the map stage, to ensure 
retry all the tasks on
                 // resubmitted stage attempt.
-                
mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId)
+                
mapOutputTracker.unregisterAllMapAndMergeOutput(failedMapStage.shuffleDep.shuffleId)
 
               case failedResultStage: ResultStage =>
                 // Abort the failed result stage since we may have committed 
output for some
@@ -1959,7 +1980,7 @@ private[spark] class DAGScheduler(
               case failedMapStage: ShuffleMapStage =>
                 // Mark all the map as broken in the map stage, to ensure 
retry all the tasks on
                 // resubmitted stage attempt.
-                
mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId)
+                
mapOutputTracker.unregisterAllMapAndMergeOutput(failedMapStage.shuffleDep.shuffleId)
 
               case failedResultStage: ResultStage =>
                 // Abort the failed result stage since we may have committed 
output for some
@@ -2000,6 +2021,130 @@ private[spark] class DAGScheduler(
     }
   }
 
+  /**
+   * Schedules shuffle merge finalize.
+   */
+  private[scheduler] def scheduleShuffleMergeFinalize(stage: ShuffleMapStage): 
Unit = {
+    // TODO: SPARK-33701: Instead of waiting for a constant amount of time for 
finalization
+    // TODO: for all the stages, adaptively tune timeout for merge finalization
+    logInfo(("%s (%s) scheduled for finalizing" +
+      " shuffle merge in %s s").format(stage, stage.name, 
shuffleMergeFinalizeWaitSec))
+    shuffleMergeFinalizeScheduler.schedule(
+      new Runnable {
+        override def run(): Unit = finalizeShuffleMerge(stage)
+      },
+      shuffleMergeFinalizeWaitSec,
+      TimeUnit.SECONDS
+    )
+  }
+
+  /**
+   * DAGScheduler notifies all the remote shuffle services chosen to serve 
shuffle merge request for
+   * the given shuffle map stage to finalize the shuffle merge process for 
this shuffle. This is
+   * invoked in a separate thread to reduce the impact on the DAGScheduler 
main thread, as the
+   * scheduler might need to talk to 1000s of shuffle services to finalize 
shuffle merge.
+   */
+  private[scheduler] def finalizeShuffleMerge(stage: ShuffleMapStage): Unit = {
+    logInfo("%s (%s) finalizing the shuffle merge".format(stage, stage.name))
+    externalShuffleClient.foreach { shuffleClient =>
+      val shuffleId = stage.shuffleDep.shuffleId
+      val numMergers = stage.shuffleDep.getMergerLocs.length
+      val results = (0 until numMergers).map(_ => 
SettableFuture.create[Boolean]())
+
+      stage.shuffleDep.getMergerLocs.zipWithIndex.foreach {
+        case (shuffleServiceLoc, index) =>
+          // Sends async request to shuffle service to finalize shuffle merge 
on that host
+          // TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage is 
cancelled
+          // TODO: during shuffleMergeFinalizeWaitSec
+          shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
+            shuffleServiceLoc.port, shuffleId,
+            new MergeFinalizerListener {
+              override def onShuffleMergeSuccess(statuses: MergeStatuses): 
Unit = {
+                assert(shuffleId == statuses.shuffleId)
+                eventProcessLoop.post(RegisterMergeStatuses(stage, MergeStatus.
+                  convertMergeStatusesToMergeStatusArr(statuses, 
shuffleServiceLoc)))
+                results(index).set(true)
+              }
+
+              override def onShuffleMergeFailure(e: Throwable): Unit = {
+                logWarning(s"Exception encountered when trying to finalize 
shuffle " +
+                  s"merge on ${shuffleServiceLoc.host} for shuffle 
$shuffleId", e)
+                // Do not fail the future as this would cause dag scheduler to 
prematurely
+                // give up on waiting for merge results from the remaining 
shuffle services
+                // if one fails
+                results(index).set(false)
+              }
+            })
+      }
+      // DAGScheduler only waits for a limited amount of time for the merge 
results.
+      // It will attempt to submit the next stage(s) irrespective of whether 
merge results
+      // from all shuffle services are received or not.
+      try {
+        Futures.allAsList(results: _*).get(shuffleMergeResultsTimeoutSec, 
TimeUnit.SECONDS)
+      } catch {
+        case _: TimeoutException =>
+          logInfo(s"Timed out on waiting for merge results from all " +
+            s"$numMergers mergers for shuffle $shuffleId")
+      } finally {
+        eventProcessLoop.post(ShuffleMergeFinalized(stage))
+      }
+    }
+  }
+
+  private def processShuffleMapStageCompletion(shuffleStage: ShuffleMapStage): 
Unit = {
+    markStageAsFinished(shuffleStage)
+    logInfo("looking for newly runnable stages")
+    logInfo("running: " + runningStages)
+    logInfo("waiting: " + waitingStages)
+    logInfo("failed: " + failedStages)
+
+    // This call to increment the epoch may not be strictly necessary, but it 
is retained
+    // for now in order to minimize the changes in behavior from an earlier 
version of the
+    // code. This existing behavior of always incrementing the epoch following 
any
+    // successful shuffle map stage completion may have benefits by causing 
unneeded
+    // cached map outputs to be cleaned up earlier on executors. In the future 
we can
+    // consider removing this call, but this will require some extra 
investigation.
+    // See https://github.com/apache/spark/pull/17955/files#r117385673 for 
more details.
+    mapOutputTracker.incrementEpoch()
+
+    clearCacheLocs()
+
+    if (!shuffleStage.isAvailable) {
+      // Some tasks had failed; let's resubmit this shuffleStage.
+      // TODO: Lower-level scheduler should also deal with this
+      logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name +
+        ") because some of its tasks had failed: " +
+        shuffleStage.findMissingPartitions().mkString(", "))
+      submitStage(shuffleStage)
+    } else {
+      markMapStageJobsAsFinished(shuffleStage)
+      submitWaitingChildStages(shuffleStage)
+    }
+  }
+
+  private[scheduler] def handleRegisterMergeStatuses(
+      stage: ShuffleMapStage,
+      mergeStatuses: Seq[(Int, MergeStatus)]): Unit = {
+    // Register merge statuses if the stage is still running and shuffle merge 
is not finalized yet.
+    // TODO: SPARK-35549: Currently merge statuses results which come after 
shuffle merge
+    // TODO: is finalized is not registered.
+    if (runningStages.contains(stage) && 
!stage.shuffleDep.shuffleMergeFinalized) {
+      mapOutputTracker.registerMergeResults(stage.shuffleDep.shuffleId, 
mergeStatuses)
+    }
+  }
+
+  private[scheduler] def handleShuffleMergeFinalized(stage: ShuffleMapStage): 
Unit = {
+    // Only update MapOutputTracker metadata if the stage is still active. i.e 
not cancelled.
+    if (runningStages.contains(stage)) {
+      stage.shuffleDep.markShuffleMergeFinalized()
+      processShuffleMapStageCompletion(stage)
+    } else {
+      // Unregister all merge results if the stage is currently not
+      // active (i.e. the stage is cancelled)
+      mapOutputTracker.unregisterAllMergeResult(stage.shuffleDep.shuffleId)
+    }
+  }
+
   private def handleResubmittedFailure(task: Task[_], stage: Stage): Unit = {
     logInfo(s"Resubmitted $task, so marking it as still running.")
     stage match {
@@ -2447,6 +2592,12 @@ private[scheduler] class 
DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
 
     case ResubmitFailedStages =>
       dagScheduler.resubmitFailedStages()
+
+    case RegisterMergeStatuses(stage, mergeStatuses) =>
+      dagScheduler.handleRegisterMergeStatuses(stage, mergeStatuses)
+
+    case ShuffleMergeFinalized(stage) =>
+      dagScheduler.handleShuffleMergeFinalized(stage)
   }
 
   override def onError(e: Throwable): Unit = {
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index d226fe8..307844c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -105,3 +105,9 @@ private[scheduler]
 case class UnschedulableTaskSetRemoved(stageId: Int, stageAttemptId: Int)
   extends DAGSchedulerEvent
 
+private[scheduler] case class RegisterMergeStatuses(
+    stage: ShuffleMapStage, mergeStatuses: Seq[(Int, MergeStatus)])
+  extends DAGSchedulerEvent
+
+private[scheduler] case class ShuffleMergeFinalized(stage: ShuffleMapStage)
+  extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala 
b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index 556478d..7b681bf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -42,7 +42,7 @@ class StageInfo(
     val resourceProfileId: Int) {
   /** When this stage was submitted from the DAGScheduler to a TaskScheduler. 
*/
   var submissionTime: Option[Long] = None
-  /** Time when all tasks in the stage completed or when the stage was 
cancelled. */
+  /** Time when the stage completed or when the stage was cancelled. */
   var completionTime: Option[Long] = None
   /** If the stage failed, the reason why. */
   var failureReason: Option[String] = None
diff --git 
a/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager
 
b/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager
index 60054c8..33b162e 100644
--- 
a/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager
+++ 
b/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager
@@ -1,3 +1,4 @@
 org.apache.spark.scheduler.DummyExternalClusterManager
 org.apache.spark.scheduler.MockExternalClusterManager
 org.apache.spark.scheduler.CSMockExternalClusterManager
+org.apache.spark.scheduler.PushBasedClusterManager
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 4c74e4f..f6e87ee 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -25,9 +25,8 @@ import scala.annotation.meta.param
 import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
 import scala.util.control.NonFatal
 
-import org.mockito.Mockito.spy
-import org.mockito.Mockito.times
-import org.mockito.Mockito.verify
+import org.mockito.Mockito._
+import org.roaringbitmap.RoaringBitmap
 import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
 import org.scalatest.exceptions.TestFailedException
 import org.scalatest.time.SpanSugar._
@@ -40,9 +39,10 @@ import org.apache.spark.rdd.{DeterministicLevel, RDD}
 import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile, 
ResourceProfileBuilder, TaskResourceRequests}
 import org.apache.spark.resource.ResourceUtils.{FPGA, GPU}
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.scheduler.local.LocalSchedulerBackend
 import org.apache.spark.shuffle.{FetchFailedException, 
MetadataFetchFailedException}
 import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
-import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, 
LongAccumulator, Utils}
+import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, 
Clock, LongAccumulator, SystemClock, Utils}
 
 class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
   extends DAGSchedulerEventProcessLoop(dagScheduler) {
@@ -295,6 +295,35 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     }
   }
 
+  class MyDAGScheduler(
+      sc: SparkContext,
+      taskScheduler: TaskScheduler,
+      listenerBus: LiveListenerBus,
+      mapOutputTracker: MapOutputTrackerMaster,
+      blockManagerMaster: BlockManagerMaster,
+      env: SparkEnv,
+      clock: Clock = new SystemClock(),
+      shuffleMergeFinalize: Boolean = true,
+      shuffleMergeRegister: Boolean = true
+  ) extends DAGScheduler(
+      sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, 
env, clock) {
+    /**
+     * Schedules shuffle merge finalize.
+     */
+    override private[scheduler] def scheduleShuffleMergeFinalize(
+        shuffleMapStage: ShuffleMapStage): Unit = {
+      if (shuffleMergeRegister) {
+        for (part <- 0 until 
shuffleMapStage.shuffleDep.partitioner.numPartitions) {
+          val mergeStatuses = Seq((part, makeMergeStatus("")))
+          handleRegisterMergeStatuses(shuffleMapStage, mergeStatuses)
+        }
+        if (shuffleMergeFinalize) {
+          handleShuffleMergeFinalized(shuffleMapStage)
+        }
+      }
+    }
+  }
+
   override def beforeEach(): Unit = {
     super.beforeEach()
     firstInit = true
@@ -322,13 +351,14 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     broadcastManager = new BroadcastManager(true, sc.getConf)
     mapOutputTracker = spy(new MyMapOutputTrackerMaster(sc.getConf, 
broadcastManager))
     blockManagerMaster = spy(new MyBlockManagerMaster(sc.getConf))
-    scheduler = new DAGScheduler(
+    scheduler = new MyDAGScheduler(
       sc,
       taskScheduler,
       sc.listenerBus,
       mapOutputTracker,
       blockManagerMaster,
       sc.env)
+
     dagEventProcessLoopTester = new 
DAGSchedulerEventProcessLoopTester(scheduler)
   }
 
@@ -3393,6 +3423,359 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(rprofsE === Set())
   }
 
+  private def initPushBasedShuffleConfs(conf: SparkConf) = {
+    conf.set(config.SHUFFLE_SERVICE_ENABLED, true)
+    conf.set(config.PUSH_BASED_SHUFFLE_ENABLED, true)
+    conf.set("spark.master", "pushbasedshuffleclustermanager")
+  }
+
+  test("SPARK-32920: shuffle merge finalization") {
+    initPushBasedShuffleConfs(conf)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 2
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) 
== parts)
+    completeNextResultStageWithSuccess(1, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: merger locations not empty") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 3)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 2
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    val shuffleStage = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage.shuffleDep.getMergerLocs.nonEmpty)
+
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) 
== parts)
+    completeNextResultStageWithSuccess(1, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: merger locations reuse from shuffle dependency") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS, 3)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 2
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+    submit(reduceRdd, Array(0, 1))
+
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    assert(shuffleDep.getMergerLocs.nonEmpty)
+    val mergerLocs = shuffleDep.getMergerLocs
+    completeNextResultStageWithSuccess(1, 0 )
+
+    // submit another job w/ the shared dependency, and have a fetch failure
+    val reduce2 = new MyRDD(sc, 2, List(shuffleDep))
+    submit(reduce2, Array(0, 1))
+    // Note that the stage numbering here is only b/c the shared dependency 
produces a new, skipped
+    // stage.  If instead it reused the existing stage, then this would be 
stage 2
+    completeNextStageWithFetchFailure(3, 0, shuffleDep)
+    scheduler.resubmitFailedStages()
+
+    assert(scheduler.runningStages.nonEmpty)
+    assert(scheduler.stageIdToStage(2)
+      .asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs.nonEmpty)
+    val newMergerLocs = scheduler.stageIdToStage(2)
+      .asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs
+
+    // Check if same merger locs is reused for the new stage with shared 
shuffle dependency
+    assert(mergerLocs.zip(newMergerLocs).forall(x => x._1.host == x._2.host))
+    completeShuffleMapStageSuccessfully(2, 0, 2)
+    completeNextResultStageWithSuccess(3, 1, idx => idx + 1234)
+    assert(results === Map(0 -> 1234, 1 -> 1235))
+
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: Disable shuffle merge due to not enough mergers 
available") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 7
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    val shuffleStage = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(!shuffleStage.shuffleDep.shuffleMergeEnabled)
+
+    completeNextResultStageWithSuccess(1, 0)
+    assert(results === Map(2 -> 42, 5 -> 42, 4 -> 42, 1 -> 42, 3 -> 42, 6 -> 
42, 0 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: Ensure child stage should not start before all the" +
+      " parent stages are completed with shuffle merge finalized for all the 
parent stages") {
+    initPushBasedShuffleConfs(conf)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 1
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2), 
tracker = mapOutputTracker)
+
+    // Submit a reduce job
+    submit(reduceRdd, (0 until parts).toArray)
+
+    complete(taskSets(0), Seq((Success, makeMapStatus("hostA", 1))))
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.getMergerLocs.nonEmpty)
+
+    complete(taskSets(1), Seq((Success, makeMapStatus("hostA", 1))))
+    val shuffleStage2 = 
scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage2.shuffleDep.getMergerLocs.nonEmpty)
+
+    assert(shuffleStage2.shuffleDep.shuffleMergeFinalized)
+    assert(shuffleStage1.shuffleDep.shuffleMergeFinalized)
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep1.shuffleId) 
== parts)
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep2.shuffleId) 
== parts)
+
+    completeNextResultStageWithSuccess(2, 0)
+    assert(results === Map(0 -> 42))
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: Reused ShuffleDependency with Shuffle Merge disabled for 
the corresponding" +
+      " ShuffleDependency should not cause DAGScheduler to hang") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 10)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 20
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+    val partitions = (0 until parts).toArray
+    submit(reduceRdd, partitions)
+
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    val shuffleStage = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(!shuffleStage.shuffleDep.shuffleMergeEnabled)
+
+    completeNextResultStageWithSuccess(1, 0)
+    val reduce2 = new MyRDD(sc, parts, List(shuffleDep))
+    submit(reduce2, partitions)
+    // Stage 2 should not be executed as it should reuse the already computed 
shuffle output
+    assert(scheduler.stageIdToStage(2).latestInfo.taskMetrics == null)
+    completeNextResultStageWithSuccess(3, 0, idx => idx + 1234)
+
+    val expected = (0 until parts).map(idx => (idx, idx + 1234))
+    assert(results === expected.toMap)
+
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: Reused ShuffleDependency with Shuffle Merge disabled for 
the corresponding" +
+      " ShuffleDependency with shuffle data loss should recompute missing 
partitions") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 10)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 20
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+    val partitions = (0 until parts).toArray
+    submit(reduceRdd, partitions)
+
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    val shuffleStage = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(!shuffleStage.shuffleDep.shuffleMergeEnabled)
+
+    completeNextResultStageWithSuccess(1, 0)
+
+    DAGSchedulerSuite.clearMergerLocs
+    val hosts = (6 to parts).map {x => s"Host$x" }
+    DAGSchedulerSuite.addMergerLocs(hosts)
+
+    val reduce2 = new MyRDD(sc, parts, List(shuffleDep))
+    submit(reduce2, partitions)
+    // Note that the stage numbering here is only b/c the shared dependency 
produces a new, skipped
+    // stage. If instead it reused the existing stage, then this would be 
stage 2
+    completeNextStageWithFetchFailure(3, 0, shuffleDep)
+    scheduler.resubmitFailedStages()
+
+    // Make sure shuffle merge is disabled for the retry
+    val stage2 = scheduler.stageIdToStage(2).asInstanceOf[ShuffleMapStage]
+    assert(!stage2.shuffleDep.shuffleMergeEnabled)
+
+    // the scheduler now creates a new task set to regenerate the missing map 
output, but this time
+    // using a different stage, the "skipped" one
+    assert(scheduler.stageIdToStage(2).latestInfo.taskMetrics != null)
+    completeShuffleMapStageSuccessfully(2, 0, 2)
+    completeNextResultStageWithSuccess(3, 1, idx => idx + 1234)
+
+    val expected = (0 until parts).map(idx => (idx, idx + 1234))
+    assert(results === expected.toMap)
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: Empty RDD should not be computed") {
+    initPushBasedShuffleConfs(conf)
+    val data = sc.emptyRDD[Int]
+    data.sortBy(x => x).collect()
+    assert(taskSets.isEmpty)
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: Merge results should be unregistered if the running stage 
is cancelled" +
+    " before shuffle merge is finalized") {
+    initPushBasedShuffleConfs(conf)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    scheduler = new MyDAGScheduler(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      mapOutputTracker,
+      blockManagerMaster,
+      sc.env,
+      shuffleMergeFinalize = false)
+    dagEventProcessLoopTester = new 
DAGSchedulerEventProcessLoopTester(scheduler)
+
+    val parts = 2
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+    // Complete shuffle map stage successfully on hostA
+    complete(taskSets(0), taskSets(0).tasks.zipWithIndex.map {
+      case (task, _) =>
+        (Success, makeMapStatus("hostA", parts))
+    }.toSeq)
+
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) 
== parts)
+    val shuffleMapStageToCancel = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    runEvent(StageCancelled(0, Option("Explicit cancel check")))
+    scheduler.handleShuffleMergeFinalized(shuffleMapStageToCancel)
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) 
== 0)
+  }
+
+  test("SPARK-32920: SPARK-35549: Merge results should not get registered" +
+    " after shuffle merge finalization") {
+    initPushBasedShuffleConfs(conf)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+
+    scheduler = new MyDAGScheduler(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      mapOutputTracker,
+      blockManagerMaster,
+      sc.env,
+      shuffleMergeFinalize = false,
+      shuffleMergeRegister = false)
+    dagEventProcessLoopTester = new 
DAGSchedulerEventProcessLoopTester(scheduler)
+
+    val parts = 2
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+    // Complete shuffle map stage successfully on hostA
+    complete(taskSets(0), taskSets(0).tasks.zipWithIndex.map {
+      case (task, _) =>
+        (Success, makeMapStatus("hostA", parts))
+    }.toSeq)
+    val shuffleMapStage = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((0, 
makeMergeStatus("hostA"))))
+    scheduler.handleShuffleMergeFinalized(shuffleMapStage)
+    scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((1, 
makeMergeStatus("hostA"))))
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) 
== 1)
+  }
+
+  test("SPARK-32920: Disable push based shuffle in the case of a barrier 
stage") {
+    initPushBasedShuffleConfs(conf)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+
+    val parts = 2
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil).barrier().mapPartitions(iter 
=> iter)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+    completeShuffleMapStageSuccessfully(0, 0, reduceRdd.partitions.length)
+    val shuffleMapStage = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(!shuffleMapStage.shuffleDep.shuffleMergeEnabled)
+  }
+
+  test("SPARK-32920: metadata fetch failure should not unregister map status") 
{
+    initPushBasedShuffleConfs(conf)
+    val parts = 2
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+    submit(reduceRdd, (0 until parts).toArray)
+    assert(taskSets.length == 1)
+
+    // Complete shuffle map stage successfully on hostA
+    complete(taskSets(0), taskSets(0).tasks.zipWithIndex.map {
+      case (task, _) =>
+        (Success, makeMapStatus("hostA", parts))
+    }.toSeq)
+
+    assert(mapOutputTracker.getNumAvailableOutputs(shuffleDep.shuffleId) == 
parts)
+
+    // Finish the first task
+    runEvent(makeCompletionEvent(
+      taskSets(1).tasks(0), Success, makeMapStatus("hostA", parts)))
+
+    // The second task fails with Metadata Failed exception.
+    val metadataFetchFailedEx = new MetadataFetchFailedException(
+      shuffleDep.shuffleId, 1, "metadata failure");
+    runEvent(makeCompletionEvent(
+      taskSets(1).tasks(1), metadataFetchFailedEx.toTaskFailedReason, null))
+    assert(mapOutputTracker.getNumAvailableOutputs(shuffleDep.shuffleId) == 
parts)
+  }
+
   /**
    * Assert that the supplied TaskSet has exactly the given hosts as its 
preferred locations.
    * Note that this checks only the host and not the executor ID.
@@ -3448,14 +3831,69 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
 }
 
 object DAGSchedulerSuite {
+  val mergerLocs = ArrayBuffer[BlockManagerId]()
+
   def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2, mapTaskId: 
Long = -1): MapStatus =
     MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 
mapTaskId)
 
   def makeBlockManagerId(host: String): BlockManagerId = {
     BlockManagerId(host + "-exec", host, 12345)
   }
+
+  def makeMergeStatus(host: String, size: Long = 1000): MergeStatus =
+    MergeStatus(makeBlockManagerId(host), mock(classOf[RoaringBitmap]), size)
+
+  def addMergerLocs(locs: Seq[String]): Unit = {
+    locs.foreach { loc => mergerLocs.append(makeBlockManagerId(loc)) }
+  }
+
+  def clearMergerLocs: Unit = mergerLocs.clear()
+
 }
 
 object FailThisAttempt {
   val _fail = new AtomicBoolean(true)
 }
+
+private class PushBasedSchedulerBackend(
+    conf: SparkConf,
+    scheduler: TaskSchedulerImpl,
+    cores: Int) extends LocalSchedulerBackend(conf, scheduler, cores) {
+
+  override def getShufflePushMergerLocations(
+      numPartitions: Int,
+      resourceProfileId: Int): Seq[BlockManagerId] = {
+    val mergerLocations = 
Utils.randomize(DAGSchedulerSuite.mergerLocs).take(numPartitions)
+    if (mergerLocations.size < numPartitions && mergerLocations.size <
+      conf.getInt(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD.key, 
5)) {
+      Seq.empty[BlockManagerId]
+    } else {
+      mergerLocations
+    }
+  }
+
+  // Currently this is only used in tests specifically for Push based shuffle
+  override def maxNumConcurrentTasks(rp: ResourceProfile): Int = {
+    2
+  }
+}
+
+private class PushBasedClusterManager extends ExternalClusterManager {
+  def canCreate(masterURL: String): Boolean = masterURL == 
"pushbasedshuffleclustermanager"
+
+  override def createSchedulerBackend(
+      sc: SparkContext,
+      masterURL: String,
+      scheduler: TaskScheduler): SchedulerBackend = {
+    new PushBasedSchedulerBackend(sc.conf, 
scheduler.asInstanceOf[TaskSchedulerImpl], 1)
+  }
+
+  override def createTaskScheduler(
+      sc: SparkContext,
+      masterURL: String): TaskScheduler = new TaskSchedulerImpl(sc, 1, isLocal 
= true)
+
+  override def initialize(scheduler: TaskScheduler, backend: 
SchedulerBackend): Unit = {
+    val sc = scheduler.asInstanceOf[TaskSchedulerImpl]
+    sc.initialize(backend)
+  }
+}

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to