This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 91899986a0a1 [SPARK-53481][PYTHON] Make Hybrid queue class 91899986a0a1 is described below commit 91899986a0a17f09547d12b5181956577ed07c2b Author: Richard Chen <r.c...@databricks.com> AuthorDate: Fri Sep 5 15:44:42 2025 +0800 [SPARK-53481][PYTHON] Make Hybrid queue class ### What changes were proposed in this pull request? create an abstraction for the hybrid row queue:`HybridQueue` ### Why are the changes needed? the hybrid row queue is currently used to buffer internal rows and spill to disk if needed. Implementations may find it useful to create a similar interface for other objects that are not just InternalRows. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #52212 from richardc-db/make_hybrid_queue_class. Authored-by: Richard Chen <r.c...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../spark/sql/execution/python/HybridQueue.scala | 178 +++++++++++++++++++++ .../spark/sql/execution/python/RowQueue.scala | 139 ++-------------- .../python/streaming/PythonForeachWriter.scala | 3 +- .../spark/sql/execution/python/RowQueueSuite.scala | 4 +- 4 files changed, 196 insertions(+), 128 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/HybridQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/HybridQueue.scala new file mode 100644 index 000000000000..90996c552645 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/HybridQueue.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ + +import scala.Enumeration + +import org.apache.spark.memory.{MemoryConsumer, SparkOutOfMemoryError, TaskMemoryManager} +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.unsafe.memory.MemoryBlock + +/** + * Enum to represent the storage mode for hybrid queues. + */ +object QueueMode extends Enumeration { + type QueueMode = Value + val IN_MEMORY, DISK = Value +} + +trait Queue[T] { + def add(item: T): Boolean + def remove(): T + def close(): Unit +} + +/** + * A generic base class for hybrid queues that can store data either in memory or on disk. + * This class contains common logic for queue management, spilling, and memory management. + */ +abstract class HybridQueue[T, Q <: Queue[T]]( + memManager: TaskMemoryManager, + tempDir: File, + serMgr: SerializerManager) + extends MemoryConsumer(memManager, memManager.getTungstenMemoryMode) { + + // Each buffer should have at least one element. + protected var queues = new java.util.LinkedList[Q]() + + private var writing: Q = _ + protected var reading: Q = _ + + protected var numElementsQueuedOnDisk: Long = 0L + protected var numElementsQueued: Long = 0L + + // exposed for testing + private[python] def numQueues(): Int = queues.size() + + protected def createDiskQueue(): Q + protected def createInMemoryQueue(page: MemoryBlock): Q + protected def getRequiredSize(item: T): Long + protected def getPageSize(queue: Q): Long + protected def isInMemoryQueue(queue: Q): Boolean + protected def isReadingFromDiskQueue: Boolean = !isInMemoryQueue(reading) + + def spill(size: Long, trigger: MemoryConsumer): Long = { + if (trigger == this) { + // When it's triggered by itself, it should write upcoming elements into disk instead of + // copying the elements already in the queue. + return 0L + } + var released = 0L + synchronized { + // poll out all the buffers and add them back in the same order to make sure that the elements + // are in correct order. + val newQueues = new java.util.LinkedList[Q]() + while (!queues.isEmpty) { + val queue = queues.remove() + val newQueue = if (!queues.isEmpty && isInMemoryQueue(queue)) { + val diskQueue = createDiskQueue() + var item = queue.remove() + while (item != null) { + diskQueue.add(item) + item = queue.remove() + } + released += getPageSize(queue) + queue.close() + diskQueue + } else { + queue + } + newQueues.add(newQueue) + } + queues = newQueues + } + released + } + + private def createNewQueue(required: Long): Q = { + // Tests may attempt to force spills. + val page = try { + allocatePage(required) + } catch { + case _: SparkOutOfMemoryError => + null + } + val buffer = if (page != null) { + createInMemoryQueue(page) + } else { + createDiskQueue() + } + + synchronized { + queues.add(buffer) + } + buffer + } + + def add(item: T): QueueMode.Value = { + if (writing == null || !writing.add(item)) { + writing = createNewQueue(getRequiredSize(item)) + if (!writing.add(item)) { + throw QueryExecutionErrors.failedToPushRowIntoRowQueueError(writing.toString) + } + } + numElementsQueued += 1 + if (isInMemoryQueue(writing)) { + QueueMode.IN_MEMORY + } else { + numElementsQueuedOnDisk += 1 + QueueMode.DISK + } + } + + def remove(): T = { + var item: T = null.asInstanceOf[T] + if (reading != null) { + item = reading.remove() + } + if (item == null) { + if (reading != null) { + reading.close() + } + synchronized { + reading = queues.remove() + } + assert(reading != null, s"queue should not be empty") + item = reading.remove() + assert(item != null, s"$reading should have at least one element") + } + if (!isInMemoryQueue(reading)) { + numElementsQueuedOnDisk -= 1 + } + numElementsQueued -= 1 + item + } + + def close(): Unit = { + if (reading != null) { + reading.close() + reading = null.asInstanceOf[Q] + } + synchronized { + while (!queues.isEmpty) { + queues.remove().close() + } + } + } + + def getNumElementsQueuedOnDisk(): Long = numElementsQueuedOnDisk + def getNumElementsQueued(): Long = numElementsQueued +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala index ce30a54c8d4e..d2008cfa1309 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala @@ -23,10 +23,9 @@ import com.google.common.io.Closeables import org.apache.spark.SparkEnv import org.apache.spark.io.NioBufferedFileInputStream -import org.apache.spark.memory.{MemoryConsumer, SparkOutOfMemoryError, TaskMemoryManager} +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.SerializerManager import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.Utils @@ -38,25 +37,7 @@ import org.apache.spark.util.Utils * reader, the reader ALWAYS ran behind the writer. See the doc of class [[BatchEvalPythonExec]] * on how it works. */ -private[python] trait RowQueue { - - /** - * Add a row to the end of it, returns true iff the row has been added to the queue. - */ - def add(row: UnsafeRow): Boolean - - /** - * Retrieve and remove the first row, returns null if it's empty. - * - * It can only be called after add is called, otherwise it will fail (NPE). - */ - def remove(): UnsafeRow - - /** - * Cleanup all the resources. - */ - def close(): Unit -} +trait RowQueue extends Queue[UnsafeRow] /** * A RowQueue that is based on in-memory page. UnsafeRows are appended into it until it's full. @@ -171,125 +152,35 @@ private[python] case class DiskRowQueue( * HybridRowQueue could be safely appended in one thread, and pulled in another thread in the same * time. */ -private[python] case class HybridRowQueue( +case class HybridRowQueue( memManager: TaskMemoryManager, tempDir: File, numFields: Int, serMgr: SerializerManager) - extends MemoryConsumer(memManager, memManager.getTungstenMemoryMode) with RowQueue { + extends HybridQueue[UnsafeRow, RowQueue](memManager, tempDir, serMgr) { - // Each buffer should have at least one row - private var queues = new java.util.LinkedList[RowQueue]() - - private var writing: RowQueue = _ - private var reading: RowQueue = _ - - // exposed for testing - private[python] def numQueues(): Int = queues.size() - - def spill(size: Long, trigger: MemoryConsumer): Long = { - if (trigger == this) { - // When it's triggered by itself, it should write upcoming rows into disk instead of copying - // the rows already in the queue. - return 0L - } - var released = 0L - synchronized { - // poll out all the buffers and add them back in the same order to make sure that the rows - // are in correct order. - val newQueues = new java.util.LinkedList[RowQueue]() - while (!queues.isEmpty) { - val queue = queues.remove() - val newQueue = if (!queues.isEmpty && queue.isInstanceOf[InMemoryRowQueue]) { - val diskQueue = createDiskQueue() - var row = queue.remove() - while (row != null) { - diskQueue.add(row) - row = queue.remove() - } - released += queue.asInstanceOf[InMemoryRowQueue].page.size() - queue.close() - diskQueue - } else { - queue - } - newQueues.add(newQueue) - } - queues = newQueues - } - released - } - - private def createDiskQueue(): RowQueue = { + override protected def createDiskQueue(): RowQueue = { DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields, serMgr) } - private def createNewQueue(required: Long): RowQueue = { - val page = try { - allocatePage(required) - } catch { - case _: SparkOutOfMemoryError => - null - } - val buffer = if (page != null) { - new InMemoryRowQueue(page, numFields) { - override def close(): Unit = { - freePage(this.page) - } + override protected def createInMemoryQueue(page: MemoryBlock): RowQueue = { + new InMemoryRowQueue(page, numFields) { + override def close(): Unit = { + freePage(this.page) } - } else { - createDiskQueue() } - - synchronized { - queues.add(buffer) - } - buffer } - def add(row: UnsafeRow): Boolean = { - if (writing == null || !writing.add(row)) { - writing = createNewQueue(4 + row.getSizeInBytes) - if (!writing.add(row)) { - throw QueryExecutionErrors.failedToPushRowIntoRowQueueError(writing.toString) - } - } - true - } + override protected def getRequiredSize(item: UnsafeRow): Long = 4 + item.getSizeInBytes - def remove(): UnsafeRow = { - var row: UnsafeRow = null - if (reading != null) { - row = reading.remove() - } - if (row == null) { - if (reading != null) { - reading.close() - } - synchronized { - reading = queues.remove() - } - assert(reading != null, s"queue should not be empty") - row = reading.remove() - assert(row != null, s"$reading should have at least one row") - } - row - } + override protected def getPageSize(queue: RowQueue): Long = + queue.asInstanceOf[InMemoryRowQueue].page.size() - def close(): Unit = { - if (reading != null) { - reading.close() - reading = null - } - synchronized { - while (!queues.isEmpty) { - queues.remove().close() - } - } - } + override protected def isInMemoryQueue(queue: RowQueue): Boolean = + queue.isInstanceOf[InMemoryRowQueue] } -private[sql] object HybridRowQueue { +object HybridRowQueue { def apply(taskMemoryMgr: TaskMemoryManager, file: File, fields: Int): HybridRowQueue = { HybridRowQueue(taskMemoryMgr, file, fields, SparkEnv.get.serializerManager) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala index 01643af9cf30..b1b79946c2fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala @@ -184,8 +184,7 @@ object PythonForeachWriter { } def add(row: UnsafeRow): Unit = withLock { - assert(queue.add(row), s"Failed to add row to HybridRowQueue while sending data to Python" + - s"[count = $count, allAdded = $allAdded, exception = $exception]") + queue.add(row) count += 1 unblockRemove.signal() logTrace(s"Added $row, $count left") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala index 5cf1dea7d073..10d3b1429600 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala @@ -111,7 +111,7 @@ class RowQueueSuite extends SparkFunSuite with EncryptionFunSuite { var i = 0 while (i < n) { row.setLong(0, i) - assert(queue.add(row), "fail to add") + queue.add(row) i += 1 } assert(queue.numQueues() > 1, "should have more than one queue") @@ -128,7 +128,7 @@ class RowQueueSuite extends SparkFunSuite with EncryptionFunSuite { i = 0 while (i < n) { row.setLong(0, i) - assert(queue.add(row), "fail to add") + queue.add(row) i += 1 } assert(queue.numQueues() > 1, "should have more than one queue") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org