This is an automated email from the ASF dual-hosted git repository.
yao pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new 596f680ea37c [SPARK-48845][SQL] GenericUDF catch exceptions from
children
596f680ea37c is described below
commit 596f680ea37c8fae77d2ba29d79cbc9339a04ca9
Author: jackylee-ch <[email protected]>
AuthorDate: Fri Jul 12 18:15:56 2024 +0800
[SPARK-48845][SQL] GenericUDF catch exceptions from children
### What changes were proposed in this pull request?
This pr is trying to fix the syntax issues with GenericUDF since 3.5.0. The
problem arose from DeferredObject currently passing a value instead of a
function, which prevented users from catching exceptions in GenericUDF,
resulting in semantic differences.
Here is an example case we encountered. Originally, the semantics were that
udf_exception would throw an exception, while udf_catch_exception could catch
the exception and return a null value. However, currently, any exception
encountered by udf_exception will cause the program to fail.
```
select udf_catch_exception(udf_exception(col1)) from table
```
### Why are the changes needed?
For before Spark 3.5, we directly made the GenericUDF's DeferredObject lazy
and evaluated the children in `function.evaluate(deferredObjects)`.
Now, we would run the children's code first. If an exception is thrown, we
would make it lazy to GenericUDF's DeferredObject.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Newly added UT.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47268 from jackylee-ch/generic_udf_catch_exception_from_child_func.
Lead-authored-by: jackylee-ch <[email protected]>
Co-authored-by: Kent Yao <[email protected]>
Signed-off-by: Kent Yao <[email protected]>
(cherry picked from commit 236d95738b6e50bc9ec54955e86d01b6dcf11c0e)
Signed-off-by: Kent Yao <[email protected]>
---
.../apache/spark/sql/hive/hiveUDFEvaluators.scala | 12 +++--
.../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 22 +++++++---
.../sql/hive/execution/UDFCatchException.java | 51 ++++++++++++++++++++++
.../sql/hive/execution/UDFThrowException.java | 26 +++++++++++
.../spark/sql/hive/execution/HiveUDFSuite.scala | 23 ++++++++++
5 files changed, 124 insertions(+), 10 deletions(-)
diff --git
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala
index 094f8ba7a0f8..fc1c795a1aa1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala
@@ -129,7 +129,11 @@ class HiveGenericUDFEvaluator(
override def returnType: DataType = inspectorToDataType(returnInspector)
def setArg(index: Int, arg: Any): Unit =
- deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(arg)
+ deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(() => arg)
+
+ def setException(index: Int, exp: Throwable): Unit = {
+ deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(() => throw
exp)
+ }
override def doEvaluate(): Any =
unwrapper(function.evaluate(deferredObjects))
}
@@ -139,10 +143,10 @@ private[hive] class DeferredObjectAdapter(oi:
ObjectInspector, dataType: DataTyp
extends DeferredObject with HiveInspectors {
private val wrapper = wrapperFor(oi, dataType)
- private var func: Any = _
- def set(func: Any): Unit = {
+ private var func: () => Any = _
+ def set(func: () => Any): Unit = {
this.func = func
}
override def prepare(i: Int): Unit = {}
- override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef]
+ override def get(): AnyRef = wrapper(func()).asInstanceOf[AnyRef]
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 01684f52ab82..0c8305b3ccb2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -136,7 +136,13 @@ private[hive] case class HiveGenericUDF(
override def eval(input: InternalRow): Any = {
children.zipWithIndex.foreach {
- case (child, idx) => evaluator.setArg(idx, child.eval(input))
+ case (child, idx) =>
+ try {
+ evaluator.setArg(idx, child.eval(input))
+ } catch {
+ case t: Throwable =>
+ evaluator.setException(idx, t)
+ }
}
evaluator.evaluate()
}
@@ -157,10 +163,15 @@ private[hive] case class HiveGenericUDF(
val setValues = evals.zipWithIndex.map {
case (eval, i) =>
s"""
- |if (${eval.isNull}) {
- | $refEvaluator.setArg($i, null);
- |} else {
- | $refEvaluator.setArg($i, ${eval.value});
+ |try {
+ | ${eval.code}
+ | if (${eval.isNull}) {
+ | $refEvaluator.setArg($i, null);
+ | } else {
+ | $refEvaluator.setArg($i, ${eval.value});
+ | }
+ |} catch (Throwable t) {
+ | $refEvaluator.setException($i, t);
|}
|""".stripMargin
}
@@ -169,7 +180,6 @@ private[hive] case class HiveGenericUDF(
val resultTerm = ctx.freshName("result")
ev.copy(code =
code"""
- |${evals.map(_.code).mkString("\n")}
|${setValues.mkString("\n")}
|$resultType $resultTerm = ($resultType) $refEvaluator.evaluate();
|boolean ${ev.isNull} = $resultTerm == null;
diff --git
a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java
b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java
new file mode 100644
index 000000000000..242dbeaa63c9
--- /dev/null
+++
b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+public class UDFCatchException extends GenericUDF {
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] args) throws
UDFArgumentException {
+ if (args.length != 1) {
+ throw new UDFArgumentException("Exactly one argument is expected.");
+ }
+ return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
+ }
+
+ @Override
+ public Object evaluate(GenericUDF.DeferredObject[] args) {
+ if (args == null) {
+ return null;
+ }
+ try {
+ return args[0].get();
+ } catch (Exception e) {
+ return null;
+ }
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ return null;
+ }
+}
diff --git
a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java
b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java
new file mode 100644
index 000000000000..5d6ff6ca40ae
--- /dev/null
+++
b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+public class UDFThrowException extends UDF {
+ public String evaluate(String data) {
+ return Integer.valueOf(data).toString();
+ }
+}
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index d12ebae0f5fc..f3be79f90229 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -35,6 +35,7 @@ import org.apache.hadoop.io.{LongWritable, Writable}
import org.apache.spark.{SparkException, SparkFiles, TestUtils}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.functions.{call_function, max}
@@ -791,6 +792,28 @@ class HiveUDFSuite extends QueryTest with
TestHiveSingleton with SQLTestUtils {
}
}
}
+
+ test("SPARK-48845: GenericUDF catch exceptions from child UDFs") {
+ withTable("test_catch_exception") {
+ withUserDefinedFunction("udf_throw_exception" -> true,
"udf_catch_exception" -> true) {
+ Seq("9", "9-1").toDF("a").write.saveAsTable("test_catch_exception")
+ sql("CREATE TEMPORARY FUNCTION udf_throw_exception AS " +
+ s"'${classOf[UDFThrowException].getName}'")
+ sql("CREATE TEMPORARY FUNCTION udf_catch_exception AS " +
+ s"'${classOf[UDFCatchException].getName}'")
+ Seq(
+ CodegenObjectFactoryMode.FALLBACK.toString,
+ CodegenObjectFactoryMode.NO_CODEGEN.toString
+ ).foreach { codegenMode =>
+ withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) {
+ val df = sql(
+ "SELECT udf_catch_exception(udf_throw_exception(a)) FROM
test_catch_exception")
+ checkAnswer(df, Seq(Row("9"), Row(null)))
+ }
+ }
+ }
+ }
+ }
}
class TestPair(x: Int, y: Int) extends Writable with Serializable {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]