This is an automated email from the ASF dual-hosted git repository. xinrong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 32ab341071a [SPARK-43412][PYTHON][CONNECT] Introduce `SQL_ARROW_BATCHED_UDF` EvalType for Arrow-optimized Python UDFs 32ab341071a is described below commit 32ab341071aa69917f820baf5f61668c2455f1db Author: Xinrong Meng <xinr...@apache.org> AuthorDate: Wed May 10 13:09:15 2023 -0700 [SPARK-43412][PYTHON][CONNECT] Introduce `SQL_ARROW_BATCHED_UDF` EvalType for Arrow-optimized Python UDFs ### What changes were proposed in this pull request? Introduce `SQL_ARROW_BATCHED_UDF` EvalType for Arrow-optimized Python UDFs. An EvalType is used to uniquely identify a UDF type in PySpark. ### Why are the changes needed? We are about to improve nested non-atomic input/output support of an Arrow-optimized Python UDF. However, currently, it shares the same EvalType with a pickled Python UDF, but the same implementation with a Pandas UDF. Introducing an EvalType enables isolating the changes to Arrow-optimized Python UDFs. The PR is also a pre-requisite for registering an Arrow-optimized Python UDF. ### Does this PR introduce _any_ user-facing change? No user-facing behavior/result changes for Arrow-optimized Python UDFs. An `evalType`, as an attribute mainly designed for internal use, is changed as shown below: ```py >>> udf(lambda x: str(x), useArrow=True).evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF True # whereas >>> udf(lambda x: str(x), useArrow=False).evalType == PythonEvalType.SQL_BATCHED_UDF True ``` ### How was this patch tested? A new unit test `test_eval_type` and existing tests. Closes #41053 from xinrong-meng/evalTypeArrowPyUDF. Authored-by: Xinrong Meng <xinr...@apache.org> Signed-off-by: Xinrong Meng <xinr...@apache.org> --- .../main/scala/org/apache/spark/api/python/PythonRunner.scala | 2 ++ python/pyspark/rdd.py | 3 ++- python/pyspark/sql/_typing.pyi | 1 + python/pyspark/sql/connect/functions.py | 7 +------ python/pyspark/sql/connect/udf.py | 3 +-- python/pyspark/sql/functions.py | 6 +----- python/pyspark/sql/pandas/functions.py | 3 +++ python/pyspark/sql/tests/test_arrow_python_udf.py | 9 +++++++++ python/pyspark/sql/udf.py | 8 +++----- python/pyspark/worker.py | 9 ++++++--- .../org/apache/spark/sql/catalyst/expressions/PythonUDF.scala | 1 + .../apache/spark/sql/execution/python/ExtractPythonUDFs.scala | 3 ++- 12 files changed, 32 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 0b420f268ee..912e76005f0 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -44,6 +44,7 @@ private[spark] object PythonEvalType { val NON_UDF = 0 val SQL_BATCHED_UDF = 100 + val SQL_ARROW_BATCHED_UDF = 101 val SQL_SCALAR_PANDAS_UDF = 200 val SQL_GROUPED_MAP_PANDAS_UDF = 201 @@ -58,6 +59,7 @@ private[spark] object PythonEvalType { def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" + case SQL_ARROW_BATCHED_UDF => "SQL_ARROW_BATCHED_UDF" case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF" case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 13f93fbdad6..e6ef7f6108e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -110,7 +110,7 @@ if TYPE_CHECKING: ) from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import AtomicType, StructType - from pyspark.sql._typing import AtomicValue, RowLike, SQLBatchedUDFType + from pyspark.sql._typing import AtomicValue, RowLike, SQLArrowBatchedUDFType, SQLBatchedUDFType from py4j.java_gateway import JavaObject from py4j.java_collections import JavaArray @@ -140,6 +140,7 @@ class PythonEvalType: NON_UDF: "NonUDFType" = 0 SQL_BATCHED_UDF: "SQLBatchedUDFType" = 100 + SQL_ARROW_BATCHED_UDF: "SQLArrowBatchedUDFType" = 101 SQL_SCALAR_PANDAS_UDF: "PandasScalarUDFType" = 200 SQL_GROUPED_MAP_PANDAS_UDF: "PandasGroupedMapUDFType" = 201 diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi index 209bb70fadd..aafd916dbb7 100644 --- a/python/pyspark/sql/_typing.pyi +++ b/python/pyspark/sql/_typing.pyi @@ -57,6 +57,7 @@ AtomicValue = TypeVar( RowLike = TypeVar("RowLike", List[Any], Tuple[Any, ...], pyspark.sql.types.Row) SQLBatchedUDFType = Literal[100] +SQLArrowBatchedUDFType = Literal[101] class SupportsOpen(Protocol): def open(self, partition_id: int, epoch_id: int) -> bool: ... diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 60aa8ae14de..b7d7bc937cf 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -2504,8 +2504,6 @@ def udf( returnType: "DataTypeOrString" = StringType(), useArrow: Optional[bool] = None, ) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]], "UserDefinedFunctionLike"]]: - from pyspark.rdd import PythonEvalType - if f is None or isinstance(f, (str, DataType)): # If DataType has been passed as a positional argument # for decorator use it as a returnType @@ -2513,13 +2511,10 @@ def udf( return functools.partial( _create_py_udf, returnType=return_type, - evalType=PythonEvalType.SQL_BATCHED_UDF, useArrow=useArrow, ) else: - return _create_py_udf( - f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, useArrow=useArrow - ) + return _create_py_udf(f=f, returnType=returnType, useArrow=useArrow) udf.__doc__ = pysparkfuncs.udf.__doc__ diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 546a8e3bcbd..012c6c0d2d5 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -53,7 +53,6 @@ if TYPE_CHECKING: def _create_py_udf( f: Callable[..., Any], returnType: "DataTypeOrString", - evalType: int, useArrow: Optional[bool] = None, ) -> "UserDefinedFunctionLike": from pyspark.sql.udf import _create_arrow_py_udf @@ -68,7 +67,7 @@ def _create_py_udf( else useArrow ) - regular_udf = _create_udf(f, returnType, evalType) + regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF) return_type = regular_udf.returnType try: is_func_with_args = len(getfullargspec(f).args) > 0 diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e035ff5f0a3..e9b71f7d617 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -41,7 +41,6 @@ from py4j.java_gateway import JVMView from pyspark import SparkContext from pyspark.errors import PySparkTypeError, PySparkValueError -from pyspark.rdd import PythonEvalType from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import ArrayType, DataType, StringType, StructType, _from_numpy_type @@ -10398,13 +10397,10 @@ def udf( return functools.partial( _create_py_udf, returnType=return_type, - evalType=PythonEvalType.SQL_BATCHED_UDF, useArrow=useArrow, ) else: - return _create_py_udf( - f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, useArrow=useArrow - ) + return _create_py_udf(f=f, returnType=returnType, useArrow=useArrow) def _test() -> None: diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 09da310979e..b7d381f04c7 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -416,11 +416,14 @@ def _create_pandas_udf(f, returnType, evalType): PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + PythonEvalType.SQL_ARROW_BATCHED_UDF, ]: # In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is being triggered # at `apply` instead. # In case of 'SQL_MAP_PANDAS_ITER_UDF', 'SQL_MAP_ARROW_ITER_UDF' and # 'SQL_COGROUPED_MAP_PANDAS_UDF', the evaluation type will always be set. + # In case of 'SQL_ARROW_BATCHED_UDF', no deprecation warning is required since it is not + # exposed to users. pass elif len(argspec.annotations) > 0: try: diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py index 755f5c7d2ab..51112beadec 100644 --- a/python/pyspark/sql/tests/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/test_arrow_python_udf.py @@ -26,6 +26,7 @@ from pyspark.testing.sqlutils import ( pyarrow_requirement_message, ReusedSQLTestCase, ) +from pyspark.rdd import PythonEvalType @unittest.skipIf( @@ -110,6 +111,14 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin): ) self.assertEquals(row_false[0], "[1, 2, 3]") + def test_eval_type(self): + self.assertEquals( + udf(lambda x: str(x), useArrow=True).evalType, PythonEvalType.SQL_ARROW_BATCHED_UDF + ) + self.assertEquals( + udf(lambda x: str(x), useArrow=False).evalType, PythonEvalType.SQL_BATCHED_UDF + ) + class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index e01e479516c..374e8c1bcbb 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -87,7 +87,6 @@ def _create_udf( def _create_py_udf( f: Callable[..., Any], returnType: "DataTypeOrString", - evalType: int, useArrow: Optional[bool] = None, ) -> "UserDefinedFunctionLike": """Create a regular/Arrow-optimized Python UDF.""" @@ -129,8 +128,7 @@ def _create_py_udf( if useArrow is None else useArrow ) - - regular_udf = _create_udf(f, returnType, evalType) + regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF) return_type = regular_udf.returnType try: is_func_with_args = len(getfullargspec(f).args) > 0 @@ -188,11 +186,10 @@ def _create_arrow_py_udf(regular_udf): # type: ignore vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__ vectorized_udf.__module__ = f.__module__ if hasattr(f, "__module__") else f.__class__.__module__ vectorized_udf.__doc__ = f.__doc__ - pudf = _create_pandas_udf(vectorized_udf, return_type, None) + pudf = _create_pandas_udf(vectorized_udf, return_type, PythonEvalType.SQL_ARROW_BATCHED_UDF) # Keep the attributes as if this is a regular Python UDF. pudf.func = f pudf.returnType = return_type - pudf.evalType = regular_udf.evalType return pudf @@ -253,6 +250,7 @@ class UserDefinedFunction: def returnType(self) -> DataType: # This makes sure this is called after SparkContext is initialized. # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string. + # TODO: PythonEvalType.SQL_BATCHED_UDF if self._returnType_placeholder is None: if isinstance(self._returnType, DataType): self._returnType_placeholder = self._returnType diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index bc40e5fc4ef..9bd8df077b6 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -419,7 +419,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): func = fail_on_stopiteration(chained_func) # the last returnType will be the return type of UDF - if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: + if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF): return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: return arg_offsets, wrap_batch_iter_udf(func, return_type) @@ -445,7 +445,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): raise ValueError("Unknown eval type: {}".format(eval_type)) -# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType +# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF and SQL_ARROW_BATCHED_UDF when +# returning StructType def assign_cols_by_name(runner_conf): return ( runner_conf.get( @@ -459,6 +460,7 @@ def read_udfs(pickleSer, infile, eval_type): runner_conf = {} if eval_type in ( + PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, @@ -509,7 +511,8 @@ def read_udfs(pickleSer, infile, eval_type): # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. df_for_struct = ( - eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF + eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 9533f142ab5..08ffbea5510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.types.{DataType, StructType} object PythonUDF { private[this] val SCALAR_TYPES = Set( PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index e41f3095d9f..57c3e1ad88e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -270,7 +270,8 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] { val evaluation = evalType match { case PythonEvalType.SQL_BATCHED_UDF => BatchEvalPython(validUdfs, resultAttrs, child) - case PythonEvalType.SQL_SCALAR_PANDAS_UDF | PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF => + case PythonEvalType.SQL_SCALAR_PANDAS_UDF | PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF + | PythonEvalType.SQL_ARROW_BATCHED_UDF => ArrowEvalPython(validUdfs, resultAttrs, child, evalType) case _ => throw new IllegalStateException("Unexpected UDF evalType") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org