Yicong Huang created SPARK-54647:
------------------------------------
Summary: Support User-Defined Aggregate Functions (UDAF) in PySpark
Key: SPARK-54647
URL: https://issues.apache.org/jira/browse/SPARK-54647
Project: Spark
Issue Type: New Feature
Components: PySpark
Affects Versions: 4.2.0
Reporter: Yicong Huang
Currently PySpark supports User-Defined Functions (UDF) and User-Defined Table
Functions (UDTF), but lacks support for User-Defined Aggregate Functions
(UDAF). Users need to write custom aggregation logic in Scala/Java or use less
efficient workarounds. This limits the ability to express complex aggregation
logic directly in Python.
This change adds UDAF support in PySpark using a two-stage aggregation pattern
with mapInArrow and applyInArrow. The basic idea is to implement aggregation
(and partial aggregation) by:
{code}
df.selectExpr("rand() as
key").mapInArrow(func1).groupBy(key).applyInArrow(func2)
{code}
Where func1 calls Aggregator.reduce() for partial aggregation within each
partition, and func2 calls Aggregator.merge() to combine partial results, then
Aggregator.finish() for final results.
The implementation provides a Python Aggregator base class that users can
subclass:
{code:python}
class Aggregator:
def zero(self) -> Any:
"""Return zero value for aggregation buffer"""
raise NotImplementedError
def reduce(self, buffer: Any, value: Any) -> Any:
"""Combine input value into buffer"""
raise NotImplementedError
def merge(self, buffer1: Any, buffer2: Any) -> Any:
"""Merge two intermediate buffers"""
raise NotImplementedError
def finish(self, reduction: Any) -> Any:
"""Produce final result from buffer"""
raise NotImplementedError
{code}
Users can create UDAF instances using the udaf() function and use them with
DataFrame.agg():
{code:python}
sum_udaf = udaf(MySum(), "bigint")
df.agg(sum_udaf(df.value))
df.groupBy("group").agg(sum_udaf(df.value))
{code}
The implementation uses iterator-based mapInArrow and applyInArrow APIs to
reduce memory footprint, serializes Aggregator instances using cloudpickle for
distribution to worker nodes, and handles both grouped and non-grouped
aggregations. The random key range is determined by
spark.sql.shuffle.partitions or DataFrame partition count to ensure proper data
distribution.
--
This message was sent by Atlassian Jira
(v8.20.10#820010)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]