This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c6ca25a8028b [SPARK-54956][CORE] Unify the indeterminate shuffle retry 
solution
c6ca25a8028b is described below

commit c6ca25a8028bae0b7b8f6a1bf0105b1dd236ceed
Author: Tengfei Huang <[email protected]>
AuthorDate: Mon Jan 19 13:39:32 2026 +0800

    [SPARK-54956][CORE] Unify the indeterminate shuffle retry solution
    
    ### What changes were proposed in this pull request?
    Unify the approach to handle indeterminate shuffle retry.
    
    Currently we have two mechanism to detect indeterminate shuffle retry, 
shuffle checksum mismatch detection and stage's indeterminate attribute. For 
either way, once indeterminate shuffle retry is detected, we should ensure all 
the succeeding stages will do a fully retry.
    
    The PR proposes to unify the approach to the new approach we added with 
shuffle checksum mismatch detection which can cover more cases.
    
    ### Why are the changes needed?
    Align the approach to handle indeterminate shuffle retry to cover more 
scenarios and reduce the code maintenance effort.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UTs
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #53782 from ivoson/SPARK-54956.
    
    Lead-authored-by: Tengfei Huang <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../org/apache/spark/scheduler/DAGScheduler.scala  | 148 +++++++++++----------
 .../apache/spark/scheduler/ShuffleMapStage.scala   |  18 ++-
 .../scala/org/apache/spark/scheduler/Stage.scala   |  16 +--
 .../apache/spark/scheduler/DAGSchedulerSuite.scala | 104 +++++++++++++--
 4 files changed, 185 insertions(+), 101 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 309fcb28b757..67785a0ce9ea 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1555,34 +1555,28 @@ private[spark] class DAGScheduler(
   private def submitMissingTasks(stage: Stage, jobId: Int): Unit = {
     logDebug("submitMissingTasks(" + stage + ")")
 
-    // Before find missing partition, do the intermediate state clean work 
first.
-    // The operation here can make sure for the partially completed 
intermediate stage,
-    // `findMissingPartitions()` returns all partitions every time.
+    // For statically indeterminate stages being retried, we trigger rollback 
BEFORE task
+    // submission. This is more efficient than deferring to task completion 
because:
+    // 1. It avoids submitting a partial stage that would need to be cancelled
+    // 2. It ensures findMissingPartitions() returns ALL partitions for the 
retry
+    //
+    // For runtime detection (checksum mismatch), we must defer to task 
completion because
+    // we don't know the stage is indeterminate until we see the checksum 
differ.
     stage match {
       case sms: ShuffleMapStage if !sms.isAvailable =>
-        if (!sms.shuffleDep.checksumMismatchFullRetryEnabled && 
stage.isIndeterminate) {
-          // already executed at least once
-          if (sms.getNextAttemptId > 0) {
-            // While we previously validated possible rollbacks during the 
handling of a FetchFailure,
-            // where we were fetching from an indeterminate source map stages, 
this later check
-            // covers additional cases like recalculating an indeterminate 
stage after an executor
-            // loss. Moreover, because this check occurs later in the process, 
if a result stage task
-            // has successfully completed, we can detect this and abort the 
job, as rolling back a
-            // result stage is not possible.
-            val stagesToRollback = collectSucceedingStages(sms)
-            filterAndAbortUnrollbackableStages(stagesToRollback)
-            // stages which cannot be rolled back were aborted which leads to 
removing the
-            // the dependant job(s) from the active jobs set
-            val numActiveJobsWithStageAfterRollback =
-              activeJobs.count(job => 
stagesToRollback.contains(job.finalStage))
-            if (numActiveJobsWithStageAfterRollback == 0) {
-              logInfo(log"All jobs depending on the indeterminate stage " +
-                log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is 
not needed anymore.")
-              return
-            }
+        if (sms.isStaticallyIndeterminate &&
+            !sms.shuffleDep.checksumMismatchFullRetryEnabled &&
+            sms.getNextAttemptId > 0) {
+          // The `rollbackCurrentStage = true` parameter ensures the current 
submitting stage
+          // is included in the cleanup for a fresh start: clearing its 
shuffle outputs, marking
+          // old task results to be ignored, and creating a new shuffle merge 
state for the
+          // upcoming retry.
+          rollbackSucceedingStages(sms, rollbackCurrentStage = true)
+          if (!stageIdToStage.contains(stage.id)) {
+            logInfo(log"All jobs depending on the indeterminate stage " +
+              log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is 
not needed anymore.")
+            return
           }
-          
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
-          sms.shuffleDep.newShuffleMergeState()
         }
       case _ =>
     }
@@ -1880,7 +1874,9 @@ private[spark] class DAGScheduler(
         if (shuffleStage.isAvailable) {
           computedTotalSize
         } else {
-          if (shuffleStage.isIndeterminate) {
+          // For indeterminate stages, don't use partial merge results as the 
stage output
+          // may change on retry
+          if (shuffleStage.isStaticallyIndeterminate) {
             0L
           } else {
             computedTotalSize
@@ -1903,8 +1899,14 @@ private[spark] class DAGScheduler(
    * the map tasks are re-tried. For stages where rollback and retry all tasks 
are not possible,
    * we will need to abort the stages.
    */
-  private[scheduler] def rollbackSucceedingStages(mapStage: ShuffleMapStage): 
Unit = {
-    val stagesToRollback = collectSucceedingStages(mapStage).filterNot(_ == 
mapStage)
+  private[scheduler] def rollbackSucceedingStages(
+      mapStage: ShuffleMapStage,
+      rollbackCurrentStage: Boolean = false): Unit = {
+    val stagesToRollback = if (rollbackCurrentStage) {
+      collectSucceedingStages(mapStage)
+    } else {
+      collectSucceedingStages(mapStage).filterNot(_ == mapStage)
+    }
     val stagesCanRollback = 
filterAndAbortUnrollbackableStages(stagesToRollback)
     // stages which cannot be rolled back were aborted which leads to removing 
the
     // the dependant job(s) from the active jobs set, there could be no active 
jobs
@@ -1920,9 +1922,15 @@ private[spark] class DAGScheduler(
       // Rollback the running stages first to avoid triggering more fetch 
failures.
       stagesToRollback.toSeq.sortBy(!runningStages.contains(_)).foreach {
         case sms: ShuffleMapStage =>
-          rollbackShuffleMapStage(sms, "rolling back due to indeterminate " +
-            s"output of shuffle map stage $mapStage")
-          sms.markAsRollingBack()
+          // Iterate over the stages to rollback and checking whether the 
stage has been rolled back
+          // for the current attempt to avoid rolling back the same stage 
attempt multiple times.
+          val alreadyRollingBack =
+            sms.maxAttemptIdToIgnore.contains(sms.latestInfo.attemptNumber())
+          if (sms.getNextAttemptId > 0 && !alreadyRollingBack) {
+            rollbackShuffleMapStage(sms, "rolling back due to indeterminate " +
+              s"output of shuffle map stage $mapStage")
+            sms.markAsRollingBack()
+          }
 
         case rs: ResultStage =>
           rs.markAsRollingBack()
@@ -2086,19 +2094,11 @@ private[spark] class DAGScheduler(
         // tasks complete, they still count and we can mark the corresponding 
partitions as
         // finished if the stage is determinate. Here we notify the task 
scheduler to skip running
         // tasks for the same partition to save resource.
-        def stageWithChecksumMismatchFullRetryEnabled(stage: Stage): Boolean = 
{
-          stage match {
-            case s: ShuffleMapStage => 
s.shuffleDep.checksumMismatchFullRetryEnabled
-            case _ => 
stage.parents.exists(stageWithChecksumMismatchFullRetryEnabled)
-          }
-        }
 
-        // Ignore task completion for old attempt of indeterminate stage
-        val ignoreOldTaskAttempts = if 
(stageWithChecksumMismatchFullRetryEnabled(stage)) {
+        // Ignore task completion for old attempt of stages with 
nondeterministic output.
+        // This is tracked via maxAttemptIdToIgnore which is set when a stage 
is rolled back.
+        val ignoreOldTaskAttempts =
           stage.maxAttemptIdToIgnore.exists(_ >= task.stageAttemptId)
-        } else {
-          stage.isIndeterminate && task.stageAttemptId < 
stage.latestInfo.attemptNumber()
-        }
 
         if (!ignoreOldTaskAttempts && task.stageAttemptId < 
stage.latestInfo.attemptNumber()) {
           taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
@@ -2179,18 +2179,16 @@ private[spark] class DAGScheduler(
                   shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
                 if (isChecksumMismatched) {
                   shuffleStage.isChecksumMismatched = isChecksumMismatched
-                  // There could be multiple checksum mismatches detected for 
a single stage attempt.
-                  // We check for stage abortion once and only once when we 
first detect checksum
-                  // mismatch for each stage attempt. For example, assume that 
we have
-                  // stage1 -> stage2, and we encounter checksum mismatch 
during the retry of stage1.
-                  // In this case, we need to call abortUnrollbackableStages() 
for the succeeding
-                  // stages. Assume that when stage2 is retried, some tasks 
finish and some tasks
-                  // failed again with FetchFailed. In case that we encounter 
checksum mismatch again
-                  // during the retry of stage1, we need to call 
abortUnrollbackableStages() again.
-                  if (shuffleStage.maxChecksumMismatchedId < 
smt.stageAttemptId) {
-                    shuffleStage.maxChecksumMismatchedId = smt.stageAttemptId
-                    if 
(shuffleStage.shuffleDep.checksumMismatchFullRetryEnabled
-                      && shuffleStage.isStageIndeterminate) {
+                  // Runtime detection of nondeterministic output via checksum 
mismatch.
+                  // This is the trigger point for runtime detection - we only 
know the stage
+                  // is indeterminate when we see different checksums from 
different attempts.
+                  //
+                  // Note: Static detection (isStaticallyIndeterminate) is 
triggered earlier
+                  // in FetchFailed handler and submitMissingTasks for 
efficiency.
+                  if (shuffleStage.shuffleDep.checksumMismatchFullRetryEnabled
+                      && shuffleStage.isRuntimeIndeterminate) {
+                    if (shuffleStage.maxChecksumMismatchedId < 
smt.stageAttemptId) {
+                      shuffleStage.maxChecksumMismatchedId = smt.stageAttemptId
                       rollbackSucceedingStages(shuffleStage)
                     }
                   }
@@ -2312,19 +2310,20 @@ private[spark] class DAGScheduler(
             failedStages += failedStage
             failedStages += mapStage
             if (noResubmitEnqueued) {
-              // If the map stage is INDETERMINATE, which means the map tasks 
may return
-              // different result when re-try, we need to re-try all the tasks 
of the failed
-              // stage and its succeeding stages, because the input data will 
be changed after the
-              // map tasks are re-tried.
-              // Note that, if map stage is UNORDERED, we are fine. The 
shuffle partitioner is
-              // guaranteed to be determinate, so the input data of the 
reducers will not change
-              // even if the map tasks are re-tried.
-              if (mapStage.isIndeterminate && 
!mapStage.shuffleDep.checksumMismatchFullRetryEnabled) {
-                val stagesToRollback = collectSucceedingStages(mapStage)
-                val stagesCanRollback = 
filterAndAbortUnrollbackableStages(stagesToRollback)
-                logInfo(log"The shuffle map stage ${MDC(STAGE, mapStage)} with 
indeterminate output " +
-                  log"was failed, we will roll back and rerun below stages 
which include itself and all " +
-                  log"its indeterminate child stages: ${MDC(STAGES, 
stagesCanRollback)}")
+              // For statically indeterminate stages, trigger rollback early 
(here and in
+              // submitMissingTasks) rather than deferring to task completion. 
This is more
+              // efficient because it clears shuffle outputs before the retry 
is submitted,
+              // ensuring findMissingPartitions() returns all partitions.
+              //
+              // For runtime detection (checksum mismatch), rollback is 
triggered at task
+              // completion when the mismatch is discovered.
+              //
+              // The `rollbackCurrentStage = true` parameter ensures the 
failed map stage is
+              // included in the cleanup: clearing its shuffle outputs, 
marking old task results
+              // to be ignored, and creating a new shuffle merge state for the 
upcoming retry.
+              if (mapStage.isStaticallyIndeterminate &&
+                  !mapStage.shuffleDep.checksumMismatchFullRetryEnabled) {
+                rollbackSucceedingStages(mapStage, rollbackCurrentStage = true)
               }
 
               // We expect one executor failure to trigger many FetchFailures 
in rapid succession,
@@ -2535,7 +2534,12 @@ private[spark] class DAGScheduler(
               "shuffle block fetching protocol. Please check the config " +
               "'spark.shuffle.useOldFetchProtocol', see more detail in " +
               "SPARK-27665 and SPARK-25341."
-            abortStage(mapStage, reason, None)
+            // We should abort the final stages to make sure the job fails. 
Otherwise, nothing
+            // would happen if the shuffle output are available since the job 
would be considered
+            // as no longer depending on the stage.
+            mapStage.jobIds.flatMap(jobIdToActiveJob.get)
+              .map(_.finalStage)
+              .foreach(abortStage(_, reason, None))
           } else {
             rollingBackStages += mapStage
           }
@@ -2819,11 +2823,11 @@ private[spark] class DAGScheduler(
       if (stage.pendingPartitions.isEmpty)
         if (runningStages.contains(stage)) {
           processShuffleMapStageCompletion(stage)
-        } else if (stage.isIndeterminate) {
+        } else if (stage.isStaticallyIndeterminate) {
           // There are 2 possibilities here - stage is either cancelled or it 
will be resubmitted.
-          // If this is an indeterminate stage which is cancelled, we 
unregister all its merge
-          // results here just to free up some memory. If the indeterminate 
stage is resubmitted,
-          // merge results are cleared again when the newer attempt is 
submitted.
+          // If this is a statically indeterminate stage which is cancelled, 
we unregister all its
+          // merge results here just to free up some memory. If the 
indeterminate stage is
+          // resubmitted, merge results are cleared again when the newer 
attempt is submitted.
           mapOutputTracker.unregisterAllMergeResult(stage.shuffleDep.shuffleId)
           // For determinate stages, which have completed merge finalization, 
we don't need to
           // unregister merge results - since the stage retry, or any other 
stage computing the
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala 
b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
index db09d19d0acf..79f7af48f102 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
@@ -20,7 +20,7 @@ package org.apache.spark.scheduler
 import scala.collection.mutable.HashSet
 
 import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency}
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
 import org.apache.spark.util.CallSite
 
 /**
@@ -94,4 +94,20 @@ private[spark] class ShuffleMapStage(
       .findMissingPartitions(shuffleDep.shuffleId)
       .getOrElse(0 until numPartitions)
   }
+
+  /**
+   * Whether the stage is statically declared as indeterminate based on the 
RDD's
+   * outputDeterministicLevel property. This is known at RDD creation time.
+   */
+  def isStaticallyIndeterminate: Boolean = {
+    rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE
+  }
+
+  /**
+   * Whether the stage has been detected as indeterminate at runtime via 
checksum mismatch.
+   * This means different stage attempts have produced different data for the 
same partition.
+   */
+  def isRuntimeIndeterminate: Boolean = {
+    !rdd.isReliablyCheckpointed && isChecksumMismatched
+  }
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala 
b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index d8aaea013ee6..f22776bbc196 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.HashSet
 
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.{DeterministicLevel, RDD}
+import org.apache.spark.rdd.RDD
 import org.apache.spark.util.CallSite
 
 /**
@@ -155,18 +155,4 @@ private[scheduler] abstract class Stage(
 
   /** Returns the sequence of partition ids that are missing (i.e. needs to be 
computed). */
   def findMissingPartitions(): Seq[Int]
-
-  def isIndeterminate: Boolean = {
-    rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE
-  }
-
-  // Returns true if any parents of this stage are indeterminate.
-  def isParentIndeterminate: Boolean = {
-    parents.exists(_.isStageIndeterminate)
-  }
-
-  // Returns true if the stage itself is indeterminate.
-  def isStageIndeterminate: Boolean = {
-    !rdd.isReliablyCheckpointed && isChecksumMismatched
-  }
 }
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index aa11148514a1..6f0fd9608334 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -190,6 +190,11 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
   private var firstInit: Boolean = _
   /** Set of TaskSets the DAGScheduler has requested executed. */
   val taskSets = scala.collection.mutable.Buffer[TaskSet]()
+
+  def taskSet(stageId: Int, attemptId: Int): TaskSet = {
+    taskSets.find(ts => ts.stageId == stageId && ts.stageAttemptId == 
attemptId).get
+  }
+
   /** Track running tasks, the key is the task's stageId , the value is the 
task's partitionId */
   var runningTaskInfos = new HashMap[Int, HashSet[Int]]()
 
@@ -3288,10 +3293,10 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     // Check status for all failedStages
     val failedStages = scheduler.failedStages.toSeq
     assert(failedStages.map(_.id) == Seq(1, 2))
-    // Shuffle blocks of "hostC" is lost, so first task of the 
`shuffleMapRdd2` needs to retry.
+    // Shuffle blocks of "hostC" is lost, rollback to do a fully-retry.
     assert(failedStages.collect {
       case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 
=> stage
-    }.head.findMissingPartitions() == Seq(0))
+    }.head.findMissingPartitions() == Seq(0, 1))
     // The result stage is still waiting for its 2 tasks to complete
     assert(failedStages.collect {
       case stage: ResultStage => stage
@@ -3330,6 +3335,78 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assertDataStructuresEmpty()
   }
 
+  test("SPARK-54956: abort stage if ever executed while using old fetch 
protocol") {
+    conf.set(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL.key, "true")
+    val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+    val shuffleDep1 = new ShuffleDependency(
+      shuffleMapRdd1,
+      new HashPartitioner(2),
+      _checksumMismatchFullRetryEnabled = true)
+    val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = 
mapOutputTracker)
+    val shuffleMapRdd3 = new MyRDD(sc, 2, List(shuffleDep1), tracker = 
mapOutputTracker)
+
+    val shuffleDep2 = new ShuffleDependency(
+      shuffleMapRdd2,
+      new HashPartitioner(2),
+      _checksumMismatchFullRetryEnabled = true)
+    val shuffleDep3 = new ShuffleDependency(
+      shuffleMapRdd3,
+      new HashPartitioner(2),
+      _checksumMismatchFullRetryEnabled = true)
+    val finalRdd = new MyRDD(sc, 2, List(shuffleDep2, shuffleDep3), tracker = 
mapOutputTracker)
+
+    submit(finalRdd, Array(0, 1))
+
+    // Finish shuffle map stage 0 and 1
+    completeShuffleMapStageSuccessfully(0, 0, 2, checksumVal = 100)
+    assert(mapOutputTracker.findMissingPartitions(shuffleDep1.shuffleId) === 
Some(Seq.empty))
+    completeShuffleMapStageSuccessfully(1, 0, 2, Seq("hostC", "hostD"), 
checksumVal = 200)
+    assert(mapOutputTracker.findMissingPartitions(taskSet(1, 0).shuffleId.get) 
=== Some(Seq.empty))
+
+    // Fail a task in Stage 2 with fetch failure from Stage 0.
+    runEvent(makeCompletionEvent(
+      taskSet(2, 0).tasks(0),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleDep1.shuffleId, 0L, 0, 
0, "ignored"),
+      null))
+    assert(scheduler.failedStages.map(_.id).toSeq == Seq(0, 2))
+    scheduler.resubmitFailedStages()
+
+    // Finish retry of stage 0 with different checksum values.
+    completeShuffleMapStageSuccessfully(0, 1, 2, checksumVal = 101)
+
+    // Expect failure
+    assert(failure != null && failure.getMessage.contains(
+      "Spark can only do this while using the new shuffle block fetching 
protocol"))
+  }
+
+  test("SPARK-54956: avoid redundant indeterminate stage rollback") {
+    val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(2))
+    val shuffleId1 = shuffleDep1.shuffleId
+    val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = 
mapOutputTracker)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(2))
+    val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = 
mapOutputTracker)
+
+    submit(finalRdd, Array(0, 1))
+
+    // Finish the first shuffle map stage.
+    completeShuffleMapStageSuccessfully(0, 0, 2)
+
+    // Trigger failure in Stage 1 (attempt 0)
+    runEvent(makeCompletionEvent(
+      taskSets(1).tasks(0),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, 
"ignored"),
+      null))
+    assert(scheduler.failedStages.map(_.id).toSeq == Seq(0, 1))
+    scheduler.resubmitFailedStages()
+
+    // Verify shuffleMergeId is 1 (incremented once from 0).
+    val stage0 = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(stage0.shuffleDep.shuffleMergeId == 1)
+    val stage1 = scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage]
+    assert(stage1.shuffleDep.shuffleMergeId == 1)
+  }
+
   test("SPARK-45182: Ignore task completion from old stage after retrying 
indeterminate stages") {
     val (shuffleId1, shuffleId2) = constructTwoIndeterminateStage()
 
@@ -4632,10 +4709,10 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     // Check status for all failedStages
     val failedStages = scheduler.failedStages.toSeq
     assert(failedStages.map(_.id) == Seq(1, 2))
-    // Shuffle blocks of "hostC" is lost, so first task of the 
`shuffleMapRdd2` needs to retry.
+    // Shuffle blocks of "hostC" is lost, rollback to do a fully-retry.
     assert(failedStages.collect {
       case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 
=> stage
-    }.head.findMissingPartitions() == Seq(0))
+    }.head.findMissingPartitions() == Seq(0, 1))
     // The result stage is still waiting for its 2 tasks to complete
     assert(failedStages.collect {
       case stage: ResultStage => stage
@@ -4654,10 +4731,11 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
 
     val newFailedStages = scheduler.failedStages.toSeq
     assert(newFailedStages.map(_.id) == Seq(0, 1))
-    // shuffleMergeId for indeterminate failed stages should be 2
-    assert(failedStages.collect {
-      case stage: ShuffleMapStage => stage.shuffleDep.shuffleMergeId
-    }.forall(x => x == 2))
+    // shuffleMergeId for indeterminate failed stages should be increased
+    assert(newFailedStages.filter(_.id == 0)
+      .exists(_.asInstanceOf[ShuffleMapStage].shuffleDep.shuffleMergeId == 1))
+    assert(newFailedStages.filter(_.id == 1)
+      .exists(_.asInstanceOf[ShuffleMapStage].shuffleDep.shuffleMergeId == 2))
     scheduler.resubmitFailedStages()
 
     // First shuffle map stage resubmitted and reran all tasks.
@@ -4668,16 +4746,16 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     // Finish all stage.
     completeShuffleMapStageSuccessfully(0, 1, 2)
     assert(mapOutputTracker.findMissingPartitions(shuffleId1) === 
Some(Seq.empty))
-    // shuffleMergeId should be 2 for the attempt number 1 for stage 0
+    // shuffleMergeId should be 1 for the attempt number 1 for stage 0
     assert(mapOutputTracker.shuffleStatuses.get(shuffleId1).forall(
-      _.mergeStatuses.forall(x => x.shuffleMergeId == 2)))
+      _.mergeStatuses.forall(x => x.shuffleMergeId == 1)))
     assert(mapOutputTracker.getNumAvailableMergeResults(shuffleId1) == 2)
 
     completeShuffleMapStageSuccessfully(1, 2, 2, Seq("hostC", "hostD"))
     assert(mapOutputTracker.findMissingPartitions(shuffleId2) === 
Some(Seq.empty))
     // shuffleMergeId should be 2 for the attempt number 2 for stage 1
     assert(mapOutputTracker.shuffleStatuses.get(shuffleId2).forall(
-      _.mergeStatuses.forall(x => x.shuffleMergeId == 3)))
+      _.mergeStatuses.forall(x => x.shuffleMergeId == 2)))
     assert(mapOutputTracker.getNumAvailableMergeResults(shuffleId2) == 2)
 
     complete(taskSets(6), Seq((Success, 11), (Success, 12)))
@@ -4866,7 +4944,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     }
 
     val shuffleIndeterminateStage = 
scheduler.stageIdToStage(3).asInstanceOf[ShuffleMapStage]
-    assert(shuffleIndeterminateStage.isIndeterminate)
+    assert(shuffleIndeterminateStage.isStaticallyIndeterminate)
     scheduler.handleShuffleMergeFinalized(shuffleIndeterminateStage, 2)
     assert(shuffleIndeterminateStage.shuffleDep.shuffleMergeEnabled)
     assert(!shuffleIndeterminateStage.shuffleDep.isShuffleMergeFinalizedMarked)
@@ -5115,7 +5193,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
         (Success, makeMapStatus("host" + idx, parts))
     }.toSeq)
 
-    assert(shuffleStage1.shuffleDep.shuffleMergeId == 2)
+    assert(shuffleStage1.shuffleDep.shuffleMergeId == 1)
     assert(shuffleStage1.shuffleDep.isShuffleMergeFinalizedMarked)
     val newMergerLocs =
       
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to