This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 3a3b5b33 [SEDONA-311] Support overloaded functions in
InferredExpression (#874)
3a3b5b33 is described below
commit 3a3b5b33bed7bdf34ae928f95af7a56c8065f4ff
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Tue Jun 27 13:23:28 2023 +0800
[SEDONA-311] Support overloaded functions in InferredExpression (#874)
---
.../java/org/apache/sedona/common/Functions.java | 6 +--
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 2 +-
.../sql/sedona_sql/expressions/Functions.scala | 2 +-
.../expressions/InferredExpression.scala | 63 +++++++---------------
.../sql/sedona_sql/expressions/st_functions.scala | 4 +-
5 files changed, 25 insertions(+), 52 deletions(-)
diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java
b/common/src/main/java/org/apache/sedona/common/Functions.java
index 5b8ea3e6..2d3c3f6c 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -905,15 +905,15 @@ public class Functions {
return geometry;
}
- public static Geometry affine(Geometry geometry, Double a, Double b,
Double d, Double e, Double xOff, Double yOff, Double c,
- Double f, Double g, Double h, Double i,
Double zOff) {
+ public static Geometry affine(Geometry geometry, double a, double b,
double d, double e, double xOff, double yOff, double c,
+ double f, double g, double h, double i,
double zOff) {
if (!geometry.isEmpty()) {
GeomUtils.affineGeom(geometry, a, b, d, e, xOff, yOff, c, f, g, h,
i, zOff);
}
return geometry;
}
- public static Geometry affine(Geometry geometry, Double a, Double b,
Double d, Double e, Double xOff, Double yOff) {
+ public static Geometry affine(Geometry geometry, double a, double b,
double d, double e, double xOff, double yOff) {
if (!geometry.isEmpty()) {
GeomUtils.affineGeom(geometry, a, b, d, e, xOff, yOff, null, null,
null, null, null, null);
}
diff --git a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index bec5a76e..ccb36175 100644
--- a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -153,7 +153,7 @@ object Catalog {
function[ST_NRings](),
function[ST_Translate](0.0),
function[ST_FrechetDistance](),
- function[ST_Affine](null, null, null, null, null, null),
+ function[ST_Affine](),
function[ST_BoundingDiagonal](),
function[ST_HausdorffDistance](-1),
// Expression for rasters
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 904920c9..1ef150d3 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -1013,7 +1013,7 @@ case class ST_FrechetDistance(inputExpressions:
Seq[Expression])
}
case class ST_Affine(inputExpressions: Seq[Expression])
- extends
InferredExpression(InferrableFunction.allowSixRightNull(Functions.affine _))
with FoldableExpression {
+ extends InferredExpression(inferrableFunction13(Functions.affine),
inferrableFunction7(Functions.affine)) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) =
{
copy(inputExpressions = newChildren)
}
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
index a943765c..41f2aede 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/InferredExpression.scala
@@ -34,22 +34,36 @@ import scala.reflect.runtime.universe.typeOf
/**
* This is the base class for wrapping Java/Scala functions as a catalyst
expression in Spark SQL.
- * @param f The function to be wrapped. Subclasses can simply pass a function
to this constructor,
+ * @param fSeq The functions to be wrapped. Subclasses can simply pass a
function to this constructor,
* and the function will be converted to [[InferrableFunction]] by
[[InferrableFunctionConverter]]
* automatically.
*/
-abstract class InferredExpression(f: InferrableFunction)
+abstract class InferredExpression(fSeq: InferrableFunction *)
extends Expression with ImplicitCastInputTypes with SerdeAware with
CodegenFallback with FoldableExpression
with Serializable {
+
def inputExpressions: Seq[Expression]
+
+ lazy val f: InferrableFunction = fSeq match {
+ // If there is only one function, simply use it and let
org.apache.sedona.sql.UDF.Catalog handle default arguments.
+ case Seq(f) => f
+ // If there are multiple overloaded functions, find the one with the same
number of arguments as the input
+ // expressions. Please note that the Catalog won't be able to handle
default arguments in this case. We'll
+ // move default argument handling from Catalog to this class in the future.
+ case _ => fSeq.find(f => f.sparkInputTypes.size == inputExpressions.size)
match {
+ case Some(f) => f
+ case None => throw new IllegalArgumentException(s"No overloaded
function ${getClass.getName} has ${inputExpressions.size} arguments")
+ }
+ }
+
override def children: Seq[Expression] = inputExpressions
override def toString: String = s" **${getClass.getName}** "
override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] = f.sparkInputTypes
override def dataType: DataType = f.sparkReturnType
- private val argExtractors: Array[InternalRow => Any] =
f.buildExtractors(inputExpressions)
- private val evaluator: InternalRow => Any = f.evaluatorBuilder(argExtractors)
+ private lazy val argExtractors: Array[InternalRow => Any] =
f.buildExtractors(inputExpressions)
+ private lazy val evaluator: InternalRow => Any =
f.evaluatorBuilder(argExtractors)
override def eval(input: InternalRow): Any = f.serializer(evaluator(input))
override def evalWithoutSerialization(input: InternalRow): Any =
evaluator(input)
@@ -220,45 +234,4 @@ object InferrableFunction {
}
})
}
-
- def allowSixRightNull[R, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12,
A13](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13) => R)
-
(implicit typeTag: TypeTag[(A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11,
A12, A13) => R]): InferrableFunction = {
- apply(typeTag, extractors => {
- val func = f.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any) => Any]
- val extractor1 = extractors(0)
- val extractor2 = extractors(1)
- val extractor3 = extractors(2)
- val extractor4 = extractors(3)
- val extractor5 = extractors(4)
- val extractor6 = extractors(5)
- val extractor7 = extractors(6)
- val extractor8 = extractors(7)
- val extractor9 = extractors(8)
- val extractor10 = extractors(9)
- val extractor11 = extractors(10)
- val extractor12 = extractors(11)
- val extractor13 = extractors(12)
- input => {
- val arg1 = extractor1(input)
- val arg2 = extractor2(input)
- val arg3 = extractor3(input)
- val arg4 = extractor4(input)
- val arg5 = extractor5(input)
- val arg6 = extractor6(input)
- val arg7 = extractor7(input)
- val arg8 = extractor8(input)
- val arg9 = extractor9(input)
- val arg10 = extractor10(input)
- val arg11 = extractor11(input)
- val arg12 = extractor12(input)
- val arg13 = extractor13(input)
- if (arg1 != null && arg2 != null && arg3 != null && arg4 != null &&
arg5 != null && arg6 != null && arg7 != null) {
- func(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10,
arg11, arg12, arg13)
- } else {
- null
- }
- }
- })
- }
-
}
diff --git
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
index fadde48c..9c5a64a9 100644
---
a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
+++
b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
@@ -339,10 +339,10 @@ object st_functions extends DataFrameAPI {
wrapExpression[ST_Affine](geometry, a, b, d, e, xOff, yOff, c, f, g, h, i,
zOff)
def ST_Affine(geometry: Column, a: Column, b: Column, d: Column, e: Column,
xOff: Column, yOff: Column) =
- wrapExpression[ST_Affine](geometry, a, b, d, e, xOff, yOff, null, null,
null, null, null, null)
+ wrapExpression[ST_Affine](geometry, a, b, d, e, xOff, yOff)
def ST_Affine(geometry: String, a: Double, b: Double, d: Double, e: Double,
xOff: Double, yOff: Double) =
- wrapExpression[ST_Affine](geometry, a, b, d, e, xOff, yOff, null, null,
null, null, null, null)
+ wrapExpression[ST_Affine](geometry, a, b, d, e, xOff, yOff)
def ST_BoundingDiagonal(geometry: Column) =
wrapExpression[ST_BoundingDiagonal](geometry)