paleolimbot commented on code in PR #1825:
URL: https://github.com/apache/sedona/pull/1825#discussion_r1967308179
##########
python/sedona/utils/geoarrow.py:
##########
@@ -14,13 +14,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import itertools
+from typing import List, Callable
# We may be able to achieve streaming rather than complete materialization by
using
# with the ArrowStreamSerializer (instead of the ArrowCollectSerializer)
-from sedona.sql.types import GeometryType
from sedona.sql.st_functions import ST_AsEWKB
+from pyspark.sql import SparkSession
+from pyspark.sql import DataFrame
+from pyspark.sql.types import StructType, StructField, DataType, ArrayType,
MapType
+import pyarrow as pa
Review Comment:
I am not sure what the dependency situation is like for spark, but it may be
worth making this a lazy import (e.g., like in `dataframe_to_arrow` so that
when we `import from seconda.utils.geoarrow` from `sedona/spark/__init__.py` we
don't necessarily require pyarrow to be installed (alternatively, we could add
pyarrow to the `apache-sedona[spark]` extras to match the runtime requirement).
##########
python/sedona/utils/geoarrow.py:
##########
@@ -186,3 +198,123 @@ def unique_srid_from_ewkb(obj):
import pyproj
return pyproj.CRS(f"EPSG:{epsg_code}")
+
+
+def _dedup_names(names: List[str]) -> List[str]:
+ if len(set(names)) == len(names):
+ return names
+ else:
+
+ def _gen_dedup(_name: str) -> Callable[[], str]:
+ _i = itertools.count()
+ return lambda: f"{_name}_{next(_i)}"
+
+ def _gen_identity(_name: str) -> Callable[[], str]:
+ return lambda: _name
+
+ gen_new_name = {
+ name: _gen_dedup(name) if len(list(group)) > 1 else
_gen_identity(name)
+ for name, group in itertools.groupby(sorted(names))
+ }
+ return [gen_new_name[name]() for name in names]
+
+
+def _deduplicate_field_names(dt: DataType) -> DataType:
Review Comment:
```suggestion
# Backport from Spark 4.0
#
https://github.com/apache/spark/blob/3515b207c41d78194d11933cd04bddc21f8418dd/python/pyspark/sql/pandas/types.py#L1385
def _deduplicate_field_names(dt: DataType) -> DataType:
```
##########
python/sedona/utils/geoarrow.py:
##########
@@ -186,3 +198,123 @@ def unique_srid_from_ewkb(obj):
import pyproj
return pyproj.CRS(f"EPSG:{epsg_code}")
+
+
+def _dedup_names(names: List[str]) -> List[str]:
+ if len(set(names)) == len(names):
+ return names
+ else:
+
+ def _gen_dedup(_name: str) -> Callable[[], str]:
+ _i = itertools.count()
+ return lambda: f"{_name}_{next(_i)}"
+
+ def _gen_identity(_name: str) -> Callable[[], str]:
+ return lambda: _name
+
+ gen_new_name = {
+ name: _gen_dedup(name) if len(list(group)) > 1 else
_gen_identity(name)
+ for name, group in itertools.groupby(sorted(names))
+ }
+ return [gen_new_name[name]() for name in names]
+
+
+def _deduplicate_field_names(dt: DataType) -> DataType:
+ if isinstance(dt, StructType):
+ dedup_field_names = _dedup_names(dt.names)
+
+ return StructType(
+ [
+ StructField(
+ dedup_field_names[i],
+ _deduplicate_field_names(field.dataType),
+ nullable=field.nullable,
+ )
+ for i, field in enumerate(dt.fields)
+ ]
+ )
+ elif isinstance(dt, ArrayType):
+ return ArrayType(
+ _deduplicate_field_names(dt.elementType),
containsNull=dt.containsNull
+ )
+ elif isinstance(dt, MapType):
+ return MapType(
+ _deduplicate_field_names(dt.keyType),
+ _deduplicate_field_names(dt.valueType),
+ valueContainsNull=dt.valueContainsNull,
+ )
+ else:
+ return dt
+
+
+def infer_schema(gdf: gpd.GeoDataFrame) -> StructType:
+ fields = gdf.dtypes.reset_index().values.tolist()
+ geom_fields = []
+ index = 0
+ for name, dtype in fields:
+ if dtype == "geometry":
+ geom_fields.append((index, name))
+ continue
+
+ index += 1
+
+ if not geom_fields:
+ raise ValueError("No geometry field found in the GeoDataFrame")
+
+ pa_schema = pa.Schema.from_pandas(
+ gdf.drop([name for _, name in geom_fields], axis=1)
+ )
+
+ spark_schema = []
+
+ for field in pa_schema:
+ field_type = field.type
+ spark_type = from_arrow_type(field_type)
+ spark_schema.append(StructField(field.name, spark_type, True))
+
+ for index, geom_field in geom_fields:
+ spark_schema.insert(index, StructField(geom_field, GeometryType(),
True))
+
+ return StructType(spark_schema)
+
+
+def create_spatial_dataframe(spark: SparkSession, gdf: gpd.GeoDataFrame) ->
DataFrame:
Review Comment:
```suggestion
# Modified backport from Spark 4.0
#
https://github.com/apache/spark/blob/3515b207c41d78194d11933cd04bddc21f8418dd/python/pyspark/sql/pandas/conversion.py#L632
def create_spatial_dataframe(spark: SparkSession, gdf: gpd.GeoDataFrame) ->
DataFrame:
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]