This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 692f1b6265be [SPARK-52278][PYTHON] Scalar Arrow UDF support named
arguments
692f1b6265be is described below
commit 692f1b6265be6167f331caaa9875a4a2e7215fae
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri May 23 14:42:19 2025 +0800
[SPARK-52278][PYTHON] Scalar Arrow UDF support named arguments
### What changes were proposed in this pull request?
Scalar Arrow UDF support named arguments
### Why are the changes needed?
for feature parity with pandas UDF
### Does this PR introduce _any_ user-facing change?
no, Arrow UDF is not public now
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #50996 from zhengruifeng/py_arrow_udf_test_named_args.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../sql/tests/arrow/test_arrow_udf_scalar.py | 110 ++++++++++++++++++++-
1 file changed, 108 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
index 707aff62e87c..37d985244df8 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
@@ -44,7 +44,7 @@ from pyspark.sql.types import (
MapType,
BinaryType,
)
-from pyspark.errors import AnalysisException
+from pyspark.errors import AnalysisException, PythonException
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pyarrow,
@@ -726,7 +726,113 @@ class ScalarArrowUDFTestsMixin:
res = df.select(scalar_g(scalar_f(F.col("id"))).alias("res"))
self.assertEqual(expected, res.collect())
- # TODO: add tests for named arguments
+ def test_arrow_udf_named_arguments(self):
+ import pyarrow as pa
+
+ @arrow_udf("int")
+ def test_udf(a, b):
+ return pa.compute.add(a, pa.compute.multiply(b,
10)).cast(pa.int32())
+
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
+ self.spark.udf.register("test_udf", test_udf)
+
+ expected = [Row(0), Row(101)]
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(F.col("id"), b=F.col("id")
* 10)),
+ self.spark.range(2).select(test_udf(a=F.col("id"),
b=F.col("id") * 10)),
+ self.spark.range(2).select(test_udf(b=F.col("id") * 10,
a=F.col("id"))),
+ self.spark.sql("SELECT test_udf(id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ self.assertEqual(expected, df.collect())
+
+ def test_arrow_udf_named_arguments_negative(self):
+ import pyarrow as pa
+
+ @arrow_udf("int")
+ def test_udf(a, b):
+ return pa.compute.add(a, b).cast(pa.int32())
+
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
+ self.spark.udf.register("test_udf", test_udf)
+
+ with self.assertRaisesRegex(
+ AnalysisException,
+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ ):
+ self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM
range(2)").show()
+
+ with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+ self.spark.sql("SELECT test_udf(a => id, id * 10) FROM
range(2)").show()
+
+ with self.assertRaisesRegex(
+ PythonException, r"test_udf\(\) got an unexpected keyword argument
'c'"
+ ):
+ self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
+
+ def test_arrow_udf_named_arguments_and_defaults(self):
+ import pyarrow as pa
+
+ @arrow_udf("int")
+ def test_udf(a, b=0):
+ return pa.compute.add(a, pa.compute.multiply(b,
10)).cast(pa.int32())
+
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
+ self.spark.udf.register("test_udf", test_udf)
+
+ # without "b"
+ expected = [Row(0), Row(1)]
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(F.col("id"))),
+ self.spark.range(2).select(test_udf(a=F.col("id"))),
+ self.spark.sql("SELECT test_udf(id) FROM range(2)"),
+ self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
+ ]
+ ):
+ with self.subTest(with_b=False, query_no=i):
+ self.assertEqual(expected, df.collect())
+
+ # with "b"
+ expected = [Row(0), Row(101)]
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(F.col("id"), b=F.col("id")
* 10)),
+ self.spark.range(2).select(test_udf(a=F.col("id"),
b=F.col("id") * 10)),
+ self.spark.range(2).select(test_udf(b=F.col("id") * 10,
a=F.col("id"))),
+ self.spark.sql("SELECT test_udf(id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
+ ]
+ ):
+ with self.subTest(with_b=True, query_no=i):
+ self.assertEqual(expected, df.collect())
+
+ def test_arrow_udf_kwargs(self):
+ import pyarrow as pa
+
+ @arrow_udf("int")
+ def test_udf(a, **kwargs):
+ return pa.compute.add(a, pa.compute.multiply(kwargs["b"],
10)).cast(pa.int32())
+
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
+ self.spark.udf.register("test_udf", test_udf)
+
+ expected = [Row(0), Row(101)]
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(a=F.col("id"),
b=F.col("id") * 10)),
+ self.spark.range(2).select(test_udf(b=F.col("id") * 10,
a=F.col("id"))),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ self.assertEqual(expected, df.collect())
class ScalarArrowUDFTests(ScalarArrowUDFTestsMixin, ReusedSQLTestCase):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]