This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7e4a040e07c9 [SPARK-55064][SQL][CORE] Support query level
indeterminate shuffle retry
7e4a040e07c9 is described below
commit 7e4a040e07c90d539d26d7878e1dfdadd289f075
Author: Tengfei Huang <[email protected]>
AuthorDate: Thu Jan 29 15:47:49 2026 +0800
[SPARK-55064][SQL][CORE] Support query level indeterminate shuffle retry
### What changes were proposed in this pull request?
Currently when a checksum mismatch detected for a indeterminate shuffle map
stage, we'll find and validate all the succeeding stages in `active jobs`.
There are a few problems here especially for complicated queries with
sub-queries:
1. The jobs reuse the same shuffle output submitted by sub-queries could
have been finished which can not be tracked through active jobs;
2. The stage(stage id) in a sub-query job could be different even though
they share the same shuffle output;
3. Succeeding shuffle map stages with available shuffle output are skipped
which may have inconsistent data;
4. There could be succeeding stages running which consumed the old shuffle
data;
To further cover these scenarios, this PR proposed to:
1. Tracking all the jobs submitted by the same query execution used for
finding all the succeeding stages within the same query;
2. Find the succeeding stages by comparing the shuffle Id instead of stage
Id for completed jobs;
3. Abort the current indeterminate shuffle map stage if there are completed
succeeding result stages;
4. Abort the running succeeding result stages;
6. Rollback all succeeding shuffle map stages:
a. Cancel running stages, clean up shuffle output and resubmit;
b. Clean up shuffle data for available completed shuffle map stages;
### Why are the changes needed?
Ensure correctness for queries with indeterminate shuffle map stage retries.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UTs
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53868 from ivoson/SPARK-55064.
Authored-by: Tengfei Huang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../java/org/apache/spark/internal/LogKeys.java | 1 +
.../main/scala/org/apache/spark/Dependency.scala | 3 +-
.../main/scala/org/apache/spark/SparkContext.scala | 1 +
.../org/apache/spark/scheduler/DAGScheduler.scala | 181 +++++++++++-
.../apache/spark/scheduler/DAGSchedulerEvent.scala | 2 +
.../apache/spark/scheduler/DAGSchedulerSuite.scala | 302 +++++++++++++++++++++
.../org/apache/spark/sql/internal/SQLConf.scala | 14 +
.../apache/spark/sql/execution/SQLExecution.scala | 10 +
.../execution/exchange/ShuffleExchangeExec.scala | 4 +-
.../spark/sql/execution/ui/SQLListener.scala | 3 +
.../spark/sql/execution/QueryExecutionSuite.scala | 63 +++++
11 files changed, 579 insertions(+), 5 deletions(-)
diff --git
a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
index ad51da5bbc6e..59df0423fad2 100644
--- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
+++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
@@ -316,6 +316,7 @@ public enum LogKeys implements LogKey {
JAVA_VERSION,
JAVA_VM_NAME,
JOB_ID,
+ JOB_IDS,
JOIN_CONDITION,
JOIN_CONDITION_SUB_EXPR,
JOIN_TYPE,
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala
b/core/src/main/scala/org/apache/spark/Dependency.scala
index c94ce35cb250..da97aff5b344 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -90,7 +90,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C:
ClassTag](
val mapSideCombine: Boolean = false,
val shuffleWriterProcessor: ShuffleWriteProcessor = new
ShuffleWriteProcessor,
val rowBasedChecksums: Array[RowBasedChecksum] =
ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS,
- private val _checksumMismatchFullRetryEnabled: Boolean = false)
+ private val _checksumMismatchFullRetryEnabled: Boolean = false,
+ val checksumMismatchQueryLevelRollbackEnabled: Boolean = false)
extends Dependency[Product2[K, V]] with Logging {
def this(
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala
b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 6f8be49e3959..8c92f4c10aa5 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -3152,6 +3152,7 @@ object SparkContext extends Logging {
private[spark] val SPARK_SCHEDULER_POOL = "spark.scheduler.pool"
private[spark] val RDD_SCOPE_KEY = "spark.rdd.scope"
private[spark] val RDD_SCOPE_NO_OVERRIDE_KEY = "spark.rdd.scope.noOverride"
+ private[spark] val SQL_EXECUTION_ID_KEY = "spark.sql.execution.id"
/**
* Executor id for the driver. In earlier versions of Spark, this was
`<driver>`, but this was
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 67785a0ce9ea..0c8d437c98bd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -171,6 +171,14 @@ private[spark] class DAGScheduler(
private[scheduler] val activeJobs = new HashSet[ActiveJob]
+ // Track all the jobs submitted by the same query execution, will clean up
after
+ // the query finishes. Use ConcurrentHashMap to allow thread-safe access from
+ // different threads (e.g., SQLExecution during testing).
+ private[spark] val activeQueryToJobs =
+ new ConcurrentHashMap[Long, java.util.Set[ActiveJob]]()
+
+ private[spark] val jobIdToQueryExecutionId = new ConcurrentHashMap[Int,
java.lang.Long]()
+
// Job groups that are cancelled with `cancelFutureJobs` as true, with at
most
// `NUM_CANCELLED_JOB_GROUPS_TO_TRACK` stored. On a new job submission, if
its job group is in
// this set, the job will be immediately cancelled.
@@ -1145,6 +1153,14 @@ private[spark] class DAGScheduler(
eventProcessLoop.post(AllJobsCancelled)
}
+ /**
+ * Cleanup the jobs of query execution tracked in the DAGScheduler.
+ */
+ def cleanupQueryJobs(executionId: Long): Unit = {
+ logInfo(s"Asked to cleanup jobs for query execution $executionId")
+ eventProcessLoop.post(CleanupQueryJobs(executionId))
+ }
+
private[scheduler] def doCancelAllJobs(): Unit = {
// Cancel all running jobs.
runningStages.map(_.firstJobId).foreach(handleJobCancellation(_,
@@ -1153,6 +1169,12 @@ private[spark] class DAGScheduler(
jobIdToActiveJob.clear() // but just in case we lost track of some jobs...
}
+ private[spark] def doCleanupQueryJobs(executionId: Long): Unit = {
+ Option(activeQueryToJobs.remove(executionId)).foreach { jobs =>
+ jobs.forEach(job => jobIdToQueryExecutionId.remove(job.jobId))
+ }
+ }
+
/**
* Cancel all jobs associated with a running or scheduled stage.
*/
@@ -1325,6 +1347,18 @@ private[spark] class DAGScheduler(
listenerBus.post(SparkListenerTaskGettingResult(taskInfo))
}
+ private def getQueryExecutionIdFromProperties(properties: Properties):
Option[Long] = {
+ try {
+ Option(properties)
+ .flatMap(properties =>
Option(properties.getProperty(SparkContext.SQL_EXECUTION_ID_KEY)))
+ .map(_.toLong)
+ } catch {
+ case e: Throwable =>
+ logWarning(log"Failed to get query execution ID from job properties",
e)
+ None
+ }
+ }
+
private[scheduler] def handleJobSubmitted(
jobId: Int,
finalRDD: RDD[_],
@@ -1398,6 +1432,10 @@ private[spark] class DAGScheduler(
val jobSubmissionTime = clock.getTimeMillis()
jobIdToActiveJob(jobId) = job
activeJobs += job
+ getQueryExecutionIdFromProperties(properties).foreach { qeId =>
+ activeQueryToJobs.computeIfAbsent(qeId, _ =>
ConcurrentHashMap.newKeySet()).add(job)
+ jobIdToQueryExecutionId.put(jobId, qeId)
+ }
finalStage.setActiveJob(job)
val stageIds = jobIdToStageIds(jobId).toArray
val stageInfos =
@@ -1441,6 +1479,10 @@ private[spark] class DAGScheduler(
val jobSubmissionTime = clock.getTimeMillis()
jobIdToActiveJob(jobId) = job
activeJobs += job
+ getQueryExecutionIdFromProperties(properties).foreach { qeId =>
+ activeQueryToJobs.computeIfAbsent(qeId, _ =>
ConcurrentHashMap.newKeySet()).add(job)
+ jobIdToQueryExecutionId.put(jobId, qeId)
+ }
finalStage.addActiveJob(job)
val stageIds = jobIdToStageIds(jobId).toArray
val stageInfos =
@@ -1898,22 +1940,28 @@ private[spark] class DAGScheduler(
* stages will be resubmitted and re-try all the tasks, as the input data
may be changed after
* the map tasks are re-tried. For stages where rollback and retry all tasks
are not possible,
* we will need to abort the stages.
+ *
+ * @return true if the corresponding active jobs are not aborted and will
continue to run after rollback, otherwise false
*/
private[scheduler] def rollbackSucceedingStages(
mapStage: ShuffleMapStage,
- rollbackCurrentStage: Boolean = false): Unit = {
+ rollbackCurrentStage: Boolean = false): Boolean = {
val stagesToRollback = if (rollbackCurrentStage) {
collectSucceedingStages(mapStage)
} else {
collectSucceedingStages(mapStage).filterNot(_ == mapStage)
}
+ logInfo(log"Found succeeding stages ${MDC(STAGES, stagesToRollback)} of " +
+ log"shuffle checksum mismatch stage ${MDC(STAGE, mapStage)} in active
jobs")
val stagesCanRollback =
filterAndAbortUnrollbackableStages(stagesToRollback)
+
// stages which cannot be rolled back were aborted which leads to removing
the
// the dependant job(s) from the active jobs set, there could be no active
jobs
// left depending on the indeterminate stage and hence no need to roll
back any stages.
val numActiveJobsWithStageAfterRollback =
activeJobs.count(job => stagesToRollback.contains(job.finalStage))
- if (numActiveJobsWithStageAfterRollback == 0) {
+ val hasActiveJobs = numActiveJobsWithStageAfterRollback > 0
+ if (!hasActiveJobs) {
logInfo(log"All jobs depending on the indeterminate stage " +
log"(${MDC(STAGE_ID, mapStage.id)}) were aborted.")
} else {
@@ -1940,6 +1988,94 @@ private[spark] class DAGScheduler(
log"was retried, we will roll back and rerun its succeeding " +
log"stages: ${MDC(STAGES, stagesCanRollback)}")
}
+ // Whether there are still active jobs which need to be processed
+ hasActiveJobs
+ }
+
+ private def getCompletedJobsFromSameQuery(mapStage: ShuffleMapStage):
Array[ActiveJob] = {
+ import scala.jdk.CollectionConverters._
+ val executionIds = mapStage
+ .jobIds
+ .flatMap(jobId => Option(jobIdToQueryExecutionId.get(jobId)))
+ if (executionIds.size > 1) {
+ logWarning(log"There are multiple queries reuse the same stage:
${MDC(STAGE, mapStage)}")
+ }
+ executionIds
+ .flatMap(qeId =>
Option(activeQueryToJobs.get(qeId)).map(_.asScala).getOrElse(Set.empty))
+ .diff(activeJobs)
+ .toArray
+ .sortBy(_.jobId)
+ }
+
+ private def abortSucceedingFinalStages(mapStage: ShuffleMapStage, reason:
String): Unit = {
+ val jobFinalStages = activeJobs.map(_.finalStage).toSet
+
+ collectSucceedingStages(mapStage)
+ .intersect(jobFinalStages)
+ .foreach { stage =>
+ // Abort stage will fail the jobs depending on it, cleaning up the
stages for these jobs:
+ // 1. cancel running stages
+ // 2. clean up the stages from DAGScheduler if no active jobs
depending on them, waiting
+ // stages will be removed and won't be submitted.
+ // As we abort all the succeeding active jobs, all the succeeding
stages should be
+ // cleaned up.
+ abortStage(stage, reason, exception = None)
+ }
+ }
+
+ // In rollbackSucceedingStages, we assume that the jobs are independent even
if they reuse
+ // the same shuffle. If an active job hits fetch failure and keeps retrying
upstream stages
+ // cascadingly, and a shuffle checksum mismatch is detected, these retried
stages will be
+ // fully rolled back if possible, or the job will be aborted. However, we
won't do anything for
+ // the other active jobs.
+ //
+ // For SQL query execution, all the jobs from the same query are related as
they all contribute
+ // to the query result in some ways. For example, the jobs for subquery
expressions are not part
+ // of the main query RDD, but they are still related. We need to guarantee
the consistency of all
+ // the jobs, which means if a shuffle stage hits checksum mismatch, all its
downstream stages in
+ // all the jobs of the same query execution should be rolled back if
possible.
+ //
+ // This method does:
+ // 1. Find all the jobs triggered by the same query execution including the
completed ones.
+ // 2. Find all the succeeding stages of the checksum mismatch shuffle stage
based on shuffle id
+ // as stage can be different in different jobs even they share the same
shuffle.
+ // 3. Abort succeeding stages if their leaf result stages have started
running.
+ // 4. Clean up shuffle outputs of the checksum mismatched shuffle stages
which are available to
+ // make sure all these stages would be resubmitted and fully retied.
+ // 5. Cancel running shuffle map stages and resubmit.
+ private[scheduler] def rollbackSucceedingStagesForQuery(mapStage:
ShuffleMapStage): Unit = {
+ // Find the completed jobs triggered by the same query execution.
+ val completedJobs = getCompletedJobsFromSameQuery(mapStage)
+ val succeedingStagesInCompletedJobs =
+ collectSucceedingStagesByShuffleId(mapStage, completedJobs)
+ logInfo(log"Found succeeding stages ${MDC(STAGES,
succeedingStagesInCompletedJobs)} of " +
+ log"shuffle checksum mismatch stage ${MDC(STAGE, mapStage)} in completed
jobs: (" +
+ log"${MDC(JOB_IDS, completedJobs.map(_.jobId).mkString(","))})")
+
+ // Abort all the succeeding final stages in active jobs to fail fast and
avoid wasting
+ // resources if there are succeeding result stages in completed jobs.
+ val completedResultStages =
+ succeedingStagesInCompletedJobs.collect { case r: ResultStage => r }
+ if (completedResultStages.nonEmpty) {
+ val reason = s"cannot rollback completed result stages
${completedResultStages}, " +
+ s"please re-run the query to ensure data correctness"
+ abortSucceedingFinalStages(mapStage, reason)
+ return
+ }
+
+ // Rollback the succeeding stages in active jobs. If there are no active
jobs left after
+ // rollback, we can skip the rollback for completed jobs.
+ val hasActiveJobs = rollbackSucceedingStages(mapStage)
+ if (hasActiveJobs) {
+ // Rollback the shuffle map stages in completed jobs to make sure the
completed shuffle
+ // map stages would be re-submitted and fully retried.
+ succeedingStagesInCompletedJobs.collect { case s: ShuffleMapStage => s }
+ .foreach(stage => rollbackShuffleMapStage(stage, "rolling back due to
indeterminate " +
+ s"output of shuffle map stage $mapStage"))
+ } else {
+ logInfo(log"All jobs depending on the checksum mismatch stage " +
+ log"(${MDC(STAGE_ID, mapStage.id)}) were aborted, skip the rollback.")
+ }
}
/**
@@ -2189,7 +2325,11 @@ private[spark] class DAGScheduler(
&& shuffleStage.isRuntimeIndeterminate) {
if (shuffleStage.maxChecksumMismatchedId <
smt.stageAttemptId) {
shuffleStage.maxChecksumMismatchedId = smt.stageAttemptId
- rollbackSucceedingStages(shuffleStage)
+ if
(shuffleStage.shuffleDep.checksumMismatchQueryLevelRollbackEnabled) {
+ rollbackSucceedingStagesForQuery(shuffleStage)
+ } else {
+ rollbackSucceedingStages(shuffleStage)
+ }
}
}
}
@@ -2484,6 +2624,32 @@ private[spark] class DAGScheduler(
}
}
+ private def collectSucceedingStagesByShuffleId(
+ mapStage: ShuffleMapStage, jobs: Array[ActiveJob]): HashSet[Stage] = {
+ val succeedingStages = HashSet[Stage]()
+ val shuffleId = mapStage.shuffleDep.shuffleId
+
+ def getShuffleId(stage: Stage): Option[Int] = stage match {
+ case s: ShuffleMapStage => Some(s.shuffleDep.shuffleId)
+ case _ => None
+ }
+
+ def collectSucceedingStagesInternal(stageChain: List[Stage]): Unit = {
+ val head = stageChain.head
+ if (getShuffleId(head).contains(shuffleId)) {
+ stageChain.drop(1).foreach { s =>
+ succeedingStages += s
+ }
+ } else {
+ head.parents.foreach { s =>
+ collectSucceedingStagesInternal(s :: stageChain)
+ }
+ }
+ }
+ jobs.foreach(job => collectSucceedingStagesInternal(job.finalStage :: Nil))
+ succeedingStages
+ }
+
private def collectSucceedingStages(mapStage: ShuffleMapStage):
HashSet[Stage] = {
// TODO: perhaps materialize this if we are going to compute it often
enough ?
// It's a little tricky to find all the succeeding stages of `mapStage`,
because
@@ -2543,6 +2709,8 @@ private[spark] class DAGScheduler(
} else {
rollingBackStages += mapStage
}
+ } else if (runningStages.contains(mapStage)) {
+ rollingBackStages += mapStage
}
case resultStage: ResultStage if resultStage.activeJob.isDefined =>
@@ -2550,6 +2718,10 @@ private[spark] class DAGScheduler(
if (numMissingPartitions < resultStage.numTasks) {
// TODO: support to rollback result tasks.
abortStage(resultStage, generateErrorMessage(resultStage), None)
+ } else if (runningStages.contains(resultStage)) {
+ val reason = "cannot rollback a running result stage, please re-run
the query " +
+ "to ensure data correctness"
+ abortStage(resultStage, reason, None)
}
case _ =>
@@ -3362,6 +3534,9 @@ private[scheduler] class
DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case AllJobsCancelled =>
dagScheduler.doCancelAllJobs()
+ case CleanupQueryJobs(executionId) =>
+ dagScheduler.doCleanupQueryJobs(executionId)
+
case ExecutorAdded(execId, host) =>
dagScheduler.handleExecutorAdded(execId, host)
diff --git
a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 8932d2ef323b..cc788e5c65bc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -78,6 +78,8 @@ private[scheduler] case class JobTagCancelled(
private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
+private[scheduler] case class CleanupQueryJobs(executionId: Long) extends
DAGSchedulerEvent
+
private[scheduler]
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends
DAGSchedulerEvent
diff --git
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 6f0fd9608334..5cbb0654eb37 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -3493,6 +3493,307 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
assertDataStructuresEmpty()
}
+ test("SPARK-55064: abort stage if checksum mismatch detected with succeeding
" +
+ "result stage in completed jobs") {
+ val executionId = 55064L
+ val properties = new Properties()
+ properties.setProperty(SparkContext.SQL_EXECUTION_ID_KEY,
executionId.toString)
+ try {
+ val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+ val shuffleDep1 = new ShuffleDependency(
+ shuffleMapRdd1,
+ new HashPartitioner(2),
+ _checksumMismatchFullRetryEnabled = true,
+ checksumMismatchQueryLevelRollbackEnabled = true)
+ val shuffleId1 = shuffleDep1.shuffleId
+
+ // Submit and complete the 1st job depending on shuffleDep1
+ val finalRdd1 = new MyRDD(sc, 2, List(shuffleDep1), tracker =
mapOutputTracker)
+ submit(finalRdd1, Array(0, 1), properties = properties)
+ // Finish the first shuffle map stage.
+ completeShuffleMapStageSuccessfully(
+ 0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100)
+ assert(mapOutputTracker.findMissingPartitions(shuffleId1) ===
Some(Seq.empty))
+ // Finish the result stage.
+ completeNextResultStageWithSuccess(1, 0)
+ assertDataStructuresEmpty()
+
+ // Submit the 2nd job depending on shuffleDep1, and fail it by checksum
mismatch
+ val finalRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker =
mapOutputTracker)
+ submit(finalRdd2, Array(0, 1), properties = properties)
+ // The first task failed with FetchFailed.
+ runEvent(makeCompletionEvent(
+ taskSets(2).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0,
"ignored"),
+ null))
+
+ // Check status for all failedStages.
+ val failedStages = scheduler.failedStages.toSeq
+ assert(failedStages.map(_.id) == Seq(2, 3))
+ scheduler.resubmitFailedStages()
+
+ // Complete the re-attempt of shuffle map stage 2(shuffleId1) with a
different checksum.
+ completeShuffleMapStageSuccessfully(2, 1, 2, checksumVal = 101)
+ assert(failure != null && failure.getMessage.contains(
+ "cannot rollback completed result stages"))
+ } finally {
+ scheduler.cleanupQueryJobs(executionId)
+ }
+ }
+
+ test("SPARK-55064: clean up shuffle data for the succeeding available
shuffle map stages " +
+ "in completed jobs if checksum mismatch detected") {
+ val executionId = 55064L
+ val properties = new Properties()
+ properties.setProperty(SparkContext.SQL_EXECUTION_ID_KEY,
executionId.toString)
+ try {
+ val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+ val shuffleDep1 = new ShuffleDependency(
+ shuffleMapRdd1,
+ new HashPartitioner(2),
+ _checksumMismatchFullRetryEnabled = true,
+ checksumMismatchQueryLevelRollbackEnabled = true)
+ val shuffleId1 = shuffleDep1.shuffleId
+
+ // Submit and complete the 1st shuffle map stage job depending on
shuffleDep2
+ val shuffleMapRdd2 = new MyRDD(
+ sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+ val shuffleDep2 = new ShuffleDependency(
+ shuffleMapRdd2,
+ new HashPartitioner(2),
+ _checksumMismatchFullRetryEnabled = true,
+ checksumMismatchQueryLevelRollbackEnabled = true)
+ val shuffleId2 = shuffleDep2.shuffleId
+ val mapStageJobId = submitMapStage(shuffleDep2)
+ scheduler.activeQueryToJobs
+ .computeIfAbsent(executionId, _ =>
java.util.concurrent.ConcurrentHashMap.newKeySet())
+ .add(scheduler.jobIdToActiveJob(mapStageJobId))
+ scheduler.jobIdToQueryExecutionId.put(mapStageJobId, executionId)
+
+ // Finish the stages.
+ completeShuffleMapStageSuccessfully(
+ 0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100)
+ completeShuffleMapStageSuccessfully(
+ 1, 0, 2, Seq("hostB", "hostC"), checksumVal = 200)
+
+ assert(mapOutputTracker.findMissingPartitions(shuffleId1) ===
Some(Seq.empty))
+ assert(mapOutputTracker.findMissingPartitions(shuffleId2) ===
Some(Seq.empty))
+
+ // Submit the 2nd job depending on shuffleDep1, and fail it by checksum
mismatch
+ val finalRdd1 = new MyRDD(sc, 2, List(shuffleDep1), tracker =
mapOutputTracker)
+ submit(finalRdd1, Array(0, 1), properties = properties)
+ // The first task failed with FetchFailed.
+ val resultTaskSet = taskSets.last
+ runEvent(makeCompletionEvent(
+ resultTaskSet.tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0,
"ignored"),
+ null))
+
+ // Check status for all failedStages.
+ val failedStages = scheduler.failedStages.toSeq
+ assert(failedStages.exists {
+ case stage: ShuffleMapStage => stage.shuffleDep.shuffleId == shuffleId1
+ case _ => false
+ })
+ assert(failedStages.exists(_.isInstanceOf[ResultStage]))
+ scheduler.resubmitFailedStages()
+
+ assert(mapOutputTracker.getNumAvailableOutputs(shuffleId2) === 2)
+ // Complete the re-attempt of shuffle map stage (shuffleId1) with a
different checksum.
+ val retryShuffleTaskSet = taskSets.last
+ assert(retryShuffleTaskSet.shuffleId.contains(shuffleId1))
+ completeShuffleMapStageSuccessfully(
+ retryShuffleTaskSet.stageId,
+ retryShuffleTaskSet.stageAttemptId,
+ 2,
+ checksumVal = 101)
+ // Output of shuffleId2 should be cleaned up and job on finalRdd1 will
be resubmitted.
+ assert(mapOutputTracker.getNumAvailableOutputs(shuffleId2) === 0)
+ val expectedResults = Seq((Success, 11), (Success, 12))
+ val expectedResultsMap: Map[Int, Any] = Map(0 -> 11, 1 -> 12)
+ completeAndCheckAnswer(taskSets.last, expectedResults,
expectedResultsMap)
+
+ // Retry all tasks of shuffleId2 for the new jobs
+ val finalRdd2 = new MyRDD(sc, 2, List(shuffleDep2), tracker =
mapOutputTracker)
+ submit(finalRdd2, Array(0, 1), properties = properties)
+ val shuffleRetryTaskSet = taskSets.last
+ assert(shuffleRetryTaskSet.shuffleId.contains(shuffleId2))
+ completeShuffleMapStageSuccessfully(
+ shuffleRetryTaskSet.stageId,
+ shuffleRetryTaskSet.stageAttemptId,
+ 2,
+ checksumVal = 200)
+ completeAndCheckAnswer(taskSets.last, expectedResults,
expectedResultsMap)
+ } finally {
+ scheduler.cleanupQueryJobs(executionId)
+ }
+ }
+
+ test("SPARK-55064: clean up shuffle data for the succeeding available
shuffle map stages " +
+ "in active jobs if checksum mismatch detected") {
+ val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+ val shuffleDep1 = new ShuffleDependency(
+ shuffleMapRdd1,
+ new HashPartitioner(2),
+ _checksumMismatchFullRetryEnabled = true,
+ checksumMismatchQueryLevelRollbackEnabled = true)
+ val shuffleId1 = shuffleDep1.shuffleId
+
+ val shuffleMapRdd2 = new MyRDD(
+ sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+ val shuffleMapRdd3 = new MyRDD(
+ sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+
+ val shuffleDep2 = new ShuffleDependency(
+ shuffleMapRdd2,
+ new HashPartitioner(2),
+ _checksumMismatchFullRetryEnabled = true,
+ checksumMismatchQueryLevelRollbackEnabled = true)
+ val shuffleDep3 = new ShuffleDependency(
+ shuffleMapRdd3,
+ new HashPartitioner(2),
+ _checksumMismatchFullRetryEnabled = true,
+ checksumMismatchQueryLevelRollbackEnabled = true)
+
+ val finalRdd = new MyRDD(sc, 2, List(shuffleDep2, shuffleDep3), tracker =
mapOutputTracker)
+ submit(finalRdd, Array(0, 1))
+
+ // Finish shuffle map stage 0 and 1.
+ completeShuffleMapStageSuccessfully(
+ 0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100)
+ completeShuffleMapStageSuccessfully(
+ 1, 0, 2, Seq("hostB", "hostC"), checksumVal = 200)
+ assert(mapOutputTracker.findMissingPartitions(shuffleId1) ===
Some(Seq.empty))
+ val shuffleId2 = taskSets(1).shuffleId.get
+ assert(mapOutputTracker.findMissingPartitions(shuffleId2) ===
Some(Seq.empty))
+
+ val shuffleId3 = taskSets(2).shuffleId.get
+ // The first task of shuffle map stage 2 failed with FetchFailed.
+ runEvent(makeCompletionEvent(
+ taskSets(2).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0,
"ignored"),
+ null))
+
+ // Check status for all failedStages.
+ val failedStages = scheduler.failedStages.toSeq
+ assert(failedStages.map(_.id) == Seq(0, 2))
+ scheduler.resubmitFailedStages()
+
+ assert(mapOutputTracker.getNumAvailableOutputs(shuffleId2) === 2)
+ // Complete the re-attempt of shuffle map stage 0(shuffleId1) with a
different checksum.
+ checkAndCompleteRetryStage(3, 0, shuffleId1, 1, checksumVal = 101)
+ // Output of shuffleId2 should be cleaned up and will be resubmitted.
+ assert(mapOutputTracker.getNumAvailableOutputs(shuffleId2) === 0)
+ checkAndCompleteRetryStage(4, 2, shuffleId3, 2, checksumVal = 300)
+ checkAndCompleteRetryStage(5, 1, shuffleId2, 2, checksumVal = 200)
+ completeAndCheckAnswer(taskSets(6), Seq((Success, 11), (Success, 12)),
Map(0 -> 11, 1 -> 12))
+ }
+
+ test("SPARK-55064: cancel and resubmit running succeeding shuffle map stages
" +
+ "if checksum mismatch detected") {
+ val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+ val shuffleDep1 = new ShuffleDependency(
+ shuffleMapRdd1,
+ new HashPartitioner(2),
+ _checksumMismatchFullRetryEnabled = true,
+ checksumMismatchQueryLevelRollbackEnabled = true)
+ val shuffleId1 = shuffleDep1.shuffleId
+
+ val shuffleMapRdd2 = new MyRDD(
+ sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+ val shuffleMapRdd3 = new MyRDD(
+ sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+
+ val shuffleDep2 = new ShuffleDependency(
+ shuffleMapRdd2,
+ new HashPartitioner(2),
+ _checksumMismatchFullRetryEnabled = true,
+ checksumMismatchQueryLevelRollbackEnabled = true)
+ val shuffleDep3 = new ShuffleDependency(
+ shuffleMapRdd3,
+ new HashPartitioner(2),
+ _checksumMismatchFullRetryEnabled = true,
+ checksumMismatchQueryLevelRollbackEnabled = true)
+
+ val finalRdd = new MyRDD(sc, 2, List(shuffleDep2, shuffleDep3), tracker =
mapOutputTracker)
+ submit(finalRdd, Array(0, 1))
+
+ // Finish shuffle map stage 0
+ completeShuffleMapStageSuccessfully(
+ 0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100)
+ assert(mapOutputTracker.findMissingPartitions(shuffleId1) ===
Some(Seq.empty))
+
+ val shuffleId2 = taskSets(1).shuffleId.get
+ val shuffleId3 = taskSets(2).shuffleId.get
+
+ // The first task of shuffle map stage 2 failed with FetchFailed.
+ runEvent(makeCompletionEvent(
+ taskSets(2).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0,
"ignored"),
+ null))
+
+ // Check status for all failedStages.
+ val failedStages = scheduler.failedStages.toSeq
+ assert(failedStages.map(_.id) == Seq(0, 2))
+ scheduler.resubmitFailedStages()
+
+ // Complete the re-attempt of shuffle map stage 0(shuffleId1) with a
different checksum.
+ checkAndCompleteRetryStage(3, 0, shuffleId1, 1, checksumVal = 101)
+ // Shuffle map stage 1 should be cancelled and resubmitted.
+ assert(scheduler.failedStages.map(_.id).toSeq == Seq(1))
+ scheduler.resubmitFailedStages()
+
+ checkAndCompleteRetryStage(4, 2, shuffleId3, 2, checksumVal = 200)
+ checkAndCompleteRetryStage(5, 1, shuffleId2, 2, checksumVal = 300)
+ completeAndCheckAnswer(taskSets(6), Seq((Success, 11), (Success, 12)),
Map(0 -> 11, 1 -> 12))
+ }
+
+ test("SPARK-55064: abort stage if checksum mismatch detected with succeeding
" +
+ "running result stage in active jobs") {
+ val executionId = 55064L
+ val properties = new Properties()
+ properties.setProperty(SparkContext.SQL_EXECUTION_ID_KEY,
executionId.toString)
+ try {
+ val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+ val shuffleDep1 = new ShuffleDependency(
+ shuffleMapRdd1,
+ new HashPartitioner(2),
+ _checksumMismatchFullRetryEnabled = true,
+ checksumMismatchQueryLevelRollbackEnabled = true)
+ val shuffleId1 = shuffleDep1.shuffleId
+
+ // Submit 2 jobs depending on shuffleDep1
+ val finalRdd1 = new MyRDD(
+ sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+ submit(finalRdd1, Array(0, 1), properties = properties)
+ val finalRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker =
mapOutputTracker)
+ submit(finalRdd2, Array(0, 1), properties = properties)
+
+ // Finish the stages.
+ completeShuffleMapStageSuccessfully(
+ 0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100)
+ assert(mapOutputTracker.findMissingPartitions(shuffleId1) ===
Some(Seq.empty))
+
+ // The first task of result stage 2 failed with FetchFailed.
+ runEvent(makeCompletionEvent(
+ taskSets(2).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0,
"ignored"),
+ null))
+
+ // Check status for all failedStages.
+ val failedStages = scheduler.failedStages.toSeq
+ assert(failedStages.map(_.id) == Seq(0, 2))
+ scheduler.resubmitFailedStages()
+
+ // Complete the re-attempt of shuffle map stage 0(shuffleId1) with a
different checksum.
+ completeShuffleMapStageSuccessfully(0, 1, 2, checksumVal = 101)
+ assert(failure != null && failure.getMessage.contains(
+ "cannot rollback a running result stage"))
+ } finally {
+ scheduler.cleanupQueryJobs(executionId)
+ }
+ }
+
private def checkAndCompleteRetryStage(
taskSetIndex: Int,
stageId: Int,
@@ -3502,6 +3803,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
stageAttemptId: Int = 1): Unit = {
assert(taskSets(taskSetIndex).stageId == stageId)
assert(taskSets(taskSetIndex).stageAttemptId == stageAttemptId)
+ assert(taskSets(taskSetIndex).shuffleId.contains(shuffleId))
assert(taskSets(taskSetIndex).tasks.length == numTasks)
completeShuffleMapStageSuccessfully(stageId, stageAttemptId, 2,
checksumVal = checksumVal)
assert(mapOutputTracker.findMissingPartitions(shuffleId) ===
Some(Seq.empty))
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 585dca249d3d..e86466826019 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -917,6 +917,17 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ private[spark] val SHUFFLE_CHECKSUM_MISMATCH_QUERY_LEVEL_ROLLBACK_ENABLED =
+
buildConf("spark.sql.shuffle.orderIndependentChecksum.enableQueryLevelRollbackOnMismatch")
+ .internal()
+ .doc("Whether to rollback all the consumer stages from the same query
executor " +
+ "when we detect checksum mismatches with its producer stages,
including cancel " +
+ "running shuffle map stages and resubmit, clean up all the shuffle
data written " +
+ "for available shuffle map stages and abort the running result
stages.")
+ .version("4.2.0")
+ .booleanConf
+ .createWithDefault(false)
+
val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE =
buildConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize")
.internal()
@@ -7201,6 +7212,9 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
def shuffleChecksumMismatchFullRetryEnabled: Boolean =
getConf(SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED)
+ def shuffleChecksumMismatchQueryLevelRollbackEnabled: Boolean =
+ getConf(SHUFFLE_CHECKSUM_MISMATCH_QUERY_LEVEL_ROLLBACK_ENABLED)
+
def allowCollationsInMapKeys: Boolean = getConf(ALLOW_COLLATIONS_IN_MAP_KEYS)
def objectLevelCollationsEnabled: Boolean =
getConf(OBJECT_LEVEL_COLLATIONS_ENABLED)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 45fa6c60f465..cf26b2991652 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -238,6 +238,16 @@ object SQLExecution extends Logging {
event.duration = endTime - startTime
event.qe = queryExecution
event.executionFailure = ex
+ if (Utils.isTesting) {
+ import scala.jdk.CollectionConverters._
+ event.jobIds =
Option(sc.dagScheduler.activeQueryToJobs.get(executionId))
+ .map(_.asScala.map(_.jobId).toSet)
+ .getOrElse(Set.empty)
+ }
+
+ // Clean up jobs tracked by DAGScheduler for this query
execution.
+ sc.dagScheduler.cleanupQueryJobs(executionId)
+
sc.listenerBus.post(event)
// Observation.tryComplete is called here to ensure the
observation is completed,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index a1f693ef5c15..95120039a6f9 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -495,7 +495,9 @@ object ShuffleExchangeExec {
serializer,
shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics),
rowBasedChecksums =
UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize),
- _checksumMismatchFullRetryEnabled =
SQLConf.get.shuffleChecksumMismatchFullRetryEnabled)
+ _checksumMismatchFullRetryEnabled =
SQLConf.get.shuffleChecksumMismatchFullRetryEnabled,
+ checksumMismatchQueryLevelRollbackEnabled =
+ SQLConf.get.shuffleChecksumMismatchQueryLevelRollbackEnabled)
dependency
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index 4e773390b2f7..6fbd6ad62a2d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -90,6 +90,9 @@ case class SparkListenerSQLExecutionEnd(
// The exception object that caused this execution to fail. None if the
execution doesn't fail.
@JsonIgnore private[sql] var executionFailure: Option[Throwable] = None
+
+ // The jobs for this execution. Test only.
+ @JsonIgnore private[sql] var jobIds: Set[Int] = Set.empty
}
/**
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
index af3a0f3e3710..877e6970368b 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
@@ -17,9 +17,11 @@
package org.apache.spark.sql.execution
import scala.collection.mutable
+import scala.concurrent.duration.DurationInt
import scala.io.Source
import scala.util.Try
+import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent,
SparkListenerJobStart}
import org.apache.spark.sql.{AnalysisException, ExtendedExplainGenerator,
FastOperator, SaveMode}
import org.apache.spark.sql.catalyst.{QueryPlanningTracker,
QueryPlanningTrackerCallback, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{CurrentNamespace,
UnresolvedFunction, UnresolvedRelation}
@@ -33,6 +35,7 @@ import
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec,
QueryStageExec}
import org.apache.spark.sql.execution.datasources.v2.ShowTablesExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.storage.ShuffleIndexBlockId
@@ -519,6 +522,66 @@ class QueryExecutionSuite extends SharedSparkSession {
}
}
+ test("SPARK-55064: Jobs are tracked in DAGScheduler before query finished") {
+ val jobs = new mutable.HashSet[Int]()
+
+ var sqlExecutionEndVerified: Boolean = false
+ val listener = new SparkListener {
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+ jobs += jobStart.jobId
+ }
+
+ override def onOtherEvent(event: SparkListenerEvent): Unit = {
+ event match {
+ case e: SparkListenerSQLExecutionEnd =>
+ val jobsTracked = e.jobIds
+ assert(jobsTracked == jobs)
+ jobs.clear()
+ sqlExecutionEndVerified = true
+ case _ =>
+ }
+ }
+ }
+ spark.sparkContext.addSparkListener(listener)
+
+ try {
+ withTable("t1", "t2") {
+ spark.range(10).selectExpr("id as A", "id as D")
+ .write.partitionBy("A").mode("overwrite").saveAsTable("t1")
+ spark.range(2).selectExpr("id as B", "id as C")
+ .write.mode("overwrite").saveAsTable("t2")
+
+ Seq(true, false).foreach { adaptiveEnabled =>
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key ->
adaptiveEnabled.toString) {
+ sqlExecutionEndVerified = false
+ val df = sql(
+ """
+ |SELECT avg(A) FROM t1
+ |WHERE A = (
+ | SELECT B from t2 where C = 1
+ |)
+ """.stripMargin)
+
+ df.collect()
+ spark.sparkContext.listenerBus.waitUntilEmpty()
+ assert(sqlExecutionEndVerified)
+ // The jobs tracked by DAGScheduler should be cleared after the
query is done.
+ eventually(timeout(10.seconds)) {
+ assert(spark.sparkContext.dagScheduler.activeQueryToJobs.isEmpty)
+
assert(spark.sparkContext.dagScheduler.jobIdToQueryExecutionId.isEmpty)
+ }
+ }
+ }
+ }
+ eventually(timeout(10.seconds)) {
+ assert(spark.sparkContext.dagScheduler.activeQueryToJobs.isEmpty)
+ assert(spark.sparkContext.dagScheduler.jobIdToQueryExecutionId.isEmpty)
+ }
+ } finally {
+ spark.sparkContext.removeSparkListener(listener)
+ }
+ }
+
case class MockCallbackEagerCommand(
var trackerAnalyzed: QueryPlanningTracker = null,
var trackerReadyForExecution: QueryPlanningTracker = null)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]