This is an automated email from the ASF dual-hosted git repository. jiayu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push: new 629d47958c [GH-2066] Implement spatial index sindex query() and unit tests (#2067) 629d47958c is described below commit 629d47958cadcdfa82a41d1825f2bb410de90c37 Author: Feng Zhang <f...@wherobots.com> AuthorDate: Wed Jul 9 11:48:02 2025 -0700 [GH-2066] Implement spatial index sindex query() and unit tests (#2067) * [GH-] Implement GeoSeries sindex and tests * fix test_spatial_index_with_shapely_array * fix NumPy comparisons in the test * do not test internal objects * address the copilot comments and refactor __init__ * fix spark mode count * add index building logic * add query implementation for spatial index * remove unused code * fix test * shapely only support strtree * fix results and add log_advice to query() * switch to use StructuredAdapter --- python/sedona/geopandas/base.py | 2 +- python/sedona/geopandas/geodataframe.py | 18 ++- python/sedona/geopandas/geoindex.py | 28 ---- python/sedona/geopandas/geoseries.py | 29 +++- python/sedona/geopandas/sindex.py | 225 ++++++++++++++++++++++++++++++++ python/tests/geopandas/test_sindex.py | 145 ++++++++++++++++++++ 6 files changed, 410 insertions(+), 37 deletions(-) diff --git a/python/sedona/geopandas/base.py b/python/sedona/geopandas/base.py index f99a19cff8..8bc1ad4dd3 100644 --- a/python/sedona/geopandas/base.py +++ b/python/sedona/geopandas/base.py @@ -75,7 +75,7 @@ class GeoFrame(metaclass=ABCMeta): @property @abstractmethod - def geoindex(self) -> "GeoIndex": + def sindex(self) -> "SpatialIndex": raise NotImplementedError("This method is not implemented yet.") @abstractmethod diff --git a/python/sedona/geopandas/geodataframe.py b/python/sedona/geopandas/geodataframe.py index a2f90dff4b..b627d06a7d 100644 --- a/python/sedona/geopandas/geodataframe.py +++ b/python/sedona/geopandas/geodataframe.py @@ -28,7 +28,7 @@ from pyspark.pandas.internal import InternalFrame from sedona.geopandas._typing import Label from sedona.geopandas.base import GeoFrame -from sedona.geopandas.geoindex import GeoIndex +from sedona.geopandas.sindex import SpatialIndex class GeoDataFrame(GeoFrame, pspd.DataFrame): @@ -250,9 +250,19 @@ class GeoDataFrame(GeoFrame, pspd.DataFrame): raise NotImplementedError("This method is not implemented yet.") @property - def geoindex(self) -> GeoIndex: - # Implementation of the abstract method - raise NotImplementedError("This method is not implemented yet.") + def sindex(self) -> SpatialIndex | None: + """ + Returns a spatial index for the GeoDataFrame. + The spatial index allows for efficient spatial queries. If the spatial + index cannot be created (e.g., no geometry column is present), this + property will return None. + Returns: + - SpatialIndex: The spatial index for the GeoDataFrame. + - None: If the spatial index is not supported. + """ + if "geometry" in self.columns: + return SpatialIndex(self._internal.spark_frame, column_name="geometry") + return None def copy(self, deep=False): """ diff --git a/python/sedona/geopandas/geoindex.py b/python/sedona/geopandas/geoindex.py deleted file mode 100644 index 4dbc04b742..0000000000 --- a/python/sedona/geopandas/geoindex.py +++ /dev/null @@ -1,28 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -class GeoIndex: - """ - A placeholder class for GeoIndex. - """ - - def __init__(self): - raise NotImplementedError("This method is not implemented yet.") - - def some_method(self): - raise NotImplementedError("This method is not implemented yet.") diff --git a/python/sedona/geopandas/geoseries.py b/python/sedona/geopandas/geoseries.py index 2e4dc6bb2a..6160438a60 100644 --- a/python/sedona/geopandas/geoseries.py +++ b/python/sedona/geopandas/geoseries.py @@ -37,7 +37,7 @@ from shapely.geometry.base import BaseGeometry from sedona.geopandas._typing import Label from sedona.geopandas.base import GeoFrame from sedona.geopandas.geodataframe import GeoDataFrame -from sedona.geopandas.geoindex import GeoIndex +from sedona.geopandas.sindex import SpatialIndex from pyspark.pandas.internal import ( SPARK_DEFAULT_INDEX_NAME, # __index_level_0__ @@ -506,9 +506,30 @@ class GeoSeries(GeoFrame, pspd.Series): return self @property - def geoindex(self) -> "GeoIndex": - # Implementation of the abstract method - raise NotImplementedError("This method is not implemented yet.") + def sindex(self) -> SpatialIndex: + """ + Returns a spatial index built from the geometries. + + Returns + ------- + SpatialIndex + The spatial index for this GeoDataFrame. + + Examples + -------- + >>> from shapely.geometry import Point + >>> from sedona.geopandas import GeoDataFrame + >>> + >>> gdf = GeoDataFrame([{"geometry": Point(1, 1), "value": 1}, + ... {"geometry": Point(2, 2), "value": 2}]) + >>> index = gdf.sindex + >>> index.size + 2 + """ + geometry_column = self.get_first_geometry_column() + if geometry_column is None: + raise ValueError("No geometry column found in GeoSeries") + return SpatialIndex(self._internal.spark_frame, column_name=geometry_column) def copy(self, deep=False): """ diff --git a/python/sedona/geopandas/sindex.py b/python/sedona/geopandas/sindex.py new file mode 100644 index 0000000000..426b2d12ab --- /dev/null +++ b/python/sedona/geopandas/sindex.py @@ -0,0 +1,225 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +from pyspark.pandas.utils import log_advice +from pyspark.sql import DataFrame as PySparkDataFrame + +from sedona.spark import StructuredAdapter +from sedona.spark.core.enums import IndexType + + +class SpatialIndex: + """ + A wrapper around Sedona's spatial index functionality. + """ + + def __init__(self, geometry, index_type="strtree", column_name=None): + """ + Initialize the SpatialIndex with geometry data. + + Parameters + ---------- + geometry : np.array of Shapely geometries, PySparkDataFrame column, or PySparkDataFrame + index_type : str, default "strtree" + The type of spatial index to use. + column_name : str, optional + The column name to extract geometry from if `geometry` is a PySparkDataFrame. + """ + + if isinstance(geometry, np.ndarray): + self.geometry = geometry + self.index_type = index_type + self._dataframe = None + self._is_spark = False + # Build local index for numpy array + self._build_local_index() + elif isinstance(geometry, PySparkDataFrame): + if column_name is None: + raise ValueError( + "column_name must be specified when geometry is a PySparkDataFrame" + ) + self.geometry = geometry[column_name] + self.index_type = index_type + self._dataframe = geometry + self._is_spark = True + # Build distributed spatial index + self._build_spark_index(column_name) + else: + raise TypeError( + "Invalid type for `geometry`. Expected np.array or PySparkDataFrame." + ) + + def query(self, geometry, predicate=None, sort=False): + """ + Query the spatial index for geometries that intersect the given geometry. + + Parameters + ---------- + geometry : Shapely geometry + The geometry to query against the spatial index. + predicate : str, optional + Spatial predicate to filter results (e.g., 'intersects', 'contains'). + sort : bool, optional, default False + Whether to sort the results. + + Returns + ------- + list + List of indices of matching geometries. + """ + log_advice( + "`query` returns local list of indices of matching geometries onto driver's memory. " + "It should only be used if the resulting collection is expected to be small." + ) + + if self.is_empty: + return [] + + if self._is_spark: + # For Spark-based spatial index + from sedona.spark.core.spatialOperator import RangeQuery + + # Execute the spatial range query + if predicate == "contains": + result_rdd = RangeQuery.SpatialRangeQuery( + self._indexed_rdd, geometry, True, True + ) + else: # Default to intersects + result_rdd = RangeQuery.SpatialRangeQuery( + self._indexed_rdd, geometry, False, True + ) + + results = result_rdd.collect() + return results + else: + # For local spatial index based on Shapely STRtree + if predicate == "contains": + # STRtree doesn't directly support contains predicate + # We need to filter results after querying + candidate_indices = self._index.query(geometry) + results = [ + i for i in candidate_indices if geometry.contains(self.geometry[i]) + ] + else: + # Default is intersects + results = self._index.query(geometry) + + if sort and results: + # Sort by distance to the query geometry if requested + results = sorted( + results, key=lambda i: self.geometry[i].distance(geometry) + ) + + return results + + def nearest(self, geometry, k=1, return_distance=False): + """ + Find the nearest geometry in the spatial index. + + Parameters + ---------- + geometry : Shapely geometry + The geometry to find the nearest neighbor for. + k : int, optional, default 1 + Number of nearest neighbors to find. + return_distance : bool, optional, default False + Whether to return distances along with indices. + + Returns + ------- + list or tuple + List of indices of nearest geometries, optionally with distances. + """ + # Placeholder for KNN query using Sedona + raise NotImplementedError("This method is not implemented yet.") + + def intersection(self, bounds): + """ + Find geometries that intersect the given bounding box. + + Parameters + ---------- + bounds : tuple + Bounding box as (min_x, min_y, max_x, max_y). + + Returns + ------- + list + List of indices of matching geometries. + """ + raise NotImplementedError("This method is not implemented yet.") + + @property + def size(self): + """ + Get the size of the spatial index. + + Returns + ------- + int + Number of geometries in the index. + """ + if self._is_spark: + return self._dataframe.count() + return len(self.geometry) + + @property + def is_empty(self): + """ + Check if the spatial index is empty. + + Returns + ------- + bool + True if the index is empty, False otherwise. + """ + return self.size == 0 + + def _build_spark_index(self, column_name): + """ + Build a distributed spatial index on the geometry column of the DataFrame. + + This uses Sedona's built-in indexing functionality. + """ + + # Convert index_type string to Sedona IndexType enum + index_type_map = {"strtree": IndexType.RTREE, "quadtree": IndexType.QUADTREE} + sedona_index_type = index_type_map.get(self.index_type.lower(), IndexType.RTREE) + + # Create a SpatialRDD from the DataFrame + spatial_rdd = StructuredAdapter.toSpatialRdd(self._dataframe, column_name) + + # Build spatial index + spatial_rdd.buildIndex(sedona_index_type, False) + + # Store the indexed RDD + self._indexed_rdd = spatial_rdd + + def _build_local_index(self): + """ + Build a local spatial index for numpy array of geometries. + """ + from shapely.strtree import STRtree + + if len(self.geometry) > 0: + if self.index_type.lower() == "strtree": + self._index = STRtree(self.geometry) + else: + raise ValueError( + f"Unsupported index type: {self.index_type}. Only 'strtree' is supported for local indexing." + ) diff --git a/python/tests/geopandas/test_sindex.py b/python/tests/geopandas/test_sindex.py new file mode 100644 index 0000000000..a6e76b65d0 --- /dev/null +++ b/python/tests/geopandas/test_sindex.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +from pyspark.sql.functions import expr +from shapely.geometry import Point, Polygon, LineString + +from tests.test_base import TestBase +from sedona.geopandas import GeoSeries +from sedona.geopandas.sindex import SpatialIndex + + +class TestSpatialIndex(TestBase): + """Tests for the spatial index functionality in GeoSeries.""" + + def setup_method(self): + """Set up test data.""" + # Create a GeoSeries with point geometries + self.points = GeoSeries( + [Point(0, 0), Point(1, 1), Point(2, 2), Point(3, 3), Point(4, 4)] + ) + + # Create a GeoSeries with polygon geometries + self.polygons = GeoSeries( + [ + Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]), + Polygon([(1, 1), (2, 1), (2, 2), (1, 2)]), + Polygon([(2, 2), (3, 2), (3, 3), (2, 3)]), + Polygon([(3, 3), (4, 3), (4, 4), (3, 4)]), + Polygon([(4, 4), (5, 4), (5, 5), (4, 5)]), + ] + ) + + # Create a GeoSeries with line geometries + self.lines = GeoSeries( + [ + LineString([(0, 0), (1, 1)]), + LineString([(1, 1), (2, 2)]), + LineString([(2, 2), (3, 3)]), + LineString([(3, 3), (4, 4)]), + LineString([(4, 4), (5, 5)]), + ] + ) + + def test_sindex_property_exists(self): + """Test that the sindex property exists on GeoSeries.""" + assert hasattr(self.points, "sindex") + assert hasattr(self.polygons, "sindex") + assert hasattr(self.lines, "sindex") + + def test_query_with_point(self): + """Test querying the spatial index with a point geometry.""" + # Create a list of Shapely geometries - squares around points (0,0), (1,1), etc. + geometries = [ + Polygon( + [ + (i - 0.5, j - 0.5), + (i + 0.5, j - 0.5), + (i + 0.5, j + 0.5), + (i - 0.5, j + 0.5), + ] + ) + for i, j in [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)] + ] + + # Create a spatial index from the geometries + geom_array = np.array(geometries, dtype=object) + sindex = SpatialIndex(geom_array) + + # Test query with a point that should intersect with one polygon + query_point = Point(2.2, 2.2) + result_indices = sindex.query(query_point) + assert len(result_indices) == 1 + + # Test query with a point that intersects no polygons + empty_point = Point(10, 10) + empty_results = sindex.query(empty_point) + assert len(empty_results) == 0 + + def test_query_with_spark_dataframe(self): + """Test querying the spatial index with a Spark DataFrame.""" + # Create a spatial DataFrame with polygons + polygons_data = [ + (1, "POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))"), + (2, "POLYGON((1 1, 2 1, 2 2, 1 2, 1 1))"), + (3, "POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))"), + (4, "POLYGON((3 3, 4 3, 4 4, 3 4, 3 3))"), + (5, "POLYGON((4 4, 5 4, 5 5, 4 5, 4 4))"), + ] + + df = self.spark.createDataFrame(polygons_data, ["id", "wkt"]) + spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)")) + + # Create a SpatialIndex from the DataFrame + sindex = SpatialIndex(spatial_df, index_type="strtree", column_name="geometry") + + # Test query with a point that should intersect with one polygon + from shapely.geometry import Point + + query_point = Point(2.2, 2.2) + + # Execute query + result_indices = sindex.query(query_point, "contains") + + # Verify results - should find at least one result (polygon containing the point) + assert len(result_indices) > 0 + + # Test query with a polygon that should intersect multiple polygons + from shapely.geometry import box + + query_box = box(1.5, 1.5, 3.5, 3.5) + + # Execute query + box_results = sindex.query(query_box, predicate="contains") + + # Verify results - should find multiple polygons + assert len(box_results) > 1 + + # Test with contains predicate + # The query box fully contains polygon at index 2 (POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))) + contains_results = sindex.query(query_box, predicate="contains") + + # Verify contains results + assert len(contains_results) >= 1 + + # Test with a point outside any polygon + outside_point = Point(10, 10) + outside_results = sindex.query(outside_point) + + # Verify no results for point outside + assert len(outside_results) == 0