This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 1798df23fa [SEDONA-721] Add Sedona vectorized udf for Python (#1859)
1798df23fa is described below

commit 1798df23fa0cbe8460979a41df09e6129a87d8a9
Author: Paweł Tokaj <[email protected]>
AuthorDate: Wed Apr 2 03:35:19 2025 +0200

    [SEDONA-721] Add Sedona vectorized udf for Python (#1859)
    
    * SEDONA-721 Add Sedona vectorized udf.
    
    * SEDONA-721 Add documentation
    
    * SEDONA-721 Add documentation
    
    * SEDONA-721 Add documentation
    
    * Update .github/workflows/java.yml
    
    Co-authored-by: Kristin Cowalcijk <[email protected]>
    
    * SEDONA-721 Apply requested changes.
    
    * SEDONA-721 Apply requested changes.
    
    * SEDONA-721 Apply requested changes.
    
    * SEDONA-721 Apply requested changes.
    
    * SEDONA-721 Apply requested changes.
    
    * SEDONA-721 Apply requested changes.
    
    * SEDONA-721 Apply requested changes.
    
    * SEDONA-721 Apply requested changes.
    
    ---------
    
    Co-authored-by: Kristin Cowalcijk <[email protected]>
---
 .github/workflows/java.yml                         |   8 +-
 docs/tutorial/sql.md                               |  67 ++++++
 python/sedona/sql/functions.py                     | 144 +++++++++++++
 python/tests/utils/test_pandas_arrow_udf.py        | 231 +++++++++++++++++++++
 .../org/apache/sedona/spark/SedonaContext.scala    |  28 ++-
 .../org/apache/sedona/sql/UDF/PythonEvalType.scala |  29 +++
 .../strategies/SedonaArrowEvalPython.scala         |  32 +++
 .../spark/sql/udf/ExtractSedonaUDFRule.scala       | 168 +++++++++++++++
 .../spark/sql/udf/SedonaArrowEvalPython.scala      |  32 +++
 .../apache/spark/sql/udf/SedonaArrowStrategy.scala |  89 ++++++++
 .../org/apache/spark/sql/udf/StrategySuite.scala   |  67 ++++++
 .../apache/spark/sql/udf/TestScalarPandasUDF.scala | 122 +++++++++++
 12 files changed, 1015 insertions(+), 2 deletions(-)

diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml
index f8e4f6b204..921102cb36 100644
--- a/.github/workflows/java.yml
+++ b/.github/workflows/java.yml
@@ -97,7 +97,7 @@ jobs:
           java-version: ${{ matrix.jdk }}
       - uses: actions/setup-python@v5
         with:
-          python-version: '3.7'
+          python-version: '3.10'
       - name: Cache Maven packages
         uses: actions/cache@v3
         with:
@@ -110,6 +110,12 @@ jobs:
           SKIP_TESTS: ${{ matrix.skipTests }}
         run: |
           SPARK_COMPAT_VERSION=${SPARK_VERSION:0:3}
+
+          if [ "${SPARK_VERSION}" == "3.5.0" ]; then
+              pip install pyspark==3.5.0 pandas shapely apache-sedona pyarrow
+              export SPARK_HOME=$(python -c "import pyspark; 
print(pyspark.__path__[0])")
+          fi
+
           mvn -q clean install -Dspark=${SPARK_COMPAT_VERSION} 
-Dscala=${SCALA_VERSION:0:4} -Dspark.version=${SPARK_VERSION} ${SKIP_TESTS}
       - run: mkdir staging
       - run: cp spark-shaded/target/sedona-*.jar staging
diff --git a/docs/tutorial/sql.md b/docs/tutorial/sql.md
index 821ea37d9b..b835084757 100644
--- a/docs/tutorial/sql.md
+++ b/docs/tutorial/sql.md
@@ -1195,6 +1195,73 @@ Output:
 
+------------------------------+--------+--------------------------------------------------+-----------------+
 ```
 
+## Spatial vectorized udfs (Python only)
+
+By default when you create the user defined functions in Python, the UDFs are 
not vectorized.
+This means that the UDFs are called row by row which can be slow.
+To speed up the UDFs, you can use the `vectorized` UDF which will be called in 
a batch mode
+using Apache Arrow.
+
+To create a vectorized UDF please use the decorator sedona_vectorized_udf.
+Currently supports only the scalar UDFs. Vectorized UDFs are way faster than
+the normal UDFs. It might be even 2x faster than the normal UDFs.
+
+!!!note
+       When you use geometry as an input type, please include the BaseGeometry 
type,
+       like Point from shapely or geopandas GeoSeries, when you use GEO_SERIES 
vectorized udf.
+       That's how Sedona infers the type and knows if the data should be cast.
+
+Decorator signature looks as follows:
+
+```python
+def sedona_vectorized_udf(udf_type: SedonaUDFType = 
SedonaUDFType.SHAPELY_SCALAR, return_type: DataType)
+```
+
+where udf_type is the type of the UDF function, currently supported are:
+
+- SHAPELY_SCALAR
+- GEO_SERIES
+
+The main difference is what input data you get in the function
+Let's analyze the two examples below, that creates buffers from
+a given geometry.
+
+### Shapely scalar UDF
+
+```python
+import shapely.geometry.base as b
+from sedona.sql.functions import sedona_vectorized_udf
+
+@sedona_vectorized_udf(return_type=GeometryType())
+def vectorized_buffer(geom: b.BaseGeometry) -> b.BaseGeometry:
+    return geom.buffer(0.1)
+```
+
+### GeoSeries UDF
+
+```python
+import geopandas as gpd
+from sedona.sql.functions import sedona_vectorized_udf, SedonaUDFType
+from sedona.sql.types import GeometryType
+
+
+@sedona_vectorized_udf(udf_type=SedonaUDFType.GEO_SERIES, 
return_type=GeometryType())
+def vectorized_geo_series_buffer(series: gpd.GeoSeries) -> gpd.GeoSeries:
+    buffered = series.buffer(0.1)
+
+    return buffered
+```
+
+To call the UDFs you can use the following code:
+
+```python
+# Shapely scalar UDF
+df.withColumn("buffered", vectorized_buffer(df.geom)).show()
+
+# GeoSeries UDF
+df.withColumn("buffered", vectorized_geo_series_buffer(df.geom)).show()
+```
+
 ## Save to permanent storage
 
 To save a Spatial DataFrame to some permanent storage such as Hive tables and 
HDFS, you can simply convert each geometry in the Geometry type column back to 
a plain String and save the plain DataFrame to wherever you want.
diff --git a/python/sedona/sql/functions.py b/python/sedona/sql/functions.py
new file mode 100644
index 0000000000..83648ff53f
--- /dev/null
+++ b/python/sedona/sql/functions.py
@@ -0,0 +1,144 @@
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing,
+#  software distributed under the License is distributed on an
+#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+#  KIND, either express or implied.  See the License for the
+#  specific language governing permissions and limitations
+#  under the License.
+
+
+import inspect
+from enum import Enum
+
+import pandas as pd
+
+from sedona.sql.types import GeometryType
+from sedona.utils import geometry_serde
+from pyspark.sql.udf import UserDefinedFunction
+from pyspark.sql.types import DataType
+from shapely.geometry.base import BaseGeometry
+
+
+SEDONA_SCALAR_EVAL_TYPE = 5200
+SEDONA_PANDAS_ARROW_NAME = "SedonaPandasArrowUDF"
+
+
+class SedonaUDFType(Enum):
+    SHAPELY_SCALAR = "ShapelyScalar"
+    GEO_SERIES = "GeoSeries"
+
+
+class InvalidSedonaUDFType(Exception):
+    pass
+
+
+sedona_udf_to_eval_type = {
+    SedonaUDFType.SHAPELY_SCALAR: SEDONA_SCALAR_EVAL_TYPE,
+    SedonaUDFType.GEO_SERIES: SEDONA_SCALAR_EVAL_TYPE,
+}
+
+
+def sedona_vectorized_udf(
+    return_type: DataType, udf_type: SedonaUDFType = 
SedonaUDFType.SHAPELY_SCALAR
+):
+    import geopandas as gpd
+
+    def apply_fn(fn):
+        function_signature = inspect.signature(fn)
+        serialize_geom = False
+        deserialize_geom = False
+
+        if isinstance(return_type, GeometryType):
+            serialize_geom = True
+
+        if issubclass(function_signature.return_annotation, BaseGeometry):
+            serialize_geom = True
+
+        if issubclass(function_signature.return_annotation, gpd.GeoSeries):
+            serialize_geom = True
+
+        for param in function_signature.parameters.values():
+            if issubclass(param.annotation, BaseGeometry):
+                deserialize_geom = True
+
+            if issubclass(param.annotation, gpd.GeoSeries):
+                deserialize_geom = True
+
+        if udf_type == SedonaUDFType.SHAPELY_SCALAR:
+            return _apply_shapely_series_udf(
+                fn, return_type, serialize_geom, deserialize_geom
+            )
+
+        if udf_type == SedonaUDFType.GEO_SERIES:
+            return _apply_geo_series_udf(
+                fn, return_type, serialize_geom, deserialize_geom
+            )
+
+        raise InvalidSedonaUDFType(f"Invalid UDF type: {udf_type}")
+
+    return apply_fn
+
+
+def _apply_shapely_series_udf(
+    fn, return_type: DataType, serialize_geom: bool, deserialize_geom: bool
+):
+    def apply(series: pd.Series) -> pd.Series:
+        applied = series.apply(
+            lambda x: (
+                fn(geometry_serde.deserialize(x)[0]) if deserialize_geom else 
fn(x)
+            )
+        )
+
+        return applied.apply(
+            lambda x: geometry_serde.serialize(x) if serialize_geom else x
+        )
+
+    udf = UserDefinedFunction(
+        apply, return_type, "SedonaPandasArrowUDF", 
evalType=SEDONA_SCALAR_EVAL_TYPE
+    )
+
+    return udf
+
+
+def _apply_geo_series_udf(
+    fn, return_type: DataType, serialize_geom: bool, deserialize_geom: bool
+):
+    import geopandas as gpd
+
+    def apply(series: pd.Series) -> pd.Series:
+        series_data = series
+        if deserialize_geom:
+            series_data = gpd.GeoSeries(
+                series.apply(lambda x: geometry_serde.deserialize(x)[0])
+            )
+
+        return fn(series_data).apply(
+            lambda x: geometry_serde.serialize(x) if serialize_geom else x
+        )
+
+    return UserDefinedFunction(
+        apply, return_type, "SedonaPandasArrowUDF", 
evalType=SEDONA_SCALAR_EVAL_TYPE
+    )
+
+
+def deserialize_geometry_if_geom(data):
+    if isinstance(data, BaseGeometry):
+        return geometry_serde.deserialize(data)[0]
+
+    return data
+
+
+def serialize_to_geometry_if_geom(data, return_type: DataType):
+    if isinstance(return_type, GeometryType):
+        return geometry_serde.serialize(data)
+
+    return data
diff --git a/python/tests/utils/test_pandas_arrow_udf.py 
b/python/tests/utils/test_pandas_arrow_udf.py
new file mode 100644
index 0000000000..c0d723d214
--- /dev/null
+++ b/python/tests/utils/test_pandas_arrow_udf.py
@@ -0,0 +1,231 @@
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing,
+#  software distributed under the License is distributed on an
+#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+#  KIND, either express or implied.  See the License for the
+#  specific language governing permissions and limitations
+#  under the License.
+
+
+from sedona.sql.types import GeometryType
+from sedona.sql.functions import sedona_vectorized_udf, SedonaUDFType
+from tests import chicago_crimes_input_location
+from tests.test_base import TestBase
+import pyspark.sql.functions as f
+import shapely.geometry.base as b
+import geopandas as gpd
+import pytest
+import pyspark
+import pandas as pd
+from pyspark.sql.functions import pandas_udf
+from pyspark.sql.types import IntegerType, FloatType
+from shapely.geometry import Point
+from shapely.wkt import loads
+
+
+def non_vectorized_buffer_udf(geom: b.BaseGeometry) -> b.BaseGeometry:
+    return geom.buffer(0.1)
+
+
+@sedona_vectorized_udf(return_type=GeometryType())
+def vectorized_buffer_udf(geom: b.BaseGeometry) -> b.BaseGeometry:
+    return geom.buffer(0.1)
+
+
+@sedona_vectorized_udf(return_type=FloatType())
+def vectorized_geom_to_numeric_udf(geom: b.BaseGeometry) -> float:
+    return geom.area
+
+
+@sedona_vectorized_udf(return_type=FloatType())
+def vectorized_geom_to_numeric_udf_child_geom(geom: Point) -> float:
+    return geom.x
+
+
+@sedona_vectorized_udf(return_type=GeometryType())
+def vectorized_numeric_to_geom(x: float) -> b.BaseGeometry:
+    return Point(x, x)
+
+
+@sedona_vectorized_udf(udf_type=SedonaUDFType.GEO_SERIES, 
return_type=FloatType())
+def vectorized_series_to_numeric_udf(series: gpd.GeoSeries) -> pd.Series:
+    buffered = series.x
+
+    return buffered
+
+
+@sedona_vectorized_udf(udf_type=SedonaUDFType.GEO_SERIES, 
return_type=GeometryType())
+def vectorized_series_string_to_geom(x: pd.Series) -> b.BaseGeometry:
+    return x.apply(lambda x: loads(str(x)))
+
+
+@sedona_vectorized_udf(udf_type=SedonaUDFType.GEO_SERIES, 
return_type=GeometryType())
+def vectorized_series_string_to_geom_2(x: pd.Series):
+    return x.apply(lambda x: loads(str(x)))
+
+
+@sedona_vectorized_udf(udf_type=SedonaUDFType.GEO_SERIES, 
return_type=GeometryType())
+def vectorized_series_buffer_udf(series: gpd.GeoSeries) -> gpd.GeoSeries:
+    buffered = series.buffer(0.1)
+
+    return buffered
+
+
+@pandas_udf(IntegerType())
+def squared_udf(s: pd.Series) -> pd.Series:
+    return s**2  # Perform vectorized operation
+
+
+buffer_distanced_udf = f.udf(non_vectorized_buffer_udf, GeometryType())
+
+
+class TestSedonaArrowUDF(TestBase):
+
+    def get_area(self, df, udf_fn):
+        return (
+            df.select(udf_fn(f.col("geom")).alias("buffer"))
+            .selectExpr("SUM(ST_Area(buffer))")
+            .collect()[0][0]
+        )
+
+    @pytest.mark.skipif(
+        pyspark.__version__ < "3.5", reason="requires Spark 3.5 or higher"
+    )
+    def test_pandas_arrow_udf(self):
+        df = (
+            self.spark.read.option("header", "true")
+            .format("csv")
+            .load(chicago_crimes_input_location)
+            .selectExpr("ST_Point(x, y) AS geom")
+        )
+
+        area1 = self.get_area(df, vectorized_buffer_udf)
+        assert area1 > 478
+
+    @pytest.mark.skipif(
+        pyspark.__version__ < "3.5", reason="requires Spark 3.5 or higher"
+    )
+    def test_pandas_udf_shapely_geometry_and_numeric(self):
+        df = (
+            self.spark.read.option("header", "true")
+            .format("csv")
+            .load(chicago_crimes_input_location)
+            .selectExpr("ST_Point(x, y) AS geom", "x")
+            .select(
+                vectorized_geom_to_numeric_udf(f.col("geom")).alias("area"),
+                vectorized_geom_to_numeric_udf_child_geom(f.col("geom")).alias(
+                    "x_coordinate"
+                ),
+                vectorized_numeric_to_geom(f.col("x").cast("float")).alias(
+                    "geom_second"
+                ),
+            )
+        )
+
+        assert df.select(f.sum("area")).collect()[0][0] == 0.0
+        assert -1339276 > df.select(f.sum("x_coordinate")).collect()[0][0] > 
-1339277
+        assert (
+            -1339276
+            > df.selectExpr("ST_X(geom_second) AS x_coordinate")
+            .select(f.sum("x_coordinate"))
+            .collect()[0][0]
+            > -1339277
+        )
+
+    @pytest.mark.skipif(
+        pyspark.__version__ < "3.5", reason="requires Spark 3.5 or higher"
+    )
+    def test_pandas_udf_geoseries_geometry_and_numeric(self):
+        df = (
+            self.spark.read.option("header", "true")
+            .format("csv")
+            .load(chicago_crimes_input_location)
+            .selectExpr(
+                "ST_Point(x, y) AS geom",
+                "CONCAT('POINT(', x, ' ', y, ')') AS wkt",
+            )
+            .select(
+                
vectorized_series_to_numeric_udf(f.col("geom")).alias("x_coordinate"),
+                vectorized_series_string_to_geom(f.col("wkt")).alias("geom"),
+                
vectorized_series_string_to_geom_2(f.col("wkt")).alias("geom_2"),
+            )
+        )
+
+        assert -1339276 > df.select(f.sum("x_coordinate")).collect()[0][0] > 
-1339277
+        assert (
+            -1339276
+            > df.selectExpr("ST_X(geom) AS x_coordinate")
+            .select(f.sum("x_coordinate"))
+            .collect()[0][0]
+            > -1339277
+        )
+        assert (
+            -1339276
+            > df.selectExpr("ST_X(geom_2) AS x_coordinate")
+            .select(f.sum("x_coordinate"))
+            .collect()[0][0]
+            > -1339277
+        )
+
+    @pytest.mark.skipif(
+        pyspark.__version__ < "3.5", reason="requires Spark 3.5 or higher"
+    )
+    def test_pandas_udf_numeric_to_geometry(self):
+        df = (
+            self.spark.read.option("header", "true")
+            .format("csv")
+            .load(chicago_crimes_input_location)
+            .selectExpr("ST_Point(y, x) AS geom")
+        )
+
+        area1 = self.get_area(df, vectorized_buffer_udf)
+        assert area1 > 478
+
+    @pytest.mark.skipif(
+        pyspark.__version__ < "3.5", reason="requires Spark 3.5 or higher"
+    )
+    def test_pandas_udf_numeric_and_numeric_to_geometry(self):
+        df = (
+            self.spark.read.option("header", "true")
+            .format("csv")
+            .load(chicago_crimes_input_location)
+            .selectExpr("ST_Point(y, x) AS geom")
+        )
+
+        area1 = self.get_area(df, vectorized_buffer_udf)
+        assert area1 > 478
+
+    @pytest.mark.skipif(
+        pyspark.__version__ < "3.5", reason="requires Spark 3.5 or higher"
+    )
+    def test_geo_series_udf(self):
+        df = (
+            self.spark.read.option("header", "true")
+            .format("csv")
+            .load(chicago_crimes_input_location)
+            .selectExpr("ST_Point(y, x) AS geom")
+        )
+
+        area = self.get_area(df, vectorized_series_buffer_udf)
+
+        assert area > 478
+
+    def test_pandas_arrow_udf_compatibility(self):
+        df = (
+            self.spark.read.option("header", "true")
+            .format("csv")
+            .load(chicago_crimes_input_location)
+            .selectExpr("CAST(x AS INT) AS x")
+        )
+
+        sum_value = df.select(f.sum(squared_udf(f.col("x")))).collect()[0][0]
+        assert sum_value == 115578630
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala 
b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
index 7cfb8670be..d38ad5e1b6 100644
--- a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
@@ -24,10 +24,12 @@ import org.apache.sedona.sql.RasterRegistrator
 import org.apache.sedona.sql.UDF.Catalog
 import org.apache.sedona.sql.UDT.UdtRegistrator
 import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.sedona_sql.optimization._
 import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
 import 
org.apache.spark.sql.sedona_sql.strategy.physical.function.EvalPhysicalFunctionStrategy
-import org.apache.spark.sql.{SQLContext, SparkSession}
+import org.apache.spark.sql.{SQLContext, SparkSession, Strategy}
 
 import scala.annotation.StaticAnnotation
 import scala.util.Try
@@ -50,6 +52,7 @@ object SedonaContext {
 
   /**
    * This is the entry point of the entire Sedona system
+   *
    * @param sparkSession
    * @return
    */
@@ -64,6 +67,28 @@ object SedonaContext {
       sparkSession.experimental.extraStrategies ++= Seq(new 
JoinQueryDetector(sparkSession))
     }
 
+    val sedonaArrowStrategy = Try(
+      Class
+        .forName("org.apache.spark.sql.udf.SedonaArrowStrategy")
+        .getDeclaredConstructor()
+        .newInstance()
+        .asInstanceOf[Strategy])
+
+    val extractSedonaUDFRule =
+      Try(
+        Class
+          .forName("org.apache.spark.sql.udf.ExtractSedonaUDFRule")
+          .getDeclaredConstructor()
+          .newInstance()
+          .asInstanceOf[Rule[LogicalPlan]])
+
+    if (sedonaArrowStrategy.isSuccess && extractSedonaUDFRule.isSuccess) {
+      sparkSession.experimental.extraStrategies =
+        sparkSession.experimental.extraStrategies :+ sedonaArrowStrategy.get
+      sparkSession.experimental.extraOptimizations =
+        sparkSession.experimental.extraOptimizations :+ 
extractSedonaUDFRule.get
+    }
+
     customOptimizationsWithSession(sparkSession).foreach { opt =>
       if (!sparkSession.experimental.extraOptimizations.exists {
           case _: opt.type => true
@@ -95,6 +120,7 @@ object SedonaContext {
    * This method adds the basic Sedona configurations to the SparkSession 
Usually the user does
    * not need to call this method directly This is only needed when the user 
needs to manually
    * configure Sedona
+   *
    * @return
    */
   def builder(): SparkSession.Builder = {
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala 
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala
new file mode 100644
index 0000000000..aece26267d
--- /dev/null
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.UDF
+
+// We use constant 5000 for Sedona UDFs, 200 is Apache Spark scalar UDF
+object PythonEvalType {
+  val SQL_SCALAR_SEDONA_UDF = 5200
+  val SEDONA_UDF_TYPE_CONSTANT = 5000
+
+  def toString(pythonEvalType: Int): String = pythonEvalType match {
+    case SQL_SCALAR_SEDONA_UDF => "SQL_SCALAR_GEO_UDF"
+  }
+}
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowEvalPython.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowEvalPython.scala
new file mode 100644
index 0000000000..78e00871ad
--- /dev/null
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowEvalPython.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategies
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF}
+import org.apache.spark.sql.catalyst.plans.logical.{BaseEvalPython, 
LogicalPlan}
+
+case class SedonaArrowEvalPython(
+    udfs: Seq[PythonUDF],
+    resultAttrs: Seq[Attribute],
+    child: LogicalPlan,
+    evalType: Int)
+    extends BaseEvalPython {
+  override protected def withNewChildInternal(newChild: LogicalPlan): 
SedonaArrowEvalPython =
+    copy(child = newChild)
+}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala
new file mode 100644
index 0000000000..03e10a1602
--- /dev/null
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.udf
+
+import org.apache.sedona.sql.UDF.PythonEvalType
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
Expression, ExpressionSet, PythonUDF}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, 
Subquery}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.PYTHON_UDF
+
+import scala.collection.mutable
+
+// That rule extracts scalar Python UDFs, currently Apache Spark has
+// assert on types which blocks using the vectorized udfs with geometry type
+class ExtractSedonaUDFRule extends Rule[LogicalPlan] {
+
+  private def hasScalarPythonUDF(e: Expression): Boolean = {
+    e.exists(PythonUDF.isScalarPythonUDF)
+  }
+
+  @scala.annotation.tailrec
+  private def canEvaluateInPython(e: PythonUDF): Boolean = {
+    e.children match {
+      case Seq(u: PythonUDF) => e.evalType == u.evalType && 
canEvaluateInPython(u)
+      case children => !children.exists(hasScalarPythonUDF)
+    }
+  }
+
+  def isScalarPythonUDF(e: Expression): Boolean = {
+    e.isInstanceOf[PythonUDF] && e
+      .asInstanceOf[PythonUDF]
+      .evalType == PythonEvalType.SQL_SCALAR_SEDONA_UDF
+  }
+
+  private def collectEvaluableUDFsFromExpressions(
+      expressions: Seq[Expression]): Seq[PythonUDF] = {
+
+    var firstVisitedScalarUDFEvalType: Option[Int] = None
+
+    def canChainUDF(evalType: Int): Boolean = {
+      evalType == firstVisitedScalarUDFEvalType.get
+    }
+
+    def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match {
+      case udf: PythonUDF
+          if isScalarPythonUDF(udf) && canEvaluateInPython(udf)
+            && firstVisitedScalarUDFEvalType.isEmpty =>
+        firstVisitedScalarUDFEvalType = Some(udf.evalType)
+        Seq(udf)
+      case udf: PythonUDF
+          if isScalarPythonUDF(udf) && canEvaluateInPython(udf)
+            && canChainUDF(udf.evalType) =>
+        Seq(udf)
+      case e => e.children.flatMap(collectEvaluableUDFs)
+    }
+
+    expressions.flatMap(collectEvaluableUDFs)
+  }
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan match {
+    case s: Subquery if s.correlated => plan
+
+    case _ =>
+      plan.transformUpWithPruning(_.containsPattern(PYTHON_UDF)) {
+        case p: SedonaArrowEvalPython => p
+
+        case plan: LogicalPlan => extract(plan)
+      }
+  }
+
+  private def canonicalizeDeterministic(u: PythonUDF) = {
+    if (u.deterministic) {
+      u.canonicalized.asInstanceOf[PythonUDF]
+    } else {
+      u
+    }
+  }
+
+  private def extract(plan: LogicalPlan): LogicalPlan = {
+    val udfs = 
ExpressionSet(collectEvaluableUDFsFromExpressions(plan.expressions))
+      .filter(udf => udf.references.subsetOf(plan.inputSet))
+      .toSeq
+      .asInstanceOf[Seq[PythonUDF]]
+
+    udfs match {
+      case Seq() => plan
+      case _ => resolveUDFs(plan, udfs)
+    }
+  }
+
+  def resolveUDFs(plan: LogicalPlan, udfs: Seq[PythonUDF]): LogicalPlan = {
+    val attributeMap = mutable.HashMap[PythonUDF, Expression]()
+
+    val newChildren = adjustAttributeMap(plan, udfs, attributeMap)
+
+    
udfs.map(canonicalizeDeterministic).filterNot(attributeMap.contains).foreach { 
udf =>
+      throw new IllegalStateException(
+        s"Invalid PythonUDF $udf, requires attributes from more than one 
child.")
+    }
+
+    val rewritten = plan.withNewChildren(newChildren).transformExpressions { 
case p: PythonUDF =>
+      attributeMap.getOrElse(canonicalizeDeterministic(p), p)
+    }
+
+    val newPlan = extract(rewritten)
+    if (newPlan.output != plan.output) {
+      Project(plan.output, newPlan)
+    } else {
+      newPlan
+    }
+  }
+
+  def adjustAttributeMap(
+      plan: LogicalPlan,
+      udfs: Seq[PythonUDF],
+      attributeMap: mutable.HashMap[PythonUDF, Expression]): Seq[LogicalPlan] 
= {
+    plan.children.map { child =>
+      val validUdfs = udfs.filter { udf =>
+        udf.references.subsetOf(child.outputSet)
+      }
+
+      if (validUdfs.nonEmpty) {
+        require(
+          validUdfs.forall(isScalarPythonUDF),
+          "Can only extract scalar vectorized udf or sql batch udf")
+
+        val resultAttrs = validUdfs.zipWithIndex.map { case (u, i) =>
+          AttributeReference(s"pythonUDF$i", u.dataType)()
+        }
+
+        val evalTypes = validUdfs.map(_.evalType).toSet
+        if (evalTypes.size != 1) {
+          throw new IllegalStateException(
+            "Expected udfs have the same evalType but got different evalTypes: 
" +
+              evalTypes.mkString(","))
+        }
+        val evalType = evalTypes.head
+        val evaluation = evalType match {
+          case PythonEvalType.SQL_SCALAR_SEDONA_UDF =>
+            SedonaArrowEvalPython(validUdfs, resultAttrs, child, evalType)
+          case _ =>
+            throw new IllegalStateException("Unexpected UDF evalType")
+        }
+
+        attributeMap ++= 
validUdfs.map(canonicalizeDeterministic).zip(resultAttrs)
+        evaluation
+      } else {
+        child
+      }
+    }
+  }
+}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala
new file mode 100644
index 0000000000..7600ece507
--- /dev/null
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.udf
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF}
+import org.apache.spark.sql.catalyst.plans.logical.{BaseEvalPython, 
LogicalPlan}
+
+case class SedonaArrowEvalPython(
+    udfs: Seq[PythonUDF],
+    resultAttrs: Seq[Attribute],
+    child: LogicalPlan,
+    evalType: Int)
+    extends BaseEvalPython {
+  override protected def withNewChildInternal(newChild: LogicalPlan): 
SedonaArrowEvalPython =
+    copy(child = newChild)
+}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala
new file mode 100644
index 0000000000..a403fa6b9e
--- /dev/null
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.udf
+
+import org.apache.sedona.sql.UDF.PythonEvalType
+import org.apache.spark.api.python.ChainedPythonFunctions
+import org.apache.spark.{JobArtifactSet, TaskContext}
+import org.apache.spark.sql.Strategy
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.python.{ArrowPythonRunner, 
BatchIterator, EvalPythonExec, PythonSQLMetrics}
+import org.apache.spark.sql.types.StructType
+
+import scala.collection.JavaConverters.asScalaIteratorConverter
+
+// We use custom Strategy to avoid Apache Spark assert on types, we
+// can consider extending this to support other engines working with
+// arrow data
+class SedonaArrowStrategy extends Strategy {
+  override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+    case SedonaArrowEvalPython(udfs, output, child, evalType) =>
+      SedonaArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: 
Nil
+    case _ => Nil
+  }
+}
+
+// It's modification og Apache Spark's ArrowEvalPythonExec, we remove the 
check on the types to allow geometry types
+// here, it's initial version to allow the vectorized udf for Sedona geometry 
types. We can consider extending this
+// to support other engines working with arrow data
+case class SedonaArrowEvalPythonExec(
+    udfs: Seq[PythonUDF],
+    resultAttrs: Seq[Attribute],
+    child: SparkPlan,
+    evalType: Int)
+    extends EvalPythonExec
+    with PythonSQLMetrics {
+
+  private val batchSize = conf.arrowMaxRecordsPerBatch
+  private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+  private val largeVarTypes = conf.arrowUseLargeVarTypes
+  private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
+  private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+
+  protected override def evaluate(
+      funcs: Seq[ChainedPythonFunctions],
+      argOffsets: Array[Array[Int]],
+      iter: Iterator[InternalRow],
+      schema: StructType,
+      context: TaskContext): Iterator[InternalRow] = {
+
+    val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else 
Iterator(iter)
+
+    val columnarBatchIter = new ArrowPythonRunner(
+      funcs,
+      evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT,
+      argOffsets,
+      schema,
+      sessionLocalTimeZone,
+      largeVarTypes,
+      pythonRunnerConf,
+      pythonMetrics,
+      jobArtifactUUID).compute(batchIter, context.partitionId(), context)
+
+    columnarBatchIter.flatMap { batch =>
+      batch.rowIterator.asScala
+    }
+  }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    copy(child = newChild)
+}
diff --git 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
new file mode 100644
index 0000000000..adbb97819f
--- /dev/null
+++ 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.udf
+
+import org.apache.sedona.spark.SedonaContext
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.udf.ScalarUDF.geoPandasScalaFunction
+import org.locationtech.jts.io.WKTReader
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.matchers.should.Matchers
+
+class StrategySuite extends AnyFunSuite with Matchers {
+  val wktReader = new WKTReader()
+
+  val spark: SparkSession = {
+    val builder = SedonaContext
+      .builder()
+      .master("local[*]")
+      .appName("sedonasqlScalaTest")
+
+    val spark = SedonaContext.create(builder.getOrCreate())
+
+    spark.sparkContext.setLogLevel("ALL")
+    spark
+  }
+
+  import spark.implicits._
+
+  test("sedona geospatial UDF") {
+    val df = Seq(
+      (1, "value", wktReader.read("POINT(21 52)")),
+      (2, "value1", wktReader.read("POINT(20 50)")),
+      (3, "value2", wktReader.read("POINT(20 49)")),
+      (4, "value3", wktReader.read("POINT(20 48)")),
+      (5, "value4", wktReader.read("POINT(20 47)")))
+      .toDF("id", "value", "geom")
+      .withColumn("geom_buffer", geoPandasScalaFunction(col("geom")))
+
+    df.count shouldEqual 5
+
+    df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))")
+      .as[String]
+      .collect() should contain theSameElementsAs Seq(
+      "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))",
+      "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))",
+      "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))",
+      "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))",
+      "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))")
+  }
+}
diff --git 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
new file mode 100644
index 0000000000..c0a2d8f260
--- /dev/null
+++ 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.udf
+
+import org.apache.sedona.sql.UDF
+import org.apache.spark.TestUtils
+import org.apache.spark.api.python._
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.util.Utils
+
+import java.io.File
+import java.nio.file.{Files, Paths}
+import scala.sys.process.Process
+import scala.jdk.CollectionConverters._
+
+object ScalarUDF {
+
+  val pythonExec: String = {
+    val pythonExec =
+      sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", 
sys.env.getOrElse("PYSPARK_PYTHON", "python3"))
+    if (TestUtils.testCommandAvailable(pythonExec)) {
+      pythonExec
+    } else {
+      "python"
+    }
+  }
+
+  private[spark] lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "")
+  protected lazy val sparkHome: String = {
+    sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
+  }
+
+  private lazy val py4jPath =
+    Paths.get(sparkHome, "python", "lib", 
PythonUtils.PY4J_ZIP_NAME).toAbsolutePath
+  private[spark] lazy val pysparkPythonPath = s"$py4jPath"
+
+  private lazy val isPythonAvailable: Boolean = 
TestUtils.testCommandAvailable(pythonExec)
+
+  lazy val pythonVer: String = if (isPythonAvailable) {
+    Process(
+      Seq(pythonExec, "-c", "import sys; print('%d.%d' % 
sys.version_info[:2])"),
+      None,
+      "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!.trim()
+  } else {
+    throw new RuntimeException(s"Python executable [$pythonExec] is 
unavailable.")
+  }
+
+  protected def withTempPath(f: File => Unit): Unit = {
+    val path = Utils.createTempDir()
+    path.delete()
+    try f(path)
+    finally Utils.deleteRecursively(path)
+  }
+
+  val pandasFunc: Array[Byte] = {
+    var binaryPandasFunc: Array[Byte] = null
+    withTempPath { path =>
+      println(path)
+      Process(
+        Seq(
+          pythonExec,
+          "-c",
+          f"""
+            |from pyspark.sql.types import IntegerType
+            |from shapely.geometry import Point
+            |from sedona.sql.types import GeometryType
+            |from pyspark.serializers import CloudPickleSerializer
+            |from sedona.utils import geometry_serde
+            |from shapely import box
+            |f = open('$path', 'wb');
+            |def w(x):
+            |    def apply_function(w):
+            |        geom, offset = geometry_serde.deserialize(w)
+            |        bounds = geom.buffer(1).bounds
+            |        x = box(*bounds)
+            |        return geometry_serde.serialize(x)
+            |    return x.apply(apply_function)
+            |f.write(CloudPickleSerializer().dumps((w, GeometryType())))
+            |""".stripMargin),
+        None,
+        "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
+      binaryPandasFunc = Files.readAllBytes(path.toPath)
+    }
+    assert(binaryPandasFunc != null)
+    binaryPandasFunc
+  }
+
+  private val workerEnv = new java.util.HashMap[String, String]()
+  workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath")
+
+  val geoPandasScalaFunction: UserDefinedPythonFunction = 
UserDefinedPythonFunction(
+    name = "geospatial_udf",
+    func = SimplePythonFunction(
+      command = pandasFunc,
+      envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],
+      pythonIncludes = List.empty[String].asJava,
+      pythonExec = pythonExec,
+      pythonVer = pythonVer,
+      broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
+      accumulator = null),
+    dataType = GeometryUDT,
+    pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF,
+    udfDeterministic = true)
+}


Reply via email to