This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 91dd5e3d74b7 [SPARK-50148][SQL] Make `StaticInvoke` compatible with
the method that declare throw exception
91dd5e3d74b7 is described below
commit 91dd5e3d74b73d0d5d1e51387352cb61e1e3b5d9
Author: panbingkun <[email protected]>
AuthorDate: Wed Oct 30 17:53:06 2024 +0100
[SPARK-50148][SQL] Make `StaticInvoke` compatible with the method that
declare throw exception
### What changes were proposed in this pull request?
The pr aims to make `StaticInvoke` compatible with the method that `declare
throw exception`.
### Why are the changes needed?
Currently, our `StaticInvoke` does not support calling the method that
`declare throw exception`, while `Invoke` supports it, let's align it.
- Invoke
https://github.com/apache/spark/blob/bb15eb7b91ab775bdb84b6b17353a706794b122d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala#L457-L469
- StaticInvoke
https://github.com/apache/spark/blob/bb15eb7b91ab775bdb84b6b17353a706794b122d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala#L313
### Does this PR introduce _any_ user-facing change?
No, only for spark developer.
### How was this patch tested?
- Add new UT.
- Pass GA.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48679 from panbingkun/SPARK-50148.
Authored-by: panbingkun <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../sql/catalyst/expressions/objects/objects.scala | 51 +++++++++++++---------
.../expressions/TestThrowExceptionMethod.java | 39 +++++++++++++++++
.../expressions/ObjectExpressionsSuite.scala | 24 ++++++++++
3 files changed, 94 insertions(+), 20 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 7c198f05cf49..5c786bc5ddbf 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -211,6 +211,25 @@ trait InvokeLike extends Expression with NonSQLExpression
with ImplicitCastInput
method
}
}
+
+ final def getFuncResult(
+ needTryCatch: Boolean,
+ resultVal: String,
+ funcCall: String,
+ returnType: Option[String] = None): String = {
+ val castFuncCall = if (returnType.isEmpty) funcCall else
s"(${returnType.get}) $funcCall"
+ if (needTryCatch) {
+ s"""
+ try {
+ $resultVal = $castFuncCall;
+ } catch (Exception e) {
+ org.apache.spark.unsafe.Platform.throwException(e);
+ }
+ """
+ } else {
+ s"$resultVal = $castFuncCall;"
+ }
+ }
}
/**
@@ -328,6 +347,8 @@ case class StaticInvoke(
val (argCode, argString, resultIsNull) = prepareArguments(ctx)
+ val needTryCatch = method.getExceptionTypes.nonEmpty
+
val callFunc = s"$objectName.$functionName($argString)"
val prepareIsNull = if (nullable) {
@@ -340,14 +361,15 @@ case class StaticInvoke(
val evaluate = if (returnNullable && !method.getReturnType.isPrimitive) {
if (CodeGenerator.defaultValue(dataType) == "null") {
s"""
- ${ev.value} = ($javaType) $callFunc;
+ ${getFuncResult(needTryCatch, ev.value, callFunc, Some(javaType))}
${ev.isNull} = ${ev.value} == null;
"""
} else {
val boxedResult = ctx.freshName("boxedResult")
val boxedJavaType = CodeGenerator.boxedType(dataType)
s"""
- $boxedJavaType $boxedResult = ($boxedJavaType) $callFunc;
+ $boxedJavaType $boxedResult = null;
+ ${getFuncResult(needTryCatch, boxedResult, callFunc,
Some(boxedJavaType))}
${ev.isNull} = $boxedResult == null;
if (!${ev.isNull}) {
${ev.value} = $boxedResult;
@@ -355,7 +377,7 @@ case class StaticInvoke(
"""
}
} else {
- s"${ev.value} = ($javaType) $callFunc;"
+ getFuncResult(needTryCatch, ev.value, callFunc, Some(javaType))
}
val code = code"""
@@ -474,38 +496,27 @@ case class Invoke(
val returnPrimitive = method.isDefined &&
method.get.getReturnType.isPrimitive
val needTryCatch = method.isDefined &&
method.get.getExceptionTypes.nonEmpty
- def getFuncResult(resultVal: String, funcCall: String): String = if
(needTryCatch) {
- s"""
- try {
- $resultVal = $funcCall;
- } catch (Exception e) {
- org.apache.spark.unsafe.Platform.throwException(e);
- }
- """
- } else {
- s"$resultVal = $funcCall;"
- }
-
val evaluate = if (returnPrimitive) {
- getFuncResult(ev.value, s"${obj.value}.$encodedFunctionName($argString)")
+ getFuncResult(needTryCatch, ev.value,
s"${obj.value}.$encodedFunctionName($argString)")
} else {
val funcResult = ctx.freshName("funcResult")
// If the function can return null, we do an extra check to make sure
our null bit is still
// set correctly.
val assignResult = if (!returnNullable) {
- s"${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;"
+ s"${ev.value} = $funcResult;"
} else {
s"""
if ($funcResult != null) {
- ${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;
+ ${ev.value} = $funcResult;
} else {
${ev.isNull} = true;
}
"""
}
s"""
- Object $funcResult = null;
- ${getFuncResult(funcResult,
s"${obj.value}.$encodedFunctionName($argString)")}
+ ${CodeGenerator.boxedType(javaType)} $funcResult = null;
+ ${getFuncResult(needTryCatch, funcResult,
s"${obj.value}.$encodedFunctionName($argString)",
+ Some(CodeGenerator.boxedType(javaType)))}
$assignResult
"""
}
diff --git
a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/TestThrowExceptionMethod.java
b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/TestThrowExceptionMethod.java
new file mode 100644
index 000000000000..e74989021ea5
--- /dev/null
+++
b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/TestThrowExceptionMethod.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions;
+
+import java.io.IOException;
+import java.io.Serializable;
+
+public class TestThrowExceptionMethod implements Serializable {
+
+ public int invoke(int i) throws IOException {
+ if (i != 0) {
+ return i * 2;
+ } else {
+ throw new IOException("Invoke the method that throw IOException");
+ }
+ }
+
+ public static int staticInvoke(int i) throws IOException {
+ if (i != 0) {
+ return i * 2;
+ } else {
+ throw new IOException("StaticInvoke the method that throw IOException");
+ }
+ }
+}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index d31e76469f53..215362c47b94 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -780,6 +780,30 @@ class ObjectExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
inputRow)
}
}
+
+ test("Invoke call the method that throw Exception") {
+ val targetObject = new TestThrowExceptionMethod
+ val funcClass = classOf[TestThrowExceptionMethod]
+ val funcObj = Literal.create(targetObject, ObjectType(funcClass))
+
+ val inputInt = Seq(BoundReference(0, IntegerType, nullable = true))
+
+ checkObjectExprEvaluation(
+ Invoke(funcObj, "invoke", IntegerType, inputInt),
+ 2,
+ InternalRow.fromSeq(Seq(Integer.valueOf(1))))
+ }
+
+ test("StaticInvoke call the method that throw Exception") {
+ val funcClass = classOf[TestThrowExceptionMethod]
+
+ val inputInt = Seq(BoundReference(0, IntegerType, nullable = true))
+
+ checkObjectExprEvaluation(
+ StaticInvoke(funcClass, IntegerType, "staticInvoke", inputInt),
+ 2,
+ InternalRow.fromSeq(Seq(Integer.valueOf(1))))
+ }
}
class TestBean extends Serializable {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]