This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8c7a6fc81cc [SPARK-43886][PYTHON] Accept generics tuple as typing
hints of Pandas UDF
8c7a6fc81cc is described below
commit 8c7a6fc81cceaba3d9c428baec336639b0d91205
Author: Xinrong Meng <[email protected]>
AuthorDate: Wed May 31 09:19:21 2023 +0900
[SPARK-43886][PYTHON] Accept generics tuple as typing hints of Pandas UDF
### What changes were proposed in this pull request?
Accept generics tuple as typing hints in Pandas UDF.
### Why are the changes needed?
Adapt to [PEP 585](https://peps.python.org/pep-0585/) with Python 3.9.
### Does this PR introduce _any_ user-facing change?
Yes. `tuple` is accepted as typing hints of Pandas UDF.
FROM
```py
>>> pandas_udf("long")
... def multiply(iterator: Iterator[tuple[pd.Series, pd.DataFrame]]) ->
Iterator[pd.Series]:
... for s1, df in iterator:
... yield s1 * df.v
...
Traceback (most recent call last):
...
raise PySparkNotImplementedError(
pyspark.errors.exceptions.base.PySparkNotImplementedError:
[UNSUPPORTED_SIGNATURE] Unsupported signature: (iterator:
Iterator[tuple[pandas.core.series.Series, pandas.core.frame.DataFrame]]) ->
Iterator[pandas.core.series.Series].
```
TO
```py
>>> pandas_udf("long")
... def multiply(iterator: Iterator[tuple[pd.Series, pd.DataFrame]]) ->
Iterator[pd.Series]:
... for s1, df in iterator:
... yield s1 * df.v
...
>>> multiply._unwrapped.evalType
204 # SQL_SCALAR_PANDAS_ITER_UDF
```
### How was this patch tested?
Unit tests.
Closes #41388 from xinrong-meng/tuple.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/pandas/typehints.py | 2 +-
.../sql/tests/pandas/test_pandas_udf_typehints.py | 24 ++++++++++++++++++++++
2 files changed, 25 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/sql/pandas/typehints.py
b/python/pyspark/sql/pandas/typehints.py
index 29ac81af944..f0c13e66a63 100644
--- a/python/pyspark/sql/pandas/typehints.py
+++ b/python/pyspark/sql/pandas/typehints.py
@@ -145,7 +145,7 @@ def check_tuple_annotation(
# Tuple has _name but other types have __name__
# Check if the name is Tuple first. After that, check the generic types.
name = getattr(annotation, "_name", getattr(annotation, "__name__", None))
- return name == "Tuple" and (
+ return name in ("Tuple", "tuple") and (
parameter_check_func is None or all(map(parameter_check_func,
annotation.__args__))
)
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
index 3cdf83e2d06..bfb874ffe53 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import sys
import unittest
from inspect import signature
from typing import Union, Iterator, Tuple, cast, get_type_hints
@@ -113,6 +114,29 @@ class PandasUDFTypeHintsTests(ReusedSQLTestCase):
infer_eval_type(signature(func), get_type_hints(func)),
PandasUDFType.SCALAR_ITER
)
+ @unittest.skipIf(sys.version_info < (3, 9), "Type hinting generics require
Python 3.9.")
+ def test_type_annotation_tuple_generics(self):
+ def func(iter: Iterator[tuple[pd.DataFrame, pd.Series]]) ->
Iterator[pd.DataFrame]:
+ pass
+
+ self.assertEqual(
+ infer_eval_type(signature(func), get_type_hints(func)),
PandasUDFType.SCALAR_ITER
+ )
+
+ def func(iter: Iterator[tuple[pd.DataFrame, ...]]) ->
Iterator[pd.Series]:
+ pass
+
+ self.assertEqual(
+ infer_eval_type(signature(func), get_type_hints(func)),
PandasUDFType.SCALAR_ITER
+ )
+
+ def func(iter: Iterator[tuple[Union[pd.DataFrame, pd.Series], ...]])
-> Iterator[pd.Series]:
+ pass
+
+ self.assertEqual(
+ infer_eval_type(signature(func), get_type_hints(func)),
PandasUDFType.SCALAR_ITER
+ )
+
def test_type_annotation_group_agg(self):
def func(col: pd.Series) -> str:
pass
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]