This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 91899986a0a1 [SPARK-53481][PYTHON] Make Hybrid queue class
91899986a0a1 is described below

commit 91899986a0a17f09547d12b5181956577ed07c2b
Author: Richard Chen <r.c...@databricks.com>
AuthorDate: Fri Sep 5 15:44:42 2025 +0800

    [SPARK-53481][PYTHON] Make Hybrid queue class
    
    ### What changes were proposed in this pull request?
    
    create an abstraction for the hybrid row queue:`HybridQueue`
    
    ### Why are the changes needed?
    
    the hybrid row queue is currently used to buffer internal rows and spill to 
disk if needed. Implementations may find it useful to create a similar 
interface for other objects that are not just InternalRows.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    existing tests
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #52212 from richardc-db/make_hybrid_queue_class.
    
    Authored-by: Richard Chen <r.c...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../spark/sql/execution/python/HybridQueue.scala   | 178 +++++++++++++++++++++
 .../spark/sql/execution/python/RowQueue.scala      | 139 ++--------------
 .../python/streaming/PythonForeachWriter.scala     |   3 +-
 .../spark/sql/execution/python/RowQueueSuite.scala |   4 +-
 4 files changed, 196 insertions(+), 128 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/HybridQueue.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/HybridQueue.scala
new file mode 100644
index 000000000000..90996c552645
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/HybridQueue.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.Enumeration
+
+import org.apache.spark.memory.{MemoryConsumer, SparkOutOfMemoryError, 
TaskMemoryManager}
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.unsafe.memory.MemoryBlock
+
+/**
+ * Enum to represent the storage mode for hybrid queues.
+ */
+object QueueMode extends Enumeration {
+  type QueueMode = Value
+  val IN_MEMORY, DISK = Value
+}
+
+trait Queue[T] {
+  def add(item: T): Boolean
+  def remove(): T
+  def close(): Unit
+}
+
+/**
+ * A generic base class for hybrid queues that can store data either in memory 
or on disk.
+ * This class contains common logic for queue management, spilling, and memory 
management.
+ */
+abstract class HybridQueue[T, Q <: Queue[T]](
+    memManager: TaskMemoryManager,
+    tempDir: File,
+    serMgr: SerializerManager)
+  extends MemoryConsumer(memManager, memManager.getTungstenMemoryMode) {
+
+  // Each buffer should have at least one element.
+  protected var queues = new java.util.LinkedList[Q]()
+
+  private var writing: Q = _
+  protected var reading: Q = _
+
+  protected var numElementsQueuedOnDisk: Long = 0L
+  protected var numElementsQueued: Long = 0L
+
+  // exposed for testing
+  private[python] def numQueues(): Int = queues.size()
+
+  protected def createDiskQueue(): Q
+  protected def createInMemoryQueue(page: MemoryBlock): Q
+  protected def getRequiredSize(item: T): Long
+  protected def getPageSize(queue: Q): Long
+  protected def isInMemoryQueue(queue: Q): Boolean
+  protected def isReadingFromDiskQueue: Boolean = !isInMemoryQueue(reading)
+
+  def spill(size: Long, trigger: MemoryConsumer): Long = {
+    if (trigger == this) {
+      // When it's triggered by itself, it should write upcoming elements into 
disk instead of
+      // copying the elements already in the queue.
+      return 0L
+    }
+    var released = 0L
+    synchronized {
+      // poll out all the buffers and add them back in the same order to make 
sure that the elements
+      // are in correct order.
+      val newQueues = new java.util.LinkedList[Q]()
+      while (!queues.isEmpty) {
+        val queue = queues.remove()
+        val newQueue = if (!queues.isEmpty && isInMemoryQueue(queue)) {
+          val diskQueue = createDiskQueue()
+          var item = queue.remove()
+          while (item != null) {
+            diskQueue.add(item)
+            item = queue.remove()
+          }
+          released += getPageSize(queue)
+          queue.close()
+          diskQueue
+        } else {
+          queue
+        }
+        newQueues.add(newQueue)
+      }
+      queues = newQueues
+    }
+    released
+  }
+
+  private def createNewQueue(required: Long): Q = {
+    // Tests may attempt to force spills.
+    val page = try {
+      allocatePage(required)
+    } catch {
+      case _: SparkOutOfMemoryError =>
+        null
+    }
+    val buffer = if (page != null) {
+      createInMemoryQueue(page)
+    } else {
+      createDiskQueue()
+    }
+
+    synchronized {
+      queues.add(buffer)
+    }
+    buffer
+  }
+
+  def add(item: T): QueueMode.Value = {
+    if (writing == null || !writing.add(item)) {
+      writing = createNewQueue(getRequiredSize(item))
+      if (!writing.add(item)) {
+        throw 
QueryExecutionErrors.failedToPushRowIntoRowQueueError(writing.toString)
+      }
+    }
+    numElementsQueued += 1
+    if (isInMemoryQueue(writing)) {
+      QueueMode.IN_MEMORY
+    } else {
+      numElementsQueuedOnDisk += 1
+      QueueMode.DISK
+    }
+  }
+
+  def remove(): T = {
+    var item: T = null.asInstanceOf[T]
+    if (reading != null) {
+      item = reading.remove()
+    }
+    if (item == null) {
+      if (reading != null) {
+        reading.close()
+      }
+      synchronized {
+        reading = queues.remove()
+      }
+      assert(reading != null, s"queue should not be empty")
+      item = reading.remove()
+      assert(item != null, s"$reading should have at least one element")
+    }
+    if (!isInMemoryQueue(reading)) {
+      numElementsQueuedOnDisk -= 1
+    }
+    numElementsQueued -= 1
+    item
+  }
+
+  def close(): Unit = {
+    if (reading != null) {
+      reading.close()
+      reading = null.asInstanceOf[Q]
+    }
+    synchronized {
+      while (!queues.isEmpty) {
+        queues.remove().close()
+      }
+    }
+  }
+
+  def getNumElementsQueuedOnDisk(): Long = numElementsQueuedOnDisk
+  def getNumElementsQueued(): Long = numElementsQueued
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
index ce30a54c8d4e..d2008cfa1309 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
@@ -23,10 +23,9 @@ import com.google.common.io.Closeables
 
 import org.apache.spark.SparkEnv
 import org.apache.spark.io.NioBufferedFileInputStream
-import org.apache.spark.memory.{MemoryConsumer, SparkOutOfMemoryError, 
TaskMemoryManager}
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.serializer.SerializerManager
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.memory.MemoryBlock
 import org.apache.spark.util.Utils
@@ -38,25 +37,7 @@ import org.apache.spark.util.Utils
  * reader, the reader ALWAYS ran behind the writer. See the doc of class 
[[BatchEvalPythonExec]]
  * on how it works.
  */
-private[python] trait RowQueue {
-
-  /**
-   * Add a row to the end of it, returns true iff the row has been added to 
the queue.
-   */
-  def add(row: UnsafeRow): Boolean
-
-  /**
-   * Retrieve and remove the first row, returns null if it's empty.
-   *
-   * It can only be called after add is called, otherwise it will fail (NPE).
-   */
-  def remove(): UnsafeRow
-
-  /**
-   * Cleanup all the resources.
-   */
-  def close(): Unit
-}
+trait RowQueue extends Queue[UnsafeRow]
 
 /**
  * A RowQueue that is based on in-memory page. UnsafeRows are appended into it 
until it's full.
@@ -171,125 +152,35 @@ private[python] case class DiskRowQueue(
  * HybridRowQueue could be safely appended in one thread, and pulled in 
another thread in the same
  * time.
  */
-private[python] case class HybridRowQueue(
+case class HybridRowQueue(
     memManager: TaskMemoryManager,
     tempDir: File,
     numFields: Int,
     serMgr: SerializerManager)
-  extends MemoryConsumer(memManager, memManager.getTungstenMemoryMode) with 
RowQueue {
+  extends HybridQueue[UnsafeRow, RowQueue](memManager, tempDir, serMgr) {
 
-  // Each buffer should have at least one row
-  private var queues = new java.util.LinkedList[RowQueue]()
-
-  private var writing: RowQueue = _
-  private var reading: RowQueue = _
-
-  // exposed for testing
-  private[python] def numQueues(): Int = queues.size()
-
-  def spill(size: Long, trigger: MemoryConsumer): Long = {
-    if (trigger == this) {
-      // When it's triggered by itself, it should write upcoming rows into 
disk instead of copying
-      // the rows already in the queue.
-      return 0L
-    }
-    var released = 0L
-    synchronized {
-      // poll out all the buffers and add them back in the same order to make 
sure that the rows
-      // are in correct order.
-      val newQueues = new java.util.LinkedList[RowQueue]()
-      while (!queues.isEmpty) {
-        val queue = queues.remove()
-        val newQueue = if (!queues.isEmpty && 
queue.isInstanceOf[InMemoryRowQueue]) {
-          val diskQueue = createDiskQueue()
-          var row = queue.remove()
-          while (row != null) {
-            diskQueue.add(row)
-            row = queue.remove()
-          }
-          released += queue.asInstanceOf[InMemoryRowQueue].page.size()
-          queue.close()
-          diskQueue
-        } else {
-          queue
-        }
-        newQueues.add(newQueue)
-      }
-      queues = newQueues
-    }
-    released
-  }
-
-  private def createDiskQueue(): RowQueue = {
+  override protected def createDiskQueue(): RowQueue = {
     DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields, serMgr)
   }
 
-  private def createNewQueue(required: Long): RowQueue = {
-    val page = try {
-      allocatePage(required)
-    } catch {
-      case _: SparkOutOfMemoryError =>
-        null
-    }
-    val buffer = if (page != null) {
-      new InMemoryRowQueue(page, numFields) {
-        override def close(): Unit = {
-          freePage(this.page)
-        }
+  override protected def createInMemoryQueue(page: MemoryBlock): RowQueue = {
+    new InMemoryRowQueue(page, numFields) {
+      override def close(): Unit = {
+        freePage(this.page)
       }
-    } else {
-      createDiskQueue()
     }
-
-    synchronized {
-      queues.add(buffer)
-    }
-    buffer
   }
 
-  def add(row: UnsafeRow): Boolean = {
-    if (writing == null || !writing.add(row)) {
-      writing = createNewQueue(4 + row.getSizeInBytes)
-      if (!writing.add(row)) {
-        throw 
QueryExecutionErrors.failedToPushRowIntoRowQueueError(writing.toString)
-      }
-    }
-    true
-  }
+  override protected def getRequiredSize(item: UnsafeRow): Long = 4 + 
item.getSizeInBytes
 
-  def remove(): UnsafeRow = {
-    var row: UnsafeRow = null
-    if (reading != null) {
-      row = reading.remove()
-    }
-    if (row == null) {
-      if (reading != null) {
-        reading.close()
-      }
-      synchronized {
-        reading = queues.remove()
-      }
-      assert(reading != null, s"queue should not be empty")
-      row = reading.remove()
-      assert(row != null, s"$reading should have at least one row")
-    }
-    row
-  }
+  override protected def getPageSize(queue: RowQueue): Long =
+    queue.asInstanceOf[InMemoryRowQueue].page.size()
 
-  def close(): Unit = {
-    if (reading != null) {
-      reading.close()
-      reading = null
-    }
-    synchronized {
-      while (!queues.isEmpty) {
-        queues.remove().close()
-      }
-    }
-  }
+  override protected def isInMemoryQueue(queue: RowQueue): Boolean =
+    queue.isInstanceOf[InMemoryRowQueue]
 }
 
-private[sql] object HybridRowQueue {
+object HybridRowQueue {
   def apply(taskMemoryMgr: TaskMemoryManager, file: File, fields: Int): 
HybridRowQueue = {
     HybridRowQueue(taskMemoryMgr, file, fields, SparkEnv.get.serializerManager)
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala
index 01643af9cf30..b1b79946c2fb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala
@@ -184,8 +184,7 @@ object PythonForeachWriter {
     }
 
     def add(row: UnsafeRow): Unit = withLock {
-      assert(queue.add(row), s"Failed to add row to HybridRowQueue while 
sending data to Python" +
-        s"[count = $count, allAdded = $allAdded, exception = $exception]")
+      queue.add(row)
       count += 1
       unblockRemove.signal()
       logTrace(s"Added $row, $count left")
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
index 5cf1dea7d073..10d3b1429600 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
@@ -111,7 +111,7 @@ class RowQueueSuite extends SparkFunSuite with 
EncryptionFunSuite {
       var i = 0
       while (i < n) {
         row.setLong(0, i)
-        assert(queue.add(row), "fail to add")
+        queue.add(row)
         i += 1
       }
       assert(queue.numQueues() > 1, "should have more than one queue")
@@ -128,7 +128,7 @@ class RowQueueSuite extends SparkFunSuite with 
EncryptionFunSuite {
       i = 0
       while (i < n) {
         row.setLong(0, i)
-        assert(queue.add(row), "fail to add")
+        queue.add(row)
         i += 1
       }
       assert(queue.numQueues() > 1, "should have more than one queue")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to