This is an automated email from the ASF dual-hosted git repository.

ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new fd490013a Revert "[CELEBORN-1388] Use finer grained locks in 
changePartitionManager"
fd490013a is described below

commit fd490013aecc2971a7c396f117b8df1465f3775e
Author: mingji <[email protected]>
AuthorDate: Thu May 30 11:18:58 2024 +0800

    Revert "[CELEBORN-1388] Use finer grained locks in changePartitionManager"
    
    This reverts commit 9f304798cb2147fe4e9d900e85832c1034397863.
---
 .../celeborn/client/ChangePartitionManager.scala   | 81 +++++++---------------
 .../org/apache/celeborn/common/CelebornConf.scala  | 10 ---
 docs/configuration/client.md                       |  1 -
 3 files changed, 26 insertions(+), 66 deletions(-)

diff --git 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
index 25a2ab09f..9cf1dca52 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
@@ -47,14 +47,8 @@ class ChangePartitionManager(
   // shuffleId -> (partitionId -> set of ChangePartition)
   private val changePartitionRequests =
     JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap[Integer, 
JSet[ChangePartitionRequest]]]()
-
-  // shuffleId -> locks
-  private val locks = JavaUtils.newConcurrentHashMap[Int, Array[AnyRef]]()
-  private val lockBucketSize = conf.batchHandleChangePartitionBuckets
-
   // shuffleId -> set of partition id
-  private val inBatchPartitions =
-    JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap.KeySetView[Int, 
java.lang.Boolean]]()
+  private val inBatchPartitions = JavaUtils.newConcurrentHashMap[Int, 
JSet[Integer]]()
 
   private val batchHandleChangePartitionEnabled = 
conf.batchHandleChangePartitionEnabled
   private val batchHandleChangePartitionExecutors = 
ThreadUtils.newDaemonCachedThreadPool(
@@ -85,19 +79,14 @@ class ChangePartitionManager(
                 batchHandleChangePartitionExecutors.submit {
                   new Runnable {
                     override def run(): Unit = {
-                      val distinctPartitions = {
-                        val requestSet = inBatchPartitions.get(shuffleId)
-                        val locksForShuffle = locks.computeIfAbsent(shuffleId, 
locksRegisterFunc)
-                        requests.asScala.map { case (partitionId, request) =>
-                          locksForShuffle(partitionId % 
locksForShuffle.length).synchronized {
-                            if (!requestSet.contains(partitionId)) {
-                              requestSet.add(partitionId)
-                              Some(request.asScala.toArray.maxBy(_.epoch))
-                            } else {
-                              None
-                            }
-                          }
-                        }.filter(_.isDefined).map(_.get).toArray
+                      val distinctPartitions = requests.synchronized {
+                        // For each partition only need handle one request
+                        requests.asScala.filter { case (partitionId, _) =>
+                          
!inBatchPartitions.get(shuffleId).contains(partitionId)
+                        }.map { case (partitionId, request) =>
+                          inBatchPartitions.get(shuffleId).add(partitionId)
+                          request.asScala.toArray.maxBy(_.epoch)
+                        }.toArray
                       }
                       if (distinctPartitions.nonEmpty) {
                         handleRequestPartitions(
@@ -134,16 +123,8 @@ class ChangePartitionManager(
         JavaUtils.newConcurrentHashMap()
     }
 
-  private val inBatchShuffleIdRegisterFunc =
-    new util.function.Function[Int, ConcurrentHashMap.KeySetView[Int, 
java.lang.Boolean]]() {
-      override def apply(s: Int): ConcurrentHashMap.KeySetView[Int, 
java.lang.Boolean] =
-        ConcurrentHashMap.newKeySet[Int]()
-    }
-
-  private val locksRegisterFunc = new util.function.Function[Int, 
Array[AnyRef]] {
-    override def apply(t: Int): Array[AnyRef] = {
-      Array.fill(lockBucketSize)(new AnyRef())
-    }
+  private val inBatchShuffleIdRegisterFunc = new util.function.Function[Int, 
util.Set[Integer]]() {
+    override def apply(s: Int): util.Set[Integer] = new util.HashSet[Integer]()
   }
 
   def handleRequestPartitionLocation(
@@ -170,22 +151,15 @@ class ChangePartitionManager(
       oldPartition,
       cause)
 
-    val locksForShuffle = locks.computeIfAbsent(shuffleId, locksRegisterFunc)
-    locksForShuffle(partitionId % locksForShuffle.length).synchronized {
-      var newEntry = false
-      val set = requests.computeIfAbsent(
-        partitionId,
-        new java.util.function.Function[Integer, 
util.Set[ChangePartitionRequest]] {
-          override def apply(t: Integer): util.Set[ChangePartitionRequest] = {
-            newEntry = true
-            new util.HashSet[ChangePartitionRequest]()
-          }
-        })
-
-      if (newEntry) {
+    requests.synchronized {
+      if (requests.containsKey(partitionId)) {
+        requests.get(partitionId).add(changePartition)
         logTrace(s"[handleRequestPartitionLocation] For $shuffleId, request 
for same partition" +
           s"$partitionId-$oldEpoch exists, register context.")
+        return
       } else {
+        // If new slot for the partition has been allocated, reply and return.
+        // Else register and allocate for it.
         getLatestPartition(shuffleId, partitionId, oldEpoch).foreach { 
latestLoc =>
           context.reply(
             partitionId,
@@ -196,8 +170,10 @@ class ChangePartitionManager(
             s" shuffleId: $shuffleId $latestLoc")
           return
         }
+        val set = new util.HashSet[ChangePartitionRequest]()
+        set.add(changePartition)
+        requests.put(partitionId, set)
       }
-      set.add(changePartition)
     }
     if (!batchHandleChangePartitionEnabled) {
       handleRequestPartitions(shuffleId, Array(changePartition))
@@ -240,16 +216,14 @@ class ChangePartitionManager(
 
     // remove together to reduce lock time
     def replySuccess(locations: Array[PartitionLocation]): Unit = {
-      val locksForShuffle = locks.computeIfAbsent(shuffleId, locksRegisterFunc)
-      locations.map { location =>
-        locksForShuffle(location.getId % locksForShuffle.length).synchronized {
-          val ret = requestsMap.remove(location.getId)
+      requestsMap.synchronized {
+        locations.map { location =>
           if (batchHandleChangePartitionEnabled) {
             inBatchPartitions.get(shuffleId).remove(location.getId)
           }
           // Here one partition id can be remove more than once,
           // so need to filter null result before reply.
-          location -> Option(ret)
+          location -> Option(requestsMap.remove(location.getId))
         }
       }.foreach { case (newLocation, requests) =>
         requests.map(_.asScala.toList.foreach(req =>
@@ -263,14 +237,12 @@ class ChangePartitionManager(
 
     // remove together to reduce lock time
     def replyFailure(status: StatusCode): Unit = {
-      changePartitions.map { changePartition =>
-        val locksForShuffle = locks.computeIfAbsent(shuffleId, 
locksRegisterFunc)
-        locksForShuffle(changePartition.partitionId % 
locksForShuffle.length).synchronized {
-          val r = requestsMap.remove(changePartition.partitionId)
+      requestsMap.synchronized {
+        changePartitions.map { changePartition =>
           if (batchHandleChangePartitionEnabled) {
             
inBatchPartitions.get(shuffleId).remove(changePartition.partitionId)
           }
-          Option(r)
+          Option(requestsMap.remove(changePartition.partitionId))
         }
       }.foreach { requests =>
         requests.map(_.asScala.toList.foreach(req =>
@@ -353,6 +325,5 @@ class ChangePartitionManager(
   def removeExpiredShuffle(shuffleId: Int): Unit = {
     changePartitionRequests.remove(shuffleId)
     inBatchPartitions.remove(shuffleId)
-    locks.remove(shuffleId)
   }
 }
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 68fa938b5..36635e84e 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -1023,8 +1023,6 @@ class CelebornConf(loadDefaults: Boolean) extends 
Cloneable with Logging with Se
     PartitionSplitMode.valueOf(get(SHUFFLE_PARTITION_SPLIT_MODE))
   def shufflePartitionSplitThreshold: Long = 
get(SHUFFLE_PARTITION_SPLIT_THRESHOLD)
   def batchHandleChangePartitionEnabled: Boolean = 
get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_ENABLED)
-  def batchHandleChangePartitionBuckets: Int =
-    get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_BUCKETS)
   def batchHandleChangePartitionNumThreads: Int = 
get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_THREADS)
   def batchHandleChangePartitionRequestInterval: Long =
     get(CLIENT_BATCH_HANDLE_CHANGE_PARTITION_INTERVAL)
@@ -4086,14 +4084,6 @@ object CelebornConf extends Logging {
       .booleanConf
       .createWithDefault(true)
 
-  val CLIENT_BATCH_HANDLE_CHANGE_PARTITION_BUCKETS: ConfigEntry[Int] =
-    
buildConf("celeborn.client.shuffle.batchHandleChangePartition.partitionBuckets")
-      .categories("client")
-      .doc("Max number of change partition requests which can be concurrently 
processed ")
-      .version("0.5.0")
-      .intConf
-      .createWithDefault(256)
-
   val CLIENT_BATCH_HANDLE_CHANGE_PARTITION_THREADS: ConfigEntry[Int] =
     buildConf("celeborn.client.shuffle.batchHandleChangePartition.threads")
       .withAlternative("celeborn.shuffle.batchHandleChangePartition.threads")
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 0fe90d0ec..6d290080a 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -82,7 +82,6 @@ license: |
 | celeborn.client.rpc.reserveSlots.askTimeout | &lt;value of 
celeborn.rpc.askTimeout&gt; | false | Timeout for LifecycleManager request 
reserve slots. | 0.3.0 |  | 
 | celeborn.client.rpc.shared.threads | 16 | false | Number of shared rpc 
threads in LifecycleManager. | 0.3.2 |  | 
 | celeborn.client.shuffle.batchHandleChangePartition.interval | 100ms | false 
| Interval for LifecycleManager to schedule handling change partition requests 
in batch. | 0.3.0 | celeborn.shuffle.batchHandleChangePartition.interval | 
-| celeborn.client.shuffle.batchHandleChangePartition.partitionBuckets | 256 | 
false | Max number of change partition requests which can be concurrently 
processed  | 0.5.0 |  | 
 | celeborn.client.shuffle.batchHandleChangePartition.threads | 8 | false | 
Threads number for LifecycleManager to handle change partition request in 
batch. | 0.3.0 | celeborn.shuffle.batchHandleChangePartition.threads | 
 | celeborn.client.shuffle.batchHandleCommitPartition.interval | 5s | false | 
Interval for LifecycleManager to schedule handling commit partition requests in 
batch. | 0.3.0 | celeborn.shuffle.batchHandleCommitPartition.interval | 
 | celeborn.client.shuffle.batchHandleCommitPartition.threads | 8 | false | 
Threads number for LifecycleManager to handle commit partition request in 
batch. | 0.3.0 | celeborn.shuffle.batchHandleCommitPartition.threads | 

Reply via email to