This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e490dd737dbf [SPARK-50274][CORE] Guard against use-after-close in
DirectByteBufferOutputStream
e490dd737dbf is described below
commit e490dd737dbf3325ccb18eaf2e713155ec112f21
Author: Ankur Dave <[email protected]>
AuthorDate: Sun Nov 10 18:40:38 2024 +0900
[SPARK-50274][CORE] Guard against use-after-close in
DirectByteBufferOutputStream
### What changes were proposed in this pull request?
`DirectByteBufferOutputStream#close()` calls `StorageUtils.dispose()` to
free its direct byte buffer. This puts the object into an unspecified and
dangerous state after being closed, and can cause unpredictable JVM crashes if
it the object is used after close.
This PR makes this safer by modifying `close()` to place the object into a
known-closed state, and modifying all methods to assert not closed.
To minimize the performance impact from the extra checks, this PR also
changes `DirectByteBufferOutputStream#buffer` from `private` to
`private[this]`, which should produce more efficient direct field accesses.
### Why are the changes needed?
Improves debuggability for users of DirectByteBufferOutputStream such as
PythonRunner.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added a test in DirectByteBufferOutputStreamSuite to verify that use after
close throws IllegalStateException.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48807 from
ankurdave/SPARK-50274-DirectByteBufferOutputStream-checkNotClosed.
Authored-by: Ankur Dave <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../spark/util/DirectByteBufferOutputStream.scala | 24 +++++++++++--
.../util/DirectByteBufferOutputStreamSuite.scala | 41 ++++++++++++++++++++++
2 files changed, 62 insertions(+), 3 deletions(-)
diff --git
a/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala
b/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala
index 1683e892511f..fd10d60a13fd 100644
---
a/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala
+++
b/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala
@@ -20,6 +20,7 @@ package org.apache.spark.util
import java.io.OutputStream
import java.nio.ByteBuffer
+import org.apache.spark.SparkException
import org.apache.spark.storage.StorageUtils
import org.apache.spark.unsafe.Platform
@@ -29,16 +30,18 @@ import org.apache.spark.unsafe.Platform
* @param capacity The initial capacity of the direct byte buffer
*/
private[spark] class DirectByteBufferOutputStream(capacity: Int) extends
OutputStream {
- private var buffer = Platform.allocateDirectBuffer(capacity)
+ private[this] var buffer = Platform.allocateDirectBuffer(capacity)
def this() = this(32)
override def write(b: Int): Unit = {
+ checkNotClosed()
ensureCapacity(buffer.position() + 1)
buffer.put(b.toByte)
}
override def write(b: Array[Byte], off: Int, len: Int): Unit = {
+ checkNotClosed()
ensureCapacity(buffer.position() + len)
buffer.put(b, off, len)
}
@@ -63,15 +66,29 @@ private[spark] class DirectByteBufferOutputStream(capacity:
Int) extends OutputS
buffer = newBuffer
}
- def reset(): Unit = buffer.clear()
+ private def checkNotClosed(): Unit = {
+ if (buffer == null) {
+ throw SparkException.internalError(
+ "Cannot call methods on a closed DirectByteBufferOutputStream")
+ }
+ }
+
+ def reset(): Unit = {
+ checkNotClosed()
+ buffer.clear()
+ }
- def size(): Int = buffer.position()
+ def size(): Int = {
+ checkNotClosed()
+ buffer.position()
+ }
/**
* Any subsequent call to [[close()]], [[write()]], [[reset()]] will
invalidate the buffer
* returned by this method.
*/
def toByteBuffer: ByteBuffer = {
+ checkNotClosed()
val outputBuffer = buffer.duplicate()
outputBuffer.flip()
outputBuffer
@@ -80,6 +97,7 @@ private[spark] class DirectByteBufferOutputStream(capacity:
Int) extends OutputS
override def close(): Unit = {
// Eagerly free the direct byte buffer without waiting for GC to reduce
memory pressure.
StorageUtils.dispose(buffer)
+ buffer = null
}
}
diff --git
a/core/src/test/scala/org/apache/spark/util/DirectByteBufferOutputStreamSuite.scala
b/core/src/test/scala/org/apache/spark/util/DirectByteBufferOutputStreamSuite.scala
new file mode 100644
index 000000000000..7fd9d1fc05c9
--- /dev/null
+++
b/core/src/test/scala/org/apache/spark/util/DirectByteBufferOutputStreamSuite.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+
+class DirectByteBufferOutputStreamSuite extends SparkFunSuite {
+ test("use after close") {
+ val o = new DirectByteBufferOutputStream()
+ val size = 1000
+ o.write(new Array[Byte](size), 0, size)
+ val b = o.toByteBuffer
+ o.close()
+
+ // Using `o` after close should throw an exception rather than crashing.
+ assertThrows[SparkException] { o.write(123) }
+ assertThrows[SparkException] { o.write(new Array[Byte](size), 0, size) }
+ assertThrows[SparkException] { o.reset() }
+ assertThrows[SparkException] { o.size() }
+ assertThrows[SparkException] { o.toByteBuffer }
+
+ // Using `b` after `o` is closed may crash.
+ // val arr = new Array[Byte](size)
+ // b.get(arr)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]