This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 40993e671b87 [SPARK-53361][SS][1/2] Optimizing JVM–Python 
Communication in TWS by Grouping Multiple Keys into One Arrow Batch
40993e671b87 is described below

commit 40993e671b8758893a09d1471938f1581732ecf5
Author: zeruibao <zerui....@databricks.com>
AuthorDate: Mon Sep 15 17:03:29 2025 +0900

    [SPARK-53361][SS][1/2] Optimizing JVM–Python Communication in TWS by 
Grouping Multiple Keys into One Arrow Batch
    
    ### What changes were proposed in this pull request?
    This PR introduces an optimization to JVM–Python communication in TWS by 
allowing multiple keys to be grouped into a single Arrow batch.
    
    Currently, each Arrow batch is restricted to contain records for a single 
key. In high-cardinality scenarios, this results in many small Arrow batches 
(e.g., [(key1, value1), (key1, value2)], [(key2, value1), (key2, value2)]), 
which increases the overhead of Arrow batch transmission between the JVM and 
Python.
    
    With this change, records with different keys can be bin-packed into the 
same Arrow batch, reducing the number of batches transmitted. On the Python 
side, we leverage groupBy to regroup records by key, mirroring the behavior of 
the Scala GroupedIterator implementation.
    
    This PR only handle `TransformWithStateInPySparkPythonRunner`. 
`TransformWithStateInPySparkPythonInitialStateRunner` would only affect the 
batch 0 so that we will leave to another PR.
    
    This approach significantly reduces transmission overhead while preserving 
correct grouping semantics.
    
    ### Why are the changes needed?
    
    Benchmark results show that in high-cardinality scenarios, this 
optimization improves throughput by ~20% by reducing the overhead of Arrow 
batch transmission. For low-cardinality scenarios, the change introduces no 
observable regression,.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing UT and Benchmark.
    
    Performance improvement with Pyspark without initial state (**one** shuffle 
partitions).
    Full cardinality (Each key is distinct):
    - Without Optimization: 4350 Rows/second
    - With Optimization: 13470 Rows/second
    
    Performance improvement with Pandas without initial state (**six** shuffle 
partitions).
    10,000,000 distinct keys:
    - Without Optimization: 5000 records/s
    - With Optimization: 6221 records/s
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #52331 from 
zeruibao/zeruibao/SPARK-53361-only-improve-non-initial-state.
    
    Authored-by: zeruibao <zerui....@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 python/pyspark/sql/pandas/serializers.py           | 37 +++++++++++++---------
 python/pyspark/worker.py                           |  7 ++--
 .../streaming/BaseStreamingArrowWriter.scala       |  6 ++--
 .../TransformWithStateInPySparkPythonRunner.scala  | 34 ++++++++++++++------
 4 files changed, 52 insertions(+), 32 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 0965a2e7b546..d1bdfa9e8d01 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1603,6 +1603,7 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
         this function works in overall.
         """
         import pyarrow as pa
+        import pandas as pd
         from pyspark.sql.streaming.stateful_processor_util import (
             TransformWithStateInPandasFuncMode,
         )
@@ -1617,14 +1618,20 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
             This function must avoid materializing multiple Arrow 
RecordBatches into memory at the
             same time. And data chunks from the same grouping key should 
appear sequentially.
             """
-            for batch in batches:
-                data_pandas = [
-                    self.arrow_to_pandas(c, i)
-                    for i, c in 
enumerate(pa.Table.from_batches([batch]).itercolumns())
-                ]
-                key_series = [data_pandas[o] for o in self.key_offsets]
-                batch_key = tuple(s[0] for s in key_series)
-                yield (batch_key, data_pandas)
+
+            def row_stream():
+                for batch in batches:
+                    data_pandas = [
+                        self.arrow_to_pandas(c, i)
+                        for i, c in 
enumerate(pa.Table.from_batches([batch]).itercolumns())
+                    ]
+                    for row in pd.concat(data_pandas, 
axis=1).itertuples(index=False):
+                        batch_key = tuple(row[s] for s in self.key_offsets)
+                        yield (batch_key, row)
+
+            for batch_key, group_rows in groupby(row_stream(), key=lambda x: 
x[0]):
+                df = pd.DataFrame([row for _, row in group_rows])
+                yield (batch_key, df)
 
         _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
         data_batches = generate_data_batches(_batches)
@@ -1793,15 +1800,15 @@ class 
TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):
             same time. And data chunks from the same grouping key should 
appear sequentially.
             """
             for batch in batches:
-                DataRow = Row(*(batch.schema.names))
+                DataRow = Row(*batch.schema.names)
 
-                # This is supposed to be the same.
-                batch_key = tuple(batch[o][0].as_py() for o in 
self.key_offsets)
+                # Iterate row by row without converting the whole batch
+                num_cols = batch.num_columns
                 for row_idx in range(batch.num_rows):
-                    row = DataRow(
-                        *(batch.column(i)[row_idx].as_py() for i in 
range(batch.num_columns))
-                    )
-                    yield (batch_key, row)
+                    # build the key for this row
+                    row_key = tuple(batch[o][row_idx].as_py() for o in 
self.key_offsets)
+                    row = DataRow(*(batch.column(i)[row_idx].as_py() for i in 
range(num_cols)))
+                    yield row_key, row
 
         _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
         data_batches = generate_data_batches(_batches)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 292e4174818e..e73f16464aaa 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -694,10 +694,7 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec, 
runner_conf):
 
 def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf):
     def wrapped(stateful_processor_api_client, mode, key, value_series_gen):
-        import pandas as pd
-
-        values = (pd.concat(x, axis=1) for x in value_series_gen)
-        result_iter = f(stateful_processor_api_client, mode, key, values)
+        result_iter = f(stateful_processor_api_client, mode, key, 
value_series_gen)
 
         # TODO(SPARK-49100): add verification that elements in result_iter are
         # indeed of type pd.DataFrame and confirm to assigned cols
@@ -2496,7 +2493,7 @@ def read_udfs(pickleSer, infile, eval_type):
 
                 def values_gen():
                     for x in a[2]:
-                        retVal = [x[1][o] for o in parsed_offsets[0][1]]
+                        retVal = x[1].iloc[:, parsed_offsets[0][1]]
                         yield retVal
 
                 # This must be generator comprehension - do not materialize.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
index 73c70a618866..ba8b2c3ac7da 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
@@ -50,11 +50,12 @@ class BaseStreamingArrowWriter(
    *
    * @param dataRow The row to write for current batch.
    */
-  def writeRow(dataRow: InternalRow): Unit = {
+  def writeRow(dataRow: InternalRow): Boolean = {
     // If it exceeds the condition of batch (number of records) and there is 
more data for the
     // same group, finalize and construct a new batch.
 
-    if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) {
+    val isCurrentBatchFull = totalNumRowsForBatch >= arrowMaxRecordsPerBatch
+    if (isCurrentBatchFull) {
       finalizeCurrentChunk(isLastChunkForGroup = false)
       finalizeCurrentArrowBatch()
     }
@@ -63,6 +64,7 @@ class BaseStreamingArrowWriter(
 
     numRowsForCurrentChunk += 1
     totalNumRowsForBatch += 1
+    isCurrentBatchFull
   }
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index 329bd4335265..f0df3e1f7d15 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -66,6 +66,9 @@ class TransformWithStateInPySparkPythonRunner(
 
   private var pandasWriter: BaseStreamingArrowWriter = _
 
+  private var currentDataIterator: Iterator[InternalRow] = _
+
+  // Grouping multiple keys into one arrow batch
   override protected def writeNextBatchToArrowStream(
       root: VectorSchemaRoot,
       writer: ArrowStreamWriter,
@@ -75,23 +78,34 @@ class TransformWithStateInPySparkPythonRunner(
       pandasWriter = new BaseStreamingArrowWriter(root, writer, 
arrowMaxRecordsPerBatch)
     }
 
-    if (inputIterator.hasNext) {
-      val startData = dataOut.size()
-      val next = inputIterator.next()
-      val dataIter = next._2
+    // If we don't have data left for the current group, move to the next 
group.
+    if (currentDataIterator == null && inputIterator.hasNext) {
+      val (_, dataIter) = inputIterator.next()
+      currentDataIterator = dataIter
+    }
 
-      while (dataIter.hasNext) {
-        val dataRow = dataIter.next()
-        pandasWriter.writeRow(dataRow)
+    val startData = dataOut.size()
+    val hasInput = if (currentDataIterator != null) {
+      var isCurrentBatchFull = false
+      // Stop writing when the current arrowBatch is finalized/full. If we 
have rows left
+      while (currentDataIterator.hasNext && !isCurrentBatchFull) {
+        val dataRow = currentDataIterator.next()
+        isCurrentBatchFull = pandasWriter.writeRow(dataRow)
       }
-      pandasWriter.finalizeCurrentArrowBatch()
-      val deltaData = dataOut.size() - startData
-      pythonMetrics("pythonDataSent") += deltaData
+
+      if (!currentDataIterator.hasNext) {
+        currentDataIterator = null
+      }
+
       true
     } else {
+      pandasWriter.finalizeCurrentArrowBatch()
       super[PythonArrowInput].close()
       false
     }
+    val deltaData = dataOut.size() - startData
+    pythonMetrics("pythonDataSent") += deltaData
+    hasInput
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to