Yicong-Huang opened a new pull request, #53035:
URL: https://github.com/apache/spark/pull/53035

   ### What changes were proposed in this pull request?
   
   This PR introduces an iterator API for Arrow grouped aggregation UDFs in 
PySpark. It adds support for two new UDF patterns:
   - `Iterator[pa.Array] -> Any` for single column aggregations
   - `Iterator[Tuple[pa.Array, ...]] -> Any` for multiple column aggregations
   
   The implementation adds a new Python eval type 
`SQL_GROUPED_AGG_ARROW_ITER_UDF` with corresponding support in type inference, 
worker serialization, and Scala execution planning.
   
   ### Why are the changes needed?
   
   The current Arrow grouped aggregation API requires loading all data for a 
group into memory at once, which can be problematic for groups with large 
amounts of data. The iterator API allows processing data in batches, providing:
   
   1. **Memory Efficiency**: Processes data incrementally rather than loading 
entire group into memory
   2. **Consistency**: Aligns with existing iterator APIs (e.g., 
`SQL_SCALAR_ARROW_ITER_UDF`)
   3. **Flexibility**: Allows initialization of expensive state once per group 
while processing batches iteratively
   
   ### Does this PR introduce _any_ user-facing change?
   
   Yes. This PR adds a new API pattern for Arrow grouped aggregation UDFs:
   
   **Single column aggregation:**
   ```python
   import pyarrow as pa
   from typing import Iterator
   from pyspark.sql.functions import arrow_udf
   
   @arrow_udf("double")
   def arrow_mean(it: Iterator[pa.Array]) -> float:
       sum_val = 0.0
       cnt = 0
       for v in it:
           sum_val += pa.compute.sum(v).as_py()
           cnt += len(v)
       return sum_val / cnt if cnt > 0 else 0.0
   
   df.groupby("id").agg(arrow_mean(df['v'])).show()
   ```
   
   **Multiple column aggregation:**
   ```python
   import pyarrow as pa
   import numpy as np
   from typing import Iterator, Tuple
   from pyspark.sql.functions import arrow_udf
   
   @arrow_udf("double")
   def arrow_weighted_mean(it: Iterator[Tuple[pa.Array, pa.Array]]) -> float:
       weighted_sum = 0.0
       weight = 0.0
       for v, w in it:
           weighted_sum += np.dot(v.to_numpy(), w.to_numpy())
           weight += pa.compute.sum(w).as_py()
       return weighted_sum / weight if weight > 0 else 0.0
   
   df.groupby("id").agg(arrow_weighted_mean(df["v"], df["w"])).show()
   ```
   
   ### How was this patch tested?
   
   Added comprehensive unit tests in 
`python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py`:
   
   1. `test_iterator_grouped_agg_single_column()` - Tests single column 
iterator aggregation with `Iterator[pa.Array]`
   2. `test_iterator_grouped_agg_multiple_columns()` - Tests multiple column 
iterator aggregation with `Iterator[Tuple[pa.Array, pa.Array]]`
   3. `test_iterator_grouped_agg_eval_type()` - Verifies correct eval type 
inference from type hints
   
   ### Was this patch authored or co-authored using generative AI tooling?
   
   Co-Generated-by: Cursor with Claude Sonnet 4.5
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to