This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 596f680ea37c [SPARK-48845][SQL] GenericUDF catch exceptions from 
children
596f680ea37c is described below

commit 596f680ea37c8fae77d2ba29d79cbc9339a04ca9
Author: jackylee-ch <[email protected]>
AuthorDate: Fri Jul 12 18:15:56 2024 +0800

    [SPARK-48845][SQL] GenericUDF catch exceptions from children
    
    ### What changes were proposed in this pull request?
    This pr is trying to fix the syntax issues with GenericUDF since 3.5.0. The 
problem arose from DeferredObject currently passing a value instead of a 
function, which prevented users from catching exceptions in GenericUDF, 
resulting in semantic differences.
    
    Here is an example case we encountered. Originally, the semantics were that 
udf_exception would throw an exception, while udf_catch_exception could catch 
the exception and return a null value. However, currently, any exception 
encountered by udf_exception will cause the program to fail.
    ```
    select udf_catch_exception(udf_exception(col1)) from table
    ```
    
    ### Why are the changes needed?
    For before Spark 3.5, we directly made the GenericUDF's DeferredObject lazy 
and evaluated the children in `function.evaluate(deferredObjects)`.
    Now, we would run the children's code first. If an exception is thrown, we 
would make it lazy to GenericUDF's DeferredObject.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Newly added UT.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #47268 from jackylee-ch/generic_udf_catch_exception_from_child_func.
    
    Lead-authored-by: jackylee-ch <[email protected]>
    Co-authored-by: Kent Yao <[email protected]>
    Signed-off-by: Kent Yao <[email protected]>
    (cherry picked from commit 236d95738b6e50bc9ec54955e86d01b6dcf11c0e)
    Signed-off-by: Kent Yao <[email protected]>
---
 .../apache/spark/sql/hive/hiveUDFEvaluators.scala  | 12 +++--
 .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 22 +++++++---
 .../sql/hive/execution/UDFCatchException.java      | 51 ++++++++++++++++++++++
 .../sql/hive/execution/UDFThrowException.java      | 26 +++++++++++
 .../spark/sql/hive/execution/HiveUDFSuite.scala    | 23 ++++++++++
 5 files changed, 124 insertions(+), 10 deletions(-)

diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala
index 094f8ba7a0f8..fc1c795a1aa1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala
@@ -129,7 +129,11 @@ class HiveGenericUDFEvaluator(
   override def returnType: DataType = inspectorToDataType(returnInspector)
 
   def setArg(index: Int, arg: Any): Unit =
-    deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(arg)
+    deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(() => arg)
+
+  def setException(index: Int, exp: Throwable): Unit = {
+    deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(() => throw 
exp)
+  }
 
   override def doEvaluate(): Any = 
unwrapper(function.evaluate(deferredObjects))
 }
@@ -139,10 +143,10 @@ private[hive] class DeferredObjectAdapter(oi: 
ObjectInspector, dataType: DataTyp
   extends DeferredObject with HiveInspectors {
 
   private val wrapper = wrapperFor(oi, dataType)
-  private var func: Any = _
-  def set(func: Any): Unit = {
+  private var func: () => Any = _
+  def set(func: () => Any): Unit = {
     this.func = func
   }
   override def prepare(i: Int): Unit = {}
-  override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef]
+  override def get(): AnyRef = wrapper(func()).asInstanceOf[AnyRef]
 }
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 01684f52ab82..0c8305b3ccb2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -136,7 +136,13 @@ private[hive] case class HiveGenericUDF(
 
   override def eval(input: InternalRow): Any = {
     children.zipWithIndex.foreach {
-      case (child, idx) => evaluator.setArg(idx, child.eval(input))
+      case (child, idx) =>
+        try {
+          evaluator.setArg(idx, child.eval(input))
+        } catch {
+          case t: Throwable =>
+            evaluator.setException(idx, t)
+        }
     }
     evaluator.evaluate()
   }
@@ -157,10 +163,15 @@ private[hive] case class HiveGenericUDF(
     val setValues = evals.zipWithIndex.map {
       case (eval, i) =>
         s"""
-           |if (${eval.isNull}) {
-           |  $refEvaluator.setArg($i, null);
-           |} else {
-           |  $refEvaluator.setArg($i, ${eval.value});
+           |try {
+           |  ${eval.code}
+           |  if (${eval.isNull}) {
+           |    $refEvaluator.setArg($i, null);
+           |  } else {
+           |    $refEvaluator.setArg($i, ${eval.value});
+           |  }
+           |} catch (Throwable t) {
+           |  $refEvaluator.setException($i, t);
            |}
            |""".stripMargin
     }
@@ -169,7 +180,6 @@ private[hive] case class HiveGenericUDF(
     val resultTerm = ctx.freshName("result")
     ev.copy(code =
       code"""
-         |${evals.map(_.code).mkString("\n")}
          |${setValues.mkString("\n")}
          |$resultType $resultTerm = ($resultType) $refEvaluator.evaluate();
          |boolean ${ev.isNull} = $resultTerm == null;
diff --git 
a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java
 
b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java
new file mode 100644
index 000000000000..242dbeaa63c9
--- /dev/null
+++ 
b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+public class UDFCatchException extends GenericUDF {
+
+  @Override
+  public ObjectInspector initialize(ObjectInspector[] args) throws 
UDFArgumentException {
+    if (args.length != 1) {
+      throw new UDFArgumentException("Exactly one argument is expected.");
+    }
+    return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+  }
+
+  @Override
+  public Object evaluate(GenericUDF.DeferredObject[] args) {
+    if (args == null) {
+      return null;
+    }
+    try {
+      return args[0].get();
+    } catch (Exception e) {
+      return null;
+    }
+  }
+
+  @Override
+  public String getDisplayString(String[] children) {
+    return null;
+  }
+}
diff --git 
a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java
 
b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java
new file mode 100644
index 000000000000..5d6ff6ca40ae
--- /dev/null
+++ 
b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+public class UDFThrowException extends UDF {
+  public String evaluate(String data) {
+    return Integer.valueOf(data).toString();
+  }
+}
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index d12ebae0f5fc..f3be79f90229 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -35,6 +35,7 @@ import org.apache.hadoop.io.{LongWritable, Writable}
 
 import org.apache.spark.{SparkException, SparkFiles, TestUtils}
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
 import org.apache.spark.sql.catalyst.plans.logical.Project
 import org.apache.spark.sql.execution.WholeStageCodegenExec
 import org.apache.spark.sql.functions.{call_function, max}
@@ -791,6 +792,28 @@ class HiveUDFSuite extends QueryTest with 
TestHiveSingleton with SQLTestUtils {
       }
     }
   }
+
+  test("SPARK-48845: GenericUDF catch exceptions from child UDFs") {
+    withTable("test_catch_exception") {
+      withUserDefinedFunction("udf_throw_exception" -> true, 
"udf_catch_exception" -> true) {
+        Seq("9", "9-1").toDF("a").write.saveAsTable("test_catch_exception")
+        sql("CREATE TEMPORARY FUNCTION udf_throw_exception AS " +
+          s"'${classOf[UDFThrowException].getName}'")
+        sql("CREATE TEMPORARY FUNCTION udf_catch_exception AS " +
+          s"'${classOf[UDFCatchException].getName}'")
+        Seq(
+          CodegenObjectFactoryMode.FALLBACK.toString,
+          CodegenObjectFactoryMode.NO_CODEGEN.toString
+        ).foreach { codegenMode =>
+          withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) {
+            val df = sql(
+              "SELECT udf_catch_exception(udf_throw_exception(a)) FROM 
test_catch_exception")
+            checkAnswer(df, Seq(Row("9"), Row(null)))
+          }
+        }
+      }
+    }
+  }
 }
 
 class TestPair(x: Int, y: Int) extends Writable with Serializable {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to