[ 
https://issues.apache.org/jira/browse/SPARK-54647?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

ASF GitHub Bot updated SPARK-54647:
-----------------------------------
    Labels: pull-request-available  (was: )

> Support User-Defined Aggregate Functions (UDAF) in PySpark
> ----------------------------------------------------------
>
>                 Key: SPARK-54647
>                 URL: https://issues.apache.org/jira/browse/SPARK-54647
>             Project: Spark
>          Issue Type: New Feature
>          Components: PySpark
>    Affects Versions: 4.2.0
>            Reporter: Yicong Huang
>            Priority: Major
>              Labels: pull-request-available
>
> Currently PySpark supports User-Defined Functions (UDF) and User-Defined 
> Table Functions (UDTF), but lacks support for User-Defined Aggregate 
> Functions (UDAF). Users need to write custom aggregation logic in Scala/Java 
> or use less efficient workarounds. This limits the ability to express complex 
> aggregation logic directly in Python.
> This change adds UDAF support in PySpark using a two-stage aggregation 
> pattern with mapInArrow and applyInArrow. The basic idea is to implement 
> aggregation (and partial aggregation) by:
> {code}
> df.selectExpr("rand() as 
> key").mapInArrow(func1).groupBy(key).applyInArrow(func2)
> {code}
> Where func1 calls Aggregator.reduce() for partial aggregation within each 
> partition, and func2 calls Aggregator.merge() to combine partial results, 
> then Aggregator.finish() for final results.
> The implementation provides a Python Aggregator base class that users can 
> subclass:
> {code:python}
> class Aggregator:
>     def zero(self) -> Any:
>         """Return zero value for aggregation buffer"""
>         raise NotImplementedError
>     
>     def reduce(self, buffer: Any, value: Any) -> Any:
>         """Combine input value into buffer"""
>         raise NotImplementedError
>     
>     def merge(self, buffer1: Any, buffer2: Any) -> Any:
>         """Merge two intermediate buffers"""
>         raise NotImplementedError
>     
>     def finish(self, reduction: Any) -> Any:
>         """Produce final result from buffer"""
>         raise NotImplementedError
> {code}
> Users can create UDAF instances using the udaf() function and use them with 
> DataFrame.agg():
> {code:python}
> sum_udaf = udaf(MySum(), "bigint")
> df.agg(sum_udaf(df.value))
> df.groupBy("group").agg(sum_udaf(df.value))
> {code}
> The implementation uses iterator-based mapInArrow and applyInArrow APIs to 
> reduce memory footprint, serializes Aggregator instances using cloudpickle 
> for distribution to worker nodes, and handles both grouped and non-grouped 
> aggregations. The random key range is determined by 
> spark.sql.shuffle.partitions or DataFrame partition count to ensure proper 
> data distribution.



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to