This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 646562027b8 [SPARK-44876][PYTHON] Fix Arrow-optimized Python UDF on 
Spark Connect
646562027b8 is described below

commit 646562027b80e72c987ce9e710eea6c71cefdfbb
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Mon Aug 21 09:20:18 2023 +0900

    [SPARK-44876][PYTHON] Fix Arrow-optimized Python UDF on Spark Connect
    
    ### What changes were proposed in this pull request?
    
    Fixes Arrow-optimized Python UDF on Spark Connect.
    
    Also enables the missing test 
`pyspark.sql.tests.connect.test_parity_arrow_python_udf`.
    
    ### Why are the changes needed?
    
    `pyspark.sql.tests.connect.test_parity_arrow_python_udf` is not listed in 
`dev/sparktestsupport/modules.py`, and it fails when running manually.
    
    ```
    ======================================================================
    ERROR [0.072s]: test_register 
(pyspark.sql.tests.connect.test_parity_arrow_python_udf.ArrowPythonUDFParityTests)
    ----------------------------------------------------------------------
    Traceback (most recent call last):
    ...
    pyspark.errors.exceptions.base.PySparkRuntimeError: 
[SCHEMA_MISMATCH_FOR_PANDAS_UDF] Result vector from pandas_udf was not the 
required length: expected 1, got 38.
    ```
    
    The failure had not been captured because the test is missing in the 
`module.py` file.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #42568 from ueshin/issues/SPARK-44876/test_parity_arrow_python_udf.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit 75c0b8b61ca53c3763c8e43e83b93b34688ea246)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 dev/sparktestsupport/modules.py   |  1 +
 python/pyspark/sql/connect/udf.py | 20 ++++++-------
 python/pyspark/sql/udf.py         | 60 +++++++--------------------------------
 python/pyspark/worker.py          | 54 +++++++++++++++++++++++++++++++++--
 4 files changed, 73 insertions(+), 62 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 4435d19810b..25309b75cce 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -842,6 +842,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_connect_function",
         "pyspark.sql.tests.connect.test_connect_column",
         "pyspark.sql.tests.connect.test_parity_arrow",
+        "pyspark.sql.tests.connect.test_parity_arrow_python_udf",
         "pyspark.sql.tests.connect.test_parity_datasources",
         "pyspark.sql.tests.connect.test_parity_errors",
         "pyspark.sql.tests.connect.test_parity_catalog",
diff --git a/python/pyspark/sql/connect/udf.py 
b/python/pyspark/sql/connect/udf.py
index eb0541b9369..2636777e5f6 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -54,8 +54,6 @@ def _create_py_udf(
     returnType: "DataTypeOrString",
     useArrow: Optional[bool] = None,
 ) -> "UserDefinedFunctionLike":
-    from pyspark.sql.udf import _create_arrow_py_udf
-
     if useArrow is None:
         is_arrow_enabled = False
         try:
@@ -74,22 +72,22 @@ def _create_py_udf(
     else:
         is_arrow_enabled = useArrow
 
-    regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF)
-    try:
-        is_func_with_args = len(getfullargspec(f).args) > 0
-    except TypeError:
-        is_func_with_args = False
+    eval_type: int = PythonEvalType.SQL_BATCHED_UDF
+
     if is_arrow_enabled:
+        try:
+            is_func_with_args = len(getfullargspec(f).args) > 0
+        except TypeError:
+            is_func_with_args = False
         if is_func_with_args:
-            return _create_arrow_py_udf(regular_udf)
+            eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
         else:
             warnings.warn(
                 "Arrow optimization for Python UDFs cannot be enabled.",
                 UserWarning,
             )
-            return regular_udf
-    else:
-        return regular_udf
+
+    return _create_udf(f, returnType, eval_type)
 
 
 def _create_udf(
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index f25f525e33b..7d7784dd522 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -32,7 +32,6 @@ from pyspark.profiler import Profiler
 from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
 from pyspark.sql.column import Column, _to_java_column, _to_java_expr, _to_seq
 from pyspark.sql.types import (
-    BinaryType,
     DataType,
     StringType,
     StructType,
@@ -131,58 +130,24 @@ def _create_py_udf(
     else:
         is_arrow_enabled = useArrow
 
-    regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF)
-    try:
-        is_func_with_args = len(getfullargspec(f).args) > 0
-    except TypeError:
-        is_func_with_args = False
+    eval_type: int = PythonEvalType.SQL_BATCHED_UDF
+
     if is_arrow_enabled:
+        try:
+            is_func_with_args = len(getfullargspec(f).args) > 0
+        except TypeError:
+            is_func_with_args = False
         if is_func_with_args:
-            return _create_arrow_py_udf(regular_udf)
+            require_minimum_pandas_version()
+            require_minimum_pyarrow_version()
+            eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
         else:
             warnings.warn(
                 "Arrow optimization for Python UDFs cannot be enabled.",
                 UserWarning,
             )
-            return regular_udf
-    else:
-        return regular_udf
-
-
-def _create_arrow_py_udf(regular_udf):  # type: ignore
-    """Create an Arrow-optimized Python UDF out of a regular Python UDF."""
-    require_minimum_pandas_version()
-    require_minimum_pyarrow_version()
-
-    import pandas as pd
-    from pyspark.sql.pandas.functions import _create_pandas_udf
 
-    f = regular_udf.func
-    return_type = regular_udf.returnType
-
-    # "result_func" ensures the result of a Python UDF to be consistent 
with/without Arrow
-    # optimization.
-    # Otherwise, an Arrow-optimized Python UDF raises 
"pyarrow.lib.ArrowTypeError: Expected a
-    # string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF 
returns
-    # successfully.
-    result_func = lambda pdf: pdf  # noqa: E731
-    if type(return_type) == StringType:
-        result_func = lambda r: str(r) if r is not None else r  # noqa: E731
-    elif type(return_type) == BinaryType:
-        result_func = lambda r: bytes(r) if r is not None else r  # noqa: E731
-
-    def vectorized_udf(*args: pd.Series) -> pd.Series:
-        return pd.Series(result_func(f(*a)) for a in zip(*args))
-
-    # Regular UDFs can take callable instances too.
-    vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else 
f.__class__.__name__
-    vectorized_udf.__module__ = f.__module__ if hasattr(f, "__module__") else 
f.__class__.__module__
-    vectorized_udf.__doc__ = f.__doc__
-    pudf = _create_pandas_udf(vectorized_udf, return_type, 
PythonEvalType.SQL_ARROW_BATCHED_UDF)
-    # Keep the attributes as if this is a regular Python UDF.
-    pudf.func = f
-    pudf.returnType = return_type
-    return pudf
+    return _create_udf(f, returnType, eval_type)
 
 
 class UserDefinedFunction:
@@ -637,10 +602,7 @@ class UDFRegistration:
                 evalType=f.evalType,
                 deterministic=f.deterministic,
             )
-            if f.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF:
-                register_udf = _create_arrow_py_udf(source_udf)._unwrapped
-            else:
-                register_udf = source_udf._unwrapped  # type: 
ignore[attr-defined]
+            register_udf = source_udf._unwrapped  # type: ignore[attr-defined]
             return_udf = register_udf
         else:
             if returnType is None:
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 9463e6d4484..edbfad4a5dc 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -63,7 +63,7 @@ from pyspark.sql.pandas.serializers import (
     ApplyInPandasWithStateSerializer,
 )
 from pyspark.sql.pandas.types import to_arrow_type
-from pyspark.sql.types import StructType, _parse_datatype_json_string
+from pyspark.sql.types import BinaryType, StringType, StructType, 
_parse_datatype_json_string
 from pyspark.util import fail_on_stopiteration, try_simplify_traceback
 from pyspark import shuffle
 from pyspark.errors import PySparkRuntimeError, PySparkTypeError
@@ -138,6 +138,54 @@ def wrap_scalar_pandas_udf(f, return_type):
     )
 
 
+def wrap_arrow_batch_udf(f, return_type):
+    import pandas as pd
+
+    arrow_return_type = to_arrow_type(return_type)
+
+    # "result_func" ensures the result of a Python UDF to be consistent 
with/without Arrow
+    # optimization.
+    # Otherwise, an Arrow-optimized Python UDF raises 
"pyarrow.lib.ArrowTypeError: Expected a
+    # string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF 
returns
+    # successfully.
+    result_func = lambda pdf: pdf  # noqa: E731
+    if type(return_type) == StringType:
+        result_func = lambda r: str(r) if r is not None else r  # noqa: E731
+    elif type(return_type) == BinaryType:
+        result_func = lambda r: bytes(r) if r is not None else r  # noqa: E731
+
+    def evaluate(*args: pd.Series) -> pd.Series:
+        return pd.Series(result_func(f(*a)) for a in zip(*args))
+
+    def verify_result_type(result):
+        if not hasattr(result, "__len__"):
+            pd_type = "pandas.DataFrame" if type(return_type) == StructType 
else "pandas.Series"
+            raise PySparkTypeError(
+                error_class="UDF_RETURN_TYPE",
+                message_parameters={
+                    "expected": pd_type,
+                    "actual": type(result).__name__,
+                },
+            )
+        return result
+
+    def verify_result_length(result, length):
+        if len(result) != length:
+            raise PySparkRuntimeError(
+                error_class="SCHEMA_MISMATCH_FOR_PANDAS_UDF",
+                message_parameters={
+                    "expected": str(length),
+                    "actual": str(len(result)),
+                },
+            )
+        return result
+
+    return lambda *a: (
+        verify_result_length(verify_result_type(evaluate(*a)), len(a[0])),
+        arrow_return_type,
+    )
+
+
 def wrap_pandas_batch_iter_udf(f, return_type):
     arrow_return_type = to_arrow_type(return_type)
     iter_type_label = "pandas.DataFrame" if type(return_type) == StructType 
else "pandas.Series"
@@ -503,8 +551,10 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
         func = fail_on_stopiteration(chained_func)
 
     # the last returnType will be the return type of UDF
-    if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, 
PythonEvalType.SQL_ARROW_BATCHED_UDF):
+    if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
         return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
+    elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF:
+        return arg_offsets, wrap_arrow_batch_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
         return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to