This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 045411ac3 [CELEBORN-1855] LifecycleManager return appshuffleId for non 
barrier stage when fetch fail has been reported
045411ac3 is described below

commit 045411ac34c749d8ca2d7bed8d0b9c3554f71287
Author: lijianfu03 <[email protected]>
AuthorDate: Tue May 13 16:14:03 2025 +0800

    [CELEBORN-1855] LifecycleManager return appshuffleId for non barrier stage 
when fetch fail has been reported
    
    ### What changes were proposed in this pull request?
    for non barrier shuffle read stage, 
LifecycleManager#handleGetShuffleIdForApp always return appshuffleId whether 
fetch status is true or not.
    
    ### Why are the changes needed?
    
    As described in 
[jira](https://issues.apache.org/jira/browse/CELEBORN-1855), If 
LifecycleManager only returns appshuffleId whose fetch status is success, the 
task will fail directly to "there is no finished map stage associated with", 
but previous fetch fail event reported may not be fatal.So just give it a chance
    
    ### Does this PR introduce _any_ user-facing change?
    
    ### How was this patch tested?
    
    Closes #3090 from buska88/celeborn-1855.
    
    Authored-by: lijianfu03 <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../apache/spark/shuffle/celeborn/SparkUtils.java  | 17 ++++++++++++-----
 .../shuffle/celeborn/CelebornShuffleReader.scala   | 22 +++++++++++++++++++---
 .../apache/spark/shuffle/celeborn/SparkUtils.java  | 17 ++++++++++++-----
 .../shuffle/celeborn/CelebornShuffleReader.scala   | 22 ++++++++++++++++++++--
 .../apache/celeborn/client/DummyShuffleClient.java |  6 ++++--
 .../org/apache/celeborn/client/ShuffleClient.java  |  4 +++-
 .../apache/celeborn/client/ShuffleClientImpl.java  |  7 ++++---
 .../apache/celeborn/client/LifecycleManager.scala  | 18 ++++++++++++------
 common/src/main/proto/TransportMessages.proto      |  1 +
 .../celeborn/tests/spark/SparkTestBase.scala       |  3 ++-
 10 files changed, 89 insertions(+), 28 deletions(-)

diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index 44dd1ea28..b52758581 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -63,6 +63,7 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.client.ShuffleClient;
 import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.exception.CelebornRuntimeException;
 import org.apache.celeborn.common.network.protocol.TransportMessage;
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
 import org.apache.celeborn.common.util.JavaUtils;
@@ -161,11 +162,17 @@ public class SparkUtils {
       Boolean isWriter) {
     if (handle.throwsFetchFailure()) {
       String appShuffleIdentifier = 
getAppShuffleIdentifier(handle.shuffleId(), context);
-      return client.getShuffleId(
-          handle.shuffleId(),
-          appShuffleIdentifier,
-          isWriter,
-          context instanceof BarrierTaskContext);
+      Tuple2<Integer, Boolean> res =
+          client.getShuffleId(
+              handle.shuffleId(),
+              appShuffleIdentifier,
+              isWriter,
+              context instanceof BarrierTaskContext);
+      if (!res._2) {
+        throw new CelebornRuntimeException(String.format("Get invalid shuffle 
id %s", res._1));
+      } else {
+        return res._1;
+      }
     } else {
       return handle.shuffleId();
     }
diff --git 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 2269aaf68..df63a94b1 100644
--- 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -33,7 +33,7 @@ import org.apache.celeborn.client.ShuffleClient
 import org.apache.celeborn.client.read.CelebornInputStream
 import org.apache.celeborn.client.read.MetricsCallback
 import org.apache.celeborn.common.CelebornConf
-import org.apache.celeborn.common.exception.{CelebornIOException, 
PartitionUnRetryAbleException}
+import org.apache.celeborn.common.exception.{CelebornIOException, 
CelebornRuntimeException, PartitionUnRetryAbleException}
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
 import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils}
 
@@ -63,8 +63,24 @@ class CelebornShuffleReader[K, C](
   override def read(): Iterator[Product2[K, C]] = {
 
     val serializerInstance = dep.serializer.newInstance()
-
-    val shuffleId = SparkUtils.celebornShuffleId(shuffleClient, handle, 
context, false)
+    val shuffleId =
+      try {
+        SparkUtils.celebornShuffleId(shuffleClient, handle, context, false)
+      } catch {
+        case e: CelebornRuntimeException =>
+          logError(s"Failed to get shuffleId for appShuffleId 
${handle.shuffleId}", e)
+          if (handle.throwsFetchFailure) {
+            throw new FetchFailedException(
+              null,
+              handle.shuffleId,
+              -1,
+              startPartition,
+              SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId,
+              e)
+          } else {
+            throw e
+          }
+      }
     shuffleIdTracker.track(handle.shuffleId, shuffleId)
     logDebug(
       s"get shuffleId $shuffleId for appShuffleId ${handle.shuffleId} 
attemptNum ${context.stageAttemptNumber()}")
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index edaeb28bd..b2e64565e 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -66,6 +66,7 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.client.ShuffleClient;
 import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.exception.CelebornRuntimeException;
 import org.apache.celeborn.common.network.protocol.TransportMessage;
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
 import org.apache.celeborn.common.util.JavaUtils;
@@ -138,11 +139,17 @@ public class SparkUtils {
       Boolean isWriter) {
     if (handle.throwsFetchFailure()) {
       String appShuffleIdentifier = 
getAppShuffleIdentifier(handle.shuffleId(), context);
-      return client.getShuffleId(
-          handle.shuffleId(),
-          appShuffleIdentifier,
-          isWriter,
-          context instanceof BarrierTaskContext);
+      Tuple2<Integer, Boolean> res =
+          client.getShuffleId(
+              handle.shuffleId(),
+              appShuffleIdentifier,
+              isWriter,
+              context instanceof BarrierTaskContext);
+      if (!res._2) {
+        throw new CelebornRuntimeException(String.format("Get invalid shuffle 
id %s", res._1));
+      } else {
+        return res._1;
+      }
     } else {
       return handle.shuffleId();
     }
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 3e296b310..3d9f14f23 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -40,7 +40,7 @@ import org.apache.celeborn.client.{ClientUtils, ShuffleClient}
 import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups
 import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
 import org.apache.celeborn.common.CelebornConf
-import org.apache.celeborn.common.exception.{CelebornIOException, 
PartitionUnRetryAbleException}
+import org.apache.celeborn.common.exception.{CelebornIOException, 
CelebornRuntimeException, PartitionUnRetryAbleException}
 import org.apache.celeborn.common.network.client.TransportClient
 import org.apache.celeborn.common.network.protocol.TransportMessage
 import org.apache.celeborn.common.protocol._
@@ -79,7 +79,25 @@ class CelebornShuffleReader[K, C](
 
     val startTime = System.currentTimeMillis()
     val serializerInstance = newSerializerInstance(dep)
-    val shuffleId = SparkUtils.celebornShuffleId(shuffleClient, handle, 
context, false)
+    val shuffleId =
+      try {
+        SparkUtils.celebornShuffleId(shuffleClient, handle, context, false)
+      } catch {
+        case e: CelebornRuntimeException =>
+          logError(s"Failed to get shuffleId for appShuffleId 
${handle.shuffleId}", e)
+          if (throwsFetchFailure) {
+            throw new FetchFailedException(
+              null,
+              handle.shuffleId,
+              -1,
+              -1,
+              startPartition,
+              SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + 
handle.shuffleId,
+              e)
+          } else {
+            throw e
+          }
+      }
     shuffleIdTracker.track(handle.shuffleId, shuffleId)
     logDebug(
       s"get shuffleId $shuffleId for appShuffleId ${handle.shuffleId} 
attemptNum ${context.stageAttemptNumber()}")
diff --git 
a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java 
b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
index 6b3673b18..dd1a032c8 100644
--- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -32,6 +32,8 @@ import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicInteger;
 
+import scala.Tuple2;
+
 import org.apache.commons.lang3.tuple.Pair;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -182,9 +184,9 @@ public class DummyShuffleClient extends ShuffleClient {
   }
 
   @Override
-  public int getShuffleId(
+  public Tuple2<Integer, Boolean> getShuffleId(
       int appShuffleId, String appShuffleIdentifier, boolean isWriter, boolean 
isBarrierStage) {
-    return appShuffleId;
+    return Tuple2.apply(appShuffleId, true);
   }
 
   @Override
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index dde2b36c4..bf0192e4a 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -26,6 +26,8 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.LongAdder;
 import java.util.function.BiFunction;
 
+import scala.Tuple2;
+
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.hadoop.fs.FileSystem;
 import org.slf4j.Logger;
@@ -285,7 +287,7 @@ public abstract class ShuffleClient {
 
   public abstract PushState getPushState(String mapKey);
 
-  public abstract int getShuffleId(
+  public abstract Tuple2<Integer, Boolean> getShuffleId(
       int appShuffleId, String appShuffleIdentifier, boolean isWriter, boolean 
isBarrierStage);
 
   /**
diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index c6fc7f6bd..81329e8d1 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -103,7 +103,7 @@ public class ShuffleClientImpl extends ShuffleClient {
   protected byte[] extension;
 
   // key: appShuffleIdentifier, value: shuffleId
-  protected Map<String, Integer> shuffleIdCache = 
JavaUtils.newConcurrentHashMap();
+  protected Map<String, Tuple2<Integer, Boolean>> shuffleIdCache = 
JavaUtils.newConcurrentHashMap();
 
   // key: shuffleId, value: (partitionId, PartitionLocation)
   final Map<Integer, ConcurrentHashMap<Integer, PartitionLocation>> 
reducePartitionMap =
@@ -626,7 +626,7 @@ public class ShuffleClientImpl extends ShuffleClient {
   }
 
   @Override
-  public int getShuffleId(
+  public Tuple2<Integer, Boolean> getShuffleId(
       int appShuffleId, String appShuffleIdentifier, boolean isWriter, boolean 
isBarrierStage) {
     return shuffleIdCache.computeIfAbsent(
         appShuffleIdentifier,
@@ -643,7 +643,8 @@ public class ShuffleClientImpl extends ShuffleClient {
                   pbGetShuffleId,
                   conf.clientRpcRegisterShuffleAskTimeout(),
                   ClassTag$.MODULE$.apply(PbGetShuffleIdResponse.class));
-          return pbGetShuffleIdResponse.getShuffleId();
+          return Tuple2.apply(
+              pbGetShuffleIdResponse.getShuffleId(), 
pbGetShuffleIdResponse.getSuccess());
         });
   }
 
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 923d2f838..20e1099d2 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -892,7 +892,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
     if (shuffleIds == null) {
       logWarning(s"unknown appShuffleId $appShuffleId, maybe no shuffle data 
for this shuffle")
       val pbGetShuffleIdResponse =
-        
PbGetShuffleIdResponse.newBuilder().setShuffleId(UNKNOWN_APP_SHUFFLE_ID).build()
+        
PbGetShuffleIdResponse.newBuilder().setShuffleId(UNKNOWN_APP_SHUFFLE_ID).setSuccess(
+          true).build()
       context.reply(pbGetShuffleIdResponse)
       return
     }
@@ -906,7 +907,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
         shuffleIds.get(appShuffleIdentifier) match {
           case Some((shuffleId, _)) =>
             val pbGetShuffleIdResponse =
-              
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).build()
+              
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).setSuccess(true).build()
             context.reply(pbGetShuffleIdResponse)
           case None =>
             Option(appShuffleDeterminateMap.get(appShuffleId)).map { 
determinate =>
@@ -940,7 +941,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
                   newShuffleId
                 }
               val pbGetShuffleIdResponse =
-                
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).build()
+                
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).setSuccess(true).build()
               context.reply(pbGetShuffleIdResponse)
             }.orElse(
               throw new UnsupportedOperationException(
@@ -953,12 +954,17 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
             val pbGetShuffleIdResponse = {
               logDebug(
                 s"get shuffleId $shuffleId for appShuffleId $appShuffleId 
appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter")
-              
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).build()
+              
PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).setSuccess(true).build()
             }
             context.reply(pbGetShuffleIdResponse)
           case None =>
-            throw new UnsupportedOperationException(
-              s"unexpected! there is no finished map stage associated with 
appShuffleId $appShuffleId")
+            val pbGetShuffleIdResponse = {
+              logInfo(
+                s"there is no finished map stage associated with appShuffleId 
$appShuffleId")
+              
PbGetShuffleIdResponse.newBuilder().setShuffleId(UNKNOWN_APP_SHUFFLE_ID).setSuccess(
+                false).build()
+            }
+            context.reply(pbGetShuffleIdResponse)
         }
       }
     }
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index acf355756..7b0d0bec2 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -403,6 +403,7 @@ message PbGetShuffleId {
 
 message PbGetShuffleIdResponse {
   int32 shuffleId = 1;
+  bool success = 2;
 }
 
 message PbReportShuffleFetchFailure {
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
index 055b763e5..e29b21a0c 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
@@ -139,7 +139,8 @@ trait SparkTestBase extends AnyFunSuite
               conf,
               h.userIdentifier,
               h.extension)
-            val celebornShuffleId = 
SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
+            val celebornShuffleId =
+              SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
             val allFiles = workerDirs.map(dir => {
               new 
File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
             })

Reply via email to