This is an automated email from the ASF dual-hosted git repository.
ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 0b5a09a9f [CELEBORN-1896] delete data from failed to fetch shuffles
0b5a09a9f is described below
commit 0b5a09a9f796fc85c7516a8d3d108ca98c938c0b
Author: CodingCat <[email protected]>
AuthorDate: Wed May 21 11:23:11 2025 +0800
[CELEBORN-1896] delete data from failed to fetch shuffles
### What changes were proposed in this pull request?
it's a joint work with YutingWang98
currently we have to wait for spark shuffle object gc to clean disk space
occupied by celeborn shuffles
As a result, if a shuffle is failed to fetch and retried , the disk space
occupied by the failed attempt cannot really be cleaned , we hit this issue
internally when we have to deal with 100s of TB level shuffles in a single
spark application, any hiccup in servers can double even triple the disk usage
this PR implements the mechanism to delete files from failed-to-fetch
shuffles
the main idea is actually simple, it triggers clean up in LifecycleManager
when it applies for a new celeborn shuffle id for a retried shuffle write stage
the tricky part is that is to avoid delete shuffle files when it is
referred by multiple downstream stages: the PR introduces RunningStageManager
to track the dependency between stages
### Why are the changes needed?
saving disk space
### Does this PR introduce _any_ user-facing change?
a new config
### How was this patch tested?
we manually delete some files

from the above screenshot we can see that originally we have shuffle 0, 1
and after 1 faced with chunk fetch failure, it triggers a retry of 0 (shuffle
2), but at this moment, 0 has been deleted from the workers

in the logs, we can see that in the middle the application , the unregister
shuffle request was sent for shuffle 0
Closes #3109 from CodingCat/delete_fi.
Lead-authored-by: CodingCat <[email protected]>
Co-authored-by: Wang, Fei <[email protected]>
Co-authored-by: Fei Wang <[email protected]>
Co-authored-by: Fei Wang <[email protected]>
Signed-off-by: mingji <[email protected]>
---
.../spark/shuffle/celeborn/SparkCommonUtils.java | 8 ++
.../celeborn/spark/FailedShuffleCleaner.scala | 93 +++++++++++++
.../shuffle/celeborn/SparkShuffleManager.java | 27 ++++
.../apache/spark/shuffle/celeborn/SparkUtils.java | 20 ++-
.../apache/celeborn/client/LifecycleManager.scala | 26 +++-
.../commit/ReducePartitionCommitHandler.scala | 2 +-
.../org/apache/celeborn/common/CelebornConf.scala | 20 +++
docs/configuration/client.md | 2 +
pom.xml | 4 +-
.../spark/CelebornFetchFailureDiskCleanSuite.scala | 154 +++++++++++++++++++++
.../tests/spark/CelebornFetchFailureSuite.scala | 9 +-
.../celeborn/tests/spark/SparkTestBase.scala | 42 ------
.../fetch/failure/ShuffleReaderGetHooks.scala | 93 +++++++++++++
.../spark/shuffle/celeborn/SparkUtilsSuite.scala | 3 +-
14 files changed, 444 insertions(+), 59 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
index a24e06d5a..84d74f8c1 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java
@@ -52,6 +52,14 @@ public class SparkCommonUtils {
}
}
+ public static String encodeAppShuffleIdentifier(int appShuffleId,
TaskContext context) {
+ return appShuffleId + "-" + context.stageId() + "-" +
context.stageAttemptNumber();
+ }
+
+ public static String[] decodeAppShuffleIdentifier(String
appShuffleIdentifier) {
+ return appShuffleIdentifier.split("-");
+ }
+
public static int getEncodedAttemptNumber(TaskContext context) {
return (context.stageAttemptNumber() << 16) | context.attemptNumber();
}
diff --git
a/client-spark/common/src/main/scala/org/apache/celeborn/spark/FailedShuffleCleaner.scala
b/client-spark/common/src/main/scala/org/apache/celeborn/spark/FailedShuffleCleaner.scala
new file mode 100644
index 000000000..e88f6f640
--- /dev/null
+++
b/client-spark/common/src/main/scala/org/apache/celeborn/spark/FailedShuffleCleaner.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.celeborn.spark
+
+import java.util
+import java.util.concurrent.{LinkedBlockingQueue, ScheduledExecutorService,
TimeUnit}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.spark.shuffle.celeborn.SparkCommonUtils
+
+import org.apache.celeborn.client.LifecycleManager
+import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.util.ThreadUtils
+
+private[celeborn] class FailedShuffleCleaner(lifecycleManager:
LifecycleManager) extends Logging {
+
+ // in celeborn ids
+ private val shufflesToBeCleaned = new LinkedBlockingQueue[Int]()
+ private val cleanedShuffleIds = new mutable.HashSet[Int]
+
+ private lazy val cleanInterval =
+ lifecycleManager.conf.clientFetchCleanFailedShuffleIntervalMS
+
+ // for test
+ def reset(): Unit = {
+ shufflesToBeCleaned.clear()
+ cleanedShuffleIds.clear()
+ if (cleanerThreadPool != null) {
+ cleanerThreadPool.shutdownNow()
+ cleanerThreadPool = null
+ }
+ }
+
+ def addShuffleIdToBeCleaned(appShuffleIdentifier: String): Unit = {
+ val Array(appShuffleId, _, _) =
SparkCommonUtils.decodeAppShuffleIdentifier(
+ appShuffleIdentifier)
+ lifecycleManager.getShuffleIdMapping.get(appShuffleId.toInt).foreach {
+ case (_, (celebornShuffleId, _)) =>
shufflesToBeCleaned.put(celebornShuffleId)
+ }
+ }
+
+ def init(): Unit = {
+ cleanerThreadPool = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
+ "failedShuffleCleanerThreadPool")
+ cleanerThreadPool.scheduleWithFixedDelay(
+ new Runnable {
+ override def run(): Unit = {
+ try {
+ val allShuffleIds = new util.ArrayList[Int]
+ shufflesToBeCleaned.drainTo(allShuffleIds)
+ allShuffleIds.asScala.foreach { shuffleId =>
+ if (!cleanedShuffleIds.contains(shuffleId)) {
+ lifecycleManager.unregisterShuffle(shuffleId)
+ logInfo(
+ s"sent unregister shuffle request for shuffle $shuffleId
(celeborn shuffle id)")
+ cleanedShuffleIds += shuffleId
+ }
+ }
+ } catch {
+ case e: Exception =>
+ logError("unexpected exception in cleaner thread", e)
+ }
+ }
+ },
+ cleanInterval,
+ cleanInterval,
+ TimeUnit.MILLISECONDS)
+ }
+
+ init()
+
+ def removeCleanedShuffleId(celebornShuffleId: Int): Unit = {
+ cleanedShuffleIds.remove(celebornShuffleId)
+ }
+
+ private var cleanerThreadPool: ScheduledExecutorService = _
+}
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 8c099b29f..80ea5c256 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -37,6 +37,7 @@ import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.protocol.ShuffleMode;
import org.apache.celeborn.reflect.DynMethods;
+import org.apache.celeborn.spark.FailedShuffleCleaner;
/**
* In order to support Spark Stage resubmit with ShuffleReader FetchFails,
Celeborn shuffleId has to
@@ -84,6 +85,8 @@ public class SparkShuffleManager implements ShuffleManager {
ConcurrentHashMap.newKeySet();
private final CelebornShuffleFallbackPolicyRunner fallbackPolicyRunner;
+ private FailedShuffleCleaner failedShuffleCleaner = null;
+
private long sendBufferPoolCheckInterval;
private long sendBufferPoolExpireTimeout;
@@ -158,6 +161,23 @@ public class SparkShuffleManager implements ShuffleManager
{
}
}
+ if (lifecycleManager.conf().clientFetchCleanFailedShuffle()) {
+ if (!lifecycleManager.conf().clientStageRerunEnabled()) {
+ throw new IllegalArgumentException(
+ CelebornConf.CLIENT_STAGE_RERUN_ENABLED().key()
+ + " has to be "
+ + "enabled, when "
+ + CelebornConf.CLIENT_FETCH_CLEAN_FAILED_SHUFFLE().key()
+ + " is set to true");
+ }
+ failedShuffleCleaner = new FailedShuffleCleaner(lifecycleManager);
+ lifecycleManager.registerValidateCelebornShuffleIdForCleanCallback(
+ (appShuffleIdentifier) ->
+ SparkUtils.addWriterShuffleIdsToBeCleaned(this,
appShuffleIdentifier));
+ lifecycleManager.registerUnregisterShuffleCallback(
+ (celebornShuffleId) -> SparkUtils.removeCleanedShuffleId(this,
celebornShuffleId));
+ }
+
if (celebornConf.getReducerFileGroupBroadcastEnabled()) {
lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback(
(shuffleId, getReducerFileGroupResponse) ->
@@ -249,6 +269,9 @@ public class SparkShuffleManager implements ShuffleManager {
_sortShuffleManager.stop();
_sortShuffleManager = null;
}
+ if (celebornConf.clientFetchCleanFailedShuffle()) {
+ failedShuffleCleaner.reset();
+ }
}
@Override
@@ -470,4 +493,8 @@ public class SparkShuffleManager implements ShuffleManager {
public LifecycleManager getLifecycleManager() {
return this.lifecycleManager;
}
+
+ public FailedShuffleCleaner getFailedShuffleCleaner() {
+ return this.failedShuffleCleaner;
+ }
}
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index b2e64565e..fc5d605d8 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -128,17 +128,14 @@ public class SparkUtils {
.getOrElse(context::applicationId);
}
- public static String getAppShuffleIdentifier(int appShuffleId, TaskContext
context) {
- return appShuffleId + "-" + context.stageId() + "-" +
context.stageAttemptNumber();
- }
-
public static int celebornShuffleId(
ShuffleClient client,
CelebornShuffleHandle<?, ?, ?> handle,
TaskContext context,
Boolean isWriter) {
if (handle.throwsFetchFailure()) {
- String appShuffleIdentifier =
getAppShuffleIdentifier(handle.shuffleId(), context);
+ String appShuffleIdentifier =
+ SparkCommonUtils.encodeAppShuffleIdentifier(handle.shuffleId(),
context);
Tuple2<Integer, Boolean> res =
client.getShuffleId(
handle.shuffleId(),
@@ -327,7 +324,8 @@ public class SparkUtils {
if (!(taskContext instanceof BarrierTaskContext)) return;
int appShuffleId = handle.shuffleId();
- String appShuffleIdentifier =
SparkUtils.getAppShuffleIdentifier(appShuffleId, taskContext);
+ String appShuffleIdentifier =
+ SparkCommonUtils.encodeAppShuffleIdentifier(appShuffleId, taskContext);
BarrierTaskContext barrierContext = (BarrierTaskContext) taskContext;
barrierContext.addTaskFailureListener(
@@ -625,4 +623,14 @@ public class SparkUtils {
return null;
});
}
+
+ public static void addWriterShuffleIdsToBeCleaned(
+ SparkShuffleManager sparkShuffleManager, String appShuffleIdentifier) {
+
sparkShuffleManager.getFailedShuffleCleaner().addShuffleIdToBeCleaned(appShuffleIdentifier);
+ }
+
+ public static void removeCleanedShuffleId(
+ SparkShuffleManager sparkShuffleManager, int celebornShuffleId) {
+
sparkShuffleManager.getFailedShuffleCleaner().removeCleanedShuffleId(celebornShuffleId);
+ }
}
diff --git
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 69920ee84..6bc12b8e8 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -940,6 +940,7 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
logInfo(s"reuse existing shuffleId $id for appShuffleId
$appShuffleId appShuffleIdentifier $appShuffleIdentifier")
id
} else {
+ // this branch means it is a redo of previous write stage
if (isBarrierStage) {
// unregister previous shuffle(s) which are still valid
val mapUpdates = shuffleIds.filter(_._2._2).map { kv =>
@@ -950,6 +951,8 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
}
val newShuffleId = shuffleIdGenerator.getAndIncrement()
logInfo(s"generate new shuffleId $newShuffleId for
appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier")
+ validateCelebornShuffleIdForClean.foreach(callback =>
+ callback.accept(appShuffleIdentifier))
shuffleIds.put(appShuffleIdentifier, (newShuffleId, true))
newShuffleId
}
@@ -963,11 +966,12 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
} else {
shuffleIds.values.filter(v => v._2).map(v => v._1).toSeq.reverse.find(
areAllMapTasksEnd) match {
- case Some(shuffleId) =>
+ case Some(celebornShuffleId) =>
val pbGetShuffleIdResponse = {
logDebug(
- s"get shuffleId $shuffleId for appShuffleId $appShuffleId
appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter")
-
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).setSuccess(true).build()
+ s"get shuffleId $celebornShuffleId for appShuffleId
$appShuffleId appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter")
+
PbGetShuffleIdResponse.newBuilder().setShuffleId(celebornShuffleId).setSuccess(
+ true).build()
}
context.reply(pbGetShuffleIdResponse)
case None =>
@@ -1169,6 +1173,7 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
shuffleIds.values.map {
case (shuffleId, _) =>
unregisterShuffle(shuffleId)
+ unregisterShuffleCallback.foreach(c => c.accept(shuffleId))
})
}
} else {
@@ -1859,6 +1864,19 @@ class LifecycleManager(val appUniqueId: String, val
conf: CelebornConf) extends
appShuffleTrackerCallback = Some(callback)
}
+ // expecting celeborn shuffle id and application shuffle identifier
+ @volatile private var validateCelebornShuffleIdForClean:
Option[Consumer[String]] =
+ None
+ def registerValidateCelebornShuffleIdForCleanCallback(
+ callback: Consumer[String]): Unit = {
+ validateCelebornShuffleIdForClean = Some(callback)
+ }
+
+ @volatile private var unregisterShuffleCallback: Option[Consumer[Integer]] =
None
+ def registerUnregisterShuffleCallback(callback: Consumer[Integer]): Unit = {
+ unregisterShuffleCallback = Some(callback)
+ }
+
def registerAppShuffleDeterminate(appShuffleId: Int, determinate: Boolean):
Unit = {
appShuffleDeterminateMap.put(appShuffleId, determinate)
}
@@ -1952,4 +1970,6 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
}
})
}
+
+ def getShuffleIdMapping = shuffleIdMapping
}
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 944e07657..5fa4394ad 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -70,7 +70,7 @@ class ReducePartitionCommitHandler(
private val getReducerFileGroupRequest =
JavaUtils.newConcurrentHashMap[Int,
util.Set[MultiSerdeVersionRpcContext]]()
- private val dataLostShuffleSet = ConcurrentHashMap.newKeySet[Int]()
+ private[celeborn] val dataLostShuffleSet = ConcurrentHashMap.newKeySet[Int]()
private val stageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
private val shuffleMapperAttempts = JavaUtils.newConcurrentHashMap[Int,
Array[Int]]()
diff --git
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 8dbccaa1f..d3a9977ef 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -1001,6 +1001,9 @@ class CelebornConf(loadDefaults: Boolean) extends
Cloneable with Logging with Se
def clientFetchMaxRetriesForEachReplica: Int =
get(CLIENT_FETCH_MAX_RETRIES_FOR_EACH_REPLICA)
def clientStageRerunEnabled: Boolean = get(CLIENT_STAGE_RERUN_ENABLED)
+ def clientFetchCleanFailedShuffle: Boolean =
get(CLIENT_FETCH_CLEAN_FAILED_SHUFFLE)
+ def clientFetchCleanFailedShuffleIntervalMS: Long =
+ get(CLIENT_FETCH_CLEAN_FAILED_SHUFFLE_INTERVAL)
def clientFetchExcludeWorkerOnFailureEnabled: Boolean =
get(CLIENT_FETCH_EXCLUDE_WORKER_ON_FAILURE_ENABLED)
def clientFetchExcludedWorkerExpireTimeout: Long =
@@ -4827,6 +4830,23 @@ object CelebornConf extends Logging {
.booleanConf
.createWithDefault(true)
+ val CLIENT_FETCH_CLEAN_FAILED_SHUFFLE: ConfigEntry[Boolean] =
+ buildConf("celeborn.client.spark.fetch.cleanFailedShuffle")
+ .categories("client")
+ .version("0.6.0")
+ .doc("whether to clean those disk space occupied by shuffles which
cannot be fetched")
+ .booleanConf
+ .createWithDefault(false)
+
+ val CLIENT_FETCH_CLEAN_FAILED_SHUFFLE_INTERVAL: ConfigEntry[Long] =
+ buildConf("celeborn.client.spark.fetch.cleanFailedShuffleInterval")
+ .categories("client")
+ .version("0.6.0")
+ .doc("the interval to clean the failed-to-fetch shuffle files, only
valid when" +
+ s" ${CLIENT_FETCH_CLEAN_FAILED_SHUFFLE.key} is enabled")
+ .timeConf(TimeUnit.MILLISECONDS)
+ .createWithDefaultString("1s")
+
val CLIENT_FETCH_EXCLUDE_WORKER_ON_FAILURE_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.client.fetch.excludeWorkerOnFailure.enabled")
.categories("client")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 6c0ff752d..e4e8e0e83 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -111,6 +111,8 @@ license: |
| celeborn.client.shuffle.register.filterExcludedWorker.enabled | false |
false | Whether to filter excluded worker when register shuffle. | 0.4.0 | |
| celeborn.client.shuffle.reviseLostShuffles.enabled | false | false | Whether
to revise lost shuffles. | 0.6.0 | |
| celeborn.client.slot.assign.maxWorkers | 10000 | false | Max workers that
slots of one shuffle can be allocated on. Will choose the smaller positive one
from Master side and Client side, see `celeborn.master.slot.assign.maxWorkers`.
| 0.3.1 | |
+| celeborn.client.spark.fetch.cleanFailedShuffle | false | false | whether to
clean those disk space occupied by shuffles which cannot be fetched | 0.6.0 |
|
+| celeborn.client.spark.fetch.cleanFailedShuffleInterval | 1s | false | the
interval to clean the failed-to-fetch shuffle files, only valid when
celeborn.client.spark.fetch.cleanFailedShuffle is enabled | 0.6.0 | |
| celeborn.client.spark.push.dynamicWriteMode.enabled | false | false |
Whether to dynamically switch push write mode based on conditions.If true,
shuffle mode will be only determined by partition count | 0.5.0 | |
| celeborn.client.spark.push.dynamicWriteMode.partitionNum.threshold | 2000 |
false | Threshold of shuffle partition number for dynamically switching push
writer mode. When the shuffle partition number is greater than this value, use
the sort-based shuffle writer for memory efficiency; otherwise use the
hash-based shuffle writer for speed. This configuration only takes effect when
celeborn.client.spark.push.dynamicWriteMode.enabled is true. | 0.5.0 | |
| celeborn.client.spark.push.sort.memory.maxMemoryFactor | 0.4 | false | the
max portion of executor memory which can be used for SortBasedWriter buffer
(only valid when celeborn.client.spark.push.sort.memory.useAdaptiveThreshold is
enabled | 0.5.0 | |
diff --git a/pom.xml b/pom.xml
index 31a38cda1..e57cb72a9 100644
--- a/pom.xml
+++ b/pom.xml
@@ -907,7 +907,7 @@
<log4j.configuration>file:src/test/resources/log4j.properties</log4j.configuration>
<log4j2.configurationFile>src/test/resources/log4j2-test.xml</log4j2.configurationFile>
<java.io.tmpdir>${project.build.directory}/tmp</java.io.tmpdir>
- <spark.driver.memory>1g</spark.driver.memory>
+ <spark.driver.memory>8g</spark.driver.memory>
<spark.shuffle.sort.io.plugin.class>${spark.shuffle.plugin.class}</spark.shuffle.sort.io.plugin.class>
</systemProperties>
<environmentVariables>
@@ -946,7 +946,7 @@
<log4j.configuration>file:src/test/resources/log4j.properties</log4j.configuration>
<log4j2.configurationFile>src/test/resources/log4j2-test.xml</log4j2.configurationFile>
<java.io.tmpdir>${project.build.directory}/tmp</java.io.tmpdir>
- <spark.driver.memory>1g</spark.driver.memory>
+ <spark.driver.memory>8g</spark.driver.memory>
<spark.shuffle.sort.io.plugin.class>${spark.shuffle.plugin.class}</spark.shuffle.sort.io.plugin.class>
</systemProperties>
<environmentVariables>
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala
new file mode 100644
index 000000000..936ea6961
--- /dev/null
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureDiskCleanSuite.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.celeborn.tests.spark
+
+import java.io.File
+
+import org.apache.spark.SparkConf
+import org.apache.spark.shuffle.celeborn.{SparkUtils,
TestCelebornShuffleManager}
+import org.apache.spark.sql.SparkSession
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.ShuffleClient
+import org.apache.celeborn.common.protocol.ShuffleMode
+import org.apache.celeborn.service.deploy.worker.Worker
+import org.apache.celeborn.tests.spark.fetch.failure.ShuffleReaderGetHooks
+
+class CelebornFetchFailureDiskCleanSuite extends AnyFunSuite
+ with SparkTestBase
+ with BeforeAndAfterEach {
+
+ override def beforeAll(): Unit = {
+ logInfo("test initialized , setup Celeborn mini cluster")
+ setupMiniClusterWithRandomPorts(workerNum = 1)
+ }
+
+ override def beforeEach(): Unit = {
+ ShuffleClient.reset()
+ }
+
+ override def afterEach(): Unit = {
+ System.gc()
+ }
+
+ override def createWorker(map: Map[String, String]): Worker = {
+ val storageDir = createTmpDir()
+ workerDirs = workerDirs :+ storageDir
+ super.createWorker(map ++ Map("celeborn.master.heartbeat.worker.timeout"
-> "10s"), storageDir)
+ }
+
+ test("celeborn spark integration test - the failed shuffle file is cleaned
up correctly") {
+ if (Spark3OrNewer) {
+ val sparkConf = new
SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
+ val sparkSession = SparkSession.builder()
+ .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+ .config("spark.sql.shuffle.partitions", 2)
+ .config("spark.celeborn.shuffle.forceFallback.partition.enabled",
false)
+ .config("spark.celeborn.client.spark.stageRerun.enabled", "true")
+ .config(
+ "spark.shuffle.manager",
+ "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
+ .config("spark.celeborn.client.spark.fetch.cleanFailedShuffle", "true")
+ .getOrCreate()
+
+ val celebornConf =
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
+ val hook = new ShuffleReaderGetHooks(
+ celebornConf,
+ workerDirs,
+ shuffleIdToBeDeleted = Seq(0))
+ TestCelebornShuffleManager.registerReaderGetHook(hook)
+ val checkingThread =
+ triggerStorageCheckThread(Seq(0), Seq(1), sparkSession)
+ val tuples = sparkSession.sparkContext.parallelize(1 to 10000, 2)
+ .map { i => (i, i) }.groupByKey(4).collect()
+ checkStorageValidation(checkingThread)
+ // verify result
+ assert(hook.executed.get())
+ assert(tuples.length == 10000)
+ for (elem <- tuples) {
+ elem._2.foreach(i => assert(i.equals(elem._1)))
+ }
+ sparkSession.stop()
+ }
+ }
+
+ class CheckingThread(
+ shuffleIdShouldNotExist: Seq[Int],
+ shuffleIdMustExist: Seq[Int],
+ sparkSession: SparkSession)
+ extends Thread {
+ var exception: Exception = _
+
+ protected def checkDirStatus(): Boolean = {
+ val deletedSuccessfully = shuffleIdShouldNotExist.forall(shuffleId => {
+ workerDirs.forall(dir =>
+ !new File(s"$dir/celeborn-worker/shuffle_data/" +
+ s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists())
+ })
+ val deletedSuccessfullyString = shuffleIdShouldNotExist.map(shuffleId =>
{
+ shuffleId.toString + ":" +
+ workerDirs.map(dir =>
+ !new File(s"$dir/celeborn-worker/shuffle_data/" +
+
s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()).toList
+ }).mkString(",")
+ val createdSuccessfully = shuffleIdMustExist.forall(shuffleId => {
+ workerDirs.exists(dir =>
+ new File(s"$dir/celeborn-worker/shuffle_data/" +
+ s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists())
+ })
+ val createdSuccessfullyString = shuffleIdMustExist.map(shuffleId => {
+ shuffleId.toString + ":" +
+ workerDirs.map(dir =>
+ new File(s"$dir/celeborn-worker/shuffle_data/" +
+
s"${sparkSession.sparkContext.applicationId}/$shuffleId").exists()).toList
+ }).mkString(",")
+ println(s"shuffle-to-be-deleted status: $deletedSuccessfullyString \n" +
+ s"shuffle-to-be-created status: $createdSuccessfullyString")
+ deletedSuccessfully && createdSuccessfully
+ }
+
+ override def run(): Unit = {
+ var allDataInShape = checkDirStatus()
+ while (!allDataInShape) {
+ Thread.sleep(1000)
+ allDataInShape = checkDirStatus()
+ }
+ }
+ }
+
+ protected def triggerStorageCheckThread(
+ shuffleIdShouldNotExist: Seq[Int],
+ shuffleIdMustExist: Seq[Int],
+ sparkSession: SparkSession): CheckingThread = {
+ val checkingThread =
+ new CheckingThread(shuffleIdShouldNotExist, shuffleIdMustExist,
sparkSession)
+ checkingThread.setDaemon(true)
+ checkingThread.start()
+ checkingThread
+ }
+
+ protected def checkStorageValidation(thread: Thread, timeout: Long = 1200 *
1000): Unit = {
+ val checkingThread = thread.asInstanceOf[CheckingThread]
+ checkingThread.join(timeout)
+ if (checkingThread.isAlive || checkingThread.exception != null) {
+ throw new IllegalStateException("the storage checking status failed," +
+ s"${checkingThread.isAlive} ${if (checkingThread.exception != null)
checkingThread.exception.getMessage
+ else "NULL"}")
+ }
+ }
+}
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
index dd0f38401..9db3912a7 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
@@ -30,6 +30,7 @@ import org.scalatest.funsuite.AnyFunSuite
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.protocol.ShuffleMode
+import org.apache.celeborn.tests.spark.fetch.failure.ShuffleReaderGetHooks
class CelebornFetchFailureSuite extends AnyFunSuite
with SparkTestBase
@@ -57,7 +58,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()
val celebornConf =
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
- val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
+ val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs)
TestCelebornShuffleManager.registerReaderGetHook(hook)
val value = Range(1, 10000).mkString(",")
@@ -130,7 +131,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()
val celebornConf =
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
- val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
+ val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs)
TestCelebornShuffleManager.registerReaderGetHook(hook)
import sparkSession.implicits._
@@ -161,7 +162,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()
val celebornConf =
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
- val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
+ val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs)
TestCelebornShuffleManager.registerReaderGetHook(hook)
val sc = sparkSession.sparkContext
@@ -201,7 +202,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()
val celebornConf =
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
- val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
+ val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs)
TestCelebornShuffleManager.registerReaderGetHook(hook)
val sc = sparkSession.sparkContext
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
index e29b21a0c..41cbe072b 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
@@ -116,46 +116,4 @@ trait SparkTestBase extends AnyFunSuite
val outMap = result.collect().map(row => row.getString(0) ->
row.getLong(1)).toMap
outMap
}
-
- class ShuffleReaderFetchFailureGetHook(conf: CelebornConf) extends
ShuffleManagerHook {
- var executed: AtomicBoolean = new AtomicBoolean(false)
- val lock = new Object
-
- override def exec(
- handle: ShuffleHandle,
- startPartition: Int,
- endPartition: Int,
- context: TaskContext): Unit = {
- if (executed.get() == true) return
-
- lock.synchronized {
- handle match {
- case h: CelebornShuffleHandle[_, _, _] => {
- val appUniqueId = h.appUniqueId
- val shuffleClient = ShuffleClient.get(
- h.appUniqueId,
- h.lifecycleManagerHost,
- h.lifecycleManagerPort,
- conf,
- h.userIdentifier,
- h.extension)
- val celebornShuffleId =
- SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
- val allFiles = workerDirs.map(dir => {
- new
File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
- })
- val datafile = allFiles.filter(_.exists())
- .flatMap(_.listFiles().iterator).sortBy(_.getName).headOption
- datafile match {
- case Some(file) => file.delete()
- case None => throw new RuntimeException("unexpected, there must
be some data file" +
- s" under ${workerDirs.mkString(",")}")
- }
- }
- case _ => throw new RuntimeException("unexpected, only support
RssShuffleHandle here")
- }
- executed.set(true)
- }
- }
- }
}
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala
new file mode 100644
index 000000000..adac14242
--- /dev/null
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.spark.fetch.failure
+
+import java.io.File
+import java.util.concurrent.atomic.AtomicBoolean
+
+import org.apache.spark.TaskContext
+import org.apache.spark.shuffle.ShuffleHandle
+import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle,
ShuffleManagerHook, SparkCommonUtils, SparkShuffleManager, SparkUtils,
TestCelebornShuffleManager}
+
+import org.apache.celeborn.client.ShuffleClient
+import org.apache.celeborn.common.CelebornConf
+
+class ShuffleReaderGetHooks(
+ conf: CelebornConf,
+ workerDirs: Seq[String],
+ shuffleIdToBeDeleted: Seq[Int] = Seq(),
+ triggerStageId: Option[Int] = None)
+ extends ShuffleManagerHook {
+
+ var executed: AtomicBoolean = new AtomicBoolean(false)
+ val lock = new Object
+
+ private def deleteDataFile(appUniqueId: String, celebornShuffleId: Int):
Unit = {
+ val datafile =
+ workerDirs.map(dir => {
+ new
File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
+ }).filter(_.exists())
+ .flatMap(_.listFiles().iterator).headOption
+ datafile match {
+ case Some(file) => {
+ file.delete()
+ }
+ case None => throw new RuntimeException("unexpected, there must be some
data file")
+ }
+ }
+
+ override def exec(
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext): Unit = {
+ if (executed.get()) {
+ return
+ }
+ lock.synchronized {
+ handle match {
+ case h: CelebornShuffleHandle[_, _, _] => {
+ val appUniqueId = h.appUniqueId
+ val shuffleClient = ShuffleClient.get(
+ h.appUniqueId,
+ h.lifecycleManagerHost,
+ h.lifecycleManagerPort,
+ conf,
+ h.userIdentifier,
+ h.extension)
+ val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient,
h, context, false)
+ val appShuffleIdentifier =
+ SparkCommonUtils.encodeAppShuffleIdentifier(handle.shuffleId,
context)
+ val Array(_, stageId, _) = appShuffleIdentifier.split('-')
+ if (triggerStageId.isEmpty || triggerStageId.get == stageId.toInt) {
+ if (shuffleIdToBeDeleted.isEmpty) {
+ deleteDataFile(appUniqueId, celebornShuffleId)
+ } else {
+ shuffleIdToBeDeleted.foreach { shuffleId =>
+ deleteDataFile(appUniqueId, shuffleId)
+ }
+ }
+ executed.set(true)
+ }
+ }
+ case x => throw new RuntimeException(s"unexpected, only support
RssShuffleHandle here," +
+ s" but get ${x.getClass.getCanonicalName}")
+ }
+ }
+ }
+}
diff --git
a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
index 83a5e12f6..293ef080e 100644
---
a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala
@@ -33,6 +33,7 @@ import
org.apache.celeborn.common.protocol.{PartitionLocation, ShuffleMode}
import
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.tests.spark.SparkTestBase
+import org.apache.celeborn.tests.spark.fetch.failure.ShuffleReaderGetHooks
class SparkUtilsSuite extends AnyFunSuite
with SparkTestBase
@@ -60,7 +61,7 @@ class SparkUtilsSuite extends AnyFunSuite
.getOrCreate()
val celebornConf =
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
- val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
+ val hook = new ShuffleReaderGetHooks(celebornConf, workerDirs)
TestCelebornShuffleManager.registerReaderGetHook(hook)
try {