This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 663a6c479b22 [SPARK-55351][PYTHON][SQL] PythonArrowInput encapsulate 
resource allocation inside `newWriter`
663a6c479b22 is described below

commit 663a6c479b22d4d63c58210db45f2015a24182b8
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 4 13:58:27 2026 +0800

    [SPARK-55351][PYTHON][SQL] PythonArrowInput encapsulate resource allocation 
inside `newWriter`
    
    ### What changes were proposed in this pull request?
    PythonArrowInput encapsulate resource allocation inside `newWriter`
    
    ### Why are the changes needed?
    it is up to the writer to manage the resource.
    PythonArrowInput is just a helper layer to build the writer.
    Currently, subclass always have to release the resource `allocator\root` 
even if it might be not used in subclass.
    
    ### Does this PR introduce _any_ user-facing change?
    no, internal refactoring
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #54128 from zhengruifeng/refactor_pai_new_writer.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../sql/execution/python/PythonArrowInput.scala    | 73 +++++++++++++---------
 .../ApplyInPandasWithStatePythonRunner.scala       |  1 -
 .../TransformWithStateInPySparkPythonRunner.scala  |  2 -
 3 files changed, 43 insertions(+), 33 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 58a48b1815e1..a8f0f8ba5c56 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -53,6 +53,10 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
 
   protected def pythonMetrics: Map[String, SQLMetric]
 
+  /**
+   * Writes input batch to the stream connected to the Python worker.
+   * Returns true if any data was written to the stream, false if the input is 
exhausted.
+   */
   protected def writeNextBatchToArrowStream(
       root: VectorSchemaRoot,
       writer: ArrowStreamWriter,
@@ -61,15 +65,6 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
 
   protected def writeUDF(dataOut: DataOutputStream): Unit
 
-  protected lazy val allocator: BufferAllocator =
-    ArrowUtils.rootAllocator.newChildAllocator(s"stdout writer for 
$pythonExec", 0, Long.MaxValue)
-
-  protected lazy val root: VectorSchemaRoot = {
-    val arrowSchema = ArrowUtils.toArrowSchema(
-      schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
-    VectorSchemaRoot.create(arrowSchema, allocator)
-  }
-
   // Create compression codec based on config
   protected def codec: CompressionCodec = SQLConf.get.arrowCompressionCodec 
match {
     case "none" => NoCompressionCodec.INSTANCE
@@ -87,20 +82,6 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
         s"Unsupported Arrow compression codec: $other. Supported values: none, 
zstd, lz4")
   }
 
-  protected var writer: ArrowStreamWriter = _
-
-  protected def close(): Unit = {
-    Utils.tryWithSafeFinally {
-      // end writes footer to the output stream and doesn't clean any 
resources.
-      // It could throw exception if the output stream is closed, so it should 
be
-      // in the try block.
-      writer.end()
-    } {
-      root.close()
-      allocator.close()
-    }
-  }
-
   protected override def newWriter(
       env: SparkEnv,
       worker: PythonWorker,
@@ -108,20 +89,45 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
       partitionIndex: Int,
       context: TaskContext): Writer = {
     new Writer(env, worker, inputIterator, partitionIndex, context) {
+      private val arrowSchema = ArrowUtils.toArrowSchema(
+        schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
+      private val allocator: BufferAllocator = ArrowUtils.rootAllocator
+        .newChildAllocator(s"stdout writer for $pythonExec", 0, Long.MaxValue)
+      private val root: VectorSchemaRoot = 
VectorSchemaRoot.create(arrowSchema, allocator)
+      private var writer: ArrowStreamWriter = _
+
+      context.addTaskCompletionListener[Unit] { _ => this.terminate() }
 
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
         writeUDF(dataOut)
       }
 
       override def writeNextInputToStream(dataOut: DataOutputStream): Boolean 
= {
-
         if (writer == null) {
           writer = new ArrowStreamWriter(root, null, dataOut)
           writer.start()
         }
 
         assert(writer != null)
-        writeNextBatchToArrowStream(root, writer, dataOut, inputIterator)
+        val hasInput = writeNextBatchToArrowStream(root, writer, dataOut, 
inputIterator)
+        if (!hasInput) {
+          this.terminate()
+        }
+        hasInput
+      }
+
+      private def terminate(): Unit = {
+        Utils.tryWithSafeFinally {
+          // end writes footer to the output stream and doesn't clean any 
resources.
+          // It could throw exception if the output stream is closed, so it 
should be
+          // in the try block.
+          if (writer != null) {
+            writer.end()
+          }
+        } {
+          root.close()
+          allocator.close()
+        }
       }
     }
   }
@@ -129,9 +135,6 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
 
 private[python] trait BasicPythonArrowInput extends 
PythonArrowInput[Iterator[InternalRow]] {
   self: BasePythonRunner[Iterator[InternalRow], _] =>
-  protected lazy val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root)
-  protected lazy val unloader = new VectorUnloader(root, true, codec, true)
-
   protected val maxRecordsPerBatch: Int = {
     val v = SQLConf.get.arrowMaxRecordsPerBatch
     if (v > 0) v else Int.MaxValue
@@ -139,11 +142,18 @@ private[python] trait BasicPythonArrowInput extends 
PythonArrowInput[Iterator[In
 
   protected val maxBytesPerBatch: Long = SQLConf.get.arrowMaxBytesPerBatch
 
+  protected var arrowWriter: arrow.ArrowWriter = _
+  protected var unloader: VectorUnloader = _
+
   protected def writeNextBatchToArrowStream(
       root: VectorSchemaRoot,
       writer: ArrowStreamWriter,
       dataOut: DataOutputStream,
       inputIterator: Iterator[Iterator[InternalRow]]): Boolean = {
+    if (arrowWriter == null && unloader == null) {
+      arrowWriter = ArrowWriter.create(root)
+      unloader = new VectorUnloader(root, true, codec, true)
+    }
 
     if (inputIterator.hasNext) {
       val startData = dataOut.size()
@@ -167,7 +177,6 @@ private[python] trait BasicPythonArrowInput extends 
PythonArrowInput[Iterator[In
       pythonMetrics("pythonDataSent") += deltaData
       true
     } else {
-      super[PythonArrowInput].close()
       false
     }
   }
@@ -184,6 +193,11 @@ private[python] trait BatchedPythonArrowInput extends 
BasicPythonArrowInput {
       writer: ArrowStreamWriter,
       dataOut: DataOutputStream,
       inputIterator: Iterator[Iterator[InternalRow]]): Boolean = {
+    if (arrowWriter == null && unloader == null) {
+      arrowWriter = ArrowWriter.create(root)
+      unloader = new VectorUnloader(root, true, codec, true)
+    }
+
     if (!nextBatchStart.hasNext) {
       if (inputIterator.hasNext) {
         nextBatchStart = inputIterator.next()
@@ -201,7 +215,6 @@ private[python] trait BatchedPythonArrowInput extends 
BasicPythonArrowInput {
       pythonMetrics("pythonDataSent") += deltaData
       true
     } else {
-      super[BasicPythonArrowInput].close()
       false
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
index 89d8e425fd2b..7de7f140e03d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
@@ -163,7 +163,6 @@ class ApplyInPandasWithStatePythonRunner(
       true
     } else {
       pandasWriter.finalizeData()
-      super[PythonArrowInput].close()
       false
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index 05771d38cd84..792ba2b100b7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -105,7 +105,6 @@ class TransformWithStateInPySparkPythonRunner(
       true
     } else {
       pandasWriter.finalizeCurrentArrowBatch()
-      super[PythonArrowInput].close()
       false
     }
     val deltaData = dataOut.size() - startData
@@ -201,7 +200,6 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
       if (pandasWriter.getTotalNumRowsForBatch > 0) {
         pandasWriter.finalizeCurrentArrowBatch()
       }
-      super[PythonArrowInput].close()
       false
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to