This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4c19158b835b Revert "[SPARK-55351][PYTHON][SQL] PythonArrowInput 
encapsulate resource allocation inside `newWriter`"
4c19158b835b is described below

commit 4c19158b835b85a2b6be1be071af4acdf1d02c4f
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 4 19:49:55 2026 +0800

    Revert "[SPARK-55351][PYTHON][SQL] PythonArrowInput encapsulate resource 
allocation inside `newWriter`"
    
    revert https://github.com/apache/spark/pull/54128 due to a potential memory 
leak issue
    
    Closes #54138 from zhengruifeng/revert_new_writer.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../sql/execution/python/PythonArrowInput.scala    | 73 +++++++++-------------
 .../ApplyInPandasWithStatePythonRunner.scala       |  1 +
 .../TransformWithStateInPySparkPythonRunner.scala  |  2 +
 3 files changed, 33 insertions(+), 43 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index a8f0f8ba5c56..58a48b1815e1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -53,10 +53,6 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
 
   protected def pythonMetrics: Map[String, SQLMetric]
 
-  /**
-   * Writes input batch to the stream connected to the Python worker.
-   * Returns true if any data was written to the stream, false if the input is 
exhausted.
-   */
   protected def writeNextBatchToArrowStream(
       root: VectorSchemaRoot,
       writer: ArrowStreamWriter,
@@ -65,6 +61,15 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
 
   protected def writeUDF(dataOut: DataOutputStream): Unit
 
+  protected lazy val allocator: BufferAllocator =
+    ArrowUtils.rootAllocator.newChildAllocator(s"stdout writer for 
$pythonExec", 0, Long.MaxValue)
+
+  protected lazy val root: VectorSchemaRoot = {
+    val arrowSchema = ArrowUtils.toArrowSchema(
+      schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
+    VectorSchemaRoot.create(arrowSchema, allocator)
+  }
+
   // Create compression codec based on config
   protected def codec: CompressionCodec = SQLConf.get.arrowCompressionCodec 
match {
     case "none" => NoCompressionCodec.INSTANCE
@@ -82,6 +87,20 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
         s"Unsupported Arrow compression codec: $other. Supported values: none, 
zstd, lz4")
   }
 
+  protected var writer: ArrowStreamWriter = _
+
+  protected def close(): Unit = {
+    Utils.tryWithSafeFinally {
+      // end writes footer to the output stream and doesn't clean any 
resources.
+      // It could throw exception if the output stream is closed, so it should 
be
+      // in the try block.
+      writer.end()
+    } {
+      root.close()
+      allocator.close()
+    }
+  }
+
   protected override def newWriter(
       env: SparkEnv,
       worker: PythonWorker,
@@ -89,45 +108,20 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
       partitionIndex: Int,
       context: TaskContext): Writer = {
     new Writer(env, worker, inputIterator, partitionIndex, context) {
-      private val arrowSchema = ArrowUtils.toArrowSchema(
-        schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
-      private val allocator: BufferAllocator = ArrowUtils.rootAllocator
-        .newChildAllocator(s"stdout writer for $pythonExec", 0, Long.MaxValue)
-      private val root: VectorSchemaRoot = 
VectorSchemaRoot.create(arrowSchema, allocator)
-      private var writer: ArrowStreamWriter = _
-
-      context.addTaskCompletionListener[Unit] { _ => this.terminate() }
 
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
         writeUDF(dataOut)
       }
 
       override def writeNextInputToStream(dataOut: DataOutputStream): Boolean 
= {
+
         if (writer == null) {
           writer = new ArrowStreamWriter(root, null, dataOut)
           writer.start()
         }
 
         assert(writer != null)
-        val hasInput = writeNextBatchToArrowStream(root, writer, dataOut, 
inputIterator)
-        if (!hasInput) {
-          this.terminate()
-        }
-        hasInput
-      }
-
-      private def terminate(): Unit = {
-        Utils.tryWithSafeFinally {
-          // end writes footer to the output stream and doesn't clean any 
resources.
-          // It could throw exception if the output stream is closed, so it 
should be
-          // in the try block.
-          if (writer != null) {
-            writer.end()
-          }
-        } {
-          root.close()
-          allocator.close()
-        }
+        writeNextBatchToArrowStream(root, writer, dataOut, inputIterator)
       }
     }
   }
@@ -135,6 +129,9 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
 
 private[python] trait BasicPythonArrowInput extends 
PythonArrowInput[Iterator[InternalRow]] {
   self: BasePythonRunner[Iterator[InternalRow], _] =>
+  protected lazy val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root)
+  protected lazy val unloader = new VectorUnloader(root, true, codec, true)
+
   protected val maxRecordsPerBatch: Int = {
     val v = SQLConf.get.arrowMaxRecordsPerBatch
     if (v > 0) v else Int.MaxValue
@@ -142,18 +139,11 @@ private[python] trait BasicPythonArrowInput extends 
PythonArrowInput[Iterator[In
 
   protected val maxBytesPerBatch: Long = SQLConf.get.arrowMaxBytesPerBatch
 
-  protected var arrowWriter: arrow.ArrowWriter = _
-  protected var unloader: VectorUnloader = _
-
   protected def writeNextBatchToArrowStream(
       root: VectorSchemaRoot,
       writer: ArrowStreamWriter,
       dataOut: DataOutputStream,
       inputIterator: Iterator[Iterator[InternalRow]]): Boolean = {
-    if (arrowWriter == null && unloader == null) {
-      arrowWriter = ArrowWriter.create(root)
-      unloader = new VectorUnloader(root, true, codec, true)
-    }
 
     if (inputIterator.hasNext) {
       val startData = dataOut.size()
@@ -177,6 +167,7 @@ private[python] trait BasicPythonArrowInput extends 
PythonArrowInput[Iterator[In
       pythonMetrics("pythonDataSent") += deltaData
       true
     } else {
+      super[PythonArrowInput].close()
       false
     }
   }
@@ -193,11 +184,6 @@ private[python] trait BatchedPythonArrowInput extends 
BasicPythonArrowInput {
       writer: ArrowStreamWriter,
       dataOut: DataOutputStream,
       inputIterator: Iterator[Iterator[InternalRow]]): Boolean = {
-    if (arrowWriter == null && unloader == null) {
-      arrowWriter = ArrowWriter.create(root)
-      unloader = new VectorUnloader(root, true, codec, true)
-    }
-
     if (!nextBatchStart.hasNext) {
       if (inputIterator.hasNext) {
         nextBatchStart = inputIterator.next()
@@ -215,6 +201,7 @@ private[python] trait BatchedPythonArrowInput extends 
BasicPythonArrowInput {
       pythonMetrics("pythonDataSent") += deltaData
       true
     } else {
+      super[BasicPythonArrowInput].close()
       false
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
index 7de7f140e03d..89d8e425fd2b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
@@ -163,6 +163,7 @@ class ApplyInPandasWithStatePythonRunner(
       true
     } else {
       pandasWriter.finalizeData()
+      super[PythonArrowInput].close()
       false
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index 792ba2b100b7..05771d38cd84 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -105,6 +105,7 @@ class TransformWithStateInPySparkPythonRunner(
       true
     } else {
       pandasWriter.finalizeCurrentArrowBatch()
+      super[PythonArrowInput].close()
       false
     }
     val deltaData = dataOut.size() - startData
@@ -200,6 +201,7 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
       if (pandasWriter.getTotalNumRowsForBatch > 0) {
         pandasWriter.finalizeCurrentArrowBatch()
       }
+      super[PythonArrowInput].close()
       false
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to