This is an automated email from the ASF dual-hosted git repository.

imbruced pushed a commit to branch sedona-arrow-udf-example
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit 98728457796c01f61cc5f9fd38ffa9d90d54ddb1
Author: pawelkocinski <[email protected]>
AuthorDate: Sun Feb 23 21:53:47 2025 +0100

    SEDONA-721 Add Sedona vectorized udf.
---
 pom.xml                                            |  12 +-
 python/sedona/sql/udf.py                           |  19 +++
 python/sedona/utils/geoarrow.py                    |   3 +-
 python/tests/utils/test_pandas_arrow_udf.py        |  46 ++++++
 .../org/apache/sedona/sql/RasterRegistrator.scala  |   7 +
 .../sedona_sql/strategies/ExtractSedonaUDF.scala   | 165 +++++++++++++++++++++
 .../sql/sedona_sql/strategies/PythonEvalType.scala |  28 ++++
 .../strategies/SedonaArrowEvalPython.scala         |  32 ++++
 .../strategies/SedonaArrowStrategy.scala           |  82 ++++++++++
 .../sql/sedona_sql/strategies/StrategySuite.scala  |  67 +++++++++
 .../strategies/TestScalarPandasUDF.scala           | 121 +++++++++++++++
 11 files changed, 574 insertions(+), 8 deletions(-)

diff --git a/pom.xml b/pom.xml
index 08d4ff646a..da87637d32 100644
--- a/pom.xml
+++ b/pom.xml
@@ -18,12 +18,12 @@
   -->
 <project xmlns="http://maven.apache.org/POM/4.0.0"; 
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"; 
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 
http://maven.apache.org/maven-v4_0_0.xsd";>
     <modelVersion>4.0.0</modelVersion>
-    <parent>
-        <groupId>org.apache</groupId>
-        <artifactId>apache</artifactId>
-        <version>23</version>
-        <relativePath />
-    </parent>
+<!--    <parent>-->
+<!--        <groupId>org.apache</groupId>-->
+<!--        <artifactId>apache</artifactId>-->
+<!--        <version>23</version>-->
+<!--        <relativePath />-->
+<!--    </parent>-->
     <groupId>org.apache.sedona</groupId>
     <artifactId>sedona-parent</artifactId>
     <version>1.8.0-SNAPSHOT</version>
diff --git a/python/sedona/sql/udf.py b/python/sedona/sql/udf.py
new file mode 100644
index 0000000000..6e723d0e10
--- /dev/null
+++ b/python/sedona/sql/udf.py
@@ -0,0 +1,19 @@
+import pandas as pd
+
+from sedona.sql.types import GeometryType
+from sedona.utils import geometry_serde
+from pyspark.sql.udf import UserDefinedFunction
+
+
+SEDONA_SCALAR_EVAL_TYPE = 5200
+
+
+def sedona_vectorized_udf(fn):
+    def apply(series: pd.Series) -> pd.Series:
+        geo_series = series.apply(lambda x: 
fn(geometry_serde.deserialize(x)[0]))
+
+        return geo_series.apply(lambda x: geometry_serde.serialize(x))
+
+    return UserDefinedFunction(
+        apply, GeometryType(), "SedonaPandasArrowUDF", 
evalType=SEDONA_SCALAR_EVAL_TYPE
+    )
diff --git a/python/sedona/utils/geoarrow.py b/python/sedona/utils/geoarrow.py
index b4a539dfa4..b0f708af9c 100644
--- a/python/sedona/utils/geoarrow.py
+++ b/python/sedona/utils/geoarrow.py
@@ -323,12 +323,11 @@ def create_spatial_dataframe(spark: SparkSession, gdf: 
gpd.GeoDataFrame) -> Data
     step = spark._jconf.arrowMaxRecordsPerBatch()
     step = step if step > 0 else len(gdf)
     pdf_slices = (gdf.iloc[start : start + step] for start in range(0, 
len(gdf), step))
-    spark_types = [_deduplicate_field_names(f.dataType) for f in schema.fields]
 
     arrow_data = [
         [
             (c, to_arrow_type(t) if t is not None else None, t)
-            for (_, c), t in zip(pdf_slice.items(), spark_types)
+            for (_, c), t in zip(pdf_slice.items(), schema.fields)
         ]
         for pdf_slice in pdf_slices
     ]
diff --git a/python/tests/utils/test_pandas_arrow_udf.py 
b/python/tests/utils/test_pandas_arrow_udf.py
new file mode 100644
index 0000000000..7d59a93283
--- /dev/null
+++ b/python/tests/utils/test_pandas_arrow_udf.py
@@ -0,0 +1,46 @@
+from sedona.sql.types import GeometryType
+from sedona.sql.udf import sedona_vectorized_udf
+from tests import chicago_crimes_input_location
+from tests.test_base import TestBase
+import pyspark.sql.functions as f
+import shapely.geometry.base as b
+from time import time
+
+
+def non_vectorized_buffer_udf(geom: b.BaseGeometry) -> b.BaseGeometry:
+    return geom.buffer(0.001)
+
+
+@sedona_vectorized_udf
+def vectorized_buffer(geom: b.BaseGeometry) -> b.BaseGeometry:
+    return geom.buffer(0.001)
+
+
+buffer_distanced_udf = f.udf(non_vectorized_buffer_udf, GeometryType())
+
+
+class TestSedonaArrowUDF(TestBase):
+
+    def test_pandas_arrow_udf(self):
+        df = (
+            self.spark.read.option("header", "true")
+            .format("csv")
+            .load(chicago_crimes_input_location)
+            .selectExpr("ST_Point(y, x) AS geom")
+        )
+
+        vectorized_times = []
+        non_vectorized_times = []
+
+        for i in range(10):
+            start = time()
+            df = df.withColumn("buffer", vectorized_buffer(f.col("geom")))
+            df.count()
+            vectorized_times.append(time() - start)
+
+            df = df.withColumn("buffer", buffer_distanced_udf(f.col("geom")))
+            df.count()
+            non_vectorized_times.append(time() - start)
+
+        for v, nv in zip(vectorized_times, non_vectorized_times):
+            assert v < nv, "Vectorized UDF is slower than non-vectorized UDF"
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala 
b/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala
index ee7aa8b0be..bcacb2ab29 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala
@@ -22,6 +22,7 @@ import org.apache.sedona.sql.UDF.RasterUdafCatalog
 import 
org.apache.sedona.sql.utils.GeoToolsCoverageAvailability.{gridClassName, 
isGeoToolsAvailable}
 import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.sedona_sql.UDT.RasterUdtRegistratorWrapper
+import org.apache.spark.sql.sedona_sql.strategies.{ExtractSedonaUDF, 
SedonaArrowStrategy}
 import org.apache.spark.sql.{SparkSession, functions}
 import org.slf4j.{Logger, LoggerFactory}
 
@@ -29,6 +30,12 @@ object RasterRegistrator {
   val logger: Logger = LoggerFactory.getLogger(getClass)
 
   def registerAll(sparkSession: SparkSession): Unit = {
+
+    sparkSession.experimental.extraStrategies =
+      sparkSession.experimental.extraStrategies :+ new SedonaArrowStrategy()
+    sparkSession.experimental.extraOptimizations =
+      sparkSession.experimental.extraOptimizations :+ ExtractSedonaUDF
+
     if (isGeoToolsAvailable) {
       RasterUdtRegistratorWrapper.registerAll(gridClassName)
       sparkSession.udf.register(
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/ExtractSedonaUDF.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/ExtractSedonaUDF.scala
new file mode 100644
index 0000000000..be34fa5fcc
--- /dev/null
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/ExtractSedonaUDF.scala
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategies
+
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
Expression, ExpressionSet, PythonUDF}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, 
Subquery}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.PYTHON_UDF
+
+import scala.collection.mutable
+
+object ExtractSedonaUDF extends Rule[LogicalPlan] {
+
+  private def hasScalarPythonUDF(e: Expression): Boolean = {
+    e.exists(PythonUDF.isScalarPythonUDF)
+  }
+
+  @scala.annotation.tailrec
+  private def canEvaluateInPython(e: PythonUDF): Boolean = {
+    e.children match {
+      case Seq(u: PythonUDF) => e.evalType == u.evalType && 
canEvaluateInPython(u)
+      case children => !children.exists(hasScalarPythonUDF)
+    }
+  }
+
+  def isScalarPythonUDF(e: Expression): Boolean = {
+    e.isInstanceOf[PythonUDF] && e
+      .asInstanceOf[PythonUDF]
+      .evalType == PythonEvalType.SQL_SCALAR_SEDONA_UDF
+  }
+
+  private def collectEvaluableUDFsFromExpressions(
+      expressions: Seq[Expression]): Seq[PythonUDF] = {
+
+    var firstVisitedScalarUDFEvalType: Option[Int] = None
+
+    def canChainUDF(evalType: Int): Boolean = {
+      evalType == firstVisitedScalarUDFEvalType.get
+    }
+
+    def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match {
+      case udf: PythonUDF
+          if isScalarPythonUDF(udf) && canEvaluateInPython(udf)
+            && firstVisitedScalarUDFEvalType.isEmpty =>
+        firstVisitedScalarUDFEvalType = Some(udf.evalType)
+        Seq(udf)
+      case udf: PythonUDF
+          if isScalarPythonUDF(udf) && canEvaluateInPython(udf)
+            && canChainUDF(udf.evalType) =>
+        Seq(udf)
+      case e => e.children.flatMap(collectEvaluableUDFs)
+    }
+
+    expressions.flatMap(collectEvaluableUDFs)
+  }
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan match {
+    case s: Subquery if s.correlated => plan
+
+    case _ =>
+      plan.transformUpWithPruning(_.containsPattern(PYTHON_UDF)) {
+        case p: SedonaArrowEvalPython => p
+
+        case plan: LogicalPlan => extract(plan)
+      }
+  }
+
+  private def canonicalizeDeterministic(u: PythonUDF) = {
+    if (u.deterministic) {
+      u.canonicalized.asInstanceOf[PythonUDF]
+    } else {
+      u
+    }
+  }
+
+  private def extract(plan: LogicalPlan): LogicalPlan = {
+    val udfs = 
ExpressionSet(collectEvaluableUDFsFromExpressions(plan.expressions))
+      .filter(udf => udf.references.subsetOf(plan.inputSet))
+      .toSeq
+      .asInstanceOf[Seq[PythonUDF]]
+
+    udfs match {
+      case Seq() => plan
+      case _ => resolveUDFs(plan, udfs)
+    }
+  }
+
+  def resolveUDFs(plan: LogicalPlan, udfs: Seq[PythonUDF]): LogicalPlan = {
+    val attributeMap = mutable.HashMap[PythonUDF, Expression]()
+
+    val newChildren = adjustAttributeMap(plan, udfs, attributeMap)
+
+    
udfs.map(canonicalizeDeterministic).filterNot(attributeMap.contains).foreach { 
udf =>
+      throw new IllegalStateException(
+        s"Invalid PythonUDF $udf, requires attributes from more than one 
child.")
+    }
+
+    val rewritten = plan.withNewChildren(newChildren).transformExpressions { 
case p: PythonUDF =>
+      attributeMap.getOrElse(canonicalizeDeterministic(p), p)
+    }
+
+    val newPlan = extract(rewritten)
+    if (newPlan.output != plan.output) {
+      Project(plan.output, newPlan)
+    } else {
+      newPlan
+    }
+  }
+
+  def adjustAttributeMap(
+      plan: LogicalPlan,
+      udfs: Seq[PythonUDF],
+      attributeMap: mutable.HashMap[PythonUDF, Expression]): Seq[LogicalPlan] 
= {
+    plan.children.map { child =>
+      val validUdfs = udfs.filter { udf =>
+        udf.references.subsetOf(child.outputSet)
+      }
+
+      if (validUdfs.nonEmpty) {
+        require(
+          validUdfs.forall(isScalarPythonUDF),
+          "Can only extract scalar vectorized udf or sql batch udf")
+
+        val resultAttrs = validUdfs.zipWithIndex.map { case (u, i) =>
+          AttributeReference(s"pythonUDF$i", u.dataType)()
+        }
+
+        val evalTypes = validUdfs.map(_.evalType).toSet
+        if (evalTypes.size != 1) {
+          throw new IllegalStateException(
+            "Expected udfs have the same evalType but got different evalTypes: 
" +
+              evalTypes.mkString(","))
+        }
+        val evalType = evalTypes.head
+        val evaluation = evalType match {
+          case PythonEvalType.SQL_SCALAR_SEDONA_UDF =>
+            SedonaArrowEvalPython(validUdfs, resultAttrs, child, evalType)
+          case _ =>
+            throw new IllegalStateException("Unexpected UDF evalType")
+        }
+
+        attributeMap ++= 
validUdfs.map(canonicalizeDeterministic).zip(resultAttrs)
+        evaluation
+      } else {
+        child
+      }
+    }
+  }
+}
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/PythonEvalType.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/PythonEvalType.scala
new file mode 100644
index 0000000000..0a8904edb4
--- /dev/null
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/PythonEvalType.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategies
+
+object PythonEvalType {
+  val SQL_SCALAR_SEDONA_UDF = 5200
+  val SEDONA_UDF_TYPE_CONSTANT = 5000
+
+  def toString(pythonEvalType: Int): String = pythonEvalType match {
+    case SQL_SCALAR_SEDONA_UDF => "SQL_SCALAR_GEO_UDF"
+  }
+}
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowEvalPython.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowEvalPython.scala
new file mode 100644
index 0000000000..78e00871ad
--- /dev/null
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowEvalPython.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategies
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF}
+import org.apache.spark.sql.catalyst.plans.logical.{BaseEvalPython, 
LogicalPlan}
+
+case class SedonaArrowEvalPython(
+    udfs: Seq[PythonUDF],
+    resultAttrs: Seq[Attribute],
+    child: LogicalPlan,
+    evalType: Int)
+    extends BaseEvalPython {
+  override protected def withNewChildInternal(newChild: LogicalPlan): 
SedonaArrowEvalPython =
+    copy(child = newChild)
+}
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowStrategy.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowStrategy.scala
new file mode 100644
index 0000000000..f5a0d1c95f
--- /dev/null
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowStrategy.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategies
+
+import org.apache.spark.api.python.ChainedPythonFunctions
+import org.apache.spark.{JobArtifactSet, TaskContext}
+import org.apache.spark.sql.Strategy
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.python.{ArrowPythonRunner, 
BatchIterator, EvalPythonExec, PythonSQLMetrics}
+import org.apache.spark.sql.types.StructType
+
+import scala.jdk.CollectionConverters.asScalaIteratorConverter
+
+class SedonaArrowStrategy extends Strategy {
+  override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+    case SedonaArrowEvalPython(udfs, output, child, evalType) =>
+      SedonaArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: 
Nil
+    case _ => Nil
+  }
+}
+
+case class SedonaArrowEvalPythonExec(
+    udfs: Seq[PythonUDF],
+    resultAttrs: Seq[Attribute],
+    child: SparkPlan,
+    evalType: Int)
+    extends EvalPythonExec
+    with PythonSQLMetrics {
+
+  private val batchSize = conf.arrowMaxRecordsPerBatch
+  private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+  private val largeVarTypes = conf.arrowUseLargeVarTypes
+  private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
+  private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+
+  protected override def evaluate(
+      funcs: Seq[ChainedPythonFunctions],
+      argOffsets: Array[Array[Int]],
+      iter: Iterator[InternalRow],
+      schema: StructType,
+      context: TaskContext): Iterator[InternalRow] = {
+
+    val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else 
Iterator(iter)
+
+    val columnarBatchIter = new ArrowPythonRunner(
+      funcs,
+      evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT,
+      argOffsets,
+      schema,
+      sessionLocalTimeZone,
+      largeVarTypes,
+      pythonRunnerConf,
+      pythonMetrics,
+      jobArtifactUUID).compute(batchIter, context.partitionId(), context)
+
+    columnarBatchIter.flatMap { batch =>
+      batch.rowIterator.asScala
+    }
+  }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    copy(child = newChild)
+}
diff --git 
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/strategies/StrategySuite.scala
 
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/strategies/StrategySuite.scala
new file mode 100644
index 0000000000..52d8ea8bac
--- /dev/null
+++ 
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/strategies/StrategySuite.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategies
+
+import org.apache.sedona.spark.SedonaContext
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.functions.col
+import 
org.apache.spark.sql.sedona_sql.strategies.ScalarUDF.geoPandasScalaFunction
+import org.locationtech.jts.io.WKTReader
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.matchers.should.Matchers
+
+class StrategySuite extends AnyFunSuite with Matchers {
+  val wktReader = new WKTReader()
+
+  val spark: SparkSession = {
+    val builder = SedonaContext
+      .builder()
+      .master("local[*]")
+      .appName("sedonasqlScalaTest")
+
+    val spark = SedonaContext.create(builder.getOrCreate())
+
+    spark.sparkContext.setLogLevel("ALL")
+    spark
+  }
+
+  import spark.implicits._
+
+  test("Chained Scalar Pandas UDFs should be combined to a single physical 
node") {
+    val df = Seq(
+      (1, "value", wktReader.read("POINT(21 52)")),
+      (2, "value1", wktReader.read("POINT(20 50)")),
+      (3, "value2", wktReader.read("POINT(20 49)")),
+      (4, "value3", wktReader.read("POINT(20 48)")),
+      (5, "value4", wktReader.read("POINT(20 47)")))
+      .toDF("id", "value", "geom")
+      .withColumn("geom_buffer", geoPandasScalaFunction(col("geom")))
+
+    df.count shouldEqual 5
+
+    df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))")
+      .as[String]
+      .collect() should contain theSameElementsAs Seq(
+      "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))",
+      "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))",
+      "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))",
+      "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))",
+      "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))")
+  }
+}
diff --git 
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/strategies/TestScalarPandasUDF.scala
 
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/strategies/TestScalarPandasUDF.scala
new file mode 100644
index 0000000000..925115b5e8
--- /dev/null
+++ 
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/strategies/TestScalarPandasUDF.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategies
+
+import org.apache.spark.TestUtils
+import org.apache.spark.api.python._
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.util.Utils
+
+import java.io.File
+import java.nio.file.{Files, Paths}
+import scala.jdk.CollectionConverters.seqAsJavaListConverter
+import scala.sys.process.Process
+
+object ScalarUDF {
+
+  val pythonExec: String = {
+    val pythonExec =
+      sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", 
sys.env.getOrElse("PYSPARK_PYTHON", "python3"))
+    if (TestUtils.testCommandAvailable(pythonExec)) {
+      pythonExec
+    } else {
+      "python"
+    }
+  }
+
+  private[spark] lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "")
+  protected lazy val sparkHome: String = {
+    sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
+  }
+
+  private lazy val py4jPath =
+    Paths.get(sparkHome, "python", "lib", 
PythonUtils.PY4J_ZIP_NAME).toAbsolutePath
+  private[spark] lazy val pysparkPythonPath = s"$py4jPath"
+
+  private lazy val isPythonAvailable: Boolean = 
TestUtils.testCommandAvailable(pythonExec)
+
+  lazy val pythonVer: String = if (isPythonAvailable) {
+    Process(
+      Seq(pythonExec, "-c", "import sys; print('%d.%d' % 
sys.version_info[:2])"),
+      None,
+      "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!.trim()
+  } else {
+    throw new RuntimeException(s"Python executable [$pythonExec] is 
unavailable.")
+  }
+
+  protected def withTempPath(f: File => Unit): Unit = {
+    val path = Utils.createTempDir()
+    path.delete()
+    try f(path)
+    finally Utils.deleteRecursively(path)
+  }
+
+  val pandasFunc: Array[Byte] = {
+    var binaryPandasFunc: Array[Byte] = null
+    withTempPath { path =>
+      println(path)
+      Process(
+        Seq(
+          pythonExec,
+          "-c",
+          f"""
+            |from pyspark.sql.types import IntegerType
+            |from shapely.geometry import Point
+            |from sedona.sql.types import GeometryType
+            |from pyspark.serializers import CloudPickleSerializer
+            |from sedona.utils import geometry_serde
+            |from shapely import box
+            |f = open('$path', 'wb');
+            |def w(x):
+            |    def apply_function(w):
+            |        geom, offset = geometry_serde.deserialize(w)
+            |        bounds = geom.buffer(1).bounds
+            |        x = box(*bounds)
+            |        return geometry_serde.serialize(x)
+            |    return x.apply(apply_function)
+            |f.write(CloudPickleSerializer().dumps((w, GeometryType())))
+            |""".stripMargin),
+        None,
+        "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
+      binaryPandasFunc = Files.readAllBytes(path.toPath)
+    }
+    assert(binaryPandasFunc != null)
+    binaryPandasFunc
+  }
+
+  private val workerEnv = new java.util.HashMap[String, String]()
+  workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath")
+
+  val geoPandasScalaFunction: UserDefinedPythonFunction = 
UserDefinedPythonFunction(
+    name = "geospatial_udf",
+    func = SimplePythonFunction(
+      command = pandasFunc,
+      envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],
+      pythonIncludes = List.empty[String].asJava,
+      pythonExec = pythonExec,
+      pythonVer = pythonVer,
+      broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
+      accumulator = null),
+    dataType = GeometryUDT,
+    pythonEvalType = PythonEvalType.SQL_SCALAR_SEDONA_UDF,
+    udfDeterministic = true)
+}

Reply via email to