This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 07e9beaca3c2 [SPARK-55529][PYTHON] Restore Arrow-level batch merge for
non-iterator applyInPandas
07e9beaca3c2 is described below
commit 07e9beaca3c299c2d3c70666843de8fe063de673
Author: Yicong Huang <[email protected]>
AuthorDate: Tue Feb 17 13:19:02 2026 -0800
[SPARK-55529][PYTHON] Restore Arrow-level batch merge for non-iterator
applyInPandas
### What changes were proposed in this pull request?
Optimize the non-iterator `applyInPandas` path by merging Arrow batches at
the Arrow level before converting to pandas, instead of converting each batch
individually and reassembling via per-column `pd.concat`.
Changes:
- **`GroupPandasUDFSerializer.load_stream`**: Yield raw
`Iterator[pa.RecordBatch]` instead of converting to pandas per-batch via
`ArrowBatchTransformer.to_pandas`.
- **Non-iterator mapper**: Collect all Arrow batches →
`pa.Table.from_batches().combine_chunks()` → convert to pandas once for the
entire group.
- **`wrap_grouped_map_pandas_udf`**: Simplified to accept a list of pandas
Series directly.
- **Iterator mapper**: Split into its own `elif` branch; still converts
batches lazily via `ArrowBatchTransformer.to_pandas` per-batch.
### Why are the changes needed?
Follow-up to SPARK-55459. After SPARK-54316 consolidated the grouped-map
serializer, the non-iterator `applyInPandas` lost its efficient Arrow-level
batch merge and instead converts each batch to pandas individually, then
reassembles via per-column `pd.concat`. This PR restores the Arrow-level merge
so that all batches within a group are merged into a single `pa.Table` and
converted to pandas once, rather than N times (once per batch).
A pure-Python microbenchmark (335 groups × 100K rows × 5 columns, 7 runs
each):
| Approach | avg | min | vs Master |
|---|---|---|---|
| Master (per-batch convert + per-column concat) | 0.544s | 0.543s |
baseline |
| This PR (Arrow-level merge + single convert) | 0.489s | 0.487s | **10%
faster** |
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing `applyInPandas` tests (`test_pandas_grouped_map.py`,
`test_pandas_cogrouped_map.py`).
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54327 from
Yicong-Huang/SPARK-55529/optimize-apply-in-pandas-arrow-merge.
Authored-by: Yicong Huang <[email protected]>
Signed-off-by: Takuya Ueshin <[email protected]>
---
python/pyspark/sql/pandas/serializers.py | 22 ++-------
python/pyspark/worker.py | 83 ++++++++++++++++----------------
2 files changed, 45 insertions(+), 60 deletions(-)
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index c6f4676790d2..fd7237c3426f 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1205,27 +1205,13 @@ class
GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
def load_stream(self, stream):
"""
- Deserialize Grouped ArrowRecordBatches and yield as
Iterator[Iterator[pd.Series]].
- Each outer iterator element represents a group, containing an iterator
of Series lists
- (one list per batch).
+ Deserialize Grouped ArrowRecordBatches and yield raw
Iterator[pa.RecordBatch].
+ Each outer iterator element represents a group.
"""
for (batches,) in self._load_group_dataframes(stream, num_dfs=1):
- # Lazily read and convert Arrow batches one at a time from the
stream
- # This avoids loading all batches into memory for the group
- series_iter = map(
- lambda batch: ArrowBatchTransformer.to_pandas(
- batch,
- timezone=self._timezone,
- schema=self._input_type,
- struct_in_pandas=self._struct_in_pandas,
- ndarray_as_list=self._ndarray_as_list,
- df_for_struct=self._df_for_struct,
- ),
- batches,
- )
- yield series_iter
+ yield batches
# Make sure the batches are fully iterated before getting the next
group
- for _ in series_iter:
+ for _ in batches:
pass
def dump_stream(self, iterator, stream):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index b0a4dcedd256..a590b3ed47c4 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -815,34 +815,10 @@ def wrap_grouped_map_arrow_iter_udf(f, return_type,
argspec, runner_conf):
def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
- def wrapped(key_series, value_batches):
+ def wrapped(key_series, value_series):
import pandas as pd
- # Convert value_batches (Iterator[list[pd.Series]]) to a single
DataFrame
- # Optimized: Collect all Series by column, then concat once per column
- # This avoids the expensive pd.concat(axis=0) across many DataFrames
- all_series_by_col = {}
-
- for value_series in value_batches:
- for col_idx, series in enumerate(value_series):
- if col_idx not in all_series_by_col:
- all_series_by_col[col_idx] = []
- all_series_by_col[col_idx].append(series)
-
- # Concatenate each column separately (single concat per column)
- if all_series_by_col:
- columns = {}
- for col_idx, series_list in all_series_by_col.items():
- # Use the original series name if available
- col_name = (
- series_list[0].name
- if hasattr(series_list[0], "name") and series_list[0].name
- else f"col{col_idx}"
- )
- columns[col_name] = pd.concat(series_list, ignore_index=True)
- value_df = pd.DataFrame(columns)
- else:
- value_df = pd.DataFrame()
+ value_df = pd.concat(value_series, axis=1)
if len(argspec.args) == 1:
result = f(value_df)
@@ -2955,10 +2931,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf,
eval_conf):
idx += offsets_len
return parsed
- if (
- eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
- or eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
- ):
+ if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
import pyarrow as pa
# We assume there is only one UDF here because grouped map doesn't
@@ -2970,21 +2943,47 @@ def read_udfs(pickleSer, infile, eval_type,
runner_conf, eval_conf):
arg_offsets, f = udfs[0]
parsed_offsets = extract_key_value_indexes(arg_offsets)
- def mapper(series_iter):
- # Need to materialize the first series list to get the keys
- first_series_list = next(series_iter)
+ key_offsets = parsed_offsets[0][0]
+ value_offsets = parsed_offsets[0][1]
- # Extract key Series from the first batch
- key_series = [first_series_list[o] for o in parsed_offsets[0][0]]
+ def mapper(batch_iter):
+ # Collect all Arrow batches and merge at Arrow level
+ all_batches = list(batch_iter)
+ if all_batches:
+ table = pa.Table.from_batches(all_batches).combine_chunks()
+ else:
+ table = pa.table({})
+ # Convert to pandas once for the entire group
+ all_series = ArrowBatchTransformer.to_pandas(table,
timezone=ser._timezone)
+ key_series = [all_series[o] for o in key_offsets]
+ value_series = [all_series[o] for o in value_offsets]
+ yield from f(key_series, value_series)
- # Create generator for value Series lists (one list per batch)
- value_series_gen = (
- [series_list[o] for o in parsed_offsets[0][1]]
- for series_list in itertools.chain((first_series_list,),
series_iter)
- )
+ elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
+ import pyarrow as pa
- # Flatten one level: yield from wrapper to return Iterator[(df,
spark_type)]
- yield from f(key_series, value_series_gen)
+ # We assume there is only one UDF here because grouped map doesn't
+ # support combining multiple UDFs.
+ assert num_udfs == 1
+
+ # See FlatMapGroupsInPandasExec for how arg_offsets are used to
+ # distinguish between grouping attributes and data attributes
+ arg_offsets, f = udfs[0]
+ parsed_offsets = extract_key_value_indexes(arg_offsets)
+
+ def mapper(batch_iter):
+ # Convert first Arrow batch to pandas to extract keys
+ first_series = ArrowBatchTransformer.to_pandas(next(batch_iter),
timezone=ser._timezone)
+ key_series = [first_series[o] for o in parsed_offsets[0][0]]
+
+ # Lazily convert remaining Arrow batches to pandas Series
+ def value_series_gen():
+ yield [first_series[o] for o in parsed_offsets[0][1]]
+ for batch in batch_iter:
+ series = ArrowBatchTransformer.to_pandas(batch,
timezone=ser._timezone)
+ yield [series[o] for o in parsed_offsets[0][1]]
+
+ yield from f(key_series, value_series_gen())
elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
# We assume there is only one UDF here because grouped map doesn't
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]