This is an automated email from the ASF dual-hosted git repository. richox pushed a commit to branch dev-fix-spill-deadlock in repository https://gitbox.apache.org/repos/asf/auron.git
commit c94b18b60dab8d9a9d550b3b3e784964c2799b8d Author: zhangli20 <[email protected]> AuthorDate: Thu Dec 25 21:23:28 2025 +0800 fix possible deadlock in OnHeapSpillManager --- .../spark/sql/auron/memory/OnHeapSpill.scala | 38 ++++++++++++++++++---- .../sql/auron/memory/SparkOnHeapSpillManager.scala | 2 +- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/memory/OnHeapSpill.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/memory/OnHeapSpill.scala index 541f4b39..0580cad3 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/memory/OnHeapSpill.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/memory/OnHeapSpill.scala @@ -17,20 +17,32 @@ package org.apache.spark.sql.auron.memory import java.nio.ByteBuffer - import org.apache.spark.internal.Logging +import org.apache.spark.memory.MemoryConsumer + +import java.util.concurrent.locks.ReentrantLock case class OnHeapSpill(hsm: SparkOnHeapSpillManager, id: Int) extends Logging { private var spillBuf: SpillBuf = new MemBasedSpillBuf + private val lock = new ReentrantLock def memUsed: Long = spillBuf.memUsed def diskUsed: Long = spillBuf.diskUsed def size: Long = spillBuf.size def diskIOTime: Long = spillBuf.diskIOTime + private def withLock[T](f: => T): T = { + lock.lock() + try { + f + } finally { + lock.unlock() + } + } + def write(buf: ByteBuffer): Unit = { var needSpill = false - synchronized { + withLock { spillBuf match { case _: MemBasedSpillBuf => val acquiredMemory = hsm.acquireMemory(buf.capacity()) @@ -46,13 +58,13 @@ case class OnHeapSpill(hsm: SparkOnHeapSpillManager, id: Int) extends Logging { spillInternal() } - synchronized { + withLock { spillBuf.write(buf) } } def read(buf: ByteBuffer): Int = { - synchronized { + withLock { val oldMemUsed = memUsed val startPosition = buf.position() spillBuf.read(buf) @@ -69,7 +81,7 @@ case class OnHeapSpill(hsm: SparkOnHeapSpillManager, id: Int) extends Logging { } def release(): Unit = { - synchronized { + withLock { val oldMemUsed = memUsed spillBuf = new ReleasedSpillBuf(spillBuf) @@ -79,8 +91,20 @@ case class OnHeapSpill(hsm: SparkOnHeapSpillManager, id: Int) extends Logging { } } - def spill(): Long = { - synchronized { + def spill(trigger: MemoryConsumer): Long = { + // this might have been locked if the spilling is triggered by OnHeapSpill.write + if (trigger == this.hsm) { + if (lock.tryLock()) { + try { + return spillInternal() + } finally { + lock.unlock() + } + } + return 0L + } + + withLock { spillInternal() } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/memory/SparkOnHeapSpillManager.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/memory/SparkOnHeapSpillManager.scala index e77aaaf3..e9d3add9 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/memory/SparkOnHeapSpillManager.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/memory/SparkOnHeapSpillManager.scala @@ -168,7 +168,7 @@ class SparkOnHeapSpillManager(taskContext: TaskContext) val sortedSpills = spills.seq.sortBy(0 - _.map(_.memUsed).getOrElse(0L)) sortedSpills.foreach { case Some(spill) if spill.memUsed > 0 => - totalFreed += spill.spill() + totalFreed += spill.spill(trigger) if (totalFreed >= size) { return totalFreed }
