This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 32ab341071a [SPARK-43412][PYTHON][CONNECT] Introduce 
`SQL_ARROW_BATCHED_UDF` EvalType for Arrow-optimized Python UDFs
32ab341071a is described below

commit 32ab341071aa69917f820baf5f61668c2455f1db
Author: Xinrong Meng <xinr...@apache.org>
AuthorDate: Wed May 10 13:09:15 2023 -0700

    [SPARK-43412][PYTHON][CONNECT] Introduce `SQL_ARROW_BATCHED_UDF` EvalType 
for Arrow-optimized Python UDFs
    
    ### What changes were proposed in this pull request?
    Introduce `SQL_ARROW_BATCHED_UDF` EvalType for Arrow-optimized Python UDFs.
    
    An EvalType is used to uniquely identify a UDF type in PySpark.
    
    ### Why are the changes needed?
    We are about to improve nested non-atomic input/output support of an 
Arrow-optimized Python UDF.
    
    However, currently, it shares the same EvalType with a pickled Python UDF, 
but the same implementation with a Pandas UDF.
    
    Introducing an EvalType enables isolating the changes to Arrow-optimized 
Python UDFs.
    
    The PR is also a pre-requisite for registering an Arrow-optimized Python 
UDF.
    
    ### Does this PR introduce _any_ user-facing change?
    No user-facing behavior/result changes for Arrow-optimized Python UDFs.
    
    An `evalType`, as an attribute mainly designed for internal use, is changed 
as shown below:
    
    ```py
    >>> udf(lambda x: str(x), useArrow=True).evalType == 
PythonEvalType.SQL_ARROW_BATCHED_UDF
    True
    
    # whereas
    
    >>> udf(lambda x: str(x), useArrow=False).evalType == 
PythonEvalType.SQL_BATCHED_UDF
    True
    ```
    
    ### How was this patch tested?
    A new unit test `test_eval_type` and existing tests.
    
    Closes #41053 from xinrong-meng/evalTypeArrowPyUDF.
    
    Authored-by: Xinrong Meng <xinr...@apache.org>
    Signed-off-by: Xinrong Meng <xinr...@apache.org>
---
 .../main/scala/org/apache/spark/api/python/PythonRunner.scala    | 2 ++
 python/pyspark/rdd.py                                            | 3 ++-
 python/pyspark/sql/_typing.pyi                                   | 1 +
 python/pyspark/sql/connect/functions.py                          | 7 +------
 python/pyspark/sql/connect/udf.py                                | 3 +--
 python/pyspark/sql/functions.py                                  | 6 +-----
 python/pyspark/sql/pandas/functions.py                           | 3 +++
 python/pyspark/sql/tests/test_arrow_python_udf.py                | 9 +++++++++
 python/pyspark/sql/udf.py                                        | 8 +++-----
 python/pyspark/worker.py                                         | 9 ++++++---
 .../org/apache/spark/sql/catalyst/expressions/PythonUDF.scala    | 1 +
 .../apache/spark/sql/execution/python/ExtractPythonUDFs.scala    | 3 ++-
 12 files changed, 32 insertions(+), 23 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 0b420f268ee..912e76005f0 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -44,6 +44,7 @@ private[spark] object PythonEvalType {
   val NON_UDF = 0
 
   val SQL_BATCHED_UDF = 100
+  val SQL_ARROW_BATCHED_UDF = 101
 
   val SQL_SCALAR_PANDAS_UDF = 200
   val SQL_GROUPED_MAP_PANDAS_UDF = 201
@@ -58,6 +59,7 @@ private[spark] object PythonEvalType {
   def toString(pythonEvalType: Int): String = pythonEvalType match {
     case NON_UDF => "NON_UDF"
     case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
+    case SQL_ARROW_BATCHED_UDF => "SQL_ARROW_BATCHED_UDF"
     case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF"
     case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF"
     case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF"
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 13f93fbdad6..e6ef7f6108e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -110,7 +110,7 @@ if TYPE_CHECKING:
     )
     from pyspark.sql.dataframe import DataFrame
     from pyspark.sql.types import AtomicType, StructType
-    from pyspark.sql._typing import AtomicValue, RowLike, SQLBatchedUDFType
+    from pyspark.sql._typing import AtomicValue, RowLike, 
SQLArrowBatchedUDFType, SQLBatchedUDFType
 
     from py4j.java_gateway import JavaObject
     from py4j.java_collections import JavaArray
@@ -140,6 +140,7 @@ class PythonEvalType:
     NON_UDF: "NonUDFType" = 0
 
     SQL_BATCHED_UDF: "SQLBatchedUDFType" = 100
+    SQL_ARROW_BATCHED_UDF: "SQLArrowBatchedUDFType" = 101
 
     SQL_SCALAR_PANDAS_UDF: "PandasScalarUDFType" = 200
     SQL_GROUPED_MAP_PANDAS_UDF: "PandasGroupedMapUDFType" = 201
diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi
index 209bb70fadd..aafd916dbb7 100644
--- a/python/pyspark/sql/_typing.pyi
+++ b/python/pyspark/sql/_typing.pyi
@@ -57,6 +57,7 @@ AtomicValue = TypeVar(
 RowLike = TypeVar("RowLike", List[Any], Tuple[Any, ...], pyspark.sql.types.Row)
 
 SQLBatchedUDFType = Literal[100]
+SQLArrowBatchedUDFType = Literal[101]
 
 class SupportsOpen(Protocol):
     def open(self, partition_id: int, epoch_id: int) -> bool: ...
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index 60aa8ae14de..b7d7bc937cf 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -2504,8 +2504,6 @@ def udf(
     returnType: "DataTypeOrString" = StringType(),
     useArrow: Optional[bool] = None,
 ) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]], 
"UserDefinedFunctionLike"]]:
-    from pyspark.rdd import PythonEvalType
-
     if f is None or isinstance(f, (str, DataType)):
         # If DataType has been passed as a positional argument
         # for decorator use it as a returnType
@@ -2513,13 +2511,10 @@ def udf(
         return functools.partial(
             _create_py_udf,
             returnType=return_type,
-            evalType=PythonEvalType.SQL_BATCHED_UDF,
             useArrow=useArrow,
         )
     else:
-        return _create_py_udf(
-            f=f, returnType=returnType, 
evalType=PythonEvalType.SQL_BATCHED_UDF, useArrow=useArrow
-        )
+        return _create_py_udf(f=f, returnType=returnType, useArrow=useArrow)
 
 
 udf.__doc__ = pysparkfuncs.udf.__doc__
diff --git a/python/pyspark/sql/connect/udf.py 
b/python/pyspark/sql/connect/udf.py
index 546a8e3bcbd..012c6c0d2d5 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -53,7 +53,6 @@ if TYPE_CHECKING:
 def _create_py_udf(
     f: Callable[..., Any],
     returnType: "DataTypeOrString",
-    evalType: int,
     useArrow: Optional[bool] = None,
 ) -> "UserDefinedFunctionLike":
     from pyspark.sql.udf import _create_arrow_py_udf
@@ -68,7 +67,7 @@ def _create_py_udf(
             else useArrow
         )
 
-    regular_udf = _create_udf(f, returnType, evalType)
+    regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF)
     return_type = regular_udf.returnType
     try:
         is_func_with_args = len(getfullargspec(f).args) > 0
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index e035ff5f0a3..e9b71f7d617 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -41,7 +41,6 @@ from py4j.java_gateway import JVMView
 
 from pyspark import SparkContext
 from pyspark.errors import PySparkTypeError, PySparkValueError
-from pyspark.rdd import PythonEvalType
 from pyspark.sql.column import Column, _to_java_column, _to_seq, 
_create_column_from_literal
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.types import ArrayType, DataType, StringType, StructType, 
_from_numpy_type
@@ -10398,13 +10397,10 @@ def udf(
         return functools.partial(
             _create_py_udf,
             returnType=return_type,
-            evalType=PythonEvalType.SQL_BATCHED_UDF,
             useArrow=useArrow,
         )
     else:
-        return _create_py_udf(
-            f=f, returnType=returnType, 
evalType=PythonEvalType.SQL_BATCHED_UDF, useArrow=useArrow
-        )
+        return _create_py_udf(f=f, returnType=returnType, useArrow=useArrow)
 
 
 def _test() -> None:
diff --git a/python/pyspark/sql/pandas/functions.py 
b/python/pyspark/sql/pandas/functions.py
index 09da310979e..b7d381f04c7 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -416,11 +416,14 @@ def _create_pandas_udf(f, returnType, evalType):
         PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
         PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
         PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+        PythonEvalType.SQL_ARROW_BATCHED_UDF,
     ]:
         # In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is 
being triggered
         # at `apply` instead.
         # In case of 'SQL_MAP_PANDAS_ITER_UDF', 'SQL_MAP_ARROW_ITER_UDF' and
         # 'SQL_COGROUPED_MAP_PANDAS_UDF', the evaluation type will always be 
set.
+        # In case of 'SQL_ARROW_BATCHED_UDF', no deprecation warning is 
required since it is not
+        # exposed to users.
         pass
     elif len(argspec.annotations) > 0:
         try:
diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py 
b/python/pyspark/sql/tests/test_arrow_python_udf.py
index 755f5c7d2ab..51112beadec 100644
--- a/python/pyspark/sql/tests/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/test_arrow_python_udf.py
@@ -26,6 +26,7 @@ from pyspark.testing.sqlutils import (
     pyarrow_requirement_message,
     ReusedSQLTestCase,
 )
+from pyspark.rdd import PythonEvalType
 
 
 @unittest.skipIf(
@@ -110,6 +111,14 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
         )
         self.assertEquals(row_false[0], "[1, 2, 3]")
 
+    def test_eval_type(self):
+        self.assertEquals(
+            udf(lambda x: str(x), useArrow=True).evalType, 
PythonEvalType.SQL_ARROW_BATCHED_UDF
+        )
+        self.assertEquals(
+            udf(lambda x: str(x), useArrow=False).evalType, 
PythonEvalType.SQL_BATCHED_UDF
+        )
+
 
 class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index e01e479516c..374e8c1bcbb 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -87,7 +87,6 @@ def _create_udf(
 def _create_py_udf(
     f: Callable[..., Any],
     returnType: "DataTypeOrString",
-    evalType: int,
     useArrow: Optional[bool] = None,
 ) -> "UserDefinedFunctionLike":
     """Create a regular/Arrow-optimized Python UDF."""
@@ -129,8 +128,7 @@ def _create_py_udf(
             if useArrow is None
             else useArrow
         )
-
-    regular_udf = _create_udf(f, returnType, evalType)
+    regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF)
     return_type = regular_udf.returnType
     try:
         is_func_with_args = len(getfullargspec(f).args) > 0
@@ -188,11 +186,10 @@ def _create_arrow_py_udf(regular_udf):  # type: ignore
     vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else 
f.__class__.__name__
     vectorized_udf.__module__ = f.__module__ if hasattr(f, "__module__") else 
f.__class__.__module__
     vectorized_udf.__doc__ = f.__doc__
-    pudf = _create_pandas_udf(vectorized_udf, return_type, None)
+    pudf = _create_pandas_udf(vectorized_udf, return_type, 
PythonEvalType.SQL_ARROW_BATCHED_UDF)
     # Keep the attributes as if this is a regular Python UDF.
     pudf.func = f
     pudf.returnType = return_type
-    pudf.evalType = regular_udf.evalType
     return pudf
 
 
@@ -253,6 +250,7 @@ class UserDefinedFunction:
     def returnType(self) -> DataType:
         # This makes sure this is called after SparkContext is initialized.
         # ``_parse_datatype_string`` accesses to JVM for parsing a DDL 
formatted string.
+        # TODO: PythonEvalType.SQL_BATCHED_UDF
         if self._returnType_placeholder is None:
             if isinstance(self._returnType, DataType):
                 self._returnType_placeholder = self._returnType
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index bc40e5fc4ef..9bd8df077b6 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -419,7 +419,7 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
         func = fail_on_stopiteration(chained_func)
 
     # the last returnType will be the return type of UDF
-    if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
+    if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, 
PythonEvalType.SQL_ARROW_BATCHED_UDF):
         return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
         return arg_offsets, wrap_batch_iter_udf(func, return_type)
@@ -445,7 +445,8 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
         raise ValueError("Unknown eval type: {}".format(eval_type))
 
 
-# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning 
StructType
+# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF and 
SQL_ARROW_BATCHED_UDF when
+# returning StructType
 def assign_cols_by_name(runner_conf):
     return (
         runner_conf.get(
@@ -459,6 +460,7 @@ def read_udfs(pickleSer, infile, eval_type):
     runner_conf = {}
 
     if eval_type in (
+        PythonEvalType.SQL_ARROW_BATCHED_UDF,
         PythonEvalType.SQL_SCALAR_PANDAS_UDF,
         PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
         PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
@@ -509,7 +511,8 @@ def read_udfs(pickleSer, infile, eval_type):
             # Scalar Pandas UDF handles struct type arguments as pandas 
DataFrames instead of
             # pandas Series. See SPARK-27240.
             df_for_struct = (
-                eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
+                eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
+                or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
                 or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
                 or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
             )
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index 9533f142ab5..08ffbea5510 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
 object PythonUDF {
   private[this] val SCALAR_TYPES = Set(
     PythonEvalType.SQL_BATCHED_UDF,
+    PythonEvalType.SQL_ARROW_BATCHED_UDF,
     PythonEvalType.SQL_SCALAR_PANDAS_UDF,
     PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
   )
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index e41f3095d9f..57c3e1ad88e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -270,7 +270,8 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] {
           val evaluation = evalType match {
             case PythonEvalType.SQL_BATCHED_UDF =>
               BatchEvalPython(validUdfs, resultAttrs, child)
-            case PythonEvalType.SQL_SCALAR_PANDAS_UDF | 
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF =>
+            case PythonEvalType.SQL_SCALAR_PANDAS_UDF | 
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
+                 | PythonEvalType.SQL_ARROW_BATCHED_UDF =>
               ArrowEvalPython(validUdfs, resultAttrs, child, evalType)
             case _ =>
               throw new IllegalStateException("Unexpected UDF evalType")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to