This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 40993e671b87 [SPARK-53361][SS][1/2] Optimizing JVM–Python Communication in TWS by Grouping Multiple Keys into One Arrow Batch 40993e671b87 is described below commit 40993e671b8758893a09d1471938f1581732ecf5 Author: zeruibao <zerui....@databricks.com> AuthorDate: Mon Sep 15 17:03:29 2025 +0900 [SPARK-53361][SS][1/2] Optimizing JVM–Python Communication in TWS by Grouping Multiple Keys into One Arrow Batch ### What changes were proposed in this pull request? This PR introduces an optimization to JVM–Python communication in TWS by allowing multiple keys to be grouped into a single Arrow batch. Currently, each Arrow batch is restricted to contain records for a single key. In high-cardinality scenarios, this results in many small Arrow batches (e.g., [(key1, value1), (key1, value2)], [(key2, value1), (key2, value2)]), which increases the overhead of Arrow batch transmission between the JVM and Python. With this change, records with different keys can be bin-packed into the same Arrow batch, reducing the number of batches transmitted. On the Python side, we leverage groupBy to regroup records by key, mirroring the behavior of the Scala GroupedIterator implementation. This PR only handle `TransformWithStateInPySparkPythonRunner`. `TransformWithStateInPySparkPythonInitialStateRunner` would only affect the batch 0 so that we will leave to another PR. This approach significantly reduces transmission overhead while preserving correct grouping semantics. ### Why are the changes needed? Benchmark results show that in high-cardinality scenarios, this optimization improves throughput by ~20% by reducing the overhead of Arrow batch transmission. For low-cardinality scenarios, the change introduces no observable regression,. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing UT and Benchmark. Performance improvement with Pyspark without initial state (**one** shuffle partitions). Full cardinality (Each key is distinct): - Without Optimization: 4350 Rows/second - With Optimization: 13470 Rows/second Performance improvement with Pandas without initial state (**six** shuffle partitions). 10,000,000 distinct keys: - Without Optimization: 5000 records/s - With Optimization: 6221 records/s ### Was this patch authored or co-authored using generative AI tooling? No Closes #52331 from zeruibao/zeruibao/SPARK-53361-only-improve-non-initial-state. Authored-by: zeruibao <zerui....@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- python/pyspark/sql/pandas/serializers.py | 37 +++++++++++++--------- python/pyspark/worker.py | 7 ++-- .../streaming/BaseStreamingArrowWriter.scala | 6 ++-- .../TransformWithStateInPySparkPythonRunner.scala | 34 ++++++++++++++------ 4 files changed, 52 insertions(+), 32 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 0965a2e7b546..d1bdfa9e8d01 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1603,6 +1603,7 @@ class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): this function works in overall. """ import pyarrow as pa + import pandas as pd from pyspark.sql.streaming.stateful_processor_util import ( TransformWithStateInPandasFuncMode, ) @@ -1617,14 +1618,20 @@ class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): This function must avoid materializing multiple Arrow RecordBatches into memory at the same time. And data chunks from the same grouping key should appear sequentially. """ - for batch in batches: - data_pandas = [ - self.arrow_to_pandas(c, i) - for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns()) - ] - key_series = [data_pandas[o] for o in self.key_offsets] - batch_key = tuple(s[0] for s in key_series) - yield (batch_key, data_pandas) + + def row_stream(): + for batch in batches: + data_pandas = [ + self.arrow_to_pandas(c, i) + for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns()) + ] + for row in pd.concat(data_pandas, axis=1).itertuples(index=False): + batch_key = tuple(row[s] for s in self.key_offsets) + yield (batch_key, row) + + for batch_key, group_rows in groupby(row_stream(), key=lambda x: x[0]): + df = pd.DataFrame([row for _, row in group_rows]) + yield (batch_key, df) _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) data_batches = generate_data_batches(_batches) @@ -1793,15 +1800,15 @@ class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer): same time. And data chunks from the same grouping key should appear sequentially. """ for batch in batches: - DataRow = Row(*(batch.schema.names)) + DataRow = Row(*batch.schema.names) - # This is supposed to be the same. - batch_key = tuple(batch[o][0].as_py() for o in self.key_offsets) + # Iterate row by row without converting the whole batch + num_cols = batch.num_columns for row_idx in range(batch.num_rows): - row = DataRow( - *(batch.column(i)[row_idx].as_py() for i in range(batch.num_columns)) - ) - yield (batch_key, row) + # build the key for this row + row_key = tuple(batch[o][row_idx].as_py() for o in self.key_offsets) + row = DataRow(*(batch.column(i)[row_idx].as_py() for i in range(num_cols))) + yield row_key, row _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream) data_batches = generate_data_batches(_batches) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 292e4174818e..e73f16464aaa 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -694,10 +694,7 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf): def wrapped(stateful_processor_api_client, mode, key, value_series_gen): - import pandas as pd - - values = (pd.concat(x, axis=1) for x in value_series_gen) - result_iter = f(stateful_processor_api_client, mode, key, values) + result_iter = f(stateful_processor_api_client, mode, key, value_series_gen) # TODO(SPARK-49100): add verification that elements in result_iter are # indeed of type pd.DataFrame and confirm to assigned cols @@ -2496,7 +2493,7 @@ def read_udfs(pickleSer, infile, eval_type): def values_gen(): for x in a[2]: - retVal = [x[1][o] for o in parsed_offsets[0][1]] + retVal = x[1].iloc[:, parsed_offsets[0][1]] yield retVal # This must be generator comprehension - do not materialize. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala index 73c70a618866..ba8b2c3ac7da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala @@ -50,11 +50,12 @@ class BaseStreamingArrowWriter( * * @param dataRow The row to write for current batch. */ - def writeRow(dataRow: InternalRow): Unit = { + def writeRow(dataRow: InternalRow): Boolean = { // If it exceeds the condition of batch (number of records) and there is more data for the // same group, finalize and construct a new batch. - if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) { + val isCurrentBatchFull = totalNumRowsForBatch >= arrowMaxRecordsPerBatch + if (isCurrentBatchFull) { finalizeCurrentChunk(isLastChunkForGroup = false) finalizeCurrentArrowBatch() } @@ -63,6 +64,7 @@ class BaseStreamingArrowWriter( numRowsForCurrentChunk += 1 totalNumRowsForBatch += 1 + isCurrentBatchFull } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala index 329bd4335265..f0df3e1f7d15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala @@ -66,6 +66,9 @@ class TransformWithStateInPySparkPythonRunner( private var pandasWriter: BaseStreamingArrowWriter = _ + private var currentDataIterator: Iterator[InternalRow] = _ + + // Grouping multiple keys into one arrow batch override protected def writeNextBatchToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, @@ -75,23 +78,34 @@ class TransformWithStateInPySparkPythonRunner( pandasWriter = new BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch) } - if (inputIterator.hasNext) { - val startData = dataOut.size() - val next = inputIterator.next() - val dataIter = next._2 + // If we don't have data left for the current group, move to the next group. + if (currentDataIterator == null && inputIterator.hasNext) { + val (_, dataIter) = inputIterator.next() + currentDataIterator = dataIter + } - while (dataIter.hasNext) { - val dataRow = dataIter.next() - pandasWriter.writeRow(dataRow) + val startData = dataOut.size() + val hasInput = if (currentDataIterator != null) { + var isCurrentBatchFull = false + // Stop writing when the current arrowBatch is finalized/full. If we have rows left + while (currentDataIterator.hasNext && !isCurrentBatchFull) { + val dataRow = currentDataIterator.next() + isCurrentBatchFull = pandasWriter.writeRow(dataRow) } - pandasWriter.finalizeCurrentArrowBatch() - val deltaData = dataOut.size() - startData - pythonMetrics("pythonDataSent") += deltaData + + if (!currentDataIterator.hasNext) { + currentDataIterator = null + } + true } else { + pandasWriter.finalizeCurrentArrowBatch() super[PythonArrowInput].close() false } + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData + hasInput } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org