This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 2f52688d Add decorator for udwf (#1061)
2f52688d is described below
commit 2f52688d76e84794343c17ffaf3002534ecfd716
Author: kosiew <[email protected]>
AuthorDate: Sat Mar 15 19:00:50 2025 +0800
Add decorator for udwf (#1061)
* feat: Introduce create_udwf method for User-Defined Window Functions
- Added `create_udwf` static method to `WindowUDF` class, allowing users to
create User-Defined Window Functions (UDWF) as both a function and a decorator.
- Updated type hinting for `_R` using `TypeAlias` for better clarity.
- Enhanced documentation with usage examples for both function and
decorator styles, improving usability and understanding.
* refactor: Simplify UDWF test suite and introduce SimpleWindowCount
evaluator
- Removed multiple exponential smoothing classes to streamline the code.
- Introduced SimpleWindowCount class for basic row counting functionality.
- Updated test cases to validate the new SimpleWindowCount evaluator.
- Refactored fixture and test functions for clarity and consistency.
- Enhanced error handling in UDWF creation tests.
* fix: Update type alias import to use typing_extensions for compatibility
* Add udwf tests for multiple input types and decorator syntax
* replace old def udwf
* refactor: Simplify df fixture by passing ctx as an argument
* refactor: Rename DataFrame fixtures and update test functions
- Renamed `df` fixture to `complex_window_df` for clarity.
- Renamed `simple_df` fixture to `count_window_df` to better reflect its
purpose.
- Updated test functions to use the new fixture names, enhancing
readability and maintainability.
* refactor: Update udwf calls in WindowUDF to use BiasedNumbers directly
- Changed udwf1 to use BiasedNumbers instead of bias_10.
- Added udwf2 to call udwf with bias_10.
- Introduced udwf3 to demonstrate a lambda function returning
BiasedNumbers(20).
* feat: Add overloads for udwf function to support multiple input types and
decorator syntax
* refactor: Simplify udwf method signature by removing redundant type hints
* refactor: Remove state_type from udwf method signature and update return
type handling
- Eliminated the state_type parameter from the udwf method to simplify the
function signature.
- Updated return type handling in the _function and _decorator methods to
use a generic type _R for better type flexibility.
- Enhanced the decorator to wrap the original function, allowing for
improved argument handling and expression return.
* refactor: Update volatility parameter type in udwf method signature to
support Volatility enum
* Fix ruff errors
* fix C901 for def udwf
* refactor: Update udwf method signature and simplify input handling
- Changed the type hint for the return type in the
_create_window_udf_decorator method to use pa.DataType directly instead of a
TypeVar.
- Simplified the handling of input types by removing redundant checks and
directly using the input types list.
- Removed unnecessary comments and cleaned up the code for better
readability.
- Updated the test for udwf to use parameterized tests for better coverage
and maintainability.
* refactor: Rename input_type to input_types in udwf method signature for
clarity
* refactor: Enhance typing in udf.py by introducing Protocol for
WindowEvaluator and improving import organization
* Revert "refactor: Enhance typing in udf.py by introducing Protocol for
WindowEvaluator and improving import organization"
This reverts commit 16dbe5f3fd88f42d0a304384b162009bd9e49a35.
---
python/datafusion/udf.py | 123 ++++++++++++++++++++++++++-------
python/tests/test_udwf.py | 170 ++++++++++++++++++++++++++++++++++++++++++++--
2 files changed, 264 insertions(+), 29 deletions(-)
diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py
index 603b7063..e93a34ca 100644
--- a/python/datafusion/udf.py
+++ b/python/datafusion/udf.py
@@ -621,6 +621,16 @@ class WindowUDF:
args_raw = [arg.expr for arg in args]
return Expr(self._udwf.__call__(*args_raw))
+ @overload
+ @staticmethod
+ def udwf(
+ input_types: pa.DataType | list[pa.DataType],
+ return_type: pa.DataType,
+ volatility: Volatility | str,
+ name: Optional[str] = None,
+ ) -> Callable[..., WindowUDF]: ...
+
+ @overload
@staticmethod
def udwf(
func: Callable[[], WindowEvaluator],
@@ -628,24 +638,31 @@ class WindowUDF:
return_type: pa.DataType,
volatility: Volatility | str,
name: Optional[str] = None,
- ) -> WindowUDF:
- """Create a new User-Defined Window Function.
+ ) -> WindowUDF: ...
- If your :py:class:`WindowEvaluator` can be instantiated with no
arguments, you
- can simply pass it's type as ``func``. If you need to pass additional
arguments
- to it's constructor, you can define a lambda or a factory method.
During runtime
- the :py:class:`WindowEvaluator` will be constructed for every instance
in
- which this UDWF is used. The following examples are all valid.
+ @staticmethod
+ def udwf(*args: Any, **kwargs: Any): # noqa: D417
+ """Create a new User-Defined Window Function (UDWF).
- .. code-block:: python
+ This class can be used both as a **function** and as a **decorator**.
+
+ Usage:
+ - **As a function**: Call `udwf(func, input_types, return_type,
volatility,
+ name)`.
+ - **As a decorator**: Use `@udwf(input_types, return_type,
volatility,
+ name)`. When using `udwf` as a decorator, **do not pass `func`
+ explicitly**.
+ **Function example:**
+ ```
import pyarrow as pa
class BiasedNumbers(WindowEvaluator):
def __init__(self, start: int = 0) -> None:
self.start = start
- def evaluate_all(self, values: list[pa.Array], num_rows: int)
-> pa.Array:
+ def evaluate_all(self, values: list[pa.Array],
+ num_rows: int) -> pa.Array:
return pa.array([self.start + i for i in range(num_rows)])
def bias_10() -> BiasedNumbers:
@@ -655,35 +672,93 @@ class WindowUDF:
udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable")
udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(),
"immutable")
+ ```
+
+ **Decorator example:**
+ ```
+ @udwf(pa.int64(), pa.int64(), "immutable")
+ def biased_numbers() -> BiasedNumbers:
+ return BiasedNumbers(10)
+ ```
+
Args:
- func: A callable to create the window function.
- input_types: The data types of the arguments to ``func``.
+ func: **Only needed when calling as a function. Skip this argument
when
+ using `udwf` as a decorator.**
+ input_types: The data types of the arguments.
return_type: The data type of the return value.
volatility: See :py:class:`Volatility` for allowed values.
- arguments: A list of arguments to pass in to the __init__ method
for accum.
name: A descriptive name for the function.
Returns:
- A user-defined window function.
- """ # noqa: W505, E501
+ A user-defined window function that can be used in window function
calls.
+ """
+ if args and callable(args[0]):
+ # Case 1: Used as a function, require the first parameter to be
callable
+ return WindowUDF._create_window_udf(*args, **kwargs)
+ # Case 2: Used as a decorator with parameters
+ return WindowUDF._create_window_udf_decorator(*args, **kwargs)
+
+ @staticmethod
+ def _create_window_udf(
+ func: Callable[[], WindowEvaluator],
+ input_types: pa.DataType | list[pa.DataType],
+ return_type: pa.DataType,
+ volatility: Volatility | str,
+ name: Optional[str] = None,
+ ) -> WindowUDF:
+ """Create a WindowUDF instance from function arguments."""
if not callable(func):
msg = "`func` must be callable."
raise TypeError(msg)
if not isinstance(func(), WindowEvaluator):
msg = "`func` must implement the abstract base class
WindowEvaluator"
raise TypeError(msg)
- if name is None:
- name = func().__class__.__qualname__.lower()
- if isinstance(input_types, pa.DataType):
- input_types = [input_types]
- return WindowUDF(
- name=name,
- func=func,
- input_types=input_types,
- return_type=return_type,
- volatility=volatility,
+
+ name = name or func.__qualname__.lower()
+ input_types = (
+ [input_types] if isinstance(input_types, pa.DataType) else
input_types
)
+ return WindowUDF(name, func, input_types, return_type, volatility)
+
+ @staticmethod
+ def _get_default_name(func: Callable) -> str:
+ """Get the default name for a function based on its attributes."""
+ if hasattr(func, "__qualname__"):
+ return func.__qualname__.lower()
+ return func.__class__.__name__.lower()
+
+ @staticmethod
+ def _normalize_input_types(
+ input_types: pa.DataType | list[pa.DataType],
+ ) -> list[pa.DataType]:
+ """Convert a single DataType to a list if needed."""
+ if isinstance(input_types, pa.DataType):
+ return [input_types]
+ return input_types
+
+ @staticmethod
+ def _create_window_udf_decorator(
+ input_types: pa.DataType | list[pa.DataType],
+ return_type: pa.DataType,
+ volatility: Volatility | str,
+ name: Optional[str] = None,
+ ) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]:
+ """Create a decorator for a WindowUDF."""
+
+ def decorator(func: Callable[[], WindowEvaluator]) -> Callable[...,
Expr]:
+ udwf_caller = WindowUDF._create_window_udf(
+ func, input_types, return_type, volatility, name
+ )
+
+ @functools.wraps(func)
+ def wrapper(*args: Any, **kwargs: Any) -> Expr:
+ return udwf_caller(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
# Convenience exports so we can import instead of treating as
# variables at the package root
diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py
index 3d6dcf9d..4190e7d6 100644
--- a/python/tests/test_udwf.py
+++ b/python/tests/test_udwf.py
@@ -162,14 +162,27 @@ class SmoothTwoColumn(WindowEvaluator):
return pa.array(results)
+class SimpleWindowCount(WindowEvaluator):
+ """A simple window evaluator that counts rows."""
+
+ def __init__(self, base: int = 0) -> None:
+ self.base = base
+
+ def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
+ return pa.array([self.base + i for i in range(num_rows)])
+
+
class NotSubclassOfWindowEvaluator:
pass
@pytest.fixture
-def df():
- ctx = SessionContext()
+def ctx():
+ return SessionContext()
+
[email protected]
+def complex_window_df(ctx):
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[
@@ -182,7 +195,17 @@ def df():
return ctx.create_dataframe([[batch]])
-def test_udwf_errors(df):
[email protected]
+def count_window_df(ctx):
+ # create a RecordBatch and a new DataFrame from it
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, 3]), pa.array([4, 4, 6])],
+ names=["a", "b"],
+ )
+ return ctx.create_dataframe([[batch]], name="test_table")
+
+
+def test_udwf_errors(complex_window_df):
with pytest.raises(TypeError):
udwf(
NotSubclassOfWindowEvaluator,
@@ -192,6 +215,103 @@ def test_udwf_errors(df):
)
+def test_udwf_errors_with_message():
+ """Test error cases for UDWF creation."""
+ with pytest.raises(
+ TypeError, match="`func` must implement the abstract base class
WindowEvaluator"
+ ):
+ udwf(
+ NotSubclassOfWindowEvaluator, pa.int64(), pa.int64(),
volatility="immutable"
+ )
+
+
+def test_udwf_basic_usage(count_window_df):
+ """Test basic UDWF usage with a simple counting window function."""
+ simple_count = udwf(
+ SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable"
+ )
+
+ df = count_window_df.select(
+ simple_count(column("a"))
+ .window_frame(WindowFrame("rows", None, None))
+ .build()
+ .alias("count")
+ )
+ result = df.collect()[0]
+ assert result.column(0) == pa.array([0, 1, 2])
+
+
+def test_udwf_with_args(count_window_df):
+ """Test UDWF with constructor arguments."""
+ count_base10 = udwf(
+ lambda: SimpleWindowCount(10), pa.int64(), pa.int64(),
volatility="immutable"
+ )
+
+ df = count_window_df.select(
+ count_base10(column("a"))
+ .window_frame(WindowFrame("rows", None, None))
+ .build()
+ .alias("count")
+ )
+ result = df.collect()[0]
+ assert result.column(0) == pa.array([10, 11, 12])
+
+
+def test_udwf_decorator_basic(count_window_df):
+ """Test UDWF used as a decorator."""
+
+ @udwf([pa.int64()], pa.int64(), "immutable")
+ def window_count() -> WindowEvaluator:
+ return SimpleWindowCount()
+
+ df = count_window_df.select(
+ window_count(column("a"))
+ .window_frame(WindowFrame("rows", None, None))
+ .build()
+ .alias("count")
+ )
+ result = df.collect()[0]
+ assert result.column(0) == pa.array([0, 1, 2])
+
+
+def test_udwf_decorator_with_args(count_window_df):
+ """Test UDWF decorator with constructor arguments."""
+
+ @udwf([pa.int64()], pa.int64(), "immutable")
+ def window_count_base10() -> WindowEvaluator:
+ return SimpleWindowCount(10)
+
+ df = count_window_df.select(
+ window_count_base10(column("a"))
+ .window_frame(WindowFrame("rows", None, None))
+ .build()
+ .alias("count")
+ )
+ result = df.collect()[0]
+ assert result.column(0) == pa.array([10, 11, 12])
+
+
+def test_register_udwf(ctx, count_window_df):
+ """Test registering and using UDWF in SQL context."""
+ window_count = udwf(
+ SimpleWindowCount,
+ [pa.int64()],
+ pa.int64(),
+ volatility="immutable",
+ name="window_count",
+ )
+
+ ctx.register_udwf(window_count)
+ result = ctx.sql(
+ """
+ SELECT window_count(a)
+ OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED
+ FOLLOWING) FROM test_table
+ """
+ ).collect()[0]
+ assert result.column(0) == pa.array([0, 1, 2])
+
+
smooth_default = udwf(
ExponentialSmoothDefault,
pa.float64(),
@@ -299,10 +419,50 @@ data_test_udwf_functions = [
@pytest.mark.parametrize(("name", "expr", "expected"),
data_test_udwf_functions)
-def test_udwf_functions(df, name, expr, expected):
- df = df.select("a", "b", f.round(expr, lit(3)).alias(name))
+def test_udwf_functions(complex_window_df, name, expr, expected):
+ df = complex_window_df.select("a", "b", f.round(expr, lit(3)).alias(name))
# execute and collect the first (and only) batch
result = df.sort(column("a")).select(column(name)).collect()[0]
assert result.column(0) == pa.array(expected)
+
+
[email protected](
+ "udwf_func",
+ [
+ udwf(SimpleWindowCount, pa.int64(), pa.int64(), "immutable"),
+ udwf(SimpleWindowCount, [pa.int64()], pa.int64(), "immutable"),
+ udwf([pa.int64()], pa.int64(), "immutable")(lambda:
SimpleWindowCount()),
+ udwf(pa.int64(), pa.int64(), "immutable")(lambda: SimpleWindowCount()),
+ ],
+)
+def test_udwf_overloads(udwf_func, count_window_df):
+ df = count_window_df.select(
+ udwf_func(column("a"))
+ .window_frame(WindowFrame("rows", None, None))
+ .build()
+ .alias("count")
+ )
+ result = df.collect()[0]
+ assert result.column(0) == pa.array([0, 1, 2])
+
+
+def test_udwf_named_function(ctx, count_window_df):
+ """Test UDWF with explicit name parameter."""
+ window_count = udwf(
+ SimpleWindowCount,
+ pa.int64(),
+ pa.int64(),
+ volatility="immutable",
+ name="my_custom_counter",
+ )
+
+ ctx.register_udwf(window_count)
+ result = ctx.sql(
+ """
+ SELECT my_custom_counter(a)
+ OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED
+ FOLLOWING) FROM test_table"""
+ ).collect()[0]
+ assert result.column(0) == pa.array([0, 1, 2])
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]