This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 91dd5e3d74b7 [SPARK-50148][SQL] Make `StaticInvoke` compatible with 
the method that declare throw exception
91dd5e3d74b7 is described below

commit 91dd5e3d74b73d0d5d1e51387352cb61e1e3b5d9
Author: panbingkun <[email protected]>
AuthorDate: Wed Oct 30 17:53:06 2024 +0100

    [SPARK-50148][SQL] Make `StaticInvoke` compatible with the method that 
declare throw exception
    
    ### What changes were proposed in this pull request?
    The pr aims to make `StaticInvoke` compatible with the method that `declare 
throw exception`.
    
    ### Why are the changes needed?
    Currently, our `StaticInvoke` does not support calling the method that 
`declare throw exception`, while `Invoke` supports it, let's align it.
    
    - Invoke
    
https://github.com/apache/spark/blob/bb15eb7b91ab775bdb84b6b17353a706794b122d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala#L457-L469
    
    - StaticInvoke
    
https://github.com/apache/spark/blob/bb15eb7b91ab775bdb84b6b17353a706794b122d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala#L313
    
    ### Does this PR introduce _any_ user-facing change?
    No, only for spark developer.
    
    ### How was this patch tested?
    - Add new UT.
    - Pass GA.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48679 from panbingkun/SPARK-50148.
    
    Authored-by: panbingkun <[email protected]>
    Signed-off-by: Max Gekk <[email protected]>
---
 .../sql/catalyst/expressions/objects/objects.scala | 51 +++++++++++++---------
 .../expressions/TestThrowExceptionMethod.java      | 39 +++++++++++++++++
 .../expressions/ObjectExpressionsSuite.scala       | 24 ++++++++++
 3 files changed, 94 insertions(+), 20 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 7c198f05cf49..5c786bc5ddbf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -211,6 +211,25 @@ trait InvokeLike extends Expression with NonSQLExpression 
with ImplicitCastInput
       method
     }
   }
+
+  final def getFuncResult(
+      needTryCatch: Boolean,
+      resultVal: String,
+      funcCall: String,
+      returnType: Option[String] = None): String = {
+    val castFuncCall = if (returnType.isEmpty) funcCall else 
s"(${returnType.get}) $funcCall"
+    if (needTryCatch) {
+      s"""
+        try {
+          $resultVal = $castFuncCall;
+        } catch (Exception e) {
+          org.apache.spark.unsafe.Platform.throwException(e);
+        }
+      """
+    } else {
+      s"$resultVal = $castFuncCall;"
+    }
+  }
 }
 
 /**
@@ -328,6 +347,8 @@ case class StaticInvoke(
 
     val (argCode, argString, resultIsNull) = prepareArguments(ctx)
 
+    val needTryCatch = method.getExceptionTypes.nonEmpty
+
     val callFunc = s"$objectName.$functionName($argString)"
 
     val prepareIsNull = if (nullable) {
@@ -340,14 +361,15 @@ case class StaticInvoke(
     val evaluate = if (returnNullable && !method.getReturnType.isPrimitive) {
       if (CodeGenerator.defaultValue(dataType) == "null") {
         s"""
-          ${ev.value} = ($javaType) $callFunc;
+          ${getFuncResult(needTryCatch, ev.value, callFunc, Some(javaType))}
           ${ev.isNull} = ${ev.value} == null;
         """
       } else {
         val boxedResult = ctx.freshName("boxedResult")
         val boxedJavaType = CodeGenerator.boxedType(dataType)
         s"""
-          $boxedJavaType $boxedResult = ($boxedJavaType) $callFunc;
+          $boxedJavaType $boxedResult = null;
+          ${getFuncResult(needTryCatch, boxedResult, callFunc, 
Some(boxedJavaType))}
           ${ev.isNull} = $boxedResult == null;
           if (!${ev.isNull}) {
             ${ev.value} = $boxedResult;
@@ -355,7 +377,7 @@ case class StaticInvoke(
         """
       }
     } else {
-      s"${ev.value} = ($javaType) $callFunc;"
+      getFuncResult(needTryCatch, ev.value, callFunc, Some(javaType))
     }
 
     val code = code"""
@@ -474,38 +496,27 @@ case class Invoke(
     val returnPrimitive = method.isDefined && 
method.get.getReturnType.isPrimitive
     val needTryCatch = method.isDefined && 
method.get.getExceptionTypes.nonEmpty
 
-    def getFuncResult(resultVal: String, funcCall: String): String = if 
(needTryCatch) {
-      s"""
-        try {
-          $resultVal = $funcCall;
-        } catch (Exception e) {
-          org.apache.spark.unsafe.Platform.throwException(e);
-        }
-      """
-    } else {
-      s"$resultVal = $funcCall;"
-    }
-
     val evaluate = if (returnPrimitive) {
-      getFuncResult(ev.value, s"${obj.value}.$encodedFunctionName($argString)")
+      getFuncResult(needTryCatch, ev.value, 
s"${obj.value}.$encodedFunctionName($argString)")
     } else {
       val funcResult = ctx.freshName("funcResult")
       // If the function can return null, we do an extra check to make sure 
our null bit is still
       // set correctly.
       val assignResult = if (!returnNullable) {
-        s"${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;"
+        s"${ev.value} = $funcResult;"
       } else {
         s"""
           if ($funcResult != null) {
-            ${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;
+            ${ev.value} = $funcResult;
           } else {
             ${ev.isNull} = true;
           }
         """
       }
       s"""
-        Object $funcResult = null;
-        ${getFuncResult(funcResult, 
s"${obj.value}.$encodedFunctionName($argString)")}
+        ${CodeGenerator.boxedType(javaType)} $funcResult = null;
+        ${getFuncResult(needTryCatch, funcResult, 
s"${obj.value}.$encodedFunctionName($argString)",
+          Some(CodeGenerator.boxedType(javaType)))}
         $assignResult
       """
     }
diff --git 
a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/TestThrowExceptionMethod.java
 
b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/TestThrowExceptionMethod.java
new file mode 100644
index 000000000000..e74989021ea5
--- /dev/null
+++ 
b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/TestThrowExceptionMethod.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions;
+
+import java.io.IOException;
+import java.io.Serializable;
+
+public class TestThrowExceptionMethod implements Serializable {
+
+  public int invoke(int i) throws IOException {
+    if (i != 0) {
+      return i * 2;
+    } else {
+      throw new IOException("Invoke the method that throw IOException");
+    }
+  }
+
+  public static int staticInvoke(int i) throws IOException {
+    if (i != 0) {
+      return i * 2;
+    } else {
+      throw new IOException("StaticInvoke the method that throw IOException");
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index d31e76469f53..215362c47b94 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -780,6 +780,30 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
           inputRow)
     }
   }
+
+  test("Invoke call the method that throw Exception") {
+    val targetObject = new TestThrowExceptionMethod
+    val funcClass = classOf[TestThrowExceptionMethod]
+    val funcObj = Literal.create(targetObject, ObjectType(funcClass))
+
+    val inputInt = Seq(BoundReference(0, IntegerType, nullable = true))
+
+    checkObjectExprEvaluation(
+      Invoke(funcObj, "invoke", IntegerType, inputInt),
+      2,
+      InternalRow.fromSeq(Seq(Integer.valueOf(1))))
+  }
+
+  test("StaticInvoke call the method that throw Exception") {
+    val funcClass = classOf[TestThrowExceptionMethod]
+
+    val inputInt = Seq(BoundReference(0, IntegerType, nullable = true))
+
+    checkObjectExprEvaluation(
+      StaticInvoke(funcClass, IntegerType, "staticInvoke", inputInt),
+      2,
+      InternalRow.fromSeq(Seq(Integer.valueOf(1))))
+  }
 }
 
 class TestBean extends Serializable {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to