This is an automated email from the ASF dual-hosted git repository.

imbruced pushed a commit to branch 
SEDONA-714-add-geopandas-to-spark-arrow-conversion
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit ebdee33ee138f10e90057926f9c7998f517d78ff
Author: pawelkocinski <[email protected]>
AuthorDate: Sun Feb 23 21:53:47 2025 +0100

    SEDONA-714 Add geopandas to spark arrow conversion.
---
 python/sedona/utils/geoarrow.py                    | 85 +++++++++++++++++++++-
 python/tests/test_base.py                          |  2 +-
 .../test_arrow_conversion_geopandas_to_sedona.py   | 77 ++++++++++++++++++++
 3 files changed, 162 insertions(+), 2 deletions(-)

diff --git a/python/sedona/utils/geoarrow.py b/python/sedona/utils/geoarrow.py
index b8ade8528b..8c730d9a39 100644
--- a/python/sedona/utils/geoarrow.py
+++ b/python/sedona/utils/geoarrow.py
@@ -19,8 +19,18 @@
 # with the ArrowStreamSerializer (instead of the ArrowCollectSerializer)
 
 
-from sedona.sql.types import GeometryType
 from sedona.sql.st_functions import ST_AsEWKB
+from pyspark.sql import SparkSession
+from pyspark.sql import DataFrame
+from pyspark.sql.types import StructType, StructField
+import pyarrow as pa
+
+from sedona.sql.types import GeometryType
+import geopandas as gpd
+from pyspark.sql.pandas.types import (
+    from_arrow_type,
+)
+from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
 
 
 def dataframe_to_arrow(df, crs=None):
@@ -186,3 +196,76 @@ def unique_srid_from_ewkb(obj):
     import pyproj
 
     return pyproj.CRS(f"EPSG:{epsg_code}")
+
+
+def infer_schema(gdf: gpd.GeoDataFrame) -> StructType:
+    fields = gdf.dtypes.reset_index().values.tolist()
+    geom_fields = []
+    index = 0
+    for name, dtype in fields:
+        if dtype == "geometry":
+            geom_fields.append((index, name))
+            continue
+
+        index += 1
+
+    if not geom_fields:
+        raise ValueError("No geometry field found in the GeoDataFrame")
+
+    pa_schema = pa.Schema.from_pandas(
+        gdf.drop([name for _, name in geom_fields], axis=1)
+    )
+    spark_schema = []
+
+    for field in pa_schema:
+        field_type = field.type
+        spark_type = from_arrow_type(field_type)
+        spark_schema.append(StructField(field.name, spark_type, True))
+
+    for index, geom_field in geom_fields:
+        spark_schema.insert(index, StructField(geom_field, GeometryType(), 
True))
+
+    return StructType(spark_schema)
+
+
+def create_spatial_dataframe(spark: SparkSession, gdf: gpd.GeoDataFrame) -> 
DataFrame:
+    from pyspark.sql.pandas.types import (
+        to_arrow_type,
+        _deduplicate_field_names,
+    )
+
+    def reader_func(temp_filename):
+        return spark._jvm.PythonSQLUtils.readArrowStreamFromFile(temp_filename)
+
+    def create_iter_server():
+        return spark._jvm.ArrowIteratorServer()
+
+    schema = infer_schema(gdf)
+    timezone = spark._jconf.sessionLocalTimeZone()
+    step = spark._jconf.arrowMaxRecordsPerBatch()
+    step = step if step > 0 else len(gdf)
+    pdf_slices = (gdf.iloc[start : start + step] for start in range(0, 
len(gdf), step))
+    spark_types = [_deduplicate_field_names(f.dataType) for f in schema.fields]
+
+    arrow_data = [
+        [
+            (c, to_arrow_type(t) if t is not None else None, t)
+            for (_, c), t in zip(pdf_slice.items(), spark_types)
+        ]
+        for pdf_slice in pdf_slices
+    ]
+
+    safecheck = spark._jconf.arrowSafeTypeConversion()
+    ser = ArrowStreamPandasSerializer(timezone, safecheck)
+    jiter = spark._sc._serialize_to_jvm(
+        arrow_data, ser, reader_func, create_iter_server
+    )
+
+    jsparkSession = spark._jsparkSession
+    jdf = spark._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(), 
jsparkSession)
+
+    df = DataFrame(jdf, spark)
+
+    df._schema = schema
+
+    return df
diff --git a/python/tests/test_base.py b/python/tests/test_base.py
index 2769a93cdd..710b7a5564 100644
--- a/python/tests/test_base.py
+++ b/python/tests/test_base.py
@@ -16,7 +16,7 @@
 #  under the License.
 import os
 from tempfile import mkdtemp
-from typing import Iterable, Union
+from typing import Iterable
 
 
 import pyspark
diff --git a/python/tests/utils/test_arrow_conversion_geopandas_to_sedona.py 
b/python/tests/utils/test_arrow_conversion_geopandas_to_sedona.py
new file mode 100644
index 0000000000..e2c8344bc3
--- /dev/null
+++ b/python/tests/utils/test_arrow_conversion_geopandas_to_sedona.py
@@ -0,0 +1,77 @@
+import pytest
+
+from sedona.sql.types import GeometryType
+from sedona.utils.geoarrow import create_spatial_dataframe
+from tests.test_base import TestBase
+import geopandas as gpd
+
+
+class TestGeopandasToSedonaWithArrow(TestBase):
+
+    def test_conversion_dataframe(self):
+        gdf = gpd.GeoDataFrame(
+            {
+                "name": ["Sedona", "Apache"],
+                "geometry": gpd.points_from_xy([0, 1], [0, 1]),
+            }
+        )
+
+        df = create_spatial_dataframe(self.spark, gdf)
+
+        assert df.count() == 2
+        assert df.columns == ["name", "geometry"]
+        assert df.schema["geometry"].dataType == GeometryType()
+
+    def test_different_geometry_positions(self):
+        gdf = gpd.GeoDataFrame(
+            {
+                "geometry": gpd.points_from_xy([0, 1], [0, 1]),
+                "name": ["Sedona", "Apache"],
+            }
+        )
+
+        gdf2 = gpd.GeoDataFrame(
+            {
+                "name": ["Sedona", "Apache"],
+                "name1": ["Sedona", "Apache"],
+                "name2": ["Sedona", "Apache"],
+                "geometry": gpd.points_from_xy([0, 1], [0, 1]),
+            }
+        )
+
+        df1 = create_spatial_dataframe(self.spark, gdf)
+        df2 = create_spatial_dataframe(self.spark, gdf2)
+
+        assert df1.count() == 2
+        assert df1.columns == ["geometry", "name"]
+        assert df1.schema["geometry"].dataType == GeometryType()
+
+        assert df2.count() == 2
+        assert df2.columns == ["name", "name1", "name2", "geometry"]
+        assert df2.schema["geometry"].dataType == GeometryType()
+
+    def test_multiple_geometry_columns(self):
+        gdf = gpd.GeoDataFrame(
+            {
+                "name": ["Sedona", "Apache"],
+                "geometry": gpd.points_from_xy([0, 1], [0, 1]),
+                "geometry2": gpd.points_from_xy([0, 1], [0, 1]),
+            }
+        )
+
+        df = create_spatial_dataframe(self.spark, gdf)
+
+        assert df.count() == 2
+        assert df.columns == ["name", "geometry2", "geometry"]
+        assert df.schema["geometry"].dataType == GeometryType()
+        assert df.schema["geometry2"].dataType == GeometryType()
+
+    def test_missing_geometry_column(self):
+        gdf = gpd.GeoDataFrame(
+            {
+                "name": ["Sedona", "Apache"],
+            },
+        )
+
+        with pytest.raises(ValueError):
+            create_spatial_dataframe(self.spark, gdf)

Reply via email to