This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 4bacd1f21 [CELEBORN-1856] Support stage-rerun when read partition by 
chunkOffsets when enable optimize skew partition read
4bacd1f21 is described below

commit 4bacd1f211fa500460fb389ff96ae2c9e0591360
Author: wangshengjie3 <[email protected]>
AuthorDate: Mon Mar 24 22:03:15 2025 +0800

    [CELEBORN-1856] Support stage-rerun when read partition by chunkOffsets 
when enable optimize skew partition read
    
    ### What changes were proposed in this pull request?
    Support stage-rerun when read partition by chunkOffsets when enable 
optimize skew partition read
    
    ### Why are the changes needed?
    In [CELEBORN-1319](https://issues.apache.org/jira/browse/CELEBORN-1319), we 
have already implemented the skew partition read optimization based on chunk 
offsets, but we don't support skew partition shuffle retry, so we need support 
the stage rerun.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Cluster test
    
    Closes #3118 from wangshengjie123/support-stage-rerun.
    
    Lead-authored-by: wangshengjie3 <[email protected]>
    Co-authored-by: Wang, Fei <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 ...eleborn-Optimize-Skew-Partitions-spark3_2.patch | 79 ++++++++++++++++---
 ...eleborn-Optimize-Skew-Partitions-spark3_3.patch | 79 ++++++++++++++++---
 ...eleborn-Optimize-Skew-Partitions-spark3_4.patch | 77 +++++++++++++++++--
 ...eleborn-Optimize-Skew-Partitions-spark3_5.patch | 89 ++++++++++++++++++++--
 .../shuffle/celeborn/SparkShuffleManager.java      |  5 ++
 .../apache/spark/shuffle/celeborn/SparkUtils.java  | 14 ++++
 .../apache/celeborn/client/LifecycleManager.scala  | 18 ++++-
 7 files changed, 326 insertions(+), 35 deletions(-)

diff --git 
a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch 
b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch
index 0cb1fc812..0e3be7a8e 100644
--- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch
+++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch
@@ -135,7 +135,7 @@ index 00000000000..5e190c512df
 +
 +}
 diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
-index b950c07f3d8..2cb430c3c3d 100644
+index b950c07f3d8..9e339db4fb4 100644
 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
 +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
 @@ -33,6 +33,7 @@ import com.google.common.util.concurrent.{Futures, 
SettableFuture}
@@ -146,15 +146,76 @@ index b950c07f3d8..2cb430c3c3d 100644
  import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
  import org.apache.spark.internal.Logging
  import org.apache.spark.internal.config
-@@ -1780,7 +1781,7 @@ private[spark] class DAGScheduler(
-           failedStage.failedAttemptIds.add(task.stageAttemptId)
-           val shouldAbortStage =
-             failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts 
||
--            disallowStageRetryForTest
-+            disallowStageRetryForTest || 
CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)
+@@ -1369,7 +1370,10 @@ private[spark] class DAGScheduler(
+     // The operation here can make sure for the partially completed 
intermediate stage,
+     // `findMissingPartitions()` returns all partitions every time.
+     stage match {
+-      case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable 
=>
++      case sms: ShuffleMapStage if (stage.isIndeterminate ||
++        
CelebornShuffleState.isCelebornSkewedShuffle(sms.shuffleDep.shuffleId)) && 
!sms.isAvailable =>
++        logInfo(s"Unregistering shuffle output for stage ${stage.id}" +
++          s" shuffle ${sms.shuffleDep.shuffleId}")
+         
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
+         sms.shuffleDep.newShuffleMergeState()
+       case _ =>
+@@ -1689,7 +1693,15 @@ private[spark] class DAGScheduler(
+         // tasks complete, they still count and we can mark the corresponding 
partitions as
+         // finished. Here we notify the task scheduler to skip running tasks 
for the same partition,
+         // to save resource.
+-        if (task.stageAttemptId < stage.latestInfo.attemptNumber()) {
++        // CELEBORN-1856, if stage is indeterminate or shuffleMapStage is 
skewed and read by
++        // Celeborn chunkOffsets, should not call notifyPartitionCompletion, 
otherwise will
++        // skip running tasks for the same partition because 
TaskSetManager.dequeueTaskFromList
++        // will skip running task which TaskSetManager.successful(taskIndex) 
is true.
++        // TODO: Suggest cherry-pick SPARK-45182 and SPARK-45498, ResultStage 
may has result commit and other issues
++        val isStageIndeterminate = stage.isInstanceOf[ShuffleMapStage] &&
++          CelebornShuffleState.isCelebornSkewedShuffle(
++            stage.asInstanceOf[ShuffleMapStage].shuffleDep.shuffleId)
++        if (task.stageAttemptId < stage.latestInfo.attemptNumber() && 
!isStageIndeterminate) {
+           taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
+         }
+ 
+@@ -1772,6 +1784,14 @@ private[spark] class DAGScheduler(
+         val failedStage = stageIdToStage(task.stageId)
+         val mapStage = shuffleIdToMapStage(shuffleId)
+ 
++        // In Celeborn-1139 we support read skew partition by Celeborn 
chunkOffsets,
++        // it will make shuffle be indeterminate, so abort the ResultStage 
directly here.
++        if (failedStage.isInstanceOf[ResultStage] && 
CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) {
++          val shuffleFailedReason = s"ResultStage:${failedStage.id} fetch 
failed and the shuffle:$shuffleId " +
++            s"is skewed partition read by Celeborn, so abort it."
++          abortStage(failedStage, shuffleFailedReason, None)
++        }
++
+         if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) {
+           logInfo(s"Ignoring fetch failure from $task as it's from 
$failedStage attempt" +
+             s" ${task.stageAttemptId} and there is a more recent attempt for 
that stage " +
+@@ -1850,7 +1870,7 @@ private[spark] class DAGScheduler(
+               // Note that, if map stage is UNORDERED, we are fine. The 
shuffle partitioner is
+               // guaranteed to be determinate, so the input data of the 
reducers will not change
+               // even if the map tasks are re-tried.
+-              if (mapStage.isIndeterminate) {
++              if (mapStage.isIndeterminate || 
CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) {
+                 // It's a little tricky to find all the succeeding stages of 
`mapStage`, because
+                 // each stage only know its parents not children. Here we 
traverse the stages from
+                 // the leaf nodes (the result stages of active jobs), and 
rollback all the stages
+@@ -1861,7 +1881,15 @@ private[spark] class DAGScheduler(
  
-           // It is likely that we receive multiple FetchFailed for a single 
stage (because we have
-           // multiple tasks running concurrently on different executors). In 
that case, it is
+                 def collectStagesToRollback(stageChain: List[Stage]): Unit = {
+                   if (stagesToRollback.contains(stageChain.head)) {
+-                    stageChain.drop(1).foreach(s => stagesToRollback += s)
++                    stageChain.drop(1).foreach(s => {
++                      stagesToRollback += s
++                      s match {
++                        case currentMapStage: ShuffleMapStage =>
++                          
CelebornShuffleState.registerCelebornSkewedShuffle(currentMapStage.shuffleDep.shuffleId)
++                        case _: ResultStage =>
++                          // do nothing, should abort celeborn skewed read 
stage
++                      }
++                    })
+                   } else {
+                     stageChain.head.parents.foreach { s =>
+                       collectStagesToRollback(s :: stageChain)
 diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
 new file mode 100644
 index 00000000000..3dc60678461
diff --git 
a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch 
b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch
index f8e38615c..6bb8be966 100644
--- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch
+++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch
@@ -135,7 +135,7 @@ index 00000000000..5e190c512df
 +
 +}
 diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
-index bd2823bcac1..d0c88081527 100644
+index bd2823bcac1..e97218b046b 100644
 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
 +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
 @@ -33,6 +33,7 @@ import com.google.common.util.concurrent.{Futures, 
SettableFuture}
@@ -146,15 +146,76 @@ index bd2823bcac1..d0c88081527 100644
  import org.apache.spark.errors.SparkCoreErrors
  import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
  import org.apache.spark.internal.Logging
-@@ -1851,7 +1852,7 @@ private[spark] class DAGScheduler(
-           failedStage.failedAttemptIds.add(task.stageAttemptId)
-           val shouldAbortStage =
-             failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts 
||
--            disallowStageRetryForTest
-+            disallowStageRetryForTest || 
CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)
+@@ -1404,7 +1405,10 @@ private[spark] class DAGScheduler(
+     // The operation here can make sure for the partially completed 
intermediate stage,
+     // `findMissingPartitions()` returns all partitions every time.
+     stage match {
+-      case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable 
=>
++      case sms: ShuffleMapStage if (stage.isIndeterminate ||
++        
CelebornShuffleState.isCelebornSkewedShuffle(sms.shuffleDep.shuffleId)) && 
!sms.isAvailable =>
++        logInfo(s"Unregistering shuffle output for stage ${stage.id}" +
++          s" shuffle ${sms.shuffleDep.shuffleId}")
+         
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
+         sms.shuffleDep.newShuffleMergeState()
+       case _ =>
+@@ -1760,7 +1764,15 @@ private[spark] class DAGScheduler(
+         // tasks complete, they still count and we can mark the corresponding 
partitions as
+         // finished. Here we notify the task scheduler to skip running tasks 
for the same partition,
+         // to save resource.
+-        if (task.stageAttemptId < stage.latestInfo.attemptNumber()) {
++        // CELEBORN-1856, if stage is indeterminate or shuffleMapStage is 
skewed and read by
++        // Celeborn chunkOffsets, should not call notifyPartitionCompletion, 
otherwise will
++        // skip running tasks for the same partition because 
TaskSetManager.dequeueTaskFromList
++        // will skip running task which TaskSetManager.successful(taskIndex) 
is true.
++        // TODO: Suggest cherry-pick SPARK-45182 and SPARK-45498, ResultStage 
may has result commit and other issues
++        val isStageIndeterminate = stage.isInstanceOf[ShuffleMapStage] &&
++          CelebornShuffleState.isCelebornSkewedShuffle(
++            stage.asInstanceOf[ShuffleMapStage].shuffleDep.shuffleId)
++        if (task.stageAttemptId < stage.latestInfo.attemptNumber() && 
!isStageIndeterminate) {
+           taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
+         }
+ 
+@@ -1843,6 +1855,14 @@ private[spark] class DAGScheduler(
+         val failedStage = stageIdToStage(task.stageId)
+         val mapStage = shuffleIdToMapStage(shuffleId)
+ 
++        // In Celeborn-1139 we support read skew partition by Celeborn 
chunkOffsets,
++        // it will make shuffle be indeterminate, so abort the ResultStage 
directly here.
++        if (failedStage.isInstanceOf[ResultStage] && 
CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) {
++          val shuffleFailedReason = s"ResultStage:${failedStage.id} fetch 
failed and the shuffle:$shuffleId " +
++            s"is skewed partition read by Celeborn, so abort it."
++          abortStage(failedStage, shuffleFailedReason, None)
++        }
++
+         if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) {
+           logInfo(s"Ignoring fetch failure from $task as it's from 
$failedStage attempt" +
+             s" ${task.stageAttemptId} and there is a more recent attempt for 
that stage " +
+@@ -1921,7 +1941,7 @@ private[spark] class DAGScheduler(
+               // Note that, if map stage is UNORDERED, we are fine. The 
shuffle partitioner is
+               // guaranteed to be determinate, so the input data of the 
reducers will not change
+               // even if the map tasks are re-tried.
+-              if (mapStage.isIndeterminate) {
++              if (mapStage.isIndeterminate || 
CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) {
+                 // It's a little tricky to find all the succeeding stages of 
`mapStage`, because
+                 // each stage only know its parents not children. Here we 
traverse the stages from
+                 // the leaf nodes (the result stages of active jobs), and 
rollback all the stages
+@@ -1932,7 +1952,15 @@ private[spark] class DAGScheduler(
  
-           // It is likely that we receive multiple FetchFailed for a single 
stage (because we have
-           // multiple tasks running concurrently on different executors). In 
that case, it is
+                 def collectStagesToRollback(stageChain: List[Stage]): Unit = {
+                   if (stagesToRollback.contains(stageChain.head)) {
+-                    stageChain.drop(1).foreach(s => stagesToRollback += s)
++                    stageChain.drop(1).foreach(s => {
++                      stagesToRollback += s
++                      s match {
++                        case currentMapStage: ShuffleMapStage =>
++                          
CelebornShuffleState.registerCelebornSkewedShuffle(currentMapStage.shuffleDep.shuffleId)
++                        case _: ResultStage =>
++                          // do nothing, should abort celeborn skewed read 
stage
++                      }
++                    })
+                   } else {
+                     stageChain.head.parents.foreach { s =>
+                       collectStagesToRollback(s :: stageChain)
 diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
 new file mode 100644
 index 00000000000..3dc60678461
diff --git 
a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch 
b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch
index 9aed835fe..9f38d8026 100644
--- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch
+++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch
@@ -135,7 +135,7 @@ index 00000000000..5e190c512df
 +
 +}
 diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
-index 26be8c72bbc..81feaba962c 100644
+index 26be8c72bbc..4323b6d1a75 100644
 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
 +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
 @@ -34,6 +34,7 @@ import com.google.common.util.concurrent.{Futures, 
SettableFuture}
@@ -146,15 +146,76 @@ index 26be8c72bbc..81feaba962c 100644
  import org.apache.spark.errors.SparkCoreErrors
  import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
  import org.apache.spark.internal.Logging
-@@ -1897,7 +1898,7 @@ private[spark] class DAGScheduler(
+@@ -1435,7 +1436,10 @@ private[spark] class DAGScheduler(
+     // The operation here can make sure for the partially completed 
intermediate stage,
+     // `findMissingPartitions()` returns all partitions every time.
+     stage match {
+-      case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable 
=>
++      case sms: ShuffleMapStage if (stage.isIndeterminate ||
++        
CelebornShuffleState.isCelebornSkewedShuffle(sms.shuffleDep.shuffleId)) && 
!sms.isAvailable =>
++        logInfo(s"Unregistering shuffle output for stage ${stage.id}" +
++          s" shuffle ${sms.shuffleDep.shuffleId}")
+         
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
+         sms.shuffleDep.newShuffleMergeState()
+       case _ =>
+@@ -1796,7 +1800,15 @@ private[spark] class DAGScheduler(
+         // tasks complete, they still count and we can mark the corresponding 
partitions as
+         // finished. Here we notify the task scheduler to skip running tasks 
for the same partition,
+         // to save resource.
+-        if (task.stageAttemptId < stage.latestInfo.attemptNumber()) {
++        // CELEBORN-1856, if stage is indeterminate or shuffleMapStage is 
skewed and read by
++        // Celeborn chunkOffsets, should not call notifyPartitionCompletion, 
otherwise will
++        // skip running tasks for the same partition because 
TaskSetManager.dequeueTaskFromList
++        // will skip running task which TaskSetManager.successful(taskIndex) 
is true.
++        // TODO: Suggest cherry-pick SPARK-45182 and SPARK-45498, ResultStage 
may has result commit and other issues
++        val isStageIndeterminate = stage.isInstanceOf[ShuffleMapStage] &&
++          CelebornShuffleState.isCelebornSkewedShuffle(
++            stage.asInstanceOf[ShuffleMapStage].shuffleDep.shuffleId)
++        if (task.stageAttemptId < stage.latestInfo.attemptNumber() && 
!isStageIndeterminate) {
+           taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
+         }
+ 
+@@ -1879,6 +1891,14 @@ private[spark] class DAGScheduler(
+         val failedStage = stageIdToStage(task.stageId)
+         val mapStage = shuffleIdToMapStage(shuffleId)
  
-           val shouldAbortStage =
-             failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts 
||
--            disallowStageRetryForTest
-+            disallowStageRetryForTest || 
CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)
++        // In Celeborn-1139 we support read skew partition by Celeborn 
chunkOffsets,
++        // it will make shuffle be indeterminate, so abort the ResultStage 
directly here.
++        if (failedStage.isInstanceOf[ResultStage] && 
CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) {
++          val shuffleFailedReason = s"ResultStage:${failedStage.id} fetch 
failed and the shuffle:$shuffleId " +
++            s"is skewed partition read by Celeborn, so abort it."
++          abortStage(failedStage, shuffleFailedReason, None)
++        }
++
+         if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) {
+           logInfo(s"Ignoring fetch failure from $task as it's from 
$failedStage attempt" +
+             s" ${task.stageAttemptId} and there is a more recent attempt for 
that stage " +
+@@ -1977,7 +1997,7 @@ private[spark] class DAGScheduler(
+               // Note that, if map stage is UNORDERED, we are fine. The 
shuffle partitioner is
+               // guaranteed to be determinate, so the input data of the 
reducers will not change
+               // even if the map tasks are re-tried.
+-              if (mapStage.isIndeterminate) {
++              if (mapStage.isIndeterminate || 
CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) {
+                 // It's a little tricky to find all the succeeding stages of 
`mapStage`, because
+                 // each stage only know its parents not children. Here we 
traverse the stages from
+                 // the leaf nodes (the result stages of active jobs), and 
rollback all the stages
+@@ -1988,7 +2008,15 @@ private[spark] class DAGScheduler(
  
-           // It is likely that we receive multiple FetchFailed for a single 
stage (because we have
-           // multiple tasks running concurrently on different executors). In 
that case, it is
+                 def collectStagesToRollback(stageChain: List[Stage]): Unit = {
+                   if (stagesToRollback.contains(stageChain.head)) {
+-                    stageChain.drop(1).foreach(s => stagesToRollback += s)
++                    stageChain.drop(1).foreach(s => {
++                      stagesToRollback += s
++                      s match {
++                        case currentMapStage: ShuffleMapStage =>
++                          
CelebornShuffleState.registerCelebornSkewedShuffle(currentMapStage.shuffleDep.shuffleId)
++                        case _: ResultStage =>
++                          // do nothing, should abort celeborn skewed read 
stage
++                      }
++                    })
+                   } else {
+                     stageChain.head.parents.foreach { s =>
+                       collectStagesToRollback(s :: stageChain)
 diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
 new file mode 100644
 index 00000000000..3dc60678461
diff --git 
a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch 
b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch
index 553bdeae6..71d0f9859 100644
--- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch
+++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch
@@ -135,7 +135,7 @@ index 00000000000..5e190c512df
 +
 +}
 diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
-index 89d16e57934..3b9094f3254 100644
+index 89d16e57934..36ce50093c0 100644
 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
 +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
 @@ -34,6 +34,7 @@ import com.google.common.util.concurrent.{Futures, 
SettableFuture}
@@ -146,15 +146,88 @@ index 89d16e57934..3b9094f3254 100644
  import org.apache.spark.errors.SparkCoreErrors
  import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
  import org.apache.spark.internal.Logging
-@@ -1962,7 +1963,7 @@ private[spark] class DAGScheduler(
+@@ -1480,7 +1481,10 @@ private[spark] class DAGScheduler(
+     // The operation here can make sure for the partially completed 
intermediate stage,
+     // `findMissingPartitions()` returns all partitions every time.
+     stage match {
+-      case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable 
=>
++      case sms: ShuffleMapStage if (stage.isIndeterminate ||
++        
CelebornShuffleState.isCelebornSkewedShuffle(sms.shuffleDep.shuffleId)) && 
!sms.isAvailable =>
++        logInfo(s"Unregistering shuffle output for stage ${stage.id}" +
++          s" shuffle ${sms.shuffleDep.shuffleId}")
+         
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
+         sms.shuffleDep.newShuffleMergeState()
+       case _ =>
+@@ -1854,7 +1858,18 @@ private[spark] class DAGScheduler(
+         // tasks complete, they still count and we can mark the corresponding 
partitions as
+         // finished if the stage is determinate. Here we notify the task 
scheduler to skip running
+         // tasks for the same partition to save resource.
+-        if (!stage.isIndeterminate && task.stageAttemptId < 
stage.latestInfo.attemptNumber()) {
++        // finished. Here we notify the task scheduler to skip running tasks 
for the same partition,
++        // to save resource.
++        // CELEBORN-1856, if stage is indeterminate or shuffleMapStage is 
skewed and read by
++        // Celeborn chunkOffsets, should not call notifyPartitionCompletion, 
otherwise will
++        // skip running tasks for the same partition because 
TaskSetManager.dequeueTaskFromList
++        // will skip running task which TaskSetManager.successful(taskIndex) 
is true.
++        // TODO: ResultStage has result commit and other issues
++        val isCelebornShuffleIndeterminate = 
stage.isInstanceOf[ShuffleMapStage] &&
++          CelebornShuffleState.isCelebornSkewedShuffle(
++            stage.asInstanceOf[ShuffleMapStage].shuffleDep.shuffleId)
++        if (!stage.isIndeterminate && task.stageAttemptId < 
stage.latestInfo.attemptNumber()
++          && !isCelebornShuffleIndeterminate) {
+           taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
+         }
+ 
+@@ -1909,7 +1924,7 @@ private[spark] class DAGScheduler(
+           case smt: ShuffleMapTask =>
+             val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
+             // Ignore task completion for old attempt of indeterminate stage
+-            val ignoreIndeterminate = stage.isIndeterminate &&
++            val ignoreIndeterminate = (stage.isIndeterminate || 
isCelebornShuffleIndeterminate) &&
+               task.stageAttemptId < stage.latestInfo.attemptNumber()
+             if (!ignoreIndeterminate) {
+               shuffleStage.pendingPartitions -= task.partitionId
+@@ -1944,6 +1959,14 @@ private[spark] class DAGScheduler(
+         val failedStage = stageIdToStage(task.stageId)
+         val mapStage = shuffleIdToMapStage(shuffleId)
  
-           val shouldAbortStage =
-             failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts 
||
--            disallowStageRetryForTest
-+            disallowStageRetryForTest || 
CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)
++        // In Celeborn-1139 we support read skew partition by Celeborn 
chunkOffsets,
++        // it will make shuffle be indeterminate, so abort the ResultStage 
directly here.
++        if (failedStage.isInstanceOf[ResultStage] && 
CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) {
++          val shuffleFailedReason = s"ResultStage:${failedStage.id} fetch 
failed and the shuffle:$shuffleId " +
++            s"is skewed partition read by Celeborn, so abort it."
++          abortStage(failedStage, shuffleFailedReason, None)
++        }
++
+         if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) {
+           logInfo(s"Ignoring fetch failure from $task as it's from 
$failedStage attempt" +
+             s" ${task.stageAttemptId} and there is a more recent attempt for 
that stage " +
+@@ -2042,7 +2065,7 @@ private[spark] class DAGScheduler(
+               // Note that, if map stage is UNORDERED, we are fine. The 
shuffle partitioner is
+               // guaranteed to be determinate, so the input data of the 
reducers will not change
+               // even if the map tasks are re-tried.
+-              if (mapStage.isIndeterminate) {
++              if (mapStage.isIndeterminate || 
CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) {
+                 // It's a little tricky to find all the succeeding stages of 
`mapStage`, because
+                 // each stage only know its parents not children. Here we 
traverse the stages from
+                 // the leaf nodes (the result stages of active jobs), and 
rollback all the stages
+@@ -2053,7 +2076,15 @@ private[spark] class DAGScheduler(
  
-           // It is likely that we receive multiple FetchFailed for a single 
stage (because we have
-           // multiple tasks running concurrently on different executors). In 
that case, it is
+                 def collectStagesToRollback(stageChain: List[Stage]): Unit = {
+                   if (stagesToRollback.contains(stageChain.head)) {
+-                    stageChain.drop(1).foreach(s => stagesToRollback += s)
++                    stageChain.drop(1).foreach(s => {
++                      stagesToRollback += s
++                      s match {
++                        case currentMapStage: ShuffleMapStage =>
++                          
CelebornShuffleState.registerCelebornSkewedShuffle(currentMapStage.shuffleDep.shuffleId)
++                        case _: ResultStage =>
++                          // do nothing, should abort celeborn skewed read 
stage
++                      }
++                    })
+                   } else {
+                     stageChain.head.parents.foreach { s =>
+                       collectStagesToRollback(s :: stageChain)
 diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala
 new file mode 100644
 index 00000000000..3dc60678461
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index df28143c6..234fba1fe 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -150,6 +150,11 @@ public class SparkShuffleManager implements ShuffleManager 
{
 
             lifecycleManager.registerShuffleTrackerCallback(
                 shuffleId -> 
SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId));
+
+            if 
(celebornConf.clientAdaptiveOptimizeSkewedPartitionReadEnabled()) {
+              lifecycleManager.registerCelebornSkewShuffleCheckCallback(
+                  SparkUtils::isCelebornSkewShuffleOrChildShuffle);
+            }
           }
         }
       }
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index 6c2e5120e..6443c2163 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -462,4 +462,18 @@ public class SparkUtils {
       sparkContext.addSparkListener(listener);
     }
   }
+
+  private static final DynMethods.UnboundMethod isCelebornSkewShuffle_METHOD =
+      DynMethods.builder("isCelebornSkewedShuffle")
+          .hiddenImpl("org.apache.spark.celeborn.CelebornShuffleState", 
Integer.TYPE)
+          .orNoop()
+          .build();
+
+  public static boolean isCelebornSkewShuffleOrChildShuffle(int appShuffleId) {
+    if (!isCelebornSkewShuffle_METHOD.isNoop()) {
+      return isCelebornSkewShuffle_METHOD.asStatic().invoke(appShuffleId);
+    } else {
+      return false;
+    }
+  }
 }
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 285da2296..f706eeb90 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -909,7 +909,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
                 // For barrier stages, all tasks are re-executed when it is 
re-run : similar to indeterminate stage.
                 // So if a barrier stage is getting reexecuted, previous 
stage/attempt needs to
                 // be cleaned up as it is entirely unusuable
-                if (determinate && !isBarrierStage)
+                if (determinate && !isBarrierStage && 
!isCelebornSkewShuffleOrChildShuffle(
+                    appShuffleId))
                   shuffleIds.values.toSeq.reverse.find(e => e._2 == true)
                 else
                   None
@@ -1057,6 +1058,14 @@ class LifecycleManager(val appUniqueId: String, val 
conf: CelebornConf) extends
     }
   }
 
+  private def isCelebornSkewShuffleOrChildShuffle(appShuffleId: Int): Boolean 
= {
+    celebornSkewShuffleCheckCallback match {
+      case Some(skewShuffleCallback) =>
+        skewShuffleCallback.apply(appShuffleId)
+      case None => false
+    }
+  }
+
   private def handleStageEnd(shuffleId: Int): Unit = {
     // check whether shuffle has registered
     if (!registeredShuffle.contains(shuffleId)) {
@@ -1843,6 +1852,13 @@ class LifecycleManager(val appUniqueId: String, val 
conf: CelebornConf) extends
     registerShuffleResponseRpcCache.invalidate(shuffleId)
   }
 
+  @volatile private var celebornSkewShuffleCheckCallback
+      : Option[function.Function[Integer, Boolean]] = None
+  def registerCelebornSkewShuffleCheckCallback(callback: 
function.Function[Integer, Boolean])
+      : Unit = {
+    celebornSkewShuffleCheckCallback = Some(callback)
+  }
+
   // Initialize at the end of LifecycleManager construction.
   initialize()
 

Reply via email to