This is an automated email from the ASF dual-hosted git repository.

kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new 93a3c0a5 chore(rust/sedona-spatial-join): Split large join result 
batches into smaller ones (#525)
93a3c0a5 is described below

commit 93a3c0a5c6e623af286e2cfae9a7bbf683496446
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Tue Jan 20 11:17:38 2026 +0800

    chore(rust/sedona-spatial-join): Split large join result batches into 
smaller ones (#525)
    
    This is a follow up of https://github.com/apache/sedona-db/pull/523. When 
executing queries with large windows on dense datasets, each probe row may be 
matched with millions of indexed rows. If we don't break large result batches 
generated by such index probing, we'll easily overshoot the memory limit when 
assembling join result batches.
    
    This patch splits large joined build-probe side indices into smaller pieces 
and gradually assemble result batches. This will greatly reduce the amount of 
memory required for producing join results for "cover all" probe rows. The code 
for properly slicing join result indices for various join types is a bit 
complicated. We have added fuzz tests to verify that it works correctly.
    
    Co-authored-by: Copilot <[email protected]>
---
 .../sedona-spatial-join/src/index/spatial_index.rs |  10 +-
 .../src/index/spatial_index_builder.rs             |   4 +-
 rust/sedona-spatial-join/src/stream.rs             | 567 ++++++++++++++++-----
 3 files changed, 459 insertions(+), 122 deletions(-)

diff --git a/rust/sedona-spatial-join/src/index/spatial_index.rs 
b/rust/sedona-spatial-join/src/index/spatial_index.rs
index 6f3e00d0..e5e69dd8 100644
--- a/rust/sedona-spatial-join/src/index/spatial_index.rs
+++ b/rust/sedona-spatial-join/src/index/spatial_index.rs
@@ -88,8 +88,8 @@ pub struct SpatialIndex {
     /// prepared geometries.
     pub(crate) geom_idx_vec: Vec<usize>,
 
-    /// Shared bitmap builders for visited left indices, one per batch
-    pub(crate) visited_left_side: Option<Mutex<Vec<BooleanBufferBuilder>>>,
+    /// Shared bitmap builders for visited build side indices, one per batch
+    pub(crate) visited_build_side: Option<Mutex<Vec<BooleanBufferBuilder>>>,
 
     /// Counter of running probe-threads, potentially able to update `bitmap`.
     /// Each time a probe thread finished probing the index, it will decrement 
the counter.
@@ -138,7 +138,7 @@ impl SpatialIndex {
             data_id_to_batch_pos: Vec::new(),
             indexed_batches: Vec::new(),
             geom_idx_vec: Vec::new(),
-            visited_left_side: None,
+            visited_build_side: None,
             probe_threads_counter,
             knn_components,
             reservation,
@@ -659,8 +659,8 @@ impl SpatialIndex {
 
     /// Get the bitmaps for tracking visited left-side indices. The bitmaps 
will be updated
     /// by the spatial join stream when producing output batches during index 
probing phase.
-    pub(crate) fn visited_left_side(&self) -> 
Option<&Mutex<Vec<BooleanBufferBuilder>>> {
-        self.visited_left_side.as_ref()
+    pub(crate) fn visited_build_side(&self) -> 
Option<&Mutex<Vec<BooleanBufferBuilder>>> {
+        self.visited_build_side.as_ref()
     }
 
     /// Decrements counter of running threads, and returns `true`
diff --git a/rust/sedona-spatial-join/src/index/spatial_index_builder.rs 
b/rust/sedona-spatial-join/src/index/spatial_index_builder.rs
index a9b08d7a..49e0d8c6 100644
--- a/rust/sedona-spatial-join/src/index/spatial_index_builder.rs
+++ b/rust/sedona-spatial-join/src/index/spatial_index_builder.rs
@@ -249,7 +249,7 @@ impl SpatialIndexBuilder {
 
         let (rtree, batch_pos_vec) = self.build_rtree()?;
         let geom_idx_vec = self.build_geom_idx_vec(&batch_pos_vec);
-        let visited_left_side = self.build_visited_bitmaps()?;
+        let visited_build_side = self.build_visited_bitmaps()?;
 
         let refiner = create_refiner(
             self.options.spatial_library,
@@ -282,7 +282,7 @@ impl SpatialIndexBuilder {
             data_id_to_batch_pos: batch_pos_vec,
             indexed_batches: self.indexed_batches,
             geom_idx_vec,
-            visited_left_side,
+            visited_build_side,
             probe_threads_counter: AtomicUsize::new(self.probe_threads_count),
             knn_components,
             reservation: self.reservation,
diff --git a/rust/sedona-spatial-join/src/stream.rs 
b/rust/sedona-spatial-join/src/stream.rs
index 4a01e6ef..6cf175c2 100644
--- a/rust/sedona-spatial-join/src/stream.rs
+++ b/rust/sedona-spatial-join/src/stream.rs
@@ -470,7 +470,7 @@ struct PartialBuildBatch {
     interleave_indices_map: HashMap<(i32, i32), usize>,
 }
 
-/// Iterator that processes spatial join results in configurable batch sizes
+/// Iterator that produces spatial join results for one probe batch
 pub(crate) struct SpatialJoinBatchIterator {
     /// Schema of the output record batches
     schema: SchemaRef,
@@ -486,24 +486,121 @@ pub(crate) struct SpatialJoinBatchIterator {
     spatial_index: Arc<SpatialIndex>,
     /// The probe side batch being processed
     probe_evaluated_batch: Arc<EvaluatedBatch>,
-    /// Current probe row index being processed
-    current_probe_idx: usize,
     /// Join metrics for tracking performance
     join_metrics: SpatialJoinProbeMetrics,
     /// Maximum batch size before yielding a result
     max_batch_size: usize,
     /// Maintains the order of the probe side
     probe_side_ordered: bool,
-    /// Current accumulated build batch positions
-    build_batch_positions: Vec<(i32, i32)>,
-    /// Current accumulated probe indices
-    probe_indices: Vec<u32>,
-    /// Whether iteration is complete
-    is_complete: bool,
     /// The spatial predicate being evaluated
     spatial_predicate: SpatialPredicate,
     /// The spatial join options
     options: SpatialJoinOptions,
+    /// Progress of probing
+    progress: Option<ProbeProgress>,
+}
+
+struct ProbeProgress {
+    /// Index of the probe row to be probed by 
[SpatialJoinBatchIterator::probe_range] or
+    /// [SpatialJoinBatchIterator::probe_knn].
+    current_probe_idx: usize,
+    /// Index of the lastly produced probe row. This field uses `-1` as a 
sentinel value
+    /// to represent "nothing produced yet" and is stored as `i64` instead of
+    /// `Option<usize>` to keep the layout compact and avoid extra branching 
and
+    /// wrapping/unwrapping in the hot probe loop. There are three cases:
+    /// - `-1` means nothing was produced yet.
+    /// - `>= num_rows` means we have produced all probe rows. The iterator is 
complete.
+    /// - within `[0, num_rows)` means we have produced up to this probe index 
(inclusive).
+    ///   The value is the largest probe row index that has matching build 
rows so far.
+    last_produced_probe_idx: i64,
+    /// Current accumulated build batch positions
+    build_batch_positions: Vec<(i32, i32)>,
+    /// Current accumulated probe indices. Should have the same length as 
`build_batch_positions`
+    probe_indices: Vec<u32>,
+    /// Cursor of the position in the `build_batch_positions` and 
`probe_indices` vectors
+    /// for tracking the progress of producing joined batches
+    pos: usize,
+}
+
+/// Type alias for a tuple of build and probe indices slices
+type BuildAndProbeIndices<'a> = (&'a [(i32, i32)], &'a [u32]);
+
+impl ProbeProgress {
+    fn indices_for_next_batch(
+        &mut self,
+        build_side: JoinSide,
+        join_type: JoinType,
+        max_batch_size: usize,
+    ) -> Option<BuildAndProbeIndices<'_>> {
+        let end = self.probe_indices.len();
+
+        // Advance the produced probe end index to skip already hit probe side 
rows
+        // when running probe-semi, probe-anti or probe-mark joins. This is 
because
+        // semi/anti/mark joins only care about whether a probe row has 
matches,
+        // and we don't want to produce duplicate unmatched probe rows when 
the same
+        // probe row P has multiple matches and we split probe_indices range 
into
+        // multiple pieces containing P.
+        let should_skip_lastly_produced_probe_rows = matches!(
+            (build_side, join_type),
+            (
+                JoinSide::Left,
+                JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark
+            ) | (
+                JoinSide::Right,
+                JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark
+            )
+        );
+        if should_skip_lastly_produced_probe_rows {
+            while self.pos < end
+                && self.probe_indices[self.pos] as i64 == 
self.last_produced_probe_idx
+            {
+                self.pos += 1;
+            }
+        }
+
+        if self.pos >= end {
+            // No more results to produce. Should switch to Probing or 
Complete state.
+            return None;
+        }
+
+        // Take a slice of the accumulated results to produce
+        let slice_end = (self.pos + max_batch_size).min(end);
+        let build_indices = &self.build_batch_positions[self.pos..slice_end];
+        let probe_indices = &self.probe_indices[self.pos..slice_end];
+        self.pos = slice_end;
+
+        Some((build_indices, probe_indices))
+    }
+
+    fn next_probe_range(&mut self, probe_indices: &[u32]) -> Range<usize> {
+        let last_produced_probe_idx = self.last_produced_probe_idx;
+        let start_probe_idx = if probe_indices[0] as i64 == 
last_produced_probe_idx {
+            last_produced_probe_idx as usize
+        } else {
+            (last_produced_probe_idx + 1) as usize
+        };
+        let end_probe_idx = {
+            let last_probe_idx = probe_indices[probe_indices.len() - 1] as 
usize;
+            self.last_produced_probe_idx = last_probe_idx as i64;
+            last_probe_idx + 1
+        };
+        start_probe_idx..end_probe_idx
+    }
+
+    fn last_probe_range(&mut self, num_rows: usize) -> Option<Range<usize>> {
+        // Check if we have already produced all probe rows. There are 2 cases:
+        // 1. The last produced probe index is at the end (the last row had 
matches)
+        // 2. We have already called produce_last_result_batch before. Ignore 
this call.
+        if self.last_produced_probe_idx + 1 >= num_rows as i64 {
+            self.last_produced_probe_idx = num_rows as i64;
+            return None;
+        }
+
+        let start_probe_idx = (self.last_produced_probe_idx + 1) as usize;
+        let end_probe_idx = num_rows;
+        self.last_produced_probe_idx = end_probe_idx as i64;
+        Some(start_probe_idx..end_probe_idx)
+    }
 }
 
 /// Parameters for creating a SpatialJoinBatchIterator
@@ -532,68 +629,117 @@ impl SpatialJoinBatchIterator {
             build_side: params.build_side,
             spatial_index: params.spatial_index,
             probe_evaluated_batch: params.probe_evaluated_batch,
-            current_probe_idx: 0,
             join_metrics: params.join_metrics,
             max_batch_size: params.max_batch_size,
             probe_side_ordered: params.probe_side_ordered,
-            build_batch_positions: Vec::new(),
-            probe_indices: Vec::new(),
-            is_complete: false,
             spatial_predicate: params.spatial_predicate,
             options: params.options,
+            progress: Some(ProbeProgress {
+                current_probe_idx: 0,
+                last_produced_probe_idx: -1,
+                build_batch_positions: Vec::new(),
+                probe_indices: Vec::new(),
+                pos: 0,
+            }),
         })
     }
 
     pub async fn next_batch(&mut self) -> Result<Option<RecordBatch>> {
-        if self.is_complete {
-            return Ok(None);
-        }
+        let progress_opt = std::mem::take(&mut self.progress);
+        let mut progress = progress_opt.expect("Progress should be available");
+        let res = self.next_batch_inner(&mut progress).await;
+        self.progress = Some(progress);
+        res
+    }
 
-        let last_probe_idx = self.current_probe_idx;
-        match &self.spatial_predicate {
-            SpatialPredicate::KNearestNeighbors(_) => self.probe_knn()?,
-            _ => self.probe_range().await?,
-        };
+    async fn next_batch_inner(&self, progress: &mut ProbeProgress) -> 
Result<Option<RecordBatch>> {
+        let num_rows = self.probe_evaluated_batch.num_rows();
+        loop {
+            // Check if we have produced results for the entire probe batch
+            if self.is_complete_inner(progress) {
+                return Ok(None);
+            }
 
-        // Check if we've finished processing all probe rows
-        if self.current_probe_idx >= self.probe_evaluated_batch.num_rows() {
-            self.is_complete = true;
-        }
+            // Check if we need to probe more rows
+            if progress.current_probe_idx < num_rows
+                && progress.probe_indices.len() < self.max_batch_size
+            {
+                match &self.spatial_predicate {
+                    SpatialPredicate::KNearestNeighbors(_) => 
self.probe_knn(progress)?,
+                    _ => self.probe_range(progress).await?,
+                }
+            }
 
-        if self.current_probe_idx > last_probe_idx {
-            // Process the joined indices to create a RecordBatch
-            let probe_indices = std::mem::take(&mut self.probe_indices);
-            let batch = self.process_joined_indices_to_batch(
-                &self.build_batch_positions,
-                probe_indices,
-                &self.schema,
-                self.filter.as_ref(),
-                self.join_type,
-                &self.column_indices,
-                self.build_side,
-                last_probe_idx..self.current_probe_idx,
-            )?;
+            // Produce result batch from accumulated results
+            let joined_batch_opt = if progress.pos < 
progress.probe_indices.len() {
+                let joined_batch_opt = self.produce_result_batch(progress)?;
+                if progress.probe_indices.len() - progress.pos < 
self.max_batch_size {
+                    // Drain produced portion of probe_indices to make it 
shorter, so that we can
+                    // probe more rows using self.probe() in the next 
iteration.
+                    self.drain_produced_indices(progress);
+                }
+                joined_batch_opt
+            } else {
+                // No more accumulated results even after probing, we must 
have reached the end
+                self.produce_last_result_batch(progress)?
+            };
 
-            self.build_batch_positions.clear();
-            Ok(Some(batch))
-        } else {
-            Ok(None)
+            if let Some(batch) = joined_batch_opt {
+                return Ok(Some(batch));
+            }
         }
     }
 
-    fn probe_knn(&mut self) -> Result<()> {
+    async fn probe_range(&self, progress: &mut ProbeProgress) -> Result<()> {
+        let num_rows = self.probe_evaluated_batch.num_rows();
+        let range = progress.current_probe_idx..num_rows;
+
+        // Calculate remaining capacity in the progress buffer to respect 
max_batch_size
+        let max_result_size = self
+            .max_batch_size
+            .saturating_sub(progress.probe_indices.len());
+
+        let (metrics, next_row_idx) = self
+            .spatial_index
+            .query_batch(
+                &self.probe_evaluated_batch,
+                range,
+                max_result_size,
+                &mut progress.build_batch_positions,
+                &mut progress.probe_indices,
+            )
+            .await?;
+
+        progress.current_probe_idx = next_row_idx;
+
+        self.join_metrics
+            .join_result_candidates
+            .add(metrics.candidate_count);
+        self.join_metrics.join_result_count.add(metrics.count);
+
+        assert!(
+            progress.probe_indices.len() == 
progress.build_batch_positions.len(),
+            "Probe indices and build batch positions length should match"
+        );
+
+        Ok(())
+    }
+
+    /// Process more probe rows and fill in the build_batch_positions and 
probe_indices
+    /// until we have filled in enough results or processed all probe rows.
+    fn probe_knn(&self, progress: &mut ProbeProgress) -> Result<()> {
         let geom_array = &self.probe_evaluated_batch.geom_array;
         let wkbs = geom_array.wkbs();
 
         // Process from current position until we hit batch size limit or 
complete
         let num_rows = wkbs.len();
-        while self.current_probe_idx < num_rows {
+        while progress.current_probe_idx < num_rows {
             // Get WKB for current probe index
-            let wkb_opt = &wkbs[self.current_probe_idx];
+            let wkb_opt = &wkbs[progress.current_probe_idx];
 
             let Some(wkb) = wkb_opt else {
                 // Move to next probe index
-                self.current_probe_idx += 1;
+                progress.current_probe_idx += 1;
                 continue;
             };
 
@@ -609,11 +755,11 @@ impl SpatialJoinBatchIterator {
                     k,
                     use_spheroid,
                     include_tie_breakers,
-                    &mut self.build_batch_positions,
+                    &mut progress.build_batch_positions,
                 )?;
 
-                self.probe_indices.extend(std::iter::repeat_n(
-                    self.current_probe_idx as u32,
+                progress.probe_indices.extend(std::iter::repeat_n(
+                    progress.current_probe_idx as u32,
                     join_result_metrics.count,
                 ));
 
@@ -628,13 +774,13 @@ impl SpatialJoinBatchIterator {
             }
 
             assert!(
-                self.probe_indices.len() == self.build_batch_positions.len(),
+                progress.probe_indices.len() == 
progress.build_batch_positions.len(),
                 "Probe indices and build batch positions length should match"
             );
-            self.current_probe_idx += 1;
+            progress.current_probe_idx += 1;
 
             // Early exit if we have enough results
-            if self.build_batch_positions.len() >= self.max_batch_size {
+            if progress.build_batch_positions.len() >= self.max_batch_size {
                 break;
             }
         }
@@ -642,54 +788,92 @@ impl SpatialJoinBatchIterator {
         Ok(())
     }
 
-    async fn probe_range(&mut self) -> Result<()> {
-        let num_rows = self.probe_evaluated_batch.num_rows();
-        let range = self.current_probe_idx..num_rows;
+    fn produce_result_batch(&self, progress: &mut ProbeProgress) -> 
Result<Option<RecordBatch>> {
+        let Some((build_indices, probe_indices)) =
+            progress.indices_for_next_batch(self.build_side, self.join_type, 
self.max_batch_size)
+        else {
+            // No more results to produce
+            return Ok(None);
+        };
 
-        let (metrics, next_row_idx) = self
-            .spatial_index
-            .query_batch(
-                &self.probe_evaluated_batch,
-                range,
-                self.max_batch_size,
-                &mut self.build_batch_positions,
-                &mut self.probe_indices,
-            )
-            .await?;
+        let (build_partial_batch, build_indices_array, probe_indices_array) =
+            self.produce_filtered_indices(build_indices, 
probe_indices.to_vec())?;
+
+        // Produce the final joined batch
+        if probe_indices_array.is_empty() {
+            return Ok(None);
+        }
+        let probe_indices = probe_indices_array.values().as_ref();
+        let probe_range = progress.next_probe_range(probe_indices);
+        let batch = self.build_joined_batch(
+            &build_partial_batch,
+            build_indices_array,
+            probe_indices_array.clone(),
+            probe_range,
+        )?;
 
-        self.current_probe_idx = next_row_idx;
+        if batch.num_rows() > 0 {
+            Ok(Some(batch))
+        } else {
+            Ok(None)
+        }
+    }
 
-        self.join_metrics
-            .join_result_candidates
-            .add(metrics.candidate_count);
-        self.join_metrics.join_result_count.add(metrics.count);
+    /// There might be unmatched results at the tail of the probe row range 
that has not been produced,
+    /// even after all matched build/probe row indices have been produced. 
This function produces
+    /// those unmatched results as a final batch.
+    fn produce_last_result_batch(
+        &self,
+        progress: &mut ProbeProgress,
+    ) -> Result<Option<RecordBatch>> {
+        // Ensure all probe rows have been probed, and all pending results 
have been produced
+        let num_rows = self.probe_evaluated_batch.num_rows();
+        assert_eq!(progress.current_probe_idx, num_rows);
+        assert_eq!(progress.pos, progress.probe_indices.len());
 
-        assert!(
-            self.probe_indices.len() == self.build_batch_positions.len(),
-            "Probe indices and build batch positions length should match"
-        );
+        let Some(probe_range) = progress.last_probe_range(num_rows) else {
+            return Ok(None);
+        };
 
-        Ok(())
+        // Produce unmatched results in range [last_produced_probe_idx + 1, 
num_rows)
+        let build_schema = self.spatial_index.schema();
+        let build_empty_batch = RecordBatch::new_empty(build_schema);
+        let build_indices_array = UInt64Array::from(Vec::<u64>::new());
+        let probe_indices_array = UInt32Array::from(Vec::<u32>::new());
+        let batch = self.build_joined_batch(
+            &build_empty_batch,
+            build_indices_array,
+            probe_indices_array,
+            probe_range,
+        )?;
+        Ok(Some(batch))
+    }
+
+    fn drain_produced_indices(&self, progress: &mut ProbeProgress) {
+        // Move everything after `pos` to the front
+        progress.build_batch_positions.drain(0..progress.pos);
+        progress.probe_indices.drain(0..progress.pos);
+        progress.pos = 0;
     }
 
     /// Check if the iterator has finished processing
     pub fn is_complete(&self) -> bool {
-        self.is_complete
+        let progress = self
+            .progress
+            .as_ref()
+            .expect("Progress should be available");
+        self.is_complete_inner(progress)
     }
 
-    /// Process joined indices and create a RecordBatch
-    #[allow(clippy::too_many_arguments)]
-    fn process_joined_indices_to_batch(
+    fn is_complete_inner(&self, progress: &ProbeProgress) -> bool {
+        progress.last_produced_probe_idx >= 
self.probe_evaluated_batch.batch.num_rows() as i64
+    }
+
+    fn produce_filtered_indices(
         &self,
         build_indices: &[(i32, i32)],
         probe_indices: Vec<u32>,
-        schema: &Schema,
-        filter: Option<&JoinFilter>,
-        join_type: JoinType,
-        column_indices: &[ColumnIndex],
-        build_side: JoinSide,
-        probe_range: Range<usize>,
-    ) -> Result<RecordBatch> {
+    ) -> Result<(RecordBatch, UInt64Array, UInt32Array)> {
         let PartialBuildBatch {
             batch: partial_build_batch,
             indices: build_indices,
@@ -697,21 +881,21 @@ impl SpatialJoinBatchIterator {
         } = self.assemble_partial_build_batch(build_indices)?;
         let probe_indices = UInt32Array::from(probe_indices);
 
-        let (build_indices, probe_indices) = match filter {
+        let (build_indices, probe_indices) = match &self.filter {
             Some(filter) => apply_join_filter_to_indices(
                 &partial_build_batch,
                 &self.probe_evaluated_batch.batch,
                 build_indices,
                 probe_indices,
                 filter,
-                build_side,
+                self.build_side,
             )?,
             None => (build_indices, probe_indices),
         };
 
-        // set the left bitmap
-        if need_produce_result_in_final(join_type) {
-            if let Some(visited_bitmaps) = 
self.spatial_index.visited_left_side() {
+        // set the build side bitmap
+        if need_produce_result_in_final(self.join_type) {
+            if let Some(visited_bitmaps) = 
self.spatial_index.visited_build_side() {
                 mark_build_side_rows_as_visited(
                     &build_indices,
                     &interleave_indices_map,
@@ -720,32 +904,36 @@ impl SpatialJoinBatchIterator {
             }
         }
 
-        // adjust the two side indices base on the join type
+        Ok((partial_build_batch, build_indices, probe_indices))
+    }
+
+    fn build_joined_batch(
+        &self,
+        partial_build_batch: &RecordBatch,
+        build_indices: UInt64Array,
+        probe_indices: UInt32Array,
+        probe_range: Range<usize>,
+    ) -> Result<RecordBatch> {
+        // adjust the two side indices based on the join type
         let (build_indices, probe_indices) = adjust_indices_by_join_type(
             build_indices,
             probe_indices,
             probe_range,
-            join_type,
+            self.join_type,
             self.probe_side_ordered,
         )?;
 
         // Build the final result batch
-        let result_batch = build_batch_from_indices(
-            schema,
-            &partial_build_batch,
+        build_batch_from_indices(
+            &self.schema,
+            partial_build_batch,
             &self.probe_evaluated_batch.batch,
             &build_indices,
             &probe_indices,
-            column_indices,
-            build_side,
-            join_type,
-        )?;
-
-        // Update metrics with actual output
-        self.join_metrics.output_batches.add(1);
-        self.join_metrics.output_rows.add(result_batch.num_rows());
-
-        Ok(result_batch)
+            &self.column_indices,
+            self.build_side,
+            self.join_type,
+        )
     }
 
     fn assemble_partial_build_batch(
@@ -861,13 +1049,6 @@ impl std::fmt::Debug for SpatialJoinBatchIterator {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         f.debug_struct("SpatialJoinBatchIterator")
             .field("max_batch_size", &self.max_batch_size)
-            .field("current_probe_idx", &self.current_probe_idx)
-            .field("is_complete", &self.is_complete)
-            .field(
-                "build_batch_positions_len",
-                &self.build_batch_positions.len(),
-            )
-            .field("probe_indices_len", &self.probe_indices.len())
             .finish()
     }
 }
@@ -891,7 +1072,7 @@ impl UnmatchedBuildBatchIterator {
         spatial_index: Arc<SpatialIndex>,
         empty_right_batch: RecordBatch,
     ) -> Result<Self> {
-        let visited_left_side = spatial_index.visited_left_side();
+        let visited_left_side = spatial_index.visited_build_side();
         let Some(vec_visited_left_side) = visited_left_side else {
             return sedona_internal_err!("The bitmap for visited left side is 
not created");
         };
@@ -918,7 +1099,7 @@ impl UnmatchedBuildBatchIterator {
         build_side: JoinSide,
     ) -> Result<Option<RecordBatch>> {
         while self.current_batch_idx < self.total_batches && !self.is_complete 
{
-            let visited_left_side = self.spatial_index.visited_left_side();
+            let visited_left_side = self.spatial_index.visited_build_side();
             let Some(vec_visited_left_side) = visited_left_side else {
                 return sedona_internal_err!("The bitmap for visited left side 
is not created");
             };
@@ -982,6 +1163,8 @@ mod tests {
     use arrow::array::Int32Array;
     use arrow::datatypes::{DataType, Field, Schema};
     use arrow_array::cast::AsArray;
+    use rand::rngs::StdRng;
+    use rand::{Rng, SeedableRng};
 
     fn create_test_batches(
         num_batches: usize,
@@ -1202,4 +1385,158 @@ mod tests {
                 "Data mismatch when mapping back from assembled batch row {i} 
to original batch {original_batch_idx} row {original_row_idx}");
         }
     }
+
+    #[test]
+    fn test_produce_joined_indices() {
+        for max_batch_size in 1..20 {
+            verify_produce_probe_indices(&[], 0, max_batch_size);
+            verify_produce_probe_indices(&[0, 0, 0, 0], 1, max_batch_size);
+            verify_produce_probe_indices(&[0, 0, 0, 0], 10, max_batch_size);
+            verify_produce_probe_indices(&[3, 3, 3], 10, max_batch_size);
+            verify_produce_probe_indices(&[0, 0, 3, 3, 3, 6, 7], 10, 
max_batch_size);
+            verify_produce_probe_indices(&[0, 3, 3, 3, 4, 5, 5, 9], 10, 
max_batch_size);
+            verify_produce_probe_indices(&[0, 3, 3, 4, 5, 5, 9, 9], 10, 
max_batch_size);
+        }
+    }
+
+    #[test]
+    fn test_fuzz_produce_probe_indices() {
+        let num_rows_range = 0..100;
+        let max_batch_size_range = 1..100;
+        let match_probability = 0.5;
+        let num_matches_range = 1..100;
+        for seed in 0..1000 {
+            fuzz_produce_probe_indices(
+                num_rows_range.clone(),
+                max_batch_size_range.clone(),
+                match_probability,
+                num_matches_range.clone(),
+                seed,
+            );
+        }
+    }
+
+    fn fuzz_produce_probe_indices(
+        num_rows_range: Range<usize>,
+        max_batch_size_range: Range<usize>,
+        match_probability: f64,
+        num_matches_range: Range<usize>,
+        seed: u64,
+    ) {
+        let mut rng = StdRng::seed_from_u64(seed);
+        let num_rows = rng.random_range(num_rows_range);
+        let max_batch_size = rng.random_range(max_batch_size_range);
+        let mut probe_indices = Vec::with_capacity(num_rows);
+        for row in 0..num_rows {
+            let has_matches = rng.random_bool(match_probability);
+            if has_matches {
+                let num_matches = rng.random_range(num_matches_range.clone());
+                probe_indices.extend(std::iter::repeat_n(row as u32, 
num_matches));
+            }
+        }
+        verify_produce_probe_indices(&probe_indices, num_rows, max_batch_size);
+    }
+
+    fn verify_produce_probe_indices(probe_indices: &[u32], num_rows: usize, 
max_batch_size: usize) {
+        for join_type in [
+            JoinType::Inner,
+            JoinType::Left,
+            JoinType::Right,
+            JoinType::Full,
+            JoinType::LeftSemi,
+            JoinType::LeftAnti,
+            JoinType::LeftMark,
+            JoinType::RightSemi,
+            JoinType::RightAnti,
+            JoinType::RightMark,
+        ] {
+            let expected_probe_indices =
+                produce_probe_indices_once(probe_indices, num_rows, join_type);
+            let produced_probe_indices = produce_probe_indices_incrementally(
+                probe_indices,
+                num_rows,
+                max_batch_size,
+                join_type,
+            );
+            assert_eq!(
+                expected_probe_indices, produced_probe_indices,
+                "Fuzz test failed for num_rows: {}, max_batch_size: {}, 
probe_indices: {:?}",
+                num_rows, max_batch_size, probe_indices
+            );
+        }
+    }
+
+    fn produce_probe_indices_once(
+        probe_indices: &[u32],
+        num_rows: usize,
+        join_type: JoinType,
+    ) -> Vec<u32> {
+        let build_indices = UInt64Array::from(vec![0; probe_indices.len()]);
+        let probe_indices_array = UInt32Array::from(probe_indices.to_vec());
+        let probe_range = 0..num_rows;
+        let (_, result_probe_indices) = adjust_indices_by_join_type(
+            build_indices,
+            probe_indices_array,
+            probe_range,
+            join_type,
+            false,
+        )
+        .unwrap();
+        let mut expected_probe_indices = 
result_probe_indices.values().to_vec();
+        expected_probe_indices.sort();
+        expected_probe_indices
+    }
+
+    fn produce_probe_indices_incrementally(
+        probe_indices: &[u32],
+        num_rows: usize,
+        max_batch_size: usize,
+        join_type: JoinType,
+    ) -> Vec<u32> {
+        let build_batch_positions = vec![(0, 0); probe_indices.len()];
+        let mut progress = ProbeProgress {
+            current_probe_idx: 0,
+            last_produced_probe_idx: -1,
+            build_batch_positions,
+            probe_indices: probe_indices.to_vec(),
+            pos: 0,
+        };
+        let mut produced_probe_indices: Vec<u32> = Vec::new();
+        loop {
+            let Some((_, probe_indices)) =
+                progress.indices_for_next_batch(JoinSide::Left, join_type, 
max_batch_size)
+            else {
+                break;
+            };
+            let probe_indices = probe_indices.to_vec();
+            let adjust_range = progress.next_probe_range(&probe_indices);
+            let build_indices = UInt64Array::from(vec![0; 
probe_indices.len()]);
+            let probe_indices = UInt32Array::from(probe_indices);
+            let (_, result_probe_indices) = adjust_indices_by_join_type(
+                build_indices,
+                probe_indices,
+                adjust_range,
+                join_type,
+                false,
+            )
+            .unwrap();
+            
produced_probe_indices.extend(result_probe_indices.values().as_ref());
+        }
+        if let Some(last_range) = progress.last_probe_range(num_rows) {
+            let build_indices = UInt64Array::from(Vec::<u64>::new());
+            let probe_indices = UInt32Array::from(Vec::<u32>::new());
+            let (_, result_probe_indices) = adjust_indices_by_join_type(
+                build_indices,
+                probe_indices,
+                last_range,
+                join_type,
+                false,
+            )
+            .unwrap();
+            
produced_probe_indices.extend(result_probe_indices.values().as_ref());
+        }
+
+        produced_probe_indices.sort();
+        produced_probe_indices
+    }
 }

Reply via email to