This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 25c550ef37bc [SPARK-53695][PYTHON][TESTS] Add tests for 0-arg grouped 
agg UDF
25c550ef37bc is described below

commit 25c550ef37bcf4658b2d05326bd8563de78da167
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Sep 25 08:10:33 2025 +0900

    [SPARK-53695][PYTHON][TESTS] Add tests for 0-arg grouped agg UDF
    
    ### What changes were proposed in this pull request?
    Add tests for 0-arg vectorized UDF
    
    ### Why are the changes needed?
    to guard the 0-args cases:
    ```
    In [6]: pandas_udf("double")
       ...: def mean_udf2() -> float:
       ...:     return 1.0
       ...:
    
    In [7]: spark.range(10).select(mean_udf2()).show()
    +-----------+
    |mean_udf2()|
    +-----------+
    |        1.0|
    +-----------+
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #52437 from zhengruifeng/grouped_agg_0_arg.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../sql/tests/arrow/test_arrow_udf_grouped_agg.py  | 35 ++++++++++++++++++++++
 .../tests/pandas/test_pandas_udf_grouped_agg.py    | 35 +++++++++++++++++++++-
 2 files changed, 69 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
index d49f341788be..3fe6d28c66a6 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
@@ -951,6 +951,41 @@ class GroupedAggArrowUDFTestsMixin:
                 def func_a(a: pa.Array) -> pa.Scalar:
                     return pa.compute.max(a)
 
+    def test_0_args(self):
+        import pyarrow as pa
+
+        df = self.spark.range(10).withColumn("k", sf.col("id") % 3)
+
+        @arrow_udf("long", ArrowUDFType.GROUPED_AGG)
+        def arrow_max(v) -> int:
+            return pa.compute.max(v).as_py()
+
+        @arrow_udf("long", ArrowUDFType.GROUPED_AGG)
+        def arrow_lit_1() -> int:
+            return 1
+
+        expected1 = df.select(sf.max("id").alias("res1"), 
sf.lit(1).alias("res1"))
+        result1 = df.select(arrow_max("id").alias("res1"), 
arrow_lit_1().alias("res1"))
+        self.assertEqual(expected1.collect(), result1.collect())
+
+        expected2 = (
+            df.groupby("k")
+            .agg(
+                sf.max("id").alias("res1"),
+                sf.lit(1).alias("res1"),
+            )
+            .sort("k")
+        )
+        result2 = (
+            df.groupby("k")
+            .agg(
+                arrow_max("id").alias("res1"),
+                arrow_lit_1().alias("res1"),
+            )
+            .sort("k")
+        )
+        self.assertEqual(expected2.collect(), result2.collect())
+
 
 class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
index 1059af59f4a8..65f842fa70ad 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
@@ -19,7 +19,7 @@ import unittest
 from typing import cast
 
 from pyspark.util import PythonEvalType
-from pyspark.sql import Row
+from pyspark.sql import Row, functions as sf
 from pyspark.sql.functions import (
     array,
     explode,
@@ -761,6 +761,39 @@ class GroupedAggPandasUDFTestsMixin:
                 row = df.groupby("id").agg(test(df.id)).first()
                 self.assertEqual(row[1], 123)
 
+    def test_0_args(self):
+        df = self.spark.range(10).withColumn("k", sf.col("id") % 3)
+
+        @pandas_udf("long", PandasUDFType.GROUPED_AGG)
+        def pandas_max(v) -> int:
+            return v.max()
+
+        @pandas_udf("long", PandasUDFType.GROUPED_AGG)
+        def pandas_lit_1() -> int:
+            return 1
+
+        expected1 = df.select(sf.max("id").alias("res1"), 
sf.lit(1).alias("res1"))
+        result1 = df.select(pandas_max("id").alias("res1"), 
pandas_lit_1().alias("res1"))
+        self.assertEqual(expected1.collect(), result1.collect())
+
+        expected2 = (
+            df.groupby("k")
+            .agg(
+                sf.max("id").alias("res1"),
+                sf.lit(1).alias("res1"),
+            )
+            .sort("k")
+        )
+        result2 = (
+            df.groupby("k")
+            .agg(
+                pandas_max("id").alias("res1"),
+                pandas_lit_1().alias("res1"),
+            )
+            .sort("k")
+        )
+        self.assertEqual(expected2.collect(), result2.collect())
+
 
 class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, 
ReusedSQLTestCase):
     pass


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to