This is an automated email from the ASF dual-hosted git repository.
nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 46a49a828 [CELEBORN-2250] Fix lock contention in
ReducePartitionCommitHandler.finishMapperAttempt via fine-grained locks
46a49a828 is described below
commit 46a49a8285b3a77f51b1a3223760cab7d8182667
Author: yew1eb <[email protected]>
AuthorDate: Tue Jan 27 21:15:44 2026 +0800
[CELEBORN-2250] Fix lock contention in
ReducePartitionCommitHandler.finishMapperAttempt via fine-grained locks
### What changes were proposed in this pull request?
Add `shuffleIdLocks` (fine-grained locks per shuffleId), replace global
`shuffleMapperAttempts` lock in `initMapperAttempts` and `finishMapperAttempt`.
### Why are the changes needed?
High concurrency causes lock contention on `shuffleMapperAttempts` in
`finishMapperAttempt`, leading to abnormally long shuffle write time for small
queries in Kyuubi Shared Mode. Fine-grained locks eliminate cross-shuffle
blocking and improve concurrency.
### Does this PR resolve a correctness bug?
No.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI.
Closes #3586 from yew1eb/CELEBORN-2250.
Authored-by: yew1eb <[email protected]>
Signed-off-by: SteNicholas <[email protected]>
---
.../client/commit/ReducePartitionCommitHandler.scala | 13 +++++++++++--
1 file changed, 11 insertions(+), 2 deletions(-)
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index fb557ebbe..1733455c6 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -77,6 +77,8 @@ class ReducePartitionCommitHandler(
private val shuffleMapperAttempts = JavaUtils.newConcurrentHashMap[Int,
Array[Int]]()
// TODO: Move this to native Int -> Int Map
private val shuffleToCompletedMappers = JavaUtils.newConcurrentHashMap[Int,
Int]()
+ private val shuffleIdLocks = JavaUtils.newConcurrentHashMap[Int, Object]()
+
private val stageEndTimeout = conf.clientPushStageEndTimeout
private val mockShuffleLost = conf.testMockShuffleLost
private val mockShuffleLostShuffle = conf.testMockShuffleLostShuffle
@@ -119,6 +121,10 @@ class ReducePartitionCommitHandler(
}
}
+ private val shuffleIdLocksRegisterFunc = new util.function.Function[Int,
Object] {
+ override def apply(key: Int): Object = new Object()
+ }
+
override def getPartitionType(): PartitionType = {
PartitionType.REDUCE
}
@@ -178,6 +184,7 @@ class ReducePartitionCommitHandler(
inProcessStageEndShuffleSet.remove(shuffleId)
shuffleMapperAttempts.remove(shuffleId)
shuffleToCompletedMappers.remove(shuffleId)
+ shuffleIdLocks.remove(shuffleId)
commitMetadataForReducer.remove(shuffleId)
skewPartitionCompletenessValidator.remove(shuffleId)
super.removeExpiredShuffle(shuffleId)
@@ -314,7 +321,8 @@ class ReducePartitionCommitHandler(
numPartitions: Int,
crc32PerPartition: Array[Int],
bytesWrittenPerPartition: Array[Long]): (Boolean, Boolean) = {
- val (mapperAttemptFinishedSuccess, allMapperFinished) =
shuffleMapperAttempts.synchronized {
+ val shuffleLock = shuffleIdLocks.computeIfAbsent(shuffleId,
shuffleIdLocksRegisterFunc)
+ val (mapperAttemptFinishedSuccess, allMapperFinished) =
shuffleLock.synchronized {
if (getMapperAttempts(shuffleId) == null) {
logDebug(s"[handleMapperEnd] $shuffleId not registered, create one.")
initMapperAttempts(shuffleId, numMappers, numPartitions)
@@ -428,7 +436,8 @@ class ReducePartitionCommitHandler(
}
private def initMapperAttempts(shuffleId: Int, numMappers: Int,
numPartitions: Int): Unit = {
- shuffleMapperAttempts.synchronized {
+ val shuffleLock = shuffleIdLocks.computeIfAbsent(shuffleId,
shuffleIdLocksRegisterFunc)
+ shuffleLock.synchronized {
if (!shuffleMapperAttempts.containsKey(shuffleId)) {
val attempts = new Array[Int](numMappers)
util.Arrays.fill(attempts, -1)