This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new b190e2b  [SPARK-35288][SQL] StaticInvoke should find the method 
without exact argument classes match
b190e2b is described below

commit b190e2b36dba9ad85c2b30a2b693aa9defd6de02
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Fri May 7 09:07:57 2021 -0700

    [SPARK-35288][SQL] StaticInvoke should find the method without exact 
argument classes match
    
    ### What changes were proposed in this pull request?
    
    This patch proposes to make StaticInvoke able to find method with given 
method name even the parameter types do not exactly match to argument classes.
    
    ### Why are the changes needed?
    
    Unlike `Invoke`, `StaticInvoke` only tries to get the method with exact 
argument classes. If the calling method's parameter types are not exactly 
matched with the argument classes, `StaticInvoke` cannot find the method.
    
    `StaticInvoke` should be able to find the method under the cases too.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. `StaticInvoke` can find a method even the argument classes are not 
exactly matched.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #32413 from viirya/static-invoke.
    
    Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com>
    (cherry picked from commit 33fbf5647b4a5587c78ac51339c0cbc9d70547a4)
    Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com>
---
 .../sql/catalyst/expressions/objects/objects.scala | 56 ++++++++++++----------
 .../expressions/ObjectExpressionsSuite.scala       | 34 +++++++++++--
 2 files changed, 60 insertions(+), 30 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 066188a..e5e9999 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -141,6 +141,34 @@ trait InvokeLike extends Expression with NonSQLExpression {
       }
     }
   }
+
+  final def findMethod(cls: Class[_], functionName: String, argClasses: 
Seq[Class[_]]): Method = {
+    // Looking with function name + argument classes first.
+    try {
+      cls.getMethod(functionName, argClasses: _*)
+    } catch {
+      case _: NoSuchMethodException =>
+        // For some cases, e.g. arg class is Object, `getMethod` cannot find 
the method.
+        // We look at function name + argument length
+        val m = cls.getMethods.filter { m =>
+          m.getName == functionName && m.getParameterCount == arguments.length
+        }
+        if (m.isEmpty) {
+          sys.error(s"Couldn't find $functionName on $cls")
+        } else if (m.length > 1) {
+          // More than one matched method signature. Exclude synthetic one, 
e.g. generic one.
+          val realMethods = m.filter(!_.isSynthetic)
+          if (realMethods.length > 1) {
+            // Ambiguous case, we don't know which method to choose, just fail 
it.
+            sys.error(s"Found ${realMethods.length} $functionName on $cls")
+          } else {
+            realMethods.head
+          }
+        } else {
+          m.head
+        }
+    }
+  }
 }
 
 /**
@@ -232,7 +260,7 @@ case class StaticInvoke(
   override def children: Seq[Expression] = arguments
 
   lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
-  @transient lazy val method = cls.getDeclaredMethod(functionName, argClasses 
: _*)
+  @transient lazy val method = findMethod(cls, functionName, argClasses)
 
   override def eval(input: InternalRow): Any = {
     invoke(null, method, arguments, input, dataType)
@@ -319,31 +347,7 @@ case class Invoke(
 
   @transient lazy val method = targetObject.dataType match {
     case ObjectType(cls) =>
-      // Looking with function name + argument classes first.
-      try {
-        Some(cls.getMethod(encodedFunctionName, argClasses: _*))
-      } catch {
-        case _: NoSuchMethodException =>
-          // For some cases, e.g. arg class is Object, `getMethod` cannot find 
the method.
-          // We look at function name + argument length
-          val m = cls.getMethods.filter { m =>
-            m.getName == encodedFunctionName && m.getParameterCount == 
arguments.length
-          }
-          if (m.isEmpty) {
-            sys.error(s"Couldn't find $encodedFunctionName on $cls")
-          } else if (m.length > 1) {
-            // More than one matched method signature. Exclude synthetic one, 
e.g. generic one.
-            val realMethods = m.filter(!_.isSynthetic)
-            if (realMethods.length > 1) {
-              // Ambiguous case, we don't know which method to choose, just 
fail it.
-              sys.error(s"Found ${realMethods.length} $encodedFunctionName on 
$cls")
-            } else {
-              Some(realMethods.head)
-            }
-          } else {
-            Some(m.head)
-          }
-      }
+      Some(findMethod(cls, encodedFunctionName, argClasses))
     case _ => None
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 1307a24..1907e09 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -634,8 +634,22 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     val clsType = ObjectType(classOf[ConcreteClass])
     val obj = new ConcreteClass
 
+    val input = (1, 2)
     checkObjectExprEvaluation(
-      Invoke(Literal(obj, clsType), "testFunc", IntegerType, Seq(Literal(1))), 
0)
+      Invoke(Literal(obj, clsType), "testFunc", IntegerType,
+        Seq(Literal(input, ObjectType(input.getClass)))), 2)
+  }
+
+  test("SPARK-35288: static invoke should find method without exact param type 
match") {
+    val input = (1, 2)
+
+    checkObjectExprEvaluation(
+      StaticInvoke(TestStaticInvoke.getClass, IntegerType, "func",
+        Seq(Literal(input, ObjectType(input.getClass)))), 3)
+
+    checkObjectExprEvaluation(
+      StaticInvoke(TestStaticInvoke.getClass, IntegerType, "func",
+        Seq(Literal(1, IntegerType))), -1)
   }
 }
 
@@ -648,10 +662,22 @@ class TestBean extends Serializable {
     assert(i != null, "this setter should not be called with null.")
 }
 
+object TestStaticInvoke {
+  def func(param: Any): Int = param match {
+    case pair: Tuple2[_, _] =>
+      pair.asInstanceOf[Tuple2[Int, Int]]._1 + pair.asInstanceOf[Tuple2[Int, 
Int]]._2
+    case _ => -1
+  }
+}
+
 abstract class BaseClass[T] {
-  def testFunc(param: T): T
+  def testFunc(param: T): Int
 }
 
-class ConcreteClass extends BaseClass[Int] with Serializable {
-  override def testFunc(param: Int): Int = param - 1
+class ConcreteClass extends BaseClass[Product] with Serializable {
+  override def testFunc(param: Product): Int = param match {
+    case _: Tuple2[_, _] => 2
+    case _: Tuple3[_, _, _] => 3
+    case _ => 4
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to