This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 8c7a6fc81cc [SPARK-43886][PYTHON] Accept generics tuple as typing hints of Pandas UDF 8c7a6fc81cc is described below commit 8c7a6fc81cceaba3d9c428baec336639b0d91205 Author: Xinrong Meng <xinr...@apache.org> AuthorDate: Wed May 31 09:19:21 2023 +0900 [SPARK-43886][PYTHON] Accept generics tuple as typing hints of Pandas UDF ### What changes were proposed in this pull request? Accept generics tuple as typing hints in Pandas UDF. ### Why are the changes needed? Adapt to [PEP 585](https://peps.python.org/pep-0585/) with Python 3.9. ### Does this PR introduce _any_ user-facing change? Yes. `tuple` is accepted as typing hints of Pandas UDF. FROM ```py >>> pandas_udf("long") ... def multiply(iterator: Iterator[tuple[pd.Series, pd.DataFrame]]) -> Iterator[pd.Series]: ... for s1, df in iterator: ... yield s1 * df.v ... Traceback (most recent call last): ... raise PySparkNotImplementedError( pyspark.errors.exceptions.base.PySparkNotImplementedError: [UNSUPPORTED_SIGNATURE] Unsupported signature: (iterator: Iterator[tuple[pandas.core.series.Series, pandas.core.frame.DataFrame]]) -> Iterator[pandas.core.series.Series]. ``` TO ```py >>> pandas_udf("long") ... def multiply(iterator: Iterator[tuple[pd.Series, pd.DataFrame]]) -> Iterator[pd.Series]: ... for s1, df in iterator: ... yield s1 * df.v ... >>> multiply._unwrapped.evalType 204 # SQL_SCALAR_PANDAS_ITER_UDF ``` ### How was this patch tested? Unit tests. Closes #41388 from xinrong-meng/tuple. Authored-by: Xinrong Meng <xinr...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/pandas/typehints.py | 2 +- .../sql/tests/pandas/test_pandas_udf_typehints.py | 24 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/pandas/typehints.py b/python/pyspark/sql/pandas/typehints.py index 29ac81af944..f0c13e66a63 100644 --- a/python/pyspark/sql/pandas/typehints.py +++ b/python/pyspark/sql/pandas/typehints.py @@ -145,7 +145,7 @@ def check_tuple_annotation( # Tuple has _name but other types have __name__ # Check if the name is Tuple first. After that, check the generic types. name = getattr(annotation, "_name", getattr(annotation, "__name__", None)) - return name == "Tuple" and ( + return name in ("Tuple", "tuple") and ( parameter_check_func is None or all(map(parameter_check_func, annotation.__args__)) ) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py index 3cdf83e2d06..bfb874ffe53 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import sys import unittest from inspect import signature from typing import Union, Iterator, Tuple, cast, get_type_hints @@ -113,6 +114,29 @@ class PandasUDFTypeHintsTests(ReusedSQLTestCase): infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR_ITER ) + @unittest.skipIf(sys.version_info < (3, 9), "Type hinting generics require Python 3.9.") + def test_type_annotation_tuple_generics(self): + def func(iter: Iterator[tuple[pd.DataFrame, pd.Series]]) -> Iterator[pd.DataFrame]: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR_ITER + ) + + def func(iter: Iterator[tuple[pd.DataFrame, ...]]) -> Iterator[pd.Series]: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR_ITER + ) + + def func(iter: Iterator[tuple[Union[pd.DataFrame, pd.Series], ...]]) -> Iterator[pd.Series]: + pass + + self.assertEqual( + infer_eval_type(signature(func), get_type_hints(func)), PandasUDFType.SCALAR_ITER + ) + def test_type_annotation_group_agg(self): def func(col: pd.Series) -> str: pass --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org