This is an automated email from the ASF dual-hosted git repository. allisonwang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 83e7b4faf731 [SPARK-53012][PYHTON] Support Arrow Python UDTF in Spark Connect 83e7b4faf731 is described below commit 83e7b4faf731d79af68cf8905bfa7c40c8a99435 Author: Allison Wang <allison.w...@databricks.com> AuthorDate: Mon Aug 18 15:28:54 2025 -0700 [SPARK-53012][PYHTON] Support Arrow Python UDTF in Spark Connect ### What changes were proposed in this pull request? This PR supports arrow UDTF in Spark Connect. After this PR, users can create and use `arrow_udtf` when Spark Connect is enabled. ### Why are the changes needed? To support spark connect. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New UTs. ### Was this patch authored or co-authored using generative AI tooling? No Closes #51998 from allisonwang-db/SPARK-53012-arrow-udtf-connect. Authored-by: Allison Wang <allison.w...@databricks.com> Signed-off-by: Allison Wang <allison.w...@databricks.com> --- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/connect/functions/builtin.py | 16 ++++++- python/pyspark/sql/connect/udtf.py | 55 ++++++++++++++++++---- python/pyspark/sql/functions/builtin.py | 1 + python/pyspark/sql/tests/arrow/test_arrow_udtf.py | 8 +++- .../tests/connect/arrow/test_parity_arrow_udtf.py | 45 ++++++++++++++++++ .../sql/tests/connect/test_connect_function.py | 3 +- .../sql/tests/test_connect_compatibility.py | 3 +- python/pyspark/sql/udtf.py | 1 + 9 files changed, 117 insertions(+), 16 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 0141a1d3d9e2..1e333ba6c246 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -1131,6 +1131,7 @@ pyspark_connect = Module( "pyspark.sql.tests.connect.arrow.test_parity_arrow_udf_scalar", "pyspark.sql.tests.connect.arrow.test_parity_arrow_udf_grouped_agg", "pyspark.sql.tests.connect.arrow.test_parity_arrow_udf_window", + "pyspark.sql.tests.connect.arrow.test_parity_arrow_udtf", "pyspark.sql.tests.connect.pandas.test_parity_pandas_map", "pyspark.sql.tests.connect.pandas.test_parity_pandas_grouped_map", "pyspark.sql.tests.connect.pandas.test_parity_pandas_grouped_map_with_state", diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index f9f8e0118147..0380b517e6e5 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -60,7 +60,7 @@ from pyspark.sql.connect.expressions import ( ) from pyspark.sql.connect.udf import _create_py_udf from pyspark.sql.connect.udtf import AnalyzeArgument, AnalyzeResult # noqa: F401 -from pyspark.sql.connect.udtf import _create_py_udtf +from pyspark.sql.connect.udtf import _create_py_udtf, _create_pyarrow_udtf from pyspark.sql import functions as pysparkfuncs from pyspark.sql.types import ( _from_numpy_type, @@ -4500,6 +4500,20 @@ def udtf( udtf.__doc__ = pysparkfuncs.udtf.__doc__ +def arrow_udtf( + cls: Optional[Type] = None, + *, + returnType: Optional[Union[StructType, str]] = None, +) -> Union["UserDefinedTableFunction", Callable[[Type], "UserDefinedTableFunction"]]: + if cls is None: + return functools.partial(_create_pyarrow_udtf, returnType=returnType) + else: + return _create_pyarrow_udtf(cls=cls, returnType=returnType) + + +arrow_udtf.__doc__ = pysparkfuncs.arrow_udtf.__doc__ + + def call_function(funcName: str, *cols: "ColumnOrName") -> Column: from pyspark.sql.connect.column import Column as ConnectColumn diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index ed9ab26788f7..f04993207167 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -22,7 +22,7 @@ from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) import warnings -from typing import List, Type, TYPE_CHECKING, Optional, Union +from typing import List, Type, TYPE_CHECKING, Optional, Union, Any from pyspark.util import PythonEvalType from pyspark.sql.connect.column import Column @@ -34,10 +34,11 @@ from pyspark.sql.connect.plan import ( from pyspark.sql.connect.table_arg import TableArg from pyspark.sql.connect.types import UnparsedDataType from pyspark.sql.connect.utils import get_python_ver +from pyspark.sql.pandas.utils import require_minimum_pyarrow_version, require_minimum_pandas_version from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult # noqa: F401 from pyspark.sql.udtf import UDTFRegistration as PySparkUDTFRegistration, _validate_udtf_handler from pyspark.sql.types import DataType, StructType -from pyspark.errors import PySparkRuntimeError, PySparkTypeError +from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkAttributeError if TYPE_CHECKING: @@ -87,11 +88,6 @@ def _create_py_udtf( eval_type: int = PythonEvalType.SQL_TABLE_UDF if arrow_enabled: - from pyspark.sql.pandas.utils import ( - require_minimum_pandas_version, - require_minimum_pyarrow_version, - ) - try: require_minimum_pandas_version() require_minimum_pyarrow_version() @@ -106,6 +102,43 @@ def _create_py_udtf( return _create_udtf(cls, returnType, name, eval_type, deterministic) +def _create_pyarrow_udtf( + cls: Type, + returnType: Optional[Union[StructType, str]], + name: Optional[str] = None, + deterministic: bool = False, +) -> "UserDefinedTableFunction": + """Create a PyArrow-native Python UDTF.""" + # Validate PyArrow dependencies + require_minimum_pyarrow_version() + + # Validate the handler class with PyArrow-specific checks + _validate_arrow_udtf_handler(cls, returnType) + + return _create_udtf( + cls=cls, + returnType=returnType, + name=name, + evalType=PythonEvalType.SQL_ARROW_UDTF, + deterministic=deterministic, + ) + + +def _validate_arrow_udtf_handler(cls: Any, returnType: Optional[Union[StructType, str]]) -> None: + """Validate the handler class of a PyArrow UDTF.""" + # First run standard UDTF validation + _validate_udtf_handler(cls, returnType) + + # Block analyze method usage in arrow UDTFs + # TODO(SPARK-53286): Support analyze method for Arrow UDTFs to enable dynamic return types + has_analyze = hasattr(cls, "analyze") + if has_analyze: + raise PySparkAttributeError( + errorClass="INVALID_ARROW_UDTF_WITH_ANALYZE", + messageParameters={"name": cls.__name__}, + ) + + class UserDefinedTableFunction: """ User defined function in Python @@ -203,12 +236,16 @@ class UDTFRegistration: }, ) - if f.evalType not in [PythonEvalType.SQL_TABLE_UDF, PythonEvalType.SQL_ARROW_TABLE_UDF]: + if f.evalType not in [ + PythonEvalType.SQL_TABLE_UDF, + PythonEvalType.SQL_ARROW_TABLE_UDF, + PythonEvalType.SQL_ARROW_UDTF, + ]: raise PySparkTypeError( errorClass="INVALID_UDTF_EVAL_TYPE", messageParameters={ "name": name, - "eval_type": "SQL_TABLE_UDF, SQL_ARROW_TABLE_UDF", + "eval_type": "SQL_TABLE_UDF, SQL_ARROW_TABLE_UDF, SQL_ARROW_UDTF", }, ) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 15d032b95614..0bec14d10d44 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -27257,6 +27257,7 @@ def udtf( return _create_py_udtf(cls=cls, returnType=returnType, useArrow=useArrow) +@_try_remote_functions def arrow_udtf( cls: Optional[Type] = None, *, diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py index c89273e8099d..50fe6588eb92 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py @@ -18,7 +18,7 @@ import unittest from typing import Iterator from pyspark.errors import PySparkAttributeError -from pyspark.errors.exceptions.captured import PythonException +from pyspark.errors import PythonException from pyspark.sql.functions import arrow_udtf, lit from pyspark.sql.types import Row, StructType, StructField, IntegerType from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pyarrow, pyarrow_requirement_message @@ -29,7 +29,7 @@ if have_pyarrow: @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) -class ArrowUDTFTests(ReusedSQLTestCase): +class ArrowUDTFTestsMixin: def test_arrow_udtf_zero_args(self): @arrow_udtf(returnType="id int, value string") class TestUDTF: @@ -461,6 +461,10 @@ class ArrowUDTFTests(ReusedSQLTestCase): assertDataFrameEqual(sql_result_df, expected_df) +class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase): + pass + + if __name__ == "__main__": from pyspark.sql.tests.arrow.test_arrow_udtf import * # noqa: F401 diff --git a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_udtf.py b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_udtf.py new file mode 100644 index 000000000000..18227f493a0b --- /dev/null +++ b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_udtf.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.arrow.test_arrow_udtf import ArrowUDTFTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class ArrowUDTFParityTests(ArrowUDTFTestsMixin, ReusedConnectTestCase): + # TODO(SPARK-53323): Support table arguments in Spark Connect Arrow UDTFs + @unittest.skip("asTable() is not supported in Spark Connect") + def test_arrow_udtf_with_table_argument_basic(self): + super().test_arrow_udtf_with_table_argument_basic() + + # TODO(SPARK-53323): Support table arguments in Spark Connect Arrow UDTFs + @unittest.skip("asTable() is not supported in Spark Connect") + def test_arrow_udtf_with_table_argument_and_scalar(self): + super().test_arrow_udtf_with_table_argument_and_scalar() + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.arrow.test_parity_arrow_udtf import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index 847326da416b..b906f5c5cef4 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -2559,8 +2559,7 @@ class SparkConnectFunctionTests(ReusedMixedTestCase, PandasOnSparkTestUtils): cf_fn = {name for (name, value) in getmembers(CF, isfunction) if name[0] != "_"} # Functions in classic PySpark we do not expect to be available in Spark Connect - # TODO: SPARK-53012 - Implement arrow_udtf in Spark Connect - sf_excluded_fn = {"arrow_udtf"} + sf_excluded_fn = set() self.assertEqual( sf_fn - cf_fn, diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index e3b61f65dafb..b2e0cc6229c4 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -379,8 +379,7 @@ class ConnectCompatibilityTestsMixin: """Test Functions compatibility between classic and connect.""" expected_missing_connect_properties = set() expected_missing_classic_properties = set() - # TODO(SPARK-53012): support arrow_udtf in Spark Connect - expected_missing_connect_methods = {"arrow_udtf"} + expected_missing_connect_methods = set() expected_missing_classic_methods = {"check_dependencies"} self.check_compatibility( ClassicFunctions, diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index 8e15631fcca5..faaa80b4011e 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -278,6 +278,7 @@ def _validate_arrow_udtf_handler(cls: Any, returnType: Optional[Union[StructType _validate_udtf_handler(cls, returnType) # Block analyze method usage in arrow UDTFs + # TODO(SPARK-53286): Support analyze method for Arrow UDTFs to enable dynamic return types has_analyze = hasattr(cls, "analyze") if has_analyze: raise PySparkAttributeError( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org