This is an automated email from the ASF dual-hosted git repository.
xinrong pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 89c44a319463 [SPARK-51076][PYTHON][CONNECT] Arrow Python UDF fallback
for UDT input and output types
89c44a319463 is described below
commit 89c44a31946358892b3c3863e3f4d46e0e2d7dcf
Author: Xinrong Meng <[email protected]>
AuthorDate: Tue Feb 4 17:54:13 2025 -0800
[SPARK-51076][PYTHON][CONNECT] Arrow Python UDF fallback for UDT input and
output types
### What changes were proposed in this pull request?
Introduce a fallback mechanism for Arrow-optimized Python UDFs when either
the input or return types contain User-Defined Types (UDTs). If UDTs are
detected, the system logs a warning and switches to currently default,
non-Arrow-optimized UDF.
### Why are the changes needed?
To unblock enabling Arrow-optimized Python UDFs by default, see
[pr](https://github.com/apache/spark/pull/49482)
### Does this PR introduce _any_ user-facing change?
Yes. UDT input and output types will not fail Arrow Python UDF anymore, as
shown below:
```py
>>> import pyspark.sql.functions as F
>>> from pyspark.sql import Row
>>> from pyspark.testing.sqlutils import ExamplePoint, ExamplePointUDT
# UDT intput
>>> from pyspark.sql.types import *
>>> row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
>>> df = spark.createDataFrame([row])
>>>
>>> udf1 = F.udf(lambda p: p.y, DoubleType(), useArrow=True)
>>> df.select(udf1(df.point)).show()
25/02/03 17:49:57 WARN ExtractPythonUDFs: Arrow optimization disabled due
to UDT input or return type. Falling back to non-Arrow-optimized UDF execution.
+---------------+
|<lambda>(point)|
+---------------+
| 2.0|
+---------------+
# UDT output
>>> row = Row(value=3.0)
>>> df = spark.createDataFrame([row])
>>> udf_with_udt_output = F.udf(lambda v: ExamplePoint(v, v + 1),
ExamplePointUDT(), useArrow=True)
>>> df.select(udf_with_udt_output(df.value)).show()
25/02/03 17:51:43 WARN ExtractPythonUDFs: Arrow optimization disabled due
to UDT input or return type. Falling back to non-Arrow-optimized UDF execution.
+---------------+
|<lambda>(value)|
+---------------+
| (3.0, 4.0)|
+---------------+
```
### How was this patch tested?
Unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49786 from xinrong-meng/udt_arrow_udf.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Xinrong Meng <[email protected]>
(cherry picked from commit ac2f3a4dc814518882cafc012af7b2ad95b45613)
Signed-off-by: Xinrong Meng <[email protected]>
---
python/pyspark/sql/tests/test_types.py | 18 +++++++++++++
.../sql/execution/python/ExtractPythonUDFs.scala | 31 +++++++++++++++++++---
2 files changed, 46 insertions(+), 3 deletions(-)
diff --git a/python/pyspark/sql/tests/test_types.py
b/python/pyspark/sql/tests/test_types.py
index 00fee71156ef..9577fe359857 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -950,6 +950,9 @@ class TypesTestsMixin:
udf = F.udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
gd.select(udf(*gd)).collect()
+ arrow_udf = F.udf(lambda k, v: [(k, v[0])], ArrayType(df.schema),
useArrow=True)
+ gd.select(arrow_udf(*gd)).collect()
+
def test_udt_with_none(self):
df = self.spark.range(0, 10, 1, 1)
@@ -1054,19 +1057,34 @@ class TypesTestsMixin:
self.assertEqual(points, [PythonOnlyPoint(1.0, 2.0), None])
def test_udf_with_udt(self):
+ # UDT input
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
df = self.spark.createDataFrame([row])
udf = F.udf(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+ arrow_udf = F.udf(lambda p: p.y, DoubleType(), useArrow=True)
+ self.assertEqual(2.0, df.select(arrow_udf(df.point)).first()[0])
+
udf2 = F.udf(lambda p: ExamplePoint(p.x + 1, p.y + 1),
ExamplePointUDT())
self.assertEqual(ExamplePoint(2.0, 3.0),
df.select(udf2(df.point)).first()[0])
+ arrow_udf2 = F.udf(
+ lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT(),
useArrow=True
+ )
+ self.assertEqual(ExamplePoint(2.0, 3.0),
df.select(arrow_udf2(df.point)).first()[0])
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
df = self.spark.createDataFrame([row])
udf = F.udf(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+ arrow_udf = F.udf(lambda p: p.y, DoubleType(), useArrow=True)
+ self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+
udf2 = F.udf(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1),
PythonOnlyUDT())
self.assertEqual(PythonOnlyPoint(2.0, 3.0),
df.select(udf2(df.point)).first()[0])
+ arrow_udf2 = F.udf(
+ lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT(),
useArrow=True
+ )
+ self.assertEqual(PythonOnlyPoint(2.0, 3.0),
df.select(arrow_udf2(df.point)).first()[0])
def test_rdd_with_udt(self):
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index dcd6603f6490..99bcbfd9eb24 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -22,11 +22,14 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkException
import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.internal.{Logging, MDC}
+import org.apache.spark.internal.LogKeys.REASON
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._
+import org.apache.spark.sql.types._
/**
@@ -157,7 +160,7 @@ object ExtractGroupingPythonUDFFromAggregate extends
Rule[LogicalPlan] {
* This has the limitation that the input to the Python UDF is not allowed
include attributes from
* multiple child operators.
*/
-object ExtractPythonUDFs extends Rule[LogicalPlan] {
+object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
private type EvalType = Int
private type EvalTypeChecker = EvalType => Boolean
@@ -232,6 +235,14 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] {
}
}
+ private def containsUDT(dataType: DataType): Boolean = dataType match {
+ case _: UserDefinedType[_] => true
+ case ArrayType(elementType, _) => containsUDT(elementType)
+ case StructType(fields) => fields.exists(field =>
containsUDT(field.dataType))
+ case MapType(keyType, valueType, _) => containsUDT(keyType) ||
containsUDT(valueType)
+ case _ => false
+ }
+
/**
* Extract all the PythonUDFs from the current operator and evaluate them
before the operator.
*/
@@ -268,12 +279,26 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] {
evalTypes.mkString(","))
}
val evalType = evalTypes.head
+
+ val hasUDTInput = validUdfs.exists(_.children.exists(expr =>
containsUDT(expr.dataType)))
+ val hasUDTReturn = validUdfs.exists(udf => containsUDT(udf.dataType))
+
val evaluation = evalType match {
case PythonEvalType.SQL_BATCHED_UDF =>
BatchEvalPython(validUdfs, resultAttrs, child)
- case PythonEvalType.SQL_SCALAR_PANDAS_UDF |
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
- | PythonEvalType.SQL_ARROW_BATCHED_UDF =>
+ case PythonEvalType.SQL_SCALAR_PANDAS_UDF |
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF =>
ArrowEvalPython(validUdfs, resultAttrs, child, evalType)
+ case PythonEvalType.SQL_ARROW_BATCHED_UDF =>
+
+ if (hasUDTInput || hasUDTReturn) {
+ // Use BatchEvalPython if UDT is detected
+ logWarning(log"Arrow optimization disabled due to " +
+ log"${MDC(REASON, "UDT input or return type")}. " +
+ log"Falling back to non-Arrow-optimized UDF execution.")
+ BatchEvalPython(validUdfs, resultAttrs, child)
+ } else {
+ ArrowEvalPython(validUdfs, resultAttrs, child, evalType)
+ }
case _ =>
throw SparkException.internalError("Unexpected UDF evalType")
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]