This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f2a6c97d718 [SPARK-44876][PYTHON][FOLLOWUP] Fix Arrow-optimized Python
UDF to delay wrapping the function with fail_on_stopiteration
f2a6c97d718 is described below
commit f2a6c97d718839896343feaa520396f328f2f866
Author: Takuya UESHIN <[email protected]>
AuthorDate: Mon Sep 4 15:24:33 2023 +0800
[SPARK-44876][PYTHON][FOLLOWUP] Fix Arrow-optimized Python UDF to delay
wrapping the function with fail_on_stopiteration
### What changes were proposed in this pull request?
Fixes Arrow-optimized Python UDF to delay wrapping the function with
`fail_on_stopiteration`.
Also removed unnecessary verification `verify_result_type`.
### Why are the changes needed?
For Arrow-optimized Python UDF, `fail_on_stopiteration` can be applied to
only the wrapped function to avoid unnecessary overhead.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added the related test.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #42784 from ueshin/issues/SPARK-44876/fail_on_stopiteration.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/tests/test_udf.py | 15 +++++++++++++++
python/pyspark/worker.py | 22 ++++++----------------
2 files changed, 21 insertions(+), 16 deletions(-)
diff --git a/python/pyspark/sql/tests/test_udf.py
b/python/pyspark/sql/tests/test_udf.py
index 32ea05bd00a..1f895b1780b 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -1005,6 +1005,21 @@ class BaseUDFTestsMixin(object):
with self.subTest(with_b=True, query_no=i):
assertDataFrameEqual(df, [Row(0), Row(101)])
+ def test_raise_stop_iteration(self):
+ @udf("int")
+ def test_udf(a):
+ if a < 5:
+ return a
+ else:
+ raise StopIteration()
+
+ assertDataFrameEqual(
+ self.spark.range(5).select(test_udf(col("id"))), [Row(i) for i in
range(5)]
+ )
+
+ with self.assertRaisesRegex(PythonException, "StopIteration"):
+ self.spark.range(10).select(test_udf(col("id"))).show()
+
class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index fff99f1de3d..92bc622775b 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -139,6 +139,7 @@ def wrap_arrow_batch_udf(f, return_type):
elif type(return_type) == BinaryType:
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731
+ @fail_on_stopiteration
def evaluate(*args: pd.Series, **kwargs: pd.Series) -> pd.Series:
keys = list(kwargs.keys())
len_args = len(args)
@@ -151,18 +152,6 @@ def wrap_arrow_batch_udf(f, return_type):
]
)
- def verify_result_type(result):
- if not hasattr(result, "__len__"):
- pd_type = "pandas.DataFrame" if type(return_type) == StructType
else "pandas.Series"
- raise PySparkTypeError(
- error_class="UDF_RETURN_TYPE",
- message_parameters={
- "expected": pd_type,
- "actual": type(result).__name__,
- },
- )
- return result
-
def verify_result_length(result, length):
if len(result) != length:
raise PySparkRuntimeError(
@@ -175,9 +164,7 @@ def wrap_arrow_batch_udf(f, return_type):
return result
return lambda *a, **kw: (
- verify_result_length(
- verify_result_type(evaluate(*a, **kw)), len((list(a) +
list(kw.values()))[0])
- ),
+ verify_result_length(evaluate(*a, **kw), len((list(a) +
list(kw.values()))[0])),
arrow_return_type,
)
@@ -562,7 +549,10 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index):
else:
chained_func = chain(chained_func, f)
- if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
+ if eval_type in (
+ PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
+ PythonEvalType.SQL_ARROW_BATCHED_UDF,
+ ):
func = chained_func
else:
# make sure StopIteration's raised in the user code are not ignored
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]