This is an automated email from the ASF dual-hosted git repository.
mridulm80 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new 145c045654d7 [SPARK-47910][CORE] close stream when
DiskBlockObjectWriter closeResources to avoid memory leak
145c045654d7 is described below
commit 145c045654d7b8e3738c1bbfe83a58d087fbef24
Author: JacobZheng0927 <[email protected]>
AuthorDate: Tue Jun 18 00:05:26 2024 -0500
[SPARK-47910][CORE] close stream when DiskBlockObjectWriter closeResources
to avoid memory leak
### What changes were proposed in this pull request?
close stream when DiskBlockObjectWriter closeResources to avoid memory leak
### Why are the changes needed?
[SPARK-34647](https://issues.apache.org/jira/browse/SPARK-34647) replaced
the ZstdInputStream with ZstdInputStreamNoFinalizer. This meant that all usages
of CompressionCodec.compressedOutputStream would need to manually close the
stream as this would no longer be handled by the finalizer mechanism.
When using zstd for shuffle write compression, if for some reason the
execution of this process is interrupted(eg. enable
spark.sql.execution.interruptOnCancel and cancel Job). The memory used by
`ZstdInputStreamNoFinalizer` may not be freed, causing a memory leak.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
#### Spark Shell Configuration
```
$> export SPARK_SUBMIT_OPTS="-XX:+AlwaysPreTouch -Xms1g"
$> $SPARK_HOME/bin/spark-shell --conf spark.io.compression.codec=zstd
```
#### Test Script
```scala
import java.util.concurrent.TimeUnit
import scala.concurrent.Future
import scala.concurrent.ExecutionContext.Implicits.global
import scala.util.Random
sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel =
true)
(1 to 50).foreach { batch => {
val jobA = Future {
val df1 = spark.range(2000000).map { _
=>(Random.nextString(20),Random.nextInt(1000),Random.nextInt(1000),Random.nextInt(10))}.toDF("a","b","c","d")
val df2 = spark.range(2000000).map { _
=>(Random.nextString(20),Random.nextInt(1000),Random.nextInt(1000),Random.nextInt(10))}.toDF("a","b","c","d")
df1.join(df2,"b").show() }
Thread.sleep(5000)
sc.cancelJobGroup("jobA")
}}
```
#### Memory Monitor
```
$> while true; do echo \"$(date +%Y-%m-%d' '%H:%M:%S)\",$(pmap -x <PID> |
grep "total kB" | awk '{print $4}'); sleep 10; done;
```
#### Results
##### Before
```
"2024-05-13 16:54:23",1332384
"2024-05-13 16:54:33",1417112
"2024-05-13 16:54:43",2211684
"2024-05-13 16:54:53",3060820
"2024-05-13 16:55:03",3850444
"2024-05-13 16:55:14",4631744
"2024-05-13 16:55:24",5317200
"2024-05-13 16:55:34",6019464
"2024-05-13 16:55:44",6489180
"2024-05-13 16:55:54",7255548
"2024-05-13 16:56:05",7718728
"2024-05-13 16:56:15",8388392
"2024-05-13 16:56:25",8927636
"2024-05-13 16:56:36",9473412
"2024-05-13 16:56:46",10000380
"2024-05-13 16:56:56",10344024
"2024-05-13 16:57:07",10734204
"2024-05-13 16:57:17",11211900
"2024-05-13 16:57:27",11665524
"2024-05-13 16:57:38",12268976
"2024-05-13 16:57:48",12896264
"2024-05-13 16:57:58",13572244
"2024-05-13 16:58:09",14252416
"2024-05-13 16:58:19",14915560
"2024-05-13 16:58:30",15484196
"2024-05-13 16:58:40",16170324
```
##### After
```
"2024-05-13 16:35:44",1355428
"2024-05-13 16:35:54",1391028
"2024-05-13 16:36:04",1673720
"2024-05-13 16:36:14",2103716
"2024-05-13 16:36:24",2129876
"2024-05-13 16:36:35",2166412
"2024-05-13 16:36:45",2177672
"2024-05-13 16:36:55",2188340
"2024-05-13 16:37:05",2190688
"2024-05-13 16:37:15",2195168
"2024-05-13 16:37:26",2199296
"2024-05-13 16:37:36",2228052
"2024-05-13 16:37:46",2238104
"2024-05-13 16:37:56",2260624
"2024-05-13 16:38:06",2307184
"2024-05-13 16:38:16",2331140
"2024-05-13 16:38:27",2323388
"2024-05-13 16:38:37",2357552
"2024-05-13 16:38:47",2352948
"2024-05-13 16:38:57",2364744
"2024-05-13 16:39:07",2368528
"2024-05-13 16:39:18",2385492
"2024-05-13 16:39:28",2389184
"2024-05-13 16:39:38",2388060
"2024-05-13 16:39:48",2388336
"2024-05-13 16:39:58",2386916
```
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #46131 from JacobZheng0927/zstdMemoryLeak.
Authored-by: JacobZheng0927 <[email protected]>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
(cherry picked from commit e265c6041aa2c29d73647b7bd69012c1dc152a1c)
Signed-off-by: Mridul Muralidharan <mridulatgmail.com>
---
.../spark/storage/DiskBlockObjectWriter.scala | 49 +++++++++----
.../spark/storage/DiskBlockObjectWriterSuite.scala | 82 ++++++++++++++++++++--
2 files changed, 114 insertions(+), 17 deletions(-)
diff --git
a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index f8bd73e65617..5110870b4fac 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -125,6 +125,12 @@ private[spark] class DiskBlockObjectWriter(
*/
private var numRecordsCommitted = 0L
+ // For testing only.
+ private[storage] def getSerializerWrappedStream: OutputStream = bs
+
+ // For testing only.
+ private[storage] def getSerializationStream: SerializationStream = objOut
+
/**
* Set the checksum that the checksumOutputStream should use
*/
@@ -173,19 +179,36 @@ private[spark] class DiskBlockObjectWriter(
* Should call after committing or reverting partial writes.
*/
private def closeResources(): Unit = {
- if (initialized) {
- Utils.tryWithSafeFinally {
- mcs.manualClose()
- } {
- channel = null
- mcs = null
- bs = null
- fos = null
- ts = null
- objOut = null
- initialized = false
- streamOpen = false
- hasBeenClosed = true
+ try {
+ if (streamOpen) {
+ Utils.tryWithSafeFinally {
+ if (null != objOut) objOut.close()
+ bs = null
+ } {
+ objOut = null
+ if (null != bs) bs.close()
+ bs = null
+ }
+ }
+ } catch {
+ case e: IOException =>
+ logInfo(log"Exception occurred while closing the output stream" +
+ log"${MDC(ERROR, e.getMessage)}")
+ } finally {
+ if (initialized) {
+ Utils.tryWithSafeFinally {
+ mcs.manualClose()
+ } {
+ channel = null
+ mcs = null
+ bs = null
+ fos = null
+ ts = null
+ objOut = null
+ initialized = false
+ streamOpen = false
+ hasBeenClosed = true
+ }
}
}
}
diff --git
a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
index 70a57eed07ac..4352436c872f 100644
---
a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
+++
b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
@@ -16,11 +16,14 @@
*/
package org.apache.spark.storage
-import java.io.File
+import java.io.{File, InputStream, OutputStream}
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
+import org.apache.spark.serializer.{DeserializationStream, JavaSerializer,
SerializationStream, Serializer, SerializerInstance, SerializerManager}
import org.apache.spark.util.Utils
class DiskBlockObjectWriterSuite extends SparkFunSuite {
@@ -43,10 +46,14 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite {
private def createWriter(): (DiskBlockObjectWriter, File,
ShuffleWriteMetrics) = {
val file = new File(tempDir, "somefile")
val conf = new SparkConf()
- val serializerManager = new SerializerManager(new JavaSerializer(conf),
conf)
+ val serializerManager = new CustomSerializerManager(new
JavaSerializer(conf), conf, None)
val writeMetrics = new ShuffleWriteMetrics()
val writer = new DiskBlockObjectWriter(
- file, serializerManager, new JavaSerializer(new
SparkConf()).newInstance(), 1024, true,
+ file,
+ serializerManager,
+ new CustomJavaSerializer(new SparkConf()).newInstance(),
+ 1024,
+ true,
writeMetrics)
(writer, file, writeMetrics)
}
@@ -196,9 +203,76 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite {
for (i <- 1 to 500) {
writer.write(i, i)
}
+
+ val bs =
writer.getSerializerWrappedStream.asInstanceOf[OutputStreamWithCloseDetecting]
+ val objOut =
writer.getSerializationStream.asInstanceOf[SerializationStreamWithCloseDetecting]
+
writer.closeAndDelete()
assert(!file.exists())
assert(writeMetrics.bytesWritten == 0)
assert(writeMetrics.recordsWritten == 0)
+ assert(bs.isClosed)
+ assert(objOut.isClosed)
+ }
+}
+
+trait CloseDetecting {
+ var isClosed = false
+}
+
+class OutputStreamWithCloseDetecting(outputStream: OutputStream)
+ extends OutputStream
+ with CloseDetecting {
+ override def write(b: Int): Unit = outputStream.write(b)
+
+ override def close(): Unit = {
+ isClosed = true
+ outputStream.close()
+ }
+}
+
+class CustomSerializerManager(
+ defaultSerializer: Serializer,
+ conf: SparkConf,
+ encryptionKey: Option[Array[Byte]])
+ extends SerializerManager(defaultSerializer, conf, encryptionKey) {
+ override def wrapStream(blockId: BlockId, s: OutputStream): OutputStream = {
+ new OutputStreamWithCloseDetecting(wrapForCompression(blockId,
wrapForEncryption(s)))
+ }
+}
+
+class CustomJavaSerializer(conf: SparkConf) extends JavaSerializer(conf) {
+
+ override def newInstance(): SerializerInstance = {
+ new CustomJavaSerializerInstance(super.newInstance())
}
}
+
+class SerializationStreamWithCloseDetecting(serializationStream:
SerializationStream)
+ extends SerializationStream with CloseDetecting {
+
+ override def close(): Unit = {
+ isClosed = true
+ serializationStream.close()
+ }
+
+ override def writeObject[T: ClassTag](t: T): SerializationStream =
+ serializationStream.writeObject(t)
+
+ override def flush(): Unit = serializationStream.flush()
+}
+
+class CustomJavaSerializerInstance(instance: SerializerInstance) extends
SerializerInstance {
+ override def serializeStream(s: OutputStream): SerializationStream =
+ new SerializationStreamWithCloseDetecting(instance.serializeStream(s))
+
+ override def serialize[T: ClassTag](t: T): ByteBuffer = instance.serialize(t)
+
+ override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
instance.deserialize(bytes)
+
+ override def deserialize[T: ClassTag](bytes: ByteBuffer, loader:
ClassLoader): T =
+ instance.deserialize(bytes, loader)
+
+ override def deserializeStream(s: InputStream): DeserializationStream =
+ instance.deserializeStream(s)
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]