This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 629d47958c [GH-2066] Implement spatial index sindex query() and unit 
tests (#2067)
629d47958c is described below

commit 629d47958cadcdfa82a41d1825f2bb410de90c37
Author: Feng Zhang <f...@wherobots.com>
AuthorDate: Wed Jul 9 11:48:02 2025 -0700

    [GH-2066] Implement spatial index sindex query() and unit tests (#2067)
    
    * [GH-] Implement GeoSeries sindex and tests
    
    * fix test_spatial_index_with_shapely_array
    
    * fix NumPy comparisons in the test
    
    * do not test internal objects
    
    * address the copilot comments and refactor __init__
    
    * fix spark mode count
    
    * add index building logic
    
    * add query implementation for spatial index
    
    * remove unused code
    
    * fix test
    
    * shapely only support strtree
    
    * fix results and add log_advice to query()
    
    * switch to use StructuredAdapter
---
 python/sedona/geopandas/base.py         |   2 +-
 python/sedona/geopandas/geodataframe.py |  18 ++-
 python/sedona/geopandas/geoindex.py     |  28 ----
 python/sedona/geopandas/geoseries.py    |  29 +++-
 python/sedona/geopandas/sindex.py       | 225 ++++++++++++++++++++++++++++++++
 python/tests/geopandas/test_sindex.py   | 145 ++++++++++++++++++++
 6 files changed, 410 insertions(+), 37 deletions(-)

diff --git a/python/sedona/geopandas/base.py b/python/sedona/geopandas/base.py
index f99a19cff8..8bc1ad4dd3 100644
--- a/python/sedona/geopandas/base.py
+++ b/python/sedona/geopandas/base.py
@@ -75,7 +75,7 @@ class GeoFrame(metaclass=ABCMeta):
 
     @property
     @abstractmethod
-    def geoindex(self) -> "GeoIndex":
+    def sindex(self) -> "SpatialIndex":
         raise NotImplementedError("This method is not implemented yet.")
 
     @abstractmethod
diff --git a/python/sedona/geopandas/geodataframe.py 
b/python/sedona/geopandas/geodataframe.py
index a2f90dff4b..b627d06a7d 100644
--- a/python/sedona/geopandas/geodataframe.py
+++ b/python/sedona/geopandas/geodataframe.py
@@ -28,7 +28,7 @@ from pyspark.pandas.internal import InternalFrame
 
 from sedona.geopandas._typing import Label
 from sedona.geopandas.base import GeoFrame
-from sedona.geopandas.geoindex import GeoIndex
+from sedona.geopandas.sindex import SpatialIndex
 
 
 class GeoDataFrame(GeoFrame, pspd.DataFrame):
@@ -250,9 +250,19 @@ class GeoDataFrame(GeoFrame, pspd.DataFrame):
         raise NotImplementedError("This method is not implemented yet.")
 
     @property
-    def geoindex(self) -> GeoIndex:
-        # Implementation of the abstract method
-        raise NotImplementedError("This method is not implemented yet.")
+    def sindex(self) -> SpatialIndex | None:
+        """
+        Returns a spatial index for the GeoDataFrame.
+        The spatial index allows for efficient spatial queries. If the spatial
+        index cannot be created (e.g., no geometry column is present), this
+        property will return None.
+        Returns:
+        - SpatialIndex: The spatial index for the GeoDataFrame.
+        - None: If the spatial index is not supported.
+        """
+        if "geometry" in self.columns:
+            return SpatialIndex(self._internal.spark_frame, 
column_name="geometry")
+        return None
 
     def copy(self, deep=False):
         """
diff --git a/python/sedona/geopandas/geoindex.py 
b/python/sedona/geopandas/geoindex.py
deleted file mode 100644
index 4dbc04b742..0000000000
--- a/python/sedona/geopandas/geoindex.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-class GeoIndex:
-    """
-    A placeholder class for GeoIndex.
-    """
-
-    def __init__(self):
-        raise NotImplementedError("This method is not implemented yet.")
-
-    def some_method(self):
-        raise NotImplementedError("This method is not implemented yet.")
diff --git a/python/sedona/geopandas/geoseries.py 
b/python/sedona/geopandas/geoseries.py
index 2e4dc6bb2a..6160438a60 100644
--- a/python/sedona/geopandas/geoseries.py
+++ b/python/sedona/geopandas/geoseries.py
@@ -37,7 +37,7 @@ from shapely.geometry.base import BaseGeometry
 from sedona.geopandas._typing import Label
 from sedona.geopandas.base import GeoFrame
 from sedona.geopandas.geodataframe import GeoDataFrame
-from sedona.geopandas.geoindex import GeoIndex
+from sedona.geopandas.sindex import SpatialIndex
 
 from pyspark.pandas.internal import (
     SPARK_DEFAULT_INDEX_NAME,  # __index_level_0__
@@ -506,9 +506,30 @@ class GeoSeries(GeoFrame, pspd.Series):
         return self
 
     @property
-    def geoindex(self) -> "GeoIndex":
-        # Implementation of the abstract method
-        raise NotImplementedError("This method is not implemented yet.")
+    def sindex(self) -> SpatialIndex:
+        """
+        Returns a spatial index built from the geometries.
+
+        Returns
+        -------
+        SpatialIndex
+            The spatial index for this GeoDataFrame.
+
+        Examples
+        --------
+        >>> from shapely.geometry import Point
+        >>> from sedona.geopandas import GeoDataFrame
+        >>>
+        >>> gdf = GeoDataFrame([{"geometry": Point(1, 1), "value": 1},
+        ...                     {"geometry": Point(2, 2), "value": 2}])
+        >>> index = gdf.sindex
+        >>> index.size
+        2
+        """
+        geometry_column = self.get_first_geometry_column()
+        if geometry_column is None:
+            raise ValueError("No geometry column found in GeoSeries")
+        return SpatialIndex(self._internal.spark_frame, 
column_name=geometry_column)
 
     def copy(self, deep=False):
         """
diff --git a/python/sedona/geopandas/sindex.py 
b/python/sedona/geopandas/sindex.py
new file mode 100644
index 0000000000..426b2d12ab
--- /dev/null
+++ b/python/sedona/geopandas/sindex.py
@@ -0,0 +1,225 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+from pyspark.pandas.utils import log_advice
+from pyspark.sql import DataFrame as PySparkDataFrame
+
+from sedona.spark import StructuredAdapter
+from sedona.spark.core.enums import IndexType
+
+
+class SpatialIndex:
+    """
+    A wrapper around Sedona's spatial index functionality.
+    """
+
+    def __init__(self, geometry, index_type="strtree", column_name=None):
+        """
+        Initialize the SpatialIndex with geometry data.
+
+        Parameters
+        ----------
+        geometry : np.array of Shapely geometries, PySparkDataFrame column, or 
PySparkDataFrame
+        index_type : str, default "strtree"
+            The type of spatial index to use.
+        column_name : str, optional
+            The column name to extract geometry from if `geometry` is a 
PySparkDataFrame.
+        """
+
+        if isinstance(geometry, np.ndarray):
+            self.geometry = geometry
+            self.index_type = index_type
+            self._dataframe = None
+            self._is_spark = False
+            # Build local index for numpy array
+            self._build_local_index()
+        elif isinstance(geometry, PySparkDataFrame):
+            if column_name is None:
+                raise ValueError(
+                    "column_name must be specified when geometry is a 
PySparkDataFrame"
+                )
+            self.geometry = geometry[column_name]
+            self.index_type = index_type
+            self._dataframe = geometry
+            self._is_spark = True
+            # Build distributed spatial index
+            self._build_spark_index(column_name)
+        else:
+            raise TypeError(
+                "Invalid type for `geometry`. Expected np.array or 
PySparkDataFrame."
+            )
+
+    def query(self, geometry, predicate=None, sort=False):
+        """
+        Query the spatial index for geometries that intersect the given 
geometry.
+
+        Parameters
+        ----------
+        geometry : Shapely geometry
+            The geometry to query against the spatial index.
+        predicate : str, optional
+            Spatial predicate to filter results (e.g., 'intersects', 
'contains').
+        sort : bool, optional, default False
+            Whether to sort the results.
+
+        Returns
+        -------
+        list
+            List of indices of matching geometries.
+        """
+        log_advice(
+            "`query` returns local list of indices of matching geometries onto 
driver's memory. "
+            "It should only be used if the resulting collection is expected to 
be small."
+        )
+
+        if self.is_empty:
+            return []
+
+        if self._is_spark:
+            # For Spark-based spatial index
+            from sedona.spark.core.spatialOperator import RangeQuery
+
+            # Execute the spatial range query
+            if predicate == "contains":
+                result_rdd = RangeQuery.SpatialRangeQuery(
+                    self._indexed_rdd, geometry, True, True
+                )
+            else:  # Default to intersects
+                result_rdd = RangeQuery.SpatialRangeQuery(
+                    self._indexed_rdd, geometry, False, True
+                )
+
+            results = result_rdd.collect()
+            return results
+        else:
+            # For local spatial index based on Shapely STRtree
+            if predicate == "contains":
+                # STRtree doesn't directly support contains predicate
+                # We need to filter results after querying
+                candidate_indices = self._index.query(geometry)
+                results = [
+                    i for i in candidate_indices if 
geometry.contains(self.geometry[i])
+                ]
+            else:
+                # Default is intersects
+                results = self._index.query(geometry)
+
+            if sort and results:
+                # Sort by distance to the query geometry if requested
+                results = sorted(
+                    results, key=lambda i: self.geometry[i].distance(geometry)
+                )
+
+            return results
+
+    def nearest(self, geometry, k=1, return_distance=False):
+        """
+        Find the nearest geometry in the spatial index.
+
+        Parameters
+        ----------
+        geometry : Shapely geometry
+            The geometry to find the nearest neighbor for.
+        k : int, optional, default 1
+            Number of nearest neighbors to find.
+        return_distance : bool, optional, default False
+            Whether to return distances along with indices.
+
+        Returns
+        -------
+        list or tuple
+            List of indices of nearest geometries, optionally with distances.
+        """
+        # Placeholder for KNN query using Sedona
+        raise NotImplementedError("This method is not implemented yet.")
+
+    def intersection(self, bounds):
+        """
+        Find geometries that intersect the given bounding box.
+
+        Parameters
+        ----------
+        bounds : tuple
+            Bounding box as (min_x, min_y, max_x, max_y).
+
+        Returns
+        -------
+        list
+            List of indices of matching geometries.
+        """
+        raise NotImplementedError("This method is not implemented yet.")
+
+    @property
+    def size(self):
+        """
+        Get the size of the spatial index.
+
+        Returns
+        -------
+        int
+            Number of geometries in the index.
+        """
+        if self._is_spark:
+            return self._dataframe.count()
+        return len(self.geometry)
+
+    @property
+    def is_empty(self):
+        """
+        Check if the spatial index is empty.
+
+        Returns
+        -------
+        bool
+            True if the index is empty, False otherwise.
+        """
+        return self.size == 0
+
+    def _build_spark_index(self, column_name):
+        """
+        Build a distributed spatial index on the geometry column of the 
DataFrame.
+
+        This uses Sedona's built-in indexing functionality.
+        """
+
+        # Convert index_type string to Sedona IndexType enum
+        index_type_map = {"strtree": IndexType.RTREE, "quadtree": 
IndexType.QUADTREE}
+        sedona_index_type = index_type_map.get(self.index_type.lower(), 
IndexType.RTREE)
+
+        # Create a SpatialRDD from the DataFrame
+        spatial_rdd = StructuredAdapter.toSpatialRdd(self._dataframe, 
column_name)
+
+        # Build spatial index
+        spatial_rdd.buildIndex(sedona_index_type, False)
+
+        # Store the indexed RDD
+        self._indexed_rdd = spatial_rdd
+
+    def _build_local_index(self):
+        """
+        Build a local spatial index for numpy array of geometries.
+        """
+        from shapely.strtree import STRtree
+
+        if len(self.geometry) > 0:
+            if self.index_type.lower() == "strtree":
+                self._index = STRtree(self.geometry)
+            else:
+                raise ValueError(
+                    f"Unsupported index type: {self.index_type}. Only 
'strtree' is supported for local indexing."
+                )
diff --git a/python/tests/geopandas/test_sindex.py 
b/python/tests/geopandas/test_sindex.py
new file mode 100644
index 0000000000..a6e76b65d0
--- /dev/null
+++ b/python/tests/geopandas/test_sindex.py
@@ -0,0 +1,145 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+from pyspark.sql.functions import expr
+from shapely.geometry import Point, Polygon, LineString
+
+from tests.test_base import TestBase
+from sedona.geopandas import GeoSeries
+from sedona.geopandas.sindex import SpatialIndex
+
+
+class TestSpatialIndex(TestBase):
+    """Tests for the spatial index functionality in GeoSeries."""
+
+    def setup_method(self):
+        """Set up test data."""
+        # Create a GeoSeries with point geometries
+        self.points = GeoSeries(
+            [Point(0, 0), Point(1, 1), Point(2, 2), Point(3, 3), Point(4, 4)]
+        )
+
+        # Create a GeoSeries with polygon geometries
+        self.polygons = GeoSeries(
+            [
+                Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]),
+                Polygon([(1, 1), (2, 1), (2, 2), (1, 2)]),
+                Polygon([(2, 2), (3, 2), (3, 3), (2, 3)]),
+                Polygon([(3, 3), (4, 3), (4, 4), (3, 4)]),
+                Polygon([(4, 4), (5, 4), (5, 5), (4, 5)]),
+            ]
+        )
+
+        # Create a GeoSeries with line geometries
+        self.lines = GeoSeries(
+            [
+                LineString([(0, 0), (1, 1)]),
+                LineString([(1, 1), (2, 2)]),
+                LineString([(2, 2), (3, 3)]),
+                LineString([(3, 3), (4, 4)]),
+                LineString([(4, 4), (5, 5)]),
+            ]
+        )
+
+    def test_sindex_property_exists(self):
+        """Test that the sindex property exists on GeoSeries."""
+        assert hasattr(self.points, "sindex")
+        assert hasattr(self.polygons, "sindex")
+        assert hasattr(self.lines, "sindex")
+
+    def test_query_with_point(self):
+        """Test querying the spatial index with a point geometry."""
+        # Create a list of Shapely geometries - squares around points (0,0), 
(1,1), etc.
+        geometries = [
+            Polygon(
+                [
+                    (i - 0.5, j - 0.5),
+                    (i + 0.5, j - 0.5),
+                    (i + 0.5, j + 0.5),
+                    (i - 0.5, j + 0.5),
+                ]
+            )
+            for i, j in [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
+        ]
+
+        # Create a spatial index from the geometries
+        geom_array = np.array(geometries, dtype=object)
+        sindex = SpatialIndex(geom_array)
+
+        # Test query with a point that should intersect with one polygon
+        query_point = Point(2.2, 2.2)
+        result_indices = sindex.query(query_point)
+        assert len(result_indices) == 1
+
+        # Test query with a point that intersects no polygons
+        empty_point = Point(10, 10)
+        empty_results = sindex.query(empty_point)
+        assert len(empty_results) == 0
+
+    def test_query_with_spark_dataframe(self):
+        """Test querying the spatial index with a Spark DataFrame."""
+        # Create a spatial DataFrame with polygons
+        polygons_data = [
+            (1, "POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))"),
+            (2, "POLYGON((1 1, 2 1, 2 2, 1 2, 1 1))"),
+            (3, "POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))"),
+            (4, "POLYGON((3 3, 4 3, 4 4, 3 4, 3 3))"),
+            (5, "POLYGON((4 4, 5 4, 5 5, 4 5, 4 4))"),
+        ]
+
+        df = self.spark.createDataFrame(polygons_data, ["id", "wkt"])
+        spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
+
+        # Create a SpatialIndex from the DataFrame
+        sindex = SpatialIndex(spatial_df, index_type="strtree", 
column_name="geometry")
+
+        # Test query with a point that should intersect with one polygon
+        from shapely.geometry import Point
+
+        query_point = Point(2.2, 2.2)
+
+        # Execute query
+        result_indices = sindex.query(query_point, "contains")
+
+        # Verify results - should find at least one result (polygon containing 
the point)
+        assert len(result_indices) > 0
+
+        # Test query with a polygon that should intersect multiple polygons
+        from shapely.geometry import box
+
+        query_box = box(1.5, 1.5, 3.5, 3.5)
+
+        # Execute query
+        box_results = sindex.query(query_box, predicate="contains")
+
+        # Verify results - should find multiple polygons
+        assert len(box_results) > 1
+
+        # Test with contains predicate
+        # The query box fully contains polygon at index 2 (POLYGON((2 2, 3 2, 
3 3, 2 3, 2 2)))
+        contains_results = sindex.query(query_box, predicate="contains")
+
+        # Verify contains results
+        assert len(contains_results) >= 1
+
+        # Test with a point outside any polygon
+        outside_point = Point(10, 10)
+        outside_results = sindex.query(outside_point)
+
+        # Verify no results for point outside
+        assert len(outside_results) == 0

Reply via email to