This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d00eb2b3272f [MINOR][PYTHON][TESTS] Skip some tests if numpy not installed d00eb2b3272f is described below commit d00eb2b3272f12c327660bc6ca47220bb6cd5f13 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Sep 11 10:30:33 2025 +0800 [MINOR][PYTHON][TESTS] Skip some tests if numpy not installed ### What changes were proposed in this pull request? Skip some tests if numpy not installed ### Why are the changes needed? these tests depends on numpy ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #52300 from zhengruifeng/test_skip_numpy. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py | 11 +++++++++-- python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py | 8 ++++++-- python/pyspark/sql/tests/arrow/test_arrow_udf_window.py | 9 +++++++-- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py index 81a9c81ea671..d49f341788be 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py @@ -30,11 +30,13 @@ from pyspark.sql.types import ( ) from pyspark.sql import functions as sf from pyspark.errors import AnalysisException, PythonException -from pyspark.testing.sqlutils import ( - ReusedSQLTestCase, +from pyspark.testing.utils import ( + have_numpy, + numpy_requirement_message, have_pyarrow, pyarrow_requirement_message, ) +from pyspark.testing.sqlutils import ReusedSQLTestCase @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) @@ -146,6 +148,7 @@ class GroupedAggArrowUDFTestsMixin: self.assertEqual(expected, result.collect()) + @unittest.skipIf(not have_numpy, numpy_requirement_message) def test_basic(self): df = self.data weighted_mean_udf = self.arrow_agg_weighted_mean_udf @@ -268,6 +271,7 @@ class GroupedAggArrowUDFTestsMixin: self.assertEqual(expected5, result5.collect()) self.assertEqual(expected6, result6.collect()) + @unittest.skipIf(not have_numpy, numpy_requirement_message) def test_multiple_udfs(self): """ Test multiple group aggregate pandas UDFs in one agg function. @@ -537,6 +541,7 @@ class GroupedAggArrowUDFTestsMixin: assert filtered.collect()[0]["mean"] == 42.0 + @unittest.skipIf(not have_numpy, numpy_requirement_message) def test_named_arguments(self): df = self.data weighted_mean = self.arrow_agg_weighted_mean_udf @@ -565,6 +570,7 @@ class GroupedAggArrowUDFTestsMixin: df.groupby("id").agg(sf.mean(df.v).alias("wm")).collect(), ) + @unittest.skipIf(not have_numpy, numpy_requirement_message) def test_named_arguments_negative(self): df = self.data weighted_mean = self.arrow_agg_weighted_mean_udf @@ -886,6 +892,7 @@ class GroupedAggArrowUDFTestsMixin: # Integer value 2147483657 not in range: -2147483648 to 2147483647 result3.collect() + @unittest.skipIf(not have_numpy, numpy_requirement_message) def test_return_numpy_scalar(self): import numpy as np import pyarrow as pa diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py index d6e010d8d2a9..3409ce953487 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py @@ -46,11 +46,13 @@ from pyspark.sql.types import ( YearMonthIntervalType, ) from pyspark.errors import AnalysisException, PythonException -from pyspark.testing.sqlutils import ( - ReusedSQLTestCase, +from pyspark.testing.utils import ( + have_numpy, + numpy_requirement_message, have_pyarrow, pyarrow_requirement_message, ) +from pyspark.testing.sqlutils import ReusedSQLTestCase @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) @@ -813,6 +815,7 @@ class ScalarArrowUDFTestsMixin: [row] = self.spark.sql("SELECT randomArrowUDF(1)").collect() self.assertEqual(row[0], 7) + @unittest.skipIf(not have_numpy, numpy_requirement_message) def test_nondeterministic_arrow_udf(self): import pyarrow as pa @@ -835,6 +838,7 @@ class ScalarArrowUDFTestsMixin: self.assertEqual(random_udf.deterministic, False) self.assertTrue(result1["plus_ten(rand)"].equals(result1["rand"] + 10)) + @unittest.skipIf(not have_numpy, numpy_requirement_message) def test_nondeterministic_arrow_udf_in_aggregate(self): with self.quiet(): df = self.spark.range(10) diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py index b543c562f4a6..1d301597c21a 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py @@ -22,11 +22,13 @@ from pyspark.util import PythonEvalType from pyspark.sql import functions as sf from pyspark.sql.window import Window from pyspark.errors import AnalysisException, PythonException, PySparkTypeError -from pyspark.testing.sqlutils import ( - ReusedSQLTestCase, +from pyspark.testing.utils import ( + have_numpy, + numpy_requirement_message, have_pyarrow, pyarrow_requirement_message, ) +from pyspark.testing.sqlutils import ReusedSQLTestCase @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) @@ -384,6 +386,7 @@ class WindowArrowUDFTestsMixin: self.assertEqual(expected1.collect(), result1.collect()) + @unittest.skipIf(not have_numpy, numpy_requirement_message) def test_named_arguments(self): df = self.data weighted_mean = self.arrow_agg_weighted_mean_udf @@ -427,6 +430,7 @@ class WindowArrowUDFTestsMixin: ).collect(), ) + @unittest.skipIf(not have_numpy, numpy_requirement_message) def test_named_arguments_negative(self): df = self.data weighted_mean = self.arrow_agg_weighted_mean_udf @@ -718,6 +722,7 @@ class WindowArrowUDFTestsMixin: # Integer value 2147483657 not in range: -2147483648 to 2147483647 result3.collect() + @unittest.skipIf(not have_numpy, numpy_requirement_message) def test_return_numpy_scalar(self): import numpy as np import pyarrow as pa --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org