This is an automated email from the ASF dual-hosted git repository.

allisonwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 83e7b4faf731 [SPARK-53012][PYHTON] Support Arrow Python UDTF in Spark 
Connect
83e7b4faf731 is described below

commit 83e7b4faf731d79af68cf8905bfa7c40c8a99435
Author: Allison Wang <allison.w...@databricks.com>
AuthorDate: Mon Aug 18 15:28:54 2025 -0700

    [SPARK-53012][PYHTON] Support Arrow Python UDTF in Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR supports arrow UDTF in Spark Connect. After this PR, users can 
create and use `arrow_udtf` when Spark Connect is enabled.
    
    ### Why are the changes needed?
    
    To support spark connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New UTs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #51998 from allisonwang-db/SPARK-53012-arrow-udtf-connect.
    
    Authored-by: Allison Wang <allison.w...@databricks.com>
    Signed-off-by: Allison Wang <allison.w...@databricks.com>
---
 dev/sparktestsupport/modules.py                    |  1 +
 python/pyspark/sql/connect/functions/builtin.py    | 16 ++++++-
 python/pyspark/sql/connect/udtf.py                 | 55 ++++++++++++++++++----
 python/pyspark/sql/functions/builtin.py            |  1 +
 python/pyspark/sql/tests/arrow/test_arrow_udtf.py  |  8 +++-
 .../tests/connect/arrow/test_parity_arrow_udtf.py  | 45 ++++++++++++++++++
 .../sql/tests/connect/test_connect_function.py     |  3 +-
 .../sql/tests/test_connect_compatibility.py        |  3 +-
 python/pyspark/sql/udtf.py                         |  1 +
 9 files changed, 117 insertions(+), 16 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 0141a1d3d9e2..1e333ba6c246 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1131,6 +1131,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.arrow.test_parity_arrow_udf_scalar",
         "pyspark.sql.tests.connect.arrow.test_parity_arrow_udf_grouped_agg",
         "pyspark.sql.tests.connect.arrow.test_parity_arrow_udf_window",
+        "pyspark.sql.tests.connect.arrow.test_parity_arrow_udtf",
         "pyspark.sql.tests.connect.pandas.test_parity_pandas_map",
         "pyspark.sql.tests.connect.pandas.test_parity_pandas_grouped_map",
         
"pyspark.sql.tests.connect.pandas.test_parity_pandas_grouped_map_with_state",
diff --git a/python/pyspark/sql/connect/functions/builtin.py 
b/python/pyspark/sql/connect/functions/builtin.py
index f9f8e0118147..0380b517e6e5 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -60,7 +60,7 @@ from pyspark.sql.connect.expressions import (
 )
 from pyspark.sql.connect.udf import _create_py_udf
 from pyspark.sql.connect.udtf import AnalyzeArgument, AnalyzeResult  # noqa: 
F401
-from pyspark.sql.connect.udtf import _create_py_udtf
+from pyspark.sql.connect.udtf import _create_py_udtf, _create_pyarrow_udtf
 from pyspark.sql import functions as pysparkfuncs
 from pyspark.sql.types import (
     _from_numpy_type,
@@ -4500,6 +4500,20 @@ def udtf(
 udtf.__doc__ = pysparkfuncs.udtf.__doc__
 
 
+def arrow_udtf(
+    cls: Optional[Type] = None,
+    *,
+    returnType: Optional[Union[StructType, str]] = None,
+) -> Union["UserDefinedTableFunction", Callable[[Type], 
"UserDefinedTableFunction"]]:
+    if cls is None:
+        return functools.partial(_create_pyarrow_udtf, returnType=returnType)
+    else:
+        return _create_pyarrow_udtf(cls=cls, returnType=returnType)
+
+
+arrow_udtf.__doc__ = pysparkfuncs.arrow_udtf.__doc__
+
+
 def call_function(funcName: str, *cols: "ColumnOrName") -> Column:
     from pyspark.sql.connect.column import Column as ConnectColumn
 
diff --git a/python/pyspark/sql/connect/udtf.py 
b/python/pyspark/sql/connect/udtf.py
index ed9ab26788f7..f04993207167 100644
--- a/python/pyspark/sql/connect/udtf.py
+++ b/python/pyspark/sql/connect/udtf.py
@@ -22,7 +22,7 @@ from pyspark.sql.connect.utils import check_dependencies
 check_dependencies(__name__)
 
 import warnings
-from typing import List, Type, TYPE_CHECKING, Optional, Union
+from typing import List, Type, TYPE_CHECKING, Optional, Union, Any
 
 from pyspark.util import PythonEvalType
 from pyspark.sql.connect.column import Column
@@ -34,10 +34,11 @@ from pyspark.sql.connect.plan import (
 from pyspark.sql.connect.table_arg import TableArg
 from pyspark.sql.connect.types import UnparsedDataType
 from pyspark.sql.connect.utils import get_python_ver
+from pyspark.sql.pandas.utils import require_minimum_pyarrow_version, 
require_minimum_pandas_version
 from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult  # noqa: F401
 from pyspark.sql.udtf import UDTFRegistration as PySparkUDTFRegistration, 
_validate_udtf_handler
 from pyspark.sql.types import DataType, StructType
-from pyspark.errors import PySparkRuntimeError, PySparkTypeError
+from pyspark.errors import PySparkRuntimeError, PySparkTypeError, 
PySparkAttributeError
 
 
 if TYPE_CHECKING:
@@ -87,11 +88,6 @@ def _create_py_udtf(
     eval_type: int = PythonEvalType.SQL_TABLE_UDF
 
     if arrow_enabled:
-        from pyspark.sql.pandas.utils import (
-            require_minimum_pandas_version,
-            require_minimum_pyarrow_version,
-        )
-
         try:
             require_minimum_pandas_version()
             require_minimum_pyarrow_version()
@@ -106,6 +102,43 @@ def _create_py_udtf(
     return _create_udtf(cls, returnType, name, eval_type, deterministic)
 
 
+def _create_pyarrow_udtf(
+    cls: Type,
+    returnType: Optional[Union[StructType, str]],
+    name: Optional[str] = None,
+    deterministic: bool = False,
+) -> "UserDefinedTableFunction":
+    """Create a PyArrow-native Python UDTF."""
+    # Validate PyArrow dependencies
+    require_minimum_pyarrow_version()
+
+    # Validate the handler class with PyArrow-specific checks
+    _validate_arrow_udtf_handler(cls, returnType)
+
+    return _create_udtf(
+        cls=cls,
+        returnType=returnType,
+        name=name,
+        evalType=PythonEvalType.SQL_ARROW_UDTF,
+        deterministic=deterministic,
+    )
+
+
+def _validate_arrow_udtf_handler(cls: Any, returnType: 
Optional[Union[StructType, str]]) -> None:
+    """Validate the handler class of a PyArrow UDTF."""
+    # First run standard UDTF validation
+    _validate_udtf_handler(cls, returnType)
+
+    # Block analyze method usage in arrow UDTFs
+    # TODO(SPARK-53286): Support analyze method for Arrow UDTFs to enable 
dynamic return types
+    has_analyze = hasattr(cls, "analyze")
+    if has_analyze:
+        raise PySparkAttributeError(
+            errorClass="INVALID_ARROW_UDTF_WITH_ANALYZE",
+            messageParameters={"name": cls.__name__},
+        )
+
+
 class UserDefinedTableFunction:
     """
     User defined function in Python
@@ -203,12 +236,16 @@ class UDTFRegistration:
                 },
             )
 
-        if f.evalType not in [PythonEvalType.SQL_TABLE_UDF, 
PythonEvalType.SQL_ARROW_TABLE_UDF]:
+        if f.evalType not in [
+            PythonEvalType.SQL_TABLE_UDF,
+            PythonEvalType.SQL_ARROW_TABLE_UDF,
+            PythonEvalType.SQL_ARROW_UDTF,
+        ]:
             raise PySparkTypeError(
                 errorClass="INVALID_UDTF_EVAL_TYPE",
                 messageParameters={
                     "name": name,
-                    "eval_type": "SQL_TABLE_UDF, SQL_ARROW_TABLE_UDF",
+                    "eval_type": "SQL_TABLE_UDF, SQL_ARROW_TABLE_UDF, 
SQL_ARROW_UDTF",
                 },
             )
 
diff --git a/python/pyspark/sql/functions/builtin.py 
b/python/pyspark/sql/functions/builtin.py
index 15d032b95614..0bec14d10d44 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -27257,6 +27257,7 @@ def udtf(
         return _create_py_udtf(cls=cls, returnType=returnType, 
useArrow=useArrow)
 
 
+@_try_remote_functions
 def arrow_udtf(
     cls: Optional[Type] = None,
     *,
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
index c89273e8099d..50fe6588eb92 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
@@ -18,7 +18,7 @@ import unittest
 from typing import Iterator
 
 from pyspark.errors import PySparkAttributeError
-from pyspark.errors.exceptions.captured import PythonException
+from pyspark.errors import PythonException
 from pyspark.sql.functions import arrow_udtf, lit
 from pyspark.sql.types import Row, StructType, StructField, IntegerType
 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pyarrow, 
pyarrow_requirement_message
@@ -29,7 +29,7 @@ if have_pyarrow:
 
 
 @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
-class ArrowUDTFTests(ReusedSQLTestCase):
+class ArrowUDTFTestsMixin:
     def test_arrow_udtf_zero_args(self):
         @arrow_udtf(returnType="id int, value string")
         class TestUDTF:
@@ -461,6 +461,10 @@ class ArrowUDTFTests(ReusedSQLTestCase):
         assertDataFrameEqual(sql_result_df, expected_df)
 
 
+class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase):
+    pass
+
+
 if __name__ == "__main__":
     from pyspark.sql.tests.arrow.test_arrow_udtf import *  # noqa: F401
 
diff --git a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_udtf.py 
b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_udtf.py
new file mode 100644
index 000000000000..18227f493a0b
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_udtf.py
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.tests.arrow.test_arrow_udtf import ArrowUDTFTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class ArrowUDTFParityTests(ArrowUDTFTestsMixin, ReusedConnectTestCase):
+    # TODO(SPARK-53323): Support table arguments in Spark Connect Arrow UDTFs
+    @unittest.skip("asTable() is not supported in Spark Connect")
+    def test_arrow_udtf_with_table_argument_basic(self):
+        super().test_arrow_udtf_with_table_argument_basic()
+
+    # TODO(SPARK-53323): Support table arguments in Spark Connect Arrow UDTFs
+    @unittest.skip("asTable() is not supported in Spark Connect")
+    def test_arrow_udtf_with_table_argument_and_scalar(self):
+        super().test_arrow_udtf_with_table_argument_and_scalar()
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.arrow.test_parity_arrow_udtf import *  # 
noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index 847326da416b..b906f5c5cef4 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -2559,8 +2559,7 @@ class SparkConnectFunctionTests(ReusedMixedTestCase, 
PandasOnSparkTestUtils):
         cf_fn = {name for (name, value) in getmembers(CF, isfunction) if 
name[0] != "_"}
 
         # Functions in classic PySpark we do not expect to be available in 
Spark Connect
-        # TODO: SPARK-53012 - Implement arrow_udtf in Spark Connect
-        sf_excluded_fn = {"arrow_udtf"}
+        sf_excluded_fn = set()
 
         self.assertEqual(
             sf_fn - cf_fn,
diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py 
b/python/pyspark/sql/tests/test_connect_compatibility.py
index e3b61f65dafb..b2e0cc6229c4 100644
--- a/python/pyspark/sql/tests/test_connect_compatibility.py
+++ b/python/pyspark/sql/tests/test_connect_compatibility.py
@@ -379,8 +379,7 @@ class ConnectCompatibilityTestsMixin:
         """Test Functions compatibility between classic and connect."""
         expected_missing_connect_properties = set()
         expected_missing_classic_properties = set()
-        # TODO(SPARK-53012): support arrow_udtf in Spark Connect
-        expected_missing_connect_methods = {"arrow_udtf"}
+        expected_missing_connect_methods = set()
         expected_missing_classic_methods = {"check_dependencies"}
         self.check_compatibility(
             ClassicFunctions,
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index 8e15631fcca5..faaa80b4011e 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -278,6 +278,7 @@ def _validate_arrow_udtf_handler(cls: Any, returnType: 
Optional[Union[StructType
     _validate_udtf_handler(cls, returnType)
 
     # Block analyze method usage in arrow UDTFs
+    # TODO(SPARK-53286): Support analyze method for Arrow UDTFs to enable 
dynamic return types
     has_analyze = hasattr(cls, "analyze")
     if has_analyze:
         raise PySparkAttributeError(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to