This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 3a3b5b33 [SEDONA-311] Support overloaded functions in 
InferredExpression (#874)
3a3b5b33 is described below

commit 3a3b5b33bed7bdf34ae928f95af7a56c8065f4ff
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Tue Jun 27 13:23:28 2023 +0800

    [SEDONA-311] Support overloaded functions in InferredExpression (#874)
---
 .../java/org/apache/sedona/common/Functions.java   |  6 +--
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |  2 +-
 .../sql/sedona_sql/expressions/Functions.scala     |  2 +-
 .../expressions/InferredExpression.scala           | 63 +++++++---------------
 .../sql/sedona_sql/expressions/st_functions.scala  |  4 +-
 5 files changed, 25 insertions(+), 52 deletions(-)

diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java 
b/common/src/main/java/org/apache/sedona/common/Functions.java
index 5b8ea3e6..2d3c3f6c 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -905,15 +905,15 @@ public class Functions {
         return geometry;
     }
 
-    public static Geometry affine(Geometry geometry, Double a, Double b, 
Double d, Double e, Double xOff, Double yOff, Double c,
-                                  Double f, Double g, Double h, Double i, 
Double zOff) {
+    public static Geometry affine(Geometry geometry, double a, double b, 
double d, double e, double xOff, double yOff, double c,
+                                  double f, double g, double h, double i, 
double zOff) {
         if (!geometry.isEmpty()) {
             GeomUtils.affineGeom(geometry, a, b, d, e, xOff, yOff, c, f, g, h, 
i, zOff);
         }
         return geometry;
     }
 
-    public static Geometry affine(Geometry geometry, Double a, Double b, 
Double d, Double e, Double xOff, Double yOff) {
+    public static Geometry affine(Geometry geometry, double a, double b, 
double d, double e, double xOff, double yOff) {
         if (!geometry.isEmpty()) {
             GeomUtils.affineGeom(geometry, a, b, d, e, xOff, yOff, null, null, 
null, null, null, null);
         }
diff --git a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala 
b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index bec5a76e..ccb36175 100644
--- a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -153,7 +153,7 @@ object Catalog {
     function[ST_NRings](),
     function[ST_Translate](0.0),
     function[ST_FrechetDistance](),
-    function[ST_Affine](null, null, null, null, null, null),
+    function[ST_Affine](),
     function[ST_BoundingDiagonal](),
     function[ST_HausdorffDistance](-1),
     // Expression for rasters
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 904920c9..1ef150d3 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -1013,7 +1013,7 @@ case class ST_FrechetDistance(inputExpressions: 
Seq[Expression])
 }
 
 case class ST_Affine(inputExpressions: Seq[Expression])
-  extends 
InferredExpression(InferrableFunction.allowSixRightNull(Functions.affine _)) 
with FoldableExpression {
+  extends InferredExpression(inferrableFunction13(Functions.affine), 
inferrableFunction7(Functions.affine)) {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = 
{
     copy(inputExpressions = newChildren)
   }
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
index a943765c..41f2aede 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
@@ -34,22 +34,36 @@ import scala.reflect.runtime.universe.typeOf
 
 /**
  * This is the base class for wrapping Java/Scala functions as a catalyst 
expression in Spark SQL.
- * @param f The function to be wrapped. Subclasses can simply pass a function 
to this constructor,
+ * @param fSeq The functions to be wrapped. Subclasses can simply pass a 
function to this constructor,
  *          and the function will be converted to [[InferrableFunction]] by 
[[InferrableFunctionConverter]]
  *          automatically.
  */
-abstract class InferredExpression(f: InferrableFunction)
+abstract class InferredExpression(fSeq: InferrableFunction *)
   extends Expression with ImplicitCastInputTypes with SerdeAware with 
CodegenFallback with FoldableExpression
     with Serializable {
+
   def inputExpressions: Seq[Expression]
+
+  lazy val f: InferrableFunction = fSeq match {
+    // If there is only one function, simply use it and let 
org.apache.sedona.sql.UDF.Catalog handle default arguments.
+    case Seq(f) => f
+    // If there are multiple overloaded functions, find the one with the same 
number of arguments as the input
+    // expressions. Please note that the Catalog won't be able to handle 
default arguments in this case. We'll
+    // move default argument handling from Catalog to this class in the future.
+    case _ => fSeq.find(f => f.sparkInputTypes.size == inputExpressions.size) 
match {
+        case Some(f) => f
+        case None => throw new IllegalArgumentException(s"No overloaded 
function ${getClass.getName} has ${inputExpressions.size} arguments")
+      }
+  }
+
   override def children: Seq[Expression] = inputExpressions
   override def toString: String = s" **${getClass.getName}**  "
   override def nullable: Boolean = true
   override def inputTypes: Seq[AbstractDataType] = f.sparkInputTypes
   override def dataType: DataType = f.sparkReturnType
 
-  private val argExtractors: Array[InternalRow => Any] = 
f.buildExtractors(inputExpressions)
-  private val evaluator: InternalRow => Any = f.evaluatorBuilder(argExtractors)
+  private lazy val argExtractors: Array[InternalRow => Any] = 
f.buildExtractors(inputExpressions)
+  private lazy val evaluator: InternalRow => Any = 
f.evaluatorBuilder(argExtractors)
 
   override def eval(input: InternalRow): Any = f.serializer(evaluator(input))
   override def evalWithoutSerialization(input: InternalRow): Any = 
evaluator(input)
@@ -220,45 +234,4 @@ object InferrableFunction {
       }
     })
   }
-
-  def allowSixRightNull[R, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, 
A13](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13) => R)
-                                                                               
   (implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, 
A12, A13) => R]): InferrableFunction = {
-    apply(typeTag, extractors => {
-      val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, 
Any, Any, Any, Any) => Any]
-      val extractor1 = extractors(0)
-      val extractor2 = extractors(1)
-      val extractor3 = extractors(2)
-      val extractor4 = extractors(3)
-      val extractor5 = extractors(4)
-      val extractor6 = extractors(5)
-      val extractor7 = extractors(6)
-      val extractor8 = extractors(7)
-      val extractor9 = extractors(8)
-      val extractor10 = extractors(9)
-      val extractor11 = extractors(10)
-      val extractor12 = extractors(11)
-      val extractor13 = extractors(12)
-      input => {
-        val arg1 = extractor1(input)
-        val arg2 = extractor2(input)
-        val arg3 = extractor3(input)
-        val arg4 = extractor4(input)
-        val arg5 = extractor5(input)
-        val arg6 = extractor6(input)
-        val arg7 = extractor7(input)
-        val arg8 = extractor8(input)
-        val arg9 = extractor9(input)
-        val arg10 = extractor10(input)
-        val arg11 = extractor11(input)
-        val arg12 = extractor12(input)
-        val arg13 = extractor13(input)
-        if (arg1 != null && arg2 != null && arg3 != null && arg4 != null && 
arg5 != null && arg6 != null && arg7 != null) {
-          func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, 
arg11, arg12, arg13)
-        } else {
-          null
-        }
-      }
-    })
-  }
-
 }
diff --git 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
index fadde48c..9c5a64a9 100644
--- 
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
+++ 
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
@@ -339,10 +339,10 @@ object st_functions extends DataFrameAPI {
     wrapExpression[ST_Affine](geometry, a, b, d, e, xOff, yOff, c, f, g, h, i, 
zOff)
 
   def ST_Affine(geometry: Column, a: Column, b: Column, d: Column, e: Column, 
xOff: Column, yOff: Column) =
-    wrapExpression[ST_Affine](geometry, a, b, d, e, xOff, yOff, null, null, 
null, null, null, null)
+    wrapExpression[ST_Affine](geometry, a, b, d, e, xOff, yOff)
 
   def ST_Affine(geometry: String, a: Double, b: Double, d: Double, e: Double, 
xOff: Double, yOff: Double) =
-    wrapExpression[ST_Affine](geometry, a, b, d, e, xOff, yOff, null, null, 
null, null, null, null)
+    wrapExpression[ST_Affine](geometry, a, b, d, e, xOff, yOff)
 
   def ST_BoundingDiagonal(geometry: Column) =
     wrapExpression[ST_BoundingDiagonal](geometry)

Reply via email to