This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7d1ea3dbc655 [SPARK-54392][SS] Optimize JVM-Python communication for
TWS initial state
7d1ea3dbc655 is described below
commit 7d1ea3dbc6555c39762a71e9c7c2025600072682
Author: nyaapa <[email protected]>
AuthorDate: Sat Dec 6 11:59:04 2025 +0900
[SPARK-54392][SS] Optimize JVM-Python communication for TWS initial state
### What changes were proposed in this pull request?
- group multiple keys into one arrow batch;
generally will have much less batches in case of high keys cardinality.
- do not group `init_data` and `input_data` in batch0: instead of it
serialize `init_data` first, and then `input_data`;
in worst case we're going to have one more chunk by not grouping them
together, but winning by having much simpler logic on python side.
- do not create extra dataframes if not needed + copy empty one;
### Why are the changes needed?
Benchmark results show that in high-cardinality scenarios, this
optimization improves batch0 time by ~40%. No visible regressions for low
cardinality case.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing UT and Benchmark:
10,000,000 distinct keys in init state (8xi3.4xlarge):
- Without Optimization: 11400 records/s
- With Optimization: 30000 records/s
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53122 from nyaapa/SPARK-54392.
Authored-by: nyaapa <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
python/pyspark/sql/pandas/serializers.py | 215 +++++++++++++--------
.../streaming/test_pandas_transform_with_state.py | 90 +++++++++
python/pyspark/worker.py | 14 +-
.../streaming/BaseStreamingArrowWriter.scala | 6 +-
.../TransformWithStateInPySparkExec.scala | 10 +-
.../TransformWithStateInPySparkPythonRunner.scala | 66 +++++--
.../streaming/BaseStreamingArrowWriterSuite.scala | 39 ++++
7 files changed, 330 insertions(+), 110 deletions(-)
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index d46e0b2052cc..768160087032 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1492,7 +1492,9 @@ class
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
self.result_state_pdf_arrow_type = to_arrow_type(
self.result_state_df_type,
prefers_large_types=prefers_large_var_types
)
- self.arrow_max_records_per_batch = arrow_max_records_per_batch
+ self.arrow_max_records_per_batch = (
+ arrow_max_records_per_batch if arrow_max_records_per_batch > 0
else 2**31 - 1
+ )
def load_stream(self, stream):
"""
@@ -1847,13 +1849,29 @@ class
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
arrow_cast=True,
)
- self.arrow_max_records_per_batch = arrow_max_records_per_batch
+ self.arrow_max_records_per_batch = (
+ arrow_max_records_per_batch if arrow_max_records_per_batch > 0
else 2**31 - 1
+ )
self.arrow_max_bytes_per_batch = arrow_max_bytes_per_batch
self.key_offsets = None
self.average_arrow_row_size = 0
self.total_bytes = 0
self.total_rows = 0
+ def _update_batch_size_stats(self, batch):
+ """
+ Update batch size statistics for adaptive batching.
+ """
+ # Short circuit batch size calculation if the batch size is
+ # unlimited as computing batch size is computationally expensive.
+ if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0:
+ batch_bytes = sum(
+ buf.size for col in batch.columns for buf in col.buffers() if
buf is not None
+ )
+ self.total_bytes += batch_bytes
+ self.total_rows += batch.num_rows
+ self.average_arrow_row_size = self.total_bytes / self.total_rows
+
def load_stream(self, stream):
"""
Read ArrowRecordBatches from stream, deserialize them to populate a
list of data chunk, and
@@ -1881,18 +1899,7 @@ class
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
def row_stream():
for batch in batches:
- # Short circuit batch size calculation if the batch size is
- # unlimited as computing batch size is computationally
expensive.
- if self.arrow_max_bytes_per_batch != 2**31 - 1 and
batch.num_rows > 0:
- batch_bytes = sum(
- buf.size
- for col in batch.columns
- for buf in col.buffers()
- if buf is not None
- )
- self.total_bytes += batch_bytes
- self.total_rows += batch.num_rows
- self.average_arrow_row_size = self.total_bytes /
self.total_rows
+ self._update_batch_size_stats(batch)
data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in
enumerate(pa.Table.from_batches([batch]).itercolumns())
@@ -1972,6 +1979,7 @@ class
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
def load_stream(self, stream):
import pyarrow as pa
+ import pandas as pd
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasFuncMode,
)
@@ -1990,6 +1998,12 @@ class
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
def flatten_columns(cur_batch, col_name):
state_column =
cur_batch.column(cur_batch.schema.get_field_index(col_name))
+
+ # Check if the entire column is null
+ if state_column.null_count == len(state_column):
+ # Return empty table with no columns
+ return pa.Table.from_arrays([], names=[])
+
state_field_names = [
state_column.type[i].name for i in
range(state_column.type.num_fields)
]
@@ -2007,30 +2021,69 @@ class
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and
pass into the Python
- data generator. All rows in the same batch have the same grouping
key.
+ data generator. Rows in the same batch may have different
grouping keys,
+ but each batch will have either init_data or input_data, not mix.
"""
- for batch in batches:
- flatten_state_table = flatten_columns(batch, "inputData")
- data_pandas = [
- self.arrow_to_pandas(c, i)
- for i, c in enumerate(flatten_state_table.itercolumns())
- ]
- flatten_init_table = flatten_columns(batch, "initState")
- init_data_pandas = [
- self.arrow_to_pandas(c, i)
- for i, c in enumerate(flatten_init_table.itercolumns())
- ]
- key_series = [data_pandas[o] for o in self.key_offsets]
- init_key_series = [init_data_pandas[o] for o in
self.init_key_offsets]
+ def row_stream():
+ for batch in batches:
+ self._update_batch_size_stats(batch)
- if any(s.empty for s in key_series):
- # If any row is empty, assign batch_key using
init_key_series
- batch_key = tuple(s[0] for s in init_key_series)
- else:
- # If all rows are non-empty, create batch_key from
key_series
- batch_key = tuple(s[0] for s in key_series)
- yield (batch_key, data_pandas, init_data_pandas)
+ flatten_state_table = flatten_columns(batch, "inputData")
+ data_pandas = [
+ self.arrow_to_pandas(c, i)
+ for i, c in
enumerate(flatten_state_table.itercolumns())
+ ]
+
+ flatten_init_table = flatten_columns(batch, "initState")
+ init_data_pandas = [
+ self.arrow_to_pandas(c, i)
+ for i, c in enumerate(flatten_init_table.itercolumns())
+ ]
+
+ assert not (bool(init_data_pandas) and bool(data_pandas))
+
+ if bool(data_pandas):
+ for row in pd.concat(data_pandas,
axis=1).itertuples(index=False):
+ batch_key = tuple(row[s] for s in self.key_offsets)
+ yield (batch_key, row, None)
+ elif bool(init_data_pandas):
+ for row in pd.concat(init_data_pandas,
axis=1).itertuples(index=False):
+ batch_key = tuple(row[s] for s in
self.init_key_offsets)
+ yield (batch_key, None, row)
+
+ EMPTY_DATAFRAME = pd.DataFrame()
+ for batch_key, group_rows in groupby(row_stream(), key=lambda x:
x[0]):
+ rows = []
+ init_state_rows = []
+ for _, row, init_state_row in group_rows:
+ if row is not None:
+ rows.append(row)
+ if init_state_row is not None:
+ init_state_rows.append(init_state_row)
+
+ total_len = len(rows) + len(init_state_rows)
+ if (
+ total_len >= self.arrow_max_records_per_batch
+ or total_len * self.average_arrow_row_size >=
self.arrow_max_bytes_per_batch
+ ):
+ yield (
+ batch_key,
+ pd.DataFrame(rows) if len(rows) > 0 else
EMPTY_DATAFRAME.copy(),
+ pd.DataFrame(init_state_rows)
+ if len(init_state_rows) > 0
+ else EMPTY_DATAFRAME.copy(),
+ )
+ rows = []
+ init_state_rows = []
+ if rows or init_state_rows:
+ yield (
+ batch_key,
+ pd.DataFrame(rows) if len(rows) > 0 else
EMPTY_DATAFRAME.copy(),
+ pd.DataFrame(init_state_rows)
+ if len(init_state_rows) > 0
+ else EMPTY_DATAFRAME.copy(),
+ )
_batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)
@@ -2056,7 +2109,9 @@ class
TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):
def __init__(self, arrow_max_records_per_batch):
super(TransformWithStateInPySparkRowSerializer, self).__init__()
- self.arrow_max_records_per_batch = arrow_max_records_per_batch
+ self.arrow_max_records_per_batch = (
+ arrow_max_records_per_batch if arrow_max_records_per_batch > 0
else 2**31 - 1
+ )
self.key_offsets = None
def load_stream(self, stream):
@@ -2148,13 +2203,13 @@ class
TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySp
self.init_key_offsets = None
def load_stream(self, stream):
- import itertools
import pyarrow as pa
from pyspark.sql.streaming.stateful_processor_util import (
TransformWithStateInPandasFuncMode,
)
+ from typing import Iterator, Any, Optional, Tuple
- def generate_data_batches(batches):
+ def generate_data_batches(batches) -> Iterator[Tuple[Any,
Optional[Any], Optional[Any]]]:
"""
Deserialize ArrowRecordBatches and return a generator of Row.
The deserialization logic assumes that Arrow RecordBatches contain
the data with the
@@ -2165,8 +2220,15 @@ class
TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySp
into the data generator.
"""
- def extract_rows(cur_batch, col_name, key_offsets):
+ def extract_rows(
+ cur_batch, col_name, key_offsets
+ ) -> Optional[Iterator[Tuple[Any, Any]]]:
data_column =
cur_batch.column(cur_batch.schema.get_field_index(col_name))
+
+ # Check if the entire column is null
+ if data_column.null_count == len(data_column):
+ return None
+
data_field_names = [
data_column.type[i].name for i in
range(data_column.type.num_fields)
]
@@ -2179,18 +2241,17 @@ class
TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySp
table = pa.Table.from_arrays(data_field_arrays,
names=data_field_names)
if table.num_rows == 0:
- return (None, iter([]))
- else:
- batch_key = tuple(table.column(o)[0].as_py() for o in
key_offsets)
+ return None
- rows = []
+ def row_iterator():
for row_idx in range(table.num_rows):
+ key = tuple(table.column(o)[row_idx].as_py() for o in
key_offsets)
row = DataRow(
*(table.column(i)[row_idx].as_py() for i in
range(table.num_columns))
)
- rows.append(row)
+ yield (key, row)
- return (batch_key, iter(rows))
+ return row_iterator()
"""
The arrow batch is written in the schema:
@@ -2198,49 +2259,45 @@ class
TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySp
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and
pass into the Python
- data generator. All rows in the same batch have the same grouping
key.
+ data generator. Each batch will have either init_data or
input_data, not mix.
"""
for batch in batches:
- (input_batch_key, input_data_iter) = extract_rows(
- batch, "inputData", self.key_offsets
- )
- (init_batch_key, init_state_iter) = extract_rows(
- batch, "initState", self.init_key_offsets
- )
+ # Detect which column has data - each batch contains only one
type
+ input_result = extract_rows(batch, "inputData",
self.key_offsets)
+ init_result = extract_rows(batch, "initState",
self.init_key_offsets)
- if input_batch_key is None:
- batch_key = init_batch_key
- else:
- batch_key = input_batch_key
-
- for init_state_row in init_state_iter:
- yield (batch_key, None, init_state_row)
+ assert not (input_result is not None and init_result is not
None)
- for input_data_row in input_data_iter:
- yield (batch_key, input_data_row, None)
+ if input_result is not None:
+ for key, input_data_row in input_result:
+ yield (key, input_data_row, None)
+ elif init_result is not None:
+ for key, init_state_row in init_result:
+ yield (key, None, init_state_row)
_batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)
for k, g in groupby(data_batches, key=lambda x: x[0]):
- # g: list(batch_key, input_data_iter, init_state_iter)
-
- # they are sharing the iterator, hence need to copy
- input_values_iter, init_state_iter = itertools.tee(g, 2)
-
- chained_input_values = itertools.chain(map(lambda x: x[1],
input_values_iter))
- chained_init_state_values = itertools.chain(map(lambda x: x[2],
init_state_iter))
-
- chained_input_values_without_none = filter(
- lambda x: x is not None, chained_input_values
- )
- chained_init_state_values_without_none = filter(
- lambda x: x is not None, chained_init_state_values
- )
-
- ret_tuple = (chained_input_values_without_none,
chained_init_state_values_without_none)
-
- yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k,
ret_tuple)
+ input_rows = []
+ init_rows = []
+
+ for batch_key, input_row, init_row in g:
+ if input_row is not None:
+ input_rows.append(input_row)
+ if init_row is not None:
+ init_rows.append(init_row)
+
+ total_len = len(input_rows) + len(init_rows)
+ if total_len >= self.arrow_max_records_per_batch:
+ ret_tuple = (iter(input_rows), iter(init_rows))
+ yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k,
ret_tuple)
+ input_rows = []
+ init_rows = []
+
+ if input_rows or init_rows:
+ ret_tuple = (iter(input_rows), iter(init_rows))
+ yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k,
ret_tuple)
yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)
diff --git
a/python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state.py
b/python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state.py
index 57125c7820c0..ecdfcfda3c1d 100644
---
a/python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state.py
+++
b/python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state.py
@@ -1483,6 +1483,96 @@ class TransformWithStateTestsMixin:
),
)
+ def test_transform_with_state_with_records_limit(self):
+ if not self.use_pandas():
+ return
+
+ def make_check_results(expected_per_batch):
+ def check_results(batch_df, batch_id):
+ batch_df.collect()
+ if batch_id == 0:
+ assert set(batch_df.sort("id").collect()) ==
expected_per_batch[0]
+ else:
+ assert set(batch_df.sort("id").collect()) ==
expected_per_batch[1]
+
+ return check_results
+
+ result_with_small_limit = [
+ {
+ Row(id="0", chunkCount=2),
+ Row(id="1", chunkCount=2),
+ },
+ {
+ Row(id="0", chunkCount=3),
+ Row(id="1", chunkCount=2),
+ },
+ ]
+
+ result_with_large_limit = [
+ {
+ Row(id="0", chunkCount=1),
+ Row(id="1", chunkCount=1),
+ },
+ {
+ Row(id="0", chunkCount=1),
+ Row(id="1", chunkCount=1),
+ },
+ ]
+
+ data = [("0", 789), ("3", 987)]
+ initial_state = self.spark.createDataFrame(data, "id string, initVal
int").groupBy("id")
+
+ with self.sql_conf(
+ # Set it to a very small number so that every row would be a
separate pandas df
+ {"spark.sql.execution.arrow.maxRecordsPerBatch": "1"}
+ ):
+ self._test_transform_with_state_basic(
+ ChunkCountProcessorFactory(),
+ make_check_results(result_with_small_limit),
+ output_schema=StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("chunkCount", IntegerType(), True),
+ ]
+ ),
+ )
+
+ self._test_transform_with_state_basic(
+ ChunkCountProcessorWithInitialStateFactory(),
+ make_check_results(result_with_small_limit),
+ initial_state=initial_state,
+ output_schema=StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("chunkCount", IntegerType(), True),
+ ]
+ ),
+ )
+
+ with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch":
"-1"}):
+ self._test_transform_with_state_basic(
+ ChunkCountProcessorFactory(),
+ make_check_results(result_with_large_limit),
+ output_schema=StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("chunkCount", IntegerType(), True),
+ ]
+ ),
+ )
+
+ self._test_transform_with_state_basic(
+ ChunkCountProcessorWithInitialStateFactory(),
+ make_check_results(result_with_large_limit),
+ initial_state=initial_state,
+ output_schema=StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("chunkCount", IntegerType(), True),
+ ]
+ ),
+ )
+
# test all state types (value, list, map) with large values (512 KB)
def test_transform_with_state_large_values(self):
def check_results(batch_df, batch_id):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 50e71fb6da9d..109157e2c339 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -880,12 +880,14 @@ def wrap_grouped_transform_with_state_pandas_udf(f,
return_type, runner_conf):
def wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type,
runner_conf):
def wrapped(stateful_processor_api_client, mode, key, value_series_gen):
- import pandas as pd
-
+ # Split the generator into two using itertools.tee
state_values_gen, init_states_gen = itertools.tee(value_series_gen, 2)
- state_values = (df for x, _ in state_values_gen if not (df :=
pd.concat(x, axis=1)).empty)
- init_states = (df for _, x in init_states_gen if not (df :=
pd.concat(x, axis=1)).empty)
+ # Extract just the data DataFrames (first element of each tuple)
+ state_values = (data_df for data_df, _ in state_values_gen if not
data_df.empty)
+
+ # Extract just the init DataFrames (second element of each tuple)
+ init_states = (init_df for _, init_df in init_states_gen if not
init_df.empty)
result_iter = f(stateful_processor_api_client, mode, key,
state_values, init_states)
# TODO(SPARK-49100): add verification that elements in result_iter are
@@ -3071,8 +3073,8 @@ def read_udfs(pickleSer, infile, eval_type):
def values_gen():
for x in a[2]:
- retVal = [x[1][o] for o in parsed_offsets[0][1]]
- initVal = [x[2][o] for o in parsed_offsets[1][1]]
+ retVal = x[1]
+ initVal = x[2]
yield retVal, initVal
# This must be generator comprehension - do not materialize.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
index f0371cafb72a..8c9ab2a8c636 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala
@@ -88,10 +88,14 @@ class BaseStreamingArrowWriter(
protected def isBatchSizeLimitReached: Boolean = {
// If we have either reached the records or bytes limit
- totalNumRowsForBatch >= arrowMaxRecordsPerBatch ||
+ (arrowMaxRecordsPerBatch > 0 && totalNumRowsForBatch >=
arrowMaxRecordsPerBatch) ||
// Short circuit batch size calculation if the batch size is unlimited
as computing batch
// size is computationally expensive.
((arrowMaxBytesPerBatch != Int.MaxValue)
&& (arrowWriterForData.sizeInBytes() >= arrowMaxBytesPerBatch))
}
+
+ def getTotalNumRowsForBatch: Int = {
+ totalNumRowsForBatch
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
index a1cf71844950..1ceaf6c4bf81 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
@@ -31,7 +31,7 @@ import
org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
PythonUDF}
import org.apache.spark.sql.catalyst.plans.logical.TransformWithStateInPySpark
import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.execution.{CoGroupedIterator, SparkPlan}
+import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowPythonRunner
import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython,
groupAndProject, resolveArgOffsets}
@@ -347,9 +347,9 @@ case class TransformWithStateInPySparkExec(
val initData =
groupAndProject(initStateIterator, initialStateGroupingAttrs,
initialState.output, initDedupAttributes)
- // group input rows and initial state rows by the same grouping key
- val groupedData: Iterator[(InternalRow, Iterator[InternalRow],
Iterator[InternalRow])] =
- new CoGroupedIterator(data, initData, groupingAttributes)
+ // concatenate input rows and initial state rows iterators
+ val inputIter: Iterator[((InternalRow, Iterator[InternalRow]), Boolean)]
=
+ initData.map { item => (item, true) } ++ data.map { item => (item,
false) }
val evalType = {
if (userFacingDataType ==
TransformWithStateInPySpark.UserFacingDataType.PANDAS) {
@@ -374,7 +374,7 @@ case class TransformWithStateInPySparkExec(
batchTimestampMs,
eventTimeWatermarkForEviction
)
- executePython(groupedData, output, runner)
+ executePython(inputIter, output, runner)
}
CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index 42d4ad68c29a..3eb7c7e64d64 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -144,6 +144,9 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
private var pandasWriter: BaseStreamingArrowWriter = _
+ private var currentDataIterator: Iterator[InternalRow] = _
+ private var isCurrentIterFromInitState: Option[Boolean] = None
+
override protected def writeNextBatchToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
@@ -158,30 +161,53 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
)
}
- if (inputIterator.hasNext) {
- val startData = dataOut.size()
- // a new grouping key with data & init state iter
- val next = inputIterator.next()
- val dataIter = next._2
- val initIter = next._3
-
- while (dataIter.hasNext || initIter.hasNext) {
- val dataRow =
- if (dataIter.hasNext) dataIter.next()
- else InternalRow.empty
- val initRow =
- if (initIter.hasNext) initIter.next()
- else InternalRow.empty
- pandasWriter.writeRow(InternalRow(dataRow, initRow))
+
+ // If we don't have data left for the current group, move to the next
group.
+ if (currentDataIterator == null && inputIterator.hasNext) {
+ val ((_, data), isInitState) = inputIterator.next()
+ currentDataIterator = data
+ val isPrevIterFromInitState = isCurrentIterFromInitState
+ isCurrentIterFromInitState = Some(isInitState)
+ if (isPrevIterFromInitState.isDefined &&
+ isPrevIterFromInitState.get != isInitState &&
+ pandasWriter.getTotalNumRowsForBatch > 0) {
+ // So we won't have batches with mixed data and init state.
+ pandasWriter.finalizeCurrentArrowBatch()
+ return true
}
- pandasWriter.finalizeCurrentArrowBatch()
- val deltaData = dataOut.size() - startData
- pythonMetrics("pythonDataSent") += deltaData
+ }
+
+ val startData = dataOut.size()
+
+ val hasInput = if (currentDataIterator != null) {
+ var isCurrentBatchFull = false
+ val isCurrentIterFromInitStateVal = isCurrentIterFromInitState.get
+ // Stop writing when the current arrowBatch is finalized/full. If we
have rows left
+ while (currentDataIterator.hasNext && !isCurrentBatchFull) {
+ val dataRow = currentDataIterator.next()
+ isCurrentBatchFull = if (isCurrentIterFromInitStateVal) {
+ pandasWriter.writeRow(InternalRow(null, dataRow))
+ } else {
+ pandasWriter.writeRow(InternalRow(dataRow, null))
+ }
+ }
+
+ if (!currentDataIterator.hasNext) {
+ currentDataIterator = null
+ }
+
true
} else {
+ if (pandasWriter.getTotalNumRowsForBatch > 0) {
+ pandasWriter.finalizeCurrentArrowBatch()
+ }
super[PythonArrowInput].close()
false
}
+
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
+ hasInput
}
}
@@ -392,5 +418,7 @@ trait TransformWithStateInPySparkPythonRunnerUtils extends
Logging {
object TransformWithStateInPySparkPythonRunner {
type InType = (InternalRow, Iterator[InternalRow])
- type GroupedInType = (InternalRow, Iterator[InternalRow],
Iterator[InternalRow])
+
+ // ((key, rows), isInitState)
+ type GroupedInType = ((InternalRow, Iterator[InternalRow]), Boolean)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
index fc10a102b4f5..49839fb8c985 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala
@@ -95,4 +95,43 @@ class BaseStreamingArrowWriterSuite extends SparkFunSuite
with BeforeAndAfterEac
verify(writer, times(2)).writeBatch()
verify(arrowWriter, times(2)).reset()
}
+
+ test("test negative or zero arrowMaxRecordsPerBatch is unlimited") {
+ val root: VectorSchemaRoot = mock(classOf[VectorSchemaRoot])
+ val dataRow = mock(classOf[InternalRow])
+
+ // Test with negative value
+ transformWithStateInPySparkWriter = new BaseStreamingArrowWriter(
+ root, writer, -1, arrowMaxBytesPerBatch, arrowWriter)
+
+ // Write many rows (more than typical batch size)
+ for (_ <- 1 to 10) {
+ transformWithStateInPySparkWriter.writeRow(dataRow)
+ }
+
+ // Verify all rows were written but batch was not finalized
+ verify(arrowWriter, times(10)).write(dataRow)
+ verify(writer, never()).writeBatch()
+
+ // Only finalize when explicitly called
+ transformWithStateInPySparkWriter.finalizeCurrentArrowBatch()
+ verify(writer).writeBatch()
+
+ // Test with zero value
+ transformWithStateInPySparkWriter = new BaseStreamingArrowWriter(
+ root, writer, 0, arrowMaxBytesPerBatch, arrowWriter)
+
+ // Write many rows again
+ for (_ <- 1 to 10) {
+ transformWithStateInPySparkWriter.writeRow(dataRow)
+ }
+
+ // Verify rows were written but batch was not finalized
+ verify(arrowWriter, times(20)).write(dataRow)
+ verify(writer).writeBatch() // still 1 from before
+
+ // Only finalize when explicitly called
+ transformWithStateInPySparkWriter.finalizeCurrentArrowBatch()
+ verify(writer, times(2)).writeBatch()
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]