This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4b72478e0578 [SPARK-53282][PYTHON][TESTS] Add test for arrow udf type 
hints
4b72478e0578 is described below

commit 4b72478e05785ba841b139fac732d35023fe8659
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Aug 15 12:50:34 2025 +0800

    [SPARK-53282][PYTHON][TESTS] Add test for arrow udf type hints
    
    ### What changes were proposed in this pull request?
    Add test for arrow udf type hints
    
    ### Why are the changes needed?
    to improve test coverage
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    new test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #52024 from zhengruifeng/test_type_hints.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   1 +
 python/pyspark/sql/pandas/typehints.py             |   5 +-
 .../sql/tests/arrow/test_arrow_udf_typehints.py    | 336 +++++++++++++++++++++
 3 files changed, 338 insertions(+), 4 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index c0645d7b5ba9..0141a1d3d9e2 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -558,6 +558,7 @@ pyspark_sql = Module(
         "pyspark.sql.tests.arrow.test_arrow_udf_grouped_agg",
         "pyspark.sql.tests.arrow.test_arrow_udf_scalar",
         "pyspark.sql.tests.arrow.test_arrow_udf_window",
+        "pyspark.sql.tests.arrow.test_arrow_udf_typehints",
         "pyspark.sql.tests.arrow.test_arrow_udtf",
         "pyspark.sql.tests.pandas.test_pandas_cogrouped_map",
         "pyspark.sql.tests.pandas.test_pandas_grouped_map",
diff --git a/python/pyspark/sql/pandas/typehints.py 
b/python/pyspark/sql/pandas/typehints.py
index 0d5c50ce8a3c..d7d6fca8ded3 100644
--- a/python/pyspark/sql/pandas/typehints.py
+++ b/python/pyspark/sql/pandas/typehints.py
@@ -87,10 +87,7 @@ def infer_eval_type(
     ) and (return_annotation == pd.Series or return_annotation == pd.DataFrame)
 
     # pa.Array, ... -> pa.Array
-    is_arrow_array = all(
-        a == pa.Array or check_union_annotation(a, parameter_check_func=lambda 
na: na == pa.Array)
-        for a in parameters_sig
-    ) and (return_annotation == pa.Array)
+    is_arrow_array = all(a == pa.Array for a in parameters_sig) and 
(return_annotation == pa.Array)
 
     # Iterator[Tuple[Series, Frame or Union[DataFrame, Series], ...] -> 
Iterator[Series or Frame]
     is_iterator_tuple_series_or_frame = (
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_typehints.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_typehints.py
new file mode 100644
index 000000000000..3684dcd2779e
--- /dev/null
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_typehints.py
@@ -0,0 +1,336 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+from inspect import signature
+from typing import Union, Iterator, Tuple, get_type_hints
+
+from pyspark.sql import functions as sf
+from pyspark.testing.utils import (
+    have_pyarrow,
+    pyarrow_requirement_message,
+    have_numpy,
+    numpy_requirement_message,
+)
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.sql.pandas.typehints import infer_eval_type
+from pyspark.sql.pandas.functions import arrow_udf, ArrowUDFType
+from pyspark.sql import Row
+
+if have_pyarrow:
+    import pyarrow as pa
+
+
[email protected](not have_pyarrow, pyarrow_requirement_message)
+class ArrowUDFTypeHintsTests(ReusedSQLTestCase):
+    def test_type_annotation_scalar(self):
+        def func(col: pa.Array) -> pa.Array:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR
+        )
+
+        def func(col: pa.Array, col1: pa.Array) -> pa.Array:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR
+        )
+
+        def func(col: pa.Array, *args: pa.Array) -> pa.Array:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR
+        )
+
+        def func(col: pa.Array, *args: pa.Array, **kwargs: pa.Array) -> 
pa.Array:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR
+        )
+
+        def func(col: pa.Array, *, col2: pa.Array) -> pa.Array:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR
+        )
+
+        # Union[pa.Array, pa.Array] equals to pa.Array
+        def func(col: Union[pa.Array, pa.Array], *, col2: pa.Array) -> 
pa.Array:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR
+        )
+
+    def test_type_annotation_scalar_iter(self):
+        def func(iter: Iterator[pa.Array]) -> Iterator[pa.Array]:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR_ITER
+        )
+
+        def func(iter: Iterator[Tuple[pa.Array, ...]]) -> Iterator[pa.Array]:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR_ITER
+        )
+
+    def test_type_annotation_tuple_generics(self):
+        def func(iter: Iterator[tuple[pa.Array, pa.Array]]) -> 
Iterator[pa.Array]:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR_ITER
+        )
+
+        def func(iter: Iterator[tuple[pa.Array, ...]]) -> Iterator[pa.Array]:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR_ITER
+        )
+
+        # Union[pa.Array, pa.Array] equals to pa.Array
+        def func(iter: Iterator[tuple[Union[pa.Array, pa.Array], ...]]) -> 
Iterator[pa.Array]:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR_ITER
+        )
+
+    def test_type_annotation_group_agg(self):
+        def func(col: pa.Array) -> str:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.GROUPED_AGG
+        )
+
+        def func(col: pa.Array, col1: pa.Array) -> int:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.GROUPED_AGG
+        )
+
+        def func(col: pa.Array, *args: pa.Array) -> Row:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.GROUPED_AGG
+        )
+
+        def func(col: pa.Array, *args: pa.Array, **kwargs: pa.Array) -> str:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.GROUPED_AGG
+        )
+
+        def func(col: pa.Array, *, col2: pa.Array) -> float:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.GROUPED_AGG
+        )
+
+        # Union[pa.Array, pa.Array] equals to pa.Array
+        def func(col: Union[pa.Array, pa.Array], *, col2: pa.Array) -> float:
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.GROUPED_AGG
+        )
+
+    def test_type_annotation_negative(self):
+        def func(col: str) -> pa.Array:
+            pass
+
+        self.assertRaisesRegex(
+            NotImplementedError,
+            "Unsupported signature.*str",
+            infer_eval_type,
+            signature(func),
+            get_type_hints(func),
+        )
+
+        def func(col: pa.Array, col1: int) -> pa.Array:
+            pass
+
+        self.assertRaisesRegex(
+            NotImplementedError,
+            "Unsupported signature.*int",
+            infer_eval_type,
+            signature(func),
+            get_type_hints(func),
+        )
+
+        def func(col: Union[pa.Array, str], col1: int) -> pa.Array:
+            pass
+
+        self.assertRaisesRegex(
+            NotImplementedError,
+            "Unsupported signature.*str",
+            infer_eval_type,
+            signature(func),
+            get_type_hints(func),
+        )
+
+        def func(col: pa.Array) -> Tuple[pa.Array]:
+            pass
+
+        self.assertRaisesRegex(
+            NotImplementedError,
+            "Unsupported signature.*Tuple",
+            infer_eval_type,
+            signature(func),
+            get_type_hints(func),
+        )
+
+        def func(col, *args: pa.Array) -> pa.Array:
+            pass
+
+        self.assertRaisesRegex(
+            ValueError,
+            "should be specified.*Array",
+            infer_eval_type,
+            signature(func),
+            get_type_hints(func),
+        )
+
+        def func(col: pa.Array, *args: pa.Array, **kwargs: pa.Array):
+            pass
+
+        self.assertRaisesRegex(
+            ValueError,
+            "should be specified.*Array",
+            infer_eval_type,
+            signature(func),
+            get_type_hints(func),
+        )
+
+        def func(col: pa.Array, *, col2) -> pa.Array:
+            pass
+
+        self.assertRaisesRegex(
+            ValueError,
+            "should be specified.*Array",
+            infer_eval_type,
+            signature(func),
+            get_type_hints(func),
+        )
+
+    def test_scalar_udf_type_hint(self):
+        df = self.spark.range(10).selectExpr("id", "id as v")
+
+        def plus_one(v: pa.Array) -> pa.Array:
+            return pa.compute.add(v, 1)
+
+        plus_one = arrow_udf("long")(plus_one)
+        actual = df.select(plus_one(df.v).alias("plus_one"))
+        expected = df.selectExpr("(v + 1) as plus_one")
+        self.assertEqual(expected.collect(), actual.collect())
+
+    def test_scalar_iter_udf_type_hint(self):
+        df = self.spark.range(10).selectExpr("id", "id as v")
+
+        def plus_one(itr: Iterator[pa.Array]) -> Iterator[pa.Array]:
+            for s in itr:
+                yield pa.compute.add(s, 1)
+
+        plus_one = arrow_udf("long")(plus_one)
+
+        actual = df.select(plus_one(df.v).alias("plus_one"))
+        expected = df.selectExpr("(v + 1) as plus_one")
+        self.assertEqual(expected.collect(), actual.collect())
+
+    @unittest.skipIf(not have_numpy, numpy_requirement_message)
+    def test_group_agg_udf_type_hint(self):
+        import numpy as np
+
+        df = self.spark.range(10).selectExpr("id", "id as v")
+
+        def weighted_mean(v: pa.Array, w: pa.Array) -> np.float64:
+            return np.average(v, weights=w)
+
+        weighted_mean = arrow_udf("double")(weighted_mean)
+
+        actual = df.groupby("id").agg(weighted_mean(df.v, 
sf.lit(1.0))).sort("id")
+        expected = df.groupby("id").agg(sf.mean(df.v).alias("weighted_mean(v, 
1.0)")).sort("id")
+        self.assertEqual(expected.collect(), actual.collect())
+
+    def test_string_type_annotation(self):
+        def func(col: "pa.Array") -> "pa.Array":
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR
+        )
+
+        def func(col: "pa.Array", col1: "pa.Array") -> "pa.Array":
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR
+        )
+
+        def func(col: "pa.Array", *args: "pa.Array") -> "pa.Array":
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR
+        )
+
+        def func(col: "pa.Array", *args: "pa.Array", **kwargs: "pa.Array") -> 
"pa.Array":
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR
+        )
+
+        def func(col: "pa.Array", *, col2: "pa.Array") -> "pa.Array":
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR
+        )
+
+        # Union[pa.Array, pa.Array] equals to pa.Array
+        def func(col: Union["pa.Array", "pa.Array"], *, col2: "pa.Array") -> 
"pa.Array":
+            pass
+
+        self.assertEqual(
+            infer_eval_type(signature(func), get_type_hints(func)), 
ArrowUDFType.SCALAR
+        )
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.arrow.test_arrow_udf_typehints import *  # noqa: 
#401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to