This is an automated email from the ASF dual-hosted git repository. nicholasjiang pushed a commit to branch branch-0.6 in repository https://gitbox.apache.org/repos/asf/celeborn.git
commit db9585f718514b5d9a450a4cb4726bfe357ac5c0 Author: SteNicholas <[email protected]> AuthorDate: Tue Feb 17 16:03:13 2026 +0800 [CELEBORN-2063] Parallelize the create partition writer in handleReserveSlots to speed up the reserveSlots RPC process time ### What changes were proposed in this pull request? Parallelize the create partition writer in `handleReserveSlots` to speed up the reserveSlots RPC process time。 ### Why are the changes needed? The creation of partition writer in `handleReserveSlots` could use parallelize way to speed up the reserveSlots RPC process time. ### Does this PR introduce _any_ user-facing change? Introduce `celeborn.worker.writer.create.parallel.enabled`, `celeborn.worker.writer.create.parallel.threads` and `eleborn.worker.writer.create.parallel.timeout` to config parallelize the creation of file writer. ### How was this patch tested? CI. Closes #3387 from SteNicholas/CELEBORN-2063. Authored-by: SteNicholas <[email protected]> Signed-off-by: SteNicholas <[email protected]> (cherry picked from commit 8e6f4d5f95f58238913bf6f5bc769e5508d64efe) Signed-off-by: SteNicholas <[email protected]> --- .../org/apache/celeborn/common/CelebornConf.scala | 28 +++ .../org/apache/celeborn/common/util/Utils.scala | 82 +++++++- docs/configuration/worker.md | 3 + .../service/deploy/worker/Controller.scala | 212 ++++++++++++++------- .../celeborn/service/deploy/worker/Worker.scala | 14 ++ 5 files changed, 256 insertions(+), 83 deletions(-) diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 3447d5dbd..6a5f0ed7e 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1341,6 +1341,10 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def workerS3FlusherThreads: Int = get(WORKER_FLUSHER_S3_THREADS) def workerOssFlusherThreads: Int = get(WORKER_FLUSHER_OSS_THREADS) def workerCreateWriterMaxAttempts: Int = get(WORKER_WRITER_CREATE_MAX_ATTEMPTS) + def workerCreateWriterParallelEnabled: Boolean = get(WORKER_WRITER_CREATE_PARALLEL_ENABLED) + def workerCreateWriterParallelThreads: Int = + get(WORKER_WRITER_CREATE_PARALLEL_THREADS).getOrElse(Runtime.getRuntime.availableProcessors) + def workerCreateWriterParallelTimeout: Long = get(WORKER_WRITER_CREATE_PARALLEL_TIMEOUT) def workerFlusherLocalGatherAPIEnabled: Boolean = get(WORKER_FLUSHER_LOCAL_GATHER_API_ENABLED) // ////////////////////////////////////////////////////// @@ -4038,6 +4042,30 @@ object CelebornConf extends Logging { .intConf .createWithDefault(3) + val WORKER_WRITER_CREATE_PARALLEL_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.worker.writer.create.parallel.enabled") + .categories("worker") + .version("0.6.3") + .doc("Whether to parallelize the creation of file writer.") + .booleanConf + .createWithDefault(false) + + val WORKER_WRITER_CREATE_PARALLEL_THREADS: OptionalConfigEntry[Int] = + buildConf("celeborn.worker.writer.create.parallel.threads") + .categories("worker") + .version("0.6.3") + .doc("Thread number of worker to parallelize the creation of file writer.") + .intConf + .createOptional + + val WORKER_WRITER_CREATE_PARALLEL_TIMEOUT: ConfigEntry[Long] = + buildConf("celeborn.worker.writer.create.parallel.timeout") + .categories("worker") + .version("0.6.3") + .doc("Timeout for a worker to create a file writer in parallel.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("120s") + val WORKER_FLUSHER_LOCAL_GATHER_API_ENABLED: ConfigEntry[Boolean] = buildConf("celeborn.worker.flusher.local.gatherAPI.enabled") .internal diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala index 7dec9f086..900356c9d 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala @@ -1004,11 +1004,14 @@ object Utils extends Logging { } /** - * if the action is timeout, will return the callback result - * if other exception will be thrown directly + * If the action is timeout, will return the callback result. + * If other exception will be thrown directly. + * * @param block the normal action block * @param callback callback if timeout + * @param threadPool thread pool to submit * @param timeoutInSeconds timeout limit value in seconds + * @param errorMessage error message to log exception * @tparam T result type * @return result */ @@ -1016,15 +1019,45 @@ object Utils extends Logging { threadPool: ThreadPoolExecutor, timeoutInSeconds: Long = 10, errorMessage: String = "none"): T = { - val futureTask = new Callable[T] { + tryFutureWithTimeoutAndCallback(callback)( + future(block)(threadPool), + timeoutInSeconds, + errorMessage) + } + + /** + * Create future that thread pool submits future task. + * + * @param block the normal action block + * @param threadPool thread pool to submit + * @tparam T result type + * @return future + */ + def future[T](block: => T)( + threadPool: ThreadPoolExecutor): java.util.concurrent.Future[T] = { + threadPool.submit(new Callable[T] { override def call(): T = { block } - } + }) + } - var future: java.util.concurrent.Future[T] = null + /** + * If the action is timeout, will return the callback result. + * If other exception will be thrown directly. + * + * @param callback callback if timeout + * @param future future to try with timeout and callback + * @param timeoutInSeconds timeout limit value in seconds + * @param errorMessage error message to log exception + * @tparam T result type + * @return result + */ + def tryFutureWithTimeoutAndCallback[T](callback: => T)( + future: java.util.concurrent.Future[T], + timeoutInSeconds: Long = 10, + errorMessage: String = "none"): T = { try { - future = threadPool.submit(futureTask) future.get(timeoutInSeconds, TimeUnit.SECONDS) } catch { case _: TimeoutException => @@ -1034,9 +1067,40 @@ object Utils extends Logging { case throwable: Throwable => throw throwable } finally { - if (null != future && !future.isCancelled) { - future.cancel(true) - } + cancelFuture(future) + } + } + + /** + * If the action is timeout, will return the callback result. + * If other exception will be thrown directly. + * + * @param futures futures to try with timeout and callback + * @param timeoutInSeconds timeout limit value in seconds + * @param errorMessage error message to log exception + * @tparam T result type + * @return results + */ + def tryFuturesWithTimeout[T]( + futures: List[java.util.concurrent.Future[T]], + timeoutInSeconds: Long = 10, + errorMessage: String = "none"): List[T] = { + try { + futures.map(_.get(timeoutInSeconds, TimeUnit.SECONDS)) + } catch { + case throwable: Throwable => + logError( + s"${throwable.getClass.getSimpleName} in thread ${Thread.currentThread().getName}," + + s" error message: $errorMessage") + throw throwable + } finally { + futures.foreach(cancelFuture) + } + } + + def cancelFuture[T](future: java.util.concurrent.Future[T]): Unit = { + if (null != future && !future.isCancelled) { + future.cancel(true) } } diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md index d9db38d0b..2b540884a 100644 --- a/docs/configuration/worker.md +++ b/docs/configuration/worker.md @@ -197,5 +197,8 @@ license: | | celeborn.worker.storage.workingDir | celeborn-worker/shuffle_data | false | Worker's working dir path name. | 0.3.0 | celeborn.worker.workingDir | | celeborn.worker.writer.close.timeout | 120s | false | Timeout for a file writer to close | 0.2.0 | | | celeborn.worker.writer.create.maxAttempts | 3 | false | Retry count for a file writer to create if its creation was failed. | 0.2.0 | | +| celeborn.worker.writer.create.parallel.enabled | false | false | Whether to parallelize the creation of file writer. | 0.6.3 | | +| celeborn.worker.writer.create.parallel.threads | <undefined> | false | Thread number of worker to parallelize the creation of file writer. | 0.6.3 | | +| celeborn.worker.writer.create.parallel.timeout | 120s | false | Timeout for a worker to create a file writer in parallel. | 0.6.3 | | | worker.flush.reuseCopyBuffer.enabled | true | false | Whether to enable reuse copy buffer for flush. Note that this copy buffer must not be referenced again after flushing. This means that, for example, the Hdfs(Oss or S3) client will not asynchronously access this buffer after the flush method returns, otherwise data modification problems will occur. | 0.6.1 | | <!--end-include--> diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala index cd2485b0a..d281b2dbb 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala @@ -32,14 +32,14 @@ import org.roaringbitmap.RoaringBitmap import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.internal.Logging -import org.apache.celeborn.common.meta.{ReduceFileMeta, WorkerInfo, WorkerPartitionLocationInfo} +import org.apache.celeborn.common.meta.{WorkerInfo, WorkerPartitionLocationInfo} import org.apache.celeborn.common.metrics.MetricsSystem import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType, StorageInfo} import org.apache.celeborn.common.protocol.message.ControlMessages._ import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.rpc._ import org.apache.celeborn.common.util.{JavaUtils, Utils} -import org.apache.celeborn.service.deploy.worker.storage.{MapPartitionMetaHandler, PartitionDataWriter, SegmentMapPartitionMetaHandler, StorageManager} +import org.apache.celeborn.service.deploy.worker.storage.{MapPartitionMetaHandler, PartitionDataWriter, StorageManager} private[deploy] class Controller( override val rpcEnv: RpcEnv, @@ -64,11 +64,13 @@ private[deploy] class Controller( var commitThreadPool: ThreadPoolExecutor = _ var commitFinishedChecker: ScheduledExecutorService = _ var asyncReplyPool: ScheduledExecutorService = _ + var createWriterThreadPool: ThreadPoolExecutor = _ val minPartitionSizeToEstimate = conf.minPartitionSizeToEstimate var shutdown: AtomicBoolean = _ val defaultPushdataTimeout = conf.pushDataTimeoutMs val mockCommitFilesFailure = conf.testMockCommitFilesFailure val shuffleCommitTimeout = conf.workerShuffleCommitTimeout + val createWriterParallelTimeout = conf.workerCreateWriterParallelTimeout val workerCommitFilesCheckInterval = conf.workerCommitFilesCheckInterval def init(worker: Worker): Unit = { @@ -83,6 +85,7 @@ private[deploy] class Controller( timer = worker.timer commitThreadPool = worker.commitThreadPool asyncReplyPool = worker.asyncReplyPool + createWriterThreadPool = worker.createWriterThreadPool shutdown = worker.shutdown commitFinishedChecker = worker.commitFinishedChecker @@ -192,88 +195,45 @@ private[deploy] class Controller( context.reply(ReserveSlotsResponse(StatusCode.NO_AVAILABLE_WORKING_DIR, msg)) return } - val primaryLocs = new jArrayList[PartitionLocation]() - try { - for (ind <- 0 until requestPrimaryLocs.size()) { - var location = partitionLocationInfo.getPrimaryLocation( - shuffleKey, - requestPrimaryLocs.get(ind).getUniqueId) - if (location == null) { - location = requestPrimaryLocs.get(ind) - val writer = storageManager.createPartitionDataWriter( - applicationId, - shuffleId, - location, - splitThreshold, - splitMode, - partitionType, - rangeReadFilter, - userIdentifier, - partitionSplitEnabled, - isSegmentGranularityVisible) - primaryLocs.add(new WorkingPartition(location, writer)) - } else { - primaryLocs.add(location) - } - } - } catch { - case e: Exception => - logError(s"CreateWriter for $shuffleKey failed.", e) - } + val primaryLocs = createWriters( + shuffleKey, + applicationId, + shuffleId, + requestPrimaryLocs, + splitThreshold, + splitMode, + partitionType, + rangeReadFilter, + userIdentifier, + partitionSplitEnabled, + isSegmentGranularityVisible, + isPrimary = true) if (primaryLocs.size() < requestPrimaryLocs.size()) { val msg = s"Not all primary partition satisfied for $shuffleKey" logWarning(s"[handleReserveSlots] $msg, will destroy writers.") - primaryLocs.asScala.foreach { partitionLocation => - val fileWriter = partitionLocation.asInstanceOf[WorkingPartition].getFileWriter - fileWriter.destroy(new IOException(s"Destroy FileWriter $fileWriter caused by " + - s"reserving slots failed for $shuffleKey.")) - } + destroyWriters(primaryLocs, shuffleKey) context.reply(ReserveSlotsResponse(StatusCode.RESERVE_SLOTS_FAILED, msg)) return } - val replicaLocs = new jArrayList[PartitionLocation]() - try { - for (ind <- 0 until requestReplicaLocs.size()) { - var location = - partitionLocationInfo.getReplicaLocation( - shuffleKey, - requestReplicaLocs.get(ind).getUniqueId) - if (location == null) { - location = requestReplicaLocs.get(ind) - val writer = storageManager.createPartitionDataWriter( - applicationId, - shuffleId, - location, - splitThreshold, - splitMode, - partitionType, - rangeReadFilter, - userIdentifier, - partitionSplitEnabled, - isSegmentGranularityVisible) - replicaLocs.add(new WorkingPartition(location, writer)) - } else { - replicaLocs.add(location) - } - } - } catch { - case e: Exception => - logError(s"CreateWriter for $shuffleKey failed.", e) - } + val replicaLocs = createWriters( + shuffleKey, + applicationId, + shuffleId, + requestReplicaLocs, + splitThreshold, + splitMode, + partitionType, + rangeReadFilter, + userIdentifier, + partitionSplitEnabled, + isSegmentGranularityVisible, + isPrimary = false) if (replicaLocs.size() < requestReplicaLocs.size()) { val msg = s"Not all replica partition satisfied for $shuffleKey" logWarning(s"[handleReserveSlots] $msg, destroy writers.") - primaryLocs.asScala.foreach { partitionLocation => - val fileWriter = partitionLocation.asInstanceOf[WorkingPartition].getFileWriter - fileWriter.destroy(new IOException(s"Destroy FileWriter $fileWriter caused by " + - s"reserving slots failed for $shuffleKey.")) - } - replicaLocs.asScala.foreach { partitionLocation => - val fileWriter = partitionLocation.asInstanceOf[WorkingPartition].getFileWriter - fileWriter.destroy(new IOException(s"Destroy FileWriter $fileWriter caused by " + - s"reserving slots failed for $shuffleKey.")) - } + destroyWriters(primaryLocs, shuffleKey) + destroyWriters(replicaLocs, shuffleKey) context.reply(ReserveSlotsResponse(StatusCode.RESERVE_SLOTS_FAILED, msg)) return } @@ -299,6 +259,110 @@ private[deploy] class Controller( context.reply(ReserveSlotsResponse(StatusCode.SUCCESS)) } + private def createWriters( + shuffleKey: String, + applicationId: String, + shuffleId: Int, + requestLocs: jList[PartitionLocation], + splitThreshold: Long, + splitMode: PartitionSplitMode, + partitionType: PartitionType, + rangeReadFilter: Boolean, + userIdentifier: UserIdentifier, + partitionSplitEnabled: Boolean, + isSegmentGranularityVisible: Boolean, + isPrimary: Boolean): jList[PartitionLocation] = { + val partitionLocations = new jArrayList[PartitionLocation]() + try { + def createWriter(partitionLocation: PartitionLocation): PartitionLocation = { + createPartitionDataWriter( + shuffleKey, + applicationId, + shuffleId, + partitionLocation, + splitThreshold, + splitMode, + partitionType, + rangeReadFilter, + userIdentifier, + partitionSplitEnabled, + isSegmentGranularityVisible, + isPrimary) + } + if (createWriterThreadPool == null) { + partitionLocations.addAll(requestLocs.asScala.map(createWriter).asJava) + } else { + partitionLocations.addAll(Utils.tryFuturesWithTimeout( + requestLocs.asScala.map(requestLoc => + Utils.future(createWriter(requestLoc))(createWriterThreadPool)).toList, + createWriterParallelTimeout, + s"Create FileWriter for $shuffleKey timeout.").asJava) + } + } catch { + case e: Exception => + logError(s"Create FileWriter for $shuffleKey failed.", e) + } + partitionLocations + } + + private def createPartitionDataWriter( + shuffleKey: String, + applicationId: String, + shuffleId: Int, + requestLoc: PartitionLocation, + splitThreshold: Long, + splitMode: PartitionSplitMode, + partitionType: PartitionType, + rangeReadFilter: Boolean, + userIdentifier: UserIdentifier, + partitionSplitEnabled: Boolean, + isSegmentGranularityVisible: Boolean, + isPrimary: Boolean): PartitionLocation = { + try { + var location = + if (isPrimary) { + partitionLocationInfo.getPrimaryLocation( + shuffleKey, + requestLoc.getUniqueId) + } else { + partitionLocationInfo.getReplicaLocation( + shuffleKey, + requestLoc.getUniqueId) + } + if (location == null) { + location = requestLoc + val writer = storageManager.createPartitionDataWriter( + applicationId, + shuffleId, + location, + splitThreshold, + splitMode, + partitionType, + rangeReadFilter, + userIdentifier, + partitionSplitEnabled, + isSegmentGranularityVisible) + new WorkingPartition(location, writer) + } else { + location + } + } catch { + case e: Exception => + logError(s"Create FileWriter for $requestLoc $shuffleKey failed.", e) + throw e + } + } + + private def destroyWriters( + partitionLocations: jList[PartitionLocation], + shuffleKey: String): Unit = { + partitionLocations.asScala.foreach { partitionLocation => + val fileWriter = partitionLocation.asInstanceOf[WorkingPartition].getFileWriter + fileWriter.destroy(new IOException(s"Destroy FileWriter $fileWriter caused by " + + s"reserving slots failed for $shuffleKey.")) + } + } + private def commitFiles( shuffleKey: String, uniqueIds: jList[String], diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala index be382e36f..8786ad232 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala @@ -346,6 +346,7 @@ private[celeborn] class Worker( conf.workerCleanThreads) val asyncReplyPool: ScheduledExecutorService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-rpc-async-replier") + var createWriterThreadPool: ThreadPoolExecutor = _ val timer = new HashedWheelTimer(ThreadUtils.namedSingleThreadFactory("worker-timer")) // Configs @@ -581,6 +582,13 @@ private[celeborn] class Worker( } }) + if (conf.workerCreateWriterParallelEnabled) { + createWriterThreadPool = + ThreadUtils.newDaemonFixedThreadPool( + conf.workerCreateWriterParallelThreads, + "worker-writer-creator") + } + pushDataHandler.init(this) replicateHandler.init(this) fetchHandler.init(this) @@ -628,12 +636,18 @@ private[celeborn] class Worker( commitThreadPool.shutdown() commitFinishedChecker.shutdown(); asyncReplyPool.shutdown() + if (createWriterThreadPool != null) { + createWriterThreadPool.shutdown() + } } else { forwardMessageScheduler.shutdownNow() replicateThreadPool.shutdownNow() commitThreadPool.shutdownNow() commitFinishedChecker.shutdownNow(); asyncReplyPool.shutdownNow() + if (createWriterThreadPool != null) { + createWriterThreadPool.shutdownNow() + } } workerSource.appActiveConnections.clear() partitionsSorter.close(exitKind)
