This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 646562027b8 [SPARK-44876][PYTHON] Fix Arrow-optimized Python UDF on Spark Connect 646562027b8 is described below commit 646562027b80e72c987ce9e710eea6c71cefdfbb Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Mon Aug 21 09:20:18 2023 +0900 [SPARK-44876][PYTHON] Fix Arrow-optimized Python UDF on Spark Connect ### What changes were proposed in this pull request? Fixes Arrow-optimized Python UDF on Spark Connect. Also enables the missing test `pyspark.sql.tests.connect.test_parity_arrow_python_udf`. ### Why are the changes needed? `pyspark.sql.tests.connect.test_parity_arrow_python_udf` is not listed in `dev/sparktestsupport/modules.py`, and it fails when running manually. ``` ====================================================================== ERROR [0.072s]: test_register (pyspark.sql.tests.connect.test_parity_arrow_python_udf.ArrowPythonUDFParityTests) ---------------------------------------------------------------------- Traceback (most recent call last): ... pyspark.errors.exceptions.base.PySparkRuntimeError: [SCHEMA_MISMATCH_FOR_PANDAS_UDF] Result vector from pandas_udf was not the required length: expected 1, got 38. ``` The failure had not been captured because the test is missing in the `module.py` file. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. Closes #42568 from ueshin/issues/SPARK-44876/test_parity_arrow_python_udf. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit 75c0b8b61ca53c3763c8e43e83b93b34688ea246) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/connect/udf.py | 20 ++++++------- python/pyspark/sql/udf.py | 60 +++++++-------------------------------- python/pyspark/worker.py | 54 +++++++++++++++++++++++++++++++++-- 4 files changed, 73 insertions(+), 62 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 4435d19810b..25309b75cce 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -842,6 +842,7 @@ pyspark_connect = Module( "pyspark.sql.tests.connect.test_connect_function", "pyspark.sql.tests.connect.test_connect_column", "pyspark.sql.tests.connect.test_parity_arrow", + "pyspark.sql.tests.connect.test_parity_arrow_python_udf", "pyspark.sql.tests.connect.test_parity_datasources", "pyspark.sql.tests.connect.test_parity_errors", "pyspark.sql.tests.connect.test_parity_catalog", diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index eb0541b9369..2636777e5f6 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -54,8 +54,6 @@ def _create_py_udf( returnType: "DataTypeOrString", useArrow: Optional[bool] = None, ) -> "UserDefinedFunctionLike": - from pyspark.sql.udf import _create_arrow_py_udf - if useArrow is None: is_arrow_enabled = False try: @@ -74,22 +72,22 @@ def _create_py_udf( else: is_arrow_enabled = useArrow - regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF) - try: - is_func_with_args = len(getfullargspec(f).args) > 0 - except TypeError: - is_func_with_args = False + eval_type: int = PythonEvalType.SQL_BATCHED_UDF + if is_arrow_enabled: + try: + is_func_with_args = len(getfullargspec(f).args) > 0 + except TypeError: + is_func_with_args = False if is_func_with_args: - return _create_arrow_py_udf(regular_udf) + eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF else: warnings.warn( "Arrow optimization for Python UDFs cannot be enabled.", UserWarning, ) - return regular_udf - else: - return regular_udf + + return _create_udf(f, returnType, eval_type) def _create_udf( diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index f25f525e33b..7d7784dd522 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -32,7 +32,6 @@ from pyspark.profiler import Profiler from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType from pyspark.sql.column import Column, _to_java_column, _to_java_expr, _to_seq from pyspark.sql.types import ( - BinaryType, DataType, StringType, StructType, @@ -131,58 +130,24 @@ def _create_py_udf( else: is_arrow_enabled = useArrow - regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF) - try: - is_func_with_args = len(getfullargspec(f).args) > 0 - except TypeError: - is_func_with_args = False + eval_type: int = PythonEvalType.SQL_BATCHED_UDF + if is_arrow_enabled: + try: + is_func_with_args = len(getfullargspec(f).args) > 0 + except TypeError: + is_func_with_args = False if is_func_with_args: - return _create_arrow_py_udf(regular_udf) + require_minimum_pandas_version() + require_minimum_pyarrow_version() + eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF else: warnings.warn( "Arrow optimization for Python UDFs cannot be enabled.", UserWarning, ) - return regular_udf - else: - return regular_udf - - -def _create_arrow_py_udf(regular_udf): # type: ignore - """Create an Arrow-optimized Python UDF out of a regular Python UDF.""" - require_minimum_pandas_version() - require_minimum_pyarrow_version() - - import pandas as pd - from pyspark.sql.pandas.functions import _create_pandas_udf - f = regular_udf.func - return_type = regular_udf.returnType - - # "result_func" ensures the result of a Python UDF to be consistent with/without Arrow - # optimization. - # Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a - # string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns - # successfully. - result_func = lambda pdf: pdf # noqa: E731 - if type(return_type) == StringType: - result_func = lambda r: str(r) if r is not None else r # noqa: E731 - elif type(return_type) == BinaryType: - result_func = lambda r: bytes(r) if r is not None else r # noqa: E731 - - def vectorized_udf(*args: pd.Series) -> pd.Series: - return pd.Series(result_func(f(*a)) for a in zip(*args)) - - # Regular UDFs can take callable instances too. - vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__ - vectorized_udf.__module__ = f.__module__ if hasattr(f, "__module__") else f.__class__.__module__ - vectorized_udf.__doc__ = f.__doc__ - pudf = _create_pandas_udf(vectorized_udf, return_type, PythonEvalType.SQL_ARROW_BATCHED_UDF) - # Keep the attributes as if this is a regular Python UDF. - pudf.func = f - pudf.returnType = return_type - return pudf + return _create_udf(f, returnType, eval_type) class UserDefinedFunction: @@ -637,10 +602,7 @@ class UDFRegistration: evalType=f.evalType, deterministic=f.deterministic, ) - if f.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF: - register_udf = _create_arrow_py_udf(source_udf)._unwrapped - else: - register_udf = source_udf._unwrapped # type: ignore[attr-defined] + register_udf = source_udf._unwrapped # type: ignore[attr-defined] return_udf = register_udf else: if returnType is None: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 9463e6d4484..edbfad4a5dc 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -63,7 +63,7 @@ from pyspark.sql.pandas.serializers import ( ApplyInPandasWithStateSerializer, ) from pyspark.sql.pandas.types import to_arrow_type -from pyspark.sql.types import StructType, _parse_datatype_json_string +from pyspark.sql.types import BinaryType, StringType, StructType, _parse_datatype_json_string from pyspark.util import fail_on_stopiteration, try_simplify_traceback from pyspark import shuffle from pyspark.errors import PySparkRuntimeError, PySparkTypeError @@ -138,6 +138,54 @@ def wrap_scalar_pandas_udf(f, return_type): ) +def wrap_arrow_batch_udf(f, return_type): + import pandas as pd + + arrow_return_type = to_arrow_type(return_type) + + # "result_func" ensures the result of a Python UDF to be consistent with/without Arrow + # optimization. + # Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a + # string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns + # successfully. + result_func = lambda pdf: pdf # noqa: E731 + if type(return_type) == StringType: + result_func = lambda r: str(r) if r is not None else r # noqa: E731 + elif type(return_type) == BinaryType: + result_func = lambda r: bytes(r) if r is not None else r # noqa: E731 + + def evaluate(*args: pd.Series) -> pd.Series: + return pd.Series(result_func(f(*a)) for a in zip(*args)) + + def verify_result_type(result): + if not hasattr(result, "__len__"): + pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" + raise PySparkTypeError( + error_class="UDF_RETURN_TYPE", + message_parameters={ + "expected": pd_type, + "actual": type(result).__name__, + }, + ) + return result + + def verify_result_length(result, length): + if len(result) != length: + raise PySparkRuntimeError( + error_class="SCHEMA_MISMATCH_FOR_PANDAS_UDF", + message_parameters={ + "expected": str(length), + "actual": str(len(result)), + }, + ) + return result + + return lambda *a: ( + verify_result_length(verify_result_type(evaluate(*a)), len(a[0])), + arrow_return_type, + ) + + def wrap_pandas_batch_iter_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) iter_type_label = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" @@ -503,8 +551,10 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): func = fail_on_stopiteration(chained_func) # the last returnType will be the return type of UDF - if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF): + if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: return arg_offsets, wrap_scalar_pandas_udf(func, return_type) + elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF: + return arg_offsets, wrap_arrow_batch_udf(func, return_type) elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type) elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org