This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7e4a040e07c9 [SPARK-55064][SQL][CORE] Support query level 
indeterminate shuffle retry
7e4a040e07c9 is described below

commit 7e4a040e07c90d539d26d7878e1dfdadd289f075
Author: Tengfei Huang <[email protected]>
AuthorDate: Thu Jan 29 15:47:49 2026 +0800

    [SPARK-55064][SQL][CORE] Support query level indeterminate shuffle retry
    
    ### What changes were proposed in this pull request?
    Currently when a checksum mismatch detected for a indeterminate shuffle map 
stage, we'll find and validate all the succeeding stages in `active jobs`.
    
    There are a few problems here especially for complicated queries with 
sub-queries:
    1. The jobs reuse the same shuffle output submitted by sub-queries could 
have been finished which can not be tracked through active jobs;
    2. The stage(stage id) in a sub-query job could be different even though 
they share the same shuffle output;
    3. Succeeding shuffle map stages with available shuffle output are skipped 
which may have inconsistent data;
    4. There could be succeeding stages running which consumed the old shuffle 
data;
    
    To further cover these scenarios, this PR proposed to:
    1. Tracking all the jobs submitted by the same query execution used for 
finding all the succeeding stages within the same query;
    2. Find the succeeding stages by comparing the shuffle Id instead of stage 
Id for completed jobs;
    3. Abort the current indeterminate shuffle map stage if there are completed 
succeeding result stages;
    4. Abort the running succeeding result stages;
    6. Rollback all succeeding shuffle map stages:
        a. Cancel running stages, clean up shuffle output and resubmit;
        b. Clean up shuffle data for available completed shuffle map stages;
    
    ### Why are the changes needed?
    Ensure correctness for queries with indeterminate shuffle map stage retries.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UTs
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #53868 from ivoson/SPARK-55064.
    
    Authored-by: Tengfei Huang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../java/org/apache/spark/internal/LogKeys.java    |   1 +
 .../main/scala/org/apache/spark/Dependency.scala   |   3 +-
 .../main/scala/org/apache/spark/SparkContext.scala |   1 +
 .../org/apache/spark/scheduler/DAGScheduler.scala  | 181 +++++++++++-
 .../apache/spark/scheduler/DAGSchedulerEvent.scala |   2 +
 .../apache/spark/scheduler/DAGSchedulerSuite.scala | 302 +++++++++++++++++++++
 .../org/apache/spark/sql/internal/SQLConf.scala    |  14 +
 .../apache/spark/sql/execution/SQLExecution.scala  |  10 +
 .../execution/exchange/ShuffleExchangeExec.scala   |   4 +-
 .../spark/sql/execution/ui/SQLListener.scala       |   3 +
 .../spark/sql/execution/QueryExecutionSuite.scala  |  63 +++++
 11 files changed, 579 insertions(+), 5 deletions(-)

diff --git 
a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java 
b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
index ad51da5bbc6e..59df0423fad2 100644
--- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
+++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
@@ -316,6 +316,7 @@ public enum LogKeys implements LogKey {
   JAVA_VERSION,
   JAVA_VM_NAME,
   JOB_ID,
+  JOB_IDS,
   JOIN_CONDITION,
   JOIN_CONDITION_SUB_EXPR,
   JOIN_TYPE,
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala 
b/core/src/main/scala/org/apache/spark/Dependency.scala
index c94ce35cb250..da97aff5b344 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -90,7 +90,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
     val mapSideCombine: Boolean = false,
     val shuffleWriterProcessor: ShuffleWriteProcessor = new 
ShuffleWriteProcessor,
     val rowBasedChecksums: Array[RowBasedChecksum] = 
ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS,
-    private val _checksumMismatchFullRetryEnabled: Boolean = false)
+    private val _checksumMismatchFullRetryEnabled: Boolean = false,
+    val checksumMismatchQueryLevelRollbackEnabled: Boolean = false)
   extends Dependency[Product2[K, V]] with Logging {
 
   def this(
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala 
b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 6f8be49e3959..8c92f4c10aa5 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -3152,6 +3152,7 @@ object SparkContext extends Logging {
   private[spark] val SPARK_SCHEDULER_POOL = "spark.scheduler.pool"
   private[spark] val RDD_SCOPE_KEY = "spark.rdd.scope"
   private[spark] val RDD_SCOPE_NO_OVERRIDE_KEY = "spark.rdd.scope.noOverride"
+  private[spark] val SQL_EXECUTION_ID_KEY = "spark.sql.execution.id"
 
   /**
    * Executor id for the driver.  In earlier versions of Spark, this was 
`<driver>`, but this was
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 67785a0ce9ea..0c8d437c98bd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -171,6 +171,14 @@ private[spark] class DAGScheduler(
 
   private[scheduler] val activeJobs = new HashSet[ActiveJob]
 
+  // Track all the jobs submitted by the same query execution, will clean up 
after
+  // the query finishes. Use ConcurrentHashMap to allow thread-safe access from
+  // different threads (e.g., SQLExecution during testing).
+  private[spark] val activeQueryToJobs =
+    new ConcurrentHashMap[Long, java.util.Set[ActiveJob]]()
+
+  private[spark] val jobIdToQueryExecutionId = new ConcurrentHashMap[Int, 
java.lang.Long]()
+
   // Job groups that are cancelled with `cancelFutureJobs` as true, with at 
most
   // `NUM_CANCELLED_JOB_GROUPS_TO_TRACK` stored. On a new job submission, if 
its job group is in
   // this set, the job will be immediately cancelled.
@@ -1145,6 +1153,14 @@ private[spark] class DAGScheduler(
     eventProcessLoop.post(AllJobsCancelled)
   }
 
+  /**
+   * Cleanup the jobs of query execution tracked in the DAGScheduler.
+   */
+  def cleanupQueryJobs(executionId: Long): Unit = {
+    logInfo(s"Asked to cleanup jobs for query execution $executionId")
+    eventProcessLoop.post(CleanupQueryJobs(executionId))
+  }
+
   private[scheduler] def doCancelAllJobs(): Unit = {
     // Cancel all running jobs.
     runningStages.map(_.firstJobId).foreach(handleJobCancellation(_,
@@ -1153,6 +1169,12 @@ private[spark] class DAGScheduler(
     jobIdToActiveJob.clear() // but just in case we lost track of some jobs...
   }
 
+  private[spark] def doCleanupQueryJobs(executionId: Long): Unit = {
+    Option(activeQueryToJobs.remove(executionId)).foreach { jobs =>
+      jobs.forEach(job => jobIdToQueryExecutionId.remove(job.jobId))
+    }
+  }
+
   /**
    * Cancel all jobs associated with a running or scheduled stage.
    */
@@ -1325,6 +1347,18 @@ private[spark] class DAGScheduler(
     listenerBus.post(SparkListenerTaskGettingResult(taskInfo))
   }
 
+  private def getQueryExecutionIdFromProperties(properties: Properties): 
Option[Long] = {
+    try {
+      Option(properties)
+        .flatMap(properties => 
Option(properties.getProperty(SparkContext.SQL_EXECUTION_ID_KEY)))
+        .map(_.toLong)
+    } catch {
+      case e: Throwable =>
+        logWarning(log"Failed to get query execution ID from job properties", 
e)
+        None
+    }
+  }
+
   private[scheduler] def handleJobSubmitted(
       jobId: Int,
       finalRDD: RDD[_],
@@ -1398,6 +1432,10 @@ private[spark] class DAGScheduler(
     val jobSubmissionTime = clock.getTimeMillis()
     jobIdToActiveJob(jobId) = job
     activeJobs += job
+    getQueryExecutionIdFromProperties(properties).foreach { qeId =>
+      activeQueryToJobs.computeIfAbsent(qeId, _ => 
ConcurrentHashMap.newKeySet()).add(job)
+      jobIdToQueryExecutionId.put(jobId, qeId)
+    }
     finalStage.setActiveJob(job)
     val stageIds = jobIdToStageIds(jobId).toArray
     val stageInfos =
@@ -1441,6 +1479,10 @@ private[spark] class DAGScheduler(
     val jobSubmissionTime = clock.getTimeMillis()
     jobIdToActiveJob(jobId) = job
     activeJobs += job
+    getQueryExecutionIdFromProperties(properties).foreach { qeId =>
+      activeQueryToJobs.computeIfAbsent(qeId, _ => 
ConcurrentHashMap.newKeySet()).add(job)
+      jobIdToQueryExecutionId.put(jobId, qeId)
+    }
     finalStage.addActiveJob(job)
     val stageIds = jobIdToStageIds(jobId).toArray
     val stageInfos =
@@ -1898,22 +1940,28 @@ private[spark] class DAGScheduler(
    * stages will be resubmitted and re-try all the tasks, as the input data 
may be changed after
    * the map tasks are re-tried. For stages where rollback and retry all tasks 
are not possible,
    * we will need to abort the stages.
+   *
+   * @return true if the corresponding active jobs are not aborted and will 
continue to run after rollback, otherwise false
    */
   private[scheduler] def rollbackSucceedingStages(
       mapStage: ShuffleMapStage,
-      rollbackCurrentStage: Boolean = false): Unit = {
+      rollbackCurrentStage: Boolean = false): Boolean = {
     val stagesToRollback = if (rollbackCurrentStage) {
       collectSucceedingStages(mapStage)
     } else {
       collectSucceedingStages(mapStage).filterNot(_ == mapStage)
     }
+    logInfo(log"Found succeeding stages ${MDC(STAGES, stagesToRollback)} of " +
+      log"shuffle checksum mismatch stage ${MDC(STAGE, mapStage)} in active 
jobs")
     val stagesCanRollback = 
filterAndAbortUnrollbackableStages(stagesToRollback)
+
     // stages which cannot be rolled back were aborted which leads to removing 
the
     // the dependant job(s) from the active jobs set, there could be no active 
jobs
     // left depending on the indeterminate stage and hence no need to roll 
back any stages.
     val numActiveJobsWithStageAfterRollback =
       activeJobs.count(job => stagesToRollback.contains(job.finalStage))
-    if (numActiveJobsWithStageAfterRollback == 0) {
+    val hasActiveJobs = numActiveJobsWithStageAfterRollback > 0
+    if (!hasActiveJobs) {
       logInfo(log"All jobs depending on the indeterminate stage " +
         log"(${MDC(STAGE_ID, mapStage.id)}) were aborted.")
     } else {
@@ -1940,6 +1988,94 @@ private[spark] class DAGScheduler(
         log"was retried, we will roll back and rerun its succeeding " +
         log"stages: ${MDC(STAGES, stagesCanRollback)}")
     }
+    // Whether there are still active jobs which need to be processed
+    hasActiveJobs
+  }
+
+  private def getCompletedJobsFromSameQuery(mapStage: ShuffleMapStage): 
Array[ActiveJob] = {
+    import scala.jdk.CollectionConverters._
+    val executionIds = mapStage
+      .jobIds
+      .flatMap(jobId => Option(jobIdToQueryExecutionId.get(jobId)))
+    if (executionIds.size > 1) {
+      logWarning(log"There are multiple queries reuse the same stage: 
${MDC(STAGE, mapStage)}")
+    }
+    executionIds
+      .flatMap(qeId => 
Option(activeQueryToJobs.get(qeId)).map(_.asScala).getOrElse(Set.empty))
+      .diff(activeJobs)
+      .toArray
+      .sortBy(_.jobId)
+  }
+
+  private def abortSucceedingFinalStages(mapStage: ShuffleMapStage, reason: 
String): Unit = {
+    val jobFinalStages = activeJobs.map(_.finalStage).toSet
+
+    collectSucceedingStages(mapStage)
+      .intersect(jobFinalStages)
+      .foreach { stage =>
+        // Abort stage will fail the jobs depending on it, cleaning up the 
stages for these jobs:
+        // 1. cancel running stages
+        // 2. clean up the stages from DAGScheduler if no active jobs 
depending on them, waiting
+        //    stages will be removed and won't be submitted.
+        // As we abort all the succeeding active jobs, all the succeeding 
stages should be
+        // cleaned up.
+        abortStage(stage, reason, exception = None)
+      }
+  }
+
+  // In rollbackSucceedingStages, we assume that the jobs are independent even 
if they reuse
+  // the same shuffle. If an active job hits fetch failure and keeps retrying 
upstream stages
+  // cascadingly, and a shuffle checksum mismatch is detected, these retried 
stages will be
+  // fully rolled back if possible, or the job will be aborted. However, we 
won't do anything for
+  // the other active jobs.
+  //
+  // For SQL query execution, all the jobs from the same query are related as 
they all contribute
+  // to the query result in some ways. For example, the jobs for subquery 
expressions are not part
+  // of the main query RDD, but they are still related. We need to guarantee 
the consistency of all
+  // the jobs, which means if a shuffle stage hits checksum mismatch, all its 
downstream stages in
+  // all the jobs of the same query execution should be rolled back if 
possible.
+  //
+  // This method does:
+  // 1. Find all the jobs triggered by the same query execution including the 
completed ones.
+  // 2. Find all the succeeding stages of the checksum mismatch shuffle stage 
based on shuffle id
+  //    as stage can be different in different jobs even they share the same 
shuffle.
+  // 3. Abort succeeding stages if their leaf result stages have started 
running.
+  // 4. Clean up shuffle outputs of the checksum mismatched shuffle stages 
which are available to
+  //    make sure all these stages would be resubmitted and fully retied.
+  // 5. Cancel running shuffle map stages and resubmit.
+  private[scheduler] def rollbackSucceedingStagesForQuery(mapStage: 
ShuffleMapStage): Unit = {
+    // Find the completed jobs triggered by the same query execution.
+    val completedJobs = getCompletedJobsFromSameQuery(mapStage)
+    val succeedingStagesInCompletedJobs =
+      collectSucceedingStagesByShuffleId(mapStage, completedJobs)
+    logInfo(log"Found succeeding stages ${MDC(STAGES, 
succeedingStagesInCompletedJobs)} of " +
+      log"shuffle checksum mismatch stage ${MDC(STAGE, mapStage)} in completed 
jobs: (" +
+      log"${MDC(JOB_IDS, completedJobs.map(_.jobId).mkString(","))})")
+
+    // Abort all the succeeding final stages in active jobs to fail fast and 
avoid wasting
+    // resources if there are succeeding result stages in completed jobs.
+    val completedResultStages =
+      succeedingStagesInCompletedJobs.collect { case r: ResultStage => r }
+    if (completedResultStages.nonEmpty) {
+      val reason = s"cannot rollback completed result stages 
${completedResultStages}, " +
+        s"please re-run the query to ensure data correctness"
+      abortSucceedingFinalStages(mapStage, reason)
+      return
+    }
+
+    // Rollback the succeeding stages in active jobs. If there are no active 
jobs left after
+    // rollback, we can skip the rollback for completed jobs.
+    val hasActiveJobs = rollbackSucceedingStages(mapStage)
+    if (hasActiveJobs) {
+      // Rollback the shuffle map stages in completed jobs to make sure the 
completed shuffle
+      // map stages would be re-submitted and fully retried.
+      succeedingStagesInCompletedJobs.collect { case s: ShuffleMapStage => s }
+        .foreach(stage => rollbackShuffleMapStage(stage, "rolling back due to 
indeterminate " +
+          s"output of shuffle map stage $mapStage"))
+    } else {
+      logInfo(log"All jobs depending on the checksum mismatch stage " +
+        log"(${MDC(STAGE_ID, mapStage.id)}) were aborted, skip the rollback.")
+    }
   }
 
   /**
@@ -2189,7 +2325,11 @@ private[spark] class DAGScheduler(
                       && shuffleStage.isRuntimeIndeterminate) {
                     if (shuffleStage.maxChecksumMismatchedId < 
smt.stageAttemptId) {
                       shuffleStage.maxChecksumMismatchedId = smt.stageAttemptId
-                      rollbackSucceedingStages(shuffleStage)
+                      if 
(shuffleStage.shuffleDep.checksumMismatchQueryLevelRollbackEnabled) {
+                        rollbackSucceedingStagesForQuery(shuffleStage)
+                      } else {
+                        rollbackSucceedingStages(shuffleStage)
+                      }
                     }
                   }
                 }
@@ -2484,6 +2624,32 @@ private[spark] class DAGScheduler(
     }
   }
 
+  private def collectSucceedingStagesByShuffleId(
+      mapStage: ShuffleMapStage, jobs: Array[ActiveJob]): HashSet[Stage] = {
+    val succeedingStages = HashSet[Stage]()
+    val shuffleId = mapStage.shuffleDep.shuffleId
+
+    def getShuffleId(stage: Stage): Option[Int] = stage match {
+      case s: ShuffleMapStage => Some(s.shuffleDep.shuffleId)
+      case _ => None
+    }
+
+    def collectSucceedingStagesInternal(stageChain: List[Stage]): Unit = {
+      val head = stageChain.head
+      if (getShuffleId(head).contains(shuffleId)) {
+        stageChain.drop(1).foreach { s =>
+          succeedingStages += s
+        }
+      } else {
+        head.parents.foreach { s =>
+          collectSucceedingStagesInternal(s :: stageChain)
+        }
+      }
+    }
+    jobs.foreach(job => collectSucceedingStagesInternal(job.finalStage :: Nil))
+    succeedingStages
+  }
+
   private def collectSucceedingStages(mapStage: ShuffleMapStage): 
HashSet[Stage] = {
     // TODO: perhaps materialize this if we are going to compute it often 
enough ?
     // It's a little tricky to find all the succeeding stages of `mapStage`, 
because
@@ -2543,6 +2709,8 @@ private[spark] class DAGScheduler(
           } else {
             rollingBackStages += mapStage
           }
+        } else if (runningStages.contains(mapStage)) {
+          rollingBackStages += mapStage
         }
 
       case resultStage: ResultStage if resultStage.activeJob.isDefined =>
@@ -2550,6 +2718,10 @@ private[spark] class DAGScheduler(
         if (numMissingPartitions < resultStage.numTasks) {
           // TODO: support to rollback result tasks.
           abortStage(resultStage, generateErrorMessage(resultStage), None)
+        } else if (runningStages.contains(resultStage)) {
+          val reason = "cannot rollback a running result stage, please re-run 
the query " +
+            "to ensure data correctness"
+          abortStage(resultStage, reason, None)
         }
 
       case _ =>
@@ -3362,6 +3534,9 @@ private[scheduler] class 
DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
     case AllJobsCancelled =>
       dagScheduler.doCancelAllJobs()
 
+    case CleanupQueryJobs(executionId) =>
+      dagScheduler.doCleanupQueryJobs(executionId)
+
     case ExecutorAdded(execId, host) =>
       dagScheduler.handleExecutorAdded(execId, host)
 
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 8932d2ef323b..cc788e5c65bc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -78,6 +78,8 @@ private[scheduler] case class JobTagCancelled(
 
 private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
 
+private[scheduler] case class CleanupQueryJobs(executionId: Long) extends 
DAGSchedulerEvent
+
 private[scheduler]
 case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends 
DAGSchedulerEvent
 
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 6f0fd9608334..5cbb0654eb37 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -3493,6 +3493,307 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assertDataStructuresEmpty()
   }
 
+  test("SPARK-55064: abort stage if checksum mismatch detected with succeeding 
" +
+    "result stage in completed jobs") {
+    val executionId = 55064L
+    val properties = new Properties()
+    properties.setProperty(SparkContext.SQL_EXECUTION_ID_KEY, 
executionId.toString)
+    try {
+      val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+      val shuffleDep1 = new ShuffleDependency(
+        shuffleMapRdd1,
+        new HashPartitioner(2),
+        _checksumMismatchFullRetryEnabled = true,
+        checksumMismatchQueryLevelRollbackEnabled = true)
+      val shuffleId1 = shuffleDep1.shuffleId
+
+      // Submit and complete the 1st job depending on shuffleDep1
+      val finalRdd1 = new MyRDD(sc, 2, List(shuffleDep1), tracker = 
mapOutputTracker)
+      submit(finalRdd1, Array(0, 1), properties = properties)
+      // Finish the first shuffle map stage.
+      completeShuffleMapStageSuccessfully(
+        0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100)
+      assert(mapOutputTracker.findMissingPartitions(shuffleId1) === 
Some(Seq.empty))
+      // Finish the result stage.
+      completeNextResultStageWithSuccess(1, 0)
+      assertDataStructuresEmpty()
+
+      // Submit the 2nd job depending on shuffleDep1, and fail it by checksum 
mismatch
+      val finalRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = 
mapOutputTracker)
+      submit(finalRdd2, Array(0, 1), properties = properties)
+      // The first task failed with FetchFailed.
+      runEvent(makeCompletionEvent(
+        taskSets(2).tasks(0),
+        FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, 
"ignored"),
+        null))
+
+      // Check status for all failedStages.
+      val failedStages = scheduler.failedStages.toSeq
+      assert(failedStages.map(_.id) == Seq(2, 3))
+      scheduler.resubmitFailedStages()
+
+      // Complete the re-attempt of shuffle map stage 2(shuffleId1) with a 
different checksum.
+      completeShuffleMapStageSuccessfully(2, 1, 2, checksumVal = 101)
+      assert(failure != null && failure.getMessage.contains(
+        "cannot rollback completed result stages"))
+    } finally {
+      scheduler.cleanupQueryJobs(executionId)
+    }
+  }
+
+  test("SPARK-55064: clean up shuffle data for the succeeding available 
shuffle map stages " +
+    "in completed jobs if checksum mismatch detected") {
+    val executionId = 55064L
+    val properties = new Properties()
+    properties.setProperty(SparkContext.SQL_EXECUTION_ID_KEY, 
executionId.toString)
+    try {
+      val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+      val shuffleDep1 = new ShuffleDependency(
+        shuffleMapRdd1,
+        new HashPartitioner(2),
+        _checksumMismatchFullRetryEnabled = true,
+        checksumMismatchQueryLevelRollbackEnabled = true)
+      val shuffleId1 = shuffleDep1.shuffleId
+
+      // Submit and complete the 1st shuffle map stage job depending on 
shuffleDep2
+      val shuffleMapRdd2 = new MyRDD(
+        sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+      val shuffleDep2 = new ShuffleDependency(
+        shuffleMapRdd2,
+        new HashPartitioner(2),
+        _checksumMismatchFullRetryEnabled = true,
+        checksumMismatchQueryLevelRollbackEnabled = true)
+      val shuffleId2 = shuffleDep2.shuffleId
+      val mapStageJobId = submitMapStage(shuffleDep2)
+      scheduler.activeQueryToJobs
+        .computeIfAbsent(executionId, _ => 
java.util.concurrent.ConcurrentHashMap.newKeySet())
+        .add(scheduler.jobIdToActiveJob(mapStageJobId))
+      scheduler.jobIdToQueryExecutionId.put(mapStageJobId, executionId)
+
+      // Finish the stages.
+      completeShuffleMapStageSuccessfully(
+        0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100)
+      completeShuffleMapStageSuccessfully(
+        1, 0, 2, Seq("hostB", "hostC"), checksumVal = 200)
+
+      assert(mapOutputTracker.findMissingPartitions(shuffleId1) === 
Some(Seq.empty))
+      assert(mapOutputTracker.findMissingPartitions(shuffleId2) === 
Some(Seq.empty))
+
+      // Submit the 2nd job depending on shuffleDep1, and fail it by checksum 
mismatch
+      val finalRdd1 = new MyRDD(sc, 2, List(shuffleDep1), tracker = 
mapOutputTracker)
+      submit(finalRdd1, Array(0, 1), properties = properties)
+      // The first task failed with FetchFailed.
+      val resultTaskSet = taskSets.last
+      runEvent(makeCompletionEvent(
+        resultTaskSet.tasks(0),
+        FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, 
"ignored"),
+        null))
+
+      // Check status for all failedStages.
+      val failedStages = scheduler.failedStages.toSeq
+      assert(failedStages.exists {
+        case stage: ShuffleMapStage => stage.shuffleDep.shuffleId == shuffleId1
+        case _ => false
+      })
+      assert(failedStages.exists(_.isInstanceOf[ResultStage]))
+      scheduler.resubmitFailedStages()
+
+      assert(mapOutputTracker.getNumAvailableOutputs(shuffleId2) === 2)
+      // Complete the re-attempt of shuffle map stage (shuffleId1) with a 
different checksum.
+      val retryShuffleTaskSet = taskSets.last
+      assert(retryShuffleTaskSet.shuffleId.contains(shuffleId1))
+      completeShuffleMapStageSuccessfully(
+        retryShuffleTaskSet.stageId,
+        retryShuffleTaskSet.stageAttemptId,
+        2,
+        checksumVal = 101)
+      // Output of shuffleId2 should be cleaned up and job on finalRdd1 will 
be resubmitted.
+      assert(mapOutputTracker.getNumAvailableOutputs(shuffleId2) === 0)
+      val expectedResults = Seq((Success, 11), (Success, 12))
+      val expectedResultsMap: Map[Int, Any] = Map(0 -> 11, 1 -> 12)
+      completeAndCheckAnswer(taskSets.last, expectedResults, 
expectedResultsMap)
+
+      // Retry all tasks of shuffleId2 for the new jobs
+      val finalRdd2 = new MyRDD(sc, 2, List(shuffleDep2), tracker = 
mapOutputTracker)
+      submit(finalRdd2, Array(0, 1), properties = properties)
+      val shuffleRetryTaskSet = taskSets.last
+      assert(shuffleRetryTaskSet.shuffleId.contains(shuffleId2))
+      completeShuffleMapStageSuccessfully(
+        shuffleRetryTaskSet.stageId,
+        shuffleRetryTaskSet.stageAttemptId,
+        2,
+        checksumVal = 200)
+      completeAndCheckAnswer(taskSets.last, expectedResults, 
expectedResultsMap)
+    } finally {
+      scheduler.cleanupQueryJobs(executionId)
+    }
+  }
+
+  test("SPARK-55064: clean up shuffle data for the succeeding available 
shuffle map stages " +
+    "in active jobs if checksum mismatch detected") {
+    val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+    val shuffleDep1 = new ShuffleDependency(
+      shuffleMapRdd1,
+      new HashPartitioner(2),
+      _checksumMismatchFullRetryEnabled = true,
+      checksumMismatchQueryLevelRollbackEnabled = true)
+    val shuffleId1 = shuffleDep1.shuffleId
+
+    val shuffleMapRdd2 = new MyRDD(
+      sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+    val shuffleMapRdd3 = new MyRDD(
+      sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+
+    val shuffleDep2 = new ShuffleDependency(
+      shuffleMapRdd2,
+      new HashPartitioner(2),
+      _checksumMismatchFullRetryEnabled = true,
+      checksumMismatchQueryLevelRollbackEnabled = true)
+    val shuffleDep3 = new ShuffleDependency(
+      shuffleMapRdd3,
+      new HashPartitioner(2),
+      _checksumMismatchFullRetryEnabled = true,
+      checksumMismatchQueryLevelRollbackEnabled = true)
+
+    val finalRdd = new MyRDD(sc, 2, List(shuffleDep2, shuffleDep3), tracker = 
mapOutputTracker)
+    submit(finalRdd, Array(0, 1))
+
+    // Finish shuffle map stage 0 and 1.
+    completeShuffleMapStageSuccessfully(
+      0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100)
+    completeShuffleMapStageSuccessfully(
+      1, 0, 2, Seq("hostB", "hostC"), checksumVal = 200)
+    assert(mapOutputTracker.findMissingPartitions(shuffleId1) === 
Some(Seq.empty))
+    val shuffleId2 = taskSets(1).shuffleId.get
+    assert(mapOutputTracker.findMissingPartitions(shuffleId2) === 
Some(Seq.empty))
+
+    val shuffleId3 = taskSets(2).shuffleId.get
+    // The first task of shuffle map stage 2 failed with FetchFailed.
+    runEvent(makeCompletionEvent(
+      taskSets(2).tasks(0),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, 
"ignored"),
+      null))
+
+    // Check status for all failedStages.
+    val failedStages = scheduler.failedStages.toSeq
+    assert(failedStages.map(_.id) == Seq(0, 2))
+    scheduler.resubmitFailedStages()
+
+    assert(mapOutputTracker.getNumAvailableOutputs(shuffleId2) === 2)
+    // Complete the re-attempt of shuffle map stage 0(shuffleId1) with a 
different checksum.
+    checkAndCompleteRetryStage(3, 0, shuffleId1, 1, checksumVal = 101)
+    // Output of shuffleId2 should be cleaned up and will be resubmitted.
+    assert(mapOutputTracker.getNumAvailableOutputs(shuffleId2) === 0)
+    checkAndCompleteRetryStage(4, 2, shuffleId3, 2, checksumVal = 300)
+    checkAndCompleteRetryStage(5, 1, shuffleId2, 2, checksumVal = 200)
+    completeAndCheckAnswer(taskSets(6), Seq((Success, 11), (Success, 12)), 
Map(0 -> 11, 1 -> 12))
+  }
+
+  test("SPARK-55064: cancel and resubmit running succeeding shuffle map stages 
" +
+    "if checksum mismatch detected") {
+    val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+    val shuffleDep1 = new ShuffleDependency(
+      shuffleMapRdd1,
+      new HashPartitioner(2),
+      _checksumMismatchFullRetryEnabled = true,
+      checksumMismatchQueryLevelRollbackEnabled = true)
+    val shuffleId1 = shuffleDep1.shuffleId
+
+    val shuffleMapRdd2 = new MyRDD(
+      sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+    val shuffleMapRdd3 = new MyRDD(
+      sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+
+    val shuffleDep2 = new ShuffleDependency(
+      shuffleMapRdd2,
+      new HashPartitioner(2),
+      _checksumMismatchFullRetryEnabled = true,
+      checksumMismatchQueryLevelRollbackEnabled = true)
+    val shuffleDep3 = new ShuffleDependency(
+      shuffleMapRdd3,
+      new HashPartitioner(2),
+      _checksumMismatchFullRetryEnabled = true,
+      checksumMismatchQueryLevelRollbackEnabled = true)
+
+    val finalRdd = new MyRDD(sc, 2, List(shuffleDep2, shuffleDep3), tracker = 
mapOutputTracker)
+    submit(finalRdd, Array(0, 1))
+
+    // Finish shuffle map stage 0
+    completeShuffleMapStageSuccessfully(
+      0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100)
+    assert(mapOutputTracker.findMissingPartitions(shuffleId1) === 
Some(Seq.empty))
+
+    val shuffleId2 = taskSets(1).shuffleId.get
+    val shuffleId3 = taskSets(2).shuffleId.get
+
+    // The first task of shuffle map stage 2 failed with FetchFailed.
+    runEvent(makeCompletionEvent(
+      taskSets(2).tasks(0),
+      FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, 
"ignored"),
+      null))
+
+    // Check status for all failedStages.
+    val failedStages = scheduler.failedStages.toSeq
+    assert(failedStages.map(_.id) == Seq(0, 2))
+    scheduler.resubmitFailedStages()
+
+    // Complete the re-attempt of shuffle map stage 0(shuffleId1) with a 
different checksum.
+    checkAndCompleteRetryStage(3, 0, shuffleId1, 1, checksumVal = 101)
+    // Shuffle map stage 1 should be cancelled and resubmitted.
+    assert(scheduler.failedStages.map(_.id).toSeq == Seq(1))
+    scheduler.resubmitFailedStages()
+
+    checkAndCompleteRetryStage(4, 2, shuffleId3, 2, checksumVal = 200)
+    checkAndCompleteRetryStage(5, 1, shuffleId2, 2, checksumVal = 300)
+    completeAndCheckAnswer(taskSets(6), Seq((Success, 11), (Success, 12)), 
Map(0 -> 11, 1 -> 12))
+  }
+
+  test("SPARK-55064: abort stage if checksum mismatch detected with succeeding 
" +
+    "running result stage in active jobs") {
+    val executionId = 55064L
+    val properties = new Properties()
+    properties.setProperty(SparkContext.SQL_EXECUTION_ID_KEY, 
executionId.toString)
+    try {
+      val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+      val shuffleDep1 = new ShuffleDependency(
+        shuffleMapRdd1,
+        new HashPartitioner(2),
+        _checksumMismatchFullRetryEnabled = true,
+        checksumMismatchQueryLevelRollbackEnabled = true)
+      val shuffleId1 = shuffleDep1.shuffleId
+
+      // Submit 2 jobs depending on shuffleDep1
+      val finalRdd1 = new MyRDD(
+        sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+      submit(finalRdd1, Array(0, 1), properties = properties)
+      val finalRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = 
mapOutputTracker)
+      submit(finalRdd2, Array(0, 1), properties = properties)
+
+      // Finish the stages.
+      completeShuffleMapStageSuccessfully(
+        0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100)
+      assert(mapOutputTracker.findMissingPartitions(shuffleId1) === 
Some(Seq.empty))
+
+      // The first task of result stage 2 failed with FetchFailed.
+      runEvent(makeCompletionEvent(
+        taskSets(2).tasks(0),
+        FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, 
"ignored"),
+        null))
+
+      // Check status for all failedStages.
+      val failedStages = scheduler.failedStages.toSeq
+      assert(failedStages.map(_.id) == Seq(0, 2))
+      scheduler.resubmitFailedStages()
+
+      // Complete the re-attempt of shuffle map stage 0(shuffleId1) with a 
different checksum.
+      completeShuffleMapStageSuccessfully(0, 1, 2, checksumVal = 101)
+      assert(failure != null && failure.getMessage.contains(
+        "cannot rollback a running result stage"))
+    } finally {
+      scheduler.cleanupQueryJobs(executionId)
+    }
+  }
+
   private def checkAndCompleteRetryStage(
       taskSetIndex: Int,
       stageId: Int,
@@ -3502,6 +3803,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
       stageAttemptId: Int = 1): Unit = {
     assert(taskSets(taskSetIndex).stageId == stageId)
     assert(taskSets(taskSetIndex).stageAttemptId == stageAttemptId)
+    assert(taskSets(taskSetIndex).shuffleId.contains(shuffleId))
     assert(taskSets(taskSetIndex).tasks.length == numTasks)
     completeShuffleMapStageSuccessfully(stageId, stageAttemptId, 2, 
checksumVal = checksumVal)
     assert(mapOutputTracker.findMissingPartitions(shuffleId) === 
Some(Seq.empty))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 585dca249d3d..e86466826019 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -917,6 +917,17 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  private[spark] val SHUFFLE_CHECKSUM_MISMATCH_QUERY_LEVEL_ROLLBACK_ENABLED =
+    
buildConf("spark.sql.shuffle.orderIndependentChecksum.enableQueryLevelRollbackOnMismatch")
+      .internal()
+      .doc("Whether to rollback all the consumer stages from the same query 
executor " +
+        "when we detect checksum mismatches with its producer stages, 
including cancel " +
+        "running shuffle map stages and resubmit, clean up all the shuffle 
data written " +
+        "for available shuffle map stages and abort the running result 
stages.")
+      .version("4.2.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE =
     buildConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize")
       .internal()
@@ -7201,6 +7212,9 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
   def shuffleChecksumMismatchFullRetryEnabled: Boolean =
     getConf(SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED)
 
+  def shuffleChecksumMismatchQueryLevelRollbackEnabled: Boolean =
+    getConf(SHUFFLE_CHECKSUM_MISMATCH_QUERY_LEVEL_ROLLBACK_ENABLED)
+
   def allowCollationsInMapKeys: Boolean = getConf(ALLOW_COLLATIONS_IN_MAP_KEYS)
 
   def objectLevelCollationsEnabled: Boolean = 
getConf(OBJECT_LEVEL_COLLATIONS_ENABLED)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 45fa6c60f465..cf26b2991652 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -238,6 +238,16 @@ object SQLExecution extends Logging {
               event.duration = endTime - startTime
               event.qe = queryExecution
               event.executionFailure = ex
+              if (Utils.isTesting) {
+                import scala.jdk.CollectionConverters._
+                event.jobIds = 
Option(sc.dagScheduler.activeQueryToJobs.get(executionId))
+                  .map(_.asScala.map(_.jobId).toSet)
+                  .getOrElse(Set.empty)
+              }
+
+              // Clean up jobs tracked by DAGScheduler for this query 
execution.
+              sc.dagScheduler.cleanupQueryJobs(executionId)
+
               sc.listenerBus.post(event)
 
               // Observation.tryComplete is called here to ensure the 
observation is completed,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index a1f693ef5c15..95120039a6f9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -495,7 +495,9 @@ object ShuffleExchangeExec {
         serializer,
         shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics),
         rowBasedChecksums = 
UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize),
-        _checksumMismatchFullRetryEnabled = 
SQLConf.get.shuffleChecksumMismatchFullRetryEnabled)
+        _checksumMismatchFullRetryEnabled = 
SQLConf.get.shuffleChecksumMismatchFullRetryEnabled,
+        checksumMismatchQueryLevelRollbackEnabled =
+          SQLConf.get.shuffleChecksumMismatchQueryLevelRollbackEnabled)
 
     dependency
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index 4e773390b2f7..6fbd6ad62a2d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -90,6 +90,9 @@ case class SparkListenerSQLExecutionEnd(
 
   // The exception object that caused this execution to fail. None if the 
execution doesn't fail.
   @JsonIgnore private[sql] var executionFailure: Option[Throwable] = None
+
+  // The jobs for this execution. Test only.
+  @JsonIgnore private[sql] var jobIds: Set[Int] = Set.empty
 }
 
 /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
index af3a0f3e3710..877e6970368b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
@@ -17,9 +17,11 @@
 package org.apache.spark.sql.execution
 
 import scala.collection.mutable
+import scala.concurrent.duration.DurationInt
 import scala.io.Source
 import scala.util.Try
 
+import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, 
SparkListenerJobStart}
 import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator, 
FastOperator, SaveMode}
 import org.apache.spark.sql.catalyst.{QueryPlanningTracker, 
QueryPlanningTrackerCallback, TableIdentifier}
 import org.apache.spark.sql.catalyst.analysis.{CurrentNamespace, 
UnresolvedFunction, UnresolvedRelation}
@@ -33,6 +35,7 @@ import 
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
QueryStageExec}
 import org.apache.spark.sql.execution.datasources.v2.ShowTablesExec
 import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.storage.ShuffleIndexBlockId
@@ -519,6 +522,66 @@ class QueryExecutionSuite extends SharedSparkSession {
     }
   }
 
+  test("SPARK-55064: Jobs are tracked in DAGScheduler before query finished") {
+    val jobs = new mutable.HashSet[Int]()
+
+    var sqlExecutionEndVerified: Boolean = false
+    val listener = new SparkListener {
+      override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+        jobs += jobStart.jobId
+      }
+
+      override def onOtherEvent(event: SparkListenerEvent): Unit = {
+        event match {
+          case e: SparkListenerSQLExecutionEnd =>
+            val jobsTracked = e.jobIds
+            assert(jobsTracked == jobs)
+            jobs.clear()
+            sqlExecutionEndVerified = true
+          case _ =>
+        }
+      }
+    }
+    spark.sparkContext.addSparkListener(listener)
+
+    try {
+      withTable("t1", "t2") {
+        spark.range(10).selectExpr("id as A", "id as D")
+          .write.partitionBy("A").mode("overwrite").saveAsTable("t1")
+        spark.range(2).selectExpr("id as B", "id as C")
+          .write.mode("overwrite").saveAsTable("t2")
+
+        Seq(true, false).foreach { adaptiveEnabled =>
+          withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> 
adaptiveEnabled.toString) {
+            sqlExecutionEndVerified = false
+            val df = sql(
+              """
+                |SELECT avg(A) FROM t1
+                |WHERE A = (
+                |  SELECT B from t2 where C = 1
+                |)
+              """.stripMargin)
+
+            df.collect()
+            spark.sparkContext.listenerBus.waitUntilEmpty()
+            assert(sqlExecutionEndVerified)
+            // The jobs tracked by DAGScheduler should be cleared after the 
query is done.
+            eventually(timeout(10.seconds)) {
+              assert(spark.sparkContext.dagScheduler.activeQueryToJobs.isEmpty)
+              
assert(spark.sparkContext.dagScheduler.jobIdToQueryExecutionId.isEmpty)
+            }
+          }
+        }
+      }
+      eventually(timeout(10.seconds)) {
+        assert(spark.sparkContext.dagScheduler.activeQueryToJobs.isEmpty)
+        assert(spark.sparkContext.dagScheduler.jobIdToQueryExecutionId.isEmpty)
+      }
+    } finally {
+      spark.sparkContext.removeSparkListener(listener)
+    }
+  }
+
   case class MockCallbackEagerCommand(
       var trackerAnalyzed: QueryPlanningTracker = null,
       var trackerReadyForExecution: QueryPlanningTracker = null)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to