This is an automated email from the ASF dual-hosted git repository.

timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 58925853 Allow users to pass a single expression instead of a list of 
expressions for partition_by and order_by (#1187)
58925853 is described below

commit 589258533b796dd24a8ca56bbd0a0e7ed3ea05cc
Author: Tim Saucer <timsau...@gmail.com>
AuthorDate: Thu Aug 21 13:40:11 2025 -0400

    Allow users to pass a single expression instead of a list of expressions 
for partition_by and order_by (#1187)
---
 python/datafusion/expr.py        | 12 +++--
 python/datafusion/functions.py   | 98 +++++++++++++++++-----------------------
 python/tests/test_aggregation.py | 29 ++++++++++++
 python/tests/test_dataframe.py   | 68 ++++++++++++++++++++++++++++
 4 files changed, 146 insertions(+), 61 deletions(-)

diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index e785cab0..c0b49571 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -216,9 +216,11 @@ __all__ = [
 
 
 def expr_list_to_raw_expr_list(
-    expr_list: Optional[list[Expr]],
+    expr_list: Optional[list[Expr] | Expr],
 ) -> Optional[list[expr_internal.Expr]]:
     """Helper function to convert an optional list to raw expressions."""
+    if isinstance(expr_list, Expr):
+        expr_list = [expr_list]
     return [e.expr for e in expr_list] if expr_list is not None else None
 
 
@@ -230,9 +232,11 @@ def sort_or_default(e: Expr | SortExpr) -> 
expr_internal.SortExpr:
 
 
 def sort_list_to_raw_sort_list(
-    sort_list: Optional[list[Expr | SortExpr]],
+    sort_list: Optional[list[Expr | SortExpr] | Expr | SortExpr],
 ) -> Optional[list[expr_internal.SortExpr]]:
     """Helper function to return an optional sort list to raw variant."""
+    if isinstance(sort_list, (Expr, SortExpr)):
+        sort_list = [sort_list]
     return [sort_or_default(e) for e in sort_list] if sort_list is not None 
else None
 
 
@@ -1140,9 +1144,9 @@ class Window:
 
     def __init__(
         self,
-        partition_by: Optional[list[Expr]] = None,
+        partition_by: Optional[list[Expr] | Expr] = None,
         window_frame: Optional[WindowFrame] = None,
-        order_by: Optional[list[SortExpr | Expr]] = None,
+        order_by: Optional[list[SortExpr | Expr] | Expr | SortExpr] = None,
         null_treatment: Optional[NullTreatment] = None,
     ) -> None:
         """Construct a window definition.
diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py
index f430cdf4..34068805 100644
--- a/python/datafusion/functions.py
+++ b/python/datafusion/functions.py
@@ -428,8 +428,8 @@ def when(when: Expr, then: Expr) -> CaseBuilder:
 def window(
     name: str,
     args: list[Expr],
-    partition_by: list[Expr] | None = None,
-    order_by: list[Expr | SortExpr] | None = None,
+    partition_by: list[Expr] | Expr | None = None,
+    order_by: list[Expr | SortExpr] | Expr | SortExpr | None = None,
     window_frame: WindowFrame | None = None,
     ctx: SessionContext | None = None,
 ) -> Expr:
@@ -442,11 +442,11 @@ def window(
         df.select(functions.lag(col("a")).partition_by(col("b")).build())
     """
     args = [a.expr for a in args]
-    partition_by = expr_list_to_raw_expr_list(partition_by)
+    partition_by_raw = expr_list_to_raw_expr_list(partition_by)
     order_by_raw = sort_list_to_raw_sort_list(order_by)
     window_frame = window_frame.window_frame if window_frame is not None else 
None
     ctx = ctx.ctx if ctx is not None else None
-    return Expr(f.window(name, args, partition_by, order_by_raw, window_frame, 
ctx))
+    return Expr(f.window(name, args, partition_by_raw, order_by_raw, 
window_frame, ctx))
 
 
 # scalar functions
@@ -1723,7 +1723,7 @@ def array_agg(
     expression: Expr,
     distinct: bool = False,
     filter: Optional[Expr] = None,
-    order_by: Optional[list[Expr | SortExpr]] = None,
+    order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
 ) -> Expr:
     """Aggregate values into an array.
 
@@ -2222,7 +2222,7 @@ def regr_syy(
 def first_value(
     expression: Expr,
     filter: Optional[Expr] = None,
-    order_by: Optional[list[Expr | SortExpr]] = None,
+    order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
     null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
 ) -> Expr:
     """Returns the first value in a group of values.
@@ -2254,7 +2254,7 @@ def first_value(
 def last_value(
     expression: Expr,
     filter: Optional[Expr] = None,
-    order_by: Optional[list[Expr | SortExpr]] = None,
+    order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
     null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
 ) -> Expr:
     """Returns the last value in a group of values.
@@ -2287,7 +2287,7 @@ def nth_value(
     expression: Expr,
     n: int,
     filter: Optional[Expr] = None,
-    order_by: Optional[list[Expr | SortExpr]] = None,
+    order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
     null_treatment: NullTreatment = NullTreatment.RESPECT_NULLS,
 ) -> Expr:
     """Returns the n-th value in a group of values.
@@ -2407,8 +2407,8 @@ def lead(
     arg: Expr,
     shift_offset: int = 1,
     default_value: Optional[Any] = None,
-    partition_by: Optional[list[Expr]] = None,
-    order_by: Optional[list[Expr | SortExpr]] = None,
+    partition_by: Optional[list[Expr] | Expr] = None,
+    order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
 ) -> Expr:
     """Create a lead window function.
 
@@ -2442,9 +2442,7 @@ def lead(
     if not isinstance(default_value, pa.Scalar) and default_value is not None:
         default_value = pa.scalar(default_value)
 
-    partition_cols = (
-        [col.expr for col in partition_by] if partition_by is not None else 
None
-    )
+    partition_by_raw = expr_list_to_raw_expr_list(partition_by)
     order_by_raw = sort_list_to_raw_sort_list(order_by)
 
     return Expr(
@@ -2452,7 +2450,7 @@ def lead(
             arg.expr,
             shift_offset,
             default_value,
-            partition_by=partition_cols,
+            partition_by=partition_by_raw,
             order_by=order_by_raw,
         )
     )
@@ -2462,8 +2460,8 @@ def lag(
     arg: Expr,
     shift_offset: int = 1,
     default_value: Optional[Any] = None,
-    partition_by: Optional[list[Expr]] = None,
-    order_by: Optional[list[Expr | SortExpr]] = None,
+    partition_by: Optional[list[Expr] | Expr] = None,
+    order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
 ) -> Expr:
     """Create a lag window function.
 
@@ -2494,9 +2492,7 @@ def lag(
     if not isinstance(default_value, pa.Scalar):
         default_value = pa.scalar(default_value)
 
-    partition_cols = (
-        [col.expr for col in partition_by] if partition_by is not None else 
None
-    )
+    partition_by_raw = expr_list_to_raw_expr_list(partition_by)
     order_by_raw = sort_list_to_raw_sort_list(order_by)
 
     return Expr(
@@ -2504,15 +2500,15 @@ def lag(
             arg.expr,
             shift_offset,
             default_value,
-            partition_by=partition_cols,
+            partition_by=partition_by_raw,
             order_by=order_by_raw,
         )
     )
 
 
 def row_number(
-    partition_by: Optional[list[Expr]] = None,
-    order_by: Optional[list[Expr | SortExpr]] = None,
+    partition_by: Optional[list[Expr] | Expr] = None,
+    order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
 ) -> Expr:
     """Create a row number window function.
 
@@ -2533,22 +2529,20 @@ def row_number(
         partition_by: Expressions to partition the window frame on.
         order_by: Set ordering within the window frame.
     """
-    partition_cols = (
-        [col.expr for col in partition_by] if partition_by is not None else 
None
-    )
+    partition_by_raw = expr_list_to_raw_expr_list(partition_by)
     order_by_raw = sort_list_to_raw_sort_list(order_by)
 
     return Expr(
         f.row_number(
-            partition_by=partition_cols,
+            partition_by=partition_by_raw,
             order_by=order_by_raw,
         )
     )
 
 
 def rank(
-    partition_by: Optional[list[Expr]] = None,
-    order_by: Optional[list[Expr | SortExpr]] = None,
+    partition_by: Optional[list[Expr] | Expr] = None,
+    order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
 ) -> Expr:
     """Create a rank window function.
 
@@ -2574,22 +2568,20 @@ def rank(
         partition_by: Expressions to partition the window frame on.
         order_by: Set ordering within the window frame.
     """
-    partition_cols = (
-        [col.expr for col in partition_by] if partition_by is not None else 
None
-    )
+    partition_by_raw = expr_list_to_raw_expr_list(partition_by)
     order_by_raw = sort_list_to_raw_sort_list(order_by)
 
     return Expr(
         f.rank(
-            partition_by=partition_cols,
+            partition_by=partition_by_raw,
             order_by=order_by_raw,
         )
     )
 
 
 def dense_rank(
-    partition_by: Optional[list[Expr]] = None,
-    order_by: Optional[list[Expr | SortExpr]] = None,
+    partition_by: Optional[list[Expr] | Expr] = None,
+    order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
 ) -> Expr:
     """Create a dense_rank window function.
 
@@ -2610,22 +2602,20 @@ def dense_rank(
         partition_by: Expressions to partition the window frame on.
         order_by: Set ordering within the window frame.
     """
-    partition_cols = (
-        [col.expr for col in partition_by] if partition_by is not None else 
None
-    )
+    partition_by_raw = expr_list_to_raw_expr_list(partition_by)
     order_by_raw = sort_list_to_raw_sort_list(order_by)
 
     return Expr(
         f.dense_rank(
-            partition_by=partition_cols,
+            partition_by=partition_by_raw,
             order_by=order_by_raw,
         )
     )
 
 
 def percent_rank(
-    partition_by: Optional[list[Expr]] = None,
-    order_by: Optional[list[Expr | SortExpr]] = None,
+    partition_by: Optional[list[Expr] | Expr] = None,
+    order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
 ) -> Expr:
     """Create a percent_rank window function.
 
@@ -2647,22 +2637,20 @@ def percent_rank(
         partition_by: Expressions to partition the window frame on.
         order_by: Set ordering within the window frame.
     """
-    partition_cols = (
-        [col.expr for col in partition_by] if partition_by is not None else 
None
-    )
+    partition_by_raw = expr_list_to_raw_expr_list(partition_by)
     order_by_raw = sort_list_to_raw_sort_list(order_by)
 
     return Expr(
         f.percent_rank(
-            partition_by=partition_cols,
+            partition_by=partition_by_raw,
             order_by=order_by_raw,
         )
     )
 
 
 def cume_dist(
-    partition_by: Optional[list[Expr]] = None,
-    order_by: Optional[list[Expr | SortExpr]] = None,
+    partition_by: Optional[list[Expr] | Expr] = None,
+    order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
 ) -> Expr:
     """Create a cumulative distribution window function.
 
@@ -2684,14 +2672,12 @@ def cume_dist(
         partition_by: Expressions to partition the window frame on.
         order_by: Set ordering within the window frame.
     """
-    partition_cols = (
-        [col.expr for col in partition_by] if partition_by is not None else 
None
-    )
+    partition_by_raw = expr_list_to_raw_expr_list(partition_by)
     order_by_raw = sort_list_to_raw_sort_list(order_by)
 
     return Expr(
         f.cume_dist(
-            partition_by=partition_cols,
+            partition_by=partition_by_raw,
             order_by=order_by_raw,
         )
     )
@@ -2699,8 +2685,8 @@ def cume_dist(
 
 def ntile(
     groups: int,
-    partition_by: Optional[list[Expr]] = None,
-    order_by: Optional[list[Expr | SortExpr]] = None,
+    partition_by: Optional[list[Expr] | Expr] = None,
+    order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
 ) -> Expr:
     """Create a n-tile window function.
 
@@ -2725,15 +2711,13 @@ def ntile(
         partition_by: Expressions to partition the window frame on.
         order_by: Set ordering within the window frame.
     """
-    partition_cols = (
-        [col.expr for col in partition_by] if partition_by is not None else 
None
-    )
+    partition_by_raw = expr_list_to_raw_expr_list(partition_by)
     order_by_raw = sort_list_to_raw_sort_list(order_by)
 
     return Expr(
         f.ntile(
             Expr.literal(groups).expr,
-            partition_by=partition_cols,
+            partition_by=partition_by_raw,
             order_by=order_by_raw,
         )
     )
@@ -2743,7 +2727,7 @@ def string_agg(
     expression: Expr,
     delimiter: str,
     filter: Optional[Expr] = None,
-    order_by: Optional[list[Expr | SortExpr]] = None,
+    order_by: Optional[list[Expr | SortExpr] | Expr | SortExpr] = None,
 ) -> Expr:
     """Concatenates the input strings.
 
diff --git a/python/tests/test_aggregation.py b/python/tests/test_aggregation.py
index 49dfb38c..96269b16 100644
--- a/python/tests/test_aggregation.py
+++ b/python/tests/test_aggregation.py
@@ -154,6 +154,11 @@ def test_aggregation_stats(df, agg_expr, calc_expected):
             pa.array([[6, 4, 4]]),
             False,
         ),
+        (
+            f.array_agg(column("b"), order_by=column("c")),
+            pa.array([[6, 4, 4]]),
+            False,
+        ),
         (f.avg(column("b"), filter=column("a") != lit(1)), pa.array([5.0]), 
False),
         (f.sum(column("b"), filter=column("a") != lit(1)), pa.array([10]), 
False),
         (f.count(column("b"), distinct=True), pa.array([2]), False),
@@ -329,6 +334,15 @@ def test_bit_and_bool_fns(df, name, expr, result):
             ),
             [None, None],
         ),
+        (
+            "first_value_no_list_order_by",
+            f.first_value(
+                column("b"),
+                order_by=column("b"),
+                null_treatment=NullTreatment.RESPECT_NULLS,
+            ),
+            [None, None],
+        ),
         (
             "first_value_ignore_null",
             f.first_value(
@@ -343,6 +357,11 @@ def test_bit_and_bool_fns(df, name, expr, result):
             f.last_value(column("a"), 
order_by=[column("a").sort(ascending=False)]),
             [0, 4],
         ),
+        (
+            "last_value_no_list_ordered",
+            f.last_value(column("a"), order_by=column("a")),
+            [3, 6],
+        ),
         (
             "last_value_with_null",
             f.last_value(
@@ -366,6 +385,11 @@ def test_bit_and_bool_fns(df, name, expr, result):
             f.nth_value(column("a"), 2, 
order_by=[column("a").sort(ascending=False)]),
             [2, 5],
         ),
+        (
+            "nth_value_no_list_ordered",
+            f.nth_value(column("a"), 2, 
order_by=column("a").sort(ascending=False)),
+            [2, 5],
+        ),
         (
             "nth_value_with_null",
             f.nth_value(
@@ -414,6 +438,11 @@ def test_first_last_value(df_partitioned, name, expr, 
result) -> None:
             f.string_agg(column("a"), ",", order_by=[column("b")]),
             "one,three,two,two",
         ),
+        (
+            "string_agg",
+            f.string_agg(column("a"), ",", order_by=column("b")),
+            "one,three,two,two",
+        ),
     ],
 )
 def test_string_agg(name, expr, result) -> None:
diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index 590bad32..1fd99b33 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -551,12 +551,25 @@ data_test_window_functions = [
         ),
         [2, 1, 3, 4, 2, 1, 3],
     ),
+    (
+        "row_w_params_no_lists",
+        f.row_number(
+            order_by=column("b"),
+            partition_by=column("c"),
+        ),
+        [2, 1, 3, 4, 2, 1, 3],
+    ),
     ("rank", f.rank(order_by=[column("b")]), [3, 1, 3, 5, 6, 1, 6]),
     (
         "rank_w_params",
         f.rank(order_by=[column("b"), column("a")], 
partition_by=[column("c")]),
         [2, 1, 3, 4, 2, 1, 3],
     ),
+    (
+        "rank_w_params_no_lists",
+        f.rank(order_by=column("a"), partition_by=column("c")),
+        [1, 2, 3, 4, 1, 2, 3],
+    ),
     (
         "dense_rank",
         f.dense_rank(order_by=[column("b")]),
@@ -567,6 +580,11 @@ data_test_window_functions = [
         f.dense_rank(order_by=[column("b"), column("a")], 
partition_by=[column("c")]),
         [2, 1, 3, 4, 2, 1, 3],
     ),
+    (
+        "dense_rank_w_params_no_lists",
+        f.dense_rank(order_by=column("a"), partition_by=column("c")),
+        [1, 2, 3, 4, 1, 2, 3],
+    ),
     (
         "percent_rank",
         f.round(f.percent_rank(order_by=[column("b")]), literal(3)),
@@ -582,6 +600,14 @@ data_test_window_functions = [
         ),
         [0.333, 0.0, 0.667, 1.0, 0.5, 0.0, 1.0],
     ),
+    (
+        "percent_rank_w_params_no_lists",
+        f.round(
+            f.percent_rank(order_by=column("a"), partition_by=column("c")),
+            literal(3),
+        ),
+        [0.0, 0.333, 0.667, 1.0, 0.0, 0.5, 1.0],
+    ),
     (
         "cume_dist",
         f.round(f.cume_dist(order_by=[column("b")]), literal(3)),
@@ -597,6 +623,14 @@ data_test_window_functions = [
         ),
         [0.5, 0.25, 0.75, 1.0, 0.667, 0.333, 1.0],
     ),
+    (
+        "cume_dist_w_params_no_lists",
+        f.round(
+            f.cume_dist(order_by=column("a"), partition_by=column("c")),
+            literal(3),
+        ),
+        [0.25, 0.5, 0.75, 1.0, 0.333, 0.667, 1.0],
+    ),
     (
         "ntile",
         f.ntile(2, order_by=[column("b")]),
@@ -607,6 +641,11 @@ data_test_window_functions = [
         f.ntile(2, order_by=[column("b"), column("a")], 
partition_by=[column("c")]),
         [1, 1, 2, 2, 1, 1, 2],
     ),
+    (
+        "ntile_w_params_no_lists",
+        f.ntile(2, order_by=column("b"), partition_by=column("c")),
+        [1, 1, 2, 2, 1, 1, 2],
+    ),
     ("lead", f.lead(column("b"), order_by=[column("b")]), [7, None, 8, 9, 9, 
7, None]),
     (
         "lead_w_params",
@@ -619,6 +658,17 @@ data_test_window_functions = [
         ),
         [8, 7, -1, -1, -1, 9, -1],
     ),
+    (
+        "lead_w_params_no_lists",
+        f.lead(
+            column("b"),
+            shift_offset=2,
+            default_value=-1,
+            order_by=column("b"),
+            partition_by=column("c"),
+        ),
+        [8, 7, -1, -1, -1, 9, -1],
+    ),
     ("lag", f.lag(column("b"), order_by=[column("b")]), [None, None, 7, 7, 8, 
None, 9]),
     (
         "lag_w_params",
@@ -631,6 +681,17 @@ data_test_window_functions = [
         ),
         [-1, -1, None, 7, -1, -1, None],
     ),
+    (
+        "lag_w_params_no_lists",
+        f.lag(
+            column("b"),
+            shift_offset=2,
+            default_value=-1,
+            order_by=column("b"),
+            partition_by=column("c"),
+        ),
+        [-1, -1, None, 7, -1, -1, None],
+    ),
     (
         "first_value",
         f.first_value(column("a")).over(
@@ -638,6 +699,13 @@ data_test_window_functions = [
         ),
         [1, 1, 1, 1, 5, 5, 5],
     ),
+    (
+        "first_value_without_list_args",
+        f.first_value(column("a")).over(
+            Window(partition_by=column("c"), order_by=column("b"))
+        ),
+        [1, 1, 1, 1, 5, 5, 5],
+    ),
     (
         "last_value",
         f.last_value(column("a")).over(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to