This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch branch-0.5
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/branch-0.5 by this push:
     new caa060bb4 [CELEBORN-1686] Avoid return the same pushTaskQueue
caa060bb4 is described below

commit caa060bb40532c850386db301c050faf89ecd9e7
Author: sychen <[email protected]>
AuthorDate: Mon Nov 11 10:35:02 2024 +0800

    [CELEBORN-1686] Avoid return the same pushTaskQueue
    
    ### What changes were proposed in this pull request?
    
    ### Why are the changes needed?
    The close method of `SortBasedShuffleWriter#write` will call 
`sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());`, but the close 
method may be interrupted.
    
    After the interruption, `SortBasedShuffleWriter#cleanupPusher` will be 
called, and `sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());` 
will also be called.
    
    Since `SendBufferPool#pushTaskQueues` is a `LinkedList`, repeated add will 
store two identical `idleQueue`, which may cause multiple tasks running in 
parallel to share the same `idleQueue`, resulting in inaccurate data.
    
    ### Does this PR introduce _any_ user-facing change?
    
    ### How was this patch tested?
    Production environment verification
    
    Closes #2878 from cxzl25/CELEBORN-1686.
    
    Authored-by: sychen <[email protected]>
    Signed-off-by: Shuang <[email protected]>
    (cherry picked from commit 8f34d1555b2169159c5bf2d701ae50b206017dd6)
    Signed-off-by: Shuang <[email protected]>
---
 .../java/org/apache/spark/shuffle/celeborn/SendBufferPool.java    | 3 +++
 .../java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java   | 2 +-
 .../org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java | 4 ++--
 .../org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java | 4 ++--
 .../main/java/org/apache/celeborn/client/write/DataPusher.java    | 8 +++++---
 5 files changed, 13 insertions(+), 8 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SendBufferPool.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SendBufferPool.java
index c77fb4462..0d8a2301f 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SendBufferPool.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SendBufferPool.java
@@ -103,6 +103,9 @@ public class SendBufferPool {
   }
 
   public synchronized void returnPushTaskQueue(LinkedBlockingQueue<PushTask> 
pushTaskQueue) {
+    if (pushTaskQueue == null) {
+      return;
+    }
     if (pushTaskQueues.size() == capacity) {
       pushTaskQueues.removeFirst();
     }
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
index dd272d6c2..013785ecc 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
@@ -505,7 +505,7 @@ public class SortBasedPusher extends MemoryConsumer {
     cleanupResources();
     try {
       dataPusher.waitOnTermination();
-      sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
+      sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue());
     } catch (InterruptedException e) {
       if (throwTaskKilledOnInterruption) {
         TaskInterruptedHelper.throwTaskKillException();
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index 6db620b41..456da7e9a 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -325,7 +325,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private void cleanupPusher() throws IOException {
     try {
       dataPusher.waitOnTermination();
-      sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
+      sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue());
     } catch (InterruptedException e) {
       TaskInterruptedHelper.throwTaskKillException();
     }
@@ -334,7 +334,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private void close() throws IOException, InterruptedException {
     // here we wait for all the in-flight batches to return which sent by 
dataPusher thread
     dataPusher.waitOnTermination();
-    sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
+    sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue());
     shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
 
     // merge and push residual data to reduce network traffic
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index 7ed869844..4c5e6739b 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -359,7 +359,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private void cleanupPusher() throws IOException {
     try {
       dataPusher.waitOnTermination();
-      sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
+      sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue());
     } catch (InterruptedException e) {
       TaskInterruptedHelper.throwTaskKillException();
     }
@@ -369,7 +369,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     // here we wait for all the in-flight batches to return which sent by 
dataPusher thread
     long pushMergedDataTime = System.nanoTime();
     dataPusher.waitOnTermination();
-    sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
+    sendBufferPool.returnPushTaskQueue(dataPusher.getAndResetIdleQueue());
     shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
     closeWrite();
     shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
diff --git 
a/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java 
b/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java
index 4c8b83abe..9f6e4e007 100644
--- a/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java
+++ b/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java
@@ -41,7 +41,7 @@ public class DataPusher {
 
   private final long WAIT_TIME_NANOS = TimeUnit.MILLISECONDS.toNanos(500);
 
-  private final LinkedBlockingQueue<PushTask> idleQueue;
+  private LinkedBlockingQueue<PushTask> idleQueue;
   // partition -> PushTask Queue
   private final DataPushQueue dataPushQueue;
   private final ReentrantLock idleLock = new ReentrantLock();
@@ -235,7 +235,9 @@ public class DataPusher {
     return dataPushQueue;
   }
 
-  public LinkedBlockingQueue<PushTask> getIdleQueue() {
-    return idleQueue;
+  public LinkedBlockingQueue<PushTask> getAndResetIdleQueue() {
+    LinkedBlockingQueue<PushTask> queue = idleQueue;
+    idleQueue = null;
+    return queue;
   }
 }

Reply via email to