This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 922adad1e3cc [SPARK-53575][CORE] Retry entire consumer stages when 
checksum mismatch detected for a retried shuffle map task
922adad1e3cc is described below

commit 922adad1e3cc4eae70a7d10f852bf585835ef21a
Author: Tengfei Huang <[email protected]>
AuthorDate: Mon Sep 29 17:53:37 2025 +0800

    [SPARK-53575][CORE] Retry entire consumer stages when checksum mismatch 
detected for a retried shuffle map task
    
    ### What changes were proposed in this pull request?
    This PR proposes to retry all tasks of the consumer stages, when checksum 
mismatches are detected on their producer stages. In the case that we can't 
rollback and retry all tasks of a consumer stage, we will have to abort the 
stage (thus the job).
    
    How do we detect and handle nondeterministic before:
    - Stages are labeled as indeterminate at planning time, prior to query 
execution
    - When a task completes and `FetchFailed` is detected, we will abort all 
unrollbackable succeeding stages of the map stage, and resubmit failed stages.
    - In `submitMissingTasks()`, if a stage itself is isIndeterminate, we will 
call `unregisterAllMapAndMergeOutput()` and retry all tasks for stage.
    
    How do we detect and handle nondeterministic now:
    - During query execution, we keep track on the checksums produced by each 
map task.
    - When a task completes and checksum mismatch is detected, we will abort 
unrollbackable succeeding stages of the stage with checksum mismatches. The 
failed stages resubmission still happen in the same places as before.
    - In `submitMissingTasks()`, if the parent of a stage has checksum 
mismatches, we will call `unregisterAllMapAndMergeOutput()` and retry all tasks 
for stage.
    
    Note that (1) if a stage `isReliablyCheckpointed`, the consumer stages 
don't need to have whole stage retry, and (2) when mismatches are detected for 
a stage in a chain (e.g., the first stage in stage_i -> stage_i+1 -> stage_i+2 
-> ...), the direct consumer (e.g., stage_i+1) of the stage will have a whole 
stage retry, and an indirect consumer (e.g., stage_i+2) will have a whole stage 
retry when its parent detects checksum mismatches.
    
    ### Why are the changes needed?
    Handle nondeterministic issues caused by the retry of shuffle map task.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UTs added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #52336 from ivoson/SPARK-53575.
    
    Authored-by: Tengfei Huang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../main/scala/org/apache/spark/Dependency.scala   |   3 +-
 .../scala/org/apache/spark/MapOutputTracker.scala  |  10 +-
 core/src/main/scala/org/apache/spark/rdd/RDD.scala |   2 +-
 .../org/apache/spark/scheduler/DAGScheduler.scala  | 101 +++++---
 .../scala/org/apache/spark/scheduler/Stage.scala   |  22 ++
 .../apache/spark/scheduler/DAGSchedulerSuite.scala | 271 ++++++++++++++++++++-
 .../org/apache/spark/sql/internal/SQLConf.scala    |  11 +
 .../execution/exchange/ShuffleExchangeExec.scala   |   9 +-
 .../apache/spark/sql/MapStatusEndToEndSuite.scala  |  43 ++--
 9 files changed, 407 insertions(+), 65 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala 
b/core/src/main/scala/org/apache/spark/Dependency.scala
index 93a2bbe25157..c436025e06bb 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -89,7 +89,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
     val aggregator: Option[Aggregator[K, V, C]] = None,
     val mapSideCombine: Boolean = false,
     val shuffleWriterProcessor: ShuffleWriteProcessor = new 
ShuffleWriteProcessor,
-    val rowBasedChecksums: Array[RowBasedChecksum] = 
ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS)
+    val rowBasedChecksums: Array[RowBasedChecksum] = 
ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS,
+    val checksumMismatchFullRetryEnabled: Boolean = false)
   extends Dependency[Product2[K, V]] with Logging {
 
   def this(
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala 
b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 3f823b60156a..334eb832c4c2 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -165,9 +165,11 @@ private class ShuffleStatus(
 
   /**
    * Register a map output. If there is already a registered location for the 
map output then it
-   * will be replaced by the new location.
+   * will be replaced by the new location. Returns true if the checksum in the 
new MapStatus is
+   * different from a previous registered MapStatus. Otherwise, returns false.
    */
-  def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock {
+  def addMapOutput(mapIndex: Int, status: MapStatus): Boolean = withWriteLock {
+    var isChecksumMismatch: Boolean = false
     val currentMapStatus = mapStatuses(mapIndex)
     if (currentMapStatus == null) {
       _numAvailableMapOutputs += 1
@@ -183,9 +185,11 @@ private class ShuffleStatus(
       logInfo(s"Checksum of map output changes from ${preStatus.checksumValue} 
to " +
         s"${status.checksumValue} for task ${status.mapId}.")
       checksumMismatchIndices.add(mapIndex)
+      isChecksumMismatch = true
     }
     mapStatuses(mapIndex) = status
     mapIdToMapIndex(status.mapId) = mapIndex
+    isChecksumMismatch
   }
 
   /**
@@ -853,7 +857,7 @@ private[spark] class MapOutputTrackerMaster(
     }
   }
 
-  def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): 
Unit = {
+  def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): 
Boolean = {
     shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
   }
 
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala 
b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 117b2925710d..d1408ee774ce 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1773,7 +1773,7 @@ abstract class RDD[T: ClassTag](
   /**
    * Return whether this RDD is reliably checkpointed and materialized.
    */
-  private[rdd] def isReliablyCheckpointed: Boolean = {
+  private[spark] def isReliablyCheckpointed: Boolean = {
     checkpointData match {
       case Some(reliable: ReliableRDDCheckpointData[_]) if 
reliable.isCheckpointed => true
       case _ => false
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 30eb49b0c079..3b719a2c7d24 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1551,29 +1551,46 @@ private[spark] class DAGScheduler(
     // The operation here can make sure for the partially completed 
intermediate stage,
     // `findMissingPartitions()` returns all partitions every time.
     stage match {
-      case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
-        // already executed at least once
-        if (sms.getNextAttemptId > 0) {
-          // While we previously validated possible rollbacks during the 
handling of a FetchFailure,
-          // where we were fetching from an indeterminate source map stages, 
this later check
-          // covers additional cases like recalculating an indeterminate stage 
after an executor
-          // loss. Moreover, because this check occurs later in the process, 
if a result stage task
-          // has successfully completed, we can detect this and abort the job, 
as rolling back a
-          // result stage is not possible.
-          val stagesToRollback = collectSucceedingStages(sms)
-          abortStageWithInvalidRollBack(stagesToRollback)
-          // stages which cannot be rolled back were aborted which leads to 
removing the
-          // the dependant job(s) from the active jobs set
-          val numActiveJobsWithStageAfterRollback =
-            activeJobs.count(job => stagesToRollback.contains(job.finalStage))
-          if (numActiveJobsWithStageAfterRollback == 0) {
-            logInfo(log"All jobs depending on the indeterminate stage " +
-              log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is 
not needed anymore.")
-            return
+      case sms: ShuffleMapStage if !sms.isAvailable =>
+        val needFullStageRetry = if 
(sms.shuffleDep.checksumMismatchFullRetryEnabled) {
+          // When the parents of this stage are indeterminate (e.g., some 
parents are not
+          // checkpointed and checksum mismatches are detected), the output 
data of the parents
+          // may have changed due to task retries. For correctness reason, we 
need to
+          // retry all tasks of the current stage. The legacy way of using 
current stage's
+          // deterministic level to trigger full stage retry is not accurate.
+          stage.isParentIndeterminate
+        } else {
+          if (stage.isIndeterminate) {
+            // already executed at least once
+            if (sms.getNextAttemptId > 0) {
+              // While we previously validated possible rollbacks during the 
handling of a FetchFailure,
+              // where we were fetching from an indeterminate source map 
stages, this later check
+              // covers additional cases like recalculating an indeterminate 
stage after an executor
+              // loss. Moreover, because this check occurs later in the 
process, if a result stage task
+              // has successfully completed, we can detect this and abort the 
job, as rolling back a
+              // result stage is not possible.
+              val stagesToRollback = collectSucceedingStages(sms)
+              abortStageWithInvalidRollBack(stagesToRollback)
+              // stages which cannot be rolled back were aborted which leads 
to removing the
+              // the dependant job(s) from the active jobs set
+              val numActiveJobsWithStageAfterRollback =
+                activeJobs.count(job => 
stagesToRollback.contains(job.finalStage))
+              if (numActiveJobsWithStageAfterRollback == 0) {
+                logInfo(log"All jobs depending on the indeterminate stage " +
+                  log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage 
is not needed anymore.")
+                return
+              }
+            }
+            true
+          } else {
+            false
           }
         }
-        
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
-        sms.shuffleDep.newShuffleMergeState()
+
+        if (needFullStageRetry) {
+          
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
+          sms.shuffleDep.newShuffleMergeState()
+        }
       case _ =>
     }
 
@@ -1886,6 +1903,20 @@ private[spark] class DAGScheduler(
     }
   }
 
+  /**
+   * If a map stage is non-deterministic, the map tasks of the stage may 
return different result
+   * when re-try. To make sure data correctness, we need to re-try all the 
tasks of its succeeding
+   * stages, as the input data may be changed after the map tasks are 
re-tried. For stages where
+   * rollback and retry all tasks are not possible, we will need to abort the 
stages.
+   */
+  private[scheduler] def abortUnrollbackableStages(mapStage: ShuffleMapStage): 
Unit = {
+    val stagesToRollback = collectSucceedingStages(mapStage)
+    val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback)
+    logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with 
indeterminate output " +
+      log"was failed, we will roll back and rerun below stages which include 
itself and all its " +
+      log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}")
+  }
+
   /**
    * Responds to a task finishing. This is called inside the event loop so it 
assumes that it can
    * modify the scheduler's internal state. Use taskEnded() to post a task end 
event from outside.
@@ -2022,8 +2053,26 @@ private[spark] class DAGScheduler(
                 // The epoch of the task is acceptable (i.e., the task was 
launched after the most
                 // recent failure we're aware of for the executor), so mark 
the task's output as
                 // available.
-                mapOutputTracker.registerMapOutput(
+                val isChecksumMismatched = mapOutputTracker.registerMapOutput(
                   shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
+                if (isChecksumMismatched) {
+                  shuffleStage.isChecksumMismatched = isChecksumMismatched
+                  // There could be multiple checksum mismatches detected for 
a single stage attempt.
+                  // We check for stage abortion once and only once when we 
first detect checksum
+                  // mismatch for each stage attempt. For example, assume that 
we have
+                  // stage1 -> stage2, and we encounter checksum mismatch 
during the retry of stage1.
+                  // In this case, we need to call abortUnrollbackableStages() 
for the succeeding
+                  // stages. Assume that when stage2 is retried, some tasks 
finish and some tasks
+                  // failed again with FetchFailed. In case that we encounter 
checksum mismatch again
+                  // during the retry of stage1, we need to call 
abortUnrollbackableStages() again.
+                  if (shuffleStage.maxChecksumMismatchedId < 
smt.stageAttemptId) {
+                    shuffleStage.maxChecksumMismatchedId = smt.stageAttemptId
+                    if 
(shuffleStage.shuffleDep.checksumMismatchFullRetryEnabled
+                      && shuffleStage.isStageIndeterminate) {
+                      abortUnrollbackableStages(shuffleStage)
+                    }
+                  }
+                }
               }
             } else {
               logInfo(log"Ignoring ${MDC(TASK_NAME, smt)} completion from an 
older attempt of indeterminate stage")
@@ -2148,12 +2197,8 @@ private[spark] class DAGScheduler(
               // Note that, if map stage is UNORDERED, we are fine. The 
shuffle partitioner is
               // guaranteed to be determinate, so the input data of the 
reducers will not change
               // even if the map tasks are re-tried.
-              if (mapStage.isIndeterminate) {
-                val stagesToRollback = collectSucceedingStages(mapStage)
-                val rollingBackStages = 
abortStageWithInvalidRollBack(stagesToRollback)
-                logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} 
with indeterminate output was failed, " +
-                  log"we will roll back and rerun below stages which include 
itself and all its " +
-                  log"indeterminate child stages: ${MDC(STAGES, 
rollingBackStages)}")
+              if (mapStage.isIndeterminate && 
!mapStage.shuffleDep.checksumMismatchFullRetryEnabled) {
+                abortUnrollbackableStages(mapStage)
               }
 
               // We expect one executor failure to trigger many FetchFailures 
in rapid succession,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala 
b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index f35beafd8748..9bf604e9a83c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -72,6 +72,18 @@ private[scheduler] abstract class Stage(
   private var nextAttemptId: Int = 0
   private[scheduler] def getNextAttemptId: Int = nextAttemptId
 
+  /**
+   * Whether checksum mismatches have been detected across different attempt 
of the stage, where
+   * checksum mismatches typically indicates that different stage attempts 
have produced different
+   * data.
+   */
+  private[scheduler] var isChecksumMismatched: Boolean = false
+
+  /**
+   * The maximum of task attempt id where checksum mismatches are detected.
+   */
+  private[scheduler] var maxChecksumMismatchedId: Int = nextAttemptId
+
   val name: String = callSite.shortForm
   val details: String = callSite.longForm
 
@@ -131,4 +143,14 @@ private[scheduler] abstract class Stage(
   def isIndeterminate: Boolean = {
     rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE
   }
+
+  // Returns true if any parents of this stage are indeterminate.
+  def isParentIndeterminate: Boolean = {
+    parents.exists(_.isStageIndeterminate)
+  }
+
+  // Returns true if the stage itself is indeterminate.
+  def isStageIndeterminate: Boolean = {
+    !rdd.isReliablyCheckpointed && isChecksumMismatched
+  }
 }
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 1ada81cbdd0e..c20866fda0a3 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -3415,6 +3415,19 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assertDataStructuresEmpty()
   }
 
+  private def checkAndCompleteRetryStage(
+      taskSetIndex: Int,
+      stageId: Int,
+      shuffleId: Int,
+      numTasks: Int = 2,
+      checksumVal: Long = 0): Unit = {
+    assert(taskSets(taskSetIndex).stageId == stageId)
+    assert(taskSets(taskSetIndex).stageAttemptId == 1)
+    assert(taskSets(taskSetIndex).tasks.length == numTasks)
+    completeShuffleMapStageSuccessfully(stageId, 1, 2, checksumVal = 
checksumVal)
+    assert(mapOutputTracker.findMissingPartitions(shuffleId) === 
Some(Seq.empty))
+  }
+
   test("SPARK-25341: continuous indeterminate stage roll back") {
     // shuffleMapRdd1/2/3 are all indeterminate.
     val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true)
@@ -3454,17 +3467,6 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(scheduler.failedStages.toSeq.map(_.id) == Seq(1, 2))
     scheduler.resubmitFailedStages()
 
-    def checkAndCompleteRetryStage(
-        taskSetIndex: Int,
-        stageId: Int,
-        shuffleId: Int): Unit = {
-      assert(taskSets(taskSetIndex).stageId == stageId)
-      assert(taskSets(taskSetIndex).stageAttemptId == 1)
-      assert(taskSets(taskSetIndex).tasks.length == 2)
-      completeShuffleMapStageSuccessfully(stageId, 1, 2)
-      assert(mapOutputTracker.findMissingPartitions(shuffleId) === 
Some(Seq.empty))
-    }
-
     // Check all indeterminate stage roll back.
     checkAndCompleteRetryStage(3, 0, shuffleId1)
     checkAndCompleteRetryStage(4, 1, shuffleId2)
@@ -3477,6 +3479,253 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assertDataStructuresEmpty()
   }
 
+  // Construct the scenario of stages with checksum mismatches and FetchFailed.
+  private def constructChecksumMismatchStageFetchFailed(): (Int, Int) = {
+    val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+
+    val shuffleDep1 = new ShuffleDependency(
+      shuffleMapRdd1,
+      new HashPartitioner(2),
+      checksumMismatchFullRetryEnabled = true
+    )
+    val shuffleId1 = shuffleDep1.shuffleId
+    val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = 
mapOutputTracker)
+
+    val shuffleDep2 = new ShuffleDependency(
+      shuffleMapRdd2,
+      new HashPartitioner(2),
+      checksumMismatchFullRetryEnabled = true
+    )
+    val shuffleId2 = shuffleDep2.shuffleId
+    val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = 
mapOutputTracker)
+
+    submit(finalRdd, Array(0, 1))
+
+    // Finish the first shuffle map stage.
+    completeShuffleMapStageSuccessfully(
+      0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100)
+    assert(mapOutputTracker.findMissingPartitions(shuffleId1) === 
Some(Seq.empty))
+
+    // The first task of the second shuffle map stage failed with FetchFailed.
+    runEvent(makeCompletionEvent(
+      taskSets(1).tasks(0),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, 
"ignored"),
+      null))
+
+    // Finish the second task of the second shuffle map stage.
+    runEvent(makeCompletionEvent(
+      taskSets(1).tasks(1), Success, makeMapStatus("hostB", 2),
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(1)))
+
+    (shuffleId1, shuffleId2)
+  }
+
+  // Construct the scenario of stages with checksum mismatches and FetchFailed.
+  // This function assumes that the input `mapRdd` has a single stage with 2 
partitions.
+  private def constructChecksumMismatchStageFetchFailed(mapRdd: MyRDD): Unit = 
{
+    val shuffleDep = new ShuffleDependency(
+      mapRdd,
+      new HashPartitioner(2),
+      checksumMismatchFullRetryEnabled = true
+    )
+    val shuffleId = shuffleDep.shuffleId
+    val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    submit(finalRdd, Array(0, 1))
+
+    completeShuffleMapStageSuccessfully(
+      0, 0, numShufflePartitions = 2, Seq("hostA", "hostB"), checksumVal = 100)
+    assert(mapOutputTracker.findMissingPartitions(shuffleId) === 
Some(Seq.empty))
+
+    // Fail the first task of the result stage with FetchFailed.
+    runEvent(makeCompletionEvent(
+      taskSets(1).tasks(0),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"),
+      null))
+
+    // Finish the second task of the result stage.
+    runEvent(makeCompletionEvent(
+      taskSets(1).tasks(1), Success, 42,
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(0)))
+
+    // Check status for all failedStages.
+    val failedStages = scheduler.failedStages.toSeq
+    // Shuffle blocks of "hostA" is lost, so first task of the shuffle map 
stage and
+    // result stage needs to retry.
+    assert(failedStages.map(_.id) == Seq(0, 1))
+    assert(failedStages.forall(_.findMissingPartitions() == Seq(0)))
+
+    scheduler.resubmitFailedStages()
+
+    // First shuffle map stage reran failed tasks with a different checksum.
+    completeShuffleMapStageSuccessfully(0, 1, 2, checksumVal = 101)
+  }
+
+  private def assertChecksumMismatchResultStageFailToRollback(mapRdd: MyRDD): 
Unit = {
+    constructChecksumMismatchStageFetchFailed(mapRdd)
+
+    // The job should fail because Spark can't rollback the result stage.
+    assert(failure != null && failure.getMessage.contains("Spark cannot 
rollback"))
+  }
+
+  private def assertChecksumMismatchResultStageNotRolledBack(mapRdd: MyRDD): 
Unit = {
+    constructChecksumMismatchStageFetchFailed(mapRdd)
+
+    assert(failure == null, "job should not fail")
+    // Result stage success, all job ended.
+    complete(taskSets(3), Seq((Success, 41)))
+    assert(results === Map(0 -> 41, 1 -> 42))
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-53575: abort stage while using old fetch protocol") {
+    conf.set(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL.key, "true")
+    constructChecksumMismatchStageFetchFailed()
+
+    scheduler.resubmitFailedStages()
+    completeShuffleMapStageSuccessfully(0, 1, 2, checksumVal = 101)
+
+    // The job should fail because Spark can't rollback the shuffle map stage 
while
+    // using old protocol.
+    assert(failure != null && failure.getMessage.contains(
+      "Spark can only do this while using the new shuffle block fetching 
protocol"))
+  }
+
+  test("SPARK-53575: retry all the succeeding stages when the map stage has 
checksum mismatches") {
+    val (shuffleId1, shuffleId2) =
+      constructChecksumMismatchStageFetchFailed()
+
+    // Check status for all failedStages.
+    val failedStages = scheduler.failedStages.toSeq
+    // Shuffle blocks of "hostA" is lost, so first task of the 
`shuffleMapRdd1` and
+    // `shuffleMapRdd2` needs to retry.
+    assert(failedStages.map(_.id) == Seq(0, 1))
+    assert(failedStages.forall(_.findMissingPartitions() == Seq(0)))
+
+    scheduler.resubmitFailedStages()
+
+    // First shuffle map stage reran failed tasks with a different checksum.
+    checkAndCompleteRetryStage(2, 0, shuffleId1, numTasks = 1, checksumVal = 
101)
+
+    // Second shuffle map stage reran all tasks.
+    checkAndCompleteRetryStage(3, 1, shuffleId2, numTasks = 2)
+
+    complete(taskSets(4), Seq((Success, 11), (Success, 12)))
+
+    // Job successful ended.
+    assert(results === Map(0 -> 11, 1 -> 12))
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-53575: continuous checksum mismatch stage roll back") {
+    // shuffleMapRdd1/2 have checksum mismatches, and shuffleMapRdd2/3 
requires full stage retries.
+    val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+    val shuffleDep1 = new ShuffleDependency(
+      shuffleMapRdd1,
+      new HashPartitioner(2),
+      checksumMismatchFullRetryEnabled = true
+    )
+    val shuffleId1 = shuffleDep1.shuffleId
+
+    val shuffleMapRdd2 = new MyRDD(
+      sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+    val shuffleDep2 = new ShuffleDependency(
+      shuffleMapRdd2,
+      new HashPartitioner(2),
+      checksumMismatchFullRetryEnabled = true
+    )
+    val shuffleId2 = shuffleDep2.shuffleId
+
+    val shuffleMapRdd3 = new MyRDD(
+      sc, 2, List(shuffleDep2), tracker = mapOutputTracker)
+    val shuffleDep3 = new ShuffleDependency(
+      shuffleMapRdd3,
+      new HashPartitioner(2),
+      checksumMismatchFullRetryEnabled = true
+    )
+    val shuffleId3 = shuffleDep3.shuffleId
+    val finalRdd = new MyRDD(sc, 2, List(shuffleDep3), tracker = 
mapOutputTracker)
+
+    submit(finalRdd, Array(0, 1), properties = new Properties())
+
+    // Finish the first 2 shuffle map stages.
+    completeShuffleMapStageSuccessfully(0, 0, 2, Seq("hostA", "hostB"), 
checksumVal = 100)
+    assert(mapOutputTracker.findMissingPartitions(shuffleId1) === 
Some(Seq.empty))
+    completeShuffleMapStageSuccessfully(1, 0, 2, Seq("hostA", "hostB"), 
checksumVal = 200)
+    assert(mapOutputTracker.findMissingPartitions(shuffleId2) === 
Some(Seq.empty))
+
+    // Fail the first task of the third shuffle map stage with FetchFailed.
+    runEvent(makeCompletionEvent(
+      taskSets(2).tasks(0),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId2, 0L, 0, 0, 
"ignored"),
+      null))
+
+    // Finish the second task of the third shuffle map stage.
+    runEvent(makeCompletionEvent(
+      taskSets(2).tasks(1), Success, makeMapStatus("hostB", 2),
+      Seq.empty, Array.empty, createFakeTaskInfoWithId(1)))
+    mapOutputTracker.removeOutputsOnHost("hostA")
+
+    // Check status for all failedStages.
+    val failedStages = scheduler.failedStages.toSeq
+    // Shuffle blocks of "hostA" is lost, so first task of the 
`shuffleMapRdd2` and
+    // `shuffleMapRdd3` needs to retry.
+    assert(failedStages.map(_.id) == Seq(1, 2))
+    assert(failedStages.forall(_.findMissingPartitions() == Seq(0)))
+
+    scheduler.resubmitFailedStages()
+
+    // First shuffle map stage reran failed tasks with a different checksum.
+    checkAndCompleteRetryStage(3, 0, shuffleId1, numTasks = 1, checksumVal = 
101)
+    // Second and third shuffle map stages reran all tasks with a different 
checksum.
+    checkAndCompleteRetryStage(4, 1, shuffleId2, numTasks = 2, checksumVal = 
201)
+    checkAndCompleteRetryStage(5, 2, shuffleId3, numTasks = 2, checksumVal = 
301)
+    // Result stage success, all job ended.
+    complete(taskSets(6), Seq((Success, 11), (Success, 12)))
+    assert(results === Map(0 -> 11, 1 -> 12))
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-53575: cannot rollback a result stage") {
+    val shuffleMapRdd = new MyRDD(sc, 2, Nil)
+    assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd)
+  }
+
+  test("SPARK-53575: local checkpoint fail to rollback (checkpointed before)") 
{
+    val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil)
+    shuffleMapRdd.localCheckpoint()
+    shuffleMapRdd.doCheckpoint()
+    assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd)
+  }
+
+  test("SPARK-53575: local checkpoint fail to rollback (checkpointing now)") {
+    val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil)
+    shuffleMapRdd.localCheckpoint()
+    assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd)
+  }
+
+  test("SPARK-53575: reliable checkpoint can avoid rollback (checkpointed 
before)") {
+    withTempDir { dir =>
+      sc.setCheckpointDir(dir.getCanonicalPath)
+      val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil)
+      shuffleMapRdd.checkpoint()
+      shuffleMapRdd.doCheckpoint()
+      assertChecksumMismatchResultStageNotRolledBack(shuffleMapRdd)
+    }
+  }
+
+  test("SPARK-53575: reliable checkpoint fail to rollback (checkpointing 
now)") {
+    withTempDir { dir =>
+      sc.setCheckpointDir(dir.getCanonicalPath)
+      val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil)
+      shuffleMapRdd.checkpoint()
+      assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd)
+    }
+  }
+
   test("SPARK-29042: Sampled RDD with unordered input should be 
indeterminate") {
     val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = false)
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 17b8dd493cf8..477d09d29a05 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -890,6 +890,14 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  private[spark] val SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED =
+    
buildConf("spark.sql.shuffle.orderIndependentChecksum.enableFullRetryOnMismatch")
+      .doc("Whether to retry all tasks of a consumer stage when we detect 
checksum mismatches " +
+        "with its producer stages.")
+      .version("4.1.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE =
     buildConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize")
       .internal()
@@ -6651,6 +6659,9 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
   def shuffleOrderIndependentChecksumEnabled: Boolean =
     getConf(SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED)
 
+  def shuffleChecksumMismatchFullRetryEnabled: Boolean =
+    getConf(SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED)
+
   def allowCollationsInMapKeys: Boolean = getConf(ALLOW_COLLATIONS_IN_MAP_KEYS)
 
   def objectLevelCollationsEnabled: Boolean = 
getConf(OBJECT_LEVEL_COLLATIONS_ENABLED)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 9c86bbb606a5..f052bd906880 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -480,19 +480,22 @@ object ShuffleExchangeExec {
     // Now, we manually create a ShuffleDependency. Because pairs in 
rddWithPartitionIds
     // are in the form of (partitionId, row) and every partitionId is in the 
expected range
     // [0, part.numPartitions - 1]. The partitioner of this is a 
PartitionIdPassthrough.
-    val checksumSize =
-      if (SQLConf.get.shuffleOrderIndependentChecksumEnabled) {
+    val checksumSize = {
+      if (SQLConf.get.shuffleOrderIndependentChecksumEnabled ||
+        SQLConf.get.shuffleChecksumMismatchFullRetryEnabled) {
         part.numPartitions
       } else {
         0
       }
+    }
     val dependency =
       new ShuffleDependency[Int, InternalRow, InternalRow](
         rddWithPartitionIds,
         new PartitionIdPassthrough(part.numPartitions),
         serializer,
         shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics),
-        rowBasedChecksums = 
UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize))
+        rowBasedChecksums = 
UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize),
+        checksumMismatchFullRetryEnabled = 
SQLConf.get.shuffleChecksumMismatchFullRetryEnabled)
 
     dependency
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala
index 0fe660312210..abcd346c3277 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala
@@ -25,7 +25,6 @@ import org.apache.spark.sql.test.SQLTestUtils
 class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils {
     override def spark: SparkSession = SparkSession.builder()
       .master("local")
-      .config(SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, value = 
true)
       .config(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, value = 5)
       .config(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key, 
value = false)
       .getOrCreate()
@@ -39,26 +38,34 @@ class MapStatusEndToEndSuite extends SparkFunSuite with 
SQLTestUtils {
   }
 
   test("Propagate checksum from executor to driver") {
-    assert(spark.sparkContext.conf
-      .get("spark.sql.shuffle.orderIndependentChecksum.enabled") == "true")
-    
assert(spark.conf.get("spark.sql.shuffle.orderIndependentChecksum.enabled") == 
"true")
-    assert(spark.sparkContext.conf.get("spark.sql.leafNodeDefaultParallelism") 
== "5")
-    assert(spark.conf.get("spark.sql.leafNodeDefaultParallelism") == "5")
-    
assert(spark.sparkContext.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled")
+    
assert(spark.sparkContext.conf.get(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key) 
== "5")
+    assert(spark.conf.get(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key) == "5")
+    
assert(spark.sparkContext.conf.get(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key)
       == "false")
-    
assert(spark.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled")
 == "false")
+    
assert(spark.conf.get(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key)
 == "false")
 
-    withTable("t") {
-      spark.range(1000).repartition(10).write.mode("overwrite").
-        saveAsTable("t")
-    }
+    var shuffleId = 0
+    Seq(("true", "false"), ("false", "true"), ("true", "true")).foreach {
+      case (orderIndependentChecksumEnabled: String, 
checksumMismatchFullRetryEnabled: String) =>
+        withSQLConf(
+          SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key ->
+            orderIndependentChecksumEnabled,
+          SQLConf.SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key ->
+            checksumMismatchFullRetryEnabled) {
+          withTable("t") {
+            spark.range(1000).repartition(10).write.mode("overwrite").
+              saveAsTable("t")
+          }
 
-    val shuffleStatuses = spark.sparkContext.env.mapOutputTracker.
-      asInstanceOf[MapOutputTrackerMaster].shuffleStatuses
-    assert(shuffleStatuses.size == 1)
+          val shuffleStatuses = spark.sparkContext.env.mapOutputTracker.
+            asInstanceOf[MapOutputTrackerMaster].shuffleStatuses
+          assert(shuffleStatuses.contains(shuffleId))
 
-    val mapStatuses = shuffleStatuses(0).mapStatuses
-    assert(mapStatuses.length == 5)
-    assert(mapStatuses.forall(_.checksumValue != 0))
+          val mapStatuses = shuffleStatuses(shuffleId).mapStatuses
+          assert(mapStatuses.length == 5)
+          assert(mapStatuses.forall(_.checksumValue != 0))
+          shuffleId += 1
+        }
+    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to