This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0c26d7a1a68c [SPARK-53395][PYTHON][CONNECT][TESTS] Add tests for combinations of different scalar UDFs 0c26d7a1a68c is described below commit 0c26d7a1a68c315e90e6d70ffba147522c8656cf Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed Aug 27 18:47:15 2025 +0800 [SPARK-53395][PYTHON][CONNECT][TESTS] Add tests for combinations of different scalar UDFs ### What changes were proposed in this pull request? Add tests for combinations of different scalar UDFs ### Why are the changes needed? to improve test coverage for complex combinations of different UDF types, we have 6 scalar udf now: - python udf - arrow-optimized python udf - pandas udf - pandas udf with iter api - arrow udf - arrow udf with iter api we should make sure different kinds of UDFs can be used together ### Does this PR introduce _any_ user-facing change? No, test-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #52141 from zhengruifeng/test_udf_comb. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- dev/sparktestsupport/modules.py | 2 + .../tests/connect/test_parity_udf_combinations.py | 40 ++++ python/pyspark/sql/tests/test_udf_combinations.py | 208 +++++++++++++++++++++ 3 files changed, 250 insertions(+) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 1e333ba6c246..5227ad9f75bc 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -585,6 +585,7 @@ pyspark_sql = Module( "pyspark.sql.tests.test_subquery", "pyspark.sql.tests.test_types", "pyspark.sql.tests.test_udf", + "pyspark.sql.tests.test_udf_combinations", "pyspark.sql.tests.test_udf_profiler", "pyspark.sql.tests.test_udtf", "pyspark.sql.tests.test_tvf", @@ -1102,6 +1103,7 @@ pyspark_connect = Module( "pyspark.sql.tests.connect.test_parity_column", "pyspark.sql.tests.connect.test_parity_readwriter", "pyspark.sql.tests.connect.test_parity_udf", + "pyspark.sql.tests.connect.test_parity_udf_combinations", "pyspark.sql.tests.connect.test_parity_udf_profiler", "pyspark.sql.tests.connect.test_parity_memory_profiler", "pyspark.sql.tests.connect.test_parity_udtf", diff --git a/python/pyspark/sql/tests/connect/test_parity_udf_combinations.py b/python/pyspark/sql/tests/connect/test_parity_udf_combinations.py new file mode 100644 index 000000000000..bc63aa7aeb50 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_udf_combinations.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.test_udf_combinations import UDFCombinationsTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class UDFCombinationsParityTests(UDFCombinationsTestsMixin, ReusedConnectTestCase): + @classmethod + def setUpClass(cls): + ReusedConnectTestCase.setUpClass() + cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "false") + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_parity_udf_combinations import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_udf_combinations.py b/python/pyspark/sql/tests/test_udf_combinations.py new file mode 100644 index 000000000000..8111bb79d3c7 --- /dev/null +++ b/python/pyspark/sql/tests/test_udf_combinations.py @@ -0,0 +1,208 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Iterator +import itertools +import unittest + +from pyspark.sql.functions import udf, arrow_udf, pandas_udf +from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing.utils import ( + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, +) +class UDFCombinationsTestsMixin: + @property + def python_udf_add1(self): + @udf("long") + def py_add1(v): + assert isinstance(v, int) + return v + 1 + + return py_add1 + + @property + def arrow_opt_python_udf_add1(self): + @udf("long") + def py_arrow_opt_add1(v, useArrow=True): + assert isinstance(v, int) + return v + 1 + + return py_arrow_opt_add1 + + @property + def pandas_udf_add1(self): + import pandas as pd + + @pandas_udf("long") + def pandas_add1(s): + assert isinstance(s, pd.Series) + return s + 1 + + return pandas_add1 + + @property + def pandas_iter_udf_add1(self): + import pandas as pd + + @pandas_udf("long") + def pandas_iter_add1(it: Iterator[pd.Series]) -> Iterator[pd.Series]: + for s in it: + assert isinstance(s, pd.Series) + yield s + 1 + + return pandas_iter_add1 + + @property + def arrow_udf_add1(self): + import pyarrow as pa + + @arrow_udf("long") + def arrow_add1(a): + assert isinstance(a, pa.Array) + return pa.compute.add(a, 1) + + return arrow_add1 + + @property + def arrow_iter_udf_add1(self): + import pyarrow as pa + + @arrow_udf("long") + def arrow_iter_add1(it: Iterator[pa.Array]) -> Iterator[pa.Array]: + for a in it: + assert isinstance(a, pa.Array) + yield pa.compute.add(a, 1) + + return arrow_iter_add1 + + def all_scalar_functions(self): + return [ + self.python_udf_add1, + self.arrow_opt_python_udf_add1, + self.pandas_udf_add1, + self.pandas_iter_udf_add1, + self.arrow_udf_add1, + self.arrow_iter_udf_add1, + ] + + def test_combination_2(self): + df = self.spark.range(10) + + expected = df.selectExpr("id + 2 AS res").collect() + + combs = itertools.combinations(self.all_scalar_functions(), 2) + for f1, f2 in combs: + with self.subTest( + udf1=f1.__name__, + udf2=f2.__name__, + ): + result = df.select(f1(f2("id")).alias("res")) + self.assertEqual(expected, result.collect()) + + def test_combination_3(self): + df = self.spark.range(10) + + expected = df.selectExpr("id + 3 AS res").collect() + + combs = itertools.combinations(self.all_scalar_functions(), 3) + for f1, f2, f3 in combs: + with self.subTest( + udf1=f1.__name__, + udf2=f2.__name__, + udf3=f3.__name__, + ): + result = df.select(f1(f2(f3("id"))).alias("res")) + self.assertEqual(expected, result.collect()) + + def test_combination_4(self): + df = self.spark.range(10) + + expected = df.selectExpr("id + 4 AS res").collect() + + combs = itertools.combinations(self.all_scalar_functions(), 4) + for f1, f2, f3, f4 in combs: + with self.subTest( + udf1=f1.__name__, + udf2=f2.__name__, + udf3=f3.__name__, + udf4=f4.__name__, + ): + result = df.select(f1(f2(f3(f4("id")))).alias("res")) + self.assertEqual(expected, result.collect()) + + def test_combination_5(self): + df = self.spark.range(10) + + expected = df.selectExpr("id + 5 AS res").collect() + + combs = itertools.combinations(self.all_scalar_functions(), 5) + for f1, f2, f3, f4, f5 in combs: + with self.subTest( + udf1=f1.__name__, + udf2=f2.__name__, + udf3=f3.__name__, + udf4=f4.__name__, + udf5=f5.__name__, + ): + result = df.select(f1(f2(f3(f4(f5("id"))))).alias("res")) + self.assertEqual(expected, result.collect()) + + def test_combination_6(self): + df = self.spark.range(10) + + expected = df.selectExpr("id + 6 AS res").collect() + + combs = itertools.combinations(self.all_scalar_functions(), 6) + for f1, f2, f3, f4, f5, f6 in combs: + with self.subTest( + udf1=f1.__name__, + udf2=f2.__name__, + udf3=f3.__name__, + udf4=f4.__name__, + udf5=f5.__name__, + udf6=f6.__name__, + ): + result = df.select(f1(f2(f3(f4(f5(f6("id")))))).alias("res")) + self.assertEqual(expected, result.collect()) + + +class UDFCombinationsTests(UDFCombinationsTestsMixin, ReusedSQLTestCase): + @classmethod + def setUpClass(cls): + ReusedSQLTestCase.setUpClass() + cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "false") + + +if __name__ == "__main__": + from pyspark.sql.tests.test_udf_combinations import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org