This is an automated email from the ASF dual-hosted git repository.

timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f52688d Add decorator for udwf (#1061)
2f52688d is described below

commit 2f52688d76e84794343c17ffaf3002534ecfd716
Author: kosiew <[email protected]>
AuthorDate: Sat Mar 15 19:00:50 2025 +0800

    Add decorator for udwf (#1061)
    
    * feat: Introduce create_udwf method for User-Defined Window Functions
    
    - Added `create_udwf` static method to `WindowUDF` class, allowing users to 
create User-Defined Window Functions (UDWF) as both a function and a decorator.
    - Updated type hinting for `_R` using `TypeAlias` for better clarity.
    - Enhanced documentation with usage examples for both function and 
decorator styles, improving usability and understanding.
    
    * refactor: Simplify UDWF test suite and introduce SimpleWindowCount 
evaluator
    
    - Removed multiple exponential smoothing classes to streamline the code.
    - Introduced SimpleWindowCount class for basic row counting functionality.
    - Updated test cases to validate the new SimpleWindowCount evaluator.
    - Refactored fixture and test functions for clarity and consistency.
    - Enhanced error handling in UDWF creation tests.
    
    * fix: Update type alias import to use typing_extensions for compatibility
    
    * Add udwf tests for multiple input types and decorator syntax
    
    * replace old def udwf
    
    * refactor: Simplify df fixture by passing ctx as an argument
    
    * refactor: Rename DataFrame fixtures and update test functions
    
    - Renamed `df` fixture to `complex_window_df` for clarity.
    - Renamed `simple_df` fixture to `count_window_df` to better reflect its 
purpose.
    - Updated test functions to use the new fixture names, enhancing 
readability and maintainability.
    
    * refactor: Update udwf calls in WindowUDF to use BiasedNumbers directly
    
    - Changed udwf1 to use BiasedNumbers instead of bias_10.
    - Added udwf2 to call udwf with bias_10.
    - Introduced udwf3 to demonstrate a lambda function returning 
BiasedNumbers(20).
    
    * feat: Add overloads for udwf function to support multiple input types and 
decorator syntax
    
    * refactor: Simplify udwf method signature by removing redundant type hints
    
    * refactor: Remove state_type from udwf method signature and update return 
type handling
    
    - Eliminated the state_type parameter from the udwf method to simplify the 
function signature.
    - Updated return type handling in the _function and _decorator methods to 
use a generic type _R for better type flexibility.
    - Enhanced the decorator to wrap the original function, allowing for 
improved argument handling and expression return.
    
    * refactor: Update volatility parameter type in udwf method signature to 
support Volatility enum
    
    * Fix ruff errors
    
    * fix C901 for def udwf
    
    * refactor: Update udwf method signature and simplify input handling
    
    - Changed the type hint for the return type in the 
_create_window_udf_decorator method to use pa.DataType directly instead of a 
TypeVar.
    - Simplified the handling of input types by removing redundant checks and 
directly using the input types list.
    - Removed unnecessary comments and cleaned up the code for better 
readability.
    - Updated the test for udwf to use parameterized tests for better coverage 
and maintainability.
    
    * refactor: Rename input_type to input_types in udwf method signature for 
clarity
    
    * refactor: Enhance typing in udf.py by introducing Protocol for 
WindowEvaluator and improving import organization
    
    * Revert "refactor: Enhance typing in udf.py by introducing Protocol for 
WindowEvaluator and improving import organization"
    
    This reverts commit 16dbe5f3fd88f42d0a304384b162009bd9e49a35.
---
 python/datafusion/udf.py  | 123 ++++++++++++++++++++++++++-------
 python/tests/test_udwf.py | 170 ++++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 264 insertions(+), 29 deletions(-)

diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py
index 603b7063..e93a34ca 100644
--- a/python/datafusion/udf.py
+++ b/python/datafusion/udf.py
@@ -621,6 +621,16 @@ class WindowUDF:
         args_raw = [arg.expr for arg in args]
         return Expr(self._udwf.__call__(*args_raw))
 
+    @overload
+    @staticmethod
+    def udwf(
+        input_types: pa.DataType | list[pa.DataType],
+        return_type: pa.DataType,
+        volatility: Volatility | str,
+        name: Optional[str] = None,
+    ) -> Callable[..., WindowUDF]: ...
+
+    @overload
     @staticmethod
     def udwf(
         func: Callable[[], WindowEvaluator],
@@ -628,24 +638,31 @@ class WindowUDF:
         return_type: pa.DataType,
         volatility: Volatility | str,
         name: Optional[str] = None,
-    ) -> WindowUDF:
-        """Create a new User-Defined Window Function.
+    ) -> WindowUDF: ...
 
-        If your :py:class:`WindowEvaluator` can be instantiated with no 
arguments, you
-        can simply pass it's type as ``func``. If you need to pass additional 
arguments
-        to it's constructor, you can define a lambda or a factory method. 
During runtime
-        the :py:class:`WindowEvaluator` will be constructed for every instance 
in
-        which this UDWF is used. The following examples are all valid.
+    @staticmethod
+    def udwf(*args: Any, **kwargs: Any):  # noqa: D417
+        """Create a new User-Defined Window Function (UDWF).
 
-        .. code-block:: python
+        This class can be used both as a **function** and as a **decorator**.
+
+        Usage:
+            - **As a function**: Call `udwf(func, input_types, return_type, 
volatility,
+              name)`.
+            - **As a decorator**: Use `@udwf(input_types, return_type, 
volatility,
+              name)`. When using `udwf` as a decorator, **do not pass `func`
+              explicitly**.
 
+        **Function example:**
+            ```
             import pyarrow as pa
 
             class BiasedNumbers(WindowEvaluator):
                 def __init__(self, start: int = 0) -> None:
                     self.start = start
 
-                def evaluate_all(self, values: list[pa.Array], num_rows: int) 
-> pa.Array:
+                def evaluate_all(self, values: list[pa.Array],
+                    num_rows: int) -> pa.Array:
                     return pa.array([self.start + i for i in range(num_rows)])
 
             def bias_10() -> BiasedNumbers:
@@ -655,35 +672,93 @@ class WindowUDF:
             udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable")
             udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), 
"immutable")
 
+            ```
+
+        **Decorator example:**
+            ```
+            @udwf(pa.int64(), pa.int64(), "immutable")
+            def biased_numbers() -> BiasedNumbers:
+                return BiasedNumbers(10)
+            ```
+
         Args:
-            func: A callable to create the window function.
-            input_types: The data types of the arguments to ``func``.
+            func: **Only needed when calling as a function. Skip this argument 
when
+                using `udwf` as a decorator.**
+            input_types: The data types of the arguments.
             return_type: The data type of the return value.
             volatility: See :py:class:`Volatility` for allowed values.
-            arguments: A list of arguments to pass in to the __init__ method 
for accum.
             name: A descriptive name for the function.
 
         Returns:
-            A user-defined window function.
-        """  # noqa: W505, E501
+            A user-defined window function that can be used in window function 
calls.
+        """
+        if args and callable(args[0]):
+            # Case 1: Used as a function, require the first parameter to be 
callable
+            return WindowUDF._create_window_udf(*args, **kwargs)
+        # Case 2: Used as a decorator with parameters
+        return WindowUDF._create_window_udf_decorator(*args, **kwargs)
+
+    @staticmethod
+    def _create_window_udf(
+        func: Callable[[], WindowEvaluator],
+        input_types: pa.DataType | list[pa.DataType],
+        return_type: pa.DataType,
+        volatility: Volatility | str,
+        name: Optional[str] = None,
+    ) -> WindowUDF:
+        """Create a WindowUDF instance from function arguments."""
         if not callable(func):
             msg = "`func` must be callable."
             raise TypeError(msg)
         if not isinstance(func(), WindowEvaluator):
             msg = "`func` must implement the abstract base class 
WindowEvaluator"
             raise TypeError(msg)
-        if name is None:
-            name = func().__class__.__qualname__.lower()
-        if isinstance(input_types, pa.DataType):
-            input_types = [input_types]
-        return WindowUDF(
-            name=name,
-            func=func,
-            input_types=input_types,
-            return_type=return_type,
-            volatility=volatility,
+
+        name = name or func.__qualname__.lower()
+        input_types = (
+            [input_types] if isinstance(input_types, pa.DataType) else 
input_types
         )
 
+        return WindowUDF(name, func, input_types, return_type, volatility)
+
+    @staticmethod
+    def _get_default_name(func: Callable) -> str:
+        """Get the default name for a function based on its attributes."""
+        if hasattr(func, "__qualname__"):
+            return func.__qualname__.lower()
+        return func.__class__.__name__.lower()
+
+    @staticmethod
+    def _normalize_input_types(
+        input_types: pa.DataType | list[pa.DataType],
+    ) -> list[pa.DataType]:
+        """Convert a single DataType to a list if needed."""
+        if isinstance(input_types, pa.DataType):
+            return [input_types]
+        return input_types
+
+    @staticmethod
+    def _create_window_udf_decorator(
+        input_types: pa.DataType | list[pa.DataType],
+        return_type: pa.DataType,
+        volatility: Volatility | str,
+        name: Optional[str] = None,
+    ) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]:
+        """Create a decorator for a WindowUDF."""
+
+        def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., 
Expr]:
+            udwf_caller = WindowUDF._create_window_udf(
+                func, input_types, return_type, volatility, name
+            )
+
+            @functools.wraps(func)
+            def wrapper(*args: Any, **kwargs: Any) -> Expr:
+                return udwf_caller(*args, **kwargs)
+
+            return wrapper
+
+        return decorator
+
 
 # Convenience exports so we can import instead of treating as
 # variables at the package root
diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py
index 3d6dcf9d..4190e7d6 100644
--- a/python/tests/test_udwf.py
+++ b/python/tests/test_udwf.py
@@ -162,14 +162,27 @@ class SmoothTwoColumn(WindowEvaluator):
         return pa.array(results)
 
 
+class SimpleWindowCount(WindowEvaluator):
+    """A simple window evaluator that counts rows."""
+
+    def __init__(self, base: int = 0) -> None:
+        self.base = base
+
+    def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
+        return pa.array([self.base + i for i in range(num_rows)])
+
+
 class NotSubclassOfWindowEvaluator:
     pass
 
 
 @pytest.fixture
-def df():
-    ctx = SessionContext()
+def ctx():
+    return SessionContext()
+
 
[email protected]
+def complex_window_df(ctx):
     # create a RecordBatch and a new DataFrame from it
     batch = pa.RecordBatch.from_arrays(
         [
@@ -182,7 +195,17 @@ def df():
     return ctx.create_dataframe([[batch]])
 
 
-def test_udwf_errors(df):
[email protected]
+def count_window_df(ctx):
+    # create a RecordBatch and a new DataFrame from it
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2, 3]), pa.array([4, 4, 6])],
+        names=["a", "b"],
+    )
+    return ctx.create_dataframe([[batch]], name="test_table")
+
+
+def test_udwf_errors(complex_window_df):
     with pytest.raises(TypeError):
         udwf(
             NotSubclassOfWindowEvaluator,
@@ -192,6 +215,103 @@ def test_udwf_errors(df):
         )
 
 
+def test_udwf_errors_with_message():
+    """Test error cases for UDWF creation."""
+    with pytest.raises(
+        TypeError, match="`func` must implement the abstract base class 
WindowEvaluator"
+    ):
+        udwf(
+            NotSubclassOfWindowEvaluator, pa.int64(), pa.int64(), 
volatility="immutable"
+        )
+
+
+def test_udwf_basic_usage(count_window_df):
+    """Test basic UDWF usage with a simple counting window function."""
+    simple_count = udwf(
+        SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable"
+    )
+
+    df = count_window_df.select(
+        simple_count(column("a"))
+        .window_frame(WindowFrame("rows", None, None))
+        .build()
+        .alias("count")
+    )
+    result = df.collect()[0]
+    assert result.column(0) == pa.array([0, 1, 2])
+
+
+def test_udwf_with_args(count_window_df):
+    """Test UDWF with constructor arguments."""
+    count_base10 = udwf(
+        lambda: SimpleWindowCount(10), pa.int64(), pa.int64(), 
volatility="immutable"
+    )
+
+    df = count_window_df.select(
+        count_base10(column("a"))
+        .window_frame(WindowFrame("rows", None, None))
+        .build()
+        .alias("count")
+    )
+    result = df.collect()[0]
+    assert result.column(0) == pa.array([10, 11, 12])
+
+
+def test_udwf_decorator_basic(count_window_df):
+    """Test UDWF used as a decorator."""
+
+    @udwf([pa.int64()], pa.int64(), "immutable")
+    def window_count() -> WindowEvaluator:
+        return SimpleWindowCount()
+
+    df = count_window_df.select(
+        window_count(column("a"))
+        .window_frame(WindowFrame("rows", None, None))
+        .build()
+        .alias("count")
+    )
+    result = df.collect()[0]
+    assert result.column(0) == pa.array([0, 1, 2])
+
+
+def test_udwf_decorator_with_args(count_window_df):
+    """Test UDWF decorator with constructor arguments."""
+
+    @udwf([pa.int64()], pa.int64(), "immutable")
+    def window_count_base10() -> WindowEvaluator:
+        return SimpleWindowCount(10)
+
+    df = count_window_df.select(
+        window_count_base10(column("a"))
+        .window_frame(WindowFrame("rows", None, None))
+        .build()
+        .alias("count")
+    )
+    result = df.collect()[0]
+    assert result.column(0) == pa.array([10, 11, 12])
+
+
+def test_register_udwf(ctx, count_window_df):
+    """Test registering and using UDWF in SQL context."""
+    window_count = udwf(
+        SimpleWindowCount,
+        [pa.int64()],
+        pa.int64(),
+        volatility="immutable",
+        name="window_count",
+    )
+
+    ctx.register_udwf(window_count)
+    result = ctx.sql(
+        """
+        SELECT window_count(a)
+        OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED
+        FOLLOWING) FROM test_table
+        """
+    ).collect()[0]
+    assert result.column(0) == pa.array([0, 1, 2])
+
+
 smooth_default = udwf(
     ExponentialSmoothDefault,
     pa.float64(),
@@ -299,10 +419,50 @@ data_test_udwf_functions = [
 
 
 @pytest.mark.parametrize(("name", "expr", "expected"), 
data_test_udwf_functions)
-def test_udwf_functions(df, name, expr, expected):
-    df = df.select("a", "b", f.round(expr, lit(3)).alias(name))
+def test_udwf_functions(complex_window_df, name, expr, expected):
+    df = complex_window_df.select("a", "b", f.round(expr, lit(3)).alias(name))
 
     # execute and collect the first (and only) batch
     result = df.sort(column("a")).select(column(name)).collect()[0]
 
     assert result.column(0) == pa.array(expected)
+
+
[email protected](
+    "udwf_func",
+    [
+        udwf(SimpleWindowCount, pa.int64(), pa.int64(), "immutable"),
+        udwf(SimpleWindowCount, [pa.int64()], pa.int64(), "immutable"),
+        udwf([pa.int64()], pa.int64(), "immutable")(lambda: 
SimpleWindowCount()),
+        udwf(pa.int64(), pa.int64(), "immutable")(lambda: SimpleWindowCount()),
+    ],
+)
+def test_udwf_overloads(udwf_func, count_window_df):
+    df = count_window_df.select(
+        udwf_func(column("a"))
+        .window_frame(WindowFrame("rows", None, None))
+        .build()
+        .alias("count")
+    )
+    result = df.collect()[0]
+    assert result.column(0) == pa.array([0, 1, 2])
+
+
+def test_udwf_named_function(ctx, count_window_df):
+    """Test UDWF with explicit name parameter."""
+    window_count = udwf(
+        SimpleWindowCount,
+        pa.int64(),
+        pa.int64(),
+        volatility="immutable",
+        name="my_custom_counter",
+    )
+
+    ctx.register_udwf(window_count)
+    result = ctx.sql(
+        """
+        SELECT my_custom_counter(a)
+        OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED
+        FOLLOWING) FROM test_table"""
+    ).collect()[0]
+    assert result.column(0) == pa.array([0, 1, 2])


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to