This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7d1ea3dbc655 [SPARK-54392][SS] Optimize JVM-Python communication for 
TWS initial state
7d1ea3dbc655 is described below

commit 7d1ea3dbc6555c39762a71e9c7c2025600072682
Author: nyaapa <[email protected]>
AuthorDate: Sat Dec 6 11:59:04 2025 +0900

    [SPARK-54392][SS] Optimize JVM-Python communication for TWS initial state
    
    ### What changes were proposed in this pull request?
    
    - group multiple keys into one arrow batch;
    generally will have much less batches in case of high keys cardinality.
    - do not group `init_data` and `input_data` in batch0: instead of it 
serialize `init_data` first, and then `input_data`;
    in worst case we're going to have one more chunk by not grouping them 
together, but winning by having much simpler logic on python side.
    - do not create extra dataframes if not needed + copy empty one;
    
    ### Why are the changes needed?
    Benchmark results show that in high-cardinality scenarios, this 
optimization improves batch0 time by ~40%. No visible regressions for low 
cardinality case.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing UT and Benchmark:
    
    10,000,000 distinct keys in init state (8xi3.4xlarge):
        - Without Optimization: 11400 records/s
        - With Optimization: 30000 records/s
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #53122 from nyaapa/SPARK-54392.
    
    Authored-by: nyaapa <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 python/pyspark/sql/pandas/serializers.py           | 215 +++++++++++++--------
 .../streaming/test_pandas_transform_with_state.py  |  90 +++++++++
 python/pyspark/worker.py                           |  14 +-
 .../streaming/BaseStreamingArrowWriter.scala       |   6 +-
 .../TransformWithStateInPySparkExec.scala          |  10 +-
 .../TransformWithStateInPySparkPythonRunner.scala  |  66 +++++--
 .../streaming/BaseStreamingArrowWriterSuite.scala  |  39 ++++
 7 files changed, 330 insertions(+), 110 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index d46e0b2052cc..768160087032 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1492,7 +1492,9 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
         self.result_state_pdf_arrow_type = to_arrow_type(
             self.result_state_df_type, 
prefers_large_types=prefers_large_var_types
         )
-        self.arrow_max_records_per_batch = arrow_max_records_per_batch
+        self.arrow_max_records_per_batch = (
+            arrow_max_records_per_batch if arrow_max_records_per_batch > 0 
else 2**31 - 1
+        )
 
     def load_stream(self, stream):
         """
@@ -1847,13 +1849,29 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
             int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
             arrow_cast=True,
         )
-        self.arrow_max_records_per_batch = arrow_max_records_per_batch
+        self.arrow_max_records_per_batch = (
+            arrow_max_records_per_batch if arrow_max_records_per_batch > 0 
else 2**31 - 1
+        )
         self.arrow_max_bytes_per_batch = arrow_max_bytes_per_batch
         self.key_offsets = None
         self.average_arrow_row_size = 0
         self.total_bytes = 0
         self.total_rows = 0
 
+    def _update_batch_size_stats(self, batch):
+        """
+        Update batch size statistics for adaptive batching.
+        """
+        # Short circuit batch size calculation if the batch size is
+        # unlimited as computing batch size is computationally expensive.
+        if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0:
+            batch_bytes = sum(
+                buf.size for col in batch.columns for buf in col.buffers() if 
buf is not None
+            )
+            self.total_bytes += batch_bytes
+            self.total_rows += batch.num_rows
+            self.average_arrow_row_size = self.total_bytes / self.total_rows
+
     def load_stream(self, stream):
         """
         Read ArrowRecordBatches from stream, deserialize them to populate a 
list of data chunk, and
@@ -1881,18 +1899,7 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
 
             def row_stream():
                 for batch in batches:
-                    # Short circuit batch size calculation if the batch size is
-                    # unlimited as computing batch size is computationally 
expensive.
-                    if self.arrow_max_bytes_per_batch != 2**31 - 1 and 
batch.num_rows > 0:
-                        batch_bytes = sum(
-                            buf.size
-                            for col in batch.columns
-                            for buf in col.buffers()
-                            if buf is not None
-                        )
-                        self.total_bytes += batch_bytes
-                        self.total_rows += batch.num_rows
-                        self.average_arrow_row_size = self.total_bytes / 
self.total_rows
+                    self._update_batch_size_stats(batch)
                     data_pandas = [
                         self.arrow_to_pandas(c, i)
                         for i, c in 
enumerate(pa.Table.from_batches([batch]).itercolumns())
@@ -1972,6 +1979,7 @@ class 
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
 
     def load_stream(self, stream):
         import pyarrow as pa
+        import pandas as pd
         from pyspark.sql.streaming.stateful_processor_util import (
             TransformWithStateInPandasFuncMode,
         )
@@ -1990,6 +1998,12 @@ class 
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
 
             def flatten_columns(cur_batch, col_name):
                 state_column = 
cur_batch.column(cur_batch.schema.get_field_index(col_name))
+
+                # Check if the entire column is null
+                if state_column.null_count == len(state_column):
+                    # Return empty table with no columns
+                    return pa.Table.from_arrays([], names=[])
+
                 state_field_names = [
                     state_column.type[i].name for i in 
range(state_column.type.num_fields)
                 ]
@@ -2007,30 +2021,69 @@ class 
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
                 .add("inputData", dataSchema)
                 .add("initState", initStateSchema)
             We'll parse batch into Tuples of (key, inputData, initState) and 
pass into the Python
-             data generator. All rows in the same batch have the same grouping 
key.
+             data generator. Rows in the same batch may have different 
grouping keys,
+             but each batch will have either init_data or input_data, not mix.
             """
-            for batch in batches:
-                flatten_state_table = flatten_columns(batch, "inputData")
-                data_pandas = [
-                    self.arrow_to_pandas(c, i)
-                    for i, c in enumerate(flatten_state_table.itercolumns())
-                ]
 
-                flatten_init_table = flatten_columns(batch, "initState")
-                init_data_pandas = [
-                    self.arrow_to_pandas(c, i)
-                    for i, c in enumerate(flatten_init_table.itercolumns())
-                ]
-                key_series = [data_pandas[o] for o in self.key_offsets]
-                init_key_series = [init_data_pandas[o] for o in 
self.init_key_offsets]
+            def row_stream():
+                for batch in batches:
+                    self._update_batch_size_stats(batch)
 
-                if any(s.empty for s in key_series):
-                    # If any row is empty, assign batch_key using 
init_key_series
-                    batch_key = tuple(s[0] for s in init_key_series)
-                else:
-                    # If all rows are non-empty, create batch_key from 
key_series
-                    batch_key = tuple(s[0] for s in key_series)
-                yield (batch_key, data_pandas, init_data_pandas)
+                    flatten_state_table = flatten_columns(batch, "inputData")
+                    data_pandas = [
+                        self.arrow_to_pandas(c, i)
+                        for i, c in 
enumerate(flatten_state_table.itercolumns())
+                    ]
+
+                    flatten_init_table = flatten_columns(batch, "initState")
+                    init_data_pandas = [
+                        self.arrow_to_pandas(c, i)
+                        for i, c in enumerate(flatten_init_table.itercolumns())
+                    ]
+
+                    assert not (bool(init_data_pandas) and bool(data_pandas))
+
+                    if bool(data_pandas):
+                        for row in pd.concat(data_pandas, 
axis=1).itertuples(index=False):
+                            batch_key = tuple(row[s] for s in self.key_offsets)
+                            yield (batch_key, row, None)
+                    elif bool(init_data_pandas):
+                        for row in pd.concat(init_data_pandas, 
axis=1).itertuples(index=False):
+                            batch_key = tuple(row[s] for s in 
self.init_key_offsets)
+                            yield (batch_key, None, row)
+
+            EMPTY_DATAFRAME = pd.DataFrame()
+            for batch_key, group_rows in groupby(row_stream(), key=lambda x: 
x[0]):
+                rows = []
+                init_state_rows = []
+                for _, row, init_state_row in group_rows:
+                    if row is not None:
+                        rows.append(row)
+                    if init_state_row is not None:
+                        init_state_rows.append(init_state_row)
+
+                    total_len = len(rows) + len(init_state_rows)
+                    if (
+                        total_len >= self.arrow_max_records_per_batch
+                        or total_len * self.average_arrow_row_size >= 
self.arrow_max_bytes_per_batch
+                    ):
+                        yield (
+                            batch_key,
+                            pd.DataFrame(rows) if len(rows) > 0 else 
EMPTY_DATAFRAME.copy(),
+                            pd.DataFrame(init_state_rows)
+                            if len(init_state_rows) > 0
+                            else EMPTY_DATAFRAME.copy(),
+                        )
+                        rows = []
+                        init_state_rows = []
+                if rows or init_state_rows:
+                    yield (
+                        batch_key,
+                        pd.DataFrame(rows) if len(rows) > 0 else 
EMPTY_DATAFRAME.copy(),
+                        pd.DataFrame(init_state_rows)
+                        if len(init_state_rows) > 0
+                        else EMPTY_DATAFRAME.copy(),
+                    )
 
         _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
         data_batches = generate_data_batches(_batches)
@@ -2056,7 +2109,9 @@ class 
TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):
 
     def __init__(self, arrow_max_records_per_batch):
         super(TransformWithStateInPySparkRowSerializer, self).__init__()
-        self.arrow_max_records_per_batch = arrow_max_records_per_batch
+        self.arrow_max_records_per_batch = (
+            arrow_max_records_per_batch if arrow_max_records_per_batch > 0 
else 2**31 - 1
+        )
         self.key_offsets = None
 
     def load_stream(self, stream):
@@ -2148,13 +2203,13 @@ class 
TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySp
         self.init_key_offsets = None
 
     def load_stream(self, stream):
-        import itertools
         import pyarrow as pa
         from pyspark.sql.streaming.stateful_processor_util import (
             TransformWithStateInPandasFuncMode,
         )
+        from typing import Iterator, Any, Optional, Tuple
 
-        def generate_data_batches(batches):
+        def generate_data_batches(batches) -> Iterator[Tuple[Any, 
Optional[Any], Optional[Any]]]:
             """
             Deserialize ArrowRecordBatches and return a generator of Row.
             The deserialization logic assumes that Arrow RecordBatches contain 
the data with the
@@ -2165,8 +2220,15 @@ class 
TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySp
              into the data generator.
             """
 
-            def extract_rows(cur_batch, col_name, key_offsets):
+            def extract_rows(
+                cur_batch, col_name, key_offsets
+            ) -> Optional[Iterator[Tuple[Any, Any]]]:
                 data_column = 
cur_batch.column(cur_batch.schema.get_field_index(col_name))
+
+                # Check if the entire column is null
+                if data_column.null_count == len(data_column):
+                    return None
+
                 data_field_names = [
                     data_column.type[i].name for i in 
range(data_column.type.num_fields)
                 ]
@@ -2179,18 +2241,17 @@ class 
TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySp
                 table = pa.Table.from_arrays(data_field_arrays, 
names=data_field_names)
 
                 if table.num_rows == 0:
-                    return (None, iter([]))
-                else:
-                    batch_key = tuple(table.column(o)[0].as_py() for o in 
key_offsets)
+                    return None
 
-                    rows = []
+                def row_iterator():
                     for row_idx in range(table.num_rows):
+                        key = tuple(table.column(o)[row_idx].as_py() for o in 
key_offsets)
                         row = DataRow(
                             *(table.column(i)[row_idx].as_py() for i in 
range(table.num_columns))
                         )
-                        rows.append(row)
+                        yield (key, row)
 
-                    return (batch_key, iter(rows))
+                return row_iterator()
 
             """
             The arrow batch is written in the schema:
@@ -2198,49 +2259,45 @@ class 
TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySp
                 .add("inputData", dataSchema)
                 .add("initState", initStateSchema)
             We'll parse batch into Tuples of (key, inputData, initState) and 
pass into the Python
-             data generator. All rows in the same batch have the same grouping 
key.
+             data generator. Each batch will have either init_data or 
input_data, not mix.
             """
             for batch in batches:
-                (input_batch_key, input_data_iter) = extract_rows(
-                    batch, "inputData", self.key_offsets
-                )
-                (init_batch_key, init_state_iter) = extract_rows(
-                    batch, "initState", self.init_key_offsets
-                )
+                # Detect which column has data - each batch contains only one 
type
+                input_result = extract_rows(batch, "inputData", 
self.key_offsets)
+                init_result = extract_rows(batch, "initState", 
self.init_key_offsets)
 
-                if input_batch_key is None:
-                    batch_key = init_batch_key
-                else:
-                    batch_key = input_batch_key
-
-                for init_state_row in init_state_iter:
-                    yield (batch_key, None, init_state_row)
+                assert not (input_result is not None and init_result is not 
None)
 
-                for input_data_row in input_data_iter:
-                    yield (batch_key, input_data_row, None)
+                if input_result is not None:
+                    for key, input_data_row in input_result:
+                        yield (key, input_data_row, None)
+                elif init_result is not None:
+                    for key, init_state_row in init_result:
+                        yield (key, None, init_state_row)
 
         _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
         data_batches = generate_data_batches(_batches)
 
         for k, g in groupby(data_batches, key=lambda x: x[0]):
-            # g: list(batch_key, input_data_iter, init_state_iter)
-
-            # they are sharing the iterator, hence need to copy
-            input_values_iter, init_state_iter = itertools.tee(g, 2)
-
-            chained_input_values = itertools.chain(map(lambda x: x[1], 
input_values_iter))
-            chained_init_state_values = itertools.chain(map(lambda x: x[2], 
init_state_iter))
-
-            chained_input_values_without_none = filter(
-                lambda x: x is not None, chained_input_values
-            )
-            chained_init_state_values_without_none = filter(
-                lambda x: x is not None, chained_init_state_values
-            )
-
-            ret_tuple = (chained_input_values_without_none, 
chained_init_state_values_without_none)
-
-            yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, 
ret_tuple)
+            input_rows = []
+            init_rows = []
+
+            for batch_key, input_row, init_row in g:
+                if input_row is not None:
+                    input_rows.append(input_row)
+                if init_row is not None:
+                    init_rows.append(init_row)
+
+                total_len = len(input_rows) + len(init_rows)
+                if total_len >= self.arrow_max_records_per_batch:
+                    ret_tuple = (iter(input_rows), iter(init_rows))
+                    yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, 
ret_tuple)
+                    input_rows = []
+                    init_rows = []
+
+            if input_rows or init_rows:
+                ret_tuple = (iter(input_rows), iter(init_rows))
+                yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, 
ret_tuple)
 
         yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)
 
diff --git 
a/python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state.py
index 57125c7820c0..ecdfcfda3c1d 100644
--- 
a/python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state.py
+++ 
b/python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state.py
@@ -1483,6 +1483,96 @@ class TransformWithStateTestsMixin:
                 ),
             )
 
+    def test_transform_with_state_with_records_limit(self):
+        if not self.use_pandas():
+            return
+
+        def make_check_results(expected_per_batch):
+            def check_results(batch_df, batch_id):
+                batch_df.collect()
+                if batch_id == 0:
+                    assert set(batch_df.sort("id").collect()) == 
expected_per_batch[0]
+                else:
+                    assert set(batch_df.sort("id").collect()) == 
expected_per_batch[1]
+
+            return check_results
+
+        result_with_small_limit = [
+            {
+                Row(id="0", chunkCount=2),
+                Row(id="1", chunkCount=2),
+            },
+            {
+                Row(id="0", chunkCount=3),
+                Row(id="1", chunkCount=2),
+            },
+        ]
+
+        result_with_large_limit = [
+            {
+                Row(id="0", chunkCount=1),
+                Row(id="1", chunkCount=1),
+            },
+            {
+                Row(id="0", chunkCount=1),
+                Row(id="1", chunkCount=1),
+            },
+        ]
+
+        data = [("0", 789), ("3", 987)]
+        initial_state = self.spark.createDataFrame(data, "id string, initVal 
int").groupBy("id")
+
+        with self.sql_conf(
+            # Set it to a very small number so that every row would be a 
separate pandas df
+            {"spark.sql.execution.arrow.maxRecordsPerBatch": "1"}
+        ):
+            self._test_transform_with_state_basic(
+                ChunkCountProcessorFactory(),
+                make_check_results(result_with_small_limit),
+                output_schema=StructType(
+                    [
+                        StructField("id", StringType(), True),
+                        StructField("chunkCount", IntegerType(), True),
+                    ]
+                ),
+            )
+
+            self._test_transform_with_state_basic(
+                ChunkCountProcessorWithInitialStateFactory(),
+                make_check_results(result_with_small_limit),
+                initial_state=initial_state,
+                output_schema=StructType(
+                    [
+                        StructField("id", StringType(), True),
+                        StructField("chunkCount", IntegerType(), True),
+                    ]
+                ),
+            )
+
+        with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 
"-1"}):
+            self._test_transform_with_state_basic(
+                ChunkCountProcessorFactory(),
+                make_check_results(result_with_large_limit),
+                output_schema=StructType(
+                    [
+                        StructField("id", StringType(), True),
+                        StructField("chunkCount", IntegerType(), True),
+                    ]
+                ),
+            )
+
+            self._test_transform_with_state_basic(
+                ChunkCountProcessorWithInitialStateFactory(),
+                make_check_results(result_with_large_limit),
+                initial_state=initial_state,
+                output_schema=StructType(
+                    [
+                        StructField("id", StringType(), True),
+                        StructField("chunkCount", IntegerType(), True),
+                    ]
+                ),
+            )
+
     # test all state types (value, list, map) with large values (512 KB)
     def test_transform_with_state_large_values(self):
         def check_results(batch_df, batch_id):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 50e71fb6da9d..109157e2c339 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -880,12 +880,14 @@ def wrap_grouped_transform_with_state_pandas_udf(f, 
return_type, runner_conf):
 
 def wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type, 
runner_conf):
     def wrapped(stateful_processor_api_client, mode, key, value_series_gen):
-        import pandas as pd
-
+        # Split the generator into two using itertools.tee
         state_values_gen, init_states_gen = itertools.tee(value_series_gen, 2)
-        state_values = (df for x, _ in state_values_gen if not (df := 
pd.concat(x, axis=1)).empty)
-        init_states = (df for _, x in init_states_gen if not (df := 
pd.concat(x, axis=1)).empty)
 
+        # Extract just the data DataFrames (first element of each tuple)
+        state_values = (data_df for data_df, _ in state_values_gen if not 
data_df.empty)
+
+        # Extract just the init DataFrames (second element of each tuple)
+        init_states = (init_df for _, init_df in init_states_gen if not 
init_df.empty)
         result_iter = f(stateful_processor_api_client, mode, key, 
state_values, init_states)
 
         # TODO(SPARK-49100): add verification that elements in result_iter are
@@ -3071,8 +3073,8 @@ def read_udfs(pickleSer, infile, eval_type):
 
                 def values_gen():
                     for x in a[2]:
-                        retVal = [x[1][o] for o in parsed_offsets[0][1]]
-                        initVal = [x[2][o] for o in parsed_offsets[1][1]]
+                        retVal = x[1]
+                        initVal = x[2]
                         yield retVal, initVal
 
                 # This must be generator comprehension - do not materialize.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
index f0371cafb72a..8c9ab2a8c636 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
@@ -88,10 +88,14 @@ class BaseStreamingArrowWriter(
 
   protected def isBatchSizeLimitReached: Boolean = {
     // If we have either reached the records or bytes limit
-    totalNumRowsForBatch >= arrowMaxRecordsPerBatch ||
+    (arrowMaxRecordsPerBatch > 0 && totalNumRowsForBatch >= 
arrowMaxRecordsPerBatch) ||
       // Short circuit batch size calculation if the batch size is unlimited 
as computing batch
       // size is computationally expensive.
       ((arrowMaxBytesPerBatch != Int.MaxValue)
         && (arrowWriterForData.sizeInBytes() >= arrowMaxBytesPerBatch))
   }
+
+  def getTotalNumRowsForBatch: Int = {
+    totalNumRowsForBatch
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
index a1cf71844950..1ceaf6c4bf81 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
@@ -31,7 +31,7 @@ import 
org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
PythonUDF}
 import org.apache.spark.sql.catalyst.plans.logical.TransformWithStateInPySpark
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.execution.{CoGroupedIterator, SparkPlan}
+import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.execution.python.ArrowPythonRunner
 import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython, 
groupAndProject, resolveArgOffsets}
@@ -347,9 +347,9 @@ case class TransformWithStateInPySparkExec(
       val initData =
         groupAndProject(initStateIterator, initialStateGroupingAttrs,
           initialState.output, initDedupAttributes)
-      // group input rows and initial state rows by the same grouping key
-      val groupedData: Iterator[(InternalRow, Iterator[InternalRow], 
Iterator[InternalRow])] =
-        new CoGroupedIterator(data, initData, groupingAttributes)
+      // concatenate input rows and initial state rows iterators
+      val inputIter: Iterator[((InternalRow, Iterator[InternalRow]), Boolean)] 
=
+          initData.map { item => (item, true) } ++ data.map { item => (item, 
false) }
 
       val evalType = {
         if (userFacingDataType == 
TransformWithStateInPySpark.UserFacingDataType.PANDAS) {
@@ -374,7 +374,7 @@ case class TransformWithStateInPySparkExec(
         batchTimestampMs,
         eventTimeWatermarkForEviction
       )
-      executePython(groupedData, output, runner)
+      executePython(inputIter, output, runner)
     }
 
     CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index 42d4ad68c29a..3eb7c7e64d64 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -144,6 +144,9 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
 
   private var pandasWriter: BaseStreamingArrowWriter = _
 
+  private var currentDataIterator: Iterator[InternalRow] = _
+  private var isCurrentIterFromInitState: Option[Boolean] = None
+
   override protected def writeNextBatchToArrowStream(
       root: VectorSchemaRoot,
       writer: ArrowStreamWriter,
@@ -158,30 +161,53 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
       )
     }
 
-    if (inputIterator.hasNext) {
-      val startData = dataOut.size()
-      // a new grouping key with data & init state iter
-      val next = inputIterator.next()
-      val dataIter = next._2
-      val initIter = next._3
-
-      while (dataIter.hasNext || initIter.hasNext) {
-        val dataRow =
-          if (dataIter.hasNext) dataIter.next()
-          else InternalRow.empty
-        val initRow =
-          if (initIter.hasNext) initIter.next()
-          else InternalRow.empty
-        pandasWriter.writeRow(InternalRow(dataRow, initRow))
+
+    // If we don't have data left for the current group, move to the next 
group.
+    if (currentDataIterator == null && inputIterator.hasNext) {
+      val ((_, data), isInitState) = inputIterator.next()
+      currentDataIterator = data
+      val isPrevIterFromInitState = isCurrentIterFromInitState
+      isCurrentIterFromInitState = Some(isInitState)
+      if (isPrevIterFromInitState.isDefined &&
+        isPrevIterFromInitState.get != isInitState &&
+        pandasWriter.getTotalNumRowsForBatch > 0) {
+        // So we won't have batches with mixed data and init state.
+        pandasWriter.finalizeCurrentArrowBatch()
+        return true
       }
-      pandasWriter.finalizeCurrentArrowBatch()
-      val deltaData = dataOut.size() - startData
-      pythonMetrics("pythonDataSent") += deltaData
+    }
+
+    val startData = dataOut.size()
+
+    val hasInput = if (currentDataIterator != null) {
+      var isCurrentBatchFull = false
+      val isCurrentIterFromInitStateVal = isCurrentIterFromInitState.get
+      // Stop writing when the current arrowBatch is finalized/full. If we 
have rows left
+      while (currentDataIterator.hasNext && !isCurrentBatchFull) {
+        val dataRow = currentDataIterator.next()
+        isCurrentBatchFull = if (isCurrentIterFromInitStateVal) {
+          pandasWriter.writeRow(InternalRow(null, dataRow))
+        } else {
+          pandasWriter.writeRow(InternalRow(dataRow, null))
+        }
+      }
+
+      if (!currentDataIterator.hasNext) {
+        currentDataIterator = null
+      }
+
       true
     } else {
+      if (pandasWriter.getTotalNumRowsForBatch > 0) {
+        pandasWriter.finalizeCurrentArrowBatch()
+      }
       super[PythonArrowInput].close()
       false
     }
+
+    val deltaData = dataOut.size() - startData
+    pythonMetrics("pythonDataSent") += deltaData
+    hasInput
   }
 }
 
@@ -392,5 +418,7 @@ trait TransformWithStateInPySparkPythonRunnerUtils extends 
Logging {
 
 object TransformWithStateInPySparkPythonRunner {
   type InType = (InternalRow, Iterator[InternalRow])
-  type GroupedInType = (InternalRow, Iterator[InternalRow], 
Iterator[InternalRow])
+
+  // ((key, rows), isInitState)
+  type GroupedInType = ((InternalRow, Iterator[InternalRow]), Boolean)
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
index fc10a102b4f5..49839fb8c985 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
@@ -95,4 +95,43 @@ class BaseStreamingArrowWriterSuite extends SparkFunSuite 
with BeforeAndAfterEac
     verify(writer, times(2)).writeBatch()
     verify(arrowWriter, times(2)).reset()
   }
+
+  test("test negative or zero arrowMaxRecordsPerBatch is unlimited") {
+    val root: VectorSchemaRoot = mock(classOf[VectorSchemaRoot])
+    val dataRow = mock(classOf[InternalRow])
+
+    // Test with negative value
+    transformWithStateInPySparkWriter = new BaseStreamingArrowWriter(
+      root, writer, -1, arrowMaxBytesPerBatch, arrowWriter)
+
+    // Write many rows (more than typical batch size)
+    for (_ <- 1 to 10) {
+      transformWithStateInPySparkWriter.writeRow(dataRow)
+    }
+
+    // Verify all rows were written but batch was not finalized
+    verify(arrowWriter, times(10)).write(dataRow)
+    verify(writer, never()).writeBatch()
+
+    // Only finalize when explicitly called
+    transformWithStateInPySparkWriter.finalizeCurrentArrowBatch()
+    verify(writer).writeBatch()
+
+    // Test with zero value
+    transformWithStateInPySparkWriter = new BaseStreamingArrowWriter(
+      root, writer, 0, arrowMaxBytesPerBatch, arrowWriter)
+
+    // Write many rows again
+    for (_ <- 1 to 10) {
+      transformWithStateInPySparkWriter.writeRow(dataRow)
+    }
+
+    // Verify rows were written but batch was not finalized
+    verify(arrowWriter, times(20)).write(dataRow)
+    verify(writer).writeBatch()  // still 1 from before
+
+    // Only finalize when explicitly called
+    transformWithStateInPySparkWriter.finalizeCurrentArrowBatch()
+    verify(writer, times(2)).writeBatch()
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to