This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new bc3bd460e [CELEBORN-1544][FOLLOWUP] ShuffleWriter needs to catch
exception and call abort to avoid memory leaks
bc3bd460e is described below
commit bc3bd460e8069878395c140cc017817fe28e1954
Author: sychen <[email protected]>
AuthorDate: Wed Aug 21 14:45:57 2024 +0800
[CELEBORN-1544][FOLLOWUP] ShuffleWriter needs to catch exception and call
abort to avoid memory leaks
### What changes were proposed in this pull request?
This PR aims to fix a possible memory leak in ShuffleWriter.
Introduce a private abort method, which can be called to release memory
when an exception occurs.
### Why are the changes needed?
https://github.com/apache/celeborn/pull/2661 Call the close method in the
finally block, but the close method has `shuffleClient.mapperEnd`, which is
dangerous for incomplete tasks, and the data may be inaccurate.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
GA
Closes #2663 from cxzl25/CELEBORN-1544-followup.
Authored-by: sychen <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
---
.../apache/spark/shuffle/celeborn/SortBasedPusher.java | 6 ++++--
.../spark/shuffle/celeborn/SortBasedPusherSuiteJ.java | 2 +-
.../spark/shuffle/celeborn/HashBasedShuffleWriter.java | 18 ++++++++++++++----
.../spark/shuffle/celeborn/SortBasedShuffleWriter.java | 15 +++++++++++++--
.../spark/shuffle/celeborn/HashBasedShuffleWriter.java | 18 ++++++++++++++----
.../spark/shuffle/celeborn/SortBasedShuffleWriter.java | 15 +++++++++++++--
6 files changed, 59 insertions(+), 15 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
index 01d852017..8cd2a4874 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
@@ -501,13 +501,15 @@ public class SortBasedPusher extends MemoryConsumer {
return this.pushSortMemoryThreshold;
}
- public void close() throws IOException {
+ public void close(boolean throwTaskKilledOnInterruption) throws IOException {
cleanupResources();
try {
dataPusher.waitOnTermination();
sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
} catch (InterruptedException e) {
- TaskInterruptedHelper.throwTaskKillException();
+ if (throwTaskKilledOnInterruption) {
+ TaskInterruptedHelper.throwTaskKillException();
+ }
}
}
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
index 0962c98c4..73c15bb70 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
@@ -127,7 +127,7 @@ public class SortBasedPusherSuiteJ {
!pusher.insertRecord(
row5k.getBaseObject(), row5k.getBaseOffset(),
row5k.getSizeInBytes(), 0, true));
- pusher.close();
+ pusher.close(true);
assertEquals(taskContext.taskMetrics().memoryBytesSpilled(), 2097152);
}
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index 0986c23ef..fd83ec2e1 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -162,6 +162,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws
IOException {
+ boolean needCleanupPusher = true;
try {
if (canUseFastWrite()) {
fastWrite0(records);
@@ -174,13 +175,13 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
} else {
write0(records);
}
+ close();
+ needCleanupPusher = false;
} catch (InterruptedException e) {
TaskInterruptedHelper.throwTaskKillException();
} finally {
- try {
- close();
- } catch (InterruptedException e) {
- TaskInterruptedHelper.throwTaskKillException();
+ if (needCleanupPusher) {
+ cleanupPusher();
}
}
}
@@ -319,6 +320,15 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeMetrics.incWriteTime(System.nanoTime() - start);
}
+ private void cleanupPusher() throws IOException {
+ try {
+ dataPusher.waitOnTermination();
+ sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
+ } catch (InterruptedException e) {
+ TaskInterruptedHelper.throwTaskKillException();
+ }
+ }
+
private void close() throws IOException, InterruptedException {
// here we wait for all the in-flight batches to return which sent by
dataPusher thread
dataPusher.waitOnTermination();
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 5ecf1edba..7b8baaf06 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -143,6 +143,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws
IOException {
+ boolean needCleanupPusher = true;
try {
if (canUseFastWrite()) {
fastWrite0(records);
@@ -155,8 +156,12 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
} else {
write0(records);
}
- } finally {
close();
+ needCleanupPusher = false;
+ } finally {
+ if (needCleanupPusher) {
+ cleanupPusher();
+ }
}
}
@@ -291,11 +296,17 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeMetrics.incBytesWritten(bytesWritten);
}
+ private void cleanupPusher() throws IOException {
+ if (pusher != null) {
+ pusher.close(false);
+ }
+ }
+
private void close() throws IOException {
logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed()));
long pushStartTime = System.nanoTime();
pusher.pushData(false);
- pusher.close();
+ pusher.close(true);
writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
shuffleClient.pushMergedData(shuffleId, mapId,
taskContext.attemptNumber());
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index 5a9b04552..bd1eeb16c 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -158,6 +158,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws
IOException {
+ boolean needCleanupPusher = true;
try {
if (canUseFastWrite()) {
fastWrite0(records);
@@ -170,13 +171,13 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
} else {
write0(records);
}
+ close();
+ needCleanupPusher = false;
} catch (InterruptedException e) {
TaskInterruptedHelper.throwTaskKillException();
} finally {
- try {
- close();
- } catch (InterruptedException e) {
- TaskInterruptedHelper.throwTaskKillException();
+ if (needCleanupPusher) {
+ cleanupPusher();
}
}
}
@@ -353,6 +354,15 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeMetrics.incBytesWritten(bytesWritten);
}
+ private void cleanupPusher() throws IOException {
+ try {
+ dataPusher.waitOnTermination();
+ sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
+ } catch (InterruptedException e) {
+ TaskInterruptedHelper.throwTaskKillException();
+ }
+ }
+
private void close() throws IOException, InterruptedException {
// here we wait for all the in-flight batches to return which sent by
dataPusher thread
long pushMergedDataTime = System.nanoTime();
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index f3b856394..1c8900bb2 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -226,10 +226,15 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws
IOException {
+ boolean needCleanupPusher = true;
try {
doWrite(records);
- } finally {
close();
+ needCleanupPusher = false;
+ } finally {
+ if (needCleanupPusher) {
+ cleanupPusher();
+ }
}
}
@@ -354,11 +359,17 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeMetrics.incBytesWritten(bytesWritten);
}
+ private void cleanupPusher() throws IOException {
+ if (pusher != null) {
+ pusher.close(false);
+ }
+ }
+
private void close() throws IOException {
logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed()));
long pushStartTime = System.nanoTime();
pusher.pushData(false);
- pusher.close();
+ pusher.close(true);
shuffleClient.pushMergedData(shuffleId, mapId,
taskContext.attemptNumber());
writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);