This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1dccc5341d98 [SPARK-52959][PYTHON] Support UDT in Arrow-optimized Python UDTF 1dccc5341d98 is described below commit 1dccc5341d987f297dcdf5bcf6069fec9caa57dc Author: Takuya Ueshin <ues...@databricks.com> AuthorDate: Tue Jul 29 08:10:49 2025 -0700 [SPARK-52959][PYTHON] Support UDT in Arrow-optimized Python UDTF ### What changes were proposed in this pull request? Supports UDT in Arrow-optimized Python UDTF. ### Why are the changes needed? Arrow-optimized Python UDTF doesn't work with UDTs. ```py >>> udtf(returnType=StructType().add("point", ExamplePointUDT()), useArrow=True) ... class TestUDTFReturningUDT: ... def eval(self, x: float, y: float): ... yield ExamplePoint(x=x, y=y) ... >>> df = spark.createDataFrame([(ExamplePoint(x=1.0, y=2.0),)], schema=StructType().add("point", ExamplePointUDT())) >>> df.lateralJoin(TestUDTFTakingUDT(col("point").outer())).show() ... AttributeError: 'list' object has no attribute 'x' >>> udtf(returnType=StructType().add("point", ExamplePointUDT()), useArrow=True) ... class TestUDTFReturningUDT: ... def eval(self, x: float, y: float): ... yield ExamplePoint(x=x, y=y), ... >>> df = spark.range(2).select((col("id") + 10.0).alias("x"), (col("id") * 10.0).alias("y")) >>> df.lateralJoin(TestUDTFReturningUDT(col("x").outer(), col("y").outer())).show() ... org.apache.spark.SparkException: [ARROW_TYPE_MISMATCH] Invalid schema from Python UDTF: expected org.apache.spark.sql.test.ExamplePointUDT7f511308, got ArrayType(DoubleType,false). SQLSTATE: 42K0G ``` ### Does this PR introduce _any_ user-facing change? Yes, UDTs are available as input / output for Arrow-optimized Python UDTF. ### How was this patch tested? Added the related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51694 from ueshin/issues/SPARK-52959/udtf_udt. Authored-by: Takuya Ueshin <ues...@databricks.com> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- python/pyspark/sql/pandas/serializers.py | 28 ++++++---- python/pyspark/sql/tests/test_udtf.py | 40 ++++++++++++++ python/pyspark/worker.py | 61 +++++++++++++++------- .../execution/python/ArrowEvalPythonUDTFExec.scala | 6 ++- .../execution/python/ArrowPythonUDTFRunner.scala | 1 + 5 files changed, 107 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 5943524db433..03d3974113ba 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -841,7 +841,7 @@ class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer): Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs. """ - def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled): + def __init__(self, timezone, safecheck, input_types, int_to_decimal_coercion_enabled): super(ArrowStreamPandasUDTFSerializer, self).__init__( timezone=timezone, safecheck=safecheck, @@ -861,6 +861,7 @@ class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer): ndarray_as_list=True, # Enables explicit casting for mismatched return types of Arrow Python UDTFs. arrow_cast=True, + input_types=input_types, # Enable additional coercions for UDTF serialization int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, ) @@ -885,35 +886,44 @@ class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer): import pandas as pd import pyarrow as pa - # Make input conform to [(series1, type1), (series2, type2), ...] - if not isinstance(series, (list, tuple)) or ( - len(series) == 2 and isinstance(series[1], pa.DataType) + # Make input conform to + # [(series1, arrow_type1, spark_type1), (series2, arrow_type2, spark_type2), ...] + if ( + not isinstance(series, (list, tuple)) + or (len(series) == 2 and isinstance(series[1], pa.DataType)) + or ( + len(series) == 3 + and isinstance(series[1], pa.DataType) + and isinstance(series[2], DataType) + ) ): series = [series] series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + series = ((s[0], s[1], None) if len(s) == 2 else s for s in series) arrs = [] - for s, t in series: + for s, arrow_type, spark_type in series: if not isinstance(s, pd.DataFrame): raise PySparkValueError( "Output of an arrow-optimized Python UDTFs expects " f"a pandas.DataFrame but got: {type(s)}" ) - arrs.append(self._create_struct_array(s, t)) + arrs.append(self._create_struct_array(s, arrow_type, spark_type)) return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) def _get_or_create_converter_from_pandas(self, dt): - if dt not in self._converter_map: + key = dt.json() + if key not in self._converter_map: conv = _create_converter_from_pandas( dt, timezone=self._timezone, error_on_duplicated_field_names=False, ignore_unexpected_complex_type_values=True, ) - self._converter_map[dt] = conv - return self._converter_map[dt] + self._converter_map[key] = conv + return self._converter_map[key] def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False): """ diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 2bb7c6d1f176..7f812ad20f59 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -63,6 +63,7 @@ from pyspark.sql.types import ( VariantVal, ) from pyspark.testing import assertDataFrameEqual, assertSchemaEqual +from pyspark.testing.objects import ExamplePoint, ExamplePointUDT from pyspark.testing.sqlutils import ( have_pandas, have_pyarrow, @@ -230,6 +231,14 @@ class BaseUDTFTestsMixin: with self.assertRaisesRegex(PythonException, "UDTF_INVALID_OUTPUT_ROW_TYPE"): TestUDTF(lit(1)).collect() + @udtf(returnType=StructType().add("point", ExamplePointUDT())) + class TestUDTF: + def eval(self, x: float, y: float): + yield ExamplePoint(x=x * 10, y=y * 10) + + with self.assertRaisesRegex(PythonException, "UDTF_INVALID_OUTPUT_ROW_TYPE"): + TestUDTF(lit(1.0), lit(2.0)).collect() + def test_udtf_eval_returning_tuple_with_struct_type(self): @udtf(returnType="a: struct<b: int, c: int>") class TestUDTF: @@ -246,6 +255,30 @@ class BaseUDTFTestsMixin: with self.assertRaisesRegex(PythonException, "UDTF_RETURN_SCHEMA_MISMATCH"): TestUDTF(lit(1)).collect() + def test_udtf_eval_returning_udt(self): + @udtf(returnType=StructType().add("point", ExamplePointUDT())) + class TestUDTF: + def eval(self, x: float, y: float): + yield ExamplePoint(x=x * 10, y=y * 10), + + assertDataFrameEqual( + TestUDTF(lit(1.0), lit(2.0)), [Row(point=ExamplePoint(x=10.0, y=20.0))] + ) + + def test_udtf_eval_taking_udt(self): + @udtf(returnType="x: double, y: double") + class TestUDTF: + def eval(self, point: ExamplePoint): + yield point.x * 10, point.y * 10 + + df = self.spark.createDataFrame( + [(ExamplePoint(x=1.0, y=2.0),)], schema=StructType().add("point", ExamplePointUDT()) + ) + assertDataFrameEqual( + df.lateralJoin(TestUDTF(col("point").outer())), + [Row(point=ExamplePoint(x=1.0, y=2.0), x=10.0, y=20.0)], + ) + def test_udtf_with_invalid_return_value(self): @udtf(returnType="x: int") class TestUDTF: @@ -2955,6 +2988,13 @@ class LegacyUDTFArrowTestsMixin(BaseUDTFTestsMixin): assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)]) + @udtf(returnType=StructType().add("udt", ExamplePointUDT())) + class TestUDTF: + def eval(self, x: float, y: float): + yield ExamplePoint(x=x * 10, y=y * 10) + + assertDataFrameEqual(TestUDTF(lit(1.0), lit(2.0)), [Row(udt=ExamplePoint(x=10.0, y=20.0))]) + def test_udtf_use_large_var_types(self): for use_large_var_types in [True, False]: with self.subTest(use_large_var_types=use_large_var_types): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 76c043405986..be49e527664f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -66,7 +66,7 @@ from pyspark.sql.pandas.serializers import ( ArrowBatchUDFSerializer, ArrowStreamUDTFSerializer, ) -from pyspark.sql.pandas.types import to_arrow_type, from_arrow_schema +from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import ( ArrayType, BinaryType, @@ -1289,6 +1289,7 @@ def use_legacy_pandas_udf_conversion(runner_conf): def read_udtf(pickleSer, infile, eval_type): prefers_large_var_types = False legacy_pandas_conversion = False + input_schema = None if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF: runner_conf = {} @@ -1305,6 +1306,7 @@ def read_udtf(pickleSer, infile, eval_type): ).lower() == "true" ) + input_schema = _parse_datatype_json_string(utf8_deserializer.loads(infile)) if legacy_pandas_conversion: # NOTE: if timezone is set here, that implies respectSessionTimeZone is True safecheck = ( @@ -1320,8 +1322,12 @@ def read_udtf(pickleSer, infile, eval_type): == "true" ) timezone = runner_conf.get("spark.sql.session.timeZone", None) + input_types = [field.dataType for field in input_schema] ser = ArrowStreamPandasUDTFSerializer( - timezone, safecheck, int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled + timezone, + safecheck, + input_types=input_types, + int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, ) else: ser = ArrowStreamUDTFSerializer() @@ -1709,17 +1715,21 @@ def read_udtf(pickleSer, infile, eval_type): def evaluate(*args: pd.Series, num_rows=1): if len(args) == 0: for _ in range(num_rows): - yield verify_result( - pd.DataFrame(check_return_value(func())) - ), arrow_return_type + yield ( + verify_result(pd.DataFrame(check_return_value(func()))), + arrow_return_type, + return_type, + ) else: # Create tuples from the input pandas Series, each tuple # represents a row across all Series. row_tuples = zip(*args) for row in row_tuples: - yield verify_result( - pd.DataFrame(check_return_value(func(*row))) - ), arrow_return_type + yield ( + verify_result(pd.DataFrame(check_return_value(func(*row)))), + arrow_return_type, + return_type, + ) return evaluate @@ -1868,21 +1878,14 @@ def read_udtf(pickleSer, infile, eval_type): except Exception as e: raise_conversion_error(e) - def evaluate(*args: pa.ChunkedArray, num_rows=1): + def evaluate(*args: list, num_rows=1): if len(args) == 0: for _ in range(num_rows): for batch in verify_result(convert_to_arrow(func())).to_batches(): yield batch, arrow_return_type else: - list_args = list(args) - names = [f"_{n}" for n in range(len(list_args))] - t = pa.Table.from_arrays(list_args, names=names) - schema = from_arrow_schema(t.schema, prefers_large_var_types) - rows = ArrowTableToRowsConversion.convert( - t, schema=schema, return_as_tuples=True - ) - for row in rows: + for row in zip(*args): for batch in verify_result(convert_to_arrow(func(*row))).to_batches(): yield batch, arrow_return_type @@ -1902,10 +1905,22 @@ def read_udtf(pickleSer, infile, eval_type): def mapper(_, it): try: + converters = [ + ArrowTableToRowsConversion._create_converter( + field.dataType, none_on_identity=True + ) + for field in input_schema + ] for a in it: + pylist = [ + [conv(v) for v in column.to_pylist()] + if conv is not None + else column.to_pylist() + for column, conv in zip(a.columns, converters) + ] # The eval function yields an iterator. Each element produced by this # iterator is a tuple in the form of (pyarrow.RecordBatch, arrow_return_type). - yield from eval(*[a[o] for o in args_kwargs_offsets], num_rows=a.num_rows) + yield from eval(*[pylist[o] for o in args_kwargs_offsets], num_rows=a.num_rows) if terminate is not None: yield from terminate() except SkipRestOfInputTableException: @@ -1925,6 +1940,16 @@ def read_udtf(pickleSer, infile, eval_type): def verify_and_convert_result(result): if result is not None: + if hasattr(result, "__UDT__"): + # UDT object should not be returned directly. + raise PySparkRuntimeError( + errorClass="UDTF_INVALID_OUTPUT_ROW_TYPE", + messageParameters={ + "type": type(result).__name__, + "func": f.__name__, + }, + ) + if hasattr(result, "__len__") and len(result) != return_type_size: raise PySparkRuntimeError( errorClass="UDTF_RETURN_SCHEMA_MISMATCH", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala index d7106403a388..6a6b08a97330 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructType, UserDefinedType} import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** @@ -61,7 +61,9 @@ case class ArrowEvalPythonUDTFExec( val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) - val outputTypes = resultAttrs.map(_.dataType) + val outputTypes = resultAttrs.map(_.dataType.transformRecursively { + case udt: UserDefinedType[_] => udt.sqlType + }) val columnarBatchIter = new ArrowPythonUDTFRunner( udtf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala index 86136e444d43..660c886a3823 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala @@ -48,6 +48,7 @@ class ArrowPythonUDTFRunner( with BasicPythonArrowOutput { override protected def writeUDF(dataOut: DataOutputStream): Unit = { + PythonWorkerUtils.writeUTF(schema.json, dataOut) PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org