This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a37c3d59b8e [SPARK-53898][CORE] Fix race conditions between query 
cancellation and task completion triggered by eager shuffle cleanup
9a37c3d59b8e is described below

commit 9a37c3d59b8e80308b626b412d50b114dfe2295d
Author: Yi Wu <[email protected]>
AuthorDate: Tue Nov 11 06:57:19 2025 -0800

    [SPARK-53898][CORE] Fix race conditions between query cancellation and task 
completion triggered by eager shuffle cleanup
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to explicitly handle the `SparkException` thrown by the 
shuffle statues operations on the non-existent shuffle ID to avoid crashing the 
`SparkContext`.
    
    ### Why are the changes needed?
    
    When the main query completes, we cleanup its shuffle statuses and the data 
files. If there is subquery ongoing before it gets completely cancelled, the 
subquery can throw `SparkException` from `DAGScheduler` due to the operations 
(e.g., `MapOutputTrackerMaster.registerMapOutput()`) on the non-existent 
shuffle ID. And this unexpected exception can crash the `SparkContext`. See the 
detailed discussion at 
https://github.com/apache/spark/pull/52213#discussion_r2415632474.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #52606 from Ngone51/fix-local-shuffle-cleanup.
    
    Lead-authored-by: Yi Wu <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../scala/org/apache/spark/MapOutputTracker.scala  | 70 ++++++++++------------
 .../org/apache/spark/scheduler/DAGScheduler.scala  | 18 ++++++
 2 files changed, 49 insertions(+), 39 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala 
b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 334eb832c4c2..41a1b51a4315 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -857,37 +857,34 @@ private[spark] class MapOutputTrackerMaster(
     }
   }
 
+  private def getShuffleStatusOrError(shuffleId: Int, caller: String): 
ShuffleStatus = {
+    shuffleStatuses.get(shuffleId) match {
+      case Some(shuffleStatus) => shuffleStatus
+      case None => throw new ShuffleStatusNotFoundException(shuffleId, caller)
+    }
+  }
+
   def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): 
Boolean = {
-    shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
+    getShuffleStatusOrError(shuffleId, 
"registerMapOutput").addMapOutput(mapIndex, status)
   }
 
   /** Unregister map output information of the given shuffle, mapper and block 
manager */
   def unregisterMapOutput(shuffleId: Int, mapIndex: Int, bmAddress: 
BlockManagerId): Unit = {
-    shuffleStatuses.get(shuffleId) match {
-      case Some(shuffleStatus) =>
-        shuffleStatus.removeMapOutput(mapIndex, bmAddress)
-        incrementEpoch()
-      case None =>
-        throw new SparkException("unregisterMapOutput called for nonexistent 
shuffle ID")
-    }
+    getShuffleStatusOrError(shuffleId, 
"unregisterMapOutput").removeMapOutput(mapIndex, bmAddress)
+    incrementEpoch()
   }
 
   /** Unregister all map and merge output information of the given shuffle. */
   def unregisterAllMapAndMergeOutput(shuffleId: Int): Unit = {
-    shuffleStatuses.get(shuffleId) match {
-      case Some(shuffleStatus) =>
-        shuffleStatus.removeOutputsByFilter(x => true)
-        shuffleStatus.removeMergeResultsByFilter(x => true)
-        shuffleStatus.removeShuffleMergerLocations()
-        incrementEpoch()
-      case None =>
-        throw new SparkException(
-          s"unregisterAllMapAndMergeOutput called for nonexistent shuffle ID 
$shuffleId.")
-    }
+    val shuffleStatus = getShuffleStatusOrError(shuffleId, 
"unregisterAllMapAndMergeOutput")
+    shuffleStatus.removeOutputsByFilter(x => true)
+    shuffleStatus.removeMergeResultsByFilter(x => true)
+    shuffleStatus.removeShuffleMergerLocations()
+    incrementEpoch()
   }
 
   def registerMergeResult(shuffleId: Int, reduceId: Int, status: MergeStatus): 
Unit = {
-    shuffleStatuses(shuffleId).addMergeResult(reduceId, status)
+    getShuffleStatusOrError(shuffleId, 
"registerMergeResult").addMergeResult(reduceId, status)
   }
 
   def registerMergeResults(shuffleId: Int, statuses: Seq[(Int, MergeStatus)]): 
Unit = {
@@ -899,7 +896,8 @@ private[spark] class MapOutputTrackerMaster(
   def registerShufflePushMergerLocations(
       shuffleId: Int,
       shuffleMergers: Seq[BlockManagerId]): Unit = {
-    shuffleStatuses(shuffleId).registerShuffleMergerLocations(shuffleMergers)
+    getShuffleStatusOrError(shuffleId, "registerShufflePushMergerLocations")
+      .registerShuffleMergerLocations(shuffleMergers)
   }
 
   /**
@@ -918,28 +916,19 @@ private[spark] class MapOutputTrackerMaster(
       reduceId: Int,
       bmAddress: BlockManagerId,
       mapIndex: Option[Int] = None): Unit = {
-    shuffleStatuses.get(shuffleId) match {
-      case Some(shuffleStatus) =>
-        val mergeStatus = shuffleStatus.mergeStatuses(reduceId)
-        if (mergeStatus != null &&
-          (mapIndex.isEmpty || mergeStatus.tracker.contains(mapIndex.get))) {
-          shuffleStatus.removeMergeResult(reduceId, bmAddress)
-          incrementEpoch()
-        }
-      case None =>
-        throw new SparkException("unregisterMergeResult called for nonexistent 
shuffle ID")
+    val shuffleStatus = getShuffleStatusOrError(shuffleId, 
"unregisterMergeResult")
+    val mergeStatus = shuffleStatus.mergeStatuses(reduceId)
+    if (mergeStatus != null &&
+      (mapIndex.isEmpty || mergeStatus.tracker.contains(mapIndex.get))) {
+      shuffleStatus.removeMergeResult(reduceId, bmAddress)
+      incrementEpoch()
     }
   }
 
   def unregisterAllMergeResult(shuffleId: Int): Unit = {
-    shuffleStatuses.get(shuffleId) match {
-      case Some(shuffleStatus) =>
-        shuffleStatus.removeMergeResultsByFilter(x => true)
-        incrementEpoch()
-      case None =>
-        throw new SparkException(
-          s"unregisterAllMergeResult called for nonexistent shuffle ID 
$shuffleId.")
-    }
+    getShuffleStatusOrError(shuffleId, "unregisterAllMergeResult")
+      .removeMergeResultsByFilter(x => true)
+    incrementEpoch()
   }
 
   /** Unregister shuffle data */
@@ -1022,7 +1011,7 @@ private[spark] class MapOutputTrackerMaster(
    * Return statistics about all of the outputs for a given shuffle.
    */
   def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
-    shuffleStatuses(dep.shuffleId).withMapStatuses { statuses =>
+    getShuffleStatusOrError(dep.shuffleId, "getStatistics").withMapStatuses { 
statuses =>
       val totalSizes = new Array[Long](dep.partitioner.numPartitions)
       val parallelAggThreshold = conf.get(
         SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD)
@@ -1285,6 +1274,9 @@ private[spark] class MapOutputTrackerMaster(
   }
 }
 
+case class ShuffleStatusNotFoundException(shuffleId: Int, methodName: String)
+  extends SparkException(s"$methodName called for nonexistent shuffle ID 
$shuffleId.")
+
 /**
  * Executor-side client for fetching map output info from the driver's 
MapOutputTrackerMaster.
  * Note that this is not used in local-mode; instead, local-mode Executors 
access the
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 7d77628c3f08..7c8bea31334b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -2920,6 +2920,21 @@ private[spark] class DAGScheduler(
     }
   }
 
+  private[scheduler] def handleShuffleStatusNotFoundException(
+      ex: ShuffleStatusNotFoundException): Unit = {
+    val stage = shuffleIdToMapStage.get(ex.shuffleId)
+    val reason = "exceptions encountered while invoking " +
+      s"MapOutputTracker.${ex.methodName} with shuffleId=${ex.shuffleId}"
+    if (stage.isDefined) {
+      abortStage(stage.get, reason, Some(ex))
+      logWarning(s"Aborting stage because of $reason. It is possible that the 
stage is " +
+        "being cancelled.")
+    } else {
+      logWarning(s"Tried aborting stage because of $reason, but the stage was 
not found. " +
+        "It is possible that the stage has been cancelled earlier.")
+    }
+  }
+
   /**
    * Marks a stage as finished and removes it from the list of running stages.
    */
@@ -3192,6 +3207,9 @@ private[scheduler] class 
DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
     val timerContext = timer.time()
     try {
       doOnReceive(event)
+    } catch {
+      case ex: ShuffleStatusNotFoundException =>
+        dagScheduler.handleShuffleStatusNotFoundException(ex)
     } finally {
       timerContext.stop()
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to