zhengruifeng commented on code in PR #53400:
URL: https://github.com/apache/spark/pull/53400#discussion_r2609010207


##########
python/pyspark/sql/udaf.py:
##########
@@ -0,0 +1,897 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""
+User-defined aggregate function related classes and functions
+"""
+from typing import Any, TYPE_CHECKING, Optional, List, Iterator, Tuple
+
+from pyspark.sql.column import Column
+from pyspark.sql.types import (
+    DataType,
+    _parse_datatype_string,
+)
+from pyspark.errors import PySparkTypeError, PySparkNotImplementedError
+
+if TYPE_CHECKING:
+    from pyspark.sql._typing import DataTypeOrString, ColumnOrName
+    from pyspark.sql.dataframe import DataFrame
+
+__all__ = [
+    "Aggregator",
+    "UserDefinedAggregateFunction",
+    "udaf",
+]
+
+
+class Aggregator:
+    """
+    Base class for user-defined aggregations.
+
+    This class defines the interface for implementing user-defined aggregate 
functions (UDAFs)
+    in Python. Users should subclass this class and implement the required 
methods.
+
+    All methods (zero, reduce, merge, finish) MUST be decorated with 
@staticmethod.
+    This ensures the aggregator can be properly serialized and executed across 
Spark workers.
+
+    .. versionadded:: 4.2.0
+
+    Examples
+    --------
+    >>> class MySum(Aggregator):
+    ...     @staticmethod
+    ...     def zero():
+    ...         return 0
+    ...     @staticmethod
+    ...     def reduce(buffer, value):
+    ...         return buffer + value
+    ...     @staticmethod
+    ...     def merge(buffer1, buffer2):
+    ...         return buffer1 + buffer2
+    ...     @staticmethod
+    ...     def finish(reduction):
+    ...         return reduction
+    """
+
+    @staticmethod
+    def zero() -> Any:
+        """
+        A zero value for this aggregation. Should satisfy the property that 
any b + zero = b.
+
+        Must be decorated with @staticmethod.
+
+        Returns
+        -------
+        Any
+            The zero value for the aggregation buffer.
+        """
+        raise NotImplementedError
+
+    @staticmethod
+    def reduce(buffer: Any, value: Any) -> Any:
+        """
+        Combine an input value into the current intermediate value.
+
+        For performance, the function may modify `buffer` and return it 
instead of
+        constructing a new object.
+
+        Must be decorated with @staticmethod.
+
+        Parameters
+        ----------
+        buffer : Any
+            The current intermediate value (buffer).
+        value : Any
+            The input value to aggregate.
+
+        Returns
+        -------
+        Any
+            The updated buffer.
+        """
+        raise NotImplementedError
+
+    @staticmethod
+    def merge(buffer1: Any, buffer2: Any) -> Any:
+        """
+        Merge two intermediate values.
+
+        Must be decorated with @staticmethod.
+
+        Parameters
+        ----------
+        buffer1 : Any
+            The first intermediate value.
+        buffer2 : Any
+            The second intermediate value.
+
+        Returns
+        -------
+        Any
+            The merged intermediate value.
+        """
+        raise NotImplementedError
+
+    @staticmethod
+    def finish(reduction: Any) -> Any:
+        """
+        Transform the output of the reduction.
+
+        Must be decorated with @staticmethod.
+
+        Parameters
+        ----------
+        reduction : Any
+            The final reduction result.
+
+        Returns
+        -------
+        Any
+            The final output value.
+        """
+        raise NotImplementedError
+
+
+def _validate_aggregator_methods(aggregator: Aggregator) -> None:
+    """
+    Validate that all required Aggregator methods are decorated with 
@staticmethod.
+
+    Parameters
+    ----------
+    aggregator : Aggregator
+        The aggregator instance to validate.
+
+    Raises
+    ------
+    PySparkTypeError
+        If any required method is not a static method.
+    """
+    required_methods = ["zero", "reduce", "merge", "finish"]
+    aggregator_class = type(aggregator)
+
+    for method_name in required_methods:
+        # Check if the method exists on the class (not just inherited from 
Aggregator base)
+        if not hasattr(aggregator_class, method_name):
+            raise PySparkTypeError(
+                errorClass="NOT_CALLABLE",
+                messageParameters={
+                    "arg_name": f"aggregator.{method_name}",
+                    "arg_type": "missing",
+                },
+            )
+
+        # Get the method from the class definition (not the instance)
+        class_attr = getattr(aggregator_class, method_name)
+
+        # Check if it's a staticmethod by looking at the class __dict__
+        # (methods bound to instances lose their staticmethod wrapper)
+        if method_name in aggregator_class.__dict__:
+            raw_method = aggregator_class.__dict__[method_name]
+            if not isinstance(raw_method, staticmethod):
+                raise PySparkTypeError(
+                    errorClass="NOT_CALLABLE",
+                    messageParameters={
+                        "arg_name": f"aggregator.{method_name}",
+                        "arg_type": f"non-static method (must use 
@staticmethod decorator)",
+                    },
+                )
+
+
+class UserDefinedAggregateFunction:
+    """
+    User-defined aggregate function wrapper for Python Aggregator.
+
+    This class wraps an Aggregator instance and provides the functionality to 
use it
+    as an aggregate function in Spark SQL. The implementation uses mapInArrow 
and
+    applyInArrow to perform partial aggregation and final aggregation.
+
+    .. versionadded:: 4.2.0
+    """
+
+    def __init__(
+        self,
+        aggregator: Aggregator,
+        returnType: "DataTypeOrString",
+        name: Optional[str] = None,
+    ):
+        if not isinstance(aggregator, Aggregator):
+            raise PySparkTypeError(
+                errorClass="NOT_CALLABLE",
+                messageParameters={
+                    "arg_name": "aggregator",
+                    "arg_type": type(aggregator).__name__,
+                },
+            )
+
+        if not isinstance(returnType, (DataType, str)):
+            raise PySparkTypeError(
+                errorClass="NOT_DATATYPE_OR_STR",
+                messageParameters={
+                    "arg_name": "returnType",
+                    "arg_type": type(returnType).__name__,
+                },
+            )
+
+        # Validate that all required methods are static methods
+        _validate_aggregator_methods(aggregator)
+
+        self.aggregator = aggregator
+        self._returnType = returnType
+        self._name = name or (
+            aggregator.__class__.__name__
+            if hasattr(aggregator, "__class__")
+            else "UserDefinedAggregateFunction"
+        )
+        # Serialize aggregator for use in Arrow functions
+        # Use cloudpickle to ensure proper serialization of classes
+        try:
+            import cloudpickle
+        except ImportError:
+            import pickle as cloudpickle
+        self._serialized_aggregator = cloudpickle.dumps(aggregator)
+
+    @property
+    def returnType(self) -> DataType:
+        """Get the return type of this UDAF."""
+        if isinstance(self._returnType, DataType):
+            return self._returnType
+        else:
+            return _parse_datatype_string(self._returnType)
+
+    def __call__(self, *args: "ColumnOrName") -> Column:
+        """
+        Apply this UDAF to the given columns.
+
+        This creates a Column expression that can be used in DataFrame 
operations.
+        The actual aggregation is performed using mapInArrow and applyInArrow.
+
+        Parameters
+        ----------
+        *args : ColumnOrName
+            The columns to aggregate. Currently supports a single column.
+
+        Returns
+        -------
+        Column
+            A Column representing the aggregation result.
+
+        Notes
+        -----
+        This implementation uses mapInArrow and applyInArrow internally to 
perform
+        the aggregation. The approach follows:
+        1. mapInArrow: Performs partial aggregation (reduce) on each partition

Review Comment:
   If we want to support partial aggregation with existing arrow UDFs, I think 
we should use a modified `FlatMapGroupsInArrowExec` with 
`requiredChildDistribution = UnspecifiedDistribution`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to