This is an automated email from the ASF dual-hosted git repository.

nicholasjiang pushed a commit to branch branch-0.6
in repository https://gitbox.apache.org/repos/asf/celeborn.git

commit db9585f718514b5d9a450a4cb4726bfe357ac5c0
Author: SteNicholas <[email protected]>
AuthorDate: Tue Feb 17 16:03:13 2026 +0800

    [CELEBORN-2063] Parallelize the create partition writer in 
handleReserveSlots to speed up the reserveSlots RPC process time
    
    ### What changes were proposed in this pull request?
    
    Parallelize the create partition writer in `handleReserveSlots` to speed up 
the reserveSlots RPC process time。
    
    ### Why are the changes needed?
    
    The creation of partition writer in `handleReserveSlots` could use 
parallelize way to speed up the reserveSlots RPC process time.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Introduce `celeborn.worker.writer.create.parallel.enabled`, 
`celeborn.worker.writer.create.parallel.threads` and 
`eleborn.worker.writer.create.parallel.timeout` to config parallelize the 
creation of file writer.
    
    ### How was this patch tested?
    
    CI.
    
    Closes #3387 from SteNicholas/CELEBORN-2063.
    
    Authored-by: SteNicholas <[email protected]>
    Signed-off-by: SteNicholas <[email protected]>
    (cherry picked from commit 8e6f4d5f95f58238913bf6f5bc769e5508d64efe)
    Signed-off-by: SteNicholas <[email protected]>
---
 .../org/apache/celeborn/common/CelebornConf.scala  |  28 +++
 .../org/apache/celeborn/common/util/Utils.scala    |  82 +++++++-
 docs/configuration/worker.md                       |   3 +
 .../service/deploy/worker/Controller.scala         | 212 ++++++++++++++-------
 .../celeborn/service/deploy/worker/Worker.scala    |  14 ++
 5 files changed, 256 insertions(+), 83 deletions(-)

diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 3447d5dbd..6a5f0ed7e 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -1341,6 +1341,10 @@ class CelebornConf(loadDefaults: Boolean) extends 
Cloneable with Logging with Se
   def workerS3FlusherThreads: Int = get(WORKER_FLUSHER_S3_THREADS)
   def workerOssFlusherThreads: Int = get(WORKER_FLUSHER_OSS_THREADS)
   def workerCreateWriterMaxAttempts: Int = 
get(WORKER_WRITER_CREATE_MAX_ATTEMPTS)
+  def workerCreateWriterParallelEnabled: Boolean = 
get(WORKER_WRITER_CREATE_PARALLEL_ENABLED)
+  def workerCreateWriterParallelThreads: Int =
+    
get(WORKER_WRITER_CREATE_PARALLEL_THREADS).getOrElse(Runtime.getRuntime.availableProcessors)
+  def workerCreateWriterParallelTimeout: Long = 
get(WORKER_WRITER_CREATE_PARALLEL_TIMEOUT)
   def workerFlusherLocalGatherAPIEnabled: Boolean = 
get(WORKER_FLUSHER_LOCAL_GATHER_API_ENABLED)
 
   // //////////////////////////////////////////////////////
@@ -4038,6 +4042,30 @@ object CelebornConf extends Logging {
       .intConf
       .createWithDefault(3)
 
+  val WORKER_WRITER_CREATE_PARALLEL_ENABLED: ConfigEntry[Boolean] =
+    buildConf("celeborn.worker.writer.create.parallel.enabled")
+      .categories("worker")
+      .version("0.6.3")
+      .doc("Whether to parallelize the creation of file writer.")
+      .booleanConf
+      .createWithDefault(false)
+
+  val WORKER_WRITER_CREATE_PARALLEL_THREADS: OptionalConfigEntry[Int] =
+    buildConf("celeborn.worker.writer.create.parallel.threads")
+      .categories("worker")
+      .version("0.6.3")
+      .doc("Thread number of worker to parallelize the creation of file 
writer.")
+      .intConf
+      .createOptional
+
+  val WORKER_WRITER_CREATE_PARALLEL_TIMEOUT: ConfigEntry[Long] =
+    buildConf("celeborn.worker.writer.create.parallel.timeout")
+      .categories("worker")
+      .version("0.6.3")
+      .doc("Timeout for a worker to create a file writer in parallel.")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createWithDefaultString("120s")
+
   val WORKER_FLUSHER_LOCAL_GATHER_API_ENABLED: ConfigEntry[Boolean] =
     buildConf("celeborn.worker.flusher.local.gatherAPI.enabled")
       .internal
diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala 
b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
index 7dec9f086..900356c9d 100644
--- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
@@ -1004,11 +1004,14 @@ object Utils extends Logging {
   }
 
   /**
-   * if the action is timeout, will return the callback result
-   * if other exception will be thrown directly
+   * If the action is timeout, will return the callback result.
+   * If other exception will be thrown directly.
+   *
    * @param block the normal action block
    * @param callback callback if timeout
+   * @param threadPool thread pool to submit
    * @param timeoutInSeconds timeout limit value in seconds
+   * @param errorMessage error message to log exception
    * @tparam T result type
    * @return result
    */
@@ -1016,15 +1019,45 @@ object Utils extends Logging {
       threadPool: ThreadPoolExecutor,
       timeoutInSeconds: Long = 10,
       errorMessage: String = "none"): T = {
-    val futureTask = new Callable[T] {
+    tryFutureWithTimeoutAndCallback(callback)(
+      future(block)(threadPool),
+      timeoutInSeconds,
+      errorMessage)
+  }
+
+  /**
+   * Create future that thread pool submits future task.
+   *
+   * @param block the normal action block
+   * @param threadPool thread pool to submit
+   * @tparam T result type
+   * @return future
+   */
+  def future[T](block: => T)(
+      threadPool: ThreadPoolExecutor): java.util.concurrent.Future[T] = {
+    threadPool.submit(new Callable[T] {
       override def call(): T = {
         block
       }
-    }
+    })
+  }
 
-    var future: java.util.concurrent.Future[T] = null
+  /**
+   * If the action is timeout, will return the callback result.
+   * If other exception will be thrown directly.
+   *
+   * @param callback callback if timeout
+   * @param future future to try with timeout and callback
+   * @param timeoutInSeconds timeout limit value in seconds
+   * @param errorMessage error message to log exception
+   * @tparam T result type
+   * @return result
+   */
+  def tryFutureWithTimeoutAndCallback[T](callback: => T)(
+      future: java.util.concurrent.Future[T],
+      timeoutInSeconds: Long = 10,
+      errorMessage: String = "none"): T = {
     try {
-      future = threadPool.submit(futureTask)
       future.get(timeoutInSeconds, TimeUnit.SECONDS)
     } catch {
       case _: TimeoutException =>
@@ -1034,9 +1067,40 @@ object Utils extends Logging {
       case throwable: Throwable =>
         throw throwable
     } finally {
-      if (null != future && !future.isCancelled) {
-        future.cancel(true)
-      }
+      cancelFuture(future)
+    }
+  }
+
+  /**
+   * If the action is timeout, will return the callback result.
+   * If other exception will be thrown directly.
+   *
+   * @param futures futures to try with timeout and callback
+   * @param timeoutInSeconds timeout limit value in seconds
+   * @param errorMessage error message to log exception
+   * @tparam T result type
+   * @return results
+   */
+  def tryFuturesWithTimeout[T](
+      futures: List[java.util.concurrent.Future[T]],
+      timeoutInSeconds: Long = 10,
+      errorMessage: String = "none"): List[T] = {
+    try {
+      futures.map(_.get(timeoutInSeconds, TimeUnit.SECONDS))
+    } catch {
+      case throwable: Throwable =>
+        logError(
+          s"${throwable.getClass.getSimpleName} in thread 
${Thread.currentThread().getName}," +
+            s" error message: $errorMessage")
+        throw throwable
+    } finally {
+      futures.foreach(cancelFuture)
+    }
+  }
+
+  def cancelFuture[T](future: java.util.concurrent.Future[T]): Unit = {
+    if (null != future && !future.isCancelled) {
+      future.cancel(true)
     }
   }
 
diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md
index d9db38d0b..2b540884a 100644
--- a/docs/configuration/worker.md
+++ b/docs/configuration/worker.md
@@ -197,5 +197,8 @@ license: |
 | celeborn.worker.storage.workingDir | celeborn-worker/shuffle_data | false | 
Worker's working dir path name. | 0.3.0 | celeborn.worker.workingDir | 
 | celeborn.worker.writer.close.timeout | 120s | false | Timeout for a file 
writer to close | 0.2.0 |  | 
 | celeborn.worker.writer.create.maxAttempts | 3 | false | Retry count for a 
file writer to create if its creation was failed. | 0.2.0 |  | 
+| celeborn.worker.writer.create.parallel.enabled | false | false | Whether to 
parallelize the creation of file writer. | 0.6.3 |  | 
+| celeborn.worker.writer.create.parallel.threads | &lt;undefined&gt; | false | 
Thread number of worker to parallelize the creation of file writer. | 0.6.3 |  
| 
+| celeborn.worker.writer.create.parallel.timeout | 120s | false | Timeout for 
a worker to create a file writer in parallel. | 0.6.3 |  | 
 | worker.flush.reuseCopyBuffer.enabled | true | false | Whether to enable 
reuse copy buffer for flush. Note that this copy buffer must not be referenced 
again after flushing. This means that, for example, the Hdfs(Oss or S3) client 
will not asynchronously access this buffer after the flush method returns, 
otherwise data modification problems will occur. | 0.6.1 |  | 
 <!--end-include-->
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
index cd2485b0a..d281b2dbb 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala
@@ -32,14 +32,14 @@ import org.roaringbitmap.RoaringBitmap
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.identity.UserIdentifier
 import org.apache.celeborn.common.internal.Logging
-import org.apache.celeborn.common.meta.{ReduceFileMeta, WorkerInfo, 
WorkerPartitionLocationInfo}
+import org.apache.celeborn.common.meta.{WorkerInfo, 
WorkerPartitionLocationInfo}
 import org.apache.celeborn.common.metrics.MetricsSystem
 import org.apache.celeborn.common.protocol.{PartitionLocation, 
PartitionSplitMode, PartitionType, StorageInfo}
 import org.apache.celeborn.common.protocol.message.ControlMessages._
 import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.rpc._
 import org.apache.celeborn.common.util.{JavaUtils, Utils}
-import 
org.apache.celeborn.service.deploy.worker.storage.{MapPartitionMetaHandler, 
PartitionDataWriter, SegmentMapPartitionMetaHandler, StorageManager}
+import 
org.apache.celeborn.service.deploy.worker.storage.{MapPartitionMetaHandler, 
PartitionDataWriter, StorageManager}
 
 private[deploy] class Controller(
     override val rpcEnv: RpcEnv,
@@ -64,11 +64,13 @@ private[deploy] class Controller(
   var commitThreadPool: ThreadPoolExecutor = _
   var commitFinishedChecker: ScheduledExecutorService = _
   var asyncReplyPool: ScheduledExecutorService = _
+  var createWriterThreadPool: ThreadPoolExecutor = _
   val minPartitionSizeToEstimate = conf.minPartitionSizeToEstimate
   var shutdown: AtomicBoolean = _
   val defaultPushdataTimeout = conf.pushDataTimeoutMs
   val mockCommitFilesFailure = conf.testMockCommitFilesFailure
   val shuffleCommitTimeout = conf.workerShuffleCommitTimeout
+  val createWriterParallelTimeout = conf.workerCreateWriterParallelTimeout
   val workerCommitFilesCheckInterval = conf.workerCommitFilesCheckInterval
 
   def init(worker: Worker): Unit = {
@@ -83,6 +85,7 @@ private[deploy] class Controller(
     timer = worker.timer
     commitThreadPool = worker.commitThreadPool
     asyncReplyPool = worker.asyncReplyPool
+    createWriterThreadPool = worker.createWriterThreadPool
     shutdown = worker.shutdown
 
     commitFinishedChecker = worker.commitFinishedChecker
@@ -192,88 +195,45 @@ private[deploy] class Controller(
       context.reply(ReserveSlotsResponse(StatusCode.NO_AVAILABLE_WORKING_DIR, 
msg))
       return
     }
-    val primaryLocs = new jArrayList[PartitionLocation]()
-    try {
-      for (ind <- 0 until requestPrimaryLocs.size()) {
-        var location = partitionLocationInfo.getPrimaryLocation(
-          shuffleKey,
-          requestPrimaryLocs.get(ind).getUniqueId)
-        if (location == null) {
-          location = requestPrimaryLocs.get(ind)
-          val writer = storageManager.createPartitionDataWriter(
-            applicationId,
-            shuffleId,
-            location,
-            splitThreshold,
-            splitMode,
-            partitionType,
-            rangeReadFilter,
-            userIdentifier,
-            partitionSplitEnabled,
-            isSegmentGranularityVisible)
-          primaryLocs.add(new WorkingPartition(location, writer))
-        } else {
-          primaryLocs.add(location)
-        }
-      }
-    } catch {
-      case e: Exception =>
-        logError(s"CreateWriter for $shuffleKey failed.", e)
-    }
+    val primaryLocs = createWriters(
+      shuffleKey,
+      applicationId,
+      shuffleId,
+      requestPrimaryLocs,
+      splitThreshold,
+      splitMode,
+      partitionType,
+      rangeReadFilter,
+      userIdentifier,
+      partitionSplitEnabled,
+      isSegmentGranularityVisible,
+      isPrimary = true)
     if (primaryLocs.size() < requestPrimaryLocs.size()) {
       val msg = s"Not all primary partition satisfied for $shuffleKey"
       logWarning(s"[handleReserveSlots] $msg, will destroy writers.")
-      primaryLocs.asScala.foreach { partitionLocation =>
-        val fileWriter = 
partitionLocation.asInstanceOf[WorkingPartition].getFileWriter
-        fileWriter.destroy(new IOException(s"Destroy FileWriter $fileWriter 
caused by " +
-          s"reserving slots failed for $shuffleKey."))
-      }
+      destroyWriters(primaryLocs, shuffleKey)
       context.reply(ReserveSlotsResponse(StatusCode.RESERVE_SLOTS_FAILED, msg))
       return
     }
 
-    val replicaLocs = new jArrayList[PartitionLocation]()
-    try {
-      for (ind <- 0 until requestReplicaLocs.size()) {
-        var location =
-          partitionLocationInfo.getReplicaLocation(
-            shuffleKey,
-            requestReplicaLocs.get(ind).getUniqueId)
-        if (location == null) {
-          location = requestReplicaLocs.get(ind)
-          val writer = storageManager.createPartitionDataWriter(
-            applicationId,
-            shuffleId,
-            location,
-            splitThreshold,
-            splitMode,
-            partitionType,
-            rangeReadFilter,
-            userIdentifier,
-            partitionSplitEnabled,
-            isSegmentGranularityVisible)
-          replicaLocs.add(new WorkingPartition(location, writer))
-        } else {
-          replicaLocs.add(location)
-        }
-      }
-    } catch {
-      case e: Exception =>
-        logError(s"CreateWriter for $shuffleKey failed.", e)
-    }
+    val replicaLocs = createWriters(
+      shuffleKey,
+      applicationId,
+      shuffleId,
+      requestReplicaLocs,
+      splitThreshold,
+      splitMode,
+      partitionType,
+      rangeReadFilter,
+      userIdentifier,
+      partitionSplitEnabled,
+      isSegmentGranularityVisible,
+      isPrimary = false)
     if (replicaLocs.size() < requestReplicaLocs.size()) {
       val msg = s"Not all replica partition satisfied for $shuffleKey"
       logWarning(s"[handleReserveSlots] $msg, destroy writers.")
-      primaryLocs.asScala.foreach { partitionLocation =>
-        val fileWriter = 
partitionLocation.asInstanceOf[WorkingPartition].getFileWriter
-        fileWriter.destroy(new IOException(s"Destroy FileWriter $fileWriter 
caused by " +
-          s"reserving slots failed for $shuffleKey."))
-      }
-      replicaLocs.asScala.foreach { partitionLocation =>
-        val fileWriter = 
partitionLocation.asInstanceOf[WorkingPartition].getFileWriter
-        fileWriter.destroy(new IOException(s"Destroy FileWriter $fileWriter 
caused by " +
-          s"reserving slots failed for $shuffleKey."))
-      }
+      destroyWriters(primaryLocs, shuffleKey)
+      destroyWriters(replicaLocs, shuffleKey)
       context.reply(ReserveSlotsResponse(StatusCode.RESERVE_SLOTS_FAILED, msg))
       return
     }
@@ -299,6 +259,110 @@ private[deploy] class Controller(
     context.reply(ReserveSlotsResponse(StatusCode.SUCCESS))
   }
 
+  private def createWriters(
+      shuffleKey: String,
+      applicationId: String,
+      shuffleId: Int,
+      requestLocs: jList[PartitionLocation],
+      splitThreshold: Long,
+      splitMode: PartitionSplitMode,
+      partitionType: PartitionType,
+      rangeReadFilter: Boolean,
+      userIdentifier: UserIdentifier,
+      partitionSplitEnabled: Boolean,
+      isSegmentGranularityVisible: Boolean,
+      isPrimary: Boolean): jList[PartitionLocation] = {
+    val partitionLocations = new jArrayList[PartitionLocation]()
+    try {
+      def createWriter(partitionLocation: PartitionLocation): 
PartitionLocation = {
+        createPartitionDataWriter(
+          shuffleKey,
+          applicationId,
+          shuffleId,
+          partitionLocation,
+          splitThreshold,
+          splitMode,
+          partitionType,
+          rangeReadFilter,
+          userIdentifier,
+          partitionSplitEnabled,
+          isSegmentGranularityVisible,
+          isPrimary)
+      }
+      if (createWriterThreadPool == null) {
+        partitionLocations.addAll(requestLocs.asScala.map(createWriter).asJava)
+      } else {
+        partitionLocations.addAll(Utils.tryFuturesWithTimeout(
+          requestLocs.asScala.map(requestLoc =>
+            
Utils.future(createWriter(requestLoc))(createWriterThreadPool)).toList,
+          createWriterParallelTimeout,
+          s"Create FileWriter for $shuffleKey timeout.").asJava)
+      }
+    } catch {
+      case e: Exception =>
+        logError(s"Create FileWriter for $shuffleKey failed.", e)
+    }
+    partitionLocations
+  }
+
+  private def createPartitionDataWriter(
+      shuffleKey: String,
+      applicationId: String,
+      shuffleId: Int,
+      requestLoc: PartitionLocation,
+      splitThreshold: Long,
+      splitMode: PartitionSplitMode,
+      partitionType: PartitionType,
+      rangeReadFilter: Boolean,
+      userIdentifier: UserIdentifier,
+      partitionSplitEnabled: Boolean,
+      isSegmentGranularityVisible: Boolean,
+      isPrimary: Boolean): PartitionLocation = {
+    try {
+      var location =
+        if (isPrimary) {
+          partitionLocationInfo.getPrimaryLocation(
+            shuffleKey,
+            requestLoc.getUniqueId)
+        } else {
+          partitionLocationInfo.getReplicaLocation(
+            shuffleKey,
+            requestLoc.getUniqueId)
+        }
+      if (location == null) {
+        location = requestLoc
+        val writer = storageManager.createPartitionDataWriter(
+          applicationId,
+          shuffleId,
+          location,
+          splitThreshold,
+          splitMode,
+          partitionType,
+          rangeReadFilter,
+          userIdentifier,
+          partitionSplitEnabled,
+          isSegmentGranularityVisible)
+        new WorkingPartition(location, writer)
+      } else {
+        location
+      }
+    } catch {
+      case e: Exception =>
+        logError(s"Create FileWriter for $requestLoc $shuffleKey failed.", e)
+        throw e
+    }
+  }
+
+  private def destroyWriters(
+      partitionLocations: jList[PartitionLocation],
+      shuffleKey: String): Unit = {
+    partitionLocations.asScala.foreach { partitionLocation =>
+      val fileWriter = 
partitionLocation.asInstanceOf[WorkingPartition].getFileWriter
+      fileWriter.destroy(new IOException(s"Destroy FileWriter  $fileWriter 
caused by " +
+        s"reserving slots failed for $shuffleKey."))
+    }
+  }
+
   private def commitFiles(
       shuffleKey: String,
       uniqueIds: jList[String],
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
index be382e36f..8786ad232 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
@@ -346,6 +346,7 @@ private[celeborn] class Worker(
       conf.workerCleanThreads)
   val asyncReplyPool: ScheduledExecutorService =
     
ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-rpc-async-replier")
+  var createWriterThreadPool: ThreadPoolExecutor = _
   val timer = new 
HashedWheelTimer(ThreadUtils.namedSingleThreadFactory("worker-timer"))
 
   // Configs
@@ -581,6 +582,13 @@ private[celeborn] class Worker(
       }
     })
 
+    if (conf.workerCreateWriterParallelEnabled) {
+      createWriterThreadPool =
+        ThreadUtils.newDaemonFixedThreadPool(
+          conf.workerCreateWriterParallelThreads,
+          "worker-writer-creator")
+    }
+
     pushDataHandler.init(this)
     replicateHandler.init(this)
     fetchHandler.init(this)
@@ -628,12 +636,18 @@ private[celeborn] class Worker(
         commitThreadPool.shutdown()
         commitFinishedChecker.shutdown();
         asyncReplyPool.shutdown()
+        if (createWriterThreadPool != null) {
+          createWriterThreadPool.shutdown()
+        }
       } else {
         forwardMessageScheduler.shutdownNow()
         replicateThreadPool.shutdownNow()
         commitThreadPool.shutdownNow()
         commitFinishedChecker.shutdownNow();
         asyncReplyPool.shutdownNow()
+        if (createWriterThreadPool != null) {
+          createWriterThreadPool.shutdownNow()
+        }
       }
       workerSource.appActiveConnections.clear()
       partitionsSorter.close(exitKind)

Reply via email to