This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch branch-0.5
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/branch-0.5 by this push:
     new f639be462 [CELEBORN-1544][FOLLOWUP] ShuffleWriter needs to catch 
exception and call abort to avoid memory leaks
f639be462 is described below

commit f639be462d961cf798bb90f01283a8ea12044fa1
Author: sychen <[email protected]>
AuthorDate: Wed Aug 21 14:45:57 2024 +0800

    [CELEBORN-1544][FOLLOWUP] ShuffleWriter needs to catch exception and call 
abort to avoid memory leaks
    
    ### What changes were proposed in this pull request?
    This PR aims to fix a possible memory leak in ShuffleWriter.
    
    Introduce a private abort method, which can be called to release memory 
when an exception occurs.
    
    ### Why are the changes needed?
    https://github.com/apache/celeborn/pull/2661 Call the close method in the 
finally block, but the close method has `shuffleClient.mapperEnd`, which is 
dangerous for incomplete tasks, and the data may be inaccurate.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    GA
    
    Closes #2663 from cxzl25/CELEBORN-1544-followup.
    
    Authored-by: sychen <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
    (cherry picked from commit bc3bd460e8069878395c140cc017817fe28e1954)
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../apache/spark/shuffle/celeborn/SortBasedPusher.java |  6 ++++--
 .../spark/shuffle/celeborn/SortBasedPusherSuiteJ.java  |  2 +-
 .../spark/shuffle/celeborn/HashBasedShuffleWriter.java | 18 ++++++++++++++----
 .../spark/shuffle/celeborn/SortBasedShuffleWriter.java | 15 +++++++++++++--
 .../spark/shuffle/celeborn/HashBasedShuffleWriter.java | 18 ++++++++++++++----
 .../spark/shuffle/celeborn/SortBasedShuffleWriter.java | 15 +++++++++++++--
 6 files changed, 59 insertions(+), 15 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
index 01d852017..8cd2a4874 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
@@ -501,13 +501,15 @@ public class SortBasedPusher extends MemoryConsumer {
     return this.pushSortMemoryThreshold;
   }
 
-  public void close() throws IOException {
+  public void close(boolean throwTaskKilledOnInterruption) throws IOException {
     cleanupResources();
     try {
       dataPusher.waitOnTermination();
       sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
     } catch (InterruptedException e) {
-      TaskInterruptedHelper.throwTaskKillException();
+      if (throwTaskKilledOnInterruption) {
+        TaskInterruptedHelper.throwTaskKillException();
+      }
     }
   }
 
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
index 0962c98c4..73c15bb70 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
@@ -127,7 +127,7 @@ public class SortBasedPusherSuiteJ {
         !pusher.insertRecord(
             row5k.getBaseObject(), row5k.getBaseOffset(), 
row5k.getSizeInBytes(), 0, true));
 
-    pusher.close();
+    pusher.close(true);
 
     assertEquals(taskContext.taskMetrics().memoryBytesSpilled(), 2097152);
   }
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index 0986c23ef..fd83ec2e1 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -162,6 +162,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   @Override
   public void write(scala.collection.Iterator<Product2<K, V>> records) throws 
IOException {
+    boolean needCleanupPusher = true;
     try {
       if (canUseFastWrite()) {
         fastWrite0(records);
@@ -174,13 +175,13 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       } else {
         write0(records);
       }
+      close();
+      needCleanupPusher = false;
     } catch (InterruptedException e) {
       TaskInterruptedHelper.throwTaskKillException();
     } finally {
-      try {
-        close();
-      } catch (InterruptedException e) {
-        TaskInterruptedHelper.throwTaskKillException();
+      if (needCleanupPusher) {
+        cleanupPusher();
       }
     }
   }
@@ -319,6 +320,15 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     writeMetrics.incWriteTime(System.nanoTime() - start);
   }
 
+  private void cleanupPusher() throws IOException {
+    try {
+      dataPusher.waitOnTermination();
+      sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
+    } catch (InterruptedException e) {
+      TaskInterruptedHelper.throwTaskKillException();
+    }
+  }
+
   private void close() throws IOException, InterruptedException {
     // here we wait for all the in-flight batches to return which sent by 
dataPusher thread
     dataPusher.waitOnTermination();
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 5ecf1edba..7b8baaf06 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -143,6 +143,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   @Override
   public void write(scala.collection.Iterator<Product2<K, V>> records) throws 
IOException {
+    boolean needCleanupPusher = true;
     try {
       if (canUseFastWrite()) {
         fastWrite0(records);
@@ -155,8 +156,12 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       } else {
         write0(records);
       }
-    } finally {
       close();
+      needCleanupPusher = false;
+    } finally {
+      if (needCleanupPusher) {
+        cleanupPusher();
+      }
     }
   }
 
@@ -291,11 +296,17 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     writeMetrics.incBytesWritten(bytesWritten);
   }
 
+  private void cleanupPusher() throws IOException {
+    if (pusher != null) {
+      pusher.close(false);
+    }
+  }
+
   private void close() throws IOException {
     logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed()));
     long pushStartTime = System.nanoTime();
     pusher.pushData(false);
-    pusher.close();
+    pusher.close(true);
     writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
 
     shuffleClient.pushMergedData(shuffleId, mapId, 
taskContext.attemptNumber());
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index 5a9b04552..bd1eeb16c 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -158,6 +158,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   @Override
   public void write(scala.collection.Iterator<Product2<K, V>> records) throws 
IOException {
+    boolean needCleanupPusher = true;
     try {
       if (canUseFastWrite()) {
         fastWrite0(records);
@@ -170,13 +171,13 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       } else {
         write0(records);
       }
+      close();
+      needCleanupPusher = false;
     } catch (InterruptedException e) {
       TaskInterruptedHelper.throwTaskKillException();
     } finally {
-      try {
-        close();
-      } catch (InterruptedException e) {
-        TaskInterruptedHelper.throwTaskKillException();
+      if (needCleanupPusher) {
+        cleanupPusher();
       }
     }
   }
@@ -353,6 +354,15 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     writeMetrics.incBytesWritten(bytesWritten);
   }
 
+  private void cleanupPusher() throws IOException {
+    try {
+      dataPusher.waitOnTermination();
+      sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
+    } catch (InterruptedException e) {
+      TaskInterruptedHelper.throwTaskKillException();
+    }
+  }
+
   private void close() throws IOException, InterruptedException {
     // here we wait for all the in-flight batches to return which sent by 
dataPusher thread
     long pushMergedDataTime = System.nanoTime();
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index f3b856394..1c8900bb2 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -226,10 +226,15 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   @Override
   public void write(scala.collection.Iterator<Product2<K, V>> records) throws 
IOException {
+    boolean needCleanupPusher = true;
     try {
       doWrite(records);
-    } finally {
       close();
+      needCleanupPusher = false;
+    } finally {
+      if (needCleanupPusher) {
+        cleanupPusher();
+      }
     }
   }
 
@@ -354,11 +359,17 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     writeMetrics.incBytesWritten(bytesWritten);
   }
 
+  private void cleanupPusher() throws IOException {
+    if (pusher != null) {
+      pusher.close(false);
+    }
+  }
+
   private void close() throws IOException {
     logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed()));
     long pushStartTime = System.nanoTime();
     pusher.pushData(false);
-    pusher.close();
+    pusher.close(true);
 
     shuffleClient.pushMergedData(shuffleId, mapId, 
taskContext.attemptNumber());
     writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);

Reply via email to