This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 045411ac3 [CELEBORN-1855] LifecycleManager return appshuffleId for non
barrier stage when fetch fail has been reported
045411ac3 is described below
commit 045411ac34c749d8ca2d7bed8d0b9c3554f71287
Author: lijianfu03 <[email protected]>
AuthorDate: Tue May 13 16:14:03 2025 +0800
[CELEBORN-1855] LifecycleManager return appshuffleId for non barrier stage
when fetch fail has been reported
### What changes were proposed in this pull request?
for non barrier shuffle read stage,
LifecycleManager#handleGetShuffleIdForApp always return appshuffleId whether
fetch status is true or not.
### Why are the changes needed?
As described in
[jira](https://issues.apache.org/jira/browse/CELEBORN-1855), If
LifecycleManager only returns appshuffleId whose fetch status is success, the
task will fail directly to "there is no finished map stage associated with",
but previous fetch fail event reported may not be fatal.So just give it a chance
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Closes #3090 from buska88/celeborn-1855.
Authored-by: lijianfu03 <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../apache/spark/shuffle/celeborn/SparkUtils.java | 17 ++++++++++++-----
.../shuffle/celeborn/CelebornShuffleReader.scala | 22 +++++++++++++++++++---
.../apache/spark/shuffle/celeborn/SparkUtils.java | 17 ++++++++++++-----
.../shuffle/celeborn/CelebornShuffleReader.scala | 22 ++++++++++++++++++++--
.../apache/celeborn/client/DummyShuffleClient.java | 6 ++++--
.../org/apache/celeborn/client/ShuffleClient.java | 4 +++-
.../apache/celeborn/client/ShuffleClientImpl.java | 7 ++++---
.../apache/celeborn/client/LifecycleManager.scala | 18 ++++++++++++------
common/src/main/proto/TransportMessages.proto | 1 +
.../celeborn/tests/spark/SparkTestBase.scala | 3 ++-
10 files changed, 89 insertions(+), 28 deletions(-)
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index 44dd1ea28..b52758581 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -63,6 +63,7 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.exception.CelebornRuntimeException;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
import org.apache.celeborn.common.util.JavaUtils;
@@ -161,11 +162,17 @@ public class SparkUtils {
Boolean isWriter) {
if (handle.throwsFetchFailure()) {
String appShuffleIdentifier =
getAppShuffleIdentifier(handle.shuffleId(), context);
- return client.getShuffleId(
- handle.shuffleId(),
- appShuffleIdentifier,
- isWriter,
- context instanceof BarrierTaskContext);
+ Tuple2<Integer, Boolean> res =
+ client.getShuffleId(
+ handle.shuffleId(),
+ appShuffleIdentifier,
+ isWriter,
+ context instanceof BarrierTaskContext);
+ if (!res._2) {
+ throw new CelebornRuntimeException(String.format("Get invalid shuffle
id %s", res._1));
+ } else {
+ return res._1;
+ }
} else {
return handle.shuffleId();
}
diff --git
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 2269aaf68..df63a94b1 100644
---
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -33,7 +33,7 @@ import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.read.CelebornInputStream
import org.apache.celeborn.client.read.MetricsCallback
import org.apache.celeborn.common.CelebornConf
-import org.apache.celeborn.common.exception.{CelebornIOException,
PartitionUnRetryAbleException}
+import org.apache.celeborn.common.exception.{CelebornIOException,
CelebornRuntimeException, PartitionUnRetryAbleException}
import
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils}
@@ -63,8 +63,24 @@ class CelebornShuffleReader[K, C](
override def read(): Iterator[Product2[K, C]] = {
val serializerInstance = dep.serializer.newInstance()
-
- val shuffleId = SparkUtils.celebornShuffleId(shuffleClient, handle,
context, false)
+ val shuffleId =
+ try {
+ SparkUtils.celebornShuffleId(shuffleClient, handle, context, false)
+ } catch {
+ case e: CelebornRuntimeException =>
+ logError(s"Failed to get shuffleId for appShuffleId
${handle.shuffleId}", e)
+ if (handle.throwsFetchFailure) {
+ throw new FetchFailedException(
+ null,
+ handle.shuffleId,
+ -1,
+ startPartition,
+ SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId,
+ e)
+ } else {
+ throw e
+ }
+ }
shuffleIdTracker.track(handle.shuffleId, shuffleId)
logDebug(
s"get shuffleId $shuffleId for appShuffleId ${handle.shuffleId}
attemptNum ${context.stageAttemptNumber()}")
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index edaeb28bd..b2e64565e 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -66,6 +66,7 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.exception.CelebornRuntimeException;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
import org.apache.celeborn.common.util.JavaUtils;
@@ -138,11 +139,17 @@ public class SparkUtils {
Boolean isWriter) {
if (handle.throwsFetchFailure()) {
String appShuffleIdentifier =
getAppShuffleIdentifier(handle.shuffleId(), context);
- return client.getShuffleId(
- handle.shuffleId(),
- appShuffleIdentifier,
- isWriter,
- context instanceof BarrierTaskContext);
+ Tuple2<Integer, Boolean> res =
+ client.getShuffleId(
+ handle.shuffleId(),
+ appShuffleIdentifier,
+ isWriter,
+ context instanceof BarrierTaskContext);
+ if (!res._2) {
+ throw new CelebornRuntimeException(String.format("Get invalid shuffle
id %s", res._1));
+ } else {
+ return res._1;
+ }
} else {
return handle.shuffleId();
}
diff --git
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 3e296b310..3d9f14f23 100644
---
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -40,7 +40,7 @@ import org.apache.celeborn.client.{ClientUtils, ShuffleClient}
import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups
import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
import org.apache.celeborn.common.CelebornConf
-import org.apache.celeborn.common.exception.{CelebornIOException,
PartitionUnRetryAbleException}
+import org.apache.celeborn.common.exception.{CelebornIOException,
CelebornRuntimeException, PartitionUnRetryAbleException}
import org.apache.celeborn.common.network.client.TransportClient
import org.apache.celeborn.common.network.protocol.TransportMessage
import org.apache.celeborn.common.protocol._
@@ -79,7 +79,25 @@ class CelebornShuffleReader[K, C](
val startTime = System.currentTimeMillis()
val serializerInstance = newSerializerInstance(dep)
- val shuffleId = SparkUtils.celebornShuffleId(shuffleClient, handle,
context, false)
+ val shuffleId =
+ try {
+ SparkUtils.celebornShuffleId(shuffleClient, handle, context, false)
+ } catch {
+ case e: CelebornRuntimeException =>
+ logError(s"Failed to get shuffleId for appShuffleId
${handle.shuffleId}", e)
+ if (throwsFetchFailure) {
+ throw new FetchFailedException(
+ null,
+ handle.shuffleId,
+ -1,
+ -1,
+ startPartition,
+ SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" +
handle.shuffleId,
+ e)
+ } else {
+ throw e
+ }
+ }
shuffleIdTracker.track(handle.shuffleId, shuffleId)
logDebug(
s"get shuffleId $shuffleId for appShuffleId ${handle.shuffleId}
attemptNum ${context.stageAttemptNumber()}")
diff --git
a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
index 6b3673b18..dd1a032c8 100644
--- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -32,6 +32,8 @@ import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
+import scala.Tuple2;
+
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -182,9 +184,9 @@ public class DummyShuffleClient extends ShuffleClient {
}
@Override
- public int getShuffleId(
+ public Tuple2<Integer, Boolean> getShuffleId(
int appShuffleId, String appShuffleIdentifier, boolean isWriter, boolean
isBarrierStage) {
- return appShuffleId;
+ return Tuple2.apply(appShuffleId, true);
}
@Override
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index dde2b36c4..bf0192e4a 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -26,6 +26,8 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.BiFunction;
+import scala.Tuple2;
+
import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.fs.FileSystem;
import org.slf4j.Logger;
@@ -285,7 +287,7 @@ public abstract class ShuffleClient {
public abstract PushState getPushState(String mapKey);
- public abstract int getShuffleId(
+ public abstract Tuple2<Integer, Boolean> getShuffleId(
int appShuffleId, String appShuffleIdentifier, boolean isWriter, boolean
isBarrierStage);
/**
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index c6fc7f6bd..81329e8d1 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -103,7 +103,7 @@ public class ShuffleClientImpl extends ShuffleClient {
protected byte[] extension;
// key: appShuffleIdentifier, value: shuffleId
- protected Map<String, Integer> shuffleIdCache =
JavaUtils.newConcurrentHashMap();
+ protected Map<String, Tuple2<Integer, Boolean>> shuffleIdCache =
JavaUtils.newConcurrentHashMap();
// key: shuffleId, value: (partitionId, PartitionLocation)
final Map<Integer, ConcurrentHashMap<Integer, PartitionLocation>>
reducePartitionMap =
@@ -626,7 +626,7 @@ public class ShuffleClientImpl extends ShuffleClient {
}
@Override
- public int getShuffleId(
+ public Tuple2<Integer, Boolean> getShuffleId(
int appShuffleId, String appShuffleIdentifier, boolean isWriter, boolean
isBarrierStage) {
return shuffleIdCache.computeIfAbsent(
appShuffleIdentifier,
@@ -643,7 +643,8 @@ public class ShuffleClientImpl extends ShuffleClient {
pbGetShuffleId,
conf.clientRpcRegisterShuffleAskTimeout(),
ClassTag$.MODULE$.apply(PbGetShuffleIdResponse.class));
- return pbGetShuffleIdResponse.getShuffleId();
+ return Tuple2.apply(
+ pbGetShuffleIdResponse.getShuffleId(),
pbGetShuffleIdResponse.getSuccess());
});
}
diff --git
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 923d2f838..20e1099d2 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -892,7 +892,8 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
if (shuffleIds == null) {
logWarning(s"unknown appShuffleId $appShuffleId, maybe no shuffle data
for this shuffle")
val pbGetShuffleIdResponse =
-
PbGetShuffleIdResponse.newBuilder().setShuffleId(UNKNOWN_APP_SHUFFLE_ID).build()
+
PbGetShuffleIdResponse.newBuilder().setShuffleId(UNKNOWN_APP_SHUFFLE_ID).setSuccess(
+ true).build()
context.reply(pbGetShuffleIdResponse)
return
}
@@ -906,7 +907,7 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
shuffleIds.get(appShuffleIdentifier) match {
case Some((shuffleId, _)) =>
val pbGetShuffleIdResponse =
-
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).build()
+
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).setSuccess(true).build()
context.reply(pbGetShuffleIdResponse)
case None =>
Option(appShuffleDeterminateMap.get(appShuffleId)).map {
determinate =>
@@ -940,7 +941,7 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
newShuffleId
}
val pbGetShuffleIdResponse =
-
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).build()
+
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).setSuccess(true).build()
context.reply(pbGetShuffleIdResponse)
}.orElse(
throw new UnsupportedOperationException(
@@ -953,12 +954,17 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
val pbGetShuffleIdResponse = {
logDebug(
s"get shuffleId $shuffleId for appShuffleId $appShuffleId
appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter")
-
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).build()
+
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).setSuccess(true).build()
}
context.reply(pbGetShuffleIdResponse)
case None =>
- throw new UnsupportedOperationException(
- s"unexpected! there is no finished map stage associated with
appShuffleId $appShuffleId")
+ val pbGetShuffleIdResponse = {
+ logInfo(
+ s"there is no finished map stage associated with appShuffleId
$appShuffleId")
+
PbGetShuffleIdResponse.newBuilder().setShuffleId(UNKNOWN_APP_SHUFFLE_ID).setSuccess(
+ false).build()
+ }
+ context.reply(pbGetShuffleIdResponse)
}
}
}
diff --git a/common/src/main/proto/TransportMessages.proto
b/common/src/main/proto/TransportMessages.proto
index acf355756..7b0d0bec2 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -403,6 +403,7 @@ message PbGetShuffleId {
message PbGetShuffleIdResponse {
int32 shuffleId = 1;
+ bool success = 2;
}
message PbReportShuffleFetchFailure {
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
index 055b763e5..e29b21a0c 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
@@ -139,7 +139,8 @@ trait SparkTestBase extends AnyFunSuite
conf,
h.userIdentifier,
h.extension)
- val celebornShuffleId =
SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
+ val celebornShuffleId =
+ SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
val allFiles = workerDirs.map(dir => {
new
File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
})