This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 19735de5b46e [SPARK-53534][PYTHON][TESTS] Add tests for arrow udf with numpy output 19735de5b46e is described below commit 19735de5b46ed91dee55b77a243cc761bbeb3801 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Sep 9 08:45:15 2025 -0700 [SPARK-53534][PYTHON][TESTS] Add tests for arrow udf with numpy output ### What changes were proposed in this pull request? Add tests for arrow udf with numpy output ### Why are the changes needed? to improve test coverage ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #52285 from zhengruifeng/test_numpy_arrow_agg. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- .../sql/tests/arrow/test_arrow_udf_grouped_agg.py | 33 ++++++++++++++++++++ .../sql/tests/arrow/test_arrow_udf_window.py | 35 ++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py index f6c3112f94ca..81a9c81ea671 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py @@ -886,6 +886,39 @@ class GroupedAggArrowUDFTestsMixin: # Integer value 2147483657 not in range: -2147483648 to 2147483647 result3.collect() + def test_return_numpy_scalar(self): + import numpy as np + import pyarrow as pa + + @arrow_udf("long") + def np_max_udf(v: pa.Array) -> np.int64: + assert isinstance(v, pa.Array) + return np.max(v) + + @arrow_udf("long") + def np_min_udf(v: pa.Array) -> np.int64: + assert isinstance(v, pa.Array) + return np.min(v) + + @arrow_udf("double") + def np_avg_udf(v: pa.Array) -> np.float64: + assert isinstance(v, pa.Array) + return np.mean(v) + + df = self.spark.range(10) + expected = df.select( + sf.max("id").alias("max"), + sf.min("id").alias("min"), + sf.avg("id").alias("avg"), + ) + + result = df.select( + np_max_udf("id").alias("max"), + np_min_udf("id").alias("min"), + np_avg_udf("id").alias("avg"), + ) + self.assertEqual(expected.collect(), result.collect()) + def test_unsupported_return_types(self): import pyarrow as pa diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py index fde9d7243375..b543c562f4a6 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py @@ -718,6 +718,41 @@ class WindowArrowUDFTestsMixin: # Integer value 2147483657 not in range: -2147483648 to 2147483647 result3.collect() + def test_return_numpy_scalar(self): + import numpy as np + import pyarrow as pa + + df = self.spark.range(10).withColumn("v", sf.lit(1)) + w = Window.partitionBy("id").orderBy("v") + + @arrow_udf("long") + def np_max_udf(v: pa.Array) -> np.int64: + assert isinstance(v, pa.Array) + return np.max(v) + + @arrow_udf("long") + def np_min_udf(v: pa.Array) -> np.int64: + assert isinstance(v, pa.Array) + return np.min(v) + + @arrow_udf("double") + def np_avg_udf(v: pa.Array) -> np.float64: + assert isinstance(v, pa.Array) + return np.mean(v) + + expected = df.select( + sf.max("id").over(w).alias("max"), + sf.min("id").over(w).alias("min"), + sf.avg("id").over(w).alias("avg"), + ) + + result = df.select( + np_max_udf("id").over(w).alias("max"), + np_min_udf("id").over(w).alias("min"), + np_avg_udf("id").over(w).alias("avg"), + ) + self.assertEqual(expected.collect(), result.collect()) + class WindowArrowUDFTests(WindowArrowUDFTestsMixin, ReusedSQLTestCase): pass --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org