This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1dccc5341d98 [SPARK-52959][PYTHON] Support UDT in Arrow-optimized 
Python UDTF
1dccc5341d98 is described below

commit 1dccc5341d987f297dcdf5bcf6069fec9caa57dc
Author: Takuya Ueshin <ues...@databricks.com>
AuthorDate: Tue Jul 29 08:10:49 2025 -0700

    [SPARK-52959][PYTHON] Support UDT in Arrow-optimized Python UDTF
    
    ### What changes were proposed in this pull request?
    
    Supports UDT in Arrow-optimized Python UDTF.
    
    ### Why are the changes needed?
    
    Arrow-optimized Python UDTF doesn't work with UDTs.
    
    ```py
    >>> udtf(returnType=StructType().add("point", ExamplePointUDT()), 
useArrow=True)
    ... class TestUDTFReturningUDT:
    ...   def eval(self, x: float, y: float):
    ...     yield ExamplePoint(x=x, y=y)
    ...
    >>> df = spark.createDataFrame([(ExamplePoint(x=1.0, y=2.0),)], 
schema=StructType().add("point", ExamplePointUDT()))
    >>> df.lateralJoin(TestUDTFTakingUDT(col("point").outer())).show()
    ...
    AttributeError: 'list' object has no attribute 'x'
    
    >>> udtf(returnType=StructType().add("point", ExamplePointUDT()), 
useArrow=True)
    ... class TestUDTFReturningUDT:
    ...   def eval(self, x: float, y: float):
    ...     yield ExamplePoint(x=x, y=y),
    ...
    >>> df = spark.range(2).select((col("id") + 10.0).alias("x"), (col("id") * 
10.0).alias("y"))
    >>> df.lateralJoin(TestUDTFReturningUDT(col("x").outer(), 
col("y").outer())).show()
    ...
    org.apache.spark.SparkException: [ARROW_TYPE_MISMATCH] Invalid schema from 
Python UDTF: expected org.apache.spark.sql.test.ExamplePointUDT7f511308, got 
ArrayType(DoubleType,false). SQLSTATE: 42K0G
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, UDTs are available as input / output for Arrow-optimized Python UDTF.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #51694 from ueshin/issues/SPARK-52959/udtf_udt.
    
    Authored-by: Takuya Ueshin <ues...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 python/pyspark/sql/pandas/serializers.py           | 28 ++++++----
 python/pyspark/sql/tests/test_udtf.py              | 40 ++++++++++++++
 python/pyspark/worker.py                           | 61 +++++++++++++++-------
 .../execution/python/ArrowEvalPythonUDTFExec.scala |  6 ++-
 .../execution/python/ArrowPythonUDTFRunner.scala   |  1 +
 5 files changed, 107 insertions(+), 29 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 5943524db433..03d3974113ba 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -841,7 +841,7 @@ class 
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
     Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
     """
 
-    def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled):
+    def __init__(self, timezone, safecheck, input_types, 
int_to_decimal_coercion_enabled):
         super(ArrowStreamPandasUDTFSerializer, self).__init__(
             timezone=timezone,
             safecheck=safecheck,
@@ -861,6 +861,7 @@ class 
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
             ndarray_as_list=True,
             # Enables explicit casting for mismatched return types of Arrow 
Python UDTFs.
             arrow_cast=True,
+            input_types=input_types,
             # Enable additional coercions for UDTF serialization
             int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
         )
@@ -885,35 +886,44 @@ class 
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
         import pandas as pd
         import pyarrow as pa
 
-        # Make input conform to [(series1, type1), (series2, type2), ...]
-        if not isinstance(series, (list, tuple)) or (
-            len(series) == 2 and isinstance(series[1], pa.DataType)
+        # Make input conform to
+        # [(series1, arrow_type1, spark_type1), (series2, arrow_type2, 
spark_type2), ...]
+        if (
+            not isinstance(series, (list, tuple))
+            or (len(series) == 2 and isinstance(series[1], pa.DataType))
+            or (
+                len(series) == 3
+                and isinstance(series[1], pa.DataType)
+                and isinstance(series[2], DataType)
+            )
         ):
             series = [series]
         series = ((s, None) if not isinstance(s, (list, tuple)) else s for s 
in series)
+        series = ((s[0], s[1], None) if len(s) == 2 else s for s in series)
 
         arrs = []
-        for s, t in series:
+        for s, arrow_type, spark_type in series:
             if not isinstance(s, pd.DataFrame):
                 raise PySparkValueError(
                     "Output of an arrow-optimized Python UDTFs expects "
                     f"a pandas.DataFrame but got: {type(s)}"
                 )
 
-            arrs.append(self._create_struct_array(s, t))
+            arrs.append(self._create_struct_array(s, arrow_type, spark_type))
 
         return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in 
range(len(arrs))])
 
     def _get_or_create_converter_from_pandas(self, dt):
-        if dt not in self._converter_map:
+        key = dt.json()
+        if key not in self._converter_map:
             conv = _create_converter_from_pandas(
                 dt,
                 timezone=self._timezone,
                 error_on_duplicated_field_names=False,
                 ignore_unexpected_complex_type_values=True,
             )
-            self._converter_map[dt] = conv
-        return self._converter_map[dt]
+            self._converter_map[key] = conv
+        return self._converter_map[key]
 
     def _create_array(self, series, arrow_type, spark_type=None, 
arrow_cast=False):
         """
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 2bb7c6d1f176..7f812ad20f59 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -63,6 +63,7 @@ from pyspark.sql.types import (
     VariantVal,
 )
 from pyspark.testing import assertDataFrameEqual, assertSchemaEqual
+from pyspark.testing.objects import ExamplePoint, ExamplePointUDT
 from pyspark.testing.sqlutils import (
     have_pandas,
     have_pyarrow,
@@ -230,6 +231,14 @@ class BaseUDTFTestsMixin:
         with self.assertRaisesRegex(PythonException, 
"UDTF_INVALID_OUTPUT_ROW_TYPE"):
             TestUDTF(lit(1)).collect()
 
+        @udtf(returnType=StructType().add("point", ExamplePointUDT()))
+        class TestUDTF:
+            def eval(self, x: float, y: float):
+                yield ExamplePoint(x=x * 10, y=y * 10)
+
+        with self.assertRaisesRegex(PythonException, 
"UDTF_INVALID_OUTPUT_ROW_TYPE"):
+            TestUDTF(lit(1.0), lit(2.0)).collect()
+
     def test_udtf_eval_returning_tuple_with_struct_type(self):
         @udtf(returnType="a: struct<b: int, c: int>")
         class TestUDTF:
@@ -246,6 +255,30 @@ class BaseUDTFTestsMixin:
         with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_SCHEMA_MISMATCH"):
             TestUDTF(lit(1)).collect()
 
+    def test_udtf_eval_returning_udt(self):
+        @udtf(returnType=StructType().add("point", ExamplePointUDT()))
+        class TestUDTF:
+            def eval(self, x: float, y: float):
+                yield ExamplePoint(x=x * 10, y=y * 10),
+
+        assertDataFrameEqual(
+            TestUDTF(lit(1.0), lit(2.0)), [Row(point=ExamplePoint(x=10.0, 
y=20.0))]
+        )
+
+    def test_udtf_eval_taking_udt(self):
+        @udtf(returnType="x: double, y: double")
+        class TestUDTF:
+            def eval(self, point: ExamplePoint):
+                yield point.x * 10, point.y * 10
+
+        df = self.spark.createDataFrame(
+            [(ExamplePoint(x=1.0, y=2.0),)], schema=StructType().add("point", 
ExamplePointUDT())
+        )
+        assertDataFrameEqual(
+            df.lateralJoin(TestUDTF(col("point").outer())),
+            [Row(point=ExamplePoint(x=1.0, y=2.0), x=10.0, y=20.0)],
+        )
+
     def test_udtf_with_invalid_return_value(self):
         @udtf(returnType="x: int")
         class TestUDTF:
@@ -2955,6 +2988,13 @@ class LegacyUDTFArrowTestsMixin(BaseUDTFTestsMixin):
 
         assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)])
 
+        @udtf(returnType=StructType().add("udt", ExamplePointUDT()))
+        class TestUDTF:
+            def eval(self, x: float, y: float):
+                yield ExamplePoint(x=x * 10, y=y * 10)
+
+        assertDataFrameEqual(TestUDTF(lit(1.0), lit(2.0)), 
[Row(udt=ExamplePoint(x=10.0, y=20.0))])
+
     def test_udtf_use_large_var_types(self):
         for use_large_var_types in [True, False]:
             with self.subTest(use_large_var_types=use_large_var_types):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 76c043405986..be49e527664f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -66,7 +66,7 @@ from pyspark.sql.pandas.serializers import (
     ArrowBatchUDFSerializer,
     ArrowStreamUDTFSerializer,
 )
-from pyspark.sql.pandas.types import to_arrow_type, from_arrow_schema
+from pyspark.sql.pandas.types import to_arrow_type
 from pyspark.sql.types import (
     ArrayType,
     BinaryType,
@@ -1289,6 +1289,7 @@ def use_legacy_pandas_udf_conversion(runner_conf):
 def read_udtf(pickleSer, infile, eval_type):
     prefers_large_var_types = False
     legacy_pandas_conversion = False
+    input_schema = None
 
     if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
         runner_conf = {}
@@ -1305,6 +1306,7 @@ def read_udtf(pickleSer, infile, eval_type):
             ).lower()
             == "true"
         )
+        input_schema = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
         if legacy_pandas_conversion:
             # NOTE: if timezone is set here, that implies 
respectSessionTimeZone is True
             safecheck = (
@@ -1320,8 +1322,12 @@ def read_udtf(pickleSer, infile, eval_type):
                 == "true"
             )
             timezone = runner_conf.get("spark.sql.session.timeZone", None)
+            input_types = [field.dataType for field in input_schema]
             ser = ArrowStreamPandasUDTFSerializer(
-                timezone, safecheck, 
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled
+                timezone,
+                safecheck,
+                input_types=input_types,
+                
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
             )
         else:
             ser = ArrowStreamUDTFSerializer()
@@ -1709,17 +1715,21 @@ def read_udtf(pickleSer, infile, eval_type):
             def evaluate(*args: pd.Series, num_rows=1):
                 if len(args) == 0:
                     for _ in range(num_rows):
-                        yield verify_result(
-                            pd.DataFrame(check_return_value(func()))
-                        ), arrow_return_type
+                        yield (
+                            
verify_result(pd.DataFrame(check_return_value(func()))),
+                            arrow_return_type,
+                            return_type,
+                        )
                 else:
                     # Create tuples from the input pandas Series, each tuple
                     # represents a row across all Series.
                     row_tuples = zip(*args)
                     for row in row_tuples:
-                        yield verify_result(
-                            pd.DataFrame(check_return_value(func(*row)))
-                        ), arrow_return_type
+                        yield (
+                            
verify_result(pd.DataFrame(check_return_value(func(*row)))),
+                            arrow_return_type,
+                            return_type,
+                        )
 
             return evaluate
 
@@ -1868,21 +1878,14 @@ def read_udtf(pickleSer, infile, eval_type):
                 except Exception as e:
                     raise_conversion_error(e)
 
-            def evaluate(*args: pa.ChunkedArray, num_rows=1):
+            def evaluate(*args: list, num_rows=1):
                 if len(args) == 0:
                     for _ in range(num_rows):
                         for batch in 
verify_result(convert_to_arrow(func())).to_batches():
                             yield batch, arrow_return_type
 
                 else:
-                    list_args = list(args)
-                    names = [f"_{n}" for n in range(len(list_args))]
-                    t = pa.Table.from_arrays(list_args, names=names)
-                    schema = from_arrow_schema(t.schema, 
prefers_large_var_types)
-                    rows = ArrowTableToRowsConversion.convert(
-                        t, schema=schema, return_as_tuples=True
-                    )
-                    for row in rows:
+                    for row in zip(*args):
                         for batch in 
verify_result(convert_to_arrow(func(*row))).to_batches():
                             yield batch, arrow_return_type
 
@@ -1902,10 +1905,22 @@ def read_udtf(pickleSer, infile, eval_type):
 
         def mapper(_, it):
             try:
+                converters = [
+                    ArrowTableToRowsConversion._create_converter(
+                        field.dataType, none_on_identity=True
+                    )
+                    for field in input_schema
+                ]
                 for a in it:
+                    pylist = [
+                        [conv(v) for v in column.to_pylist()]
+                        if conv is not None
+                        else column.to_pylist()
+                        for column, conv in zip(a.columns, converters)
+                    ]
                     # The eval function yields an iterator. Each element 
produced by this
                     # iterator is a tuple in the form of (pyarrow.RecordBatch, 
arrow_return_type).
-                    yield from eval(*[a[o] for o in args_kwargs_offsets], 
num_rows=a.num_rows)
+                    yield from eval(*[pylist[o] for o in args_kwargs_offsets], 
num_rows=a.num_rows)
                 if terminate is not None:
                     yield from terminate()
             except SkipRestOfInputTableException:
@@ -1925,6 +1940,16 @@ def read_udtf(pickleSer, infile, eval_type):
 
             def verify_and_convert_result(result):
                 if result is not None:
+                    if hasattr(result, "__UDT__"):
+                        # UDT object should not be returned directly.
+                        raise PySparkRuntimeError(
+                            errorClass="UDTF_INVALID_OUTPUT_ROW_TYPE",
+                            messageParameters={
+                                "type": type(result).__name__,
+                                "func": f.__name__,
+                            },
+                        )
+
                     if hasattr(result, "__len__") and len(result) != 
return_type_size:
                         raise PySparkRuntimeError(
                             errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
index d7106403a388..6a6b08a97330 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructType, UserDefinedType}
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
 
 /**
@@ -61,7 +61,9 @@ case class ArrowEvalPythonUDTFExec(
 
     val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else 
Iterator(iter)
 
-    val outputTypes = resultAttrs.map(_.dataType)
+    val outputTypes = resultAttrs.map(_.dataType.transformRecursively {
+      case udt: UserDefinedType[_] => udt.sqlType
+    })
 
     val columnarBatchIter = new ArrowPythonUDTFRunner(
       udtf,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
index 86136e444d43..660c886a3823 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
@@ -48,6 +48,7 @@ class ArrowPythonUDTFRunner(
   with BasicPythonArrowOutput {
 
   override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+    PythonWorkerUtils.writeUTF(schema.json, dataOut)
     PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to