This is an automated email from the ASF dual-hosted git repository.

kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new 7d50b710 feat(rust/sedona-spatial-join) Spatial index supports async 
batch query and parallel refinement (#523)
7d50b710 is described below

commit 7d50b710f88214f7d6055a15f6675f3f1937bdd2
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Sat Jan 17 12:40:29 2026 +0800

    feat(rust/sedona-spatial-join) Spatial index supports async batch query and 
parallel refinement (#523)
    
    This PR addresses performance bottlenecks (stragglers) observed during the 
candidate refinement phase of SpatialBench Q10 and Q11, particularly at higher 
scale factors (SF=100 and SF=1000).
    
    When executing queries with large windows on dense datasets, a single 
R-Tree index query can retrieve millions of candidates. The probe partition 
becomes a "straggler" because it must sequentially evaluate spatial predicates 
for these millions of geometries. Since this bottleneck occurs within a single 
partition, DataFusion’s partition-level parallelism is unable to distribute the 
load.
    
    This patch introduced an async batch query interface for SpatialIndex. This 
allows the engine to split massive refinement workloads into smaller tasks, 
which are then executed in parallel by an async runtime. This amortizes 
scheduling costs of async function calls and eliminates the single-partition 
bottleneck.
---
 rust/sedona-common/src/option.rs                   |   7 +
 rust/sedona-spatial-join/src/exec.rs               |  18 +
 .../sedona-spatial-join/src/index/spatial_index.rs | 482 ++++++++++++++++++++-
 .../src/index/spatial_index_builder.rs             |   1 +
 rust/sedona-spatial-join/src/stream.rs             | 327 ++++++++------
 5 files changed, 682 insertions(+), 153 deletions(-)

diff --git a/rust/sedona-common/src/option.rs b/rust/sedona-common/src/option.rs
index fcd692fb..bc74acf7 100644
--- a/rust/sedona-common/src/option.rs
+++ b/rust/sedona-common/src/option.rs
@@ -70,6 +70,13 @@ config_namespace! {
 
         /// Include tie-breakers in KNN join results when there are tied 
distances
         pub knn_include_tie_breakers: bool, default = false
+
+        /// The minimum number of geometry pairs per chunk required to enable 
parallel
+        /// refinement during the spatial join operation. When the refinement 
phase has
+        /// fewer geometry pairs than this threshold, it will run sequentially 
instead
+        /// of spawning parallel tasks. Higher values reduce parallelization 
overhead
+        /// for small datasets, while lower values enable more fine-grained 
parallelism.
+        pub parallel_refinement_chunk_size: usize, default = 8192
     }
 }
 
diff --git a/rust/sedona-spatial-join/src/exec.rs 
b/rust/sedona-spatial-join/src/exec.rs
index 5cdea16d..43b73290 100644
--- a/rust/sedona-spatial-join/src/exec.rs
+++ b/rust/sedona-spatial-join/src/exec.rs
@@ -1135,6 +1135,24 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn test_parallel_refinement_for_large_candidate_set() -> Result<()> {
+        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
=
+            create_test_data_with_size_range((1.0, 50.0), WKB_GEOMETRY)?;
+
+        for max_batch_size in [10, 30, 100] {
+            let options = SpatialJoinOptions {
+                execution_mode: ExecutionMode::PrepareNone,
+                parallel_refinement_chunk_size: 10,
+                ..Default::default()
+            };
+            test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
+                "SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, 
R.geometry) AND L.dist < R.dist ORDER BY L.id, R.id").await?;
+        }
+
+        Ok(())
+    }
+
     async fn test_with_join_types(join_type: JoinType) -> Result<RecordBatch> {
         let ((left_schema, left_partitions), (right_schema, right_partitions)) 
=
             create_test_data_with_empty_partitions()?;
diff --git a/rust/sedona-spatial-join/src/index/spatial_index.rs 
b/rust/sedona-spatial-join/src/index/spatial_index.rs
index 83a1a754..6f3e00d0 100644
--- a/rust/sedona-spatial-join/src/index/spatial_index.rs
+++ b/rust/sedona-spatial-join/src/index/spatial_index.rs
@@ -15,14 +15,18 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::sync::{
-    atomic::{AtomicUsize, Ordering},
-    Arc,
+use std::{
+    ops::Range,
+    sync::{
+        atomic::{AtomicUsize, Ordering},
+        Arc,
+    },
 };
 
 use arrow_array::RecordBatch;
 use arrow_schema::SchemaRef;
-use datafusion_common::Result;
+use datafusion_common::{DataFusionError, Result};
+use datafusion_common_runtime::JoinSet;
 use datafusion_execution::memory_pool::{MemoryPool, MemoryReservation};
 use float_next_after::NextAfter;
 use geo::BoundingRect;
@@ -44,7 +48,7 @@ use crate::{
         knn_adapter::{KnnComponents, SedonaKnnAdapter},
         IndexQueryResult, QueryResultMetrics,
     },
-    operand_evaluator::{create_operand_evaluator, OperandEvaluator},
+    operand_evaluator::{create_operand_evaluator, distance_value_at, 
OperandEvaluator},
     refine::{create_refiner, IndexQueryResultRefiner},
     spatial_predicate::SpatialPredicate,
     utils::concurrent_reservation::ConcurrentReservation,
@@ -54,6 +58,7 @@ use sedona_common::{option::SpatialJoinOptions, 
sedona_internal_err, ExecutionMo
 
 pub struct SpatialIndex {
     pub(crate) schema: SchemaRef,
+    pub(crate) options: SpatialJoinOptions,
 
     /// The spatial predicate evaluator for the spatial predicate.
     pub(crate) evaluator: Arc<dyn OperandEvaluator>,
@@ -125,6 +130,7 @@ impl SpatialIndex {
             .then(|| KnnComponents::new(0, &[], memory_pool.clone()).unwrap());
         Self {
             schema,
+            options,
             evaluator,
             refiner,
             refiner_reservation,
@@ -178,6 +184,7 @@ impl SpatialIndex {
     /// # Returns
     /// * `JoinResultMetrics` containing the number of actual matches 
(`count`) and the number
     ///   of candidates from the filter phase (`candidate_count`)
+    #[allow(unused)]
     pub(crate) fn query(
         &self,
         probe_wkb: &Wkb,
@@ -409,6 +416,179 @@ impl SpatialIndex {
         })
     }
 
+    /// Query the spatial index with a batch of probe geometries to find 
matching build-side geometries.
+    ///
+    /// This method iterates over the probe geometries in the given range of 
the evaluated batch.
+    /// For each probe geometry, it performs the two-phase spatial join query:
+    /// 1. **Filter phase**: Uses the R-tree index with the probe geometry's 
bounding rectangle
+    ///    to quickly identify candidate geometries.
+    /// 2. **Refinement phase**: Evaluates the exact spatial predicate on 
candidates to determine
+    ///    actual matches.
+    ///
+    /// # Arguments
+    /// * `evaluated_batch` - The batch containing probe geometries and their 
bounding rectangles
+    /// * `range` - The range of rows in the evaluated batch to process.
+    /// * `max_result_size` - The maximum number of results to collect before 
stopping. If the
+    ///   number of results exceeds this limit, the method returns early.
+    /// * `build_batch_positions` - Output vector that will be populated with 
(batch_idx, row_idx)
+    ///   pairs for each matching build-side geometry.
+    /// * `probe_indices` - Output vector that will be populated with the 
probe row index (in
+    ///   `evaluated_batch`) for each match appended to 
`build_batch_positions`.
+    ///   This means the probe index is repeated `N` times when a probe 
geometry produces `N` matches,
+    ///   keeping `probe_indices.len()` in sync with 
`build_batch_positions.len()`.
+    ///
+    /// # Returns
+    /// * A tuple containing:
+    ///   - `QueryResultMetrics`: Aggregated metrics (total matches and 
candidates) for the processed rows
+    ///   - `usize`: The index of the next row to process (exclusive end of 
the processed range)
+    pub(crate) async fn query_batch(
+        self: &Arc<Self>,
+        evaluated_batch: &Arc<EvaluatedBatch>,
+        range: Range<usize>,
+        max_result_size: usize,
+        build_batch_positions: &mut Vec<(i32, i32)>,
+        probe_indices: &mut Vec<u32>,
+    ) -> Result<(QueryResultMetrics, usize)> {
+        if range.is_empty() {
+            return Ok((
+                QueryResultMetrics {
+                    count: 0,
+                    candidate_count: 0,
+                },
+                range.start,
+            ));
+        }
+
+        let rects = evaluated_batch.rects();
+        let dist = evaluated_batch.distance();
+        let mut total_candidates_count = 0;
+        let mut total_count = 0;
+        let mut current_row_idx = range.start;
+        for row_idx in range {
+            current_row_idx = row_idx;
+            let Some(probe_rect) = rects[row_idx] else {
+                continue;
+            };
+
+            let min = probe_rect.min();
+            let max = probe_rect.max();
+            let mut candidates = self.rtree.search(min.x, min.y, max.x, max.y);
+            if candidates.is_empty() {
+                continue;
+            }
+
+            let Some(probe_wkb) = evaluated_batch.wkb(row_idx) else {
+                return sedona_internal_err!(
+                    "Failed to get WKB for row {} in evaluated batch",
+                    row_idx
+                );
+            };
+
+            // Sort and dedup candidates to avoid duplicate results when we 
index one geometry
+            // using several boxes.
+            candidates.sort_unstable();
+            candidates.dedup();
+
+            let distance = match dist {
+                Some(dist_array) => distance_value_at(dist_array, row_idx)?,
+                None => None,
+            };
+
+            // Refine the candidates retrieved from the r-tree index by 
evaluating the actual spatial predicate
+            let refine_chunk_size = 
self.options.parallel_refinement_chunk_size;
+            if refine_chunk_size == 0 || candidates.len() < refine_chunk_size 
* 2 {
+                // For small candidate sets, use refine synchronously
+                let metrics =
+                    self.refine(probe_wkb, &candidates, &distance, 
build_batch_positions)?;
+                probe_indices.extend(std::iter::repeat_n(row_idx as u32, 
metrics.count));
+                total_count += metrics.count;
+                total_candidates_count += metrics.candidate_count;
+            } else {
+                // For large candidate sets, spawn several tasks to 
parallelize refinement
+                let (metrics, positions) = self
+                    .refine_concurrently(
+                        evaluated_batch,
+                        row_idx,
+                        &candidates,
+                        distance,
+                        refine_chunk_size,
+                    )
+                    .await?;
+                build_batch_positions.extend(positions);
+                probe_indices.extend(std::iter::repeat_n(row_idx as u32, 
metrics.count));
+                total_count += metrics.count;
+                total_candidates_count += metrics.candidate_count;
+            }
+
+            if total_count >= max_result_size {
+                break;
+            }
+        }
+
+        let end_idx = current_row_idx + 1;
+        Ok((
+            QueryResultMetrics {
+                count: total_count,
+                candidate_count: total_candidates_count,
+            },
+            end_idx,
+        ))
+    }
+
+    async fn refine_concurrently(
+        self: &Arc<Self>,
+        evaluated_batch: &Arc<EvaluatedBatch>,
+        row_idx: usize,
+        candidates: &[u32],
+        distance: Option<f64>,
+        refine_chunk_size: usize,
+    ) -> Result<(QueryResultMetrics, Vec<(i32, i32)>)> {
+        let mut join_set = JoinSet::new();
+        for (i, chunk) in candidates.chunks(refine_chunk_size).enumerate() {
+            let cloned_evaluated_batch = Arc::clone(evaluated_batch);
+            let chunk = chunk.to_vec();
+            let index_ref = Arc::clone(self);
+            join_set.spawn(async move {
+                let Some(probe_wkb) = cloned_evaluated_batch.wkb(row_idx) else 
{
+                    return (
+                        i,
+                        sedona_internal_err!(
+                            "Failed to get WKB for row {} in evaluated batch",
+                            row_idx
+                        ),
+                    );
+                };
+                let mut local_positions: Vec<(i32, i32)> = 
Vec::with_capacity(chunk.len());
+                let res = index_ref.refine(probe_wkb, &chunk, &distance, &mut 
local_positions);
+                (i, res.map(|r| (r, local_positions)))
+            });
+        }
+
+        // Collect the results in order
+        let mut refine_results = Vec::with_capacity(join_set.len());
+        refine_results.resize_with(join_set.len(), || None);
+        while let Some(res) = join_set.join_next().await {
+            let (chunk_idx, refine_res) =
+                res.map_err(|e| DataFusionError::External(Box::new(e)))?;
+            let (metrics, positions) = refine_res?;
+            refine_results[chunk_idx] = Some((metrics, positions));
+        }
+
+        let mut total_metrics = QueryResultMetrics {
+            count: 0,
+            candidate_count: 0,
+        };
+        let mut all_positions = Vec::with_capacity(candidates.len());
+        for res in refine_results {
+            let (metrics, positions) = res.expect("All chunks should be 
processed");
+            total_metrics.count += metrics.count;
+            total_metrics.candidate_count += metrics.candidate_count;
+            all_positions.extend(positions);
+        }
+
+        Ok((total_metrics, all_positions))
+    }
+
     fn refine(
         &self,
         probe_wkb: &Wkb,
@@ -1232,9 +1412,6 @@ mod tests {
         assert!(build_positions.len() <= 3);
         assert!(result.count > 0);
         assert!(result.count <= 3);
-
-        println!("KNN Geometry test - found {} results", result.count);
-        println!("Result positions: {build_positions:?}");
     }
 
     #[test]
@@ -1316,8 +1493,6 @@ mod tests {
         // Should return results
         assert!(!build_positions.is_empty());
 
-        println!("KNN with mixed geometries: {build_positions:?}");
-
         // Should work with mixed geometry types
         assert!(result.count > 0);
     }
@@ -1518,4 +1693,291 @@ mod tests {
         assert_eq!(result.candidate_count, 0);
         assert!(build_positions.is_empty());
     }
+
+    async fn setup_index_for_batch_test(
+        build_geoms: &[Option<&str>],
+        options: SpatialJoinOptions,
+    ) -> Arc<SpatialIndex> {
+        let memory_pool = Arc::new(GreedyMemoryPool::new(100 * 1024 * 1024));
+        let metrics = SpatialJoinBuildMetrics::default();
+        let spatial_predicate = 
SpatialPredicate::Relation(RelationPredicate::new(
+            Arc::new(Column::new("left", 0)),
+            Arc::new(Column::new("right", 0)),
+            SpatialRelationType::Intersects,
+        ));
+        let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
+            "geom",
+            DataType::Binary,
+            true,
+        )]));
+
+        let mut builder = SpatialIndexBuilder::new(
+            schema,
+            spatial_predicate,
+            options,
+            JoinType::Inner,
+            1,
+            memory_pool,
+            metrics,
+        )
+        .unwrap();
+
+        let geom_array = create_array(build_geoms, &WKB_GEOMETRY);
+        let batch = RecordBatch::try_new(
+            Arc::new(arrow_schema::Schema::new(vec![Field::new(
+                "geom",
+                DataType::Binary,
+                true,
+            )])),
+            vec![Arc::new(geom_array.clone())],
+        )
+        .unwrap();
+        let evaluated_batch = EvaluatedBatch {
+            batch,
+            geom_array: EvaluatedGeometryArray::try_new(geom_array, 
&WKB_GEOMETRY).unwrap(),
+        };
+
+        builder.add_batch(evaluated_batch).unwrap();
+        Arc::new(builder.finish().unwrap())
+    }
+
+    fn create_probe_batch(probe_geoms: &[Option<&str>]) -> Arc<EvaluatedBatch> 
{
+        let geom_array = create_array(probe_geoms, &WKB_GEOMETRY);
+        let batch = RecordBatch::try_new(
+            Arc::new(arrow_schema::Schema::new(vec![Field::new(
+                "geom",
+                DataType::Binary,
+                true,
+            )])),
+            vec![Arc::new(geom_array.clone())],
+        )
+        .unwrap();
+        Arc::new(EvaluatedBatch {
+            batch,
+            geom_array: EvaluatedGeometryArray::try_new(geom_array, 
&WKB_GEOMETRY).unwrap(),
+        })
+    }
+
+    #[tokio::test]
+    async fn test_query_batch_empty_results() {
+        let build_geoms = &[Some("POINT (0 0)"), Some("POINT (1 1)")];
+        let index = setup_index_for_batch_test(build_geoms, 
SpatialJoinOptions::default()).await;
+
+        // Probe with geometries that don't intersect
+        let probe_geoms = &[Some("POINT (10 10)"), Some("POINT (20 20)")];
+        let probe_batch = create_probe_batch(probe_geoms);
+
+        let mut build_batch_positions = Vec::new();
+        let mut probe_indices = Vec::new();
+        let (metrics, next_idx) = index
+            .query_batch(
+                &probe_batch,
+                0..2,
+                usize::MAX,
+                &mut build_batch_positions,
+                &mut probe_indices,
+            )
+            .await
+            .unwrap();
+
+        assert_eq!(metrics.count, 0);
+        assert_eq!(build_batch_positions.len(), 0);
+        assert_eq!(probe_indices.len(), 0);
+        assert_eq!(next_idx, 2);
+    }
+
+    #[tokio::test]
+    async fn test_query_batch_max_result_size() {
+        let build_geoms = &[
+            Some("POINT (0 0)"),
+            Some("POINT (0 0)"),
+            Some("POINT (0 0)"),
+        ];
+        let index = setup_index_for_batch_test(build_geoms, 
SpatialJoinOptions::default()).await;
+
+        // Probe with geometry that intersects all 3
+        let probe_geoms = &[Some("POINT (0 0)"), Some("POINT (0 0)")];
+        let probe_batch = create_probe_batch(probe_geoms);
+
+        // Case 1: Max result size is large enough
+        let mut build_batch_positions = Vec::new();
+        let mut probe_indices = Vec::new();
+        let (metrics, next_idx) = index
+            .query_batch(
+                &probe_batch,
+                0..2,
+                10,
+                &mut build_batch_positions,
+                &mut probe_indices,
+            )
+            .await
+            .unwrap();
+        assert_eq!(metrics.count, 6); // 2 probes * 3 matches
+        assert_eq!(next_idx, 2);
+        assert_eq!(probe_indices, vec![0, 0, 0, 1, 1, 1]);
+
+        // Case 2: Max result size is small (stops after first probe)
+        let mut build_batch_positions = Vec::new();
+        let mut probe_indices = Vec::new();
+        let (metrics, next_idx) = index
+            .query_batch(
+                &probe_batch,
+                0..2,
+                2, // Stop after 2 results
+                &mut build_batch_positions,
+                &mut probe_indices,
+            )
+            .await
+            .unwrap();
+
+        // It should process the first probe, find 3 matches.
+        // Since 3 >= 2, it should stop.
+        assert_eq!(metrics.count, 3);
+        assert_eq!(next_idx, 1); // Only processed 1 probe
+        assert_eq!(probe_indices, vec![0, 0, 0]);
+    }
+
+    #[tokio::test]
+    async fn test_query_batch_parallel_refinement() {
+        // Create enough build geometries to trigger parallel refinement
+        // We need candidates.len() >= chunk_size * 2
+        // Let's set chunk_size = 2, so we need >= 4 candidates.
+        let build_geoms = vec![Some("POINT (0 0)"); 10];
+        let options = SpatialJoinOptions {
+            parallel_refinement_chunk_size: 2,
+            ..Default::default()
+        };
+
+        let index = setup_index_for_batch_test(&build_geoms, options).await;
+
+        // Probe with a geometry that intersects all build geometries
+        let probe_geoms = &[Some("POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))")];
+        let probe_batch = create_probe_batch(probe_geoms);
+
+        let mut build_batch_positions = Vec::new();
+        let mut probe_indices = Vec::new();
+        let (metrics, next_idx) = index
+            .query_batch(
+                &probe_batch,
+                0..1,
+                usize::MAX,
+                &mut build_batch_positions,
+                &mut probe_indices,
+            )
+            .await
+            .unwrap();
+
+        assert_eq!(metrics.count, 10);
+        assert_eq!(build_batch_positions.len(), 10);
+        assert_eq!(probe_indices, vec![0; 10]);
+        assert_eq!(next_idx, 1);
+    }
+
+    #[tokio::test]
+    async fn test_query_batch_empty_range() {
+        let build_geoms = &[Some("POINT (0 0)")];
+        let index = setup_index_for_batch_test(build_geoms, 
SpatialJoinOptions::default()).await;
+        let probe_geoms = &[Some("POINT (0 0)"), Some("POINT (0 0)")];
+        let probe_batch = create_probe_batch(probe_geoms);
+
+        let mut build_batch_positions = Vec::new();
+        let mut probe_indices = Vec::new();
+
+        // Query with empty range
+        for empty_ranges in [0..0, 1..1, 2..2] {
+            let (metrics, next_idx) = index
+                .query_batch(
+                    &probe_batch,
+                    empty_ranges.clone(),
+                    usize::MAX,
+                    &mut build_batch_positions,
+                    &mut probe_indices,
+                )
+                .await
+                .unwrap();
+
+            assert_eq!(metrics.count, 0);
+            assert_eq!(next_idx, empty_ranges.end);
+        }
+    }
+
+    #[tokio::test]
+    async fn test_query_batch_range_offset() {
+        let build_geoms = &[Some("POINT (0 0)"), Some("POINT (1 1)")];
+        let index = setup_index_for_batch_test(build_geoms, 
SpatialJoinOptions::default()).await;
+
+        // Probe with 3 geometries:
+        // 0: POINT (0 0) - matches build[0] (should be skipped)
+        // 1: POINT (0 0) - matches build[0]
+        // 2: POINT (1 1) - matches build[1]
+        let probe_geoms = &[
+            Some("POINT (0 0)"),
+            Some("POINT (0 0)"),
+            Some("POINT (1 1)"),
+        ];
+        let probe_batch = create_probe_batch(probe_geoms);
+
+        let mut build_batch_positions = Vec::new();
+        let mut probe_indices = Vec::new();
+
+        // Query with range 1..3 (skipping the first probe)
+        let (metrics, next_idx) = index
+            .query_batch(
+                &probe_batch,
+                1..3,
+                usize::MAX,
+                &mut build_batch_positions,
+                &mut probe_indices,
+            )
+            .await
+            .unwrap();
+
+        assert_eq!(metrics.count, 2);
+        assert_eq!(next_idx, 3);
+
+        // probe_indices should contain indices relative to the batch start (1 
and 2)
+        assert_eq!(probe_indices, vec![1, 2]);
+
+        // build_batch_positions should contain matches for probe 1 and probe 2
+        // probe 1 matches build 0 (0, 0)
+        // probe 2 matches build 1 (0, 1)
+        // Note: build_batch_positions contains (batch_idx, row_idx)
+        // Since we have 1 batch, batch_idx is 0.
+        assert_eq!(build_batch_positions, vec![(0, 0), (0, 1)]);
+    }
+
+    #[tokio::test]
+    async fn test_query_batch_zero_parallel_refinement_chunk_size() {
+        let build_geoms = &[
+            Some("POINT (0 0)"),
+            Some("POINT (0 0)"),
+            Some("POINT (0 0)"),
+        ];
+        let options = SpatialJoinOptions {
+            // force synchronous refinement
+            parallel_refinement_chunk_size: 0,
+            ..Default::default()
+        };
+
+        let index = setup_index_for_batch_test(build_geoms, options).await;
+        let probe_geoms = &[Some("POINT (0 0)")];
+        let probe_batch = create_probe_batch(probe_geoms);
+
+        let mut build_batch_positions = Vec::new();
+        let mut probe_indices = Vec::new();
+
+        let result = index
+            .query_batch(
+                &probe_batch,
+                0..1,
+                10,
+                &mut build_batch_positions,
+                &mut probe_indices,
+            )
+            .await;
+
+        assert!(result.is_ok());
+        let (metrics, _) = result.unwrap();
+        assert_eq!(metrics.count, 3);
+    }
 }
diff --git a/rust/sedona-spatial-join/src/index/spatial_index_builder.rs 
b/rust/sedona-spatial-join/src/index/spatial_index_builder.rs
index 41d7fbd6..a9b08d7a 100644
--- a/rust/sedona-spatial-join/src/index/spatial_index_builder.rs
+++ b/rust/sedona-spatial-join/src/index/spatial_index_builder.rs
@@ -274,6 +274,7 @@ impl SpatialIndexBuilder {
 
         Ok(SpatialIndex {
             schema: self.schema,
+            options: self.options,
             evaluator,
             refiner,
             refiner_reservation,
diff --git a/rust/sedona-spatial-join/src/stream.rs 
b/rust/sedona-spatial-join/src/stream.rs
index f4b18244..4a01e6ef 100644
--- a/rust/sedona-spatial-join/src/stream.rs
+++ b/rust/sedona-spatial-join/src/stream.rs
@@ -23,7 +23,9 @@ use 
datafusion_physical_plan::joins::utils::StatefulStreamResult;
 use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter};
 use datafusion_physical_plan::metrics::{self, ExecutionPlanMetricsSet, 
MetricBuilder};
 use datafusion_physical_plan::{handle_state, RecordBatchStream, 
SendableRecordBatchStream};
+use futures::future::BoxFuture;
 use futures::stream::StreamExt;
+use futures::FutureExt;
 use futures::{ready, task::Poll};
 use parking_lot::Mutex;
 use sedona_common::sedona_internal_err;
@@ -37,7 +39,7 @@ use 
crate::evaluated_batch::evaluated_batch_stream::evaluate::create_evaluated_p
 use 
crate::evaluated_batch::evaluated_batch_stream::SendableEvaluatedBatchStream;
 use crate::evaluated_batch::EvaluatedBatch;
 use crate::index::SpatialIndex;
-use crate::operand_evaluator::{create_operand_evaluator, distance_value_at};
+use crate::operand_evaluator::create_operand_evaluator;
 use crate::spatial_predicate::SpatialPredicate;
 use crate::utils::join_utils::{
     adjust_indices_by_join_type, apply_join_filter_to_indices, 
build_batch_from_indices,
@@ -50,7 +52,7 @@ use sedona_common::option::SpatialJoinOptions;
 
 /// Stream for producing spatial join result batches.
 pub(crate) struct SpatialJoinStream {
-    /// Input schema
+    /// Schema of joined results
     schema: Arc<Schema>,
     /// join filter
     filter: Option<JoinFilter>,
@@ -165,7 +167,6 @@ impl SpatialJoinProbeMetrics {
 }
 
 /// This enumeration represents various states of the nested loop join 
algorithm.
-#[derive(Debug)]
 #[allow(clippy::large_enum_variant)]
 pub(crate) enum SpatialJoinStreamState {
     /// The initial mode: waiting for the spatial index to be built
@@ -174,7 +175,9 @@ pub(crate) enum SpatialJoinStreamState {
     /// fetching probe-side
     FetchProbeBatch,
     /// Indicates that we're processing a probe batch using the batch iterator
-    ProcessProbeBatch(SpatialJoinBatchIterator),
+    ProcessProbeBatch(
+        BoxFuture<'static, (Box<SpatialJoinBatchIterator>, 
Result<Option<RecordBatch>>)>,
+    ),
     /// Indicates that probe-side has been fully processed
     ExhaustedProbeSide,
     /// Indicates that we're processing unmatched build-side batches using an 
iterator
@@ -197,7 +200,7 @@ impl SpatialJoinStream {
                     handle_state!(ready!(self.fetch_probe_batch(cx)))
                 }
                 SpatialJoinStreamState::ProcessProbeBatch(_) => {
-                    handle_state!(ready!(self.process_probe_batch()))
+                    handle_state!(ready!(self.process_probe_batch(cx)))
                 }
                 SpatialJoinStreamState::ExhaustedProbeSide => {
                     
handle_state!(ready!(self.setup_unmatched_build_batch_processing()))
@@ -227,8 +230,13 @@ impl SpatialJoinStream {
         let result = self.probe_stream.poll_next_unpin(cx);
         match result {
             Poll::Ready(Some(Ok(batch))) => match 
self.create_spatial_join_iterator(batch) {
-                Ok(iterator) => {
-                    self.state = 
SpatialJoinStreamState::ProcessProbeBatch(iterator);
+                Ok(mut iterator) => {
+                    let future = async move {
+                        let result = iterator.next_batch().await;
+                        (iterator, result)
+                    }
+                    .boxed();
+                    self.state = 
SpatialJoinStreamState::ProcessProbeBatch(future);
                     Poll::Ready(Ok(StatefulStreamResult::Continue))
                 }
                 Err(e) => Poll::Ready(Err(e)),
@@ -242,54 +250,51 @@ impl SpatialJoinStream {
         }
     }
 
-    fn process_probe_batch(&mut self) -> 
Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
-        let timer = self.join_metrics.join_time.timer();
+    fn process_probe_batch(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+        let _timer = self.join_metrics.join_time.timer();
 
         // Extract the necessary data first to avoid borrowing conflicts
-        let (batch_opt, is_complete) = match &mut self.state {
-            SpatialJoinStreamState::ProcessProbeBatch(iterator) => {
-                // For KNN joins, we may have swapped build/probe sides, so 
build_side might be Right;
-                // For regular joins, build_side is always Left.
-                let build_side = match &self.spatial_predicate {
-                    SpatialPredicate::KNearestNeighbors(knn) => 
knn.probe_side.negate(),
-                    _ => JoinSide::Left,
-                };
-
-                let batch_opt = match iterator.next_batch(
-                    &self.schema,
-                    self.filter.as_ref(),
-                    self.join_type,
-                    &self.column_indices,
-                    build_side,
-                ) {
-                    Ok(opt) => opt,
-                    Err(e) => {
-                        return Poll::Ready(Err(e));
-                    }
-                };
-                let is_complete = iterator.is_complete();
-                (batch_opt, is_complete)
-            }
+        let (mut iterator, batch_opt) = match &mut self.state {
+            SpatialJoinStreamState::ProcessProbeBatch(future) => match 
future.poll_unpin(cx) {
+                Poll::Ready((iterator, result)) => {
+                    let batch_opt = match result {
+                        Ok(opt) => opt,
+                        Err(e) => {
+                            return Poll::Ready(Err(e));
+                        }
+                    };
+                    (iterator, batch_opt)
+                }
+                Poll::Pending => return Poll::Pending,
+            },
             _ => unreachable!(),
         };
 
-        let result = match batch_opt {
+        match batch_opt {
             Some(batch) => {
                 // Check if iterator is complete
-                if is_complete {
+                if iterator.is_complete() {
                     self.state = SpatialJoinStreamState::FetchProbeBatch;
+                } else {
+                    // Iterator is not complete, continue processing the 
current probe batch
+                    let future = async move {
+                        let result = iterator.next_batch().await;
+                        (iterator, result)
+                    }
+                    .boxed();
+                    self.state = 
SpatialJoinStreamState::ProcessProbeBatch(future);
                 }
-                batch
+                Poll::Ready(Ok(StatefulStreamResult::Ready(Some(batch))))
             }
             None => {
                 // Iterator finished, move to next probe batch
                 self.state = SpatialJoinStreamState::FetchProbeBatch;
-                return Poll::Ready(Ok(StatefulStreamResult::Continue));
+                Poll::Ready(Ok(StatefulStreamResult::Continue))
             }
-        };
-
-        timer.done();
-        Poll::Ready(Ok(StatefulStreamResult::Ready(Some(result))))
+        }
     }
 
     fn setup_unmatched_build_batch_processing(
@@ -391,7 +396,7 @@ impl SpatialJoinStream {
     fn create_spatial_join_iterator(
         &self,
         probe_evaluated_batch: EvaluatedBatch,
-    ) -> Result<SpatialJoinBatchIterator> {
+    ) -> Result<Box<SpatialJoinBatchIterator>> {
         let num_rows = probe_evaluated_batch.num_rows();
         self.join_metrics.probe_input_batches.add(1);
         self.join_metrics.probe_input_rows.add(num_rows);
@@ -414,15 +419,28 @@ impl SpatialJoinStream {
             spatial_index.merge_probe_stats(stats);
         }
 
-        SpatialJoinBatchIterator::new(SpatialJoinBatchIteratorParams {
+        // For KNN joins, we may have swapped build/probe sides, so build_side 
might be Right;
+        // For regular joins, build_side is always Left.
+        let build_side = match &self.spatial_predicate {
+            SpatialPredicate::KNearestNeighbors(knn) => 
knn.probe_side.negate(),
+            _ => JoinSide::Left,
+        };
+
+        let iterator = 
SpatialJoinBatchIterator::new(SpatialJoinBatchIteratorParams {
+            schema: self.schema.clone(),
+            filter: self.filter.clone(),
+            join_type: self.join_type,
+            column_indices: self.column_indices.clone(),
+            build_side,
             spatial_index: spatial_index.clone(),
-            probe_evaluated_batch,
+            probe_evaluated_batch: Arc::new(probe_evaluated_batch),
             join_metrics: self.join_metrics.clone(),
             max_batch_size: self.target_output_batch_size,
             probe_side_ordered: self.probe_side_ordered,
             spatial_predicate: self.spatial_predicate.clone(),
             options: self.options.clone(),
-        })
+        })?;
+        Ok(Box::new(iterator))
     }
 }
 
@@ -454,10 +472,20 @@ struct PartialBuildBatch {
 
 /// Iterator that processes spatial join results in configurable batch sizes
 pub(crate) struct SpatialJoinBatchIterator {
+    /// Schema of the output record batches
+    schema: SchemaRef,
+    /// Optional join filter to be applied to the join results
+    filter: Option<JoinFilter>,
+    /// Type of the join operation
+    join_type: JoinType,
+    /// Information of index and left / right placement of columns
+    column_indices: Vec<ColumnIndex>,
+    /// The side of the build stream, either Left or Right
+    build_side: JoinSide,
     /// The spatial index reference
     spatial_index: Arc<SpatialIndex>,
     /// The probe side batch being processed
-    probe_evaluated_batch: EvaluatedBatch,
+    probe_evaluated_batch: Arc<EvaluatedBatch>,
     /// Current probe row index being processed
     current_probe_idx: usize,
     /// Join metrics for tracking performance
@@ -480,8 +508,13 @@ pub(crate) struct SpatialJoinBatchIterator {
 
 /// Parameters for creating a SpatialJoinBatchIterator
 pub(crate) struct SpatialJoinBatchIteratorParams {
+    pub schema: SchemaRef,
+    pub filter: Option<JoinFilter>,
+    pub join_type: JoinType,
+    pub column_indices: Vec<ColumnIndex>,
+    pub build_side: JoinSide,
     pub spatial_index: Arc<SpatialIndex>,
-    pub probe_evaluated_batch: EvaluatedBatch,
+    pub probe_evaluated_batch: Arc<EvaluatedBatch>,
     pub join_metrics: SpatialJoinProbeMetrics,
     pub max_batch_size: usize,
     pub probe_side_ordered: bool,
@@ -492,6 +525,11 @@ pub(crate) struct SpatialJoinBatchIteratorParams {
 impl SpatialJoinBatchIterator {
     pub(crate) fn new(params: SpatialJoinBatchIteratorParams) -> Result<Self> {
         Ok(Self {
+            schema: params.schema,
+            filter: params.filter,
+            join_type: params.join_type,
+            column_indices: params.column_indices,
+            build_side: params.build_side,
             spatial_index: params.spatial_index,
             probe_evaluated_batch: params.probe_evaluated_batch,
             current_probe_idx: 0,
@@ -506,28 +544,50 @@ impl SpatialJoinBatchIterator {
         })
     }
 
-    pub fn next_batch(
-        &mut self,
-        schema: &Schema,
-        filter: Option<&JoinFilter>,
-        join_type: JoinType,
-        column_indices: &[ColumnIndex],
-        build_side: JoinSide,
-    ) -> Result<Option<RecordBatch>> {
-        // Process probe rows incrementally until we have enough results or 
finish
-        let initial_size = self.build_batch_positions.len();
+    pub async fn next_batch(&mut self) -> Result<Option<RecordBatch>> {
+        if self.is_complete {
+            return Ok(None);
+        }
 
-        let geom_array = &self.probe_evaluated_batch.geom_array;
-        let wkbs = geom_array.wkbs();
-        let rects = &geom_array.rects;
-        let distance = &geom_array.distance;
+        let last_probe_idx = self.current_probe_idx;
+        match &self.spatial_predicate {
+            SpatialPredicate::KNearestNeighbors(_) => self.probe_knn()?,
+            _ => self.probe_range().await?,
+        };
 
-        let num_rows = wkbs.len();
+        // Check if we've finished processing all probe rows
+        if self.current_probe_idx >= self.probe_evaluated_batch.num_rows() {
+            self.is_complete = true;
+        }
 
-        let last_probe_idx = self.current_probe_idx;
+        if self.current_probe_idx > last_probe_idx {
+            // Process the joined indices to create a RecordBatch
+            let probe_indices = std::mem::take(&mut self.probe_indices);
+            let batch = self.process_joined_indices_to_batch(
+                &self.build_batch_positions,
+                probe_indices,
+                &self.schema,
+                self.filter.as_ref(),
+                self.join_type,
+                &self.column_indices,
+                self.build_side,
+                last_probe_idx..self.current_probe_idx,
+            )?;
+
+            self.build_batch_positions.clear();
+            Ok(Some(batch))
+        } else {
+            Ok(None)
+        }
+    }
+
+    fn probe_knn(&mut self) -> Result<()> {
+        let geom_array = &self.probe_evaluated_batch.geom_array;
+        let wkbs = geom_array.wkbs();
 
         // Process from current position until we hit batch size limit or 
complete
-        while self.current_probe_idx < num_rows && !self.is_complete {
+        let num_rows = wkbs.len();
+        while self.current_probe_idx < num_rows {
             // Get WKB for current probe index
             let wkb_opt = &wkbs[self.current_probe_idx];
 
@@ -537,65 +597,40 @@ impl SpatialJoinBatchIterator {
                 continue;
             };
 
-            let dist = match distance {
-                Some(dist) => distance_value_at(dist, self.current_probe_idx)?,
-                None => None,
-            };
-
             // Handle KNN queries differently from regular spatial joins
-            match &self.spatial_predicate {
-                SpatialPredicate::KNearestNeighbors(knn_predicate) => {
-                    // For KNN, call query_knn only once per probe geometry 
(not per rect)
-                    let k = knn_predicate.k;
-                    let use_spheroid = knn_predicate.use_spheroid;
-                    let include_tie_breakers = 
self.options.knn_include_tie_breakers;
-
-                    let join_result_metrics = self.spatial_index.query_knn(
-                        wkb,
-                        k,
-                        use_spheroid,
-                        include_tie_breakers,
-                        &mut self.build_batch_positions,
-                    )?;
-
-                    self.probe_indices.extend(std::iter::repeat_n(
-                        self.current_probe_idx as u32,
-                        join_result_metrics.count,
-                    ));
-
-                    self.join_metrics
-                        .join_result_candidates
-                        .add(join_result_metrics.candidate_count);
-                    self.join_metrics
-                        .join_result_count
-                        .add(join_result_metrics.count);
-                }
-                _ => {
-                    // Regular spatial join: process all rects for this probe 
index
-                    let rect_opt = &rects[self.current_probe_idx];
-                    if let Some(rect) = rect_opt {
-                        let join_result_metrics = self.spatial_index.query(
-                            wkb,
-                            rect,
-                            &dist,
-                            &mut self.build_batch_positions,
-                        )?;
-
-                        self.probe_indices.extend(std::iter::repeat_n(
-                            self.current_probe_idx as u32,
-                            join_result_metrics.count,
-                        ));
-
-                        self.join_metrics
-                            .join_result_candidates
-                            .add(join_result_metrics.candidate_count);
-                        self.join_metrics
-                            .join_result_count
-                            .add(join_result_metrics.count);
-                    }
-                }
+            if let SpatialPredicate::KNearestNeighbors(knn_predicate) = 
&self.spatial_predicate {
+                // For KNN, call query_knn only once per probe geometry (not 
per rect)
+                let k = knn_predicate.k;
+                let use_spheroid = knn_predicate.use_spheroid;
+                let include_tie_breakers = 
self.options.knn_include_tie_breakers;
+
+                let join_result_metrics = self.spatial_index.query_knn(
+                    wkb,
+                    k,
+                    use_spheroid,
+                    include_tie_breakers,
+                    &mut self.build_batch_positions,
+                )?;
+
+                self.probe_indices.extend(std::iter::repeat_n(
+                    self.current_probe_idx as u32,
+                    join_result_metrics.count,
+                ));
+
+                self.join_metrics
+                    .join_result_candidates
+                    .add(join_result_metrics.candidate_count);
+                self.join_metrics
+                    .join_result_count
+                    .add(join_result_metrics.count);
+            } else {
+                unreachable!("probe_knn called for non-KNN predicate");
             }
 
+            assert!(
+                self.probe_indices.len() == self.build_batch_positions.len(),
+                "Probe indices and build batch positions length should match"
+            );
             self.current_probe_idx += 1;
 
             // Early exit if we have enough results
@@ -604,31 +639,37 @@ impl SpatialJoinBatchIterator {
             }
         }
 
-        // Check if we've finished processing all probe rows
-        if self.current_probe_idx >= num_rows {
-            self.is_complete = true;
-        }
+        Ok(())
+    }
 
-        // Return accumulated results if we have any new ones or if we're 
complete
-        if self.build_batch_positions.len() > initial_size || self.is_complete 
{
-            // Process the joined indices to create a RecordBatch
-            let probe_indices = std::mem::take(&mut self.probe_indices);
-            let batch = self.process_joined_indices_to_batch(
-                &self.build_batch_positions,
-                probe_indices,
-                schema,
-                filter,
-                join_type,
-                column_indices,
-                build_side,
-                last_probe_idx..self.current_probe_idx,
-            )?;
+    async fn probe_range(&mut self) -> Result<()> {
+        let num_rows = self.probe_evaluated_batch.num_rows();
+        let range = self.current_probe_idx..num_rows;
 
-            self.build_batch_positions.clear();
-            Ok(Some(batch))
-        } else {
-            Ok(None)
-        }
+        let (metrics, next_row_idx) = self
+            .spatial_index
+            .query_batch(
+                &self.probe_evaluated_batch,
+                range,
+                self.max_batch_size,
+                &mut self.build_batch_positions,
+                &mut self.probe_indices,
+            )
+            .await?;
+
+        self.current_probe_idx = next_row_idx;
+
+        self.join_metrics
+            .join_result_candidates
+            .add(metrics.candidate_count);
+        self.join_metrics.join_result_count.add(metrics.count);
+
+        assert!(
+            self.probe_indices.len() == self.build_batch_positions.len(),
+            "Probe indices and build batch positions length should match"
+        );
+
+        Ok(())
     }
 
     /// Check if the iterator has finished processing


Reply via email to