This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 19735de5b46e [SPARK-53534][PYTHON][TESTS] Add tests for arrow udf with 
numpy output
19735de5b46e is described below

commit 19735de5b46ed91dee55b77a243cc761bbeb3801
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Sep 9 08:45:15 2025 -0700

    [SPARK-53534][PYTHON][TESTS] Add tests for arrow udf with numpy output
    
    ### What changes were proposed in this pull request?
    Add tests for arrow udf with numpy output
    
    ### Why are the changes needed?
    to improve test coverage
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #52285 from zhengruifeng/test_numpy_arrow_agg.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 .../sql/tests/arrow/test_arrow_udf_grouped_agg.py  | 33 ++++++++++++++++++++
 .../sql/tests/arrow/test_arrow_udf_window.py       | 35 ++++++++++++++++++++++
 2 files changed, 68 insertions(+)

diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
index f6c3112f94ca..81a9c81ea671 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
@@ -886,6 +886,39 @@ class GroupedAggArrowUDFTestsMixin:
             # Integer value 2147483657 not in range: -2147483648 to 2147483647
             result3.collect()
 
+    def test_return_numpy_scalar(self):
+        import numpy as np
+        import pyarrow as pa
+
+        @arrow_udf("long")
+        def np_max_udf(v: pa.Array) -> np.int64:
+            assert isinstance(v, pa.Array)
+            return np.max(v)
+
+        @arrow_udf("long")
+        def np_min_udf(v: pa.Array) -> np.int64:
+            assert isinstance(v, pa.Array)
+            return np.min(v)
+
+        @arrow_udf("double")
+        def np_avg_udf(v: pa.Array) -> np.float64:
+            assert isinstance(v, pa.Array)
+            return np.mean(v)
+
+        df = self.spark.range(10)
+        expected = df.select(
+            sf.max("id").alias("max"),
+            sf.min("id").alias("min"),
+            sf.avg("id").alias("avg"),
+        )
+
+        result = df.select(
+            np_max_udf("id").alias("max"),
+            np_min_udf("id").alias("min"),
+            np_avg_udf("id").alias("avg"),
+        )
+        self.assertEqual(expected.collect(), result.collect())
+
     def test_unsupported_return_types(self):
         import pyarrow as pa
 
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
index fde9d7243375..b543c562f4a6 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
@@ -718,6 +718,41 @@ class WindowArrowUDFTestsMixin:
             # Integer value 2147483657 not in range: -2147483648 to 2147483647
             result3.collect()
 
+    def test_return_numpy_scalar(self):
+        import numpy as np
+        import pyarrow as pa
+
+        df = self.spark.range(10).withColumn("v", sf.lit(1))
+        w = Window.partitionBy("id").orderBy("v")
+
+        @arrow_udf("long")
+        def np_max_udf(v: pa.Array) -> np.int64:
+            assert isinstance(v, pa.Array)
+            return np.max(v)
+
+        @arrow_udf("long")
+        def np_min_udf(v: pa.Array) -> np.int64:
+            assert isinstance(v, pa.Array)
+            return np.min(v)
+
+        @arrow_udf("double")
+        def np_avg_udf(v: pa.Array) -> np.float64:
+            assert isinstance(v, pa.Array)
+            return np.mean(v)
+
+        expected = df.select(
+            sf.max("id").over(w).alias("max"),
+            sf.min("id").over(w).alias("min"),
+            sf.avg("id").over(w).alias("avg"),
+        )
+
+        result = df.select(
+            np_max_udf("id").over(w).alias("max"),
+            np_min_udf("id").over(w).alias("min"),
+            np_avg_udf("id").over(w).alias("avg"),
+        )
+        self.assertEqual(expected.collect(), result.collect())
+
 
 class WindowArrowUDFTests(WindowArrowUDFTestsMixin, ReusedSQLTestCase):
     pass


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to