Yicong Huang created SPARK-54647:
------------------------------------

             Summary: Support User-Defined Aggregate Functions (UDAF) in PySpark
                 Key: SPARK-54647
                 URL: https://issues.apache.org/jira/browse/SPARK-54647
             Project: Spark
          Issue Type: New Feature
          Components: PySpark
    Affects Versions: 4.2.0
            Reporter: Yicong Huang


Currently PySpark supports User-Defined Functions (UDF) and User-Defined Table 
Functions (UDTF), but lacks support for User-Defined Aggregate Functions 
(UDAF). Users need to write custom aggregation logic in Scala/Java or use less 
efficient workarounds. This limits the ability to express complex aggregation 
logic directly in Python.

This change adds UDAF support in PySpark using a two-stage aggregation pattern 
with mapInArrow and applyInArrow. The basic idea is to implement aggregation 
(and partial aggregation) by:

{code}
df.selectExpr("rand() as 
key").mapInArrow(func1).groupBy(key).applyInArrow(func2)
{code}

Where func1 calls Aggregator.reduce() for partial aggregation within each 
partition, and func2 calls Aggregator.merge() to combine partial results, then 
Aggregator.finish() for final results.

The implementation provides a Python Aggregator base class that users can 
subclass:

{code:python}
class Aggregator:
    def zero(self) -> Any:
        """Return zero value for aggregation buffer"""
        raise NotImplementedError
    
    def reduce(self, buffer: Any, value: Any) -> Any:
        """Combine input value into buffer"""
        raise NotImplementedError
    
    def merge(self, buffer1: Any, buffer2: Any) -> Any:
        """Merge two intermediate buffers"""
        raise NotImplementedError
    
    def finish(self, reduction: Any) -> Any:
        """Produce final result from buffer"""
        raise NotImplementedError
{code}

Users can create UDAF instances using the udaf() function and use them with 
DataFrame.agg():

{code:python}
sum_udaf = udaf(MySum(), "bigint")
df.agg(sum_udaf(df.value))
df.groupBy("group").agg(sum_udaf(df.value))
{code}

The implementation uses iterator-based mapInArrow and applyInArrow APIs to 
reduce memory footprint, serializes Aggregator instances using cloudpickle for 
distribution to worker nodes, and handles both grouped and non-grouped 
aggregations. The random key range is determined by 
spark.sql.shuffle.partitions or DataFrame partition count to ensure proper data 
distribution.




--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to