This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new aee41555c [CELEBORN-955] Re-run Spark Stage for Celeborn Shuffle Fetch
Failure
aee41555c is described below
commit aee41555c658a35b2e3f837d741d2bfd94573f2e
Author: Erik.fang <[email protected]>
AuthorDate: Sun Nov 26 16:47:58 2023 +0800
[CELEBORN-955] Re-run Spark Stage for Celeborn Shuffle Fetch Failure
### What changes were proposed in this pull request?
Currently, Celeborn uses replication to handle shuffle data lost for
celeborn shuffle reader, this PR implements an alternative solution by Spark
stage resubmission.
Design doc:
https://docs.google.com/document/d/1dkG6fww3g99VAb1wkphNlUES_MpngVPNg8601chmVp8/edit
### Why are the changes needed?
Spark stage resubmission uses less resources compared with replication, and
some Celeborn users are also asking for it
### Does this PR introduce _any_ user-facing change?
a new config celeborn.client.fetch.throwsFetchFailure is introduced to
enable this feature
### How was this patch tested?
two UTs are attached, and we also tested it in Ant Group's Dev spark cluster
Closes #1924 from
ErikFang/Re-run-Spark-Stage-for-Celeborn-Shuffle-Fetch-Failure.
Lead-authored-by: Erik.fang <[email protected]>
Co-authored-by: Cheng Pan <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
---
.../celeborn/ExecutorShuffleIdTracker.java} | 35 ++--
.../shuffle/celeborn/HashBasedShuffleWriter.java | 3 +-
.../shuffle/celeborn/SortBasedShuffleWriter.java | 3 +-
.../shuffle/celeborn/SparkShuffleManager.java | 38 +++-
.../apache/spark/shuffle/celeborn/SparkUtils.java | 21 ++
.../shuffle/celeborn/CelebornShuffleHandle.scala | 1 +
.../shuffle/celeborn/CelebornShuffleReader.scala | 56 +++++-
.../celeborn/CelebornShuffleWriterSuiteBase.java | 2 +-
.../celeborn/HashBasedShuffleWriterSuiteJ.java | 8 +-
.../shuffle/celeborn/ShuffleManagerHook.java} | 28 +--
.../celeborn/SortBasedShuffleWriterSuiteJ.java | 1 +
...SuiteJ.java => TestCelebornShuffleManager.java} | 31 +--
.../celeborn/ColumnarHashBasedShuffleWriter.java | 3 +-
.../celeborn/CelebornColumnarShuffleReader.scala | 6 +-
.../ColumnarHashBasedShuffleWriterSuiteJ.java | 20 +-
.../CelebornColumnarShuffleReaderSuite.scala | 8 +-
.../shuffle/celeborn/HashBasedShuffleWriter.java | 3 +-
.../shuffle/celeborn/SortBasedShuffleWriter.java | 4 +-
.../shuffle/celeborn/SparkShuffleManager.java | 50 ++++-
.../apache/spark/shuffle/celeborn/SparkUtils.java | 61 +++++-
.../shuffle/celeborn/CelebornShuffleHandle.scala | 1 +
.../shuffle/celeborn/CelebornShuffleReader.scala | 58 +++++-
.../celeborn/CelebornShuffleWriterSuiteBase.java | 2 +-
.../celeborn/HashBasedShuffleWriterSuiteJ.java | 8 +-
.../shuffle/celeborn/ShuffleManagerHook.java} | 28 +--
.../celeborn/TestCelebornShuffleManager.java | 67 +++++++
.../org/apache/celeborn/client/ShuffleClient.java | 9 +
.../apache/celeborn/client/ShuffleClientImpl.java | 38 ++++
.../apache/celeborn/client/LifecycleManager.scala | 154 ++++++++++++++
.../apache/celeborn/client/DummyShuffleClient.java | 10 +
.../common/network/protocol/TransportMessage.java | 12 ++
common/src/main/proto/TransportMessages.proto | 23 +++
.../org/apache/celeborn/common/CelebornConf.scala | 9 +
.../common/protocol/message/ControlMessages.scala | 24 +++
docs/configuration/client.md | 1 +
tests/spark-it/pom.xml | 49 +++++
.../tests/spark/CelebornFetchFailureSuite.scala | 223 +++++++++++++++++++++
.../service/deploy/MiniClusterFeature.scala | 12 +-
38 files changed, 999 insertions(+), 111 deletions(-)
diff --git
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ExecutorShuffleIdTracker.java
similarity index 51%
copy from
client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
copy to
client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ExecutorShuffleIdTracker.java
index 475efee24..254eb516d 100644
---
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/ExecutorShuffleIdTracker.java
@@ -17,24 +17,31 @@
package org.apache.spark.shuffle.celeborn;
-import java.io.IOException;
-
-import org.apache.spark.TaskContext;
-import org.apache.spark.shuffle.ShuffleWriter;
+import java.util.HashSet;
+import java.util.concurrent.ConcurrentHashMap;
import org.apache.celeborn.client.ShuffleClient;
-import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.util.JavaUtils;
+
+public class ExecutorShuffleIdTracker {
+ // track appShuffleId -> shuffleId Set in executor for cleanup
+ private ConcurrentHashMap<Integer, HashSet<Integer>> shuffleIdMap =
+ JavaUtils.newConcurrentHashMap();
-public class HashBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterSuiteBase {
+ public void track(int appShuffleId, int shuffleId) {
+ HashSet<Integer> shuffleIds = shuffleIdMap.computeIfAbsent(appShuffleId,
id -> new HashSet<>());
- public HashBasedShuffleWriterSuiteJ() throws IOException {}
+ synchronized (shuffleIds) {
+ shuffleIds.add(shuffleId);
+ }
+ }
- @Override
- protected ShuffleWriter<Integer, String> createShuffleWriter(
- CelebornShuffleHandle handle, TaskContext context, CelebornConf conf,
ShuffleClient client)
- throws IOException {
- // this test case is independent of the `mapId` value
- return new HashBasedShuffleWriter<Integer, String, String>(
- handle, /*mapId=*/ 0, context, conf, client, SendBufferPool.get(1, 30,
60));
+ public void unregisterAppShuffleId(ShuffleClient shuffleClient, int
appShuffleId) {
+ HashSet<Integer> shuffleIds = shuffleIdMap.remove(appShuffleId);
+ if (shuffleIds != null) {
+ synchronized (shuffleIds) {
+ shuffleIds.forEach(shuffleClient::cleanupShuffle);
+ }
+ }
}
}
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index 89826be05..407b32849 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -101,6 +101,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
// In order to facilitate the writing of unit test code, ShuffleClient needs
to be passed in as
// parameters. By the way, simplify the passed parameters.
public HashBasedShuffleWriter(
+ int shuffleId,
CelebornShuffleHandle<K, V, C> handle,
int mapId,
TaskContext taskContext,
@@ -110,7 +111,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
throws IOException {
this.mapId = mapId;
this.dep = handle.dependency();
- this.shuffleId = dep.shuffleId();
+ this.shuffleId = shuffleId;
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 62c05858d..51087c6dd 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -94,6 +94,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
// In order to facilitate the writing of unit test code, ShuffleClient needs
to be passed in as
// parameters. By the way, simplify the passed parameters.
public SortBasedShuffleWriter(
+ int shuffleId,
ShuffleDependency<K, V, C> dep,
int numMappers,
TaskContext taskContext,
@@ -104,7 +105,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
throws IOException {
this.mapId = taskContext.partitionId();
this.dep = dep;
- this.shuffleId = dep.shuffleId();
+ this.shuffleId = shuffleId;
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index 071810ad2..470d2e989 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -26,6 +26,7 @@ import scala.Int;
import org.apache.spark.*;
import org.apache.spark.launcher.SparkLauncher;
+import org.apache.spark.rdd.DeterministicLevel;
import org.apache.spark.shuffle.*;
import org.apache.spark.shuffle.sort.SortShuffleManager;
import org.apache.spark.util.Utils;
@@ -66,6 +67,8 @@ public class SparkShuffleManager implements ShuffleManager {
private long sendBufferPoolCheckInterval;
private long sendBufferPoolExpireTimeout;
+ private ExecutorShuffleIdTracker shuffleIdTracker = new
ExecutorShuffleIdTracker();
+
public SparkShuffleManager(SparkConf conf, boolean isDriver) {
this.conf = conf;
this.isDriver = isDriver;
@@ -105,6 +108,12 @@ public class SparkShuffleManager implements ShuffleManager
{
synchronized (this) {
if (lifecycleManager == null) {
lifecycleManager = new LifecycleManager(appId, celebornConf);
+ if (celebornConf.clientFetchThrowsFetchFailure()) {
+ MapOutputTrackerMaster mapOutputTracker =
+ (MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();
+ lifecycleManager.registerShuffleTrackerCallback(
+ shuffleId ->
mapOutputTracker.unregisterAllMapOutput(shuffleId));
+ }
}
}
}
@@ -119,6 +128,10 @@ public class SparkShuffleManager implements ShuffleManager
{
appUniqueId = SparkUtils.appUniqueId(dependency.rdd().context());
initializeLifecycleManager(appUniqueId);
+ lifecycleManager.registerAppShuffleDeterminate(
+ shuffleId,
+ dependency.rdd().getOutputDeterministicLevel() !=
DeterministicLevel.INDETERMINATE());
+
if (fallbackPolicyRunner.applyAllFallbackPolicy(
lifecycleManager, dependency.partitioner().numPartitions())) {
logger.warn("Fallback to SortShuffleManager!");
@@ -131,23 +144,24 @@ public class SparkShuffleManager implements
ShuffleManager {
lifecycleManager.getPort(),
lifecycleManager.getUserIdentifier(),
shuffleId,
+ celebornConf.clientFetchThrowsFetchFailure(),
numMaps,
dependency);
}
}
@Override
- public boolean unregisterShuffle(int shuffleId) {
- if (sortShuffleIds.contains(shuffleId)) {
- return sortShuffleManager().unregisterShuffle(shuffleId);
+ public boolean unregisterShuffle(int appShuffleId) {
+ if (sortShuffleIds.contains(appShuffleId)) {
+ return sortShuffleManager().unregisterShuffle(appShuffleId);
}
// For Spark driver side trigger unregister shuffle.
if (lifecycleManager != null) {
- lifecycleManager.unregisterShuffle(shuffleId);
+ lifecycleManager.unregisterAppShuffle(appShuffleId);
}
// For Spark executor side cleanup shuffle related info.
if (shuffleClient != null) {
- shuffleClient.cleanupShuffle(shuffleId);
+ shuffleIdTracker.unregisterAppShuffleId(shuffleClient, appShuffleId);
}
return true;
}
@@ -187,10 +201,14 @@ public class SparkShuffleManager implements
ShuffleManager {
h.lifecycleManagerPort(),
celebornConf,
h.userIdentifier());
+ int shuffleId = SparkUtils.celebornShuffleId(client, h, context, true);
+ shuffleIdTracker.track(h.shuffleId(), shuffleId);
+
if (ShuffleMode.SORT.equals(celebornConf.shuffleWriterMode())) {
ExecutorService pushThread =
celebornConf.clientPushSortPipelineEnabled() ? getPusherThread()
: null;
return new SortBasedShuffleWriter<>(
+ shuffleId,
h.dependency(),
h.numMaps(),
context,
@@ -200,6 +218,7 @@ public class SparkShuffleManager implements ShuffleManager {
SendBufferPool.get(cores, sendBufferPoolCheckInterval,
sendBufferPoolExpireTimeout));
} else if (ShuffleMode.HASH.equals(celebornConf.shuffleWriterMode())) {
return new HashBasedShuffleWriter<>(
+ shuffleId,
h,
mapId,
context,
@@ -225,7 +244,14 @@ public class SparkShuffleManager implements ShuffleManager
{
@SuppressWarnings("unchecked")
CelebornShuffleHandle<K, ?, C> h = (CelebornShuffleHandle<K, ?, C>)
handle;
return new CelebornShuffleReader<>(
- h, startPartition, endPartition, 0, Int.MaxValue(), context,
celebornConf);
+ h,
+ startPartition,
+ endPartition,
+ 0,
+ Int.MaxValue(),
+ context,
+ celebornConf,
+ shuffleIdTracker);
}
return _sortShuffleManager.getReader(handle, startPartition, endPartition,
context);
}
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index b4446f0da..0f03b5688 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -26,6 +26,7 @@ import scala.Tuple2;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
+import org.apache.spark.TaskContext;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.sql.execution.UnsafeRowSerializer;
@@ -34,12 +35,15 @@ import org.apache.spark.storage.BlockManagerId;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.util.Utils;
public class SparkUtils {
private static final Logger logger =
LoggerFactory.getLogger(SparkUtils.class);
+ public static final String FETCH_FAILURE_ERROR_MSG = "Celeborn FetchFailure
with shuffle id ";
+
public static MapStatus createMapStatus(
BlockManagerId loc, long[] uncompressedSizes, long[]
uncompressedRecords) throws IOException {
@@ -115,6 +119,23 @@ public class SparkUtils {
}
}
+ public static String getAppShuffleIdentifier(int appShuffleId, TaskContext
context) {
+ return appShuffleId + "-" + context.stageId() + "-" +
context.stageAttemptNumber();
+ }
+
+ public static int celebornShuffleId(
+ ShuffleClient client,
+ CelebornShuffleHandle<?, ?, ?> handle,
+ TaskContext context,
+ Boolean isWriter) {
+ if (handle.throwsFetchFailure()) {
+ String appShuffleIdentifier =
getAppShuffleIdentifier(handle.shuffleId(), context);
+ return client.getShuffleId(handle.shuffleId(), appShuffleIdentifier,
isWriter);
+ } else {
+ return handle.shuffleId();
+ }
+ }
+
// Create an instance of the class with the given name, possibly
initializing it with our conf
// Copied from SparkEnv
public static <T> T instantiateClass(String className, SparkConf conf,
Boolean isDriver) {
diff --git
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
index ef1f2c881..4f67edaf3 100644
---
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
+++
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
@@ -28,6 +28,7 @@ class CelebornShuffleHandle[K, V, C](
val lifecycleManagerPort: Int,
val userIdentifier: UserIdentifier,
shuffleId: Int,
+ val throwsFetchFailure: Boolean,
numMappers: Int,
dependency: ShuffleDependency[K, V, C])
extends BaseShuffleHandle(shuffleId, numMappers, dependency)
diff --git
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 1059f3604..dec305225 100644
---
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicReference
import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.internal.Logging
-import org.apache.spark.shuffle.ShuffleReader
+import org.apache.spark.shuffle.{FetchFailedException, ShuffleReader}
import
org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
@@ -32,7 +32,7 @@ import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.read.CelebornInputStream
import org.apache.celeborn.client.read.MetricsCallback
import org.apache.celeborn.common.CelebornConf
-import org.apache.celeborn.common.exception.CelebornIOException
+import org.apache.celeborn.common.exception.{CelebornIOException,
PartitionUnRetryAbleException}
import org.apache.celeborn.common.util.ThreadUtils
class CelebornShuffleReader[K, C](
@@ -42,7 +42,8 @@ class CelebornShuffleReader[K, C](
startMapIndex: Int = 0,
endMapIndex: Int = Int.MaxValue,
context: TaskContext,
- conf: CelebornConf)
+ conf: CelebornConf,
+ shuffleIdTracker: ExecutorShuffleIdTracker)
extends ShuffleReader[K, C] with Logging {
private val dep = handle.dependency
@@ -59,6 +60,11 @@ class CelebornShuffleReader[K, C](
val serializerInstance = dep.serializer.newInstance()
+ val shuffleId = SparkUtils.celebornShuffleId(shuffleClient, handle,
context, false)
+ shuffleIdTracker.track(handle.shuffleId, shuffleId)
+ logDebug(
+ s"get shuffleId $shuffleId for appShuffleId ${handle.shuffleId}
attemptNum ${context.stageAttemptNumber()}")
+
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricsCallback = new MetricsCallback {
@@ -87,7 +93,7 @@ class CelebornShuffleReader[K, C](
if (exceptionRef.get() == null) {
try {
val inputStream = shuffleClient.readPartition(
- handle.shuffleId,
+ shuffleId,
partitionId,
context.attemptNumber(),
startMapIndex,
@@ -113,7 +119,21 @@ class CelebornShuffleReader[K, C](
var inputStream: CelebornInputStream = streams.get(partitionId)
while (inputStream == null) {
if (exceptionRef.get() != null) {
- throw exceptionRef.get()
+ exceptionRef.get() match {
+ case ce @ (_: CelebornIOException | _:
PartitionUnRetryAbleException) =>
+ if (handle.throwsFetchFailure &&
+ shuffleClient.reportShuffleFetchFailure(handle.shuffleId,
shuffleId)) {
+ throw new FetchFailedException(
+ null,
+ handle.shuffleId,
+ -1,
+ partitionId,
+ SparkUtils.FETCH_FAILURE_ERROR_MSG + shuffleId,
+ ce)
+ } else
+ throw ce
+ case e => throw e
+ }
}
Thread.sleep(50)
inputStream = streams.get(partitionId)
@@ -122,12 +142,30 @@ class CelebornShuffleReader[K, C](
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait))
// ensure inputStream is closed when task completes
context.addTaskCompletionListener(_ => inputStream.close())
- inputStream
+ (partitionId, inputStream)
} else {
- CelebornInputStream.empty()
+ (partitionId, CelebornInputStream.empty())
}
- }).flatMap(
- serializerInstance.deserializeStream(_).asKeyValueIterator)
+ }).map { case (partitionId, inputStream) =>
+ (partitionId,
serializerInstance.deserializeStream(inputStream).asKeyValueIterator)
+ }.flatMap { case (partitionId, iter) =>
+ try {
+ iter
+ } catch {
+ case e @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
+ if (handle.throwsFetchFailure &&
+ shuffleClient.reportShuffleFetchFailure(handle.shuffleId,
shuffleId)) {
+ throw new FetchFailedException(
+ null,
+ handle.shuffleId,
+ -1,
+ partitionId,
+ SparkUtils.FETCH_FAILURE_ERROR_MSG + shuffleId,
+ e)
+ } else
+ throw e
+ }
+ }
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
diff --git
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
index bd9355190..f6f38a592 100644
---
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
+++
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
@@ -227,7 +227,7 @@ public abstract class CelebornShuffleWriterSuiteBase {
final File tempFile = new File(tempDir, UUID.randomUUID().toString());
final CelebornShuffleHandle<Integer, String, String> handle =
new CelebornShuffleHandle<>(
- appId, host, port, userIdentifier, shuffleId, numMaps, dependency);
+ appId, host, port, userIdentifier, shuffleId, false, numMaps,
dependency);
final ShuffleClient client = new DummyShuffleClient(conf, tempFile);
((DummyShuffleClient) client).initReducePartitionMap(shuffleId,
numPartitions, 1);
diff --git
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
index 475efee24..0d38e2841 100644
---
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
+++
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
@@ -35,6 +35,12 @@ public class HashBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterSuiteBase
throws IOException {
// this test case is independent of the `mapId` value
return new HashBasedShuffleWriter<Integer, String, String>(
- handle, /*mapId=*/ 0, context, conf, client, SendBufferPool.get(1, 30,
60));
+ SparkUtils.celebornShuffleId(client, handle, context, true),
+ handle,
+ /*mapId=*/ 0,
+ context,
+ conf,
+ client,
+ SendBufferPool.get(1, 30, 60));
}
}
diff --git
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerHook.java
similarity index 61%
copy from
client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
copy to
client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerHook.java
index 9d411ceab..3da77f014 100644
---
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
+++
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerHook.java
@@ -15,19 +15,21 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.celeborn
+package org.apache.spark.shuffle.celeborn;
-import org.apache.spark.ShuffleDependency
-import org.apache.spark.shuffle.BaseShuffleHandle
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleHandle;
-import org.apache.celeborn.common.identity.UserIdentifier
+public interface ShuffleManagerHook {
-class CelebornShuffleHandle[K, V, C](
- val appUniqueId: String,
- val lifecycleManagerHost: String,
- val lifecycleManagerPort: Int,
- val userIdentifier: UserIdentifier,
- shuffleId: Int,
- val numMappers: Int,
- dependency: ShuffleDependency[K, V, C])
- extends BaseShuffleHandle(shuffleId, dependency)
+ default void exec(
+ ShuffleHandle handle, int startPartition, int endPartition, TaskContext
context) {}
+
+ default void exec(
+ ShuffleHandle handle,
+ int startMapIndex,
+ int endMapIndex,
+ int startPartition,
+ int endPartition,
+ TaskContext context) {};
+}
diff --git
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
index efff3b97a..74ebab2b9 100644
---
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
+++
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
@@ -34,6 +34,7 @@ public class SortBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterSuiteBase
CelebornShuffleHandle handle, TaskContext context, CelebornConf conf,
ShuffleClient client)
throws IOException {
return new SortBasedShuffleWriter<Integer, String, String>(
+ SparkUtils.celebornShuffleId(client, handle, context, true),
handle.dependency(),
numPartitions,
context,
diff --git
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java
similarity index 53%
copy from
client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
copy to
client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java
index 475efee24..5c995e753 100644
---
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
+++
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java
@@ -17,24 +17,29 @@
package org.apache.spark.shuffle.celeborn;
-import java.io.IOException;
-
+import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
-import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.shuffle.ShuffleHandle;
+import org.apache.spark.shuffle.ShuffleReader;
+
+public class TestCelebornShuffleManager extends SparkShuffleManager {
-import org.apache.celeborn.client.ShuffleClient;
-import org.apache.celeborn.common.CelebornConf;
+ private static ShuffleManagerHook shuffleReaderGetHook = null;
-public class HashBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterSuiteBase {
+ public TestCelebornShuffleManager(SparkConf conf) {
+ super(conf, true);
+ }
- public HashBasedShuffleWriterSuiteJ() throws IOException {}
+ public static void registerReaderGetHook(ShuffleManagerHook hook) {
+ shuffleReaderGetHook = hook;
+ }
@Override
- protected ShuffleWriter<Integer, String> createShuffleWriter(
- CelebornShuffleHandle handle, TaskContext context, CelebornConf conf,
ShuffleClient client)
- throws IOException {
- // this test case is independent of the `mapId` value
- return new HashBasedShuffleWriter<Integer, String, String>(
- handle, /*mapId=*/ 0, context, conf, client, SendBufferPool.get(1, 30,
60));
+ public <K, C> ShuffleReader<K, C> getReader(
+ ShuffleHandle handle, int startPartition, int endPartition, TaskContext
context) {
+ if (shuffleReaderGetHook != null) {
+ shuffleReaderGetHook.exec(handle, startPartition, endPartition, context);
+ }
+ return super.getReader(handle, startPartition, endPartition, context);
}
}
diff --git
a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
index be5d78c50..b468c5b96 100644
---
a/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
+++
b/client-spark/spark-3-columnar-shuffle/src/main/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriter.java
@@ -58,6 +58,7 @@ public class ColumnarHashBasedShuffleWriter<K, V, C> extends
HashBasedShuffleWri
private final double columnarShuffleDictionaryMaxFactor;
public ColumnarHashBasedShuffleWriter(
+ int shuffleId,
CelebornShuffleHandle<K, V, C> handle,
TaskContext taskContext,
CelebornConf conf,
@@ -65,7 +66,7 @@ public class ColumnarHashBasedShuffleWriter<K, V, C> extends
HashBasedShuffleWri
ShuffleWriteMetricsReporter metrics,
SendBufferPool sendBufferPool)
throws IOException {
- super(handle, taskContext, conf, client, metrics, sendBufferPool);
+ super(shuffleId, handle, taskContext, conf, client, metrics,
sendBufferPool);
columnarShuffleBatchSize = conf.columnarShuffleBatchSize();
columnarShuffleCodeGenEnabled = conf.columnarShuffleCodeGenEnabled();
columnarShuffleDictionaryEnabled = conf.columnarShuffleDictionaryEnabled();
diff --git
a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
index f47f9880c..fd888fb9d 100644
---
a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
+++
b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala
@@ -33,7 +33,8 @@ class CelebornColumnarShuffleReader[K, C](
endMapIndex: Int = Int.MaxValue,
context: TaskContext,
conf: CelebornConf,
- metrics: ShuffleReadMetricsReporter)
+ metrics: ShuffleReadMetricsReporter,
+ shuffleIdTracker: ExecutorShuffleIdTracker)
extends CelebornShuffleReader[K, C](
handle,
startPartition,
@@ -42,7 +43,8 @@ class CelebornColumnarShuffleReader[K, C](
endMapIndex,
context,
conf,
- metrics) {
+ metrics,
+ shuffleIdTracker) {
override def newSerializerInstance(dep: ShuffleDependency[K, _, C]):
SerializerInstance = {
val schema = CustomShuffleDependencyUtils.getSchema(dep)
diff --git
a/client-spark/spark-3-columnar-shuffle/src/test/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriterSuiteJ.java
b/client-spark/spark-3-columnar-shuffle/src/test/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriterSuiteJ.java
index e481b6181..92aee552f 100644
---
a/client-spark/spark-3-columnar-shuffle/src/test/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriterSuiteJ.java
+++
b/client-spark/spark-3-columnar-shuffle/src/test/java/org/apache/spark/shuffle/celeborn/ColumnarHashBasedShuffleWriterSuiteJ.java
@@ -62,7 +62,7 @@ public class ColumnarHashBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterS
ShuffleWriter<Integer, String> writer =
createShuffleWriterWithoutSchema(
new CelebornShuffleHandle<>(
- "appId", "host", 0, this.userIdentifier, 0, 10,
this.dependency),
+ "appId", "host", 0, this.userIdentifier, 0, false, 10,
this.dependency),
taskContext,
conf,
client,
@@ -75,7 +75,7 @@ public class ColumnarHashBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterS
writer =
createShuffleWriter(
new CelebornShuffleHandle<>(
- "appId", "host", 0, this.userIdentifier, 0, 10,
this.dependency),
+ "appId", "host", 0, this.userIdentifier, 0, false, 10,
this.dependency),
taskContext,
conf,
client,
@@ -108,7 +108,13 @@ public class ColumnarHashBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterS
.when(() ->
CustomShuffleDependencyUtils.getSchema(handle.dependency()))
.thenReturn(schema);
return SparkUtils.createColumnarHashBasedShuffleWriter(
- handle, context, conf, client, metrics, SendBufferPool.get(1, 30,
60));
+ SparkUtils.celebornShuffleId(client, handle, context, true),
+ handle,
+ context,
+ conf,
+ client,
+ metrics,
+ SendBufferPool.get(1, 30, 60));
}
}
@@ -119,6 +125,12 @@ public class ColumnarHashBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterS
ShuffleClient client,
ShuffleWriteMetricsReporter metrics) {
return SparkUtils.createColumnarHashBasedShuffleWriter(
- handle, context, conf, client, metrics, SendBufferPool.get(1, 30, 60));
+ SparkUtils.celebornShuffleId(client, handle, context, true),
+ handle,
+ context,
+ conf,
+ client,
+ metrics,
+ SendBufferPool.get(1, 30, 60));
}
}
diff --git
a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
index 5a14d0219..5df434f54 100644
---
a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
+++
b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala
@@ -39,6 +39,7 @@ class CelebornColumnarShuffleReaderSuite {
0,
new UserIdentifier("mock", "mock"),
0,
+ false,
10,
null)
@@ -53,7 +54,8 @@ class CelebornColumnarShuffleReaderSuite {
10,
null,
new CelebornConf(),
- null)
+ null,
+ new ExecutorShuffleIdTracker())
assert(shuffleReader.getClass ==
classOf[CelebornColumnarShuffleReader[Int, String]])
} finally {
if (shuffleClient != null) {
@@ -74,6 +76,7 @@ class CelebornColumnarShuffleReaderSuite {
0,
new UserIdentifier("mock", "mock"),
0,
+ false,
10,
null),
0,
@@ -82,7 +85,8 @@ class CelebornColumnarShuffleReaderSuite {
10,
null,
new CelebornConf(),
- null)
+ null,
+ new ExecutorShuffleIdTracker())
val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int,
String, String]])
Mockito.when(shuffleDependency.shuffleId).thenReturn(0)
Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer(
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index 793d06bff..7b30101f7 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -101,6 +101,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
// In order to facilitate the writing of unit test code, ShuffleClient needs
to be passed in as
// parameters. By the way, simplify the passed parameters.
public HashBasedShuffleWriter(
+ int shuffleId,
CelebornShuffleHandle<K, V, C> handle,
TaskContext taskContext,
CelebornConf conf,
@@ -110,7 +111,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
throws IOException {
this.mapId = taskContext.partitionId();
this.dep = handle.dependency();
- this.shuffleId = dep.shuffleId();
+ this.shuffleId = shuffleId;
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = metrics;
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 98c694306..0b515c456 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -94,6 +94,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
// In order to facilitate the writing of unit test code, ShuffleClient needs
to be passed in as
// parameters. By the way, simplify the passed parameters.
public SortBasedShuffleWriter(
+ int shuffleId,
ShuffleDependency<K, V, C> dep,
int numMappers,
TaskContext taskContext,
@@ -105,7 +106,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
throws IOException {
this.mapId = taskContext.partitionId();
this.dep = dep;
- this.shuffleId = dep.shuffleId();
+ this.shuffleId = shuffleId;
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = metrics;
@@ -179,6 +180,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
SendBufferPool sendBufferPool)
throws IOException {
this(
+ SparkUtils.celebornShuffleId(client, handle, taskContext, true),
handle.dependency(),
handle.numMappers(),
taskContext,
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
index aa1c1a2c2..a1cb458cf 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java
@@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import org.apache.spark.*;
import org.apache.spark.launcher.SparkLauncher;
+import org.apache.spark.rdd.DeterministicLevel;
import org.apache.spark.shuffle.*;
import org.apache.spark.shuffle.sort.SortShuffleManager;
import org.apache.spark.sql.internal.SQLConf;
@@ -38,6 +39,17 @@ import org.apache.celeborn.common.protocol.ShuffleMode;
import org.apache.celeborn.common.util.ThreadUtils;
import org.apache.celeborn.reflect.DynMethods;
+/**
+ * In order to support Spark Stage resubmit with ShuffleReader FetchFails,
Celeborn shuffleId has to
+ * be distinguished from Spark shuffleId. Spark shuffleId is assigned at
ShuffleDependency
+ * construction time, and all Attempts of a Spark Stage have the same
ShuffleId. When Celeborn
+ * ShuffleReader fails to fetch shuffle data from worker and throws {@link
FetchFailedException},
+ * Spark DAGScheduler resubmits the failed ResultStage and corresponding
ShuffleMapStage , but
+ * Celeborn can't differentiate shuffle data from previous failed/resubmitted
ShuffleMapStage with
+ * the same shuffleId. Current solution takes Stage retry in account, and has
LifecycleManager to
+ * generate and track usage of spark shuffle id (appShuffleID) and Celeborn
shuffle id (shuffle id).
+ * Spark shuffle Reader/Write gets shuffleId from LifecycleManager with
GetShuffleId RPC
+ */
public class SparkShuffleManager implements ShuffleManager {
private static final Logger logger =
LoggerFactory.getLogger(SparkShuffleManager.class);
@@ -79,6 +91,8 @@ public class SparkShuffleManager implements ShuffleManager {
private long sendBufferPoolCheckInterval;
private long sendBufferPoolExpireTimeout;
+ private ExecutorShuffleIdTracker shuffleIdTracker = new
ExecutorShuffleIdTracker();
+
public SparkShuffleManager(SparkConf conf, boolean isDriver) {
if (conf.getBoolean(SQLConf.LOCAL_SHUFFLE_READER_ENABLED().key(), true)) {
logger.warn(
@@ -125,6 +139,13 @@ public class SparkShuffleManager implements ShuffleManager
{
synchronized (this) {
if (lifecycleManager == null) {
lifecycleManager = new LifecycleManager(appUniqueId, celebornConf);
+ if (celebornConf.clientFetchThrowsFetchFailure()) {
+ MapOutputTrackerMaster mapOutputTracker =
+ (MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();
+
+ lifecycleManager.registerShuffleTrackerCallback(
+ shuffleId ->
SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId));
+ }
}
}
}
@@ -139,6 +160,10 @@ public class SparkShuffleManager implements ShuffleManager
{
appUniqueId = SparkUtils.appUniqueId(dependency.rdd().context());
initializeLifecycleManager();
+ lifecycleManager.registerAppShuffleDeterminate(
+ shuffleId,
+ dependency.rdd().getOutputDeterministicLevel() !=
DeterministicLevel.INDETERMINATE());
+
if (fallbackPolicyRunner.applyAllFallbackPolicy(
lifecycleManager, dependency.partitioner().numPartitions())) {
if (conf.getBoolean("spark.dynamicAllocation.enabled", false)
@@ -160,23 +185,24 @@ public class SparkShuffleManager implements
ShuffleManager {
lifecycleManager.getPort(),
lifecycleManager.getUserIdentifier(),
shuffleId,
+ celebornConf.clientFetchThrowsFetchFailure(),
dependency.rdd().getNumPartitions(),
dependency);
}
}
@Override
- public boolean unregisterShuffle(int shuffleId) {
- if (sortShuffleIds.contains(shuffleId)) {
- return sortShuffleManager().unregisterShuffle(shuffleId);
+ public boolean unregisterShuffle(int appShuffleId) {
+ if (sortShuffleIds.contains(appShuffleId)) {
+ return sortShuffleManager().unregisterShuffle(appShuffleId);
}
// For Spark driver side trigger unregister shuffle.
if (lifecycleManager != null) {
- lifecycleManager.unregisterShuffle(shuffleId);
+ lifecycleManager.unregisterAppShuffle(appShuffleId);
}
// For Spark executor side cleanup shuffle related info.
if (shuffleClient != null) {
- shuffleClient.cleanupShuffle(shuffleId);
+ shuffleIdTracker.unregisterAppShuffleId(shuffleClient, appShuffleId);
}
return true;
}
@@ -217,10 +243,14 @@ public class SparkShuffleManager implements
ShuffleManager {
h.lifecycleManagerPort(),
celebornConf,
h.userIdentifier());
+ int shuffleId = SparkUtils.celebornShuffleId(shuffleClient, h,
context, true);
+ shuffleIdTracker.track(h.shuffleId(), shuffleId);
+
if (ShuffleMode.SORT.equals(celebornConf.shuffleWriterMode())) {
ExecutorService pushThread =
celebornConf.clientPushSortPipelineEnabled() ? getPusherThread()
: null;
return new SortBasedShuffleWriter<>(
+ shuffleId,
h.dependency(),
h.numMappers(),
context,
@@ -234,10 +264,10 @@ public class SparkShuffleManager implements
ShuffleManager {
SendBufferPool.get(cores, sendBufferPoolCheckInterval,
sendBufferPoolExpireTimeout);
if (COLUMNAR_SHUFFLE_CLASSES_PRESENT &&
celebornConf.columnarShuffleEnabled()) {
return SparkUtils.createColumnarHashBasedShuffleWriter(
- h, context, celebornConf, shuffleClient, metrics, pool);
+ shuffleId, h, context, celebornConf, shuffleClient, metrics,
pool);
} else {
return new HashBasedShuffleWriter<>(
- h, context, celebornConf, shuffleClient, metrics, pool);
+ shuffleId, h, context, celebornConf, shuffleClient, metrics,
pool);
}
} else {
throw new UnsupportedOperationException(
@@ -340,7 +370,8 @@ public class SparkShuffleManager implements ShuffleManager {
endMapIndex,
context,
celebornConf,
- metrics);
+ metrics,
+ shuffleIdTracker);
} else {
return new CelebornShuffleReader<>(
h,
@@ -350,7 +381,8 @@ public class SparkShuffleManager implements ShuffleManager {
endMapIndex,
context,
celebornConf,
- metrics);
+ metrics,
+ shuffleIdTracker);
}
}
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index 46a54f2ff..e7a6a5b8b 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -21,6 +21,7 @@ import java.util.concurrent.atomic.LongAdder;
import scala.Tuple2;
+import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
@@ -46,6 +47,8 @@ import org.apache.celeborn.reflect.DynMethods;
public class SparkUtils {
private static final Logger LOG = LoggerFactory.getLogger(SparkUtils.class);
+ public static final String FETCH_FAILURE_ERROR_MSG = "Celeborn FetchFailure
with shuffle id ";
+
public static MapStatus createMapStatus(
BlockManagerId loc, long[] uncompressedSizes, long mapTaskId) {
return MapStatus$.MODULE$.apply(loc, uncompressedSizes, mapTaskId);
@@ -92,6 +95,23 @@ public class SparkUtils {
.getOrElse(context::applicationId);
}
+ public static String getAppShuffleIdentifier(int appShuffleId, TaskContext
context) {
+ return appShuffleId + "-" + context.stageId() + "-" +
context.stageAttemptNumber();
+ }
+
+ public static int celebornShuffleId(
+ ShuffleClient client,
+ CelebornShuffleHandle<?, ?, ?> handle,
+ TaskContext context,
+ Boolean isWriter) {
+ if (handle.throwsFetchFailure()) {
+ String appShuffleIdentifier =
getAppShuffleIdentifier(handle.shuffleId(), context);
+ return client.getShuffleId(handle.shuffleId(), appShuffleIdentifier,
isWriter);
+ } else {
+ return handle.shuffleId();
+ }
+ }
+
// Create an instance of the class with the given name, possibly
initializing it with our conf
// Copied from SparkEnv
public static <T> T instantiateClass(String className, SparkConf conf,
Boolean isDriver) {
@@ -164,6 +184,7 @@ public class SparkUtils {
DynConstructors.builder()
.impl(
COLUMNAR_HASH_BASED_SHUFFLE_WRITER_CLASS,
+ int.class,
CelebornShuffleHandle.class,
TaskContext.class,
CelebornConf.class,
@@ -172,6 +193,7 @@ public class SparkUtils {
SendBufferPool.class);
public static <K, V, C> HashBasedShuffleWriter<K, V, C>
createColumnarHashBasedShuffleWriter(
+ int shuffleId,
CelebornShuffleHandle<K, V, C> handle,
TaskContext taskContext,
CelebornConf conf,
@@ -180,7 +202,7 @@ public class SparkUtils {
SendBufferPool sendBufferPool) {
return COLUMNAR_HASH_BASED_SHUFFLE_WRITER_CONSTRUCTOR_BUILDER
.build()
- .invoke(null, handle, taskContext, conf, client, metrics,
sendBufferPool);
+ .invoke(null, shuffleId, handle, taskContext, conf, client, metrics,
sendBufferPool);
}
public static final String COLUMNAR_SHUFFLE_READER_CLASS =
@@ -196,7 +218,8 @@ public class SparkUtils {
int.class,
TaskContext.class,
CelebornConf.class,
- ShuffleReadMetricsReporter.class);
+ ShuffleReadMetricsReporter.class,
+ ExecutorShuffleIdTracker.class);
public static <K, C> CelebornShuffleReader<K, C> createColumnarShuffleReader(
CelebornShuffleHandle<K, ?, C> handle,
@@ -206,7 +229,8 @@ public class SparkUtils {
int endMapIndex,
TaskContext context,
CelebornConf conf,
- ShuffleReadMetricsReporter metrics) {
+ ShuffleReadMetricsReporter metrics,
+ ExecutorShuffleIdTracker shuffleIdTracker) {
return COLUMNAR_SHUFFLE_READER_CONSTRUCTOR_BUILDER
.build()
.invoke(
@@ -218,6 +242,35 @@ public class SparkUtils {
endMapIndex,
context,
conf,
- metrics);
+ metrics,
+ shuffleIdTracker);
+ }
+
+ // Added in SPARK-32920, for Spark 3.2 and above
+ private static final DynMethods.UnboundMethod
UnregisterAllMapAndMergeOutput_METHOD =
+ DynMethods.builder("unregisterAllMapAndMergeOutput")
+ .impl(MapOutputTrackerMaster.class, Integer.TYPE)
+ .orNoop()
+ .build();
+
+ // for spark 3.1, see detail in SPARK-32920
+ private static final DynMethods.UnboundMethod UnregisterAllMapOutput_METHOD =
+ DynMethods.builder("unregisterAllMapOutput")
+ .impl(MapOutputTrackerMaster.class, Integer.TYPE)
+ .orNoop()
+ .build();
+
+ public static void unregisterAllMapOutput(
+ MapOutputTrackerMaster mapOutputTracker, int shuffleId) {
+ if (!UnregisterAllMapAndMergeOutput_METHOD.isNoop()) {
+
UnregisterAllMapAndMergeOutput_METHOD.bind(mapOutputTracker).invoke(shuffleId);
+ return;
+ }
+ if (!UnregisterAllMapOutput_METHOD.isNoop()) {
+ UnregisterAllMapOutput_METHOD.bind(mapOutputTracker).invoke(shuffleId);
+ return;
+ }
+ throw new UnsupportedOperationException(
+ "unexpected! neither methods
unregisterAllMapAndMergeOutput/unregisterAllMapOutput are found in
MapOutputTrackerMaster");
}
}
diff --git
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
index 9d411ceab..18a3053e0 100644
---
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
+++
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
@@ -28,6 +28,7 @@ class CelebornShuffleHandle[K, V, C](
val lifecycleManagerPort: Int,
val userIdentifier: UserIdentifier,
shuffleId: Int,
+ val throwsFetchFailure: Boolean,
val numMappers: Int,
dependency: ShuffleDependency[K, V, C])
extends BaseShuffleHandle(shuffleId, dependency)
diff --git
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index d83df1a5b..fe7af8309 100644
---
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicReference
import org.apache.spark.{InterruptibleIterator, ShuffleDependency, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerInstance
-import org.apache.spark.shuffle.{ShuffleReader, ShuffleReadMetricsReporter}
+import org.apache.spark.shuffle.{FetchFailedException, ShuffleReader,
ShuffleReadMetricsReporter}
import
org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
@@ -32,7 +32,7 @@ import org.apache.spark.util.collection.ExternalSorter
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
import org.apache.celeborn.common.CelebornConf
-import org.apache.celeborn.common.exception.CelebornIOException
+import org.apache.celeborn.common.exception.{CelebornIOException,
PartitionUnRetryAbleException}
import org.apache.celeborn.common.util.ThreadUtils
class CelebornShuffleReader[K, C](
@@ -43,7 +43,8 @@ class CelebornShuffleReader[K, C](
endMapIndex: Int = Int.MaxValue,
context: TaskContext,
conf: CelebornConf,
- metrics: ShuffleReadMetricsReporter)
+ metrics: ShuffleReadMetricsReporter,
+ shuffleIdTracker: ExecutorShuffleIdTracker)
extends ShuffleReader[K, C] with Logging {
private val dep = handle.dependency
@@ -60,6 +61,11 @@ class CelebornShuffleReader[K, C](
val serializerInstance = newSerializerInstance(dep)
+ val shuffleId = SparkUtils.celebornShuffleId(shuffleClient, handle,
context, false)
+ shuffleIdTracker.track(handle.shuffleId, shuffleId)
+ logDebug(
+ s"get shuffleId $shuffleId for appShuffleId ${handle.shuffleId}
attemptNum ${context.stageAttemptNumber()}")
+
// Update the context task metrics for each record read.
val metricsCallback = new MetricsCallback {
override def incBytesRead(bytesWritten: Long): Unit = {
@@ -89,7 +95,7 @@ class CelebornShuffleReader[K, C](
if (exceptionRef.get() == null) {
try {
val inputStream = shuffleClient.readPartition(
- handle.shuffleId,
+ shuffleId,
partitionId,
context.attemptNumber(),
startMapIndex,
@@ -115,7 +121,22 @@ class CelebornShuffleReader[K, C](
var inputStream: CelebornInputStream = streams.get(partitionId)
while (inputStream == null) {
if (exceptionRef.get() != null) {
- throw exceptionRef.get()
+ exceptionRef.get() match {
+ case ce @ (_: CelebornIOException | _:
PartitionUnRetryAbleException) =>
+ if (handle.throwsFetchFailure &&
+ shuffleClient.reportShuffleFetchFailure(handle.shuffleId,
shuffleId)) {
+ throw new FetchFailedException(
+ null,
+ handle.shuffleId,
+ -1,
+ -1,
+ partitionId,
+ SparkUtils.FETCH_FAILURE_ERROR_MSG + shuffleId,
+ ce)
+ } else
+ throw ce
+ case e => throw e
+ }
}
Thread.sleep(50)
inputStream = streams.get(partitionId)
@@ -124,12 +145,31 @@ class CelebornShuffleReader[K, C](
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait))
// ensure inputStream is closed when task completes
context.addTaskCompletionListener[Unit](_ => inputStream.close())
- inputStream
+ (partitionId, inputStream)
} else {
- CelebornInputStream.empty()
+ (partitionId, CelebornInputStream.empty())
}
- }).flatMap(
- serializerInstance.deserializeStream(_).asKeyValueIterator)
+ }).map { case (partitionId, inputStream) =>
+ (partitionId,
serializerInstance.deserializeStream(inputStream).asKeyValueIterator)
+ }.flatMap { case (partitionId, iter) =>
+ try {
+ iter
+ } catch {
+ case e @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
+ if (handle.throwsFetchFailure &&
+ shuffleClient.reportShuffleFetchFailure(handle.shuffleId,
shuffleId)) {
+ throw new FetchFailedException(
+ null,
+ handle.shuffleId,
+ -1,
+ -1,
+ partitionId,
+ SparkUtils.FETCH_FAILURE_ERROR_MSG + shuffleId,
+ e)
+ } else
+ throw e
+ }
+ }
val iterWithUpdatedRecordsRead =
if (GlutenShuffleDependencyHelper.isGlutenDep(dep.getClass.getName)) {
diff --git
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
index f4109753e..65930d9ba 100644
---
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
+++
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
@@ -236,7 +236,7 @@ public abstract class CelebornShuffleWriterSuiteBase {
final File tempFile = new File(tempDir, UUID.randomUUID().toString());
final CelebornShuffleHandle<Integer, String, String> handle =
new CelebornShuffleHandle<>(
- appId, host, port, userIdentifier, shuffleId, numMaps, dependency);
+ appId, host, port, userIdentifier, shuffleId, false, numMaps,
dependency);
final ShuffleClient client = new DummyShuffleClient(conf, tempFile);
((DummyShuffleClient) client).initReducePartitionMap(shuffleId,
numPartitions, 1);
diff --git
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
index d8c514dc7..e2876ff4a 100644
---
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
+++
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
@@ -37,6 +37,12 @@ public class HashBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterSuiteBase
ShuffleWriteMetricsReporter metrics)
throws IOException {
return new HashBasedShuffleWriter<Integer, String, String>(
- handle, context, conf, client, metrics, SendBufferPool.get(1, 30, 60));
+ SparkUtils.celebornShuffleId(client, handle, context, true),
+ handle,
+ context,
+ conf,
+ client,
+ metrics,
+ SendBufferPool.get(1, 30, 60));
}
}
diff --git
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerHook.java
similarity index 61%
copy from
client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
copy to
client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerHook.java
index 9d411ceab..3da77f014 100644
---
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleHandle.scala
+++
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerHook.java
@@ -15,19 +15,21 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.celeborn
+package org.apache.spark.shuffle.celeborn;
-import org.apache.spark.ShuffleDependency
-import org.apache.spark.shuffle.BaseShuffleHandle
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleHandle;
-import org.apache.celeborn.common.identity.UserIdentifier
+public interface ShuffleManagerHook {
-class CelebornShuffleHandle[K, V, C](
- val appUniqueId: String,
- val lifecycleManagerHost: String,
- val lifecycleManagerPort: Int,
- val userIdentifier: UserIdentifier,
- shuffleId: Int,
- val numMappers: Int,
- dependency: ShuffleDependency[K, V, C])
- extends BaseShuffleHandle(shuffleId, dependency)
+ default void exec(
+ ShuffleHandle handle, int startPartition, int endPartition, TaskContext
context) {}
+
+ default void exec(
+ ShuffleHandle handle,
+ int startMapIndex,
+ int endMapIndex,
+ int startPartition,
+ int endPartition,
+ TaskContext context) {};
+}
diff --git
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java
new file mode 100644
index 000000000..637f33771
--- /dev/null
+++
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/TestCelebornShuffleManager.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleHandle;
+import org.apache.spark.shuffle.ShuffleReadMetricsReporter;
+import org.apache.spark.shuffle.ShuffleReader;
+
+public class TestCelebornShuffleManager extends SparkShuffleManager {
+
+ private static ShuffleManagerHook shuffleReaderGetHook = null;
+
+ public TestCelebornShuffleManager(SparkConf conf) {
+ super(conf, true);
+ }
+
+ public static void registerReaderGetHook(ShuffleManagerHook hook) {
+ shuffleReaderGetHook = hook;
+ }
+
+ @Override
+ public <K, C> ShuffleReader<K, C> getReader(
+ ShuffleHandle handle,
+ int startMapIndex,
+ int endMapIndex,
+ int startPartition,
+ int endPartition,
+ TaskContext context,
+ ShuffleReadMetricsReporter metrics) {
+ if (shuffleReaderGetHook != null) {
+ shuffleReaderGetHook.exec(
+ handle, startMapIndex, endMapIndex, startPartition, endPartition,
context);
+ }
+ return super.getReader(
+ handle, startMapIndex, endMapIndex, startPartition, endPartition,
context, metrics);
+ }
+
+ @Override
+ public <K, C> ShuffleReader<K, C> getReader(
+ ShuffleHandle handle,
+ int startPartition,
+ int endPartition,
+ TaskContext context,
+ ShuffleReadMetricsReporter metrics) {
+ if (shuffleReaderGetHook != null) {
+ shuffleReaderGetHook.exec(handle, startPartition, endPartition, context);
+ }
+ return super.getReader(handle, startPartition, endPartition, context,
metrics);
+ }
+}
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index b4b566d1a..72230a536 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -212,4 +212,13 @@ public abstract class ShuffleClient {
int shuffleId, int numMappers, int numPartitions);
public abstract PushState getPushState(String mapKey);
+
+ public abstract int getShuffleId(int appShuffleId, String
appShuffleIdentifier, boolean isWriter);
+
+ /**
+ * report shuffle data fetch failure to LifecycleManager for special
handling, eg, shuffle status
+ * cleanup for spark app. It must be a sync call and make sure the cleanup
is done, otherwise,
+ * incorrect shuffle data can be fetched in re-run tasks
+ */
+ public abstract boolean reportShuffleFetchFailure(int appShuffleId, int
shuffleId);
}
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index a300ddf5f..b1a377517 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -86,6 +86,9 @@ public class ShuffleClientImpl extends ShuffleClient {
protected final int BATCH_HEADER_SIZE = 4 * 4;
+ // key: appShuffleIdentifier, value: shuffleId
+ protected Map<String, Integer> shuffleIdCache =
JavaUtils.newConcurrentHashMap();
+
// key: shuffleId, value: (partitionId, PartitionLocation)
final Map<Integer, ConcurrentHashMap<Integer, PartitionLocation>>
reducePartitionMap =
JavaUtils.newConcurrentHashMap();
@@ -517,6 +520,41 @@ public class ShuffleClientImpl extends ShuffleClient {
return pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
}
+ @Override
+ public int getShuffleId(int appShuffleId, String appShuffleIdentifier,
boolean isWriter) {
+ return shuffleIdCache.computeIfAbsent(
+ appShuffleIdentifier,
+ (id) -> {
+ PbGetShuffleId pbGetShuffleId =
+ PbGetShuffleId.newBuilder()
+ .setAppShuffleId(appShuffleId)
+ .setAppShuffleIdentifier(appShuffleIdentifier)
+ .setIsShuffleWriter(isWriter)
+ .build();
+ PbGetShuffleIdResponse pbGetShuffleIdResponse =
+ lifecycleManagerRef.askSync(
+ pbGetShuffleId,
+ conf.clientRpcRegisterShuffleRpcAskTimeout(),
+ ClassTag$.MODULE$.apply(PbGetShuffleIdResponse.class));
+ return pbGetShuffleIdResponse.getShuffleId();
+ });
+ }
+
+ @Override
+ public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) {
+ PbReportShuffleFetchFailure pbReportShuffleFetchFailure =
+ PbReportShuffleFetchFailure.newBuilder()
+ .setAppShuffleId(appShuffleId)
+ .setShuffleId(shuffleId)
+ .build();
+ PbReportShuffleFetchFailureResponse pbReportShuffleFetchFailureResponse =
+ lifecycleManagerRef.askSync(
+ pbReportShuffleFetchFailure,
+ conf.clientRpcRegisterShuffleRpcAskTimeout(),
+
ClassTag$.MODULE$.apply(PbReportShuffleFetchFailureResponse.class));
+ return pbReportShuffleFetchFailureResponse.getSuccess();
+ }
+
private ConcurrentHashMap<Integer, PartitionLocation>
registerShuffleInternal(
int shuffleId,
int numMappers,
diff --git
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index cf54dfcd8..ed308840a 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -82,6 +82,13 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Int,
PartitionLocation]]()
private val userIdentifier: UserIdentifier =
IdentityProvider.instantiate(conf).provide()
private val availableStorageTypes = conf.availableStorageTypes
+ // app shuffle id -> LinkedHashMap of (app shuffle identifier, (shuffle id,
fetch status))
+ private val shuffleIdMapping = JavaUtils.newConcurrentHashMap[
+ Int,
+ scala.collection.mutable.LinkedHashMap[String, (Int, Boolean)]]()
+ private val shuffleIdGenerator = new AtomicInteger(0)
+ // app shuffle id -> whether shuffle is determinate, rerun of a
indeterminate shuffle gets different result
+ private val appShuffleDeterminateMap = JavaUtils.newConcurrentHashMap[Int,
Boolean]();
private val rpcCacheSize = conf.clientRpcCacheSize
private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
@@ -312,6 +319,19 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
case GetReducerFileGroup(shuffleId: Int) =>
logDebug(s"Received GetShuffleFileGroup request for shuffleId
$shuffleId.")
handleGetReducerFileGroup(context, shuffleId)
+
+ case pb: PbGetShuffleId =>
+ val appShuffleId = pb.getAppShuffleId
+ val appShuffleIdentifier = pb.getAppShuffleIdentifier
+ val isWriter = pb.getIsShuffleWriter
+ logDebug(s"Received GetShuffleId request, appShuffleId $appShuffleId
appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter.")
+ handleGetShuffleIdForApp(context, appShuffleId, appShuffleIdentifier,
isWriter)
+
+ case pb: PbReportShuffleFetchFailure =>
+ val appShuffleId = pb.getAppShuffleId
+ val shuffleId = pb.getShuffleId
+ logDebug(s"Received ReportShuffleFetchFailure request, appShuffleId
$appShuffleId shuffleId $shuffleId")
+ handleReportShuffleFetchFailure(context, appShuffleId, shuffleId)
}
private def offerAndReserveSlots(
@@ -628,6 +648,118 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
commitManager.handleGetReducerFileGroup(context, shuffleId)
}
+ private def handleGetShuffleIdForApp(
+ context: RpcCallContext,
+ appShuffleId: Int,
+ appShuffleIdentifier: String,
+ isWriter: Boolean): Unit = {
+ val shuffleIds = shuffleIdMapping.computeIfAbsent(
+ appShuffleId,
+ new function.Function[Int,
scala.collection.mutable.LinkedHashMap[String, (Int, Boolean)]]() {
+ override def apply(id: Int)
+ : scala.collection.mutable.LinkedHashMap[String, (Int, Boolean)] =
{
+ val newShuffleId = shuffleIdGenerator.getAndIncrement()
+ logInfo(s"generate new shuffleId $newShuffleId for appShuffleId
$appShuffleId appShuffleIdentifier $appShuffleIdentifier")
+ scala.collection.mutable.LinkedHashMap(appShuffleIdentifier ->
(newShuffleId, true))
+ }
+ })
+
+ def isAllMaptaskEnd(shuffleId: Int): Boolean = {
+ !commitManager.getMapperAttempts(shuffleId).exists(_ < 0)
+ }
+
+ shuffleIds.synchronized {
+ if (isWriter) {
+ shuffleIds.get(appShuffleIdentifier) match {
+ case Some((shuffleId, _)) =>
+ val pbGetShuffleIdResponse =
+
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).build()
+ context.reply(pbGetShuffleIdResponse)
+ case None =>
+ Option(appShuffleDeterminateMap.get(appShuffleId)).map {
determinate =>
+ val candidateShuffle =
+ if (determinate)
+ shuffleIds.values.toSeq.reverse.find(e => e._2 == true)
+ else
+ None
+
+ val shuffleId: Integer =
+ if (determinate && candidateShuffle.isDefined) {
+ val id = candidateShuffle.get._1
+ logInfo(s"reuse existing shuffleId $id for appShuffleId
$appShuffleId appShuffleIdentifier $appShuffleIdentifier")
+ id
+ } else {
+ val newShuffleId = shuffleIdGenerator.getAndIncrement()
+ logInfo(s"generate new shuffleId $newShuffleId for
appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier")
+ shuffleIds.put(appShuffleIdentifier, (newShuffleId, true))
+ newShuffleId
+ }
+ val pbGetShuffleIdResponse =
+
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).build()
+ context.reply(pbGetShuffleIdResponse)
+ }.orElse(
+ throw new UnsupportedOperationException(
+ s"unexpected! unknown appShuffleId $appShuffleId when checking
shuffle deterministic level"))
+ }
+ } else {
+ shuffleIds.values.map(v => v._1).toSeq.reverse.find(isAllMaptaskEnd)
match {
+ case Some(shuffleId) =>
+ val pbGetShuffleIdResponse = {
+ logDebug(
+ s"get shuffleId $shuffleId for appShuffleId $appShuffleId
appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter")
+
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).build()
+ }
+ context.reply(pbGetShuffleIdResponse)
+ case None =>
+ throw new UnsupportedOperationException(
+ s"unexpected! there is no finished map stage associated with
appShuffleId $appShuffleId")
+ }
+ }
+ }
+ }
+
+ private def handleReportShuffleFetchFailure(
+ context: RpcCallContext,
+ appShuffleId: Int,
+ shuffleId: Int): Unit = {
+
+ val shuffleIds = shuffleIdMapping.get(appShuffleId)
+ if (shuffleIds == null) {
+ throw new UnsupportedOperationException(s"unexpected! unknown
appShuffleId $appShuffleId")
+ }
+ var ret = true
+ shuffleIds.synchronized {
+ shuffleIds.find(e => e._2._1 == shuffleId) match {
+ case Some((appShuffleIdentifier, (shuffleId, true))) =>
+ logInfo(s"handle fetch failure for appShuffleId $appShuffleId
shuffleId $shuffleId")
+ appShuffleTrackerCallback match {
+ case Some(callback) =>
+ try {
+ callback.accept(appShuffleId)
+ } catch {
+ case t: Throwable =>
+ logError(t.toString)
+ ret = false
+ }
+ shuffleIds.put(appShuffleIdentifier, (shuffleId, false))
+ case None =>
+ throw new UnsupportedOperationException(
+ "unexpected! appShuffleTrackerCallback is not registered")
+ }
+ case Some((appShuffleIdentifier, (shuffleId, false))) =>
+ logInfo(
+ s"Ignoring fetch failure from appShuffleIdentifier
$appShuffleIdentifier shuffleId $shuffleId, " +
+ "fetch failure is already reported and handled by other reader")
+ case None => throw new UnsupportedOperationException(
+ s"unexpected! unknown shuffleId $shuffleId for appShuffleId
$appShuffleId")
+ }
+ }
+
+ val pbReportShuffleFetchFailureResponse =
+ PbReportShuffleFetchFailureResponse.newBuilder().setSuccess(ret).build()
+ context.reply(pbReportShuffleFetchFailureResponse)
+ }
+
private def handleStageEnd(shuffleId: Int): Unit = {
// check whether shuffle has registered
if (!registeredShuffle.contains(shuffleId)) {
@@ -701,6 +833,19 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
logInfo(s"Unregister for $shuffleId success.")
}
+ def unregisterAppShuffle(appShuffleId: Int): Unit = {
+ logInfo(s"Unregister appShuffleId $appShuffleId starts...")
+ appShuffleDeterminateMap.remove(appShuffleId)
+ val shuffleIds = shuffleIdMapping.remove(appShuffleId)
+ if (shuffleIds != null) {
+ shuffleIds.synchronized(
+ shuffleIds.values.map {
+ case (shuffleId, _) =>
+ unregisterShuffle(shuffleId)
+ })
+ }
+ }
+
/* ========================================================== *
| END OF EVENT HANDLER |
* ========================================================== */
@@ -1208,6 +1353,15 @@ class LifecycleManager(val appUniqueId: String, val
conf: CelebornConf) extends
workerStatusTracker.registerWorkerStatusListener(workerStatusListener)
}
+ @volatile private var appShuffleTrackerCallback: Option[Consumer[Integer]] =
None
+ def registerShuffleTrackerCallback(callback: Consumer[Integer]): Unit = {
+ appShuffleTrackerCallback = Some(callback)
+ }
+
+ def registerAppShuffleDeterminate(appShuffleId: Int, determinate: Boolean):
Unit = {
+ appShuffleDeterminateMap.put(appShuffleId, determinate)
+ }
+
// Initialize at the end of LifecycleManager construction.
initialize()
diff --git
a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
index 339b8b859..bb4f4fe41 100644
--- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -153,6 +153,16 @@ public class DummyShuffleClient extends ShuffleClient {
return new PushState(conf);
}
+ @Override
+ public int getShuffleId(int appShuffleId, String appShuffleIdentifier,
boolean isWriter) {
+ return appShuffleId;
+ }
+
+ @Override
+ public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) {
+ return true;
+ }
+
public void initReducePartitionMap(int shuffleId, int numPartitions, int
workerNum) {
ConcurrentHashMap<Integer, PartitionLocation> map =
JavaUtils.newConcurrentHashMap();
String host = "host";
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index 8fa07a145..a716d7ded 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -32,11 +32,15 @@ import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
+import org.apache.celeborn.common.protocol.PbGetShuffleId;
+import org.apache.celeborn.common.protocol.PbGetShuffleIdResponse;
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbPushDataHandShake;
import org.apache.celeborn.common.protocol.PbReadAddCredit;
import org.apache.celeborn.common.protocol.PbRegionFinish;
import org.apache.celeborn.common.protocol.PbRegionStart;
+import org.apache.celeborn.common.protocol.PbReportShuffleFetchFailure;
+import org.apache.celeborn.common.protocol.PbReportShuffleFetchFailureResponse;
import org.apache.celeborn.common.protocol.PbStreamChunkSlice;
import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.common.protocol.PbTransportableError;
@@ -90,6 +94,14 @@ public class TransportMessage implements Serializable {
return (T) PbChunkFetchRequest.parseFrom(payload);
case TRANSPORTABLE_ERROR_VALUE:
return (T) PbTransportableError.parseFrom(payload);
+ case GET_SHUFFLE_ID_VALUE:
+ return (T) PbGetShuffleId.parseFrom(payload);
+ case GET_SHUFFLE_ID_RESPONSE_VALUE:
+ return (T) PbGetShuffleIdResponse.parseFrom(payload);
+ case REPORT_SHUFFLE_FETCH_FAILURE_VALUE:
+ return (T) PbReportShuffleFetchFailure.parseFrom(payload);
+ case REPORT_SHUFFLE_FETCH_FAILURE_RESPONSE_VALUE:
+ return (T) PbReportShuffleFetchFailureResponse.parseFrom(payload);
default:
logger.error("Unexpected type {}", type);
}
diff --git a/common/src/main/proto/TransportMessages.proto
b/common/src/main/proto/TransportMessages.proto
index 56f0e7561..a595a785d 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -87,6 +87,10 @@ enum MessageType {
TRANSPORTABLE_ERROR = 64;
WORKER_EXCLUDE = 65;
WORKER_EXCLUDE_RESPONSE = 66;
+ REPORT_SHUFFLE_FETCH_FAILURE = 67;
+ REPORT_SHUFFLE_FETCH_FAILURE_RESPONSE = 68;
+ GET_SHUFFLE_ID = 69;
+ GET_SHUFFLE_ID_RESPONSE = 70;
}
enum StreamType {
@@ -296,6 +300,25 @@ message PbGetReducerFileGroupResponse {
repeated int32 partitionIds = 4;
}
+message PbGetShuffleId {
+ int32 appShuffleId = 1;
+ string appShuffleIdentifier = 2;
+ bool isShuffleWriter = 3;
+}
+
+message PbGetShuffleIdResponse {
+ int32 shuffleId = 1;
+}
+
+message PbReportShuffleFetchFailure {
+ int32 appShuffleId = 1;
+ int32 shuffleId = 2;
+}
+
+message PbReportShuffleFetchFailureResponse {
+ bool success = 1;
+}
+
message PbUnregisterShuffle {
string appId = 1;
int32 shuffleId = 2;
diff --git
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 34fcfaca7..5b170eb30 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -777,6 +777,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable
with Logging with Se
def clientFetchTimeoutMs: Long = get(CLIENT_FETCH_TIMEOUT)
def clientFetchMaxReqsInFlight: Int = get(CLIENT_FETCH_MAX_REQS_IN_FLIGHT)
def clientFetchMaxRetriesForEachReplica: Int =
get(CLIENT_FETCH_MAX_RETRIES_FOR_EACH_REPLICA)
+ def clientFetchThrowsFetchFailure: Boolean =
get(CLIENT_FETCH_THROWS_FETCH_FAILURE)
def clientFetchExcludeWorkerOnFailureEnabled: Boolean =
get(CLIENT_FETCH_EXCLUDE_WORKER_ON_FAILURE_ENABLED)
def clientFetchExcludedWorkerExpireTimeout: Long =
@@ -3228,6 +3229,14 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(3)
+ val CLIENT_FETCH_THROWS_FETCH_FAILURE: ConfigEntry[Boolean] =
+ buildConf("celeborn.client.spark.fetch.throwsFetchFailure")
+ .categories("client")
+ .version("0.4.0")
+ .doc("client throws FetchFailedException instead of CelebornIOException")
+ .booleanConf
+ .createWithDefault(false)
+
val CLIENT_FETCH_EXCLUDE_WORKER_ON_FAILURE_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.client.fetch.excludeWorkerOnFailure.enabled")
.categories("client")
diff --git
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 63b4f7b1f..63fae0775 100644
---
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -487,6 +487,18 @@ object ControlMessages extends Logging {
case pb: PbRegisterWorker =>
new TransportMessage(MessageType.REGISTER_WORKER, pb.toByteArray)
+ case pb: PbGetShuffleId =>
+ new TransportMessage(MessageType.GET_SHUFFLE_ID, pb.toByteArray)
+
+ case pb: PbGetShuffleIdResponse =>
+ new TransportMessage(MessageType.GET_SHUFFLE_ID_RESPONSE, pb.toByteArray)
+
+ case pb: PbReportShuffleFetchFailure =>
+ new TransportMessage(MessageType.REPORT_SHUFFLE_FETCH_FAILURE,
pb.toByteArray)
+
+ case pb: PbReportShuffleFetchFailureResponse =>
+ new TransportMessage(MessageType.REPORT_SHUFFLE_FETCH_FAILURE_RESPONSE,
pb.toByteArray)
+
case HeartbeatFromWorker(
host,
rpcPort,
@@ -980,6 +992,18 @@ object ControlMessages extends Logging {
attempts,
partitionIds)
+ case GET_SHUFFLE_ID_VALUE =>
+ message.getParsedPayload()
+
+ case GET_SHUFFLE_ID_RESPONSE_VALUE =>
+ message.getParsedPayload()
+
+ case REPORT_SHUFFLE_FETCH_FAILURE_VALUE =>
+ message.getParsedPayload()
+
+ case REPORT_SHUFFLE_FETCH_FAILURE_RESPONSE_VALUE =>
+ message.getParsedPayload()
+
case UNREGISTER_SHUFFLE_VALUE =>
PbUnregisterShuffle.parseFrom(message.getPayload)
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 1c1d4e7a6..4871ac550 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -98,6 +98,7 @@ license: |
| celeborn.client.shuffle.partitionSplit.threshold | 1G | Shuffle file size
threshold, if file size exceeds this, trigger split. | 0.3.0 |
| celeborn.client.shuffle.rangeReadFilter.enabled | false | If a spark
application have skewed partition, this value can set to true to improve
performance. | 0.2.0 |
| celeborn.client.slot.assign.maxWorkers | 10000 | Max workers that slots of
one shuffle can be allocated on. Will choose the smaller positive one from
Master side and Client side, see `celeborn.master.slot.assign.maxWorkers`. |
0.3.1 |
+| celeborn.client.spark.fetch.throwsFetchFailure | false | client throws
FetchFailedException instead of CelebornIOException | 0.4.0 |
| celeborn.client.spark.push.sort.memory.threshold | 64m | When
SortBasedPusher use memory over the threshold, will trigger push data. If the
pipeline push feature is enabled
(`celeborn.client.spark.push.sort.pipeline.enabled=true`), the SortBasedPusher
will trigger a data push when the memory usage exceeds half of the threshold(by
default, 32m). | 0.3.0 |
| celeborn.client.spark.push.sort.pipeline.enabled | false | Whether to enable
pipelining for sort based shuffle writer. If true, double buffering will be
used to pipeline push | 0.3.0 |
| celeborn.client.spark.push.unsafeRow.fastWrite.enabled | true | This is
Celeborn's optimization on UnsafeRow for Spark and it's true by default. If you
have changed UnsafeRow's memory layout set this to false. | 0.2.2 |
diff --git a/tests/spark-it/pom.xml b/tests/spark-it/pom.xml
index 783cbe100..75e50221b 100644
--- a/tests/spark-it/pom.xml
+++ b/tests/spark-it/pom.xml
@@ -93,6 +93,13 @@
<version>${project.version}</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.celeborn</groupId>
+
<artifactId>celeborn-client-spark-2_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</profile>
<profile>
@@ -104,6 +111,13 @@
<version>${project.version}</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.celeborn</groupId>
+
<artifactId>celeborn-client-spark-3_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</profile>
<profile>
@@ -115,6 +129,13 @@
<version>${project.version}</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.celeborn</groupId>
+
<artifactId>celeborn-client-spark-3_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</profile>
<profile>
@@ -126,6 +147,13 @@
<version>${project.version}</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.celeborn</groupId>
+
<artifactId>celeborn-client-spark-3_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</profile>
<profile>
@@ -137,6 +165,13 @@
<version>${project.version}</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.celeborn</groupId>
+
<artifactId>celeborn-client-spark-3_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</profile>
<profile>
@@ -148,6 +183,13 @@
<version>${project.version}</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.celeborn</groupId>
+
<artifactId>celeborn-client-spark-3_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</profile>
<profile>
@@ -159,6 +201,13 @@
<version>${project.version}</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.celeborn</groupId>
+
<artifactId>celeborn-client-spark-3_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</profile>
</profiles>
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
new file mode 100644
index 000000000..07c6b35eb
--- /dev/null
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornFetchFailureSuite.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.spark
+
+import java.io.File
+import java.util.concurrent.atomic.AtomicBoolean
+
+import org.apache.spark.{SparkConf, TaskContext}
+import org.apache.spark.shuffle.ShuffleHandle
+import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle,
ShuffleManagerHook, SparkUtils, TestCelebornShuffleManager}
+import org.apache.spark.sql.SparkSession
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.ShuffleClient
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.protocol.ShuffleMode
+import org.apache.celeborn.service.deploy.worker.Worker
+
+class CelebornFetchFailureSuite extends AnyFunSuite
+ with SparkTestBase
+ with BeforeAndAfterEach {
+
+ override def beforeEach(): Unit = {
+ ShuffleClient.reset()
+ }
+
+ override def afterEach(): Unit = {
+ System.gc()
+ }
+
+ var workerDirs: Seq[String] = Seq.empty
+
+ override def createWorker(map: Map[String, String]): Worker = {
+ val storageDir = createTmpDir()
+ workerDirs = workerDirs :+ storageDir
+ super.createWorker(map, storageDir)
+ }
+
+ class ShuffleReaderGetHook(conf: CelebornConf) extends ShuffleManagerHook {
+ var executed: AtomicBoolean = new AtomicBoolean(false)
+ val lock = new Object
+
+ override def exec(
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext): Unit = {
+ if (executed.get() == true) return
+
+ lock.synchronized {
+ handle match {
+ case h: CelebornShuffleHandle[_, _, _] => {
+ val appUniqueId = h.appUniqueId
+ val shuffleClient = ShuffleClient.get(
+ h.appUniqueId,
+ h.lifecycleManagerHost,
+ h.lifecycleManagerPort,
+ conf,
+ h.userIdentifier)
+ val celebornShuffleId =
SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
+ val datafile =
+ workerDirs.map(dir => {
+ new
File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
+ }).filter(_.exists())
+ .flatMap(_.listFiles().iterator).headOption
+ datafile match {
+ case Some(file) => file.delete()
+ case None => throw new RuntimeException("unexpected, there must
be some data file")
+ }
+ }
+ case _ => throw new RuntimeException("unexpected, only support
RssShuffleHandle here")
+ }
+ executed.set(true)
+ }
+ }
+ }
+
+ test("celeborn spark integration test - Fetch Failure") {
+ val sparkConf = new
SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
+ val sparkSession = SparkSession.builder()
+ .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+ .config("spark.sql.shuffle.partitions", 2)
+ .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
+ .config("spark.celeborn.shuffle.enabled", "true")
+ .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
+ .config(
+ "spark.shuffle.manager",
+ "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
+ .getOrCreate()
+
+ val celebornConf =
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
+ val hook = new ShuffleReaderGetHook(celebornConf)
+ TestCelebornShuffleManager.registerReaderGetHook(hook)
+
+ val value = Range(1, 10000).mkString(",")
+ val tuples = sparkSession.sparkContext.parallelize(1 to 10000, 2)
+ .map { i => (i, value) }.groupByKey(16).collect()
+
+ // verify result
+ assert(hook.executed.get() == true)
+ assert(tuples.length == 10000)
+ for (elem <- tuples) {
+ assert(elem._2.mkString(",").equals(value))
+ }
+
+ sparkSession.stop()
+ }
+
+ test("celeborn spark integration test - Fetch Failure with multiple shuffle
data") {
+ val sparkConf = new
SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
+ val sparkSession = SparkSession.builder()
+ .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+ .config("spark.sql.shuffle.partitions", 2)
+ .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
+ .config("spark.celeborn.shuffle.enabled", "true")
+ .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
+ .config(
+ "spark.shuffle.manager",
+ "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
+ .getOrCreate()
+
+ val celebornConf =
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
+ val hook = new ShuffleReaderGetHook(celebornConf)
+ TestCelebornShuffleManager.registerReaderGetHook(hook)
+
+ import sparkSession.implicits._
+
+ val df1 = Seq((1, "a"), (2, "b")).toDF("id", "data").groupBy("id").count()
+ val df2 = Seq((2, "c"), (2, "d")).toDF("id", "data").groupBy("id").count()
+ val tuples = df1.hint("merge").join(df2, "id").select("*").collect()
+
+ // verify result
+ assert(hook.executed.get() == true)
+ val expect = "[2,1,2]"
+ assert(tuples.head.toString().equals(expect))
+ sparkSession.stop()
+ }
+
+ test("celeborn spark integration test - Fetch Failure with RDD reuse") {
+ val sparkConf = new
SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
+ val sparkSession = SparkSession.builder()
+ .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+ .config("spark.sql.shuffle.partitions", 2)
+ .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
+ .config("spark.celeborn.shuffle.enabled", "true")
+ .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
+ .config(
+ "spark.shuffle.manager",
+ "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
+ .getOrCreate()
+
+ val celebornConf =
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
+ val hook = new ShuffleReaderGetHook(celebornConf)
+ TestCelebornShuffleManager.registerReaderGetHook(hook)
+
+ val sc = sparkSession.sparkContext
+ val rdd1 = sc.parallelize(0 until 10000, 3).map(v => (v, v)).groupByKey()
+ val rdd2 = sc.parallelize(0 until 10000, 2).map(v => (v, v)).groupByKey()
+ val rdd3 = rdd1.map(v => (v._2, v._1))
+
+ hook.executed.set(true)
+
+ rdd1.count()
+ rdd2.count()
+
+ hook.executed.set(false)
+ rdd3.count()
+ hook.executed.set(false)
+ rdd3.count()
+ hook.executed.set(false)
+ rdd3.count()
+ hook.executed.set(false)
+ rdd3.count()
+
+ sparkSession.stop()
+ }
+
+ test("celeborn spark integration test - Fetch Failure with read write
shuffles in one stage") {
+ val sparkConf = new
SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
+ val sparkSession = SparkSession.builder()
+ .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
+ .config("spark.sql.shuffle.partitions", 2)
+ .config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
+ .config("spark.celeborn.shuffle.enabled", "true")
+ .config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
+ .config(
+ "spark.shuffle.manager",
+ "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
+ .getOrCreate()
+
+ val celebornConf =
SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
+ val hook = new ShuffleReaderGetHook(celebornConf)
+ TestCelebornShuffleManager.registerReaderGetHook(hook)
+
+ val sc = sparkSession.sparkContext
+ val rdd1 = sc.parallelize(0 until 10000, 3).map(v => (v, v)).groupByKey()
+ val rdd2 = rdd1.map(v => (v._2, v._1)).groupByKey()
+
+ hook.executed.set(true)
+ rdd1.count()
+
+ hook.executed.set(false)
+ rdd2.count()
+
+ sparkSession.stop()
+ }
+}
diff --git
a/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
b/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
index f466eb9f1..8c8411910 100644
---
a/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
+++
b/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
@@ -35,13 +35,13 @@ trait MiniClusterFeature extends Logging {
var masterInfo: (Master, Thread) = _
val workerInfos = new mutable.HashMap[Worker, Thread]()
- private def runnerWrap[T](code: => T): Thread = new Thread(new Runnable {
+ def runnerWrap[T](code: => T): Thread = new Thread(new Runnable {
override def run(): Unit = {
Utils.tryLogNonFatalError(code)
}
})
- private def createTmpDir(): String = {
+ def createTmpDir(): String = {
val tmpDir = Files.createTempDirectory("celeborn-")
logInfo(s"created temp dir: $tmpDir")
tmpDir.toFile.deleteOnExit()
@@ -66,10 +66,14 @@ trait MiniClusterFeature extends Logging {
master
}
- private def createWorker(map: Map[String, String] = null): Worker = {
+ def createWorker(map: Map[String, String] = null): Worker = {
+ createWorker(map, createTmpDir())
+ }
+
+ def createWorker(map: Map[String, String], storageDir: String): Worker = {
logInfo("start create worker for mini cluster")
val conf = new CelebornConf()
- conf.set(CelebornConf.WORKER_STORAGE_DIRS.key, createTmpDir())
+ conf.set(CelebornConf.WORKER_STORAGE_DIRS.key, storageDir)
conf.set(CelebornConf.WORKER_DISK_MONITOR_ENABLED.key, "false")
conf.set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "256K")
conf.set(CelebornConf.WORKER_HTTP_PORT.key,
s"${workerHttpPort.incrementAndGet()}")