This is an automated email from the ASF dual-hosted git repository.

ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new ec62d924c [CELEBORN-2000] Ignore the getReducerFileGroup timeout 
before shuffle stage end
ec62d924c is described below

commit ec62d924c5e8fc9854c4af97c89acb14c4c1b6df
Author: Wang, Fei <[email protected]>
AuthorDate: Tue May 20 14:16:46 2025 +0800

    [CELEBORN-2000] Ignore the getReducerFileGroup timeout before shuffle stage 
end
    
    ### What changes were proposed in this pull request?
    
    Ignore the getReducerFileGroup timeout before shuffle stage end.
    ### Why are the changes needed?
    
    1. if the getReducerFileGroup timeout is caused by lifecycle manager 
commitFiles timeout(stage not ended)
    2. maybe many tasks failed and would not report fetch failure
    3. then it cause the spark application failed eventually.
    
    The shuffle client should ignore the getReducerFileGroup timeout before 
LifeCycleManager commitFiles complete.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    
    UT
    
    Closes #3263 from turboFei/is_stage_end.
    
    Lead-authored-by: Wang, Fei <[email protected]>
    Co-authored-by: Fei Wang <[email protected]>
    Signed-off-by: mingji <[email protected]>
---
 .../shuffle/celeborn/CelebornShuffleReader.scala   | 39 ++++++++++++++++------
 .../apache/celeborn/client/DummyShuffleClient.java |  5 +++
 .../org/apache/celeborn/client/ShuffleClient.java  |  2 ++
 .../apache/celeborn/client/ShuffleClientImpl.java  | 16 +++++++++
 .../org/apache/celeborn/client/CommitManager.scala |  4 +++
 .../apache/celeborn/client/LifecycleManager.scala  |  9 +++++
 .../celeborn/client/commit/CommitHandler.scala     |  9 +++++
 .../commit/ReducePartitionCommitHandler.scala      |  6 +++-
 .../celeborn/client/WithShuffleClientSuite.scala   |  4 +--
 common/src/main/proto/TransportMessages.proto      | 10 ++++++
 .../common/protocol/message/ControlMessages.scala  | 12 +++++++
 .../celeborn/tests/client/ShuffleClientSuite.scala | 18 ++++++++++
 12 files changed, 121 insertions(+), 13 deletions(-)

diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 3d9f14f23..958e08196 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -129,16 +129,35 @@ class CelebornShuffleReader[K, C](
     val localHostAddress = Utils.localHostName(conf)
     val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
     var fileGroups: ReduceFileGroups = null
-    try {
-      // startPartition is irrelevant
-      fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
-    } catch {
-      case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
-        // if a task is interrupted, should not report fetch failure
-        // if a task update file group timeout, should not report fetch failure
-        checkAndReportFetchFailureForUpdateFileGroupFailure(shuffleId, ce)
-      case e: Throwable => throw e
-    }
+    var isShuffleStageEnd: Boolean = false
+    var updateFileGroupsRetryTimes = 0
+    do {
+      isShuffleStageEnd =
+        try {
+          shuffleClient.isShuffleStageEnd(shuffleId)
+        } catch {
+          case e: Exception =>
+            logInfo(s"Failed to check shuffle stage end for $shuffleId, assume 
ended", e)
+            true
+        }
+      try {
+        // startPartition is irrelevant
+        fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
+      } catch {
+        case ce: CelebornIOException
+            if ce.getCause != null && ce.getCause.isInstanceOf[
+              TimeoutException] && !isShuffleStageEnd =>
+          updateFileGroupsRetryTimes += 1
+          logInfo(
+            s"UpdateFileGroup for $shuffleKey timeout due to shuffle stage not 
ended," +
+              s" retry again, retry times $updateFileGroupsRetryTimes",
+            ce)
+        case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) 
=>
+          // if a task is interrupted, should not report fetch failure
+          // if a task update file group timeout, should not report fetch 
failure
+          checkAndReportFetchFailureForUpdateFileGroupFailure(shuffleId, ce)
+      }
+    } while (fileGroups == null)
 
     val batchOpenStreamStartTime = System.currentTimeMillis()
     // host-port -> (TransportClient, PartitionLocation Array, 
PbOpenStreamList)
diff --git 
a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java 
b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
index 648309489..95e02ecfd 100644
--- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -131,6 +131,11 @@ public class DummyShuffleClient extends ShuffleClient {
     return null;
   }
 
+  @Override
+  public boolean isShuffleStageEnd(int shuffleId) throws Exception {
+    return true;
+  }
+
   @Override
   public CelebornInputStream readPartition(
       int shuffleId,
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index 6363e9004..206f61170 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -217,6 +217,8 @@ public abstract class ShuffleClient {
   public abstract ShuffleClientImpl.ReduceFileGroups updateFileGroup(int 
shuffleId, int partitionId)
       throws CelebornIOException;
 
+  public abstract boolean isShuffleStageEnd(int shuffleId) throws Exception;
+
   // Reduce side read partition which is deduplicated by 
mapperId+mapperAttemptNum+batchId, batchId
   // is a self-incrementing variable hidden in the implementation when sending 
data.
   /**
diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 6e80f90c2..9c09e3604 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -1870,6 +1870,22 @@ public class ShuffleClientImpl extends ShuffleClient {
     return updateFileGroup(shuffleId, partitionId, false);
   }
 
+  @Override
+  public boolean isShuffleStageEnd(int shuffleId) throws Exception {
+    if (null != lifecycleManagerRef) {
+      PbGetStageEnd request = 
PbGetStageEnd.newBuilder().setShuffleId(shuffleId).build();
+      PbGetStageEndResponse response =
+          lifecycleManagerRef.askSync(
+              request,
+              rpcMaxRetries,
+              rpcRetryWait,
+              ClassTag$.MODULE$.apply(PbGetStageEndResponse.class));
+      return response.getStageEnd();
+    } else {
+      throw new RuntimeException("Driver endpoint is null!");
+    }
+  }
+
   public ReduceFileGroups updateFileGroup(
       int shuffleId, int partitionId, boolean isSegmentGranularityVisible)
       throws CelebornIOException {
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index 702441dce..8e2e43c89 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -283,6 +283,10 @@ class CommitManager(appUniqueId: String, val conf: 
CelebornConf, lifecycleManage
     getCommitHandler(shuffleId).handleGetReducerFileGroup(context, shuffleId, 
serdeVersion)
   }
 
+  def handleGetStageEnd(context: RpcCallContext, shuffleId: Int): Unit = {
+    getCommitHandler(shuffleId).handleGetStageEnd(context, shuffleId)
+  }
+
   // exposed for test
   def getCommitHandler(shuffleId: Int): CommitHandler = {
     val partitionType = lifecycleManager.getPartitionType(shuffleId)
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index ae692f908..69920ee84 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -444,6 +444,11 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
         s"Received GetShuffleFileGroup request for shuffleId $shuffleId, 
isSegmentGranularityVisible $isSegmentGranularityVisible")
       handleGetReducerFileGroup(context, shuffleId, 
isSegmentGranularityVisible, serdeVersion)
 
+    case pb: PbGetStageEnd =>
+      val shuffleId = pb.getShuffleId
+      logDebug(s"Received GetStageEnd request for shuffleId $shuffleId")
+      handleGetStageEnd(context, shuffleId)
+
     case pb: PbGetShuffleId =>
       val appShuffleId = pb.getAppShuffleId
       val appShuffleIdentifier = pb.getAppShuffleIdentifier
@@ -869,6 +874,10 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
     commitManager.handleGetReducerFileGroup(context, shuffleId, serdeVersion)
   }
 
+  private def handleGetStageEnd(context: RpcCallContext, shuffleId: Int): Unit 
= {
+    commitManager.handleGetStageEnd(context, shuffleId)
+  }
+
   private def handleGetShuffleIdForApp(
       context: RpcCallContext,
       appShuffleId: Int,
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
index 42ea0379f..c6f535d33 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -184,6 +184,15 @@ abstract class CommitHandler(
       shuffleId: Int,
       serdeVersion: SerdeVersion): Unit
 
+  /**
+   * Only Reduce partition mode supports get stage end.
+   */
+  def handleGetStageEnd(context: RpcCallContext, shuffleId: Int): Unit = {
+    throw new UnsupportedOperationException(
+      "Failed when do handleGetStageEnd Operation, MapPartition shuffleType 
don't " +
+        "support stage end")
+  }
+
   def removeExpiredShuffle(shuffleId: Int): Unit = {
     reducerFileGroupsMap.remove(shuffleId)
     shufflePushFailedBatches.remove(shuffleId)
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 45c371ee5..944e07657 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -35,7 +35,7 @@ import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.meta.ShufflePartitionLocationInfo
 import org.apache.celeborn.common.network.protocol.SerdeVersion
-import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
+import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType, 
PbGetStageEndResponse}
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
 import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.rpc.RpcCallContext
@@ -427,6 +427,10 @@ class ReducePartitionCommitHandler(
     }
   }
 
+  override def handleGetStageEnd(context: RpcCallContext, shuffleId: Int): 
Unit = {
+    
context.reply(PbGetStageEndResponse.newBuilder().setStageEnd(isStageEnd(shuffleId)).build())
+  }
+
   override def waitStageEnd(shuffleId: Int): (Boolean, Long) = {
     var timeout = stageEndTimeout
     val delta = 100
diff --git 
a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala 
b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
index 0570760ce..8ec00d262 100644
--- 
a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
+++ 
b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
@@ -40,7 +40,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
   private val mapId = 1
   private val attemptId = 0
 
-  private var lifecycleManager: LifecycleManager = _
+  protected var lifecycleManager: LifecycleManager = _
   protected var shuffleClient: ShuffleClientImpl = _
 
   var _shuffleId = 0
@@ -188,7 +188,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
     Assert.assertEquals(stream.read(), -1)
   }
 
-  private def prepareService(): Unit = {
+  protected def prepareService(): Unit = {
     lifecycleManager = new LifecycleManager(APP, celebornConf)
     shuffleClient = new ShuffleClientImpl(APP, celebornConf, userIdentifier)
     shuffleClient.setupLifecycleManagerRef(lifecycleManager.self)
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index 4ca924582..d53d56842 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -112,6 +112,8 @@ enum MessageType {
   REVISE_LOST_SHUFFLES = 89;
   REVISE_LOST_SHUFFLES_RESPONSE = 90;
   PUSH_MERGED_DATA_SPLIT_PARTITION_INFO = 91;
+  GET_STAGE_END = 92;
+  GET_STAGE_END_RESPONSE = 93;
 }
 
 enum StreamType {
@@ -906,6 +908,14 @@ message PbPushMergedDataSplitPartitionInfo {
   repeated int32 statusCodes = 2;
 }
 
+message PbGetStageEnd {
+  int32 shuffleId = 1;
+}
+
+message PbGetStageEndResponse {
+  bool stageEnd = 1;
+}
+
 message PbChunkOffsets {
   repeated int64 chunkOffset = 1;
 }
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 009c3a4fc..04c91e03b 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -617,6 +617,12 @@ object ControlMessages extends Logging {
     case pb: PbPushMergedDataSplitPartitionInfo =>
       new TransportMessage(MessageType.PUSH_MERGED_DATA_SPLIT_PARTITION_INFO, 
pb.toByteArray)
 
+    case pb: PbGetStageEnd =>
+      new TransportMessage(MessageType.GET_STAGE_END, pb.toByteArray)
+
+    case pb: PbGetStageEndResponse =>
+      new TransportMessage(MessageType.GET_STAGE_END_RESPONSE, pb.toByteArray)
+
     case HeartbeatFromWorker(
           host,
           rpcPort,
@@ -1465,6 +1471,12 @@ object ControlMessages extends Logging {
 
       case PUSH_MERGED_DATA_SPLIT_PARTITION_INFO_VALUE =>
         PbPushMergedDataSplitPartitionInfo.parseFrom(message.getPayload)
+
+      case GET_STAGE_END_VALUE =>
+        PbGetStageEnd.parseFrom(message.getPayload)
+
+      case GET_STAGE_END_RESPONSE_VALUE =>
+        PbGetStageEndResponse.parseFrom(message.getPayload)
     }
   }
 }
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
index 112174fbf..88eb52ef5 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala
@@ -18,9 +18,13 @@
 package org.apache.celeborn.tests.client
 
 import java.io.IOException
+import java.util
+
+import scala.collection.JavaConverters._
 
 import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl, 
WithShuffleClientSuite}
 import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.service.deploy.MiniClusterFeature
 
 class ShuffleClientSuite extends WithShuffleClientSuite with 
MiniClusterFeature {
@@ -62,6 +66,20 @@ class ShuffleClientSuite extends WithShuffleClientSuite with 
MiniClusterFeature
     lifecycleManager.stop()
   }
 
+  test("is shuffle stage end") {
+    prepareService()
+    val shuffleId = 0
+    val counts = 10
+    val ids =
+      new util.ArrayList[Integer]((0 until counts).toList.map(x => 
Integer.valueOf(x)).asJava)
+    val res = lifecycleManager.requestMasterRequestSlotsWithRetry(shuffleId, 
ids)
+    assert(res.status == StatusCode.SUCCESS)
+    lifecycleManager.registeredShuffle.add(shuffleId)
+    assert(!shuffleClient.isShuffleStageEnd(shuffleId))
+    lifecycleManager.commitManager.setStageEnd(shuffleId)
+    assert(shuffleClient.isShuffleStageEnd(shuffleId))
+  }
+
   override def afterAll(): Unit = {
     logInfo("all test complete , stop celeborn mini cluster")
     shutdownMiniCluster()

Reply via email to