This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 4303be323 [CELEBORN-1165] Avoid calling parmap when reserve slots
4303be323 is described below

commit 4303be32317cbbb473095035bb19c876094c1024
Author: zky.zhoukeyong <[email protected]>
AuthorDate: Wed Dec 13 16:37:20 2023 +0800

    [CELEBORN-1165] Avoid calling parmap when reserve slots
    
    ### What changes were proposed in this pull request?
    As title
    
    ### Why are the changes needed?
    One user reported that LifecycleManager's parmap can create huge number of 
threads and causes OOM.
    
    
![image](https://github.com/apache/incubator-celeborn/assets/948245/1e9a0b83-32fe-40d5-8739-2b370e030fc8)
    
    There are four places where parmap is called:
    
    1. When LifecycleManager commits files
    2. When LifecycleManager reserves slots
    3. When LifecycleManager setup connection to workers
    4. When StorageManager calls close
    
    This PR fixes the second one. To be more detail, this PR eliminates 
`parmap` when reserving slots, and also replaces `askSync` with `ask`.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Manual test and GA.
    
    Closes #2152 from waitinfuture/1165-1.
    
    Lead-authored-by: zky.zhoukeyong <[email protected]>
    Co-authored-by: Cheng Pan <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../apache/celeborn/client/LifecycleManager.scala  | 126 +++++++++++++++------
 1 file changed, 92 insertions(+), 34 deletions(-)

diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 3f1cf7aab..5d39e8b76 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -20,12 +20,16 @@ package org.apache.celeborn.client
 import java.nio.ByteBuffer
 import java.util
 import java.util.{function, List => JList}
-import java.util.concurrent.{Callable, ConcurrentHashMap, ScheduledFuture, 
TimeUnit}
+import java.util.concurrent.{Callable, ConcurrentHashMap, LinkedBlockingQueue, 
ScheduledFuture, TimeUnit}
 import java.util.concurrent.atomic.AtomicInteger
 import java.util.function.Consumer
 
 import scala.collection.JavaConverters._
+import scala.collection.generic.CanBuildFrom
 import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.duration.Duration
 import scala.util.Random
 
 import com.google.common.annotations.VisibleForTesting
@@ -48,6 +52,7 @@ import 
org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNet
 import org.apache.celeborn.common.util.{JavaUtils, PbSerDeUtils, ThreadUtils, 
Utils}
 // Can Remove this if celeborn don't support scala211 in future
 import org.apache.celeborn.common.util.FunctionConverter._
+import org.apache.celeborn.common.util.ThreadUtils.awaitResult
 import org.apache.celeborn.common.util.Utils.UNKNOWN_APP_SHUFFLE_ID
 
 object LifecycleManager {
@@ -137,6 +142,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
   private var checkForShuffleRemoval: ScheduledFuture[_] = _
   val rpcSharedThreadPool =
     ThreadUtils.newDaemonCachedThreadPool("shared-rpc-pool", 
conf.clientRpcSharedThreads, 30)
+  val ec = ExecutionContext.fromExecutor(rpcSharedThreadPool)
 
   // init driver celeborn LifecycleManager rpc service
   override val rpcEnv: RpcEnv = RpcEnv.create(
@@ -888,42 +894,94 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
     val reserveSlotFailedWorkers = new ShuffleFailedWorkers()
     val failureInfos = new util.concurrent.CopyOnWriteArrayList[String]()
     val workerPartitionLocations = slots.asScala.filter(p => !p._2._1.isEmpty 
|| !p._2._2.isEmpty)
-    val parallelism =
-      Math.min(Math.max(1, workerPartitionLocations.size), 
conf.clientRpcMaxParallelism)
-    ThreadUtils.parmap(workerPartitionLocations, "ReserveSlot", parallelism) {
-      case (workerInfo, (primaryLocations, replicaLocations)) =>
-        val res =
-          if (workerInfo.endpoint == null) {
-            ReserveSlotsResponse(StatusCode.REQUEST_FAILED, s"$workerInfo 
endpoint is NULL!")
-          } else {
-            requestWorkerReserveSlots(
-              workerInfo.endpoint,
-              ReserveSlots(
-                appUniqueId,
-                shuffleId,
-                primaryLocations,
-                replicaLocations,
-                partitionSplitThreshold,
-                partitionSplitMode,
-                getPartitionType(shuffleId),
-                rangeReadFilter,
-                userIdentifier,
-                conf.pushDataTimeoutMs,
-                if (getPartitionType(shuffleId) == PartitionType.MAP)
-                  conf.clientShuffleMapPartitionSplitEnabled
-                else true))
+
+    val (locsWithNullEndpoint, locs) = 
workerPartitionLocations.partition(_._1.endpoint == null)
+    val futures = new LinkedBlockingQueue[(Future[ReserveSlotsResponse], 
WorkerInfo)]()
+    val outFutures = locs.map { case (workerInfo, (primaryLocations, 
replicaLocations)) =>
+      Future {
+        val future = workerInfo.endpoint.ask[ReserveSlotsResponse](
+          ReserveSlots(
+            appUniqueId,
+            shuffleId,
+            primaryLocations,
+            replicaLocations,
+            partitionSplitThreshold,
+            partitionSplitMode,
+            getPartitionType(shuffleId),
+            rangeReadFilter,
+            userIdentifier,
+            conf.pushDataTimeoutMs,
+            if (getPartitionType(shuffleId) == PartitionType.MAP)
+              conf.clientShuffleMapPartitionSplitEnabled
+            else true))
+        futures.add((future, workerInfo))
+      }(ec)
+    }
+    val cbf =
+      implicitly[
+        CanBuildFrom[mutable.Iterable[Future[Boolean]], Boolean, 
mutable.Iterable[Boolean]]]
+    val futureSeq = Future.sequence(outFutures)(cbf, ec)
+    awaitResult(futureSeq, Duration.Inf)
+
+    var timeout = conf.rpcAskTimeout.duration.toMillis
+    val delta = 50
+    while (timeout >= 0 && !futures.isEmpty) {
+      val iter = futures.iterator()
+      while (iter.hasNext) {
+        val (future, workerInfo) = iter.next()
+        if (future.isCompleted) {
+          future.value.get match {
+            case scala.util.Success(res) =>
+              if (res.status.equals(StatusCode.SUCCESS)) {
+                logDebug(s"Successfully allocated " +
+                  s"partitions buffer for shuffleId $shuffleId" +
+                  s" from worker ${workerInfo.readableAddress()}.")
+              } else {
+                failureInfos.add(s"[reserveSlots] Failed to" +
+                  s" reserve buffers for shuffleId $shuffleId" +
+                  s" from worker ${workerInfo.readableAddress()}. Reason: 
${res.reason}")
+                reserveSlotFailedWorkers.put(workerInfo, (res.status, 
System.currentTimeMillis()))
+              }
+            case scala.util.Failure(e) =>
+              failureInfos.add(s"[reserveSlots] Failed to" +
+                s" reserve buffers for shuffleId $shuffleId" +
+                s" from worker ${workerInfo.readableAddress()}. Reason: $e")
+              reserveSlotFailedWorkers.put(
+                workerInfo,
+                (StatusCode.REQUEST_FAILED, System.currentTimeMillis()))
           }
-        if (res.status.equals(StatusCode.SUCCESS)) {
-          logDebug(s"Successfully allocated " +
-            s"partitions buffer for shuffleId $shuffleId" +
-            s" from worker ${workerInfo.readableAddress()}.")
-        } else {
-          failureInfos.add(s"[reserveSlots] Failed to" +
-            s" reserve buffers for shuffleId $shuffleId" +
-            s" from worker ${workerInfo.readableAddress()}. Reason: 
${res.reason}")
-          reserveSlotFailedWorkers.put(workerInfo, (res.status, 
System.currentTimeMillis()))
+          iter.remove()
         }
+      }
+
+      if (!futures.isEmpty) {
+        Thread.sleep(delta)
+      }
+      timeout = timeout - delta
+    }
+
+    val iter = futures.iterator()
+    while (iter.hasNext) {
+      val futureStatus = iter.next()
+      val workerInfo = futureStatus._2
+      failureInfos.add(s"[reserveSlots] Failed to" +
+        s" reserve buffers for shuffleId $shuffleId" +
+        s" from worker ${workerInfo.readableAddress()}. Reason: Timeout")
+      reserveSlotFailedWorkers.put(
+        workerInfo,
+        (StatusCode.REQUEST_FAILED, System.currentTimeMillis()))
+      iter.remove()
+    }
+
+    locsWithNullEndpoint.foreach { case (workerInfo, (_, _)) =>
+      failureInfos.add(s"[reserveSlots] Failed to" +
+        s" reserve buffers for shuffleId $shuffleId" +
+        s" from worker ${workerInfo.readableAddress()}. Reason: null endpoint")
+      reserveSlotFailedWorkers.put(
+        workerInfo,
+        (StatusCode.REQUEST_FAILED, System.currentTimeMillis()))
     }
+
     if (failureInfos.asScala.nonEmpty) {
       logError(s"Aggregated error of reserveSlots for " +
         s"shuffleId $shuffleId " +

Reply via email to