This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1dccc5341d98 [SPARK-52959][PYTHON] Support UDT in Arrow-optimized
Python UDTF
1dccc5341d98 is described below
commit 1dccc5341d987f297dcdf5bcf6069fec9caa57dc
Author: Takuya Ueshin <[email protected]>
AuthorDate: Tue Jul 29 08:10:49 2025 -0700
[SPARK-52959][PYTHON] Support UDT in Arrow-optimized Python UDTF
### What changes were proposed in this pull request?
Supports UDT in Arrow-optimized Python UDTF.
### Why are the changes needed?
Arrow-optimized Python UDTF doesn't work with UDTs.
```py
>>> udtf(returnType=StructType().add("point", ExamplePointUDT()),
useArrow=True)
... class TestUDTFReturningUDT:
... def eval(self, x: float, y: float):
... yield ExamplePoint(x=x, y=y)
...
>>> df = spark.createDataFrame([(ExamplePoint(x=1.0, y=2.0),)],
schema=StructType().add("point", ExamplePointUDT()))
>>> df.lateralJoin(TestUDTFTakingUDT(col("point").outer())).show()
...
AttributeError: 'list' object has no attribute 'x'
>>> udtf(returnType=StructType().add("point", ExamplePointUDT()),
useArrow=True)
... class TestUDTFReturningUDT:
... def eval(self, x: float, y: float):
... yield ExamplePoint(x=x, y=y),
...
>>> df = spark.range(2).select((col("id") + 10.0).alias("x"), (col("id") *
10.0).alias("y"))
>>> df.lateralJoin(TestUDTFReturningUDT(col("x").outer(),
col("y").outer())).show()
...
org.apache.spark.SparkException: [ARROW_TYPE_MISMATCH] Invalid schema from
Python UDTF: expected org.apache.spark.sql.test.ExamplePointUDT7f511308, got
ArrayType(DoubleType,false). SQLSTATE: 42K0G
```
### Does this PR introduce _any_ user-facing change?
Yes, UDTs are available as input / output for Arrow-optimized Python UDTF.
### How was this patch tested?
Added the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #51694 from ueshin/issues/SPARK-52959/udtf_udt.
Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/sql/pandas/serializers.py | 28 ++++++----
python/pyspark/sql/tests/test_udtf.py | 40 ++++++++++++++
python/pyspark/worker.py | 61 +++++++++++++++-------
.../execution/python/ArrowEvalPythonUDTFExec.scala | 6 ++-
.../execution/python/ArrowPythonUDTFRunner.scala | 1 +
5 files changed, 107 insertions(+), 29 deletions(-)
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 5943524db433..03d3974113ba 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -841,7 +841,7 @@ class
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
"""
- def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled):
+ def __init__(self, timezone, safecheck, input_types,
int_to_decimal_coercion_enabled):
super(ArrowStreamPandasUDTFSerializer, self).__init__(
timezone=timezone,
safecheck=safecheck,
@@ -861,6 +861,7 @@ class
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
ndarray_as_list=True,
# Enables explicit casting for mismatched return types of Arrow
Python UDTFs.
arrow_cast=True,
+ input_types=input_types,
# Enable additional coercions for UDTF serialization
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
@@ -885,35 +886,44 @@ class
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
import pandas as pd
import pyarrow as pa
- # Make input conform to [(series1, type1), (series2, type2), ...]
- if not isinstance(series, (list, tuple)) or (
- len(series) == 2 and isinstance(series[1], pa.DataType)
+ # Make input conform to
+ # [(series1, arrow_type1, spark_type1), (series2, arrow_type2,
spark_type2), ...]
+ if (
+ not isinstance(series, (list, tuple))
+ or (len(series) == 2 and isinstance(series[1], pa.DataType))
+ or (
+ len(series) == 3
+ and isinstance(series[1], pa.DataType)
+ and isinstance(series[2], DataType)
+ )
):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s
in series)
+ series = ((s[0], s[1], None) if len(s) == 2 else s for s in series)
arrs = []
- for s, t in series:
+ for s, arrow_type, spark_type in series:
if not isinstance(s, pd.DataFrame):
raise PySparkValueError(
"Output of an arrow-optimized Python UDTFs expects "
f"a pandas.DataFrame but got: {type(s)}"
)
- arrs.append(self._create_struct_array(s, t))
+ arrs.append(self._create_struct_array(s, arrow_type, spark_type))
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in
range(len(arrs))])
def _get_or_create_converter_from_pandas(self, dt):
- if dt not in self._converter_map:
+ key = dt.json()
+ if key not in self._converter_map:
conv = _create_converter_from_pandas(
dt,
timezone=self._timezone,
error_on_duplicated_field_names=False,
ignore_unexpected_complex_type_values=True,
)
- self._converter_map[dt] = conv
- return self._converter_map[dt]
+ self._converter_map[key] = conv
+ return self._converter_map[key]
def _create_array(self, series, arrow_type, spark_type=None,
arrow_cast=False):
"""
diff --git a/python/pyspark/sql/tests/test_udtf.py
b/python/pyspark/sql/tests/test_udtf.py
index 2bb7c6d1f176..7f812ad20f59 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -63,6 +63,7 @@ from pyspark.sql.types import (
VariantVal,
)
from pyspark.testing import assertDataFrameEqual, assertSchemaEqual
+from pyspark.testing.objects import ExamplePoint, ExamplePointUDT
from pyspark.testing.sqlutils import (
have_pandas,
have_pyarrow,
@@ -230,6 +231,14 @@ class BaseUDTFTestsMixin:
with self.assertRaisesRegex(PythonException,
"UDTF_INVALID_OUTPUT_ROW_TYPE"):
TestUDTF(lit(1)).collect()
+ @udtf(returnType=StructType().add("point", ExamplePointUDT()))
+ class TestUDTF:
+ def eval(self, x: float, y: float):
+ yield ExamplePoint(x=x * 10, y=y * 10)
+
+ with self.assertRaisesRegex(PythonException,
"UDTF_INVALID_OUTPUT_ROW_TYPE"):
+ TestUDTF(lit(1.0), lit(2.0)).collect()
+
def test_udtf_eval_returning_tuple_with_struct_type(self):
@udtf(returnType="a: struct<b: int, c: int>")
class TestUDTF:
@@ -246,6 +255,30 @@ class BaseUDTFTestsMixin:
with self.assertRaisesRegex(PythonException,
"UDTF_RETURN_SCHEMA_MISMATCH"):
TestUDTF(lit(1)).collect()
+ def test_udtf_eval_returning_udt(self):
+ @udtf(returnType=StructType().add("point", ExamplePointUDT()))
+ class TestUDTF:
+ def eval(self, x: float, y: float):
+ yield ExamplePoint(x=x * 10, y=y * 10),
+
+ assertDataFrameEqual(
+ TestUDTF(lit(1.0), lit(2.0)), [Row(point=ExamplePoint(x=10.0,
y=20.0))]
+ )
+
+ def test_udtf_eval_taking_udt(self):
+ @udtf(returnType="x: double, y: double")
+ class TestUDTF:
+ def eval(self, point: ExamplePoint):
+ yield point.x * 10, point.y * 10
+
+ df = self.spark.createDataFrame(
+ [(ExamplePoint(x=1.0, y=2.0),)], schema=StructType().add("point",
ExamplePointUDT())
+ )
+ assertDataFrameEqual(
+ df.lateralJoin(TestUDTF(col("point").outer())),
+ [Row(point=ExamplePoint(x=1.0, y=2.0), x=10.0, y=20.0)],
+ )
+
def test_udtf_with_invalid_return_value(self):
@udtf(returnType="x: int")
class TestUDTF:
@@ -2955,6 +2988,13 @@ class LegacyUDTFArrowTestsMixin(BaseUDTFTestsMixin):
assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)])
+ @udtf(returnType=StructType().add("udt", ExamplePointUDT()))
+ class TestUDTF:
+ def eval(self, x: float, y: float):
+ yield ExamplePoint(x=x * 10, y=y * 10)
+
+ assertDataFrameEqual(TestUDTF(lit(1.0), lit(2.0)),
[Row(udt=ExamplePoint(x=10.0, y=20.0))])
+
def test_udtf_use_large_var_types(self):
for use_large_var_types in [True, False]:
with self.subTest(use_large_var_types=use_large_var_types):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 76c043405986..be49e527664f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -66,7 +66,7 @@ from pyspark.sql.pandas.serializers import (
ArrowBatchUDFSerializer,
ArrowStreamUDTFSerializer,
)
-from pyspark.sql.pandas.types import to_arrow_type, from_arrow_schema
+from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.types import (
ArrayType,
BinaryType,
@@ -1289,6 +1289,7 @@ def use_legacy_pandas_udf_conversion(runner_conf):
def read_udtf(pickleSer, infile, eval_type):
prefers_large_var_types = False
legacy_pandas_conversion = False
+ input_schema = None
if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
runner_conf = {}
@@ -1305,6 +1306,7 @@ def read_udtf(pickleSer, infile, eval_type):
).lower()
== "true"
)
+ input_schema =
_parse_datatype_json_string(utf8_deserializer.loads(infile))
if legacy_pandas_conversion:
# NOTE: if timezone is set here, that implies
respectSessionTimeZone is True
safecheck = (
@@ -1320,8 +1322,12 @@ def read_udtf(pickleSer, infile, eval_type):
== "true"
)
timezone = runner_conf.get("spark.sql.session.timeZone", None)
+ input_types = [field.dataType for field in input_schema]
ser = ArrowStreamPandasUDTFSerializer(
- timezone, safecheck,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled
+ timezone,
+ safecheck,
+ input_types=input_types,
+
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
else:
ser = ArrowStreamUDTFSerializer()
@@ -1709,17 +1715,21 @@ def read_udtf(pickleSer, infile, eval_type):
def evaluate(*args: pd.Series, num_rows=1):
if len(args) == 0:
for _ in range(num_rows):
- yield verify_result(
- pd.DataFrame(check_return_value(func()))
- ), arrow_return_type
+ yield (
+
verify_result(pd.DataFrame(check_return_value(func()))),
+ arrow_return_type,
+ return_type,
+ )
else:
# Create tuples from the input pandas Series, each tuple
# represents a row across all Series.
row_tuples = zip(*args)
for row in row_tuples:
- yield verify_result(
- pd.DataFrame(check_return_value(func(*row)))
- ), arrow_return_type
+ yield (
+
verify_result(pd.DataFrame(check_return_value(func(*row)))),
+ arrow_return_type,
+ return_type,
+ )
return evaluate
@@ -1868,21 +1878,14 @@ def read_udtf(pickleSer, infile, eval_type):
except Exception as e:
raise_conversion_error(e)
- def evaluate(*args: pa.ChunkedArray, num_rows=1):
+ def evaluate(*args: list, num_rows=1):
if len(args) == 0:
for _ in range(num_rows):
for batch in
verify_result(convert_to_arrow(func())).to_batches():
yield batch, arrow_return_type
else:
- list_args = list(args)
- names = [f"_{n}" for n in range(len(list_args))]
- t = pa.Table.from_arrays(list_args, names=names)
- schema = from_arrow_schema(t.schema,
prefers_large_var_types)
- rows = ArrowTableToRowsConversion.convert(
- t, schema=schema, return_as_tuples=True
- )
- for row in rows:
+ for row in zip(*args):
for batch in
verify_result(convert_to_arrow(func(*row))).to_batches():
yield batch, arrow_return_type
@@ -1902,10 +1905,22 @@ def read_udtf(pickleSer, infile, eval_type):
def mapper(_, it):
try:
+ converters = [
+ ArrowTableToRowsConversion._create_converter(
+ field.dataType, none_on_identity=True
+ )
+ for field in input_schema
+ ]
for a in it:
+ pylist = [
+ [conv(v) for v in column.to_pylist()]
+ if conv is not None
+ else column.to_pylist()
+ for column, conv in zip(a.columns, converters)
+ ]
# The eval function yields an iterator. Each element
produced by this
# iterator is a tuple in the form of (pyarrow.RecordBatch,
arrow_return_type).
- yield from eval(*[a[o] for o in args_kwargs_offsets],
num_rows=a.num_rows)
+ yield from eval(*[pylist[o] for o in args_kwargs_offsets],
num_rows=a.num_rows)
if terminate is not None:
yield from terminate()
except SkipRestOfInputTableException:
@@ -1925,6 +1940,16 @@ def read_udtf(pickleSer, infile, eval_type):
def verify_and_convert_result(result):
if result is not None:
+ if hasattr(result, "__UDT__"):
+ # UDT object should not be returned directly.
+ raise PySparkRuntimeError(
+ errorClass="UDTF_INVALID_OUTPUT_ROW_TYPE",
+ messageParameters={
+ "type": type(result).__name__,
+ "func": f.__name__,
+ },
+ )
+
if hasattr(result, "__len__") and len(result) !=
return_type_size:
raise PySparkRuntimeError(
errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
index d7106403a388..6a6b08a97330 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructType, UserDefinedType}
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
/**
@@ -61,7 +61,9 @@ case class ArrowEvalPythonUDTFExec(
val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else
Iterator(iter)
- val outputTypes = resultAttrs.map(_.dataType)
+ val outputTypes = resultAttrs.map(_.dataType.transformRecursively {
+ case udt: UserDefinedType[_] => udt.sqlType
+ })
val columnarBatchIter = new ArrowPythonUDTFRunner(
udtf,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
index 86136e444d43..660c886a3823 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
@@ -48,6 +48,7 @@ class ArrowPythonUDTFRunner(
with BasicPythonArrowOutput {
override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+ PythonWorkerUtils.writeUTF(schema.json, dataOut)
PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]