This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch branch-0.3
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/branch-0.3 by this push:
new c3f8d4ce4 [CELEBORN-1160] Avoid calling parmap when commit files
c3f8d4ce4 is described below
commit c3f8d4ce4f0d0663fe7cd69150045e1904c565c7
Author: zky.zhoukeyong <[email protected]>
AuthorDate: Wed Dec 13 14:36:48 2023 +0800
[CELEBORN-1160] Avoid calling parmap when commit files
### What changes were proposed in this pull request?
As title
### Why are the changes needed?
One user reported that LifecycleManager's parmap can create huge number of
threads and causes OOM.

There are four places where parmap is called:
1. When LifecycleManager commits files
2. When LifecycleManager reserves slots
3. When LifecycleManager setup connection to workers
4. When StorageManager calls close
This PR fixes the first one. To be more detail, this PR eliminates `parmap`
when doing committing files, and also replaces `askSync` with `ask`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Manual test and GA.
Closes #2145 from waitinfuture/1160.
Lead-authored-by: zky.zhoukeyong <[email protected]>
Co-authored-by: Keyong Zhou <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
(cherry picked from commit 92bebd305d9a71acb7c9cd256e9f924ede77ef95)
Signed-off-by: zky.zhoukeyong <[email protected]>
---
.../org/apache/celeborn/client/CommitManager.scala | 32 +-
.../apache/celeborn/client/LifecycleManager.scala | 2 +
.../celeborn/client/commit/CommitHandler.scala | 353 +++++++++++++--------
.../client/commit/MapPartitionCommitHandler.scala | 7 +-
.../commit/ReducePartitionCommitHandler.scala | 12 +-
.../org/apache/celeborn/common/CelebornConf.scala | 9 +
docs/configuration/client.md | 1 +
.../celeborn/tests/spark/RetryReviveTest.scala | 1 +
8 files changed, 271 insertions(+), 146 deletions(-)
diff --git
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index beea26e20..c05646026 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -22,13 +22,14 @@ import java.util.concurrent.{ConcurrentHashMap,
ScheduledExecutorService, Schedu
import java.util.concurrent.atomic.{AtomicInteger, LongAdder}
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration.DurationInt
import org.roaringbitmap.RoaringBitmap
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers
-import org.apache.celeborn.client.commit.{CommitHandler,
MapPartitionCommitHandler, ReducePartitionCommitHandler}
+import org.apache.celeborn.client.commit.{CommitFilesParam, CommitHandler,
MapPartitionCommitHandler, ReducePartitionCommitHandler}
import org.apache.celeborn.client.listener.{WorkersStatus,
WorkerStatusListener}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.internal.Logging
@@ -113,13 +114,9 @@ class CommitManager(appUniqueId: String, val conf:
CelebornConf, lifecycleManage
if (workerToRequests.nonEmpty) {
val commitFilesFailedWorkers = new ShuffleFailedWorkers()
- val parallelism =
- Math.min(workerToRequests.size,
conf.clientRpcMaxParallelism)
try {
- ThreadUtils.parmap(
- workerToRequests,
- "CommitFiles",
- parallelism) {
+ val params = new
ArrayBuffer[CommitFilesParam](workerToRequests.size)
+ workerToRequests.foreach {
case (worker, requests) =>
val workerInfo =
lifecycleManager.shuffleAllocatedWorkers
@@ -141,15 +138,18 @@ class CommitManager(appUniqueId: String, val conf:
CelebornConf, lifecycleManage
.toList
.asJava
- commitHandler.commitFiles(
- appUniqueId,
- shuffleId,
- shuffleCommittedInfo,
+ params += CommitFilesParam(
workerInfo,
primaryIds,
- replicaIds,
- commitFilesFailedWorkers)
+ replicaIds)
}
+
+ commitHandler.doParallelCommitFiles(
+ shuffleId,
+ shuffleCommittedInfo,
+ params,
+ commitFilesFailedWorkers)
+
lifecycleManager.workerStatusTracker.recordWorkerFailure(
commitFilesFailedWorkers)
} finally {
@@ -277,13 +277,15 @@ class CommitManager(appUniqueId: String, val conf:
CelebornConf, lifecycleManage
conf,
lifecycleManager.shuffleAllocatedWorkers,
committedPartitionInfo,
- lifecycleManager.workerStatusTracker)
+ lifecycleManager.workerStatusTracker,
+ lifecycleManager.rpcSharedThreadPool)
case PartitionType.MAP => new MapPartitionCommitHandler(
appUniqueId,
conf,
lifecycleManager.shuffleAllocatedWorkers,
committedPartitionInfo,
- lifecycleManager.workerStatusTracker)
+ lifecycleManager.workerStatusTracker,
+ lifecycleManager.rpcSharedThreadPool)
case _ => throw new UnsupportedOperationException(
s"Unexpected ShufflePartitionType for CommitManager:
$partitionType")
}
diff --git
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index cf54dfcd8..c7756febb 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -125,6 +125,8 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
private val forwardMessageThread =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread")
private var checkForShuffleRemoval: ScheduledFuture[_] = _
+ val rpcSharedThreadPool =
+ ThreadUtils.newDaemonCachedThreadPool("shared-rpc-pool",
conf.clientRpcSharedThreads, 30)
// init driver celeborn LifecycleManager rpc service
override val rpcEnv: RpcEnv = RpcEnv.create(
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
index 947546416..641868519 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -18,11 +18,15 @@
package org.apache.celeborn.client.commit
import java.util
-import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue,
ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.{AtomicLong, LongAdder}
import scala.collection.JavaConverters._
+import scala.collection.generic.CanBuildFrom
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.duration.Duration
import org.apache.celeborn.client.{ShuffleCommittedInfo, WorkerStatusTracker}
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
@@ -37,6 +41,17 @@ import org.apache.celeborn.common.rpc.{RpcCallContext,
RpcEndpointRef}
import org.apache.celeborn.common.util.{CollectionUtils, JavaUtils,
ThreadUtils, Utils}
// Can Remove this if celeborn don't support scala211 in future
import org.apache.celeborn.common.util.FunctionConverter._
+import org.apache.celeborn.common.util.ThreadUtils.awaitResult
+
+case class CommitFilesParam(
+ worker: WorkerInfo,
+ primaryIds: util.List[String],
+ replicaIds: util.List[String])
+
+case class FutureWithStatus(
+ var future: Future[CommitFilesResponse],
+ commitFilesParam: CommitFilesParam,
+ var retriedTimes: Int)
case class CommitResult(
primaryPartitionLocationMap: ConcurrentHashMap[String, PartitionLocation],
@@ -47,7 +62,8 @@ abstract class CommitHandler(
appUniqueId: String,
conf: CelebornConf,
committedPartitionInfo: CommittedPartitionInfo,
- workerStatusTracker: WorkerStatusTracker) extends Logging {
+ workerStatusTracker: WorkerStatusTracker,
+ val sharedRpcPool: ThreadPoolExecutor) extends Logging {
private val pushReplicateEnabled = conf.clientPushReplicateEnabled
private val testRetryCommitFiles = conf.testRetryCommitFiles
@@ -57,6 +73,8 @@ abstract class CommitHandler(
private val fileCount = new LongAdder
protected val reducerFileGroupsMap = new ShuffleFileGroups
+ val ec = ExecutionContext.fromExecutor(sharedRpcPool)
+
def getPartitionType(): PartitionType
def isStageEnd(shuffleId: Int): Boolean = false
@@ -180,6 +198,146 @@ abstract class CommitHandler(
reducerFileGroupsMap.put(shuffleId, JavaUtils.newConcurrentHashMap())
}
+ def doParallelCommitFiles(
+ shuffleId: Int,
+ shuffleCommittedInfo: ShuffleCommittedInfo,
+ params: ArrayBuffer[CommitFilesParam],
+ commitFilesFailedWorkers: ShuffleFailedWorkers): Unit = {
+
+ def processResponse(res: CommitFilesResponse, worker: WorkerInfo): Unit = {
+ shuffleCommittedInfo.synchronized {
+ // record committed partitionIds
+ res.committedPrimaryIds.asScala.foreach {
+ case commitPrimaryId =>
+ val partitionUniqueIdList =
shuffleCommittedInfo.committedPrimaryIds.computeIfAbsent(
+ Utils.splitPartitionLocationUniqueId(commitPrimaryId)._1,
+ (k: Int) => new util.ArrayList[String]())
+ partitionUniqueIdList.add(commitPrimaryId)
+ }
+
+ res.committedReplicaIds.asScala.foreach {
+ case commitReplicaId =>
+ val partitionUniqueIdList =
shuffleCommittedInfo.committedReplicaIds.computeIfAbsent(
+ Utils.splitPartitionLocationUniqueId(commitReplicaId)._1,
+ (k: Int) => new util.ArrayList[String]())
+ partitionUniqueIdList.add(commitReplicaId)
+ }
+
+ // record committed partitions storage hint and disk hint
+
shuffleCommittedInfo.committedPrimaryStorageInfos.putAll(res.committedPrimaryStorageInfos)
+
shuffleCommittedInfo.committedReplicaStorageInfos.putAll(res.committedReplicaStorageInfos)
+
+ // record failed partitions
+ shuffleCommittedInfo.failedPrimaryPartitionIds.putAll(
+ res.failedPrimaryIds.asScala.map((_, worker)).toMap.asJava)
+ shuffleCommittedInfo.failedReplicaPartitionIds.putAll(
+ res.failedReplicaIds.asScala.map((_, worker)).toMap.asJava)
+
+
shuffleCommittedInfo.committedMapIdBitmap.putAll(res.committedMapIdBitMap)
+
+ totalWritten.add(res.totalWritten)
+ fileCount.add(res.fileCount)
+ shuffleCommittedInfo.currentShuffleFileCount.add(res.fileCount)
+ }
+ }
+
+ val futures = new LinkedBlockingQueue[FutureWithStatus]()
+
+ val outFutures = params.filter(param =>
+ !CollectionUtils.isEmpty(param.primaryIds) ||
+ !CollectionUtils.isEmpty(param.replicaIds)) map { param =>
+ Future {
+ val future = commitFiles(
+ appUniqueId,
+ shuffleId,
+ param.worker,
+ param.primaryIds,
+ param.replicaIds)
+
+ futures.add(FutureWithStatus(future, param, 1))
+ }(ec)
+ }
+ val cbf =
+ implicitly[
+ CanBuildFrom[ArrayBuffer[Future[Boolean]], Boolean,
ArrayBuffer[Boolean]]]
+ val futureSeq = Future.sequence(outFutures)(cbf, ec)
+ awaitResult(futureSeq, Duration.Inf)
+
+ val maxRetries = conf.clientRequestCommitFilesMaxRetries
+ var timeout = conf.rpcAskTimeout.duration.toMillis * maxRetries
+ val delta = 50
+ while (timeout >= 0 && !futures.isEmpty) {
+ val iter = futures.iterator()
+ while (iter.hasNext) {
+ val status = iter.next()
+ if (status.future.isCompleted) {
+ status.future.value.get match {
+ case scala.util.Success(res) =>
+ val worker = status.commitFilesParam.worker
+ res.status match {
+ case StatusCode.SUCCESS => // do nothing
+ case StatusCode.PARTIAL_SUCCESS |
StatusCode.SHUFFLE_NOT_REGISTERED | StatusCode.REQUEST_FAILED |
StatusCode.WORKER_EXCLUDED =>
+ logInfo(s"Request commitFiles return ${res.status} for " +
+ s"${Utils.makeShuffleKey(appUniqueId, shuffleId)}")
+ if (res.status != StatusCode.WORKER_EXCLUDED) {
+ commitFilesFailedWorkers.put(worker, (res.status,
System.currentTimeMillis()))
+ }
+ case _ =>
+ logError(s"Should never reach here! commit files response
status ${res.status}")
+ }
+
+ processResponse(res, worker)
+ iter.remove()
+ case scala.util.Failure(e) =>
+ val worker = status.commitFilesParam.worker
+ logError(
+ s"Ask worker($worker) CommitFiles for $shuffleId failed" +
+ s" (attempt ${status.retriedTimes}/$maxRetries).",
+ e)
+ if (status.retriedTimes < maxRetries) {
+ status.retriedTimes = status.retriedTimes + 1
+ status.future = commitFiles(
+ appUniqueId,
+ shuffleId,
+ status.commitFilesParam.worker,
+ status.commitFilesParam.primaryIds,
+ status.commitFilesParam.replicaIds)
+ } else {
+ val res = CommitFilesResponse(
+ StatusCode.REQUEST_FAILED,
+ List.empty.asJava,
+ List.empty.asJava,
+ status.commitFilesParam.primaryIds,
+ status.commitFilesParam.replicaIds)
+ processResponse(res, status.commitFilesParam.worker)
+ iter.remove()
+ }
+ }
+ }
+ }
+
+ if (!futures.isEmpty) {
+ Thread.sleep(delta)
+ }
+ timeout = timeout - delta
+ }
+
+ val iter = futures.iterator()
+ while (iter.hasNext) {
+ val status = iter.next()
+ logError(
+ s"Ask worker(${status.commitFilesParam.worker}) CommitFiles for
$shuffleId timed out")
+ val res = CommitFilesResponse(
+ StatusCode.REQUEST_FAILED,
+ List.empty.asJava,
+ List.empty.asJava,
+ status.commitFilesParam.primaryIds,
+ status.commitFilesParam.replicaIds)
+ processResponse(res, status.commitFilesParam.worker)
+ iter.remove()
+ }
+ }
+
def parallelCommitFiles(
shuffleId: Int,
allocatedWorkers: util.Map[WorkerInfo, ShufflePartitionLocationInfo],
@@ -195,11 +353,9 @@ abstract class CommitHandler(
val commitFileStartTime = System.nanoTime()
val workerPartitionLocations =
allocatedWorkers.asScala.filter(!_._2.isEmpty)
- val parallelism = Math.min(workerPartitionLocations.size,
conf.clientRpcMaxParallelism)
- ThreadUtils.parmap(
- workerPartitionLocations,
- "CommitFiles",
- parallelism) { case (worker, partitionLocationInfo) =>
+
+ val params = new
ArrayBuffer[CommitFilesParam](workerPartitionLocations.size)
+ workerPartitionLocations.foreach { case (worker, partitionLocationInfo) =>
val primaryParts =
partitionLocationInfo.getPrimaryPartitions(partitionIdOpt)
val replicaParts =
partitionLocationInfo.getReplicaPartitions(partitionIdOpt)
@@ -226,17 +382,14 @@ abstract class CommitHandler(
.map(_.getUniqueId).toList.asJava)
}
- commitFiles(
- appUniqueId,
- shuffleId,
- shuffleCommittedInfo,
+ params += CommitFilesParam(
worker,
primaryIds,
- replicaIds,
- commitFilesFailedWorkers)
-
+ replicaIds)
}
+ doParallelCommitFiles(shuffleId, shuffleCommittedInfo, params,
commitFilesFailedWorkers)
+
logInfo(s"Shuffle $shuffleId " +
s"commit files complete. File count
${shuffleCommittedInfo.currentShuffleFileCount.sum()} " +
s"using ${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() -
commitFileStartTime)} ms")
@@ -247,119 +400,68 @@ abstract class CommitHandler(
def commitFiles(
applicationId: String,
shuffleId: Int,
- shuffleCommittedInfo: ShuffleCommittedInfo,
worker: WorkerInfo,
primaryIds: util.List[String],
- replicaIds: util.List[String],
- commitFilesFailedWorkers: ShuffleFailedWorkers): Unit = {
-
- if (CollectionUtils.isEmpty(primaryIds) &&
CollectionUtils.isEmpty(replicaIds)) {
- return
- }
+ replicaIds: util.List[String]): Future[CommitFilesResponse] = {
- val res =
- if (!testRetryCommitFiles) {
- val commitFiles = CommitFiles(
- applicationId,
- shuffleId,
- primaryIds,
- replicaIds,
- getMapperAttempts(shuffleId),
- commitEpoch.incrementAndGet())
- val res =
- if (conf.clientCommitFilesIgnoreExcludedWorkers &&
- workerStatusTracker.excludedWorkers.containsKey(worker)) {
- CommitFilesResponse(
- StatusCode.WORKER_EXCLUDED,
- List.empty.asJava,
- List.empty.asJava,
- primaryIds,
- replicaIds)
- } else {
- requestCommitFilesWithRetry(worker.endpoint, commitFiles)
- }
-
- res.status match {
- case StatusCode.SUCCESS => // do nothing
- case StatusCode.PARTIAL_SUCCESS | StatusCode.SHUFFLE_NOT_REGISTERED
| StatusCode.REQUEST_FAILED | StatusCode.WORKER_EXCLUDED =>
- logInfo(s"Request $commitFiles return ${res.status} for " +
- s"${Utils.makeShuffleKey(applicationId, shuffleId)}")
- if (res.status != StatusCode.WORKER_EXCLUDED) {
- commitFilesFailedWorkers.put(worker, (res.status,
System.currentTimeMillis()))
- }
- case _ =>
- logError(s"Should never reach here! commit files response status
${res.status}")
- }
- res
+ if (!testRetryCommitFiles) {
+ val commitFiles = CommitFiles(
+ applicationId,
+ shuffleId,
+ primaryIds,
+ replicaIds,
+ getMapperAttempts(shuffleId),
+ commitEpoch.incrementAndGet())
+
+ if (conf.clientCommitFilesIgnoreExcludedWorkers &&
+ workerStatusTracker.excludedWorkers.containsKey(worker)) {
+ Future {
+ CommitFilesResponse(
+ StatusCode.WORKER_EXCLUDED,
+ List.empty.asJava,
+ List.empty.asJava,
+ primaryIds,
+ replicaIds)
+ }(ec)
} else {
- // for test
- val commitFiles1 = CommitFiles(
- applicationId,
- shuffleId,
- primaryIds.subList(0, primaryIds.size() / 2),
- replicaIds.subList(0, replicaIds.size() / 2),
- getMapperAttempts(shuffleId),
- commitEpoch.incrementAndGet())
- val res1 = requestCommitFilesWithRetry(worker.endpoint, commitFiles1)
-
- val commitFiles = CommitFiles(
- applicationId,
- shuffleId,
- primaryIds.subList(primaryIds.size() / 2, primaryIds.size()),
- replicaIds.subList(replicaIds.size() / 2, replicaIds.size()),
- getMapperAttempts(shuffleId),
- commitEpoch.incrementAndGet())
- val res2 = requestCommitFilesWithRetry(worker.endpoint, commitFiles)
-
-
res1.committedPrimaryStorageInfos.putAll(res2.committedPrimaryStorageInfos)
-
res1.committedReplicaStorageInfos.putAll(res2.committedReplicaStorageInfos)
- res1.committedMapIdBitMap.putAll(res2.committedMapIdBitMap)
- CommitFilesResponse(
- status = if (res1.status == StatusCode.SUCCESS) res2.status else
res1.status,
- (res1.committedPrimaryIds.asScala ++
res2.committedPrimaryIds.asScala).toList.asJava,
- (res1.committedReplicaIds.asScala ++
res2.committedReplicaIds.asScala).toList.asJava,
- (res1.failedPrimaryIds.asScala ++
res2.failedPrimaryIds.asScala).toList.asJava,
- (res1.failedReplicaIds.asScala ++
res2.failedReplicaIds.asScala).toList.asJava,
- res1.committedPrimaryStorageInfos,
- res1.committedReplicaStorageInfos,
- res1.committedMapIdBitMap,
- res1.totalWritten + res2.totalWritten,
- res1.fileCount + res2.fileCount)
+ worker.endpoint.ask[CommitFilesResponse](commitFiles)
}
-
- shuffleCommittedInfo.synchronized {
- // record committed partitionIds
- res.committedPrimaryIds.asScala.foreach({
- case commitPrimaryId =>
- val partitionUniqueIdList =
shuffleCommittedInfo.committedPrimaryIds.computeIfAbsent(
- Utils.splitPartitionLocationUniqueId(commitPrimaryId)._1,
- (k: Int) => new util.ArrayList[String]())
- partitionUniqueIdList.add(commitPrimaryId)
- })
-
- res.committedReplicaIds.asScala.foreach({
- case commitReplicaId =>
- val partitionUniqueIdList =
shuffleCommittedInfo.committedReplicaIds.computeIfAbsent(
- Utils.splitPartitionLocationUniqueId(commitReplicaId)._1,
- (k: Int) => new util.ArrayList[String]())
- partitionUniqueIdList.add(commitReplicaId)
- })
-
- // record committed partitions storage hint and disk hint
-
shuffleCommittedInfo.committedPrimaryStorageInfos.putAll(res.committedPrimaryStorageInfos)
-
shuffleCommittedInfo.committedReplicaStorageInfos.putAll(res.committedReplicaStorageInfos)
-
- // record failed partitions
- shuffleCommittedInfo.failedPrimaryPartitionIds.putAll(
- res.failedPrimaryIds.asScala.map((_, worker)).toMap.asJava)
- shuffleCommittedInfo.failedReplicaPartitionIds.putAll(
- res.failedReplicaIds.asScala.map((_, worker)).toMap.asJava)
-
-
shuffleCommittedInfo.committedMapIdBitmap.putAll(res.committedMapIdBitMap)
-
- totalWritten.add(res.totalWritten)
- fileCount.add(res.fileCount)
- shuffleCommittedInfo.currentShuffleFileCount.add(res.fileCount)
+ } else {
+ // for test
+ val commitFiles1 = CommitFiles(
+ applicationId,
+ shuffleId,
+ primaryIds.subList(0, primaryIds.size() / 2),
+ replicaIds.subList(0, replicaIds.size() / 2),
+ getMapperAttempts(shuffleId),
+ commitEpoch.incrementAndGet())
+ val res1 = requestCommitFilesWithRetryForTest(worker.endpoint,
commitFiles1)
+
+ val commitFiles = CommitFiles(
+ applicationId,
+ shuffleId,
+ primaryIds.subList(primaryIds.size() / 2, primaryIds.size()),
+ replicaIds.subList(replicaIds.size() / 2, replicaIds.size()),
+ getMapperAttempts(shuffleId),
+ commitEpoch.incrementAndGet())
+ val res2 = requestCommitFilesWithRetryForTest(worker.endpoint,
commitFiles)
+
+
res1.committedPrimaryStorageInfos.putAll(res2.committedPrimaryStorageInfos)
+
res1.committedReplicaStorageInfos.putAll(res2.committedReplicaStorageInfos)
+ res1.committedMapIdBitMap.putAll(res2.committedMapIdBitMap)
+ val res = CommitFilesResponse(
+ status = if (res1.status == StatusCode.SUCCESS) res2.status else
res1.status,
+ (res1.committedPrimaryIds.asScala ++
res2.committedPrimaryIds.asScala).toList.asJava,
+ (res1.committedReplicaIds.asScala ++
res2.committedReplicaIds.asScala).toList.asJava,
+ (res1.failedPrimaryIds.asScala ++
res2.failedPrimaryIds.asScala).toList.asJava,
+ (res1.failedReplicaIds.asScala ++
res2.failedReplicaIds.asScala).toList.asJava,
+ res1.committedPrimaryStorageInfos,
+ res1.committedReplicaStorageInfos,
+ res1.committedMapIdBitMap,
+ res1.totalWritten + res2.totalWritten,
+ res1.fileCount + res2.fileCount)
+
+ Future { res }(ec)
}
}
@@ -402,14 +504,14 @@ abstract class CommitHandler(
}
}
- private def requestCommitFilesWithRetry(
+ private def requestCommitFilesWithRetryForTest(
endpoint: RpcEndpointRef,
message: CommitFiles): CommitFilesResponse = {
val maxRetries = conf.clientRequestCommitFilesMaxRetries
var retryTimes = 0
while (retryTimes < maxRetries) {
try {
- if (testRetryCommitFiles && retryTimes < maxRetries - 1) {
+ if (retryTimes < maxRetries - 1) {
endpoint.ask[CommitFilesResponse](message)
Thread.sleep(1000)
throw new Exception("Mock fail for CommitFiles")
@@ -420,7 +522,8 @@ abstract class CommitHandler(
case e: Throwable =>
retryTimes += 1
logError(
- s"AskSync worker(${endpoint.address}) CommitFiles for
${message.shuffleId} failed (attempt $retryTimes/$maxRetries).",
+ s"Ask worker(${endpoint.address}) CommitFiles for
${message.shuffleId} failed" +
+ s" (attempt $retryTimes/$maxRetries).",
e)
}
}
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
index 54d05671f..b799d0870 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
@@ -19,7 +19,7 @@ package org.apache.celeborn.client.commit
import java.util
import java.util.Collections
-import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConverters._
@@ -52,8 +52,9 @@ class MapPartitionCommitHandler(
conf: CelebornConf,
shuffleAllocatedWorkers: ShuffleAllocatedWorkers,
committedPartitionInfo: CommittedPartitionInfo,
- workerStatusTracker: WorkerStatusTracker)
- extends CommitHandler(appId, conf, committedPartitionInfo,
workerStatusTracker)
+ workerStatusTracker: WorkerStatusTracker,
+ sharedRpcPool: ThreadPoolExecutor)
+ extends CommitHandler(appId, conf, committedPartitionInfo,
workerStatusTracker, sharedRpcPool)
with Logging {
private val shuffleSucceedPartitionIds = JavaUtils.newConcurrentHashMap[Int,
util.Set[Integer]]()
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index b93d36998..e86bc2317 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -19,7 +19,7 @@ package org.apache.celeborn.client.commit
import java.nio.ByteBuffer
import java.util
-import java.util.concurrent.{Callable, ConcurrentHashMap, TimeUnit}
+import java.util.concurrent.{Callable, ConcurrentHashMap, ThreadPoolExecutor,
TimeUnit}
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -51,8 +51,14 @@ class ReducePartitionCommitHandler(
conf: CelebornConf,
shuffleAllocatedWorkers: ShuffleAllocatedWorkers,
committedPartitionInfo: CommittedPartitionInfo,
- workerStatusTracker: WorkerStatusTracker)
- extends CommitHandler(appUniqueId, conf, committedPartitionInfo,
workerStatusTracker)
+ workerStatusTracker: WorkerStatusTracker,
+ sharedRpcPool: ThreadPoolExecutor)
+ extends CommitHandler(
+ appUniqueId,
+ conf,
+ committedPartitionInfo,
+ workerStatusTracker,
+ sharedRpcPool)
with Logging {
private val getReducerFileGroupRequest =
diff --git
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index ef7b63567..e288ea7b1 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -800,6 +800,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable
with Logging with Se
def clientPushStageEndTimeout: Long = get(CLIENT_PUSH_STAGE_END_TIMEOUT)
def clientPushUnsafeRowFastWrite: Boolean =
get(CLIENT_PUSH_UNSAFEROW_FASTWRITE_ENABLED)
def clientRpcCacheExpireTime: Long = get(CLIENT_RPC_CACHE_EXPIRE_TIME)
+ def clientRpcSharedThreads: Int = get(CLIENT_RPC_SHARED_THREADS)
def pushDataTimeoutMs: Long = get(CLIENT_PUSH_DATA_TIMEOUT)
def clientPushLimitStrategy: String = get(CLIENT_PUSH_LIMIT_STRATEGY)
def clientPushSlowStartInitialSleepTime: Long =
get(CLIENT_PUSH_SLOW_START_INITIAL_SLEEP_TIME)
@@ -3524,6 +3525,14 @@ object CelebornConf extends Logging {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("15s")
+ val CLIENT_RPC_SHARED_THREADS: ConfigEntry[Int] =
+ buildConf("celeborn.client.rpc.shared.threads")
+ .categories("client")
+ .version("0.4.0")
+ .doc("Number of shared rpc threads in LifecycleManager.")
+ .intConf
+ .createWithDefault(16)
+
val CLIENT_RESERVE_SLOTS_RACKAWARE_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.client.reserveSlots.rackaware.enabled")
.withAlternative("celeborn.client.reserveSlots.rackware.enabled")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 0fa473010..4b553a0ef 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -80,6 +80,7 @@ license: |
| celeborn.client.rpc.registerShuffle.askTimeout | <value of
celeborn.<module>.io.connectionTimeout> | Timeout for ask operations
during register shuffle. During this process, there are two times for retry
opportunities for requesting slots, one request for establishing a connection
with Worker and `celeborn.client.reserveSlots.maxRetries` times for retry
opportunities for reserving slots. User can customize this value according to
your setting. By default, the value is the m [...]
| celeborn.client.rpc.requestPartition.askTimeout | <value of
celeborn.<module>.io.connectionTimeout> | Timeout for ask operations
during requesting change partition location, such as reviving or splitting
partition. During this process, there are
`celeborn.client.reserveSlots.maxRetries` times for retry opportunities for
reserving slots. User can customize this value according to your setting. By
default, the value is the max timeout value `celeborn.<module>.io.connectionTim
[...]
| celeborn.client.rpc.reserveSlots.askTimeout | <value of
celeborn.rpc.askTimeout> | Timeout for LifecycleManager request reserve
slots. | 0.3.0 |
+| celeborn.client.rpc.shared.threads | 16 | Number of shared rpc threads in
LifecycleManager. | 0.4.0 |
| celeborn.client.shuffle.batchHandleChangePartition.interval | 100ms |
Interval for LifecycleManager to schedule handling change partition requests in
batch. | 0.3.0 |
| celeborn.client.shuffle.batchHandleChangePartition.threads | 8 | Threads
number for LifecycleManager to handle change partition request in batch. |
0.3.0 |
| celeborn.client.shuffle.batchHandleCommitPartition.interval | 5s | Interval
for LifecycleManager to schedule handling commit partition requests in batch. |
0.3.0 |
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
index 4d0b42ded..280c16575 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
@@ -46,6 +46,7 @@ class RetryReviveTest extends AnyFunSuite
test("celeborn spark integration test - retry revive as configured times") {
val sparkConf = new SparkConf()
.set(s"spark.${CelebornConf.TEST_CLIENT_RETRY_REVIVE.key}", "true")
+ .set(s"spark.${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}", "3")
.setAppName("celeborn-demo").setMaster("local[2]")
val ss = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))