This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0c26d7a1a68c [SPARK-53395][PYTHON][CONNECT][TESTS] Add tests for 
combinations of different scalar UDFs
0c26d7a1a68c is described below

commit 0c26d7a1a68c315e90e6d70ffba147522c8656cf
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Aug 27 18:47:15 2025 +0800

    [SPARK-53395][PYTHON][CONNECT][TESTS] Add tests for combinations of 
different scalar UDFs
    
    ### What changes were proposed in this pull request?
    Add tests for combinations of different scalar UDFs
    
    ### Why are the changes needed?
    to improve test coverage for complex combinations of different UDF types, 
we have 6 scalar udf now:
    
    - python udf
    - arrow-optimized python udf
    - pandas udf
    - pandas udf with iter api
    - arrow udf
    - arrow udf with iter api
    
    we should make sure different kinds of UDFs can be used together
    
    ### Does this PR introduce _any_ user-facing change?
    No, test-only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #52141 from zhengruifeng/test_udf_comb.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 dev/sparktestsupport/modules.py                    |   2 +
 .../tests/connect/test_parity_udf_combinations.py  |  40 ++++
 python/pyspark/sql/tests/test_udf_combinations.py  | 208 +++++++++++++++++++++
 3 files changed, 250 insertions(+)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 1e333ba6c246..5227ad9f75bc 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -585,6 +585,7 @@ pyspark_sql = Module(
         "pyspark.sql.tests.test_subquery",
         "pyspark.sql.tests.test_types",
         "pyspark.sql.tests.test_udf",
+        "pyspark.sql.tests.test_udf_combinations",
         "pyspark.sql.tests.test_udf_profiler",
         "pyspark.sql.tests.test_udtf",
         "pyspark.sql.tests.test_tvf",
@@ -1102,6 +1103,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_parity_column",
         "pyspark.sql.tests.connect.test_parity_readwriter",
         "pyspark.sql.tests.connect.test_parity_udf",
+        "pyspark.sql.tests.connect.test_parity_udf_combinations",
         "pyspark.sql.tests.connect.test_parity_udf_profiler",
         "pyspark.sql.tests.connect.test_parity_memory_profiler",
         "pyspark.sql.tests.connect.test_parity_udtf",
diff --git a/python/pyspark/sql/tests/connect/test_parity_udf_combinations.py 
b/python/pyspark/sql/tests/connect/test_parity_udf_combinations.py
new file mode 100644
index 000000000000..bc63aa7aeb50
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_udf_combinations.py
@@ -0,0 +1,40 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.tests.test_udf_combinations import UDFCombinationsTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class UDFCombinationsParityTests(UDFCombinationsTestsMixin, 
ReusedConnectTestCase):
+    @classmethod
+    def setUpClass(cls):
+        ReusedConnectTestCase.setUpClass()
+        cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", 
"false")
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.test_parity_udf_combinations import *  # 
noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_udf_combinations.py 
b/python/pyspark/sql/tests/test_udf_combinations.py
new file mode 100644
index 000000000000..8111bb79d3c7
--- /dev/null
+++ b/python/pyspark/sql/tests/test_udf_combinations.py
@@ -0,0 +1,208 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Iterator
+import itertools
+import unittest
+
+from pyspark.sql.functions import udf, arrow_udf, pandas_udf
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import (
+    have_pandas,
+    have_pyarrow,
+    pandas_requirement_message,
+    pyarrow_requirement_message,
+)
+
+
+@unittest.skipIf(
+    not have_pandas or not have_pyarrow,
+    pandas_requirement_message or pyarrow_requirement_message,
+)
+class UDFCombinationsTestsMixin:
+    @property
+    def python_udf_add1(self):
+        @udf("long")
+        def py_add1(v):
+            assert isinstance(v, int)
+            return v + 1
+
+        return py_add1
+
+    @property
+    def arrow_opt_python_udf_add1(self):
+        @udf("long")
+        def py_arrow_opt_add1(v, useArrow=True):
+            assert isinstance(v, int)
+            return v + 1
+
+        return py_arrow_opt_add1
+
+    @property
+    def pandas_udf_add1(self):
+        import pandas as pd
+
+        @pandas_udf("long")
+        def pandas_add1(s):
+            assert isinstance(s, pd.Series)
+            return s + 1
+
+        return pandas_add1
+
+    @property
+    def pandas_iter_udf_add1(self):
+        import pandas as pd
+
+        @pandas_udf("long")
+        def pandas_iter_add1(it: Iterator[pd.Series]) -> Iterator[pd.Series]:
+            for s in it:
+                assert isinstance(s, pd.Series)
+                yield s + 1
+
+        return pandas_iter_add1
+
+    @property
+    def arrow_udf_add1(self):
+        import pyarrow as pa
+
+        @arrow_udf("long")
+        def arrow_add1(a):
+            assert isinstance(a, pa.Array)
+            return pa.compute.add(a, 1)
+
+        return arrow_add1
+
+    @property
+    def arrow_iter_udf_add1(self):
+        import pyarrow as pa
+
+        @arrow_udf("long")
+        def arrow_iter_add1(it: Iterator[pa.Array]) -> Iterator[pa.Array]:
+            for a in it:
+                assert isinstance(a, pa.Array)
+                yield pa.compute.add(a, 1)
+
+        return arrow_iter_add1
+
+    def all_scalar_functions(self):
+        return [
+            self.python_udf_add1,
+            self.arrow_opt_python_udf_add1,
+            self.pandas_udf_add1,
+            self.pandas_iter_udf_add1,
+            self.arrow_udf_add1,
+            self.arrow_iter_udf_add1,
+        ]
+
+    def test_combination_2(self):
+        df = self.spark.range(10)
+
+        expected = df.selectExpr("id + 2 AS res").collect()
+
+        combs = itertools.combinations(self.all_scalar_functions(), 2)
+        for f1, f2 in combs:
+            with self.subTest(
+                udf1=f1.__name__,
+                udf2=f2.__name__,
+            ):
+                result = df.select(f1(f2("id")).alias("res"))
+                self.assertEqual(expected, result.collect())
+
+    def test_combination_3(self):
+        df = self.spark.range(10)
+
+        expected = df.selectExpr("id + 3 AS res").collect()
+
+        combs = itertools.combinations(self.all_scalar_functions(), 3)
+        for f1, f2, f3 in combs:
+            with self.subTest(
+                udf1=f1.__name__,
+                udf2=f2.__name__,
+                udf3=f3.__name__,
+            ):
+                result = df.select(f1(f2(f3("id"))).alias("res"))
+                self.assertEqual(expected, result.collect())
+
+    def test_combination_4(self):
+        df = self.spark.range(10)
+
+        expected = df.selectExpr("id + 4 AS res").collect()
+
+        combs = itertools.combinations(self.all_scalar_functions(), 4)
+        for f1, f2, f3, f4 in combs:
+            with self.subTest(
+                udf1=f1.__name__,
+                udf2=f2.__name__,
+                udf3=f3.__name__,
+                udf4=f4.__name__,
+            ):
+                result = df.select(f1(f2(f3(f4("id")))).alias("res"))
+                self.assertEqual(expected, result.collect())
+
+    def test_combination_5(self):
+        df = self.spark.range(10)
+
+        expected = df.selectExpr("id + 5 AS res").collect()
+
+        combs = itertools.combinations(self.all_scalar_functions(), 5)
+        for f1, f2, f3, f4, f5 in combs:
+            with self.subTest(
+                udf1=f1.__name__,
+                udf2=f2.__name__,
+                udf3=f3.__name__,
+                udf4=f4.__name__,
+                udf5=f5.__name__,
+            ):
+                result = df.select(f1(f2(f3(f4(f5("id"))))).alias("res"))
+                self.assertEqual(expected, result.collect())
+
+    def test_combination_6(self):
+        df = self.spark.range(10)
+
+        expected = df.selectExpr("id + 6 AS res").collect()
+
+        combs = itertools.combinations(self.all_scalar_functions(), 6)
+        for f1, f2, f3, f4, f5, f6 in combs:
+            with self.subTest(
+                udf1=f1.__name__,
+                udf2=f2.__name__,
+                udf3=f3.__name__,
+                udf4=f4.__name__,
+                udf5=f5.__name__,
+                udf6=f6.__name__,
+            ):
+                result = df.select(f1(f2(f3(f4(f5(f6("id")))))).alias("res"))
+                self.assertEqual(expected, result.collect())
+
+
+class UDFCombinationsTests(UDFCombinationsTestsMixin, ReusedSQLTestCase):
+    @classmethod
+    def setUpClass(cls):
+        ReusedSQLTestCase.setUpClass()
+        cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", 
"false")
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.test_udf_combinations import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to