This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 92bebd305 [CELEBORN-1160] Avoid calling parmap when commit files
92bebd305 is described below

commit 92bebd305d9a71acb7c9cd256e9f924ede77ef95
Author: zky.zhoukeyong <[email protected]>
AuthorDate: Wed Dec 13 14:36:48 2023 +0800

    [CELEBORN-1160] Avoid calling parmap when commit files
    
    ### What changes were proposed in this pull request?
    As title
    
    ### Why are the changes needed?
    One user reported that LifecycleManager's parmap can create huge number of 
threads and causes OOM.
    
    
![image](https://github.com/apache/incubator-celeborn/assets/948245/1e9a0b83-32fe-40d5-8739-2b370e030fc8)
    
    There are four places where parmap is called:
    
    1. When LifecycleManager commits files
    2. When LifecycleManager reserves slots
    3. When LifecycleManager setup connection to workers
    4. When StorageManager calls close
    
    This PR fixes the first one. To be more detail, this PR eliminates `parmap` 
when doing committing files, and also replaces `askSync` with `ask`.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Manual test and GA.
    
    Closes #2145 from waitinfuture/1160.
    
    Lead-authored-by: zky.zhoukeyong <[email protected]>
    Co-authored-by: Keyong Zhou <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../org/apache/celeborn/client/CommitManager.scala |  32 +-
 .../apache/celeborn/client/LifecycleManager.scala  |   2 +
 .../celeborn/client/commit/CommitHandler.scala     | 353 +++++++++++++--------
 .../client/commit/MapPartitionCommitHandler.scala  |   7 +-
 .../commit/ReducePartitionCommitHandler.scala      |  12 +-
 .../org/apache/celeborn/common/CelebornConf.scala  |   9 +
 docs/configuration/client.md                       |   1 +
 .../celeborn/tests/spark/RetryReviveTest.scala     |   1 +
 8 files changed, 271 insertions(+), 146 deletions(-)

diff --git 
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index a59d1ca91..e6b8d0e24 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -22,13 +22,14 @@ import java.util.concurrent.{ConcurrentHashMap, 
ScheduledExecutorService, Schedu
 import java.util.concurrent.atomic.{AtomicInteger, LongAdder}
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.duration.DurationInt
 
 import org.roaringbitmap.RoaringBitmap
 
 import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
 import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers
-import org.apache.celeborn.client.commit.{CommitHandler, 
MapPartitionCommitHandler, ReducePartitionCommitHandler}
+import org.apache.celeborn.client.commit.{CommitFilesParam, CommitHandler, 
MapPartitionCommitHandler, ReducePartitionCommitHandler}
 import org.apache.celeborn.client.listener.{WorkersStatus, 
WorkerStatusListener}
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
@@ -113,13 +114,9 @@ class CommitManager(appUniqueId: String, val conf: 
CelebornConf, lifecycleManage
 
                     if (workerToRequests.nonEmpty) {
                       val commitFilesFailedWorkers = new ShuffleFailedWorkers()
-                      val parallelism =
-                        Math.min(workerToRequests.size, 
conf.clientRpcMaxParallelism)
                       try {
-                        ThreadUtils.parmap(
-                          workerToRequests,
-                          "CommitFiles",
-                          parallelism) {
+                        val params = new 
ArrayBuffer[CommitFilesParam](workerToRequests.size)
+                        workerToRequests.foreach {
                           case (worker, requests) =>
                             val workerInfo =
                               lifecycleManager.shuffleAllocatedWorkers
@@ -141,15 +138,18 @@ class CommitManager(appUniqueId: String, val conf: 
CelebornConf, lifecycleManage
                                 .toList
                                 .asJava
 
-                            commitHandler.commitFiles(
-                              appUniqueId,
-                              shuffleId,
-                              shuffleCommittedInfo,
+                            params += CommitFilesParam(
                               workerInfo,
                               primaryIds,
-                              replicaIds,
-                              commitFilesFailedWorkers)
+                              replicaIds)
                         }
+
+                        commitHandler.doParallelCommitFiles(
+                          shuffleId,
+                          shuffleCommittedInfo,
+                          params,
+                          commitFilesFailedWorkers)
+
                         
lifecycleManager.workerStatusTracker.recordWorkerFailure(
                           commitFilesFailedWorkers)
                       } finally {
@@ -277,13 +277,15 @@ class CommitManager(appUniqueId: String, val conf: 
CelebornConf, lifecycleManage
               conf,
               lifecycleManager.shuffleAllocatedWorkers,
               committedPartitionInfo,
-              lifecycleManager.workerStatusTracker)
+              lifecycleManager.workerStatusTracker,
+              lifecycleManager.rpcSharedThreadPool)
           case PartitionType.MAP => new MapPartitionCommitHandler(
               appUniqueId,
               conf,
               lifecycleManager.shuffleAllocatedWorkers,
               committedPartitionInfo,
-              lifecycleManager.workerStatusTracker)
+              lifecycleManager.workerStatusTracker,
+              lifecycleManager.rpcSharedThreadPool)
           case _ => throw new UnsupportedOperationException(
               s"Unexpected ShufflePartitionType for CommitManager: 
$partitionType")
         }
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 9454ae2c0..3f1cf7aab 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -135,6 +135,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
   private val forwardMessageThread =
     
ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread")
   private var checkForShuffleRemoval: ScheduledFuture[_] = _
+  val rpcSharedThreadPool =
+    ThreadUtils.newDaemonCachedThreadPool("shared-rpc-pool", 
conf.clientRpcSharedThreads, 30)
 
   // init driver celeborn LifecycleManager rpc service
   override val rpcEnv: RpcEnv = RpcEnv.create(
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
index 947546416..641868519 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -18,11 +18,15 @@
 package org.apache.celeborn.client.commit
 
 import java.util
-import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, 
ThreadPoolExecutor, TimeUnit}
 import java.util.concurrent.atomic.{AtomicLong, LongAdder}
 
 import scala.collection.JavaConverters._
+import scala.collection.generic.CanBuildFrom
 import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.duration.Duration
 
 import org.apache.celeborn.client.{ShuffleCommittedInfo, WorkerStatusTracker}
 import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
@@ -37,6 +41,17 @@ import org.apache.celeborn.common.rpc.{RpcCallContext, 
RpcEndpointRef}
 import org.apache.celeborn.common.util.{CollectionUtils, JavaUtils, 
ThreadUtils, Utils}
 // Can Remove this if celeborn don't support scala211 in future
 import org.apache.celeborn.common.util.FunctionConverter._
+import org.apache.celeborn.common.util.ThreadUtils.awaitResult
+
+case class CommitFilesParam(
+    worker: WorkerInfo,
+    primaryIds: util.List[String],
+    replicaIds: util.List[String])
+
+case class FutureWithStatus(
+    var future: Future[CommitFilesResponse],
+    commitFilesParam: CommitFilesParam,
+    var retriedTimes: Int)
 
 case class CommitResult(
     primaryPartitionLocationMap: ConcurrentHashMap[String, PartitionLocation],
@@ -47,7 +62,8 @@ abstract class CommitHandler(
     appUniqueId: String,
     conf: CelebornConf,
     committedPartitionInfo: CommittedPartitionInfo,
-    workerStatusTracker: WorkerStatusTracker) extends Logging {
+    workerStatusTracker: WorkerStatusTracker,
+    val sharedRpcPool: ThreadPoolExecutor) extends Logging {
 
   private val pushReplicateEnabled = conf.clientPushReplicateEnabled
   private val testRetryCommitFiles = conf.testRetryCommitFiles
@@ -57,6 +73,8 @@ abstract class CommitHandler(
   private val fileCount = new LongAdder
   protected val reducerFileGroupsMap = new ShuffleFileGroups
 
+  val ec = ExecutionContext.fromExecutor(sharedRpcPool)
+
   def getPartitionType(): PartitionType
 
   def isStageEnd(shuffleId: Int): Boolean = false
@@ -180,6 +198,146 @@ abstract class CommitHandler(
     reducerFileGroupsMap.put(shuffleId, JavaUtils.newConcurrentHashMap())
   }
 
+  def doParallelCommitFiles(
+      shuffleId: Int,
+      shuffleCommittedInfo: ShuffleCommittedInfo,
+      params: ArrayBuffer[CommitFilesParam],
+      commitFilesFailedWorkers: ShuffleFailedWorkers): Unit = {
+
+    def processResponse(res: CommitFilesResponse, worker: WorkerInfo): Unit = {
+      shuffleCommittedInfo.synchronized {
+        // record committed partitionIds
+        res.committedPrimaryIds.asScala.foreach {
+          case commitPrimaryId =>
+            val partitionUniqueIdList = 
shuffleCommittedInfo.committedPrimaryIds.computeIfAbsent(
+              Utils.splitPartitionLocationUniqueId(commitPrimaryId)._1,
+              (k: Int) => new util.ArrayList[String]())
+            partitionUniqueIdList.add(commitPrimaryId)
+        }
+
+        res.committedReplicaIds.asScala.foreach {
+          case commitReplicaId =>
+            val partitionUniqueIdList = 
shuffleCommittedInfo.committedReplicaIds.computeIfAbsent(
+              Utils.splitPartitionLocationUniqueId(commitReplicaId)._1,
+              (k: Int) => new util.ArrayList[String]())
+            partitionUniqueIdList.add(commitReplicaId)
+        }
+
+        // record committed partitions storage hint and disk hint
+        
shuffleCommittedInfo.committedPrimaryStorageInfos.putAll(res.committedPrimaryStorageInfos)
+        
shuffleCommittedInfo.committedReplicaStorageInfos.putAll(res.committedReplicaStorageInfos)
+
+        // record failed partitions
+        shuffleCommittedInfo.failedPrimaryPartitionIds.putAll(
+          res.failedPrimaryIds.asScala.map((_, worker)).toMap.asJava)
+        shuffleCommittedInfo.failedReplicaPartitionIds.putAll(
+          res.failedReplicaIds.asScala.map((_, worker)).toMap.asJava)
+
+        
shuffleCommittedInfo.committedMapIdBitmap.putAll(res.committedMapIdBitMap)
+
+        totalWritten.add(res.totalWritten)
+        fileCount.add(res.fileCount)
+        shuffleCommittedInfo.currentShuffleFileCount.add(res.fileCount)
+      }
+    }
+
+    val futures = new LinkedBlockingQueue[FutureWithStatus]()
+
+    val outFutures = params.filter(param =>
+      !CollectionUtils.isEmpty(param.primaryIds) ||
+        !CollectionUtils.isEmpty(param.replicaIds)) map { param =>
+      Future {
+        val future = commitFiles(
+          appUniqueId,
+          shuffleId,
+          param.worker,
+          param.primaryIds,
+          param.replicaIds)
+
+        futures.add(FutureWithStatus(future, param, 1))
+      }(ec)
+    }
+    val cbf =
+      implicitly[
+        CanBuildFrom[ArrayBuffer[Future[Boolean]], Boolean, 
ArrayBuffer[Boolean]]]
+    val futureSeq = Future.sequence(outFutures)(cbf, ec)
+    awaitResult(futureSeq, Duration.Inf)
+
+    val maxRetries = conf.clientRequestCommitFilesMaxRetries
+    var timeout = conf.rpcAskTimeout.duration.toMillis * maxRetries
+    val delta = 50
+    while (timeout >= 0 && !futures.isEmpty) {
+      val iter = futures.iterator()
+      while (iter.hasNext) {
+        val status = iter.next()
+        if (status.future.isCompleted) {
+          status.future.value.get match {
+            case scala.util.Success(res) =>
+              val worker = status.commitFilesParam.worker
+              res.status match {
+                case StatusCode.SUCCESS => // do nothing
+                case StatusCode.PARTIAL_SUCCESS | 
StatusCode.SHUFFLE_NOT_REGISTERED | StatusCode.REQUEST_FAILED | 
StatusCode.WORKER_EXCLUDED =>
+                  logInfo(s"Request commitFiles return ${res.status} for " +
+                    s"${Utils.makeShuffleKey(appUniqueId, shuffleId)}")
+                  if (res.status != StatusCode.WORKER_EXCLUDED) {
+                    commitFilesFailedWorkers.put(worker, (res.status, 
System.currentTimeMillis()))
+                  }
+                case _ =>
+                  logError(s"Should never reach here! commit files response 
status ${res.status}")
+              }
+
+              processResponse(res, worker)
+              iter.remove()
+            case scala.util.Failure(e) =>
+              val worker = status.commitFilesParam.worker
+              logError(
+                s"Ask worker($worker) CommitFiles for $shuffleId failed" +
+                  s" (attempt ${status.retriedTimes}/$maxRetries).",
+                e)
+              if (status.retriedTimes < maxRetries) {
+                status.retriedTimes = status.retriedTimes + 1
+                status.future = commitFiles(
+                  appUniqueId,
+                  shuffleId,
+                  status.commitFilesParam.worker,
+                  status.commitFilesParam.primaryIds,
+                  status.commitFilesParam.replicaIds)
+              } else {
+                val res = CommitFilesResponse(
+                  StatusCode.REQUEST_FAILED,
+                  List.empty.asJava,
+                  List.empty.asJava,
+                  status.commitFilesParam.primaryIds,
+                  status.commitFilesParam.replicaIds)
+                processResponse(res, status.commitFilesParam.worker)
+                iter.remove()
+              }
+          }
+        }
+      }
+
+      if (!futures.isEmpty) {
+        Thread.sleep(delta)
+      }
+      timeout = timeout - delta
+    }
+
+    val iter = futures.iterator()
+    while (iter.hasNext) {
+      val status = iter.next()
+      logError(
+        s"Ask worker(${status.commitFilesParam.worker}) CommitFiles for 
$shuffleId timed out")
+      val res = CommitFilesResponse(
+        StatusCode.REQUEST_FAILED,
+        List.empty.asJava,
+        List.empty.asJava,
+        status.commitFilesParam.primaryIds,
+        status.commitFilesParam.replicaIds)
+      processResponse(res, status.commitFilesParam.worker)
+      iter.remove()
+    }
+  }
+
   def parallelCommitFiles(
       shuffleId: Int,
       allocatedWorkers: util.Map[WorkerInfo, ShufflePartitionLocationInfo],
@@ -195,11 +353,9 @@ abstract class CommitHandler(
 
     val commitFileStartTime = System.nanoTime()
     val workerPartitionLocations = 
allocatedWorkers.asScala.filter(!_._2.isEmpty)
-    val parallelism = Math.min(workerPartitionLocations.size, 
conf.clientRpcMaxParallelism)
-    ThreadUtils.parmap(
-      workerPartitionLocations,
-      "CommitFiles",
-      parallelism) { case (worker, partitionLocationInfo) =>
+
+    val params = new 
ArrayBuffer[CommitFilesParam](workerPartitionLocations.size)
+    workerPartitionLocations.foreach { case (worker, partitionLocationInfo) =>
       val primaryParts =
         partitionLocationInfo.getPrimaryPartitions(partitionIdOpt)
       val replicaParts = 
partitionLocationInfo.getReplicaPartitions(partitionIdOpt)
@@ -226,17 +382,14 @@ abstract class CommitHandler(
             .map(_.getUniqueId).toList.asJava)
       }
 
-      commitFiles(
-        appUniqueId,
-        shuffleId,
-        shuffleCommittedInfo,
+      params += CommitFilesParam(
         worker,
         primaryIds,
-        replicaIds,
-        commitFilesFailedWorkers)
-
+        replicaIds)
     }
 
+    doParallelCommitFiles(shuffleId, shuffleCommittedInfo, params, 
commitFilesFailedWorkers)
+
     logInfo(s"Shuffle $shuffleId " +
       s"commit files complete. File count 
${shuffleCommittedInfo.currentShuffleFileCount.sum()} " +
       s"using ${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - 
commitFileStartTime)} ms")
@@ -247,119 +400,68 @@ abstract class CommitHandler(
   def commitFiles(
       applicationId: String,
       shuffleId: Int,
-      shuffleCommittedInfo: ShuffleCommittedInfo,
       worker: WorkerInfo,
       primaryIds: util.List[String],
-      replicaIds: util.List[String],
-      commitFilesFailedWorkers: ShuffleFailedWorkers): Unit = {
-
-    if (CollectionUtils.isEmpty(primaryIds) && 
CollectionUtils.isEmpty(replicaIds)) {
-      return
-    }
+      replicaIds: util.List[String]): Future[CommitFilesResponse] = {
 
-    val res =
-      if (!testRetryCommitFiles) {
-        val commitFiles = CommitFiles(
-          applicationId,
-          shuffleId,
-          primaryIds,
-          replicaIds,
-          getMapperAttempts(shuffleId),
-          commitEpoch.incrementAndGet())
-        val res =
-          if (conf.clientCommitFilesIgnoreExcludedWorkers &&
-            workerStatusTracker.excludedWorkers.containsKey(worker)) {
-            CommitFilesResponse(
-              StatusCode.WORKER_EXCLUDED,
-              List.empty.asJava,
-              List.empty.asJava,
-              primaryIds,
-              replicaIds)
-          } else {
-            requestCommitFilesWithRetry(worker.endpoint, commitFiles)
-          }
-
-        res.status match {
-          case StatusCode.SUCCESS => // do nothing
-          case StatusCode.PARTIAL_SUCCESS | StatusCode.SHUFFLE_NOT_REGISTERED 
| StatusCode.REQUEST_FAILED | StatusCode.WORKER_EXCLUDED =>
-            logInfo(s"Request $commitFiles return ${res.status} for " +
-              s"${Utils.makeShuffleKey(applicationId, shuffleId)}")
-            if (res.status != StatusCode.WORKER_EXCLUDED) {
-              commitFilesFailedWorkers.put(worker, (res.status, 
System.currentTimeMillis()))
-            }
-          case _ =>
-            logError(s"Should never reach here! commit files response status 
${res.status}")
-        }
-        res
+    if (!testRetryCommitFiles) {
+      val commitFiles = CommitFiles(
+        applicationId,
+        shuffleId,
+        primaryIds,
+        replicaIds,
+        getMapperAttempts(shuffleId),
+        commitEpoch.incrementAndGet())
+
+      if (conf.clientCommitFilesIgnoreExcludedWorkers &&
+        workerStatusTracker.excludedWorkers.containsKey(worker)) {
+        Future {
+          CommitFilesResponse(
+            StatusCode.WORKER_EXCLUDED,
+            List.empty.asJava,
+            List.empty.asJava,
+            primaryIds,
+            replicaIds)
+        }(ec)
       } else {
-        // for test
-        val commitFiles1 = CommitFiles(
-          applicationId,
-          shuffleId,
-          primaryIds.subList(0, primaryIds.size() / 2),
-          replicaIds.subList(0, replicaIds.size() / 2),
-          getMapperAttempts(shuffleId),
-          commitEpoch.incrementAndGet())
-        val res1 = requestCommitFilesWithRetry(worker.endpoint, commitFiles1)
-
-        val commitFiles = CommitFiles(
-          applicationId,
-          shuffleId,
-          primaryIds.subList(primaryIds.size() / 2, primaryIds.size()),
-          replicaIds.subList(replicaIds.size() / 2, replicaIds.size()),
-          getMapperAttempts(shuffleId),
-          commitEpoch.incrementAndGet())
-        val res2 = requestCommitFilesWithRetry(worker.endpoint, commitFiles)
-
-        
res1.committedPrimaryStorageInfos.putAll(res2.committedPrimaryStorageInfos)
-        
res1.committedReplicaStorageInfos.putAll(res2.committedReplicaStorageInfos)
-        res1.committedMapIdBitMap.putAll(res2.committedMapIdBitMap)
-        CommitFilesResponse(
-          status = if (res1.status == StatusCode.SUCCESS) res2.status else 
res1.status,
-          (res1.committedPrimaryIds.asScala ++ 
res2.committedPrimaryIds.asScala).toList.asJava,
-          (res1.committedReplicaIds.asScala ++ 
res2.committedReplicaIds.asScala).toList.asJava,
-          (res1.failedPrimaryIds.asScala ++ 
res2.failedPrimaryIds.asScala).toList.asJava,
-          (res1.failedReplicaIds.asScala ++ 
res2.failedReplicaIds.asScala).toList.asJava,
-          res1.committedPrimaryStorageInfos,
-          res1.committedReplicaStorageInfos,
-          res1.committedMapIdBitMap,
-          res1.totalWritten + res2.totalWritten,
-          res1.fileCount + res2.fileCount)
+        worker.endpoint.ask[CommitFilesResponse](commitFiles)
       }
-
-    shuffleCommittedInfo.synchronized {
-      // record committed partitionIds
-      res.committedPrimaryIds.asScala.foreach({
-        case commitPrimaryId =>
-          val partitionUniqueIdList = 
shuffleCommittedInfo.committedPrimaryIds.computeIfAbsent(
-            Utils.splitPartitionLocationUniqueId(commitPrimaryId)._1,
-            (k: Int) => new util.ArrayList[String]())
-          partitionUniqueIdList.add(commitPrimaryId)
-      })
-
-      res.committedReplicaIds.asScala.foreach({
-        case commitReplicaId =>
-          val partitionUniqueIdList = 
shuffleCommittedInfo.committedReplicaIds.computeIfAbsent(
-            Utils.splitPartitionLocationUniqueId(commitReplicaId)._1,
-            (k: Int) => new util.ArrayList[String]())
-          partitionUniqueIdList.add(commitReplicaId)
-      })
-
-      // record committed partitions storage hint and disk hint
-      
shuffleCommittedInfo.committedPrimaryStorageInfos.putAll(res.committedPrimaryStorageInfos)
-      
shuffleCommittedInfo.committedReplicaStorageInfos.putAll(res.committedReplicaStorageInfos)
-
-      // record failed partitions
-      shuffleCommittedInfo.failedPrimaryPartitionIds.putAll(
-        res.failedPrimaryIds.asScala.map((_, worker)).toMap.asJava)
-      shuffleCommittedInfo.failedReplicaPartitionIds.putAll(
-        res.failedReplicaIds.asScala.map((_, worker)).toMap.asJava)
-
-      
shuffleCommittedInfo.committedMapIdBitmap.putAll(res.committedMapIdBitMap)
-
-      totalWritten.add(res.totalWritten)
-      fileCount.add(res.fileCount)
-      shuffleCommittedInfo.currentShuffleFileCount.add(res.fileCount)
+    } else {
+      // for test
+      val commitFiles1 = CommitFiles(
+        applicationId,
+        shuffleId,
+        primaryIds.subList(0, primaryIds.size() / 2),
+        replicaIds.subList(0, replicaIds.size() / 2),
+        getMapperAttempts(shuffleId),
+        commitEpoch.incrementAndGet())
+      val res1 = requestCommitFilesWithRetryForTest(worker.endpoint, 
commitFiles1)
+
+      val commitFiles = CommitFiles(
+        applicationId,
+        shuffleId,
+        primaryIds.subList(primaryIds.size() / 2, primaryIds.size()),
+        replicaIds.subList(replicaIds.size() / 2, replicaIds.size()),
+        getMapperAttempts(shuffleId),
+        commitEpoch.incrementAndGet())
+      val res2 = requestCommitFilesWithRetryForTest(worker.endpoint, 
commitFiles)
+
+      
res1.committedPrimaryStorageInfos.putAll(res2.committedPrimaryStorageInfos)
+      
res1.committedReplicaStorageInfos.putAll(res2.committedReplicaStorageInfos)
+      res1.committedMapIdBitMap.putAll(res2.committedMapIdBitMap)
+      val res = CommitFilesResponse(
+        status = if (res1.status == StatusCode.SUCCESS) res2.status else 
res1.status,
+        (res1.committedPrimaryIds.asScala ++ 
res2.committedPrimaryIds.asScala).toList.asJava,
+        (res1.committedReplicaIds.asScala ++ 
res2.committedReplicaIds.asScala).toList.asJava,
+        (res1.failedPrimaryIds.asScala ++ 
res2.failedPrimaryIds.asScala).toList.asJava,
+        (res1.failedReplicaIds.asScala ++ 
res2.failedReplicaIds.asScala).toList.asJava,
+        res1.committedPrimaryStorageInfos,
+        res1.committedReplicaStorageInfos,
+        res1.committedMapIdBitMap,
+        res1.totalWritten + res2.totalWritten,
+        res1.fileCount + res2.fileCount)
+
+      Future { res }(ec)
     }
   }
 
@@ -402,14 +504,14 @@ abstract class CommitHandler(
     }
   }
 
-  private def requestCommitFilesWithRetry(
+  private def requestCommitFilesWithRetryForTest(
       endpoint: RpcEndpointRef,
       message: CommitFiles): CommitFilesResponse = {
     val maxRetries = conf.clientRequestCommitFilesMaxRetries
     var retryTimes = 0
     while (retryTimes < maxRetries) {
       try {
-        if (testRetryCommitFiles && retryTimes < maxRetries - 1) {
+        if (retryTimes < maxRetries - 1) {
           endpoint.ask[CommitFilesResponse](message)
           Thread.sleep(1000)
           throw new Exception("Mock fail for CommitFiles")
@@ -420,7 +522,8 @@ abstract class CommitHandler(
         case e: Throwable =>
           retryTimes += 1
           logError(
-            s"AskSync worker(${endpoint.address}) CommitFiles for 
${message.shuffleId} failed (attempt $retryTimes/$maxRetries).",
+            s"Ask worker(${endpoint.address}) CommitFiles for 
${message.shuffleId} failed" +
+              s" (attempt $retryTimes/$maxRetries).",
             e)
       }
     }
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
index 54d05671f..b799d0870 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
@@ -19,7 +19,7 @@ package org.apache.celeborn.client.commit
 
 import java.util
 import java.util.Collections
-import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor}
 import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.JavaConverters._
@@ -52,8 +52,9 @@ class MapPartitionCommitHandler(
     conf: CelebornConf,
     shuffleAllocatedWorkers: ShuffleAllocatedWorkers,
     committedPartitionInfo: CommittedPartitionInfo,
-    workerStatusTracker: WorkerStatusTracker)
-  extends CommitHandler(appId, conf, committedPartitionInfo, 
workerStatusTracker)
+    workerStatusTracker: WorkerStatusTracker,
+    sharedRpcPool: ThreadPoolExecutor)
+  extends CommitHandler(appId, conf, committedPartitionInfo, 
workerStatusTracker, sharedRpcPool)
   with Logging {
 
   private val shuffleSucceedPartitionIds = JavaUtils.newConcurrentHashMap[Int, 
util.Set[Integer]]()
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index b93d36998..e86bc2317 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -19,7 +19,7 @@ package org.apache.celeborn.client.commit
 
 import java.nio.ByteBuffer
 import java.util
-import java.util.concurrent.{Callable, ConcurrentHashMap, TimeUnit}
+import java.util.concurrent.{Callable, ConcurrentHashMap, ThreadPoolExecutor, 
TimeUnit}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -51,8 +51,14 @@ class ReducePartitionCommitHandler(
     conf: CelebornConf,
     shuffleAllocatedWorkers: ShuffleAllocatedWorkers,
     committedPartitionInfo: CommittedPartitionInfo,
-    workerStatusTracker: WorkerStatusTracker)
-  extends CommitHandler(appUniqueId, conf, committedPartitionInfo, 
workerStatusTracker)
+    workerStatusTracker: WorkerStatusTracker,
+    sharedRpcPool: ThreadPoolExecutor)
+  extends CommitHandler(
+    appUniqueId,
+    conf,
+    committedPartitionInfo,
+    workerStatusTracker,
+    sharedRpcPool)
   with Logging {
 
   private val getReducerFileGroupRequest =
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index c24e84524..41934f50e 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -825,6 +825,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable 
with Logging with Se
   def clientPushStageEndTimeout: Long = get(CLIENT_PUSH_STAGE_END_TIMEOUT)
   def clientPushUnsafeRowFastWrite: Boolean = 
get(CLIENT_PUSH_UNSAFEROW_FASTWRITE_ENABLED)
   def clientRpcCacheExpireTime: Long = get(CLIENT_RPC_CACHE_EXPIRE_TIME)
+  def clientRpcSharedThreads: Int = get(CLIENT_RPC_SHARED_THREADS)
   def pushDataTimeoutMs: Long = get(CLIENT_PUSH_DATA_TIMEOUT)
   def clientPushLimitStrategy: String = get(CLIENT_PUSH_LIMIT_STRATEGY)
   def clientPushSlowStartInitialSleepTime: Long = 
get(CLIENT_PUSH_SLOW_START_INITIAL_SLEEP_TIME)
@@ -3703,6 +3704,14 @@ object CelebornConf extends Logging {
       .timeConf(TimeUnit.MILLISECONDS)
       .createWithDefaultString("15s")
 
+  val CLIENT_RPC_SHARED_THREADS: ConfigEntry[Int] =
+    buildConf("celeborn.client.rpc.shared.threads")
+      .categories("client")
+      .version("0.4.0")
+      .doc("Number of shared rpc threads in LifecycleManager.")
+      .intConf
+      .createWithDefault(16)
+
   val CLIENT_RESERVE_SLOTS_RACKAWARE_ENABLED: ConfigEntry[Boolean] =
     buildConf("celeborn.client.reserveSlots.rackaware.enabled")
       .withAlternative("celeborn.client.reserveSlots.rackware.enabled")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index dfab73579..7890bf77e 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -82,6 +82,7 @@ license: |
 | celeborn.client.rpc.registerShuffle.askTimeout | &lt;value of 
celeborn.&lt;module&gt;.io.connectionTimeout&gt; | Timeout for ask operations 
during register shuffle. During this process, there are two times for retry 
opportunities for requesting slots, one request for establishing a connection 
with Worker and `celeborn.client.reserveSlots.maxRetries` times for retry 
opportunities for reserving slots. User can customize this value according to 
your setting. By default, the value is the m [...]
 | celeborn.client.rpc.requestPartition.askTimeout | &lt;value of 
celeborn.&lt;module&gt;.io.connectionTimeout&gt; | Timeout for ask operations 
during requesting change partition location, such as reviving or splitting 
partition. During this process, there are 
`celeborn.client.reserveSlots.maxRetries` times for retry opportunities for 
reserving slots. User can customize this value according to your setting. By 
default, the value is the max timeout value `celeborn.<module>.io.connectionTim 
[...]
 | celeborn.client.rpc.reserveSlots.askTimeout | &lt;value of 
celeborn.rpc.askTimeout&gt; | Timeout for LifecycleManager request reserve 
slots. | 0.3.0 | 
+| celeborn.client.rpc.shared.threads | 16 | Number of shared rpc threads in 
LifecycleManager. | 0.4.0 | 
 | celeborn.client.shuffle.batchHandleChangePartition.interval | 100ms | 
Interval for LifecycleManager to schedule handling change partition requests in 
batch. | 0.3.0 | 
 | celeborn.client.shuffle.batchHandleChangePartition.threads | 8 | Threads 
number for LifecycleManager to handle change partition request in batch. | 
0.3.0 | 
 | celeborn.client.shuffle.batchHandleCommitPartition.interval | 5s | Interval 
for LifecycleManager to schedule handling commit partition requests in batch. | 
0.3.0 | 
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
index 4d0b42ded..280c16575 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
@@ -46,6 +46,7 @@ class RetryReviveTest extends AnyFunSuite
   test("celeborn spark integration test - retry revive as configured times") {
     val sparkConf = new SparkConf()
       .set(s"spark.${CelebornConf.TEST_CLIENT_RETRY_REVIVE.key}", "true")
+      .set(s"spark.${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}", "3")
       .setAppName("celeborn-demo").setMaster("local[2]")
     val ss = SparkSession.builder()
       .config(updateSparkConf(sparkConf, ShuffleMode.HASH))

Reply via email to