This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 500f3097111 [SPARK-40096] Fix finalize shuffle stage slow due to 
connection creation slow
500f3097111 is described below

commit 500f3097111a6bf024acf41400660c199a150350
Author: Kun Wan <[email protected]>
AuthorDate: Thu Sep 22 20:30:51 2022 -0500

    [SPARK-40096] Fix finalize shuffle stage slow due to connection creation 
slow
    
    ### What changes were proposed in this pull request?
    
    This PR will run `scheduleShuffleMergeFinalize()` and send 
`finalizeShuffleMerge` RPCs in two threads, and stop all work after 
`PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT` regardless of sucess or failure.
    
    Now we will only call `removeShufflePushMergerLocation` when shuffle fetch 
fails, this PR will also prevent these merger nodes from bing selected as 
mergeLocations when creating connections fails. Adding those bad merge nodes to 
finalizeBlackNodes, so subsequent shuffle map stages will not try to connect 
them.
    
    ### Why are the changes needed?
    
    DAGSchuedler will finalize each shuffle map stage in one 
`shuffle-merge-finalizer` thread,  and lock `clientPool.locks[clientIndex]` 
when creating connect to the ESS merger node, the other  
`shuffle-merge-finalizer` threads (one stage per thread) will wait for 
`SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY`.
    Although reducing `SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY` helps, 
the total wait time( SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY * 
lostMergerNodesSize * stageSize ) will still be long.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Add UT
    
    Closes #37533 from wankunde/SPARK-40096.
    
    Authored-by: Kun Wan <[email protected]>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../org/apache/spark/internal/config/package.scala |  13 ++-
 .../org/apache/spark/scheduler/DAGScheduler.scala  | 129 +++++++++++++--------
 .../apache/spark/scheduler/DAGSchedulerSuite.scala |  62 +++++++++-
 3 files changed, 155 insertions(+), 49 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala 
b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 2d2f3c9428a..9aaf6e4c6f0 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -2327,7 +2327,18 @@ package object config {
         " shuffle is enabled.")
       .version("3.3.0")
       .intConf
-      .createWithDefault(3)
+      .createWithDefault(8)
+
+  private[spark] val PUSH_SHUFFLE_FINALIZE_RPC_THREADS =
+    ConfigBuilder("spark.shuffle.push.sendFinalizeRPCThreads")
+      .internal()
+      .doc("Number of threads used by the driver to send finalize shuffle RPC 
to mergers" +
+        " location and then get MergeStatus. The thread will run for up to " +
+        " PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT. The merger ESS may open 
too many files" +
+        " if the finalize rpc is not received.")
+      .version("3.4.0")
+      .intConf
+      .createWithDefault(8)
 
   private[spark] val PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT =
     ConfigBuilder("spark.shuffle.push.minShuffleSizeToWait")
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 9bd4a6f4478..475afd01d00 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -19,7 +19,8 @@ package org.apache.spark.scheduler
 
 import java.io.NotSerializableException
 import java.util.Properties
-import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, 
TimeoutException, TimeUnit }
+import java.util.concurrent.{ConcurrentHashMap, ExecutorService, 
ScheduledFuture, TimeoutException, TimeUnit}
+import java.util.concurrent.{Future => JFutrue}
 import java.util.concurrent.atomic.AtomicInteger
 
 import scala.annotation.tailrec
@@ -273,6 +274,8 @@ private[spark] class DAGScheduler(
   private val shuffleMergeFinalizeNumThreads =
     sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_THREADS)
 
+  private val shuffleFinalizeRpcThreads = 
sc.getConf.get(config.PUSH_SHUFFLE_FINALIZE_RPC_THREADS)
+
   // Since SparkEnv gets initialized after DAGScheduler, externalShuffleClient 
needs to be
   // initialized lazily
   private lazy val externalShuffleClient: Option[BlockStoreClient] =
@@ -282,13 +285,17 @@ private[spark] class DAGScheduler(
       None
     }
 
-  // Use multi-threaded scheduled executor. The merge finalization task could 
take some time,
-  // depending on the time to establish connections to mergers, and the time 
to get MergeStatuses
-  // from all the mergers.
+  // When push-based shuffle is enabled, spark driver will submit a finalize 
task which will send
+  // a finalize rpc to each merger ESS after the shuffle map stage is 
complete. The merge
+  // finalization takes up to PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT.
   private val shuffleMergeFinalizeScheduler =
     ThreadUtils.newDaemonThreadPoolScheduledExecutor("shuffle-merge-finalizer",
       shuffleMergeFinalizeNumThreads)
 
+  // Send finalize RPC tasks to merger ESS
+  private val shuffleSendFinalizeRpcExecutor: ExecutorService =
+    ThreadUtils.newDaemonFixedThreadPool(shuffleFinalizeRpcThreads, 
"shuffle-merge-finalize-rpc")
+
   /**
    * Called by the TaskSetManager to report task's starting.
    */
@@ -2242,70 +2249,98 @@ private[spark] class DAGScheduler(
     val numMergers = stage.shuffleDep.getMergerLocs.length
     val results = (0 until numMergers).map(_ => 
SettableFuture.create[Boolean]())
     externalShuffleClient.foreach { shuffleClient =>
-      if (!registerMergeResults) {
-        results.foreach(_.set(true))
-        // Finalize in separate thread as shuffle merge is a no-op in this case
-        shuffleMergeFinalizeScheduler.schedule(new Runnable {
-          override def run(): Unit = {
-            stage.shuffleDep.getMergerLocs.foreach {
-              case shuffleServiceLoc =>
-                // Sends async request to shuffle service to finalize shuffle 
merge on that host.
-                // Since merge statuses will not be registered in this case,
-                // we pass a no-op listener.
-                shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
-                  shuffleServiceLoc.port, shuffleId, shuffleMergeId,
-                  new MergeFinalizerListener {
-                    override def onShuffleMergeSuccess(statuses: 
MergeStatuses): Unit = {
-                    }
+      val scheduledFutures =
+        if (!registerMergeResults) {
+          results.foreach(_.set(true))
+          // Finalize in separate thread as shuffle merge is a no-op in this 
case
+          stage.shuffleDep.getMergerLocs.map {
+            case shuffleServiceLoc =>
+              // Sends async request to shuffle service to finalize shuffle 
merge on that host.
+              // Since merge statuses will not be registered in this case,
+              // we pass a no-op listener.
+              shuffleSendFinalizeRpcExecutor.submit(new Runnable() {
+                override def run(): Unit = {
+                  shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
+                    shuffleServiceLoc.port, shuffleId, shuffleMergeId,
+                    new MergeFinalizerListener {
+                      override def onShuffleMergeSuccess(statuses: 
MergeStatuses): Unit = {
+                      }
 
-                    override def onShuffleMergeFailure(e: Throwable): Unit = {
-                    }
-                  })
-            }
-          }
-        }, 0, TimeUnit.SECONDS)
-      } else {
-        stage.shuffleDep.getMergerLocs.zipWithIndex.foreach {
-          case (shuffleServiceLoc, index) =>
-            // Sends async request to shuffle service to finalize shuffle 
merge on that host
-            // TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage is 
cancelled
-            // TODO: during shuffleMergeFinalizeWaitSec
-            shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
-              shuffleServiceLoc.port, shuffleId, shuffleMergeId,
-              new MergeFinalizerListener {
-                override def onShuffleMergeSuccess(statuses: MergeStatuses): 
Unit = {
-                  assert(shuffleId == statuses.shuffleId)
-                  eventProcessLoop.post(RegisterMergeStatuses(stage, 
MergeStatus.
-                    convertMergeStatusesToMergeStatusArr(statuses, 
shuffleServiceLoc)))
-                  results(index).set(true)
+                      override def onShuffleMergeFailure(e: Throwable): Unit = 
{
+                      }
+                    })
                 }
+              })
+          }
+        } else {
+          stage.shuffleDep.getMergerLocs.zipWithIndex.map {
+            case (shuffleServiceLoc, index) =>
+              // Sends async request to shuffle service to finalize shuffle 
merge on that host
+              // TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage 
is cancelled
+              // TODO: during shuffleMergeFinalizeWaitSec
+              shuffleSendFinalizeRpcExecutor.submit(new Runnable() {
+                override def run(): Unit = {
+                  shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
+                    shuffleServiceLoc.port, shuffleId, shuffleMergeId,
+                    new MergeFinalizerListener {
+                      override def onShuffleMergeSuccess(statuses: 
MergeStatuses): Unit = {
+                        assert(shuffleId == statuses.shuffleId)
+                        eventProcessLoop.post(RegisterMergeStatuses(stage, 
MergeStatus.
+                          convertMergeStatusesToMergeStatusArr(statuses, 
shuffleServiceLoc)))
+                        results(index).set(true)
+                      }
 
-                override def onShuffleMergeFailure(e: Throwable): Unit = {
-                  logWarning(s"Exception encountered when trying to finalize 
shuffle " +
-                    s"merge on ${shuffleServiceLoc.host} for shuffle 
$shuffleId", e)
-                  // Do not fail the future as this would cause dag scheduler 
to prematurely
-                  // give up on waiting for merge results from the remaining 
shuffle services
-                  // if one fails
-                  results(index).set(false)
+                      override def onShuffleMergeFailure(e: Throwable): Unit = 
{
+                        logWarning(s"Exception encountered when trying to 
finalize shuffle " +
+                          s"merge on ${shuffleServiceLoc.host} for shuffle 
$shuffleId", e)
+                        // Do not fail the future as this would cause dag 
scheduler to prematurely
+                        // give up on waiting for merge results from the 
remaining shuffle services
+                        // if one fails
+                        results(index).set(false)
+                      }
+                    })
                 }
               })
+          }
         }
-      }
       // DAGScheduler only waits for a limited amount of time for the merge 
results.
       // It will attempt to submit the next stage(s) irrespective of whether 
merge results
       // from all shuffle services are received or not.
+      var timedOut = false
       try {
         Futures.allAsList(results: _*).get(shuffleMergeResultsTimeoutSec, 
TimeUnit.SECONDS)
       } catch {
         case _: TimeoutException =>
+          timedOut = true
           logInfo(s"Timed out on waiting for merge results from all " +
             s"$numMergers mergers for shuffle $shuffleId")
       } finally {
+        if (timedOut || !registerMergeResults) {
+          cancelFinalizeShuffleMergeFutures(scheduledFutures,
+            if (timedOut) 0L else shuffleMergeResultsTimeoutSec)
+        }
         eventProcessLoop.post(ShuffleMergeFinalized(stage))
       }
     }
   }
 
+  private def cancelFinalizeShuffleMergeFutures(
+      futures: Seq[JFutrue[_]],
+      delayInSecs: Long): Unit = {
+
+    def cancelFutures(): Unit = futures.foreach(_.cancel(true))
+
+    if (delayInSecs > 0) {
+      shuffleMergeFinalizeScheduler.schedule(new Runnable {
+        override def run(): Unit = {
+          cancelFutures()
+        }
+      }, delayInSecs, TimeUnit.SECONDS)
+    } else {
+      cancelFutures()
+    }
+  }
+
   private def processShuffleMapStageCompletion(shuffleStage: ShuffleMapStage): 
Unit = {
     markStageAsFinished(shuffleStage)
     logInfo("looking for newly runnable stages")
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index fc7aa06e41e..10cd136d564 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -25,7 +25,9 @@ import scala.annotation.meta.param
 import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
 import scala.util.control.NonFatal
 
+import org.mockito.ArgumentMatchers.any
 import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
 import org.roaringbitmap.RoaringBitmap
 import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
 import org.scalatest.exceptions.TestFailedException
@@ -36,13 +38,14 @@ import org.apache.spark.broadcast.BroadcastManager
 import org.apache.spark.executor.ExecutorMetrics
 import org.apache.spark.internal.config
 import org.apache.spark.internal.config.Tests
+import org.apache.spark.network.shuffle.ExternalBlockStoreClient
 import org.apache.spark.rdd.{DeterministicLevel, RDD}
 import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile, 
ResourceProfileBuilder, TaskResourceRequests}
 import org.apache.spark.resource.ResourceUtils.{FPGA, GPU}
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
 import org.apache.spark.scheduler.local.LocalSchedulerBackend
 import org.apache.spark.shuffle.{FetchFailedException, 
MetadataFetchFailedException}
-import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
+import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, 
BlockManagerMaster}
 import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, 
Clock, LongAccumulator, SystemClock, Utils}
 
 class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
@@ -4440,6 +4443,63 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(mapStatuses.count(s => s != null && s.location.executorId == 
"hostB-exec") === 1)
   }
 
+  Seq(true, false).foreach { registerMergeResults =>
+    test("SPARK-40096: Send finalize events even if shuffle merger blocks 
indefinitely " +
+      s"with registerMergeResults is ${registerMergeResults}") {
+      initPushBasedShuffleConfs(conf)
+
+      sc.conf.set("spark.shuffle.push.results.timeout", "1s")
+      val myScheduler = new MyDAGScheduler(
+        sc,
+        taskScheduler,
+        sc.listenerBus,
+        mapOutputTracker,
+        blockManagerMaster,
+        sc.env,
+        shuffleMergeFinalize = false)
+
+      val mergerLocs = Seq(makeBlockManagerId("hostA"), 
makeBlockManagerId("hostB"))
+      val timeoutSecs = 1
+      val sendRequestsLatch = new CountDownLatch(mergerLocs.size)
+      val completeLatch = new CountDownLatch(mergerLocs.size)
+      val canSendRequestLatch = new CountDownLatch(1)
+
+      val blockStoreClient = mock(classOf[ExternalBlockStoreClient])
+      val blockStoreClientField = 
classOf[BlockManager].getDeclaredField("blockStoreClient")
+      blockStoreClientField.setAccessible(true)
+      blockStoreClientField.set(sc.env.blockManager, blockStoreClient)
+
+      val sentHosts = ArrayBuffer[String]()
+      var hostAInterrupted = false
+      doAnswer { (invoke: InvocationOnMock) =>
+        val host = invoke.getArgument[String](0)
+        sendRequestsLatch.countDown()
+        try {
+          if (host == "hostA") {
+            canSendRequestLatch.await(timeoutSecs * 2, TimeUnit.SECONDS)
+          }
+          sentHosts += host
+        } catch {
+          case _: InterruptedException => hostAInterrupted = true
+        } finally {
+          completeLatch.countDown()
+        }
+      }.when(blockStoreClient).finalizeShuffleMerge(any(), any(), any(), 
any(), any())
+
+      val shuffleMapRdd = new MyRDD(sc, 1, Nil)
+      val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(2))
+      shuffleDep.setMergerLocs(mergerLocs)
+      val shuffleStage = myScheduler.createShuffleMapStage(shuffleDep, 0)
+
+      myScheduler.finalizeShuffleMerge(shuffleStage, registerMergeResults)
+      sendRequestsLatch.await()
+      verify(blockStoreClient, times(2))
+        .finalizeShuffleMerge(any(), any(), any(), any(), any())
+      assert(sentHosts === Seq("hostB"))
+      completeLatch.await()
+      assert(hostAInterrupted)
+    }
+  }
 
   /**
    * Assert that the supplied TaskSet has exactly the given hosts as its 
preferred locations.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to