[
https://issues.apache.org/jira/browse/SPARK-54647?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
]
ASF GitHub Bot updated SPARK-54647:
-----------------------------------
Labels: pull-request-available (was: )
> Support User-Defined Aggregate Functions (UDAF) in PySpark
> ----------------------------------------------------------
>
> Key: SPARK-54647
> URL: https://issues.apache.org/jira/browse/SPARK-54647
> Project: Spark
> Issue Type: New Feature
> Components: PySpark
> Affects Versions: 4.2.0
> Reporter: Yicong Huang
> Priority: Major
> Labels: pull-request-available
>
> Currently PySpark supports User-Defined Functions (UDF) and User-Defined
> Table Functions (UDTF), but lacks support for User-Defined Aggregate
> Functions (UDAF). Users need to write custom aggregation logic in Scala/Java
> or use less efficient workarounds. This limits the ability to express complex
> aggregation logic directly in Python.
> This change adds UDAF support in PySpark using a two-stage aggregation
> pattern with mapInArrow and applyInArrow. The basic idea is to implement
> aggregation (and partial aggregation) by:
> {code}
> df.selectExpr("rand() as
> key").mapInArrow(func1).groupBy(key).applyInArrow(func2)
> {code}
> Where func1 calls Aggregator.reduce() for partial aggregation within each
> partition, and func2 calls Aggregator.merge() to combine partial results,
> then Aggregator.finish() for final results.
> The implementation provides a Python Aggregator base class that users can
> subclass:
> {code:python}
> class Aggregator:
> def zero(self) -> Any:
> """Return zero value for aggregation buffer"""
> raise NotImplementedError
>
> def reduce(self, buffer: Any, value: Any) -> Any:
> """Combine input value into buffer"""
> raise NotImplementedError
>
> def merge(self, buffer1: Any, buffer2: Any) -> Any:
> """Merge two intermediate buffers"""
> raise NotImplementedError
>
> def finish(self, reduction: Any) -> Any:
> """Produce final result from buffer"""
> raise NotImplementedError
> {code}
> Users can create UDAF instances using the udaf() function and use them with
> DataFrame.agg():
> {code:python}
> sum_udaf = udaf(MySum(), "bigint")
> df.agg(sum_udaf(df.value))
> df.groupBy("group").agg(sum_udaf(df.value))
> {code}
> The implementation uses iterator-based mapInArrow and applyInArrow APIs to
> reduce memory footprint, serializes Aggregator instances using cloudpickle
> for distribution to worker nodes, and handles both grouped and non-grouped
> aggregations. The random key range is determined by
> spark.sql.shuffle.partitions or DataFrame partition count to ensure proper
> data distribution.
--
This message was sent by Atlassian Jira
(v8.20.10#820010)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]