This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 3a3b8d35ec [SEDONA 710] Rename Geostats SQL classes to generic name; 
merge UdfRegistrator into AbstractCatalog (#1809)
3a3b8d35ec is described below

commit 3a3b8d35ec3a132804ef60d53567ee89d67940c6
Author: James Willis <[email protected]>
AuthorDate: Wed Feb 12 16:23:39 2025 -0800

    [SEDONA 710] Rename Geostats SQL classes to generic name; merge 
UdfRegistrator into AbstractCatalog (#1809)
    
    Co-authored-by: jameswillis <[email protected]>
---
 .../org/apache/sedona/spark/SedonaContext.scala    |  19 ++--
 .../apache/sedona/sql/UDF/AbstractCatalog.scala    |  25 +++++
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |  10 +-
 .../org/apache/sedona/sql/UDF/UdfRegistrator.scala |  54 ---------
 .../sedona/sql/utils/SedonaSQLRegistrator.scala    |   4 +-
 .../sedona_sql/expressions/GeoStatsFunctions.scala | 103 ++++-------------
 .../sedona_sql/expressions/PhysicalFunction.scala  | 109 ++++++++++++++++++
 .../optimization/ExtractGeoStatsFunctions.scala    | 120 --------------------
 .../optimization/ExtractPhysicalFunctions.scala    | 122 +++++++++++++++++++++
 ...tsFunction.scala => EvalPhysicalFunction.scala} |   2 +-
 .../function/EvalPhysicalFunctionExec.scala}       |   8 +-
 .../function/EvalPhysicalFunctionStrategy.scala}   |  12 +-
 12 files changed, 305 insertions(+), 283 deletions(-)

diff --git 
a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala 
b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
index fe2926fc51..7cfb8670be 100644
--- a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
@@ -21,12 +21,12 @@ package org.apache.sedona.spark
 import org.apache.sedona.common.utils.TelemetryCollector
 import org.apache.sedona.core.serde.SedonaKryoRegistrator
 import org.apache.sedona.sql.RasterRegistrator
-import org.apache.sedona.sql.UDF.UdfRegistrator
+import org.apache.sedona.sql.UDF.Catalog
 import org.apache.sedona.sql.UDT.UdtRegistrator
 import org.apache.spark.serializer.KryoSerializer
-import org.apache.spark.sql.sedona_sql.optimization.{ExtractGeoStatsFunctions, 
SpatialFilterPushDownForGeoParquet, SpatialTemporalFilterPushDownForStacScan}
-import 
org.apache.spark.sql.sedona_sql.strategy.geostats.EvalGeoStatsFunctionStrategy
+import org.apache.spark.sql.sedona_sql.optimization._
 import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
+import 
org.apache.spark.sql.sedona_sql.strategy.physical.function.EvalPhysicalFunctionStrategy
 import org.apache.spark.sql.{SQLContext, SparkSession}
 
 import scala.annotation.StaticAnnotation
@@ -73,20 +73,21 @@ object SedonaContext {
       }
     }
 
-    // Support geostats functions
-    if 
(!sparkSession.experimental.extraOptimizations.contains(ExtractGeoStatsFunctions))
 {
-      sparkSession.experimental.extraOptimizations ++= 
Seq(ExtractGeoStatsFunctions)
+    // Support physical functions
+    if 
(!sparkSession.experimental.extraOptimizations.contains(ExtractPhysicalFunctions))
 {
+      sparkSession.experimental.extraOptimizations ++= 
Seq(ExtractPhysicalFunctions)
     }
+
     if (!sparkSession.experimental.extraStrategies.exists(
-        _.isInstanceOf[EvalGeoStatsFunctionStrategy])) {
+        _.isInstanceOf[EvalPhysicalFunctionStrategy])) {
       sparkSession.experimental.extraStrategies ++= Seq(
-        new EvalGeoStatsFunctionStrategy(sparkSession))
+        new EvalPhysicalFunctionStrategy(sparkSession))
     }
 
     addGeoParquetToSupportNestedFilterSources(sparkSession)
     RasterRegistrator.registerAll(sparkSession)
     UdtRegistrator.registerAll()
-    UdfRegistrator.registerAll(sparkSession)
+    Catalog.registerAll(sparkSession)
     sparkSession
   }
 
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala 
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala
index 3ad579c38c..f8bb0ac5fe 100644
--- 
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala
+++ 
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala
@@ -18,6 +18,7 @@
  */
 package org.apache.sedona.sql.UDF
 
+import org.apache.spark.sql.{SQLContext, SparkSession, functions}
 import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
 import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, ExpressionInfo, Literal}
@@ -74,4 +75,28 @@ abstract class AbstractCatalog {
 
     (functionIdentifier, expressionInfo, functionBuilder)
   }
+
+  def registerAll(sqlContext: SQLContext): Unit = {
+    registerAll(sqlContext.sparkSession)
+  }
+
+  def registerAll(sparkSession: SparkSession): Unit = {
+    Catalog.expressions.foreach { case (functionIdentifier, expressionInfo, 
functionBuilder) =>
+      sparkSession.sessionState.functionRegistry.registerFunction(
+        functionIdentifier,
+        expressionInfo,
+        functionBuilder)
+    }
+    Catalog.aggregateExpressions.foreach(f =>
+      sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f)))
+  }
+
+  def dropAll(sparkSession: SparkSession): Unit = {
+    Catalog.expressions.foreach { case (functionIdentifier, _, _) =>
+      
sparkSession.sessionState.functionRegistry.dropFunction(functionIdentifier)
+    }
+    Catalog.aggregateExpressions.foreach(f =>
+      sparkSession.sessionState.functionRegistry.dropFunction(
+        FunctionIdentifier(f.getClass.getSimpleName)))
+  }
 }
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala 
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 16c393cdbc..af51a825f8 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -19,14 +19,12 @@
 package org.apache.sedona.sql.UDF
 
 import org.apache.spark.sql.expressions.Aggregator
-import org.apache.spark.sql.sedona_sql.expressions.{ST_InterpolatePoint, _}
 import org.apache.spark.sql.sedona_sql.expressions.collect.ST_Collect
 import org.apache.spark.sql.sedona_sql.expressions.raster._
+import org.apache.spark.sql.sedona_sql.expressions._
 import org.locationtech.jts.geom.Geometry
 import org.locationtech.jts.operation.buffer.BufferParameters
 
-import scala.collection.mutable.ListBuffer
-
 object Catalog extends AbstractCatalog {
 
   override val expressions: Seq[FunctionDescription] = Seq(
@@ -344,9 +342,5 @@ object Catalog extends AbstractCatalog {
     function[ST_WeightedDistanceBandColumn]())
 
   val aggregateExpressions: Seq[Aggregator[Geometry, _, _]] =
-    Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr)
-
-  // Aggregate functions with List as buffer
-  val aggregateExpressions2: Seq[Aggregator[Geometry, ListBuffer[Geometry], 
Geometry]] =
-    Seq(new ST_Union_Aggr())
+    Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr, new ST_Union_Aggr())
 }
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala 
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
deleted file mode 100644
index 30c3cb2e3b..0000000000
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.sedona.sql.UDF
-
-import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.{SQLContext, SparkSession, functions}
-
-object UdfRegistrator {
-
-  def registerAll(sqlContext: SQLContext): Unit = {
-    registerAll(sqlContext.sparkSession)
-  }
-
-  def registerAll(sparkSession: SparkSession): Unit = {
-    Catalog.expressions.foreach { case (functionIdentifier, expressionInfo, 
functionBuilder) =>
-      sparkSession.sessionState.functionRegistry.registerFunction(
-        functionIdentifier,
-        expressionInfo,
-        functionBuilder)
-    }
-    Catalog.aggregateExpressions.foreach(f =>
-      sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))) 
// SPARK3 anchor
-
-    Catalog.aggregateExpressions2.foreach(f =>
-      sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))) 
// SPARK3 anchor
-  }
-
-  def dropAll(sparkSession: SparkSession): Unit = {
-    Catalog.expressions.foreach { case (functionIdentifier, _, _) =>
-      
sparkSession.sessionState.functionRegistry.dropFunction(functionIdentifier)
-    }
-    Catalog.aggregateExpressions.foreach(f =>
-      sparkSession.sessionState.functionRegistry.dropFunction(
-        FunctionIdentifier(f.getClass.getSimpleName)
-      )) // SPARK3 anchor
-//Catalog.aggregateExpressions_UDAF.foreach(f => 
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(f.getClass.getSimpleName)))
 // SPARK2 anchor
-  }
-}
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
 
b/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
index 52f7ceb1cd..a679db084d 100644
--- 
a/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
+++ 
b/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
@@ -20,7 +20,7 @@ package org.apache.sedona.sql.utils
 
 import org.apache.sedona.spark.SedonaContext
 import org.apache.sedona.sql.RasterRegistrator
-import org.apache.sedona.sql.UDF.UdfRegistrator
+import org.apache.sedona.sql.UDF.Catalog
 import org.apache.spark.sql.{SQLContext, SparkSession}
 
 @deprecated("Use SedonaContext instead", "1.4.1")
@@ -44,7 +44,7 @@ object SedonaSQLRegistrator {
     SedonaContext.create(sparkSession, language)
 
   def dropAll(sparkSession: SparkSession): Unit = {
-    UdfRegistrator.dropAll(sparkSession)
+    Catalog.dropAll(sparkSession)
     RasterRegistrator.dropAll(sparkSession)
   }
 }
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala
index 8c6b645daf..75e86510ab 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/GeoStatsFunctions.scala
@@ -23,80 +23,13 @@ import 
org.apache.sedona.stats.Weighting.{addBinaryDistanceBandColumn, addWeight
 import org.apache.sedona.stats.clustering.DBSCAN.dbscan
 import org.apache.sedona.stats.hotspotDetection.GetisOrd.gLocal
 import 
org.apache.sedona.stats.outlierDetection.LocalOutlierFactor.localOutlierFactor
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, Expression, ImplicitCastInputTypes, Literal, 
ScalarSubquery, Unevaluable}
-import org.apache.spark.sql.execution.{LogicalRDD, SparkPlan}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
 import org.apache.spark.sql.functions.{col, struct}
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
 
-import scala.reflect.ClassTag
-
-// We mark ST_GeoStatsFunction as non-deterministic to avoid the filter 
push-down optimization pass
-// duplicates the ST_GeoStatsFunction when pushing down aliased 
ST_GeoStatsFunction through a
-// Project operator. This will make ST_GeoStatsFunction being evaluated twice.
-trait ST_GeoStatsFunction
-    extends Expression
-    with ImplicitCastInputTypes
-    with Unevaluable
-    with Serializable {
-
-  final override lazy val deterministic: Boolean = false
-
-  override def nullable: Boolean = true
-
-  private final lazy val sparkSession = SparkSession.getActiveSession.get
-
-  protected final lazy val geometryColumnName = getInputName(0, "geometry")
-
-  protected def getInputName(i: Int, fieldName: String): String = children(i) 
match {
-    case ref: AttributeReference => ref.name
-    case _ =>
-      throw new IllegalArgumentException(
-        f"$fieldName argument must be a named reference to an existing column")
-  }
-
-  protected def getInputNames(i: Int, fieldName: String): Seq[String] = 
children(
-    i).dataType match {
-    case StructType(fields) => fields.map(_.name)
-    case _ => throw new IllegalArgumentException(f"$fieldName argument must be 
a struct")
-  }
-
-  protected def getResultName(resultAttrs: Seq[Attribute]): String = 
resultAttrs match {
-    case Seq(attr) => attr.name
-    case _ => throw new IllegalArgumentException("resultAttrs must have 
exactly one attribute")
-  }
-
-  protected def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]): 
DataFrame
-
-  protected def getScalarValue[T](i: Int, name: String)(implicit ct: 
ClassTag[T]): T = {
-    children(i) match {
-      case Literal(l: T, _) => l
-      case _: Literal =>
-        throw new IllegalArgumentException(f"$name must be an instance of  
${ct.runtimeClass}")
-      case s: ScalarSubquery =>
-        s.eval() match {
-          case t: T => t
-          case _ =>
-            throw new IllegalArgumentException(
-              f"$name must be an instance of  ${ct.runtimeClass}")
-        }
-      case _ => throw new IllegalArgumentException(f"$name must be a scalar 
value")
-    }
-  }
-
-  def execute(plan: SparkPlan, resultAttrs: Seq[Attribute]): RDD[InternalRow] 
= {
-    val df = doExecute(
-      Dataset.ofRows(sparkSession, LogicalRDD(plan.output, 
plan.execute())(sparkSession)),
-      resultAttrs)
-    df.queryExecution.toRdd
-  }
-
-}
-
-case class ST_DBSCAN(children: Seq[Expression]) extends ST_GeoStatsFunction {
+case class ST_DBSCAN(children: Seq[Expression]) extends 
DataframePhysicalFunction {
 
   override def dataType: DataType = StructType(
     Seq(StructField("isCore", BooleanType), StructField("cluster", LongType)))
@@ -107,7 +40,9 @@ case class ST_DBSCAN(children: Seq[Expression]) extends 
ST_GeoStatsFunction {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): 
Expression =
     copy(children = newChildren)
 
-  override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]): 
DataFrame = {
+  override def transformDataframe(
+      dataframe: DataFrame,
+      resultAttrs: Seq[Attribute]): DataFrame = {
     require(
       !dataframe.columns.contains("__isCore"),
       "__isCore is a  reserved name by the dbscan algorithm. Please rename the 
columns before calling the ST_DBSCAN function.")
@@ -129,7 +64,7 @@ case class ST_DBSCAN(children: Seq[Expression]) extends 
ST_GeoStatsFunction {
   }
 }
 
-case class ST_LocalOutlierFactor(children: Seq[Expression]) extends 
ST_GeoStatsFunction {
+case class ST_LocalOutlierFactor(children: Seq[Expression]) extends 
DataframePhysicalFunction {
 
   override def dataType: DataType = DoubleType
 
@@ -139,7 +74,9 @@ case class ST_LocalOutlierFactor(children: Seq[Expression]) 
extends ST_GeoStatsF
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): 
Expression =
     copy(children = newChildren)
 
-  override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]): 
DataFrame = {
+  override def transformDataframe(
+      dataframe: DataFrame,
+      resultAttrs: Seq[Attribute]): DataFrame = {
     localOutlierFactor(
       dataframe,
       getScalarValue[Int](1, "k"),
@@ -150,7 +87,7 @@ case class ST_LocalOutlierFactor(children: Seq[Expression]) 
extends ST_GeoStatsF
   }
 }
 
-case class ST_GLocal(children: Seq[Expression]) extends ST_GeoStatsFunction {
+case class ST_GLocal(children: Seq[Expression]) extends 
DataframePhysicalFunction {
 
   override def dataType: DataType = StructType(
     Seq(
@@ -172,7 +109,9 @@ case class ST_GLocal(children: Seq[Expression]) extends 
ST_GeoStatsFunction {
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): 
Expression =
     copy(children = newChildren)
 
-  override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]): 
DataFrame = {
+  override def transformDataframe(
+      dataframe: DataFrame,
+      resultAttrs: Seq[Attribute]): DataFrame = {
     gLocal(
       dataframe,
       getInputName(0, "x"),
@@ -187,7 +126,8 @@ case class ST_GLocal(children: Seq[Expression]) extends 
ST_GeoStatsFunction {
   }
 }
 
-case class ST_BinaryDistanceBandColumn(children: Seq[Expression]) extends 
ST_GeoStatsFunction {
+case class ST_BinaryDistanceBandColumn(children: Seq[Expression])
+    extends DataframePhysicalFunction {
   override def dataType: DataType = ArrayType(
     StructType(
       Seq(StructField("neighbor", children(5).dataType), StructField("value", 
DoubleType))))
@@ -198,7 +138,9 @@ case class ST_BinaryDistanceBandColumn(children: 
Seq[Expression]) extends ST_Geo
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): 
Expression =
     copy(children = newChildren)
 
-  override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]): 
DataFrame = {
+  override def transformDataframe(
+      dataframe: DataFrame,
+      resultAttrs: Seq[Attribute]): DataFrame = {
     val attributeNames = getInputNames(5, "attributes")
     require(attributeNames.nonEmpty, "attributes must have at least one 
column")
     require(
@@ -217,7 +159,8 @@ case class ST_BinaryDistanceBandColumn(children: 
Seq[Expression]) extends ST_Geo
   }
 }
 
-case class ST_WeightedDistanceBandColumn(children: Seq[Expression]) extends 
ST_GeoStatsFunction {
+case class ST_WeightedDistanceBandColumn(children: Seq[Expression])
+    extends DataframePhysicalFunction {
 
   override def dataType: DataType = ArrayType(
     StructType(
@@ -237,7 +180,9 @@ case class ST_WeightedDistanceBandColumn(children: 
Seq[Expression]) extends ST_G
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): 
Expression =
     copy(children = newChildren)
 
-  override def doExecute(dataframe: DataFrame, resultAttrs: Seq[Attribute]): 
DataFrame = {
+  override def transformDataframe(
+      dataframe: DataFrame,
+      resultAttrs: Seq[Attribute]): DataFrame = {
     val attributeNames = getInputNames(7, "attributes")
     require(attributeNames.nonEmpty, "attributes must have at least one 
column")
     require(
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/PhysicalFunction.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/PhysicalFunction.scala
new file mode 100644
index 0000000000..253dfe2cfa
--- /dev/null
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/PhysicalFunction.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, Expression, ImplicitCastInputTypes, Literal, 
ScalarSubquery, Unevaluable}
+import org.apache.spark.sql.execution.{LogicalRDD, SparkPlan}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+
+import scala.reflect.ClassTag
+
+/**
+ * PhysicalFunctions are Functions that will be replaced with a Physical Node 
for their
+ * evaluation.
+ *
+ * execute is the method that will be called in order to evaluate the 
function. PhysicalFunction
+ * is marked non-deterministic to avoid the filter push-down optimization pass 
which duplicates
+ * the PhysicalFunction when pushing down aliased PhysicalFunction calls 
through a Project
+ * operator. Otherwise the PhysicalFunction would be evaluated twice.
+ */
+trait PhysicalFunction
+    extends Expression
+    with ImplicitCastInputTypes
+    with Unevaluable
+    with Serializable {
+  final override lazy val deterministic: Boolean = false
+
+  override def nullable: Boolean = true
+
+  protected final lazy val sparkSession = SparkSession.getActiveSession.get
+
+  protected final lazy val geometryColumnName = getInputName(0, "geometry")
+
+  protected def getInputName(i: Int, fieldName: String): String = children(i) 
match {
+    case ref: AttributeReference => ref.name
+    case _ =>
+      throw new IllegalArgumentException(
+        f"$fieldName argument must be a named reference to an existing column")
+  }
+
+  protected def getScalarValue[T](i: Int, name: String)(implicit ct: 
ClassTag[T]): T = {
+    children(i) match {
+      case Literal(l: T, _) => l
+      case _: Literal =>
+        throw new IllegalArgumentException(f"$name must be an instance of  
${ct.runtimeClass}")
+      case s: ScalarSubquery =>
+        s.eval() match {
+          case t: T => t
+          case _ =>
+            throw new IllegalArgumentException(
+              f"$name must be an instance of  ${ct.runtimeClass}")
+        }
+      case _ => throw new IllegalArgumentException(f"$name must be a scalar 
value")
+    }
+  }
+
+  protected def getInputNames(i: Int, fieldName: String): Seq[String] = 
children(
+    i).dataType match {
+    case StructType(fields) => fields.map(_.name)
+    case _ => throw new IllegalArgumentException(f"$fieldName argument must be 
a struct")
+  }
+
+  protected def getResultName(resultAttrs: Seq[Attribute]): String = 
resultAttrs match {
+    case Seq(attr) => attr.name
+    case _ => throw new IllegalArgumentException("resultAttrs must have 
exactly one attribute")
+  }
+
+  def execute(plan: SparkPlan, resultAttrs: Seq[Attribute]): RDD[InternalRow]
+}
+
+/**
+ * DataframePhysicalFunctions are Functions that will be replaced with a 
Physical Node for their
+ * evaluation.
+ *
+ * The physical node will transform the input dataframe into the output 
dataframe. execute handles
+ * conversion of the RDD[InternalRow] to a DataFrame and back. Each 
DataframePhysicalFunction
+ * should implement transformDataframe. The output dataframe should have the 
same schema as the
+ * input dataframe, except for the resultAttrs which should be added to the 
output dataframe.
+ */
+trait DataframePhysicalFunction extends PhysicalFunction {
+
+  protected def transformDataframe(dataframe: DataFrame, resultAttrs: 
Seq[Attribute]): DataFrame
+
+  override def execute(plan: SparkPlan, resultAttrs: Seq[Attribute]): 
RDD[InternalRow] = {
+    val df = transformDataframe(
+      Dataset.ofRows(sparkSession, LogicalRDD(plan.output, 
plan.execute())(sparkSession)),
+      resultAttrs)
+    df.queryExecution.toRdd
+  }
+
+}
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractGeoStatsFunctions.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractGeoStatsFunctions.scala
deleted file mode 100644
index 6b4cf9ccea..0000000000
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractGeoStatsFunctions.scala
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.spark.sql.sedona_sql.optimization
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.sedona_sql.expressions.ST_GeoStatsFunction
-import org.apache.spark.sql.sedona_sql.plans.logical.EvalGeoStatsFunction
-
-import scala.collection.mutable
-
-/**
- * Extracts GeoStats functions from operators, rewriting the query plan so 
that the geo-stats
- * functions can be evaluated alone in its own physical executors.
- */
-object ExtractGeoStatsFunctions extends Rule[LogicalPlan] {
-  var geoStatsResultCount = 0
-
-  private def collectGeoStatsFunctionsFromExpressions(
-      expressions: Seq[Expression]): Seq[ST_GeoStatsFunction] = {
-    def collectGeoStatsFunctions(expr: Expression): Seq[ST_GeoStatsFunction] = 
expr match {
-      case expr: ST_GeoStatsFunction => Seq(expr)
-      case e => e.children.flatMap(collectGeoStatsFunctions)
-    }
-    expressions.flatMap(collectGeoStatsFunctions)
-  }
-
-  def apply(plan: LogicalPlan): LogicalPlan = plan match {
-    // SPARK-26293: A subquery will be rewritten into join later, and will go 
through this rule
-    // eventually. Here we skip subquery, as geo-stats functions only needs to 
be extracted once.
-    case s: Subquery if s.correlated => plan
-    case _ =>
-      plan.transformUp {
-        case p: EvalGeoStatsFunction => p
-        case plan: LogicalPlan => extract(plan)
-      }
-  }
-
-  private def canonicalizeDeterministic(u: ST_GeoStatsFunction) = {
-    if (u.deterministic) {
-      u.canonicalized.asInstanceOf[ST_GeoStatsFunction]
-    } else {
-      u
-    }
-  }
-
-  /**
-   * Extract all the geo-stats functions from the current operator and 
evaluate them before the
-   * operator.
-   */
-  private def extract(plan: LogicalPlan): LogicalPlan = {
-    val geoStatsFuncs = plan match {
-      case e: EvalGeoStatsFunction =>
-        collectGeoStatsFunctionsFromExpressions(e.function.children)
-      case _ =>
-        
ExpressionSet(collectGeoStatsFunctionsFromExpressions(plan.expressions))
-          // ignore the ST_GeoStatsFunction that come from second/third 
aggregate, which is not used
-          .filter(func => func.references.subsetOf(plan.inputSet))
-          .filter(func =>
-            plan.children.exists(child => 
func.references.subsetOf(child.outputSet)))
-          .toSeq
-          .asInstanceOf[Seq[ST_GeoStatsFunction]]
-    }
-
-    if (geoStatsFuncs.isEmpty) {
-      // If there aren't any, we are done.
-      plan
-    } else {
-      // Transform the first geo-stats function we have found. We'll call 
extract recursively later
-      // to transform the rest.
-      val geoStatsFunc = geoStatsFuncs.head
-
-      val attributeMap = mutable.HashMap[ST_GeoStatsFunction, Expression]()
-      // Rewrite the child that has the input required for the UDF
-      val newChildren = plan.children.map { child =>
-        if (geoStatsFunc.references.subsetOf(child.outputSet)) {
-          geoStatsResultCount += 1
-          val resultAttr =
-            AttributeReference(f"geoStatsResult$geoStatsResultCount", 
geoStatsFunc.dataType)()
-          val evaluation = EvalGeoStatsFunction(geoStatsFunc, Seq(resultAttr), 
child)
-          attributeMap += (canonicalizeDeterministic(geoStatsFunc) -> 
resultAttr)
-          extract(evaluation) // handle nested geo-stats functions
-        } else {
-          child
-        }
-      }
-
-      // Replace the geo stats function call with the newly created 
geoStatsResult attribute
-      val rewritten = plan.withNewChildren(newChildren).transformExpressions {
-        case p: ST_GeoStatsFunction => 
attributeMap.getOrElse(canonicalizeDeterministic(p), p)
-      }
-
-      // extract remaining geo-stats functions recursively
-      val newPlan = extract(rewritten)
-      if (newPlan.output != plan.output) {
-        // Trim away the new UDF value if it was only used for filtering or 
something.
-        Project(plan.output, newPlan)
-      } else {
-        newPlan
-      }
-    }
-  }
-}
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractPhysicalFunctions.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractPhysicalFunctions.scala
new file mode 100644
index 0000000000..9aac6db2b8
--- /dev/null
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExtractPhysicalFunctions.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.optimization
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.sedona_sql.expressions.PhysicalFunction
+import org.apache.spark.sql.sedona_sql.plans.logical.EvalPhysicalFunction
+
+import scala.collection.mutable
+
+/**
+ * Extracts Physical functions from operators, rewriting the query plan so 
that the functions can
+ * be evaluated alone in its own physical executors.
+ */
+object ExtractPhysicalFunctions extends Rule[LogicalPlan] {
+  private var physicalFunctionResultCount = 0
+
+  private def collectPhysicalFunctionsFromExpressions(
+      expressions: Seq[Expression]): Seq[PhysicalFunction] = {
+    def collectPhysicalFunctions(expr: Expression): Seq[PhysicalFunction] = 
expr match {
+      case expr: PhysicalFunction => Seq(expr)
+      case e => e.children.flatMap(collectPhysicalFunctions)
+    }
+    expressions.flatMap(collectPhysicalFunctions)
+  }
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan match {
+    // SPARK-26293: A subquery will be rewritten into join later, and will go 
through this rule
+    // eventually. Here we skip subquery, as physical functions only needs to 
be extracted once.
+    case s: Subquery if s.correlated => plan
+    case _ =>
+      plan.transformUp {
+        case p: EvalPhysicalFunction => p
+        case plan: LogicalPlan => extract(plan)
+      }
+  }
+
+  private def canonicalizeDeterministic(u: PhysicalFunction) = {
+    if (u.deterministic) {
+      u.canonicalized.asInstanceOf[PhysicalFunction]
+    } else {
+      u
+    }
+  }
+
+  /**
+   * Extract all the physical functions from the current operator and evaluate 
them before the
+   * operator.
+   */
+  private def extract(plan: LogicalPlan): LogicalPlan = {
+    val physicalFunctions = plan match {
+      case e: EvalPhysicalFunction =>
+        collectPhysicalFunctionsFromExpressions(e.function.children)
+      case _ =>
+        
ExpressionSet(collectPhysicalFunctionsFromExpressions(plan.expressions))
+          // ignore the PhysicalFunction that come from second/third 
aggregate, which is not used
+          .filter(func => func.references.subsetOf(plan.inputSet))
+          .filter(func =>
+            plan.children.exists(child => 
func.references.subsetOf(child.outputSet)))
+          .toSeq
+          .asInstanceOf[Seq[PhysicalFunction]]
+    }
+
+    if (physicalFunctions.isEmpty) {
+      // If there aren't any, we are done.
+      plan
+    } else {
+      // Transform the first physical function we have found. We'll call 
extract recursively later
+      // to transform the rest.
+      val physicalFunction = physicalFunctions.head
+
+      val attributeMap = mutable.HashMap[PhysicalFunction, Expression]()
+      // Rewrite the child that has the input required for the UDF
+      val newChildren = plan.children.map { child =>
+        if (physicalFunction.references.subsetOf(child.outputSet)) {
+          physicalFunctionResultCount += 1
+          val resultAttr =
+            AttributeReference(
+              f"physicalFunctionResult$physicalFunctionResultCount",
+              physicalFunction.dataType)()
+          val evaluation = EvalPhysicalFunction(physicalFunction, 
Seq(resultAttr), child)
+          attributeMap += (canonicalizeDeterministic(physicalFunction) -> 
resultAttr)
+          extract(evaluation) // handle nested functions
+        } else {
+          child
+        }
+      }
+
+      // Replace the physical function call with the newly created attribute
+      val rewritten = plan.withNewChildren(newChildren).transformExpressions {
+        case p: PhysicalFunction => 
attributeMap.getOrElse(canonicalizeDeterministic(p), p)
+      }
+
+      // extract remaining physical functions recursively
+      val newPlan = extract(rewritten)
+      if (newPlan.output != plan.output) {
+        // Trim away the new UDF value if it was only used for filtering or 
something.
+        Project(plan.output, newPlan)
+      } else {
+        newPlan
+      }
+    }
+  }
+}
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalGeoStatsFunction.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalPhysicalFunction.scala
similarity index 97%
rename from 
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalGeoStatsFunction.scala
rename to 
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalPhysicalFunction.scala
index 8daeb0c304..9371d0c12d 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalGeoStatsFunction.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/plans/logical/EvalPhysicalFunction.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.plans.logical.UnaryNode
 
-case class EvalGeoStatsFunction(
+case class EvalPhysicalFunction(
     function: Expression,
     resultAttrs: Seq[Attribute],
     child: LogicalPlan)
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionExec.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/physical/function/EvalPhysicalFunctionExec.scala
similarity index 87%
rename from 
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionExec.scala
rename to 
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/physical/function/EvalPhysicalFunctionExec.scala
index fbecb69ec4..a99630dc0c 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionExec.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/physical/function/EvalPhysicalFunctionExec.scala
@@ -16,16 +16,16 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.spark.sql.sedona_sql.strategy.geostats
+package org.apache.spark.sql.sedona_sql.strategy.physical.function
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
 import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
-import org.apache.spark.sql.sedona_sql.expressions.ST_GeoStatsFunction
+import org.apache.spark.sql.sedona_sql.expressions.PhysicalFunction
 
-case class EvalGeoStatsFunctionExec(
-    function: ST_GeoStatsFunction,
+case class EvalPhysicalFunctionExec(
+    function: PhysicalFunction,
     child: SparkPlan,
     resultAttrs: Seq[Attribute])
     extends UnaryExecNode {
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionStrategy.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/physical/function/EvalPhysicalFunctionStrategy.scala
similarity index 73%
rename from 
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionStrategy.scala
rename to 
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/physical/function/EvalPhysicalFunctionStrategy.scala
index 4c10b747a6..a159badd38 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/geostats/EvalGeoStatsFunctionStrategy.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/physical/function/EvalPhysicalFunctionStrategy.scala
@@ -16,21 +16,21 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.spark.sql.sedona_sql.strategy.geostats
+package org.apache.spark.sql.sedona_sql.strategy.physical.function
 
 import org.apache.spark.sql.Strategy
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.sedona_sql.plans.logical.EvalGeoStatsFunction
+import org.apache.spark.sql.sedona_sql.plans.logical.EvalPhysicalFunction
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.sedona_sql.expressions.ST_GeoStatsFunction
+import org.apache.spark.sql.sedona_sql.expressions.PhysicalFunction
 
-class EvalGeoStatsFunctionStrategy(spark: SparkSession) extends Strategy {
+class EvalPhysicalFunctionStrategy(spark: SparkSession) extends Strategy {
 
   override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
     plan match {
-      case EvalGeoStatsFunction(function: ST_GeoStatsFunction, resultAttrs, 
child) =>
-        EvalGeoStatsFunctionExec(function, planLater(child), resultAttrs) :: 
Nil
+      case EvalPhysicalFunction(function: PhysicalFunction, resultAttrs, 
child) =>
+        EvalPhysicalFunctionExec(function, planLater(child), resultAttrs) :: 
Nil
       case _ => Nil
     }
   }


Reply via email to