This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 692f1b6265be [SPARK-52278][PYTHON] Scalar Arrow UDF support named 
arguments
692f1b6265be is described below

commit 692f1b6265be6167f331caaa9875a4a2e7215fae
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri May 23 14:42:19 2025 +0800

    [SPARK-52278][PYTHON] Scalar Arrow UDF support named arguments
    
    ### What changes were proposed in this pull request?
    Scalar Arrow UDF support named arguments
    
    ### Why are the changes needed?
    for feature parity with pandas UDF
    
    ### Does this PR introduce _any_ user-facing change?
    no, Arrow UDF is not public now
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #50996 from zhengruifeng/py_arrow_udf_test_named_args.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../sql/tests/arrow/test_arrow_udf_scalar.py       | 110 ++++++++++++++++++++-
 1 file changed, 108 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
index 707aff62e87c..37d985244df8 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
@@ -44,7 +44,7 @@ from pyspark.sql.types import (
     MapType,
     BinaryType,
 )
-from pyspark.errors import AnalysisException
+from pyspark.errors import AnalysisException, PythonException
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
     have_pyarrow,
@@ -726,7 +726,113 @@ class ScalarArrowUDFTestsMixin:
         res = df.select(scalar_g(scalar_f(F.col("id"))).alias("res"))
         self.assertEqual(expected, res.collect())
 
-    # TODO: add tests for named arguments
+    def test_arrow_udf_named_arguments(self):
+        import pyarrow as pa
+
+        @arrow_udf("int")
+        def test_udf(a, b):
+            return pa.compute.add(a, pa.compute.multiply(b, 
10)).cast(pa.int32())
+
+        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
+        self.spark.udf.register("test_udf", test_udf)
+
+        expected = [Row(0), Row(101)]
+        for i, df in enumerate(
+            [
+                self.spark.range(2).select(test_udf(F.col("id"), b=F.col("id") 
* 10)),
+                self.spark.range(2).select(test_udf(a=F.col("id"), 
b=F.col("id") * 10)),
+                self.spark.range(2).select(test_udf(b=F.col("id") * 10, 
a=F.col("id"))),
+                self.spark.sql("SELECT test_udf(id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                self.assertEqual(expected, df.collect())
+
+    def test_arrow_udf_named_arguments_negative(self):
+        import pyarrow as pa
+
+        @arrow_udf("int")
+        def test_udf(a, b):
+            return pa.compute.add(a, b).cast(pa.int32())
+
+        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
+        self.spark.udf.register("test_udf", test_udf)
+
+        with self.assertRaisesRegex(
+            AnalysisException,
+            
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+        ):
+            self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM 
range(2)").show()
+
+        with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+            self.spark.sql("SELECT test_udf(a => id, id * 10) FROM 
range(2)").show()
+
+        with self.assertRaisesRegex(
+            PythonException, r"test_udf\(\) got an unexpected keyword argument 
'c'"
+        ):
+            self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
+
+    def test_arrow_udf_named_arguments_and_defaults(self):
+        import pyarrow as pa
+
+        @arrow_udf("int")
+        def test_udf(a, b=0):
+            return pa.compute.add(a, pa.compute.multiply(b, 
10)).cast(pa.int32())
+
+        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
+        self.spark.udf.register("test_udf", test_udf)
+
+        # without "b"
+        expected = [Row(0), Row(1)]
+        for i, df in enumerate(
+            [
+                self.spark.range(2).select(test_udf(F.col("id"))),
+                self.spark.range(2).select(test_udf(a=F.col("id"))),
+                self.spark.sql("SELECT test_udf(id) FROM range(2)"),
+                self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
+            ]
+        ):
+            with self.subTest(with_b=False, query_no=i):
+                self.assertEqual(expected, df.collect())
+
+        # with "b"
+        expected = [Row(0), Row(101)]
+        for i, df in enumerate(
+            [
+                self.spark.range(2).select(test_udf(F.col("id"), b=F.col("id") 
* 10)),
+                self.spark.range(2).select(test_udf(a=F.col("id"), 
b=F.col("id") * 10)),
+                self.spark.range(2).select(test_udf(b=F.col("id") * 10, 
a=F.col("id"))),
+                self.spark.sql("SELECT test_udf(id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
+            ]
+        ):
+            with self.subTest(with_b=True, query_no=i):
+                self.assertEqual(expected, df.collect())
+
+    def test_arrow_udf_kwargs(self):
+        import pyarrow as pa
+
+        @arrow_udf("int")
+        def test_udf(a, **kwargs):
+            return pa.compute.add(a, pa.compute.multiply(kwargs["b"], 
10)).cast(pa.int32())
+
+        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
+        self.spark.udf.register("test_udf", test_udf)
+
+        expected = [Row(0), Row(101)]
+        for i, df in enumerate(
+            [
+                self.spark.range(2).select(test_udf(a=F.col("id"), 
b=F.col("id") * 10)),
+                self.spark.range(2).select(test_udf(b=F.col("id") * 10, 
a=F.col("id"))),
+                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                self.assertEqual(expected, df.collect())
 
 
 class ScalarArrowUDFTests(ScalarArrowUDFTestsMixin, ReusedSQLTestCase):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to