Yicong-Huang opened a new pull request, #53400:
URL: https://github.com/apache/spark/pull/53400
### What changes were proposed in this pull request?
Add support for User-Defined Aggregate Functions (UDAF) in PySpark.
Currently PySpark supports User-Defined Functions (UDF) and User-Defined Table
Functions (UDTF), but lacks support for UDAF. Users need to write custom
aggregation logic in Scala/Java or use less efficient workarounds.
This change adds UDAF support using a two-stage aggregation pattern with
`mapInArrow` and `applyInArrow`. The basic idea is to implement aggregation
(and partial aggregation) by:
```python
df.selectExpr("rand() as
key").mapInArrow(func1).groupBy(key).applyInArrow(func2)
```
Where `func1` calls `Aggregator.reduce()` for partial aggregation within
each partition, and `func2` calls `Aggregator.merge()` to combine partial
results, then `Aggregator.finish()` for final results.
Aligned with Scala side, the implementation provides a Python `Aggregator`
base class that users can subclass:
```python
class Aggregator:
def zero(self) -> Any:
"""Return zero value for aggregation buffer"""
raise NotImplementedError
def reduce(self, buffer: Any, value: Any) -> Any:
"""Combine input value into buffer"""
raise NotImplementedError
def merge(self, buffer1: Any, buffer2: Any) -> Any:
"""Merge two intermediate buffers"""
raise NotImplementedError
def finish(self, reduction: Any) -> Any:
"""Produce final result from buffer"""
raise NotImplementedError
```
Users can create UDAF instances using the `udaf()` function and use them
with `DataFrame.agg()`:
```python
sum_udaf = udaf(MySum(), "bigint")
df.agg(sum_udaf(df.value))
df.groupBy("group").agg(sum_udaf(df.value))
```
Key changes:
- Added `pyspark.sql.udaf` module with `Aggregator` base class,
`UserDefinedAggregateFunction` wrapper, and `udaf()` factory function
- Integrated UDAF support in `GroupedData.agg()` by detecting UDAF columns
via `_udaf_func` attribute
### Why are the changes needed?
Currently PySpark lacks support for User-Defined Aggregate Functions (UDAF),
which limits users' ability to express complex aggregation logic directly in
Python. Users must either write custom aggregation logic in Scala/Java or use
less efficient workarounds. This change adds UDAF support to complement
existing UDF and UDTF support in PySpark, aligning with the Scala/Java
`Aggregator` interface in `org.apache.spark.sql.expressions.Aggregator`.
### Does this PR introduce _any_ user-facing change?
Yes. This PR adds a new feature - User-Defined Aggregate Functions (UDAF)
support in PySpark. Users can now define custom aggregation logic by
subclassing the `Aggregator` class and using the `udaf()` function to create
UDAF instances that can be used with `DataFrame.agg()` and `GroupedData.agg()`.
Example:
```python
class MySum(Aggregator):
def zero(self):
return 0
def reduce(self, buffer, value):
return buffer + value
def merge(self, buffer1, buffer2):
return buffer1 + buffer2
def finish(self, reduction):
return reduction
sum_udaf = udaf(MySum(), "bigint")
df.agg(sum_udaf(df.value))
```
### How was this patch tested?
Added comprehensive unit tests in `python/pyspark/sql/tests/test_udaf.py`
covering:
- Basic aggregation (sum, average, max)
- Grouped aggregation with `groupBy().agg()`
- Null value handling
- Empty DataFrame handling
- Large datasets (20000+ rows) distributed across partitions
- Error handling for invalid inputs
- Integration with `df.agg()` and `df.groupBy().agg()`
### Was this patch authored or co-authored using generative AI tooling?
No.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]