This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 89c44a319463 [SPARK-51076][PYTHON][CONNECT] Arrow Python UDF fallback 
for UDT input and output types
89c44a319463 is described below

commit 89c44a31946358892b3c3863e3f4d46e0e2d7dcf
Author: Xinrong Meng <[email protected]>
AuthorDate: Tue Feb 4 17:54:13 2025 -0800

    [SPARK-51076][PYTHON][CONNECT] Arrow Python UDF fallback for UDT input and 
output types
    
    ### What changes were proposed in this pull request?
    Introduce a fallback mechanism for Arrow-optimized Python UDFs when either 
the input or return types contain User-Defined Types (UDTs). If UDTs are 
detected, the system logs a warning and switches to currently default, 
non-Arrow-optimized UDF.
    
    ### Why are the changes needed?
    To unblock enabling Arrow-optimized Python UDFs by default, see 
[pr](https://github.com/apache/spark/pull/49482)
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. UDT input and output types will not fail Arrow Python UDF anymore, as 
shown below:
    
    ```py
    >>> import pyspark.sql.functions as F
    >>> from pyspark.sql import Row
    >>> from pyspark.testing.sqlutils import ExamplePoint, ExamplePointUDT
    
    # UDT intput
    >>> from pyspark.sql.types import *
    >>> row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
    >>> df = spark.createDataFrame([row])
    >>>
    >>> udf1 = F.udf(lambda p: p.y, DoubleType(), useArrow=True)
    >>> df.select(udf1(df.point)).show()
    25/02/03 17:49:57 WARN ExtractPythonUDFs: Arrow optimization disabled due 
to UDT input or return type. Falling back to non-Arrow-optimized UDF execution.
    +---------------+
    |<lambda>(point)|
    +---------------+
    |            2.0|
    +---------------+
    
    # UDT output
    >>> row = Row(value=3.0)
    >>> df = spark.createDataFrame([row])
    >>> udf_with_udt_output = F.udf(lambda v: ExamplePoint(v, v + 1), 
ExamplePointUDT(), useArrow=True)
    >>> df.select(udf_with_udt_output(df.value)).show()
    25/02/03 17:51:43 WARN ExtractPythonUDFs: Arrow optimization disabled due 
to UDT input or return type. Falling back to non-Arrow-optimized UDF execution.
    +---------------+
    |<lambda>(value)|
    +---------------+
    |     (3.0, 4.0)|
    +---------------+
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #49786 from xinrong-meng/udt_arrow_udf.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Xinrong Meng <[email protected]>
    (cherry picked from commit ac2f3a4dc814518882cafc012af7b2ad95b45613)
    Signed-off-by: Xinrong Meng <[email protected]>
---
 python/pyspark/sql/tests/test_types.py             | 18 +++++++++++++
 .../sql/execution/python/ExtractPythonUDFs.scala   | 31 +++++++++++++++++++---
 2 files changed, 46 insertions(+), 3 deletions(-)

diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 00fee71156ef..9577fe359857 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -950,6 +950,9 @@ class TypesTestsMixin:
         udf = F.udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
         gd.select(udf(*gd)).collect()
 
+        arrow_udf = F.udf(lambda k, v: [(k, v[0])], ArrayType(df.schema), 
useArrow=True)
+        gd.select(arrow_udf(*gd)).collect()
+
     def test_udt_with_none(self):
         df = self.spark.range(0, 10, 1, 1)
 
@@ -1054,19 +1057,34 @@ class TypesTestsMixin:
         self.assertEqual(points, [PythonOnlyPoint(1.0, 2.0), None])
 
     def test_udf_with_udt(self):
+        # UDT input
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
         df = self.spark.createDataFrame([row])
         udf = F.udf(lambda p: p.y, DoubleType())
         self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+        arrow_udf = F.udf(lambda p: p.y, DoubleType(), useArrow=True)
+        self.assertEqual(2.0, df.select(arrow_udf(df.point)).first()[0])
+
         udf2 = F.udf(lambda p: ExamplePoint(p.x + 1, p.y + 1), 
ExamplePointUDT())
         self.assertEqual(ExamplePoint(2.0, 3.0), 
df.select(udf2(df.point)).first()[0])
+        arrow_udf2 = F.udf(
+            lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT(), 
useArrow=True
+        )
+        self.assertEqual(ExamplePoint(2.0, 3.0), 
df.select(arrow_udf2(df.point)).first()[0])
 
         row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
         df = self.spark.createDataFrame([row])
         udf = F.udf(lambda p: p.y, DoubleType())
         self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+        arrow_udf = F.udf(lambda p: p.y, DoubleType(), useArrow=True)
+        self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+
         udf2 = F.udf(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), 
PythonOnlyUDT())
         self.assertEqual(PythonOnlyPoint(2.0, 3.0), 
df.select(udf2(df.point)).first()[0])
+        arrow_udf2 = F.udf(
+            lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT(), 
useArrow=True
+        )
+        self.assertEqual(PythonOnlyPoint(2.0, 3.0), 
df.select(arrow_udf2(df.point)).first()[0])
 
     def test_rdd_with_udt(self):
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index dcd6603f6490..99bcbfd9eb24 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -22,11 +22,14 @@ import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.SparkException
 import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.internal.{Logging, MDC}
+import org.apache.spark.internal.LogKeys.REASON
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern._
+import org.apache.spark.sql.types._
 
 
 /**
@@ -157,7 +160,7 @@ object ExtractGroupingPythonUDFFromAggregate extends 
Rule[LogicalPlan] {
  * This has the limitation that the input to the Python UDF is not allowed 
include attributes from
  * multiple child operators.
  */
-object ExtractPythonUDFs extends Rule[LogicalPlan] {
+object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
 
   private type EvalType = Int
   private type EvalTypeChecker = EvalType => Boolean
@@ -232,6 +235,14 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] {
     }
   }
 
+  private def containsUDT(dataType: DataType): Boolean = dataType match {
+    case _: UserDefinedType[_] => true
+    case ArrayType(elementType, _) => containsUDT(elementType)
+    case StructType(fields) => fields.exists(field => 
containsUDT(field.dataType))
+    case MapType(keyType, valueType, _) => containsUDT(keyType) || 
containsUDT(valueType)
+    case _ => false
+  }
+
   /**
    * Extract all the PythonUDFs from the current operator and evaluate them 
before the operator.
    */
@@ -268,12 +279,26 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] {
               evalTypes.mkString(","))
           }
           val evalType = evalTypes.head
+
+          val hasUDTInput = validUdfs.exists(_.children.exists(expr => 
containsUDT(expr.dataType)))
+          val hasUDTReturn = validUdfs.exists(udf => containsUDT(udf.dataType))
+
           val evaluation = evalType match {
             case PythonEvalType.SQL_BATCHED_UDF =>
               BatchEvalPython(validUdfs, resultAttrs, child)
-            case PythonEvalType.SQL_SCALAR_PANDAS_UDF | 
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
-                 | PythonEvalType.SQL_ARROW_BATCHED_UDF =>
+            case PythonEvalType.SQL_SCALAR_PANDAS_UDF | 
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF =>
               ArrowEvalPython(validUdfs, resultAttrs, child, evalType)
+            case PythonEvalType.SQL_ARROW_BATCHED_UDF =>
+
+              if (hasUDTInput || hasUDTReturn) {
+                // Use BatchEvalPython if UDT is detected
+                logWarning(log"Arrow optimization disabled due to " +
+                  log"${MDC(REASON, "UDT input or return type")}. " +
+                  log"Falling back to non-Arrow-optimized UDF execution.")
+                BatchEvalPython(validUdfs, resultAttrs, child)
+              } else {
+                ArrowEvalPython(validUdfs, resultAttrs, child, evalType)
+              }
             case _ =>
               throw SparkException.internalError("Unexpected UDF evalType")
           }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to