This is an automated email from the ASF dual-hosted git repository.

ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 0b5a09a9f [CELEBORN-1896] delete data from failed to fetch shuffles
0b5a09a9f is described below

commit 0b5a09a9f796fc85c7516a8d3d108ca98c938c0b
Author: CodingCat <[email protected]>
AuthorDate: Wed May 21 11:23:11 2025 +0800

    [CELEBORN-1896] delete data from failed to fetch shuffles
    
    ### What changes were proposed in this pull request?
    
    it's a joint work with YutingWang98
    
    currently we have to wait for spark shuffle object gc to clean disk space 
occupied by celeborn shuffles
    
    As a result, if a shuffle is failed to fetch and retried , the disk space 
occupied by the failed attempt cannot really be cleaned , we hit this issue 
internally when we have to deal with 100s of TB level shuffles in a single 
spark application, any hiccup in servers can double even triple the disk usage
    
    this PR implements the mechanism to delete files from failed-to-fetch 
shuffles
    
    the main idea is actually simple, it triggers clean up in LifecycleManager 
when it applies for a new celeborn shuffle id for a retried shuffle write stage
    
    the tricky part is that is to avoid delete shuffle files when it is 
referred by multiple downstream stages: the PR introduces RunningStageManager 
to track the dependency between stages
    
    ### Why are the changes needed?
    
    saving disk space
    
    ### Does this PR introduce _any_ user-facing change?
    
    a new config
    
    ### How was this patch tested?
    
    we manually delete some files
    
    
![image](https://github.com/user-attachments/assets/4136cd52-78b2-44e7-8244-db3c5bf9d9c4)
    
    from the above screenshot we can see that originally we have shuffle 0, 1 
and after 1 faced with chunk fetch failure, it triggers a retry of 0 (shuffle 
2), but at this moment, 0 has been deleted from the workers
    
    
![image](https://github.com/user-attachments/assets/7d3b4d90-ae5a-4a54-8dec-a5005850ef0a)
    
    in the logs, we can see that in the middle the application , the unregister 
shuffle request was sent for shuffle 0
    
    Closes #3109 from CodingCat/delete_fi.
    
    Lead-authored-by: CodingCat <[email protected]>
    Co-authored-by: Wang, Fei <[email protected]>
    Co-authored-by: Fei Wang <[email protected]>
    Co-authored-by: Fei Wang <[email protected]>
    Signed-off-by: mingji <[email protected]>
---
 .../spark/shuffle/celeborn/SparkCommonUtils.java   |   8 ++
 .../celeborn/spark/FailedShuffleCleaner.scala      |  93 +++++++++++++
 .../shuffle/celeborn/SparkShuffleManager.java      |  27 ++++
 .../apache/spark/shuffle/celeborn/SparkUtils.java  |  20 ++-
 .../apache/celeborn/client/LifecycleManager.scala  |  26 +++-
 .../commit/ReducePartitionCommitHandler.scala      |   2 +-
 .../org/apache/celeborn/common/CelebornConf.scala  |  20 +++
 docs/configuration/client.md                       |   2 +
 pom.xml                                            |   4 +-
 .../spark/CelebornFetchFailureDiskCleanSuite.scala | 154 +++++++++++++++++++++
 .../tests/spark/CelebornFetchFailureSuite.scala    |   9 +-
 .../celeborn/tests/spark/SparkTestBase.scala       |  42 ------
 .../fetch/failure/ShuffleReaderGetHooks.scala      |  93 +++++++++++++
 .../spark/shuffle/celeborn/SparkUtilsSuite.scala   |   3 +-
 14 files changed, 444 insertions(+), 59 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
index a24e06d5a..84d74f8c1 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
@@ -52,6 +52,14 @@ public class SparkCommonUtils {
     }
   }
 
+  public static String encodeAppShuffleIdentifier(int appShuffleId, 
TaskContext context) {
+    return appShuffleId + "-" + context.stageId() + "-" + 
context.stageAttemptNumber();
+  }
+
+  public static String[] decodeAppShuffleIdentifier(String 
appShuffleIdentifier) {
+    return appShuffleIdentifier.split("-");
+  }
+
   public static int getEncodedAttemptNumber(TaskContext context) {
     return (context.stageAttemptNumber() << 16) | context.attemptNumber();
   }
diff --git 
a/client-spark/common/src/main/scala/org/apache/celeborn/spark/FailedShuffleCleaner.scala
 
b/client-spark/common/src/main/scala/org/apache/celeborn/spark/FailedShuffleCleaner.scala
new file mode 100644
index 000000000..e88f6f640
--- /dev/null
+++ 
b/client-spark/common/src/main/scala/org/apache/celeborn/spark/FailedShuffleCleaner.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.celeborn.spark
+
+import java.util
+import java.util.concurrent.{LinkedBlockingQueue, ScheduledExecutorService, 
TimeUnit}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.spark.shuffle.celeborn.SparkCommonUtils
+
+import org.apache.celeborn.client.LifecycleManager
+import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.util.ThreadUtils
+
+private[celeborn] class FailedShuffleCleaner(lifecycleManager: 
LifecycleManager) extends Logging {
+
+  // in celeborn ids
+  private val shufflesToBeCleaned = new LinkedBlockingQueue[Int]()
+  private val cleanedShuffleIds = new mutable.HashSet[Int]
+
+  private lazy val cleanInterval =
+    lifecycleManager.conf.clientFetchCleanFailedShuffleIntervalMS
+
+  // for test
+  def reset(): Unit = {
+    shufflesToBeCleaned.clear()
+    cleanedShuffleIds.clear()
+    if (cleanerThreadPool != null) {
+      cleanerThreadPool.shutdownNow()
+      cleanerThreadPool = null
+    }
+  }
+
+  def addShuffleIdToBeCleaned(appShuffleIdentifier: String): Unit = {
+    val Array(appShuffleId, _, _) = 
SparkCommonUtils.decodeAppShuffleIdentifier(
+      appShuffleIdentifier)
+    lifecycleManager.getShuffleIdMapping.get(appShuffleId.toInt).foreach {
+      case (_, (celebornShuffleId, _)) => 
shufflesToBeCleaned.put(celebornShuffleId)
+    }
+  }
+
+  def init(): Unit = {
+    cleanerThreadPool = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
+      "failedShuffleCleanerThreadPool")
+    cleanerThreadPool.scheduleWithFixedDelay(
+      new Runnable {
+        override def run(): Unit = {
+          try {
+            val allShuffleIds = new util.ArrayList[Int]
+            shufflesToBeCleaned.drainTo(allShuffleIds)
+            allShuffleIds.asScala.foreach { shuffleId =>
+              if (!cleanedShuffleIds.contains(shuffleId)) {
+                lifecycleManager.unregisterShuffle(shuffleId)
+                logInfo(
+                  s"sent unregister shuffle request for shuffle $shuffleId 
(celeborn shuffle id)")
+                cleanedShuffleIds += shuffleId
+              }
+            }
+          } catch {
+            case e: Exception =>
+              logError("unexpected exception in cleaner thread", e)
+          }
+        }
+      },
+      cleanInterval,
+      cleanInterval,
+      TimeUnit.MILLISECONDS)
+  }
+
+  init()
+
+  def removeCleanedShuffleId(celebornShuffleId: Int): Unit = {
+    cleanedShuffleIds.remove(celebornShuffleId)
+  }
+
+  private var cleanerThreadPool: ScheduledExecutorService = _
+}
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 8c099b29f..80ea5c256 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -37,6 +37,7 @@ import org.apache.celeborn.client.ShuffleClient;
 import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.protocol.ShuffleMode;
 import org.apache.celeborn.reflect.DynMethods;
+import org.apache.celeborn.spark.FailedShuffleCleaner;
 
 /**
  * In order to support Spark Stage resubmit with ShuffleReader FetchFails, 
Celeborn shuffleId has to
@@ -84,6 +85,8 @@ public class SparkShuffleManager implements ShuffleManager {
       ConcurrentHashMap.newKeySet();
   private final CelebornShuffleFallbackPolicyRunner fallbackPolicyRunner;
 
+  private FailedShuffleCleaner failedShuffleCleaner = null;
+
   private long sendBufferPoolCheckInterval;
   private long sendBufferPoolExpireTimeout;
 
@@ -158,6 +161,23 @@ public class SparkShuffleManager implements ShuffleManager 
{
             }
           }
 
+          if (lifecycleManager.conf().clientFetchCleanFailedShuffle()) {
+            if (!lifecycleManager.conf().clientStageRerunEnabled()) {
+              throw new IllegalArgumentException(
+                  CelebornConf.CLIENT_STAGE_RERUN_ENABLED().key()
+                      + " has to be "
+                      + "enabled, when "
+                      + CelebornConf.CLIENT_FETCH_CLEAN_FAILED_SHUFFLE().key()
+                      + " is set to true");
+            }
+            failedShuffleCleaner = new FailedShuffleCleaner(lifecycleManager);
+            lifecycleManager.registerValidateCelebornShuffleIdForCleanCallback(
+                (appShuffleIdentifier) ->
+                    SparkUtils.addWriterShuffleIdsToBeCleaned(this, 
appShuffleIdentifier));
+            lifecycleManager.registerUnregisterShuffleCallback(
+                (celebornShuffleId) -> SparkUtils.removeCleanedShuffleId(this, 
celebornShuffleId));
+          }
+
           if (celebornConf.getReducerFileGroupBroadcastEnabled()) {
             
lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback(
                 (shuffleId, getReducerFileGroupResponse) ->
@@ -249,6 +269,9 @@ public class SparkShuffleManager implements ShuffleManager {
       _sortShuffleManager.stop();
       _sortShuffleManager = null;
     }
+    if (celebornConf.clientFetchCleanFailedShuffle()) {
+      failedShuffleCleaner.reset();
+    }
   }
 
   @Override
@@ -470,4 +493,8 @@ public class SparkShuffleManager implements ShuffleManager {
   public LifecycleManager getLifecycleManager() {
     return this.lifecycleManager;
   }
+
+  public FailedShuffleCleaner getFailedShuffleCleaner() {
+    return this.failedShuffleCleaner;
+  }
 }
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index b2e64565e..fc5d605d8 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -128,17 +128,14 @@ public class SparkUtils {
         .getOrElse(context::applicationId);
   }
 
-  public static String getAppShuffleIdentifier(int appShuffleId, TaskContext 
context) {
-    return appShuffleId + "-" + context.stageId() + "-" + 
context.stageAttemptNumber();
-  }
-
   public static int celebornShuffleId(
       ShuffleClient client,
       CelebornShuffleHandle<?, ?, ?> handle,
       TaskContext context,
       Boolean isWriter) {
     if (handle.throwsFetchFailure()) {
-      String appShuffleIdentifier = 
getAppShuffleIdentifier(handle.shuffleId(), context);
+      String appShuffleIdentifier =
+          SparkCommonUtils.encodeAppShuffleIdentifier(handle.shuffleId(), 
context);
       Tuple2<Integer, Boolean> res =
           client.getShuffleId(
               handle.shuffleId(),
@@ -327,7 +324,8 @@ public class SparkUtils {
 
     if (!(taskContext instanceof BarrierTaskContext)) return;
     int appShuffleId = handle.shuffleId();
-    String appShuffleIdentifier = 
SparkUtils.getAppShuffleIdentifier(appShuffleId, taskContext);
+    String appShuffleIdentifier =
+        SparkCommonUtils.encodeAppShuffleIdentifier(appShuffleId, taskContext);
 
     BarrierTaskContext barrierContext = (BarrierTaskContext) taskContext;
     barrierContext.addTaskFailureListener(
@@ -625,4 +623,14 @@ public class SparkUtils {
           return null;
         });
   }
+
+  public static void addWriterShuffleIdsToBeCleaned(
+      SparkShuffleManager sparkShuffleManager, String appShuffleIdentifier) {
+    
sparkShuffleManager.getFailedShuffleCleaner().addShuffleIdToBeCleaned(appShuffleIdentifier);
+  }
+
+  public static void removeCleanedShuffleId(
+      SparkShuffleManager sparkShuffleManager, int celebornShuffleId) {
+    
sparkShuffleManager.getFailedShuffleCleaner().removeCleanedShuffleId(celebornShuffleId);
+  }
 }
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 69920ee84..6bc12b8e8 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -940,6 +940,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
                   logInfo(s"reuse existing shuffleId $id for appShuffleId 
$appShuffleId appShuffleIdentifier $appShuffleIdentifier")
                   id
                 } else {
+                  // this branch means it is a redo of previous write stage
                   if (isBarrierStage) {
                     // unregister previous shuffle(s) which are still valid
                     val mapUpdates = shuffleIds.filter(_._2._2).map { kv =>
@@ -950,6 +951,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
                   }
                   val newShuffleId = shuffleIdGenerator.getAndIncrement()
                   logInfo(s"generate new shuffleId $newShuffleId for 
appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier")
+                  validateCelebornShuffleIdForClean.foreach(callback =>
+                    callback.accept(appShuffleIdentifier))
                   shuffleIds.put(appShuffleIdentifier, (newShuffleId, true))
                   newShuffleId
                 }
@@ -963,11 +966,12 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
       } else {
         shuffleIds.values.filter(v => v._2).map(v => v._1).toSeq.reverse.find(
           areAllMapTasksEnd) match {
-          case Some(shuffleId) =>
+          case Some(celebornShuffleId) =>
             val pbGetShuffleIdResponse = {
               logDebug(
-                s"get shuffleId $shuffleId for appShuffleId $appShuffleId 
appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter")
-              
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).setSuccess(true).build()
+                s"get shuffleId $celebornShuffleId for appShuffleId 
$appShuffleId appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter")
+              
PbGetShuffleIdResponse.newBuilder().setShuffleId(celebornShuffleId).setSuccess(
+                true).build()
             }
             context.reply(pbGetShuffleIdResponse)
           case None =>
@@ -1169,6 +1173,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
           shuffleIds.values.map {
             case (shuffleId, _) =>
               unregisterShuffle(shuffleId)
+              unregisterShuffleCallback.foreach(c => c.accept(shuffleId))
           })
       }
     } else {
@@ -1859,6 +1864,19 @@ class LifecycleManager(val appUniqueId: String, val 
conf: CelebornConf) extends
     appShuffleTrackerCallback = Some(callback)
   }
 
+  // expecting celeborn shuffle id and application shuffle identifier
+  @volatile private var validateCelebornShuffleIdForClean: 
Option[Consumer[String]] =
+    None
+  def registerValidateCelebornShuffleIdForCleanCallback(
+      callback: Consumer[String]): Unit = {
+    validateCelebornShuffleIdForClean = Some(callback)
+  }
+
+  @volatile private var unregisterShuffleCallback: Option[Consumer[Integer]] = 
None
+  def registerUnregisterShuffleCallback(callback: Consumer[Integer]): Unit = {
+    unregisterShuffleCallback = Some(callback)
+  }
+
   def registerAppShuffleDeterminate(appShuffleId: Int, determinate: Boolean): 
Unit = {
     appShuffleDeterminateMap.put(appShuffleId, determinate)
   }
@@ -1952,4 +1970,6 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
         }
       })
   }
+
+  def getShuffleIdMapping = shuffleIdMapping
 }
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 944e07657..5fa4394ad 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -70,7 +70,7 @@ class ReducePartitionCommitHandler(
 
   private val getReducerFileGroupRequest =
     JavaUtils.newConcurrentHashMap[Int, 
util.Set[MultiSerdeVersionRpcContext]]()
-  private val dataLostShuffleSet = ConcurrentHashMap.newKeySet[Int]()
+  private[celeborn] val dataLostShuffleSet = ConcurrentHashMap.newKeySet[Int]()
   private val stageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
   private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
   private val shuffleMapperAttempts = JavaUtils.newConcurrentHashMap[Int, 
Array[Int]]()
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 8dbccaa1f..d3a9977ef 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -1001,6 +1001,9 @@ class CelebornConf(loadDefaults: Boolean) extends 
Cloneable with Logging with Se
 
   def clientFetchMaxRetriesForEachReplica: Int = 
get(CLIENT_FETCH_MAX_RETRIES_FOR_EACH_REPLICA)
   def clientStageRerunEnabled: Boolean = get(CLIENT_STAGE_RERUN_ENABLED)
+  def clientFetchCleanFailedShuffle: Boolean = 
get(CLIENT_FETCH_CLEAN_FAILED_SHUFFLE)
+  def clientFetchCleanFailedShuffleIntervalMS: Long =
+    get(CLIENT_FETCH_CLEAN_FAILED_SHUFFLE_INTERVAL)
   def clientFetchExcludeWorkerOnFailureEnabled: Boolean =
     get(CLIENT_FETCH_EXCLUDE_WORKER_ON_FAILURE_ENABLED)
   def clientFetchExcludedWorkerExpireTimeout: Long =
@@ -4827,6 +4830,23 @@ object CelebornConf extends Logging {
       .booleanConf
       .createWithDefault(true)
 
+  val CLIENT_FETCH_CLEAN_FAILED_SHUFFLE: ConfigEntry[Boolean] =
+    buildConf("celeborn.client.spark.fetch.cleanFailedShuffle")
+      .categories("client")
+      .version("0.6.0")
+      .doc("whether to clean those disk space occupied by shuffles which 
cannot be fetched")
+      .booleanConf
+      .createWithDefault(false)
+
+  val CLIENT_FETCH_CLEAN_FAILED_SHUFFLE_INTERVAL: ConfigEntry[Long] =
+    buildConf("celeborn.client.spark.fetch.cleanFailedShuffleInterval")
+      .categories("client")
+      .version("0.6.0")
+      .doc("the interval to clean the failed-to-fetch shuffle files, only 
valid when" +
+        s" ${CLIENT_FETCH_CLEAN_FAILED_SHUFFLE.key} is enabled")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createWithDefaultString("1s")
+
   val CLIENT_FETCH_EXCLUDE_WORKER_ON_FAILURE_ENABLED: ConfigEntry[Boolean] =
     buildConf("celeborn.client.fetch.excludeWorkerOnFailure.enabled")
       .categories("client")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 6c0ff752d..e4e8e0e83 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -111,6 +111,8 @@ license: |
 | celeborn.client.shuffle.register.filterExcludedWorker.enabled | false | 
false | Whether to filter excluded worker when register shuffle. | 0.4.0 |  | 
 | celeborn.client.shuffle.reviseLostShuffles.enabled | false | false | Whether 
to revise lost shuffles. | 0.6.0 |  | 
 | celeborn.client.slot.assign.maxWorkers | 10000 | false | Max workers that 
slots of one shuffle can be allocated on. Will choose the smaller positive one 
from Master side and Client side, see `celeborn.master.slot.assign.maxWorkers`. 
| 0.3.1 |  | 
+| celeborn.client.spark.fetch.cleanFailedShuffle | false | false | whether to 
clean those disk space occupied by shuffles which cannot be fetched | 0.6.0 |  
| 
+| celeborn.client.spark.fetch.cleanFailedShuffleInterval | 1s | false | the 
interval to clean the failed-to-fetch shuffle files, only valid when 
celeborn.client.spark.fetch.cleanFailedShuffle is enabled | 0.6.0 |  | 
 | celeborn.client.spark.push.dynamicWriteMode.enabled | false | false | 
Whether to dynamically switch push write mode based on conditions.If true, 
shuffle mode will be only determined by partition count | 0.5.0 |  | 
 | celeborn.client.spark.push.dynamicWriteMode.partitionNum.threshold | 2000 | 
false | Threshold of shuffle partition number for dynamically switching push 
writer mode. When the shuffle partition number is greater than this value, use 
the sort-based shuffle writer for memory efficiency; otherwise use the 
hash-based shuffle writer for speed. This configuration only takes effect when 
celeborn.client.spark.push.dynamicWriteMode.enabled is true. | 0.5.0 |  | 
 | celeborn.client.spark.push.sort.memory.maxMemoryFactor | 0.4 | false | the 
max portion of executor memory which can be used for SortBasedWriter buffer 
(only valid when celeborn.client.spark.push.sort.memory.useAdaptiveThreshold is 
enabled | 0.5.0 |  | 
diff --git a/pom.xml b/pom.xml
index 31a38cda1..e57cb72a9 100644
--- a/pom.xml
+++ b/pom.xml
@@ -907,7 +907,7 @@
               
<log4j.configuration>file:src/test/resources/log4j.properties</log4j.configuration>
               
<log4j2.configurationFile>src/test/resources/log4j2-test.xml</log4j2.configurationFile>
               <java.io.tmpdir>${project.build.directory}/tmp</java.io.tmpdir>
-              <spark.driver.memory>1g</spark.driver.memory>
+              <spark.driver.memory>8g</spark.driver.memory>
               
<spark.shuffle.sort.io.plugin.class>${spark.shuffle.plugin.class}</spark.shuffle.sort.io.plugin.class>
             </systemProperties>
             <environmentVariables>
@@ -946,7 +946,7 @@
               
<log4j.configuration>file:src/test/resources/log4j.properties</log4j.configuration>
               
<log4j2.configurationFile>src/test/resources/log4j2-test.xml</log4j2.configurationFile>
               <java.io.tmpdir>${project.build.directory}/tmp</java.io.tmpdir>
-              <spark.driver.memory>1g</spark.driver.memory>
+              <spark.driver.memory>8g</spark.driver.memory>
               
<spark.shuffle.sort.io.plugin.class>${spark.shuffle.plugin.class}</spark.shuffle.sort.io.plugin.class>
             </systemProperties>
             <environmentVariables>
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala
new file mode 100644
index 000000000..936ea6961
--- /dev/null
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.celeborn.tests.spark
+
+import java.io.File
+
+import org.apache.spark.SparkConf
+import org.apache.spark.shuffle.celeborn.{SparkUtils, 
TestCelebornShuffleManager}
+import org.apache.spark.sql.SparkSession
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.ShuffleClient
+import org.apache.celeborn.common.protocol.ShuffleMode
+import org.apache.celeborn.service.deploy.worker.Worker
+import org.apache.celeborn.tests.spark.fetch.failure.ShuffleReaderGetHooks
+
+class CelebornFetchFailureDiskCleanSuite extends AnyFunSuite
+  with SparkTestBase
+  with BeforeAndAfterEach {
+
+  override def beforeAll(): Unit = {
+    logInfo("test initialized , setup Celeborn mini cluster")
+    setupMiniClusterWithRandomPorts(workerNum = 1)
+  }
+
+  override def beforeEach(): Unit = {
+    ShuffleClient.reset()
+  }
+
+  override def afterEach(): Unit = {
+    System.gc()
+  }
+
+  override def createWorker(map: Map[String, String]): Worker = {
+    val storageDir = createTmpDir()
+    workerDirs = workerDirs :+ storageDir
+    super.createWorker(map ++ Map("celeborn.master.heartbeat.worker.timeout" 
-> "10s"), storageDir)
+  }
+
+  test("celeborn spark integration test - the failed shuffle file is cleaned 
up correctly") {
+    if (Spark3OrNewer) {
+      val sparkConf = new 
SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
+      val sparkSession = SparkSession.builder()
+        .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+        .config("spark.sql.shuffle.partitions", 2)
+        .config("spark.celeborn.shuffle.forceFallback.partition.enabled", 
false)
+        .config("spark.celeborn.client.spark.stageRerun.enabled", "true")
+        .config(
+          "spark.shuffle.manager",
+          "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
+        .config("spark.celeborn.client.spark.fetch.cleanFailedShuffle", "true")
+        .getOrCreate()
+
+      val celebornConf = 
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
+      val hook = new ShuffleReaderGetHooks(
+        celebornConf,
+        workerDirs,
+        shuffleIdToBeDeleted = Seq(0))
+      TestCelebornShuffleManager.registerReaderGetHook(hook)
+      val checkingThread =
+        triggerStorageCheckThread(Seq(0), Seq(1), sparkSession)
+      val tuples = sparkSession.sparkContext.parallelize(1 to 10000, 2)
+        .map { i => (i, i) }.groupByKey(4).collect()
+      checkStorageValidation(checkingThread)
+      // verify result
+      assert(hook.executed.get())
+      assert(tuples.length == 10000)
+      for (elem <- tuples) {
+        elem._2.foreach(i => assert(i.equals(elem._1)))
+      }
+      sparkSession.stop()
+    }
+  }
+
+  class CheckingThread(
+      shuffleIdShouldNotExist: Seq[Int],
+      shuffleIdMustExist: Seq[Int],
+      sparkSession: SparkSession)
+    extends Thread {
+    var exception: Exception = _
+
+    protected def checkDirStatus(): Boolean = {
+      val deletedSuccessfully = shuffleIdShouldNotExist.forall(shuffleId => {
+        workerDirs.forall(dir =>
+          !new File(s"$dir/celeborn-worker/shuffle_data/" +
+            s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists())
+      })
+      val deletedSuccessfullyString = shuffleIdShouldNotExist.map(shuffleId => 
{
+        shuffleId.toString + ":" +
+          workerDirs.map(dir =>
+            !new File(s"$dir/celeborn-worker/shuffle_data/" +
+              
s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()).toList
+      }).mkString(",")
+      val createdSuccessfully = shuffleIdMustExist.forall(shuffleId => {
+        workerDirs.exists(dir =>
+          new File(s"$dir/celeborn-worker/shuffle_data/" +
+            s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists())
+      })
+      val createdSuccessfullyString = shuffleIdMustExist.map(shuffleId => {
+        shuffleId.toString + ":" +
+          workerDirs.map(dir =>
+            new File(s"$dir/celeborn-worker/shuffle_data/" +
+              
s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()).toList
+      }).mkString(",")
+      println(s"shuffle-to-be-deleted status: $deletedSuccessfullyString \n" +
+        s"shuffle-to-be-created status: $createdSuccessfullyString")
+      deletedSuccessfully && createdSuccessfully
+    }
+
+    override def run(): Unit = {
+      var allDataInShape = checkDirStatus()
+      while (!allDataInShape) {
+        Thread.sleep(1000)
+        allDataInShape = checkDirStatus()
+      }
+    }
+  }
+
+  protected def triggerStorageCheckThread(
+      shuffleIdShouldNotExist: Seq[Int],
+      shuffleIdMustExist: Seq[Int],
+      sparkSession: SparkSession): CheckingThread = {
+    val checkingThread =
+      new CheckingThread(shuffleIdShouldNotExist, shuffleIdMustExist, 
sparkSession)
+    checkingThread.setDaemon(true)
+    checkingThread.start()
+    checkingThread
+  }
+
+  protected def checkStorageValidation(thread: Thread, timeout: Long = 1200 * 
1000): Unit = {
+    val checkingThread = thread.asInstanceOf[CheckingThread]
+    checkingThread.join(timeout)
+    if (checkingThread.isAlive || checkingThread.exception != null) {
+      throw new IllegalStateException("the storage checking status failed," +
+        s"${checkingThread.isAlive} ${if (checkingThread.exception != null) 
checkingThread.exception.getMessage
+        else "NULL"}")
+    }
+  }
+}
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
index dd0f38401..9db3912a7 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
@@ -30,6 +30,7 @@ import org.scalatest.funsuite.AnyFunSuite
 
 import org.apache.celeborn.client.ShuffleClient
 import org.apache.celeborn.common.protocol.ShuffleMode
+import org.apache.celeborn.tests.spark.fetch.failure.ShuffleReaderGetHooks
 
 class CelebornFetchFailureSuite extends AnyFunSuite
   with SparkTestBase
@@ -57,7 +58,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
         .getOrCreate()
 
       val celebornConf = 
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
-      val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
+      val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs)
       TestCelebornShuffleManager.registerReaderGetHook(hook)
 
       val value = Range(1, 10000).mkString(",")
@@ -130,7 +131,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
         .getOrCreate()
 
       val celebornConf = 
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
-      val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
+      val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs)
       TestCelebornShuffleManager.registerReaderGetHook(hook)
 
       import sparkSession.implicits._
@@ -161,7 +162,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
         .getOrCreate()
 
       val celebornConf = 
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
-      val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
+      val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs)
       TestCelebornShuffleManager.registerReaderGetHook(hook)
 
       val sc = sparkSession.sparkContext
@@ -201,7 +202,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
         .getOrCreate()
 
       val celebornConf = 
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
-      val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
+      val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs)
       TestCelebornShuffleManager.registerReaderGetHook(hook)
 
       val sc = sparkSession.sparkContext
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
index e29b21a0c..41cbe072b 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
@@ -116,46 +116,4 @@ trait SparkTestBase extends AnyFunSuite
     val outMap = result.collect().map(row => row.getString(0) -> 
row.getLong(1)).toMap
     outMap
   }
-
-  class ShuffleReaderFetchFailureGetHook(conf: CelebornConf) extends 
ShuffleManagerHook {
-    var executed: AtomicBoolean = new AtomicBoolean(false)
-    val lock = new Object
-
-    override def exec(
-        handle: ShuffleHandle,
-        startPartition: Int,
-        endPartition: Int,
-        context: TaskContext): Unit = {
-      if (executed.get() == true) return
-
-      lock.synchronized {
-        handle match {
-          case h: CelebornShuffleHandle[_, _, _] => {
-            val appUniqueId = h.appUniqueId
-            val shuffleClient = ShuffleClient.get(
-              h.appUniqueId,
-              h.lifecycleManagerHost,
-              h.lifecycleManagerPort,
-              conf,
-              h.userIdentifier,
-              h.extension)
-            val celebornShuffleId =
-              SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
-            val allFiles = workerDirs.map(dir => {
-              new 
File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
-            })
-            val datafile = allFiles.filter(_.exists())
-              .flatMap(_.listFiles().iterator).sortBy(_.getName).headOption
-            datafile match {
-              case Some(file) => file.delete()
-              case None => throw new RuntimeException("unexpected, there must 
be some data file" +
-                  s" under ${workerDirs.mkString(",")}")
-            }
-          }
-          case _ => throw new RuntimeException("unexpected, only support 
RssShuffleHandle here")
-        }
-        executed.set(true)
-      }
-    }
-  }
 }
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala
new file mode 100644
index 000000000..adac14242
--- /dev/null
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.spark.fetch.failure
+
+import java.io.File
+import java.util.concurrent.atomic.AtomicBoolean
+
+import org.apache.spark.TaskContext
+import org.apache.spark.shuffle.ShuffleHandle
+import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, 
ShuffleManagerHook, SparkCommonUtils, SparkShuffleManager, SparkUtils, 
TestCelebornShuffleManager}
+
+import org.apache.celeborn.client.ShuffleClient
+import org.apache.celeborn.common.CelebornConf
+
+class ShuffleReaderGetHooks(
+    conf: CelebornConf,
+    workerDirs: Seq[String],
+    shuffleIdToBeDeleted: Seq[Int] = Seq(),
+    triggerStageId: Option[Int] = None)
+  extends ShuffleManagerHook {
+
+  var executed: AtomicBoolean = new AtomicBoolean(false)
+  val lock = new Object
+
+  private def deleteDataFile(appUniqueId: String, celebornShuffleId: Int): 
Unit = {
+    val datafile =
+      workerDirs.map(dir => {
+        new 
File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
+      }).filter(_.exists())
+        .flatMap(_.listFiles().iterator).headOption
+    datafile match {
+      case Some(file) => {
+        file.delete()
+      }
+      case None => throw new RuntimeException("unexpected, there must be some 
data file")
+    }
+  }
+
+  override def exec(
+      handle: ShuffleHandle,
+      startPartition: Int,
+      endPartition: Int,
+      context: TaskContext): Unit = {
+    if (executed.get()) {
+      return
+    }
+    lock.synchronized {
+      handle match {
+        case h: CelebornShuffleHandle[_, _, _] => {
+          val appUniqueId = h.appUniqueId
+          val shuffleClient = ShuffleClient.get(
+            h.appUniqueId,
+            h.lifecycleManagerHost,
+            h.lifecycleManagerPort,
+            conf,
+            h.userIdentifier,
+            h.extension)
+          val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, 
h, context, false)
+          val appShuffleIdentifier =
+            SparkCommonUtils.encodeAppShuffleIdentifier(handle.shuffleId, 
context)
+          val Array(_, stageId, _) = appShuffleIdentifier.split('-')
+          if (triggerStageId.isEmpty || triggerStageId.get == stageId.toInt) {
+            if (shuffleIdToBeDeleted.isEmpty) {
+              deleteDataFile(appUniqueId, celebornShuffleId)
+            } else {
+              shuffleIdToBeDeleted.foreach { shuffleId =>
+                deleteDataFile(appUniqueId, shuffleId)
+              }
+            }
+            executed.set(true)
+          }
+        }
+        case x => throw new RuntimeException(s"unexpected, only support 
RssShuffleHandle here," +
+            s" but get ${x.getClass.getCanonicalName}")
+      }
+    }
+  }
+}
diff --git 
a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
index 83a5e12f6..293ef080e 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
@@ -33,6 +33,7 @@ import 
org.apache.celeborn.common.protocol.{PartitionLocation, ShuffleMode}
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
 import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.tests.spark.SparkTestBase
+import org.apache.celeborn.tests.spark.fetch.failure.ShuffleReaderGetHooks
 
 class SparkUtilsSuite extends AnyFunSuite
   with SparkTestBase
@@ -60,7 +61,7 @@ class SparkUtilsSuite extends AnyFunSuite
         .getOrCreate()
 
       val celebornConf = 
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
-      val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
+      val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs)
       TestCelebornShuffleManager.registerReaderGetHook(hook)
 
       try {


Reply via email to