Yicong-Huang opened a new pull request, #53035:
URL: https://github.com/apache/spark/pull/53035
### What changes were proposed in this pull request?
This PR introduces an iterator API for Arrow grouped aggregation UDFs in
PySpark. It adds support for two new UDF patterns:
- `Iterator[pa.Array] -> Any` for single column aggregations
- `Iterator[Tuple[pa.Array, ...]] -> Any` for multiple column aggregations
The implementation adds a new Python eval type
`SQL_GROUPED_AGG_ARROW_ITER_UDF` with corresponding support in type inference,
worker serialization, and Scala execution planning.
### Why are the changes needed?
The current Arrow grouped aggregation API requires loading all data for a
group into memory at once, which can be problematic for groups with large
amounts of data. The iterator API allows processing data in batches, providing:
1. **Memory Efficiency**: Processes data incrementally rather than loading
entire group into memory
2. **Consistency**: Aligns with existing iterator APIs (e.g.,
`SQL_SCALAR_ARROW_ITER_UDF`)
3. **Flexibility**: Allows initialization of expensive state once per group
while processing batches iteratively
### Does this PR introduce _any_ user-facing change?
Yes. This PR adds a new API pattern for Arrow grouped aggregation UDFs:
**Single column aggregation:**
```python
import pyarrow as pa
from typing import Iterator
from pyspark.sql.functions import arrow_udf
@arrow_udf("double")
def arrow_mean(it: Iterator[pa.Array]) -> float:
sum_val = 0.0
cnt = 0
for v in it:
sum_val += pa.compute.sum(v).as_py()
cnt += len(v)
return sum_val / cnt if cnt > 0 else 0.0
df.groupby("id").agg(arrow_mean(df['v'])).show()
```
**Multiple column aggregation:**
```python
import pyarrow as pa
import numpy as np
from typing import Iterator, Tuple
from pyspark.sql.functions import arrow_udf
@arrow_udf("double")
def arrow_weighted_mean(it: Iterator[Tuple[pa.Array, pa.Array]]) -> float:
weighted_sum = 0.0
weight = 0.0
for v, w in it:
weighted_sum += np.dot(v.to_numpy(), w.to_numpy())
weight += pa.compute.sum(w).as_py()
return weighted_sum / weight if weight > 0 else 0.0
df.groupby("id").agg(arrow_weighted_mean(df["v"], df["w"])).show()
```
### How was this patch tested?
Added comprehensive unit tests in
`python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py`:
1. `test_iterator_grouped_agg_single_column()` - Tests single column
iterator aggregation with `Iterator[pa.Array]`
2. `test_iterator_grouped_agg_multiple_columns()` - Tests multiple column
iterator aggregation with `Iterator[Tuple[pa.Array, pa.Array]]`
3. `test_iterator_grouped_agg_eval_type()` - Verifies correct eval type
inference from type hints
### Was this patch authored or co-authored using generative AI tooling?
Co-Generated-by: Cursor with Claude Sonnet 4.5
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]