pwrliang commented on code in PR #465:
URL: https://github.com/apache/sedona-db/pull/465#discussion_r2710640183
##########
rust/sedona-spatial-join/src/index/spatial_index.rs:
##########
@@ -441,1543 +138,78 @@ impl SpatialIndex {
/// * A tuple containing:
/// - `QueryResultMetrics`: Aggregated metrics (total matches and
candidates) for the processed rows
/// - `usize`: The index of the next row to process (exclusive end of
the processed range)
- pub(crate) async fn query_batch(
- self: &Arc<Self>,
+ async fn query_batch(
+ &self,
evaluated_batch: &Arc<EvaluatedBatch>,
range: Range<usize>,
max_result_size: usize,
build_batch_positions: &mut Vec<(i32, i32)>,
probe_indices: &mut Vec<u32>,
- ) -> Result<(QueryResultMetrics, usize)> {
- if range.is_empty() {
- return Ok((
- QueryResultMetrics {
- count: 0,
- candidate_count: 0,
- },
- range.start,
- ));
- }
-
- let rects = evaluated_batch.rects();
- let dist = evaluated_batch.distance();
- let mut total_candidates_count = 0;
- let mut total_count = 0;
- let mut current_row_idx = range.start;
- for row_idx in range {
- current_row_idx = row_idx;
- let Some(probe_rect) = rects[row_idx] else {
- continue;
- };
-
- let min = probe_rect.min();
- let max = probe_rect.max();
- let mut candidates = self.rtree.search(min.x, min.y, max.x, max.y);
- if candidates.is_empty() {
- continue;
- }
-
- let Some(probe_wkb) = evaluated_batch.wkb(row_idx) else {
- return sedona_internal_err!(
- "Failed to get WKB for row {} in evaluated batch",
- row_idx
- );
- };
-
- // Sort and dedup candidates to avoid duplicate results when we
index one geometry
- // using several boxes.
- candidates.sort_unstable();
- candidates.dedup();
-
- let distance = match dist {
- Some(dist_array) => distance_value_at(dist_array, row_idx)?,
- None => None,
- };
-
- // Refine the candidates retrieved from the r-tree index by
evaluating the actual spatial predicate
- let refine_chunk_size =
self.options.parallel_refinement_chunk_size;
- if refine_chunk_size == 0 || candidates.len() < refine_chunk_size
* 2 {
- // For small candidate sets, use refine synchronously
- let metrics =
- self.refine(probe_wkb, &candidates, &distance,
build_batch_positions)?;
- probe_indices.extend(std::iter::repeat_n(row_idx as u32,
metrics.count));
- total_count += metrics.count;
- total_candidates_count += metrics.candidate_count;
- } else {
- // For large candidate sets, spawn several tasks to
parallelize refinement
- let (metrics, positions) = self
- .refine_concurrently(
- evaluated_batch,
- row_idx,
- &candidates,
- distance,
- refine_chunk_size,
- )
- .await?;
- build_batch_positions.extend(positions);
- probe_indices.extend(std::iter::repeat_n(row_idx as u32,
metrics.count));
- total_count += metrics.count;
- total_candidates_count += metrics.candidate_count;
- }
-
- if total_count >= max_result_size {
- break;
- }
- }
-
- let end_idx = current_row_idx + 1;
- Ok((
- QueryResultMetrics {
- count: total_count,
- candidate_count: total_candidates_count,
- },
- end_idx,
- ))
- }
-
- async fn refine_concurrently(
- self: &Arc<Self>,
- evaluated_batch: &Arc<EvaluatedBatch>,
- row_idx: usize,
- candidates: &[u32],
- distance: Option<f64>,
- refine_chunk_size: usize,
- ) -> Result<(QueryResultMetrics, Vec<(i32, i32)>)> {
- let mut join_set = JoinSet::new();
- for (i, chunk) in candidates.chunks(refine_chunk_size).enumerate() {
- let cloned_evaluated_batch = Arc::clone(evaluated_batch);
- let chunk = chunk.to_vec();
- let index_ref = Arc::clone(self);
- join_set.spawn(async move {
- let Some(probe_wkb) = cloned_evaluated_batch.wkb(row_idx) else
{
- return (
- i,
- sedona_internal_err!(
- "Failed to get WKB for row {} in evaluated batch",
- row_idx
- ),
- );
- };
- let mut local_positions: Vec<(i32, i32)> =
Vec::with_capacity(chunk.len());
- let res = index_ref.refine(probe_wkb, &chunk, &distance, &mut
local_positions);
- (i, res.map(|r| (r, local_positions)))
- });
- }
-
- // Collect the results in order
- let mut refine_results = Vec::with_capacity(join_set.len());
- refine_results.resize_with(join_set.len(), || None);
- while let Some(res) = join_set.join_next().await {
- let (chunk_idx, refine_res) =
- res.map_err(|e| DataFusionError::External(Box::new(e)))?;
- let (metrics, positions) = refine_res?;
- refine_results[chunk_idx] = Some((metrics, positions));
- }
-
- let mut total_metrics = QueryResultMetrics {
- count: 0,
- candidate_count: 0,
- };
- let mut all_positions = Vec::with_capacity(candidates.len());
- for res in refine_results {
- let (metrics, positions) = res.expect("All chunks should be
processed");
- total_metrics.count += metrics.count;
- total_metrics.candidate_count += metrics.candidate_count;
- all_positions.extend(positions);
- }
-
- Ok((total_metrics, all_positions))
- }
-
- fn refine(
- &self,
- probe_wkb: &Wkb,
- candidates: &[u32],
- distance: &Option<f64>,
- build_batch_positions: &mut Vec<(i32, i32)>,
- ) -> Result<QueryResultMetrics> {
- let candidate_count = candidates.len();
-
- let mut index_query_results = Vec::with_capacity(candidate_count);
- for data_idx in candidates {
- let pos = self.data_id_to_batch_pos[*data_idx as usize];
- let (batch_idx, row_idx) = pos;
- let indexed_batch = &self.indexed_batches[batch_idx as usize];
- let build_wkb = indexed_batch.wkb(row_idx as usize);
- let Some(build_wkb) = build_wkb else {
- continue;
- };
- let distance = self.evaluator.resolve_distance(
- indexed_batch.distance(),
- row_idx as usize,
- distance,
- )?;
- let geom_idx = self.geom_idx_vec[*data_idx as usize];
- index_query_results.push(IndexQueryResult {
- wkb: build_wkb,
- distance,
- geom_idx,
- position: pos,
- });
- }
-
- if index_query_results.is_empty() {
- return Ok(QueryResultMetrics {
- count: 0,
- candidate_count,
- });
- }
-
- let results = self.refiner.refine(probe_wkb, &index_query_results)?;
- let num_results = results.len();
- build_batch_positions.extend(results);
-
- // Update refiner memory reservation
- self.refiner_reservation.resize(self.refiner.mem_usage())?;
-
- Ok(QueryResultMetrics {
- count: num_results,
- candidate_count,
- })
- }
+ ) -> Result<(QueryResultMetrics, usize)>;
/// Check if the index needs more probe statistics to determine the
optimal execution mode.
///
/// # Returns
/// * `bool` - `true` if the index needs more probe statistics, `false`
otherwise.
- pub(crate) fn need_more_probe_stats(&self) -> bool {
- self.refiner.need_more_probe_stats()
- }
-
+ fn need_more_probe_stats(&self) -> bool;
/// Merge the probe statistics into the index.
///
/// # Arguments
/// * `stats` - The probe statistics to merge.
- pub(crate) fn merge_probe_stats(&self, stats: GeoStatistics) {
- self.refiner.merge_probe_stats(stats);
- }
+ fn merge_probe_stats(&self, stats: GeoStatistics);
/// Get the bitmaps for tracking visited left-side indices. The bitmaps
will be updated
/// by the spatial join stream when producing output batches during index
probing phase.
- pub(crate) fn visited_left_side(&self) ->
Option<&Mutex<Vec<BooleanBufferBuilder>>> {
- self.visited_left_side.as_ref()
- }
-
+ fn visited_left_side(&self) -> Option<&Mutex<Vec<BooleanBufferBuilder>>>;
/// Decrements counter of running threads, and returns `true`
/// if caller is the last running thread
- pub(crate) fn report_probe_completed(&self) -> bool {
- self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
- }
-
+ fn report_probe_completed(&self) -> bool;
/// Get the memory usage of the refiner in bytes.
- pub(crate) fn get_refiner_mem_usage(&self) -> usize {
- self.refiner.mem_usage()
- }
-
+ fn get_refiner_mem_usage(&self) -> usize;
/// Get the actual execution mode used by the refiner
- pub(crate) fn get_actual_execution_mode(&self) -> ExecutionMode {
- self.refiner.actual_execution_mode()
- }
+ fn get_actual_execution_mode(&self) -> ExecutionMode;
}
-#[cfg(test)]
-mod tests {
- use crate::{
- index::{SpatialIndexBuilder, SpatialJoinBuildMetrics},
- operand_evaluator::EvaluatedGeometryArray,
- spatial_predicate::{KNNPredicate, RelationPredicate,
SpatialRelationType},
- };
-
- use super::*;
- use arrow_array::RecordBatch;
- use arrow_schema::{DataType, Field};
- use datafusion_common::JoinSide;
- use datafusion_execution::memory_pool::GreedyMemoryPool;
- use datafusion_expr::JoinType;
- use datafusion_physical_expr::expressions::Column;
- use geo_traits::Dimensions;
- use sedona_common::option::{ExecutionMode, SpatialJoinOptions};
- use sedona_geometry::wkb_factory::write_wkb_empty_point;
- use sedona_schema::datatypes::WKB_GEOMETRY;
- use sedona_testing::create::create_array;
-
- #[test]
- fn test_spatial_index_builder_empty() {
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
- let options = SpatialJoinOptions {
- execution_mode: ExecutionMode::PrepareBuild,
- ..Default::default()
- };
- let metrics = SpatialJoinBuildMetrics::default();
- let schema = Arc::new(arrow_schema::Schema::empty());
- let spatial_predicate =
SpatialPredicate::Relation(RelationPredicate::new(
- Arc::new(Column::new("geom", 0)),
- Arc::new(Column::new("geom", 1)),
- SpatialRelationType::Intersects,
- ));
-
- let builder = SpatialIndexBuilder::new(
- schema.clone(),
- spatial_predicate,
- options,
- JoinType::Inner,
- 4,
- memory_pool,
- metrics,
- )
- .unwrap();
-
- // Test finishing with empty data
- let index = builder.finish().unwrap();
- assert_eq!(index.schema(), schema);
- assert_eq!(index.indexed_batches.len(), 0);
- }
-
- #[test]
- fn test_spatial_index_builder_add_batch() {
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
- let options = SpatialJoinOptions {
- execution_mode: ExecutionMode::PrepareBuild,
- ..Default::default()
- };
- let metrics = SpatialJoinBuildMetrics::default();
-
- let spatial_predicate =
SpatialPredicate::Relation(RelationPredicate::new(
- Arc::new(Column::new("geom", 0)),
- Arc::new(Column::new("geom", 1)),
- SpatialRelationType::Intersects,
- ));
-
- // Create a simple test geometry batch
- let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
- "geom",
- DataType::Binary,
- true,
- )]));
-
- let mut builder = SpatialIndexBuilder::new(
- schema.clone(),
- spatial_predicate,
- options,
- JoinType::Inner,
- 4,
- memory_pool,
- metrics,
- )
- .unwrap();
-
- let batch = RecordBatch::new_empty(schema.clone());
- let geom_batch = create_array(
- &[
- Some("POINT (0.25 0.25)"),
- Some("POINT (10 10)"),
- None,
- Some("POINT (0.25 0.25)"),
- ],
- &WKB_GEOMETRY,
- );
- let indexed_batch = EvaluatedBatch {
- batch,
- geom_array: EvaluatedGeometryArray::try_new(geom_batch,
&WKB_GEOMETRY).unwrap(),
- };
- builder.add_batch(indexed_batch).unwrap();
-
- let index = builder.finish().unwrap();
- assert_eq!(index.schema(), schema);
- assert_eq!(index.indexed_batches.len(), 1);
- }
-
- #[test]
- fn test_knn_query_execution_with_sample_data() {
- // Create a spatial index with sample geometry data
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
- let options = SpatialJoinOptions {
- execution_mode: ExecutionMode::PrepareBuild,
- ..Default::default()
- };
- let metrics = SpatialJoinBuildMetrics::default();
-
- let spatial_predicate =
SpatialPredicate::KNearestNeighbors(KNNPredicate::new(
- Arc::new(Column::new("geom", 0)),
- Arc::new(Column::new("geom", 1)),
- 5,
- false,
- JoinSide::Left,
- ));
-
- // Create sample geometry data - points at known locations
- let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
- "geom",
- DataType::Binary,
- true,
- )]));
-
- let mut builder = SpatialIndexBuilder::new(
- schema.clone(),
- spatial_predicate,
- options,
- JoinType::Inner,
- 4,
- memory_pool,
- metrics,
- )
- .unwrap();
-
- let batch = RecordBatch::new_empty(schema.clone());
-
- // Create geometries at different distances from the query point (0, 0)
- let geom_batch = create_array(
- &[
- Some("POINT (1 0)"), // Distance: 1.0
- Some("POINT (0 2)"), // Distance: 2.0
- Some("POINT (3 0)"), // Distance: 3.0
- Some("POINT (0 4)"), // Distance: 4.0
- Some("POINT (5 0)"), // Distance: 5.0
- Some("POINT (2 2)"), // Distance: ~2.83
- Some("POINT (1 1)"), // Distance: ~1.41
- ],
- &WKB_GEOMETRY,
- );
-
- let indexed_batch = EvaluatedBatch {
- batch,
- geom_array: EvaluatedGeometryArray::try_new(geom_batch,
&WKB_GEOMETRY).unwrap(),
- };
- builder.add_batch(indexed_batch).unwrap();
-
- let index = builder.finish().unwrap();
-
- // Create a query geometry at origin (0, 0)
- let query_geom = create_array(&[Some("POINT (0 0)")], &WKB_GEOMETRY);
- let query_array = EvaluatedGeometryArray::try_new(query_geom,
&WKB_GEOMETRY).unwrap();
- let query_wkb = &query_array.wkbs()[0].as_ref().unwrap();
-
- // Test KNN query with k=3
- let mut build_positions = Vec::new();
- let result = index
- .query_knn(
- query_wkb,
- 3, // k=3
- false, // use_spheroid=false
- false, // include_tie_breakers=false
- &mut build_positions,
- )
- .unwrap();
-
- // Verify we got 3 results
- assert_eq!(build_positions.len(), 3);
- assert_eq!(result.count, 3);
- assert!(result.candidate_count >= 3);
-
- // Create a mapping of positions to verify correct ordering
- // We expect the 3 closest points: (1,0), (1,1), (0,2)
- let expected_closest_indices = vec![0, 6, 1]; // Based on our sample
data ordering
- let mut found_indices = Vec::new();
-
- for (_batch_idx, row_idx) in &build_positions {
- found_indices.push(*row_idx as usize);
- }
-
- // Sort to compare sets (order might vary due to implementation)
- found_indices.sort();
- let mut expected_sorted = expected_closest_indices;
- expected_sorted.sort();
-
- assert_eq!(found_indices, expected_sorted);
- }
-
- #[test]
- fn test_knn_query_execution_with_different_k_values() {
- // Create spatial index with more data points
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
- let options = SpatialJoinOptions {
- execution_mode: ExecutionMode::PrepareBuild,
- ..Default::default()
- };
- let metrics = SpatialJoinBuildMetrics::default();
-
- let spatial_predicate =
SpatialPredicate::KNearestNeighbors(KNNPredicate::new(
- Arc::new(Column::new("geom", 0)),
- Arc::new(Column::new("geom", 1)),
- 5,
- false,
- JoinSide::Left,
- ));
-
- let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
- "geom",
- DataType::Binary,
- true,
- )]));
-
- let mut builder = SpatialIndexBuilder::new(
- schema.clone(),
- spatial_predicate,
- options,
- JoinType::Inner,
- 4,
- memory_pool,
- metrics,
- )
- .unwrap();
-
- let batch = RecordBatch::new_empty(schema.clone());
-
- // Create 10 points at regular intervals
- let geom_batch = create_array(
- &[
- Some("POINT (1 0)"), // 0: Distance 1
- Some("POINT (2 0)"), // 1: Distance 2
- Some("POINT (3 0)"), // 2: Distance 3
- Some("POINT (4 0)"), // 3: Distance 4
- Some("POINT (5 0)"), // 4: Distance 5
- Some("POINT (6 0)"), // 5: Distance 6
- Some("POINT (7 0)"), // 6: Distance 7
- Some("POINT (8 0)"), // 7: Distance 8
- Some("POINT (9 0)"), // 8: Distance 9
- Some("POINT (10 0)"), // 9: Distance 10
- ],
- &WKB_GEOMETRY,
- );
-
- let indexed_batch = EvaluatedBatch {
- batch,
- geom_array: EvaluatedGeometryArray::try_new(geom_batch,
&WKB_GEOMETRY).unwrap(),
- };
- builder.add_batch(indexed_batch).unwrap();
-
- let index = builder.finish().unwrap();
-
- // Query point at origin
- let query_geom = create_array(&[Some("POINT (0 0)")], &WKB_GEOMETRY);
- let query_array = EvaluatedGeometryArray::try_new(query_geom,
&WKB_GEOMETRY).unwrap();
- let query_wkb = &query_array.wkbs()[0].as_ref().unwrap();
-
- // Test different k values
- for k in [1, 3, 5, 7, 10] {
- let mut build_positions = Vec::new();
- let result = index
- .query_knn(query_wkb, k, false, false, &mut build_positions)
- .unwrap();
-
- // Verify we got exactly k results (or all available if k > total)
- let expected_results = std::cmp::min(k as usize, 10);
- assert_eq!(build_positions.len(), expected_results);
- assert_eq!(result.count, expected_results);
-
- // Verify the results are the k closest points
- let mut row_indices: Vec<usize> = build_positions
- .iter()
- .map(|(_, row_idx)| *row_idx as usize)
- .collect();
- row_indices.sort();
-
- let expected_indices: Vec<usize> = (0..expected_results).collect();
- assert_eq!(row_indices, expected_indices);
- }
- }
-
- #[test]
- fn test_knn_query_execution_with_spheroid_distance() {
- // Create spatial index
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
- let options = SpatialJoinOptions {
- execution_mode: ExecutionMode::PrepareBuild,
- ..Default::default()
- };
- let metrics = SpatialJoinBuildMetrics::default();
-
- let spatial_predicate =
SpatialPredicate::KNearestNeighbors(KNNPredicate::new(
- Arc::new(Column::new("geom", 0)),
- Arc::new(Column::new("geom", 1)),
- 5,
- true,
- JoinSide::Left,
- ));
-
- let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
- "geom",
- DataType::Binary,
- true,
- )]));
-
- let mut builder = SpatialIndexBuilder::new(
- schema.clone(),
- spatial_predicate,
- options,
- JoinType::Inner,
- 4,
- memory_pool,
- metrics,
- )
- .unwrap();
-
- let batch = RecordBatch::new_empty(schema.clone());
-
- // Create points with geographic coordinates (longitude, latitude)
- let geom_batch = create_array(
- &[
- Some("POINT (-74.0 40.7)"), // NYC area
- Some("POINT (-73.9 40.7)"), // Slightly east
- Some("POINT (-74.1 40.7)"), // Slightly west
- Some("POINT (-74.0 40.8)"), // Slightly north
- Some("POINT (-74.0 40.6)"), // Slightly south
- ],
- &WKB_GEOMETRY,
- );
-
- let indexed_batch = EvaluatedBatch {
- batch,
- geom_array: EvaluatedGeometryArray::try_new(geom_batch,
&WKB_GEOMETRY).unwrap(),
- };
- builder.add_batch(indexed_batch).unwrap();
+pub(crate) trait SpatialIndexFull: SpatialIndex + SpatialIndexInternal {}
- let index = builder.finish().unwrap();
+impl<T> SpatialIndexFull for T where T: SpatialIndex + SpatialIndexInternal {}
Review Comment:
I have fixed this. Now, it only has a unified interface.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]