This is an automated email from the ASF dual-hosted git repository.
kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new 5102eedc feat(rust/sedona-spatial-join) Add partitioned index provider
(#555)
5102eedc is described below
commit 5102eedc6f9f48d68f49208794ccafb5eceb0213
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Thu Jan 29 22:17:46 2026 +0800
feat(rust/sedona-spatial-join) Add partitioned index provider (#555)
This patch adds a index provider for coordinating the creation of spatial
index for specified partitions. It is also integrated into `SpatialJoinExec` so
we use it to create the spatial index even when there's only one spatial
partition (the degenerate case). The handling for multiple spatial partitions
will be added in a subsequent PR.
The memory reservations growed in the build side collection phase will be
held by `PartitionedIndexProvider`. Spatial indexes created by the provider
does not need to hold memory reservations.
The next step is to support partitioned probe side by adding a
`PartitionedProbeStreamProvider`, and modify the state machine of
`SpatialJoinStream` to process multiple spatial partitions sequentially.
---
rust/sedona-spatial-join/Cargo.toml | 1 +
rust/sedona-spatial-join/src/build_index.rs | 1 -
rust/sedona-spatial-join/src/exec.rs | 50 +-
rust/sedona-spatial-join/src/index.rs | 2 +
.../src/index/build_side_collector.rs | 29 +-
rust/sedona-spatial-join/src/index/memory_plan.rs | 191 +++++++
.../src/index/partitioned_index_provider.rs | 602 +++++++++++++++++++++
.../sedona-spatial-join/src/index/spatial_index.rs | 35 --
.../src/index/spatial_index_builder.rs | 38 +-
rust/sedona-spatial-join/src/lib.rs | 3 +-
rust/sedona-spatial-join/src/partitioning/kdb.rs | 35 +-
.../src/partitioning/stream_repartitioner.rs | 7 +
rust/sedona-spatial-join/src/prepare.rs | 514 ++++++++++++++++++
rust/sedona-spatial-join/src/stream.rs | 135 ++++-
rust/sedona-spatial-join/src/utils.rs | 1 +
rust/sedona-spatial-join/src/utils/bbox_sampler.rs | 1 -
.../src/utils/disposable_async_cell.rs | 204 +++++++
17 files changed, 1734 insertions(+), 115 deletions(-)
diff --git a/rust/sedona-spatial-join/Cargo.toml
b/rust/sedona-spatial-join/Cargo.toml
index 322ec572..d34f7a6c 100644
--- a/rust/sedona-spatial-join/Cargo.toml
+++ b/rust/sedona-spatial-join/Cargo.toml
@@ -48,6 +48,7 @@ futures = { workspace = true }
pin-project-lite = { workspace = true }
once_cell = { workspace = true }
parking_lot = { workspace = true }
+tokio = { workspace = true }
geo = { workspace = true }
sedona-geo-generic-alg = { workspace = true }
geo-traits = { workspace = true, features = ["geo-types"] }
diff --git a/rust/sedona-spatial-join/src/build_index.rs
b/rust/sedona-spatial-join/src/build_index.rs
index f3cbb34b..0e292369 100644
--- a/rust/sedona-spatial-join/src/build_index.rs
+++ b/rust/sedona-spatial-join/src/build_index.rs
@@ -105,7 +105,6 @@ pub async fn build_index(
sedona_options.spatial_join,
join_type,
probe_threads_count,
- Arc::clone(memory_pool),
SpatialJoinBuildMetrics::new(0, &metrics),
)?;
index_builder.add_partitions(build_partitions).await?;
diff --git a/rust/sedona-spatial-join/src/exec.rs
b/rust/sedona-spatial-join/src/exec.rs
index 50cbd171..495518ea 100644
--- a/rust/sedona-spatial-join/src/exec.rs
+++ b/rust/sedona-spatial-join/src/exec.rs
@@ -36,12 +36,13 @@ use parking_lot::Mutex;
use sedona_common::SpatialJoinOptions;
use crate::{
- build_index::build_index,
- index::SpatialIndex,
+ prepare::{SpatialJoinComponents, SpatialJoinComponentsBuilder},
spatial_predicate::{KNNPredicate, SpatialPredicate},
stream::{SpatialJoinProbeMetrics, SpatialJoinStream},
- utils::join_utils::{asymmetric_join_output_partitioning,
boundedness_from_children},
- utils::once_fut::OnceAsync,
+ utils::{
+ join_utils::{asymmetric_join_output_partitioning,
boundedness_from_children},
+ once_fut::OnceAsync,
+ },
SedonaOptions,
};
@@ -132,9 +133,10 @@ pub struct SpatialJoinExec {
column_indices: Vec<ColumnIndex>,
/// Cache holding plan properties like equivalences, output partitioning
etc.
cache: PlanProperties,
- /// Spatial index built asynchronously on first execute() call and shared
across all partitions.
- /// Uses OnceAsync for lazy initialization coordinated via async runtime.
- once_async_spatial_index: Arc<Mutex<Option<OnceAsync<SpatialIndex>>>>,
+ /// Once future for creating the partitioned index provider shared by all
probe partitions.
+ /// This future runs only once before probing starts, and can be disposed
by the last finished
+ /// stream so the provider does not outlive the execution plan
unnecessarily.
+ once_async_spatial_join_components:
Arc<Mutex<Option<OnceAsync<SpatialJoinComponents>>>>,
/// Indicates if this SpatialJoin was converted from a HashJoin
/// When true, we preserve HashJoin's equivalence properties and
partitioning
converted_from_hash_join: bool,
@@ -203,7 +205,7 @@ impl SpatialJoinExec {
projection,
metrics: Default::default(),
cache,
- once_async_spatial_index: Arc::new(Mutex::new(None)),
+ once_async_spatial_join_components: Arc::new(Mutex::new(None)),
converted_from_hash_join,
seed,
})
@@ -431,7 +433,7 @@ impl ExecutionPlan for SpatialJoinExec {
projection: self.projection.clone(),
metrics: Default::default(),
cache: self.cache.clone(),
- once_async_spatial_index: Arc::new(Mutex::new(None)),
+ once_async_spatial_join_components: Arc::new(Mutex::new(None)),
converted_from_hash_join: self.converted_from_hash_join,
seed: self.seed,
}))
@@ -463,8 +465,8 @@ impl ExecutionPlan for SpatialJoinExec {
let (build_plan, probe_plan) = (&self.left, &self.right);
// Build the spatial index using shared OnceAsync
- let once_fut_spatial_index = {
- let mut once_async = self.once_async_spatial_index.lock();
+ let once_fut_spatial_join_components = {
+ let mut once_async =
self.once_async_spatial_join_components.lock();
once_async
.get_or_insert(OnceAsync::default())
.try_once(|| {
@@ -479,16 +481,16 @@ impl ExecutionPlan for SpatialJoinExec {
let probe_thread_count =
self.right.output_partitioning().partition_count();
- Ok(build_index(
+ let spatial_join_components_builder =
SpatialJoinComponentsBuilder::new(
Arc::clone(&context),
build_side.schema(),
- build_streams,
self.on.clone(),
self.join_type,
probe_thread_count,
self.metrics.clone(),
self.seed,
- ))
+ );
+
Ok(spatial_join_components_builder.build(build_streams))
})?
};
@@ -508,6 +510,7 @@ impl ExecutionPlan for SpatialJoinExec {
self.maintains_input_order()[1] &&
self.right.output_ordering().is_some();
Ok(Box::pin(SpatialJoinStream::new(
+ partition,
self.schema(),
&self.on,
self.filter.clone(),
@@ -518,8 +521,8 @@ impl ExecutionPlan for SpatialJoinExec {
join_metrics,
sedona_options.spatial_join,
target_output_batch_size,
- once_fut_spatial_index,
- Arc::clone(&self.once_async_spatial_index),
+ once_fut_spatial_join_components,
+ Arc::clone(&self.once_async_spatial_join_components),
)))
}
}
@@ -556,8 +559,8 @@ impl SpatialJoinExec {
let actual_probe_plan_is_left = std::ptr::eq(probe_plan.as_ref(),
self.left.as_ref());
// Build the spatial index
- let once_fut_spatial_index = {
- let mut once_async = self.once_async_spatial_index.lock();
+ let once_fut_spatial_join_components = {
+ let mut once_async =
self.once_async_spatial_join_components.lock();
once_async
.get_or_insert(OnceAsync::default())
.try_once(|| {
@@ -571,16 +574,16 @@ impl SpatialJoinExec {
}
let probe_thread_count =
probe_plan.output_partitioning().partition_count();
- Ok(build_index(
+ let spatial_join_components_builder =
SpatialJoinComponentsBuilder::new(
Arc::clone(&context),
build_side.schema(),
- build_streams,
self.on.clone(),
self.join_type,
probe_thread_count,
self.metrics.clone(),
self.seed,
- ))
+ );
+ Ok(spatial_join_components_builder.build(build_streams))
})?
};
@@ -605,6 +608,7 @@ impl SpatialJoinExec {
};
Ok(Box::pin(SpatialJoinStream::new(
+ partition,
self.schema(),
&self.on,
self.filter.clone(),
@@ -615,8 +619,8 @@ impl SpatialJoinExec {
join_metrics,
sedona_options.spatial_join,
target_output_batch_size,
- once_fut_spatial_index,
- Arc::clone(&self.once_async_spatial_index),
+ once_fut_spatial_join_components,
+ Arc::clone(&self.once_async_spatial_join_components),
)))
}
}
diff --git a/rust/sedona-spatial-join/src/index.rs
b/rust/sedona-spatial-join/src/index.rs
index 55df23d5..af31b8af 100644
--- a/rust/sedona-spatial-join/src/index.rs
+++ b/rust/sedona-spatial-join/src/index.rs
@@ -17,6 +17,8 @@
pub(crate) mod build_side_collector;
mod knn_adapter;
+pub(crate) mod memory_plan;
+pub(crate) mod partitioned_index_provider;
pub(crate) mod spatial_index;
pub(crate) mod spatial_index_builder;
diff --git a/rust/sedona-spatial-join/src/index/build_side_collector.rs
b/rust/sedona-spatial-join/src/index/build_side_collector.rs
index 646c6be2..d888680f 100644
--- a/rust/sedona-spatial-join/src/index/build_side_collector.rs
+++ b/rust/sedona-spatial-join/src/index/build_side_collector.rs
@@ -68,6 +68,9 @@ pub(crate) struct BuildPartition {
/// The size of this reservation will be used to determine the maximum
size of
/// each spatial partition, as well as how many spatial partitions to
create.
pub reservation: MemoryReservation,
+
+ /// Metrics collected during the build side collection phase
+ pub metrics: CollectBuildSideMetrics,
}
/// A collector for evaluating the spatial expression on build side batches
and collect
@@ -112,6 +115,10 @@ impl CollectBuildSideMetrics {
spill_metrics: SpillMetrics::new(metrics, partition),
}
}
+
+ pub fn spill_metrics(&self) -> SpillMetrics {
+ self.spill_metrics.clone()
+ }
}
impl BuildSideBatchesCollector {
@@ -147,7 +154,7 @@ impl BuildSideBatchesCollector {
mut stream: SendableEvaluatedBatchStream,
mut reservation: MemoryReservation,
mut bbox_sampler: BoundingBoxSampler,
- metrics: &CollectBuildSideMetrics,
+ metrics: CollectBuildSideMetrics,
) -> Result<BuildPartition> {
let mut spill_writer_opt = None;
let mut in_mem_batches: Vec<EvaluatedBatch> = Vec::new();
@@ -200,7 +207,7 @@ impl BuildSideBatchesCollector {
e,
);
spill_writer_opt =
- self.spill_in_mem_batches(&mut in_mem_batches,
metrics)?;
+ self.spill_in_mem_batches(&mut in_mem_batches,
&metrics)?;
}
}
Some(spill_writer) => {
@@ -236,7 +243,7 @@ impl BuildSideBatchesCollector {
"Force spilling enabled. Spilling {} in-memory batches to
disk.",
in_mem_batches.len()
);
- spill_writer_opt = self.spill_in_mem_batches(&mut in_mem_batches,
metrics)?;
+ spill_writer_opt = self.spill_in_mem_batches(&mut in_mem_batches,
&metrics)?;
}
let build_side_batch_stream: SendableEvaluatedBatchStream = match
spill_writer_opt {
@@ -266,6 +273,7 @@ impl BuildSideBatchesCollector {
bbox_samples: bbox_sampler.into_samples(),
estimated_spatial_index_memory_usage,
reservation,
+ metrics,
})
}
@@ -329,7 +337,7 @@ impl BuildSideBatchesCollector {
let evaluated_stream =
create_evaluated_build_stream(stream, evaluator,
metrics.time_taken.clone());
let result = collector
- .collect(evaluated_stream, reservation, bbox_sampler,
&metrics)
+ .collect(evaluated_stream, reservation, bbox_sampler,
metrics)
.await;
(partition_id, result)
});
@@ -378,7 +386,7 @@ impl BuildSideBatchesCollector {
let evaluated_stream =
create_evaluated_build_stream(stream, evaluator,
metrics.time_taken.clone());
let result = self
- .collect(evaluated_stream, reservation, bbox_sampler, &metrics)
+ .collect(evaluated_stream, reservation, bbox_sampler, metrics)
.await?;
results.push(result);
}
@@ -534,11 +542,12 @@ mod tests {
let metrics = CollectBuildSideMetrics::new(0, &metrics_set);
let partition = collector
- .collect(stream, reservation, sampler, &metrics)
+ .collect(stream, reservation, sampler, metrics)
.await?;
let stream = partition.build_side_batch_stream;
let is_external = stream.is_external();
let batches: Vec<EvaluatedBatch> = stream.try_collect().await?;
+ let metrics = &partition.metrics;
assert!(!is_external, "Expected in-memory batches");
assert_eq!(collect_ids(&batches), vec![0, 1, 2]);
assert_eq!(partition.num_rows, 3);
@@ -564,14 +573,15 @@ mod tests {
let metrics = CollectBuildSideMetrics::new(0, &metrics_set);
let partition = collector
- .collect(stream, reservation, sampler, &metrics)
+ .collect(stream, reservation, sampler, metrics)
.await?;
let stream = partition.build_side_batch_stream;
let is_external = stream.is_external();
let batches: Vec<EvaluatedBatch> = stream.try_collect().await?;
+ let metrics = &partition.metrics;
assert!(is_external, "Expected batches to spill to disk");
assert_eq!(collect_ids(&batches), vec![10, 11, 12]);
- let spill_metrics = metrics.spill_metrics;
+ let spill_metrics = metrics.spill_metrics();
assert!(spill_metrics.spill_file_count.value() >= 1);
assert!(spill_metrics.spilled_rows.value() >= 1);
Ok(())
@@ -587,12 +597,13 @@ mod tests {
let metrics = CollectBuildSideMetrics::new(0, &metrics_set);
let partition = collector
- .collect(stream, reservation, sampler, &metrics)
+ .collect(stream, reservation, sampler, metrics)
.await?;
assert_eq!(partition.num_rows, 0);
let stream = partition.build_side_batch_stream;
let is_external = stream.is_external();
let batches: Vec<EvaluatedBatch> = stream.try_collect().await?;
+ let metrics = &partition.metrics;
assert!(!is_external);
assert!(batches.is_empty());
assert_eq!(metrics.num_batches.value(), 0);
diff --git a/rust/sedona-spatial-join/src/index/memory_plan.rs
b/rust/sedona-spatial-join/src/index/memory_plan.rs
new file mode 100644
index 00000000..24a25c89
--- /dev/null
+++ b/rust/sedona-spatial-join/src/index/memory_plan.rs
@@ -0,0 +1,191 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::cmp::max;
+
+use datafusion_common::{DataFusionError, Result};
+
+use super::BuildPartition;
+
+/// The memory accounting summary of a build side partition. This is collected
+/// during the build side collection phase and used to estimate the memory
usage for
+/// running spatial join.
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub(crate) struct PartitionMemorySummary {
+ /// Number of rows in the partition.
+ pub num_rows: usize,
+ /// The total memory reserved when collecting this build side partition.
+ pub reserved_memory: usize,
+ /// The estimated memory usage for building the spatial index for all the
data in
+ /// this build side partition.
+ pub estimated_index_memory_usage: usize,
+}
+
+impl From<&BuildPartition> for PartitionMemorySummary {
+ fn from(partition: &BuildPartition) -> Self {
+ Self {
+ num_rows: partition.num_rows,
+ reserved_memory: partition.reservation.size(),
+ estimated_index_memory_usage:
partition.estimated_spatial_index_memory_usage,
+ }
+ }
+}
+
+/// A detailed plan for memory usage during spatial join execution. The
spatial join
+/// could be spatial-partitioned if the reserved memory is not sufficient to
hold the
+/// entire spatial index.
+#[derive(Debug, PartialEq, Eq)]
+pub(crate) struct MemoryPlan {
+ /// The total number of rows in the build side.
+ pub num_rows: usize,
+ /// The total memory reserved for the build side.
+ pub reserved_memory: usize,
+ /// The estimated memory usage for building the spatial index for the
entire build side.
+ /// It could be larger than [`Self::reserved_memory`], and in that case we
need to
+ /// partition the build side using spatial partitioning.
+ pub estimated_index_memory_usage: usize,
+ /// The memory budget for holding the spatial index. If the spatial join
is partitioned,
+ /// this is the memory budget for holding the spatial index of a single
partition.
+ pub memory_for_spatial_index: usize,
+ /// The memory budget for intermittent usage, such as buffering data
during repartitioning.
+ pub memory_for_intermittent_usage: usize,
+ /// The number of spatial partitions to split the build side into.
+ pub num_partitions: usize,
+}
+
+/// Compute the memory plan for running spatial join based on the memory
summaries of
+/// build side partitions.
+pub(crate) fn compute_memory_plan<I>(partition_summaries: I) ->
Result<MemoryPlan>
+where
+ I: IntoIterator<Item = PartitionMemorySummary>,
+{
+ let mut num_rows = 0;
+ let mut reserved_memory = 0;
+ let mut estimated_index_memory_usage = 0;
+
+ for summary in partition_summaries {
+ num_rows += summary.num_rows;
+ reserved_memory += summary.reserved_memory;
+ estimated_index_memory_usage += summary.estimated_index_memory_usage;
+ }
+
+ if reserved_memory == 0 && num_rows > 0 {
+ return Err(DataFusionError::ResourcesExhausted(
+ "Insufficient memory for spatial join".to_string(),
+ ));
+ }
+
+ // Use 80% of reserved memory for holding the spatial index. The other 20%
are reserved for
+ // intermittent usage like repartitioning buffers.
+ let memory_for_spatial_index =
+ calculate_memory_for_spatial_index(reserved_memory,
estimated_index_memory_usage);
+ let memory_for_intermittent_usage = reserved_memory -
memory_for_spatial_index;
+
+ let num_partitions = if num_rows > 0 {
+ max(
+ 1,
+ estimated_index_memory_usage.div_ceil(memory_for_spatial_index),
+ )
+ } else {
+ 1
+ };
+
+ Ok(MemoryPlan {
+ num_rows,
+ reserved_memory,
+ estimated_index_memory_usage,
+ memory_for_spatial_index,
+ memory_for_intermittent_usage,
+ num_partitions,
+ })
+}
+
+fn calculate_memory_for_spatial_index(
+ reserved_memory: usize,
+ estimated_index_memory_usage: usize,
+) -> usize {
+ if reserved_memory >= estimated_index_memory_usage {
+ // Reserved memory is sufficient to hold the entire spatial index.
Make sure that
+ // the memory for spatial index is enough for holding the entire
index. The rest
+ // can be used for intermittent usage.
+ estimated_index_memory_usage
+ } else {
+ // Reserved memory is not sufficient to hold the entire spatial index,
We need to
+ // partition the dataset using spatial partitioning. Use 80% of
reserved memory
+ // for holding the partitioned spatial index. The rest is used for
intermittent usage.
+ let reserved_portion = reserved_memory.saturating_mul(80) / 100;
+ if reserved_portion == 0 {
+ reserved_memory
+ } else {
+ reserved_portion
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ fn summary(
+ num_rows: usize,
+ reserved_memory: usize,
+ estimated_usage: usize,
+ ) -> PartitionMemorySummary {
+ PartitionMemorySummary {
+ num_rows,
+ reserved_memory,
+ estimated_index_memory_usage: estimated_usage,
+ }
+ }
+
+ #[test]
+ fn memory_plan_errors_when_no_memory_but_rows_exist() {
+ let err = compute_memory_plan(vec![summary(10, 0, 512)]).unwrap_err();
+ assert!(matches!(
+ err,
+ DataFusionError::ResourcesExhausted(msg) if
msg.contains("Insufficient memory")
+ ));
+ }
+
+ #[test]
+ fn memory_plan_partitions_large_jobs() {
+ let plan =
+ compute_memory_plan(vec![summary(100, 2_000, 1_500), summary(150,
1_000, 3_500)])
+ .expect("plan should succeed");
+
+ assert_eq!(plan.num_rows, 250);
+ assert_eq!(plan.reserved_memory, 3_000);
+ assert_eq!(plan.memory_for_spatial_index, 2_400);
+ assert_eq!(plan.memory_for_intermittent_usage, 600);
+ assert_eq!(plan.num_partitions, 3);
+ }
+
+ #[test]
+ fn memory_plan_handles_zero_rows() {
+ let plan = compute_memory_plan(vec![summary(0, 0, 0)]).expect("plan
should succeed");
+ assert_eq!(plan.num_partitions, 1);
+ assert_eq!(plan.memory_for_spatial_index, 0);
+ assert_eq!(plan.memory_for_intermittent_usage, 0);
+ }
+
+ #[test]
+ fn memory_plan_uses_entire_reservation_when_fraction_rounds_down() {
+ let plan = compute_memory_plan(vec![summary(10, 1, 1)]).expect("plan
should succeed");
+ assert_eq!(plan.memory_for_spatial_index, 1);
+ assert_eq!(plan.memory_for_intermittent_usage, 0);
+ }
+}
diff --git a/rust/sedona-spatial-join/src/index/partitioned_index_provider.rs
b/rust/sedona-spatial-join/src/index/partitioned_index_provider.rs
new file mode 100644
index 00000000..f9aeb893
--- /dev/null
+++ b/rust/sedona-spatial-join/src/index/partitioned_index_provider.rs
@@ -0,0 +1,602 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_schema::SchemaRef;
+use datafusion_common::{DataFusionError, Result, SharedResult};
+use datafusion_common_runtime::JoinSet;
+use datafusion_execution::memory_pool::MemoryReservation;
+use datafusion_expr::JoinType;
+use futures::StreamExt;
+use parking_lot::Mutex;
+use sedona_common::{sedona_internal_err, SpatialJoinOptions};
+use std::ops::DerefMut;
+use std::sync::Arc;
+use tokio::sync::mpsc;
+
+use
crate::evaluated_batch::evaluated_batch_stream::external::ExternalEvaluatedBatchStream;
+use crate::index::BuildPartition;
+use crate::partitioning::stream_repartitioner::{SpilledPartition,
SpilledPartitions};
+use crate::utils::disposable_async_cell::DisposableAsyncCell;
+use crate::{
+ index::{SpatialIndex, SpatialIndexBuilder, SpatialJoinBuildMetrics},
+ partitioning::SpatialPartition,
+ spatial_predicate::SpatialPredicate,
+};
+
+pub(crate) struct PartitionedIndexProvider {
+ schema: SchemaRef,
+ spatial_predicate: SpatialPredicate,
+ options: SpatialJoinOptions,
+ join_type: JoinType,
+ probe_threads_count: usize,
+ metrics: SpatialJoinBuildMetrics,
+
+ /// Data on the build side to build index for
+ data: BuildSideData,
+
+ /// Async cells for indexes, one per regular partition
+ index_cells: Vec<DisposableAsyncCell<SharedResult<Arc<SpatialIndex>>>>,
+
+ /// The memory reserved in the build side collection phase. We'll hold
them until
+ /// we don't need to build spatial indexes.
+ _reservations: Vec<MemoryReservation>,
+}
+
+pub(crate) enum BuildSideData {
+ SinglePartition(Mutex<Option<Vec<BuildPartition>>>),
+ MultiPartition(Mutex<SpilledPartitions>),
+}
+
+impl PartitionedIndexProvider {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new_multi_partition(
+ schema: SchemaRef,
+ spatial_predicate: SpatialPredicate,
+ options: SpatialJoinOptions,
+ join_type: JoinType,
+ probe_threads_count: usize,
+ partitioned_spill_files: SpilledPartitions,
+ metrics: SpatialJoinBuildMetrics,
+ reservations: Vec<MemoryReservation>,
+ ) -> Self {
+ let num_partitions = partitioned_spill_files.num_regular_partitions();
+ let index_cells = (0..num_partitions)
+ .map(|_| DisposableAsyncCell::new())
+ .collect();
+ Self {
+ schema,
+ spatial_predicate,
+ options,
+ join_type,
+ probe_threads_count,
+ metrics,
+ data:
BuildSideData::MultiPartition(Mutex::new(partitioned_spill_files)),
+ index_cells,
+ _reservations: reservations,
+ }
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub fn new_single_partition(
+ schema: SchemaRef,
+ spatial_predicate: SpatialPredicate,
+ options: SpatialJoinOptions,
+ join_type: JoinType,
+ probe_threads_count: usize,
+ mut build_partitions: Vec<BuildPartition>,
+ metrics: SpatialJoinBuildMetrics,
+ ) -> Self {
+ let reservations = build_partitions
+ .iter_mut()
+ .map(|p| p.reservation.take())
+ .collect();
+ let index_cells = vec![DisposableAsyncCell::new()];
+ Self {
+ schema,
+ spatial_predicate,
+ options,
+ join_type,
+ probe_threads_count,
+ metrics,
+ data:
BuildSideData::SinglePartition(Mutex::new(Some(build_partitions))),
+ index_cells,
+ _reservations: reservations,
+ }
+ }
+
+ pub fn new_empty(
+ schema: SchemaRef,
+ spatial_predicate: SpatialPredicate,
+ options: SpatialJoinOptions,
+ join_type: JoinType,
+ probe_threads_count: usize,
+ metrics: SpatialJoinBuildMetrics,
+ ) -> Self {
+ let build_partitions = Vec::new();
+ Self::new_single_partition(
+ schema,
+ spatial_predicate,
+ options,
+ join_type,
+ probe_threads_count,
+ build_partitions,
+ metrics,
+ )
+ }
+
+ pub fn num_regular_partitions(&self) -> usize {
+ self.index_cells.len()
+ }
+
+ pub async fn build_or_wait_for_index(
+ &self,
+ partition_id: u32,
+ ) -> Option<Result<Arc<SpatialIndex>>> {
+ let cell = match self.index_cells.get(partition_id as usize) {
+ Some(cell) => cell,
+ None => {
+ return Some(sedona_internal_err!(
+ "partition_id {} exceeds {} partitions",
+ partition_id,
+ self.index_cells.len()
+ ))
+ }
+ };
+ if !cell.is_empty() {
+ return get_index_from_cell(cell).await;
+ }
+
+ let res_index = {
+ let opt_res_index = self.maybe_build_index(partition_id).await;
+ match opt_res_index {
+ Some(res_index) => res_index,
+ None => {
+ // The build side data for building the index has already
been consumed by someone else,
+ // we just need to wait for the task consumed the data to
finish building the index.
+ return get_index_from_cell(cell).await;
+ }
+ }
+ };
+
+ match res_index {
+ Ok(idx) => {
+ if let Err(e) = cell.set(Ok(Arc::clone(&idx))) {
+ // This is probably because the cell has been disposed. No
one
+ // will get the index from the cell so this failure is not
a big deal.
+ log::debug!("Cannot set the index into the async cell:
{:?}", e);
+ }
+ Some(Ok(idx))
+ }
+ Err(err) => {
+ let err_arc = Arc::new(err);
+ if let Err(e) = cell.set(Err(Arc::clone(&err_arc))) {
+ log::debug!(
+ "Cannot set the index build error into the async cell:
{:?}",
+ e
+ );
+ }
+ Some(Err(DataFusionError::Shared(err_arc)))
+ }
+ }
+ }
+
+ async fn maybe_build_index(&self, partition_id: u32) ->
Option<Result<Arc<SpatialIndex>>> {
+ match &self.data {
+ BuildSideData::SinglePartition(build_partition_opt) => {
+ if partition_id != 0 {
+ return Some(sedona_internal_err!(
+ "partition_id for single-partition index is not 0"
+ ));
+ }
+
+ // consume the build side data for building the index
+ let build_partition_opt = {
+ let mut locked = build_partition_opt.lock();
+ std::mem::take(locked.deref_mut())
+ };
+
+ let Some(build_partition) = build_partition_opt else {
+ // already consumed by previous attempts, the result
should be present in the channel.
+ return None;
+ };
+
Some(self.build_index_for_single_partition(build_partition).await)
+ }
+ BuildSideData::MultiPartition(partitioned_spill_files) => {
+ // consume this partition of build side data for building index
+ let spilled_partition = {
+ let mut locked = partitioned_spill_files.lock();
+ let partition = SpatialPartition::Regular(partition_id);
+ if !locked.can_take_spilled_partition(partition) {
+ // already consumed by previous attempts, the result
should be present in the channel.
+ return None;
+ }
+ match locked.take_spilled_partition(partition) {
+ Ok(spilled_partition) => spilled_partition,
+ Err(e) => return Some(Err(e)),
+ }
+ };
+ Some(
+ self.build_index_for_spilled_partition(spilled_partition)
+ .await,
+ )
+ }
+ }
+ }
+
+ #[cfg(test)]
+ pub async fn wait_for_index(&self, partition_id: u32) ->
Option<Result<Arc<SpatialIndex>>> {
+ let cell = match self.index_cells.get(partition_id as usize) {
+ Some(cell) => cell,
+ None => {
+ return Some(sedona_internal_err!(
+ "partition_id {} exceeds {} partitions",
+ partition_id,
+ self.index_cells.len()
+ ))
+ }
+ };
+
+ get_index_from_cell(cell).await
+ }
+
+ pub fn dispose_index(&self, partition_id: u32) {
+ if let Some(cell) = self.index_cells.get(partition_id as usize) {
+ cell.dispose();
+ }
+ }
+
+ pub fn num_loaded_indexes(&self) -> usize {
+ self.index_cells
+ .iter()
+ .filter(|index_cell| index_cell.is_set())
+ .count()
+ }
+
+ async fn build_index_for_single_partition(
+ &self,
+ build_partitions: Vec<BuildPartition>,
+ ) -> Result<Arc<SpatialIndex>> {
+ let mut index_builder = SpatialIndexBuilder::new(
+ Arc::clone(&self.schema),
+ self.spatial_predicate.clone(),
+ self.options.clone(),
+ self.join_type,
+ self.probe_threads_count,
+ self.metrics.clone(),
+ )?;
+
+ for build_partition in build_partitions {
+ let stream = build_partition.build_side_batch_stream;
+ let geo_statistics = build_partition.geo_statistics;
+ index_builder.add_stream(stream, geo_statistics).await?;
+ }
+
+ let index = index_builder.finish()?;
+ Ok(Arc::new(index))
+ }
+
+ async fn build_index_for_spilled_partition(
+ &self,
+ spilled_partition: SpilledPartition,
+ ) -> Result<Arc<SpatialIndex>> {
+ let mut index_builder = SpatialIndexBuilder::new(
+ Arc::clone(&self.schema),
+ self.spatial_predicate.clone(),
+ self.options.clone(),
+ self.join_type,
+ self.probe_threads_count,
+ self.metrics.clone(),
+ )?;
+
+ // Spawn tasks to load indexed batches from spilled files concurrently
+ let (spill_files, geo_statistics, _) = spilled_partition.into_inner();
+ let mut join_set: JoinSet<Result<(), DataFusionError>> =
JoinSet::new();
+ let (tx, mut rx) = mpsc::channel(spill_files.len() * 2 + 1);
+ for spill_file in spill_files {
+ let tx = tx.clone();
+ join_set.spawn(async move {
+ let result = async {
+ let mut stream =
ExternalEvaluatedBatchStream::try_from_spill_file(spill_file)?;
+ while let Some(batch) = stream.next().await {
+ let indexed_batch = batch?;
+ if tx.send(Ok(indexed_batch)).await.is_err() {
+ return Ok(());
+ }
+ }
+ Ok::<(), DataFusionError>(())
+ }
+ .await;
+ if let Err(e) = result {
+ let _ = tx.send(Err(e)).await;
+ }
+ Ok(())
+ });
+ }
+ drop(tx);
+
+ // Collect the loaded indexed batches and add them to the index builder
+ while let Some(res) = rx.recv().await {
+ let batch = res?;
+ index_builder.add_batch(batch)?;
+ }
+
+ // Ensure all tasks completed successfully
+ while let Some(res) = join_set.join_next().await {
+ if let Err(e) = res {
+ if e.is_panic() {
+ std::panic::resume_unwind(e.into_panic());
+ }
+ return Err(DataFusionError::External(Box::new(e)));
+ }
+ }
+
+ index_builder.merge_stats(geo_statistics);
+
+ let index = index_builder.finish()?;
+ Ok(Arc::new(index))
+ }
+}
+
+async fn get_index_from_cell(
+ cell: &DisposableAsyncCell<SharedResult<Arc<SpatialIndex>>>,
+) -> Option<Result<Arc<SpatialIndex>>> {
+ match cell.get().await {
+ Some(Ok(index)) => Some(Ok(index)),
+ Some(Err(shared_err)) =>
Some(Err(DataFusionError::Shared(shared_err))),
+ None => None,
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::operand_evaluator::EvaluatedGeometryArray;
+ use crate::partitioning::partition_slots::PartitionSlots;
+ use crate::utils::bbox_sampler::BoundingBoxSamples;
+ use crate::{
+ evaluated_batch::{
+ evaluated_batch_stream::{
+ in_mem::InMemoryEvaluatedBatchStream,
SendableEvaluatedBatchStream,
+ },
+ EvaluatedBatch,
+ },
+ index::CollectBuildSideMetrics,
+ };
+ use arrow_array::{ArrayRef, BinaryArray, Int32Array, RecordBatch};
+ use arrow_schema::{DataType, Field, Schema, SchemaRef};
+ use datafusion::config::SpillCompression;
+ use datafusion_common::{DataFusionError, Result};
+ use datafusion_execution::{
+ memory_pool::{GreedyMemoryPool, MemoryConsumer, MemoryPool},
+ runtime_env::RuntimeEnv,
+ };
+ use datafusion_expr::JoinType;
+ use datafusion_physical_expr::expressions::Column;
+ use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet,
SpillMetrics};
+ use sedona_expr::statistics::GeoStatistics;
+ use sedona_functions::st_analyze_agg::AnalyzeAccumulator;
+ use sedona_geometry::analyze::analyze_geometry;
+ use sedona_schema::datatypes::WKB_GEOMETRY;
+
+ use crate::evaluated_batch::spill::EvaluatedBatchSpillWriter;
+ use crate::partitioning::stream_repartitioner::{SpilledPartition,
SpilledPartitions};
+ use crate::spatial_predicate::{RelationPredicate, SpatialRelationType};
+
+ fn sample_schema() -> SchemaRef {
+ Arc::new(Schema::new(vec![
+ Field::new("geom", DataType::Binary, true),
+ Field::new("id", DataType::Int32, false),
+ ]))
+ }
+
+ fn point_wkb(x: f64, y: f64) -> Vec<u8> {
+ let mut buf = vec![1u8, 1, 0, 0, 0];
+ buf.extend_from_slice(&x.to_le_bytes());
+ buf.extend_from_slice(&y.to_le_bytes());
+ buf
+ }
+
+ fn sample_batch(ids: &[i32], wkbs: Vec<Option<Vec<u8>>>) ->
Result<EvaluatedBatch> {
+ assert_eq!(ids.len(), wkbs.len());
+ let geom_values: Vec<Option<&[u8]>> = wkbs
+ .iter()
+ .map(|opt| opt.as_ref().map(|wkb| wkb.as_slice()))
+ .collect();
+ let geom_array: ArrayRef =
Arc::new(BinaryArray::from_opt_vec(geom_values));
+ let id_array: ArrayRef = Arc::new(Int32Array::from(ids.to_vec()));
+ let batch = RecordBatch::try_new(sample_schema(),
vec![geom_array.clone(), id_array])?;
+ let geom = EvaluatedGeometryArray::try_new(geom_array, &WKB_GEOMETRY)?;
+ Ok(EvaluatedBatch {
+ batch,
+ geom_array: geom,
+ })
+ }
+
+ fn predicate() -> SpatialPredicate {
+ SpatialPredicate::Relation(RelationPredicate::new(
+ Arc::new(Column::new("geom", 0)),
+ Arc::new(Column::new("geom", 0)),
+ SpatialRelationType::Intersects,
+ ))
+ }
+
+ fn geo_stats_from_batches(batches: &[EvaluatedBatch]) ->
Result<GeoStatistics> {
+ let mut analyzer = AnalyzeAccumulator::new(WKB_GEOMETRY, WKB_GEOMETRY);
+ for batch in batches {
+ for wkb in batch.geom_array.wkbs().iter().flatten() {
+ let summary =
+ analyze_geometry(wkb).map_err(|e|
DataFusionError::External(Box::new(e)))?;
+ analyzer.ingest_geometry_summary(&summary);
+ }
+ }
+ Ok(analyzer.finish())
+ }
+
+ fn new_reservation(memory_pool: Arc<dyn MemoryPool>) -> MemoryReservation {
+ let consumer = MemoryConsumer::new("PartitionedIndexProviderTest");
+ consumer.register(&memory_pool)
+ }
+
+ fn build_partition_from_batches(
+ memory_pool: Arc<dyn MemoryPool>,
+ batches: Vec<EvaluatedBatch>,
+ ) -> Result<BuildPartition> {
+ let schema = batches
+ .first()
+ .map(|batch| batch.schema())
+ .unwrap_or_else(|| Arc::new(Schema::empty()));
+ let geo_statistics = geo_stats_from_batches(&batches)?;
+ let num_rows = batches.iter().map(|batch| batch.num_rows()).sum();
+ let mut estimated_usage = 0;
+ for batch in &batches {
+ estimated_usage += batch.in_mem_size()?;
+ }
+ let stream: SendableEvaluatedBatchStream =
+ Box::pin(InMemoryEvaluatedBatchStream::new(schema, batches));
+ Ok(BuildPartition {
+ num_rows,
+ build_side_batch_stream: stream,
+ geo_statistics,
+ bbox_samples: BoundingBoxSamples::empty(),
+ estimated_spatial_index_memory_usage: estimated_usage,
+ reservation: new_reservation(memory_pool),
+ metrics: CollectBuildSideMetrics::new(0,
&ExecutionPlanMetricsSet::new()),
+ })
+ }
+
+ fn spill_partition_from_batches(
+ runtime_env: Arc<RuntimeEnv>,
+ batches: Vec<EvaluatedBatch>,
+ ) -> Result<SpilledPartition> {
+ if batches.is_empty() {
+ return Ok(SpilledPartition::empty());
+ }
+ let schema = batches[0].schema();
+ let sedona_type = batches[0].geom_array.sedona_type.clone();
+ let mut writer = EvaluatedBatchSpillWriter::try_new(
+ runtime_env,
+ schema,
+ &sedona_type,
+ "partitioned-index-provider-test",
+ SpillCompression::Uncompressed,
+ SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0),
+ None,
+ )?;
+ let mut num_rows = 0;
+ for batch in &batches {
+ num_rows += batch.num_rows();
+ writer.append(batch)?;
+ }
+ let geo_statistics = geo_stats_from_batches(&batches)?;
+ let spill_file = writer.finish()?;
+ Ok(SpilledPartition::new(
+ vec![Arc::new(spill_file)],
+ geo_statistics,
+ num_rows,
+ ))
+ }
+
+ fn make_spilled_partitions(
+ runtime_env: Arc<RuntimeEnv>,
+ partitions: Vec<Vec<EvaluatedBatch>>,
+ ) -> Result<SpilledPartitions> {
+ let slots = PartitionSlots::new(partitions.len());
+ let mut spilled = Vec::with_capacity(slots.total_slots());
+ for partition_batches in partitions {
+ spilled.push(spill_partition_from_batches(
+ Arc::clone(&runtime_env),
+ partition_batches,
+ )?);
+ }
+ spilled.push(SpilledPartition::empty());
+ spilled.push(SpilledPartition::empty());
+ Ok(SpilledPartitions::new(slots, spilled))
+ }
+
+ #[tokio::test]
+ async fn single_partition_builds_once_and_is_cached() -> Result<()> {
+ let memory_pool: Arc<dyn MemoryPool> =
Arc::new(GreedyMemoryPool::new(1 << 20));
+ let batches = vec![sample_batch(
+ &[1, 2],
+ vec![Some(point_wkb(10.0, 10.0)), Some(point_wkb(20.0, 20.0))],
+ )?];
+ let build_partition =
build_partition_from_batches(Arc::clone(&memory_pool), batches)?;
+ let metrics = ExecutionPlanMetricsSet::new();
+ let provider = PartitionedIndexProvider::new_single_partition(
+ sample_schema(),
+ predicate(),
+ SpatialJoinOptions::default(),
+ JoinType::Inner,
+ 1,
+ vec![build_partition],
+ SpatialJoinBuildMetrics::new(0, &metrics),
+ );
+
+ let first_index = provider
+ .build_or_wait_for_index(0)
+ .await
+ .expect("partition exists")?;
+ assert_eq!(first_index.indexed_batches.len(), 1);
+ assert_eq!(provider.num_loaded_indexes(), 1);
+
+ let cached_index = provider
+ .wait_for_index(0)
+ .await
+ .expect("cached value must remain accessible")?;
+ assert!(Arc::ptr_eq(&first_index, &cached_index));
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn multi_partition_concurrent_requests_share_indexes() -> Result<()>
{
+ let memory_pool: Arc<dyn MemoryPool> =
Arc::new(GreedyMemoryPool::new(1 << 20));
+ let runtime_env = Arc::new(RuntimeEnv::default());
+ let partition_batches = vec![
+ vec![sample_batch(&[10], vec![Some(point_wkb(0.0, 0.0))])?],
+ vec![sample_batch(&[20], vec![Some(point_wkb(50.0, 50.0))])?],
+ ];
+ let spilled_partitions = make_spilled_partitions(runtime_env,
partition_batches)?;
+ let metrics = ExecutionPlanMetricsSet::new();
+ let provider = Arc::new(PartitionedIndexProvider::new_multi_partition(
+ sample_schema(),
+ predicate(),
+ SpatialJoinOptions::default(),
+ JoinType::Inner,
+ 1,
+ spilled_partitions,
+ SpatialJoinBuildMetrics::new(0, &metrics),
+ vec![new_reservation(Arc::clone(&memory_pool))],
+ ));
+
+ let (idx_one, idx_two) = tokio::join!(
+ provider.build_or_wait_for_index(0),
+ provider.build_or_wait_for_index(0)
+ );
+ let idx_one = idx_one.expect("partition exists")?;
+ let idx_two = idx_two.expect("partition exists")?;
+ assert!(Arc::ptr_eq(&idx_one, &idx_two));
+ assert_eq!(idx_one.indexed_batches.len(), 1);
+
+ let second_partition = provider
+ .build_or_wait_for_index(1)
+ .await
+ .expect("second partition exists")?;
+ assert_eq!(second_partition.indexed_batches.len(), 1);
+ assert_eq!(provider.num_loaded_indexes(), 2);
+ Ok(())
+ }
+}
diff --git a/rust/sedona-spatial-join/src/index/spatial_index.rs
b/rust/sedona-spatial-join/src/index/spatial_index.rs
index 9364920a..bff7895d 100644
--- a/rust/sedona-spatial-join/src/index/spatial_index.rs
+++ b/rust/sedona-spatial-join/src/index/spatial_index.rs
@@ -27,7 +27,6 @@ use arrow_array::RecordBatch;
use arrow_schema::SchemaRef;
use datafusion_common::{DataFusionError, Result};
use datafusion_common_runtime::JoinSet;
-use datafusion_execution::memory_pool::MemoryReservation;
use float_next_after::NextAfter;
use geo::BoundingRect;
use geo_index::rtree::{
@@ -95,11 +94,6 @@ pub struct SpatialIndex {
/// Shared KNN components (distance metrics and geometry cache) for
efficient KNN queries
pub(crate) knn_components: Option<KnnComponents>,
-
- /// Memory reservation for tracking the memory usage of the spatial index
- /// Cleared on `SpatialIndex` drop
- #[expect(dead_code)]
- pub(crate) reservation: MemoryReservation,
}
impl SpatialIndex {
@@ -108,7 +102,6 @@ impl SpatialIndex {
schema: SchemaRef,
options: SpatialJoinOptions,
probe_threads_counter: AtomicUsize,
- reservation: MemoryReservation,
) -> Self {
let evaluator = create_operand_evaluator(&spatial_predicate,
options.clone());
let refiner = create_refiner(
@@ -133,7 +126,6 @@ impl SpatialIndex {
visited_build_side: None,
probe_threads_counter,
knn_components,
- reservation,
}
}
@@ -681,7 +673,6 @@ mod tests {
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field};
use datafusion_common::JoinSide;
- use datafusion_execution::memory_pool::GreedyMemoryPool;
use datafusion_expr::JoinType;
use datafusion_physical_expr::expressions::Column;
use geo_traits::Dimensions;
@@ -692,7 +683,6 @@ mod tests {
#[test]
fn test_spatial_index_builder_empty() {
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareBuild,
..Default::default()
@@ -711,7 +701,6 @@ mod tests {
options,
JoinType::Inner,
4,
- memory_pool,
metrics,
)
.unwrap();
@@ -724,7 +713,6 @@ mod tests {
#[test]
fn test_spatial_index_builder_add_batch() {
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareBuild,
..Default::default()
@@ -750,7 +738,6 @@ mod tests {
options,
JoinType::Inner,
4,
- memory_pool,
metrics,
)
.unwrap();
@@ -779,7 +766,6 @@ mod tests {
#[test]
fn test_knn_query_execution_with_sample_data() {
// Create a spatial index with sample geometry data
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareBuild,
..Default::default()
@@ -807,7 +793,6 @@ mod tests {
options,
JoinType::Inner,
4,
- memory_pool,
metrics,
)
.unwrap();
@@ -878,7 +863,6 @@ mod tests {
#[test]
fn test_knn_query_execution_with_different_k_values() {
// Create spatial index with more data points
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareBuild,
..Default::default()
@@ -905,7 +889,6 @@ mod tests {
options,
JoinType::Inner,
4,
- memory_pool,
metrics,
)
.unwrap();
@@ -969,7 +952,6 @@ mod tests {
#[test]
fn test_knn_query_execution_with_spheroid_distance() {
// Create spatial index
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareBuild,
..Default::default()
@@ -996,7 +978,6 @@ mod tests {
options,
JoinType::Inner,
4,
- memory_pool,
metrics,
)
.unwrap();
@@ -1066,7 +1047,6 @@ mod tests {
#[test]
fn test_knn_query_execution_edge_cases() {
// Create spatial index
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareBuild,
..Default::default()
@@ -1093,7 +1073,6 @@ mod tests {
options,
JoinType::Inner,
4,
- memory_pool,
metrics,
)
.unwrap();
@@ -1159,7 +1138,6 @@ mod tests {
#[test]
fn test_knn_query_execution_empty_index() {
// Create empty spatial index
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareBuild,
..Default::default()
@@ -1181,7 +1159,6 @@ mod tests {
options,
JoinType::Inner,
4,
- memory_pool,
metrics,
)
.unwrap();
@@ -1207,7 +1184,6 @@ mod tests {
#[test]
fn test_knn_query_execution_with_tie_breakers() {
// Create a spatial index with sample geometry data
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareBuild,
..Default::default()
@@ -1234,7 +1210,6 @@ mod tests {
options,
JoinType::Inner,
1, // probe_threads_count
- memory_pool.clone(),
metrics,
)
.unwrap();
@@ -1322,7 +1297,6 @@ mod tests {
#[test]
fn test_query_knn_with_geometry_distance() {
// Create a spatial index with sample geometry data
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareBuild,
..Default::default()
@@ -1350,7 +1324,6 @@ mod tests {
options,
JoinType::Inner,
4,
- memory_pool,
metrics,
)
.unwrap();
@@ -1407,7 +1380,6 @@ mod tests {
fn test_query_knn_with_mixed_geometries() {
// Create a spatial index with complex geometries where geometry-based
// distance should differ from centroid-based distance
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareBuild,
..Default::default()
@@ -1435,7 +1407,6 @@ mod tests {
options,
JoinType::Inner,
4,
- memory_pool,
metrics,
)
.unwrap();
@@ -1489,7 +1460,6 @@ mod tests {
#[test]
fn test_query_knn_with_tie_breakers_geometry_distance() {
// Create a spatial index with geometries that have identical
distances for tie-breaker testing
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareBuild,
..Default::default()
@@ -1516,7 +1486,6 @@ mod tests {
options,
JoinType::Inner,
4,
- memory_pool,
metrics,
)
.unwrap();
@@ -1610,7 +1579,6 @@ mod tests {
#[test]
fn test_knn_query_with_empty_geometry() {
// Create a spatial index with sample geometry data like other tests
- let memory_pool = Arc::new(GreedyMemoryPool::new(1024 * 1024));
let options = SpatialJoinOptions {
execution_mode: ExecutionMode::PrepareBuild,
..Default::default()
@@ -1638,7 +1606,6 @@ mod tests {
options,
JoinType::Inner,
1, // probe_threads_count
- memory_pool.clone(),
metrics,
)
.unwrap();
@@ -1687,7 +1654,6 @@ mod tests {
build_geoms: &[Option<&str>],
options: SpatialJoinOptions,
) -> Arc<SpatialIndex> {
- let memory_pool = Arc::new(GreedyMemoryPool::new(100 * 1024 * 1024));
let metrics = SpatialJoinBuildMetrics::default();
let spatial_predicate =
SpatialPredicate::Relation(RelationPredicate::new(
Arc::new(Column::new("left", 0)),
@@ -1706,7 +1672,6 @@ mod tests {
options,
JoinType::Inner,
1,
- memory_pool,
metrics,
)
.unwrap();
diff --git a/rust/sedona-spatial-join/src/index/spatial_index_builder.rs
b/rust/sedona-spatial-join/src/index/spatial_index_builder.rs
index 9d97b539..ca2b0088 100644
--- a/rust/sedona-spatial-join/src/index/spatial_index_builder.rs
+++ b/rust/sedona-spatial-join/src/index/spatial_index_builder.rs
@@ -22,16 +22,15 @@ use sedona_common::SpatialJoinOptions;
use sedona_expr::statistics::GeoStatistics;
use datafusion_common::{utils::proxy::VecAllocExt, Result};
-use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool,
MemoryReservation};
use datafusion_expr::JoinType;
use futures::StreamExt;
use geo_index::rtree::{sort::HilbertSort, RTree, RTreeBuilder, RTreeIndex};
use parking_lot::Mutex;
-use std::sync::{atomic::AtomicUsize, Arc};
+use std::sync::atomic::AtomicUsize;
use crate::{
- evaluated_batch::EvaluatedBatch,
- index::{knn_adapter::KnnComponents, spatial_index::SpatialIndex,
BuildPartition},
+ evaluated_batch::{evaluated_batch_stream::SendableEvaluatedBatchStream,
EvaluatedBatch},
+ index::{knn_adapter::KnnComponents, spatial_index::SpatialIndex},
operand_evaluator::create_operand_evaluator,
refine::create_refiner,
spatial_predicate::SpatialPredicate,
@@ -63,8 +62,6 @@ pub struct SpatialIndexBuilder {
/// Batches to be indexed
indexed_batches: Vec<EvaluatedBatch>,
- /// Memory reservation for tracking the memory usage of the spatial index
- reservation: MemoryReservation,
/// Statistics for indexed geometries
stats: GeoStatistics,
@@ -99,12 +96,8 @@ impl SpatialIndexBuilder {
options: SpatialJoinOptions,
join_type: JoinType,
probe_threads_count: usize,
- memory_pool: Arc<dyn MemoryPool>,
metrics: SpatialJoinBuildMetrics,
) -> Result<Self> {
- let consumer = MemoryConsumer::new("SpatialJoinIndex");
- let reservation = consumer.register(&memory_pool);
-
Ok(Self {
schema,
spatial_predicate,
@@ -113,7 +106,6 @@ impl SpatialIndexBuilder {
probe_threads_count,
metrics,
indexed_batches: Vec::new(),
- reservation,
stats: GeoStatistics::empty(),
memory_used: 0,
})
@@ -258,7 +250,6 @@ impl SpatialIndexBuilder {
self.schema,
self.options,
AtomicUsize::new(self.probe_threads_count),
- self.reservation,
));
}
@@ -297,6 +288,10 @@ impl SpatialIndexBuilder {
}
};
+ log::debug!(
+ "Estimated memory used by spatial index: {}",
+ self.memory_used
+ );
Ok(SpatialIndex {
schema: self.schema,
options: self.options,
@@ -309,26 +304,19 @@ impl SpatialIndexBuilder {
visited_build_side,
probe_threads_counter: AtomicUsize::new(self.probe_threads_count),
knn_components: knn_components_opt,
- reservation: self.reservation,
})
}
- pub async fn add_partitions(&mut self, partitions: Vec<BuildPartition>) ->
Result<()> {
- for partition in partitions {
- self.add_partition(partition).await?;
- }
- Ok(())
- }
-
- pub async fn add_partition(&mut self, mut partition: BuildPartition) ->
Result<()> {
- let mut stream = partition.build_side_batch_stream;
+ pub async fn add_stream(
+ &mut self,
+ mut stream: SendableEvaluatedBatchStream,
+ geo_statistics: GeoStatistics,
+ ) -> Result<()> {
while let Some(batch) = stream.next().await {
let indexed_batch = batch?;
self.add_batch(indexed_batch)?;
}
- self.merge_stats(partition.geo_statistics);
- let mem_bytes = partition.reservation.free();
- self.reservation.try_grow(mem_bytes)?;
+ self.merge_stats(geo_statistics);
Ok(())
}
diff --git a/rust/sedona-spatial-join/src/lib.rs
b/rust/sedona-spatial-join/src/lib.rs
index 94af3f22..2abaf3c4 100644
--- a/rust/sedona-spatial-join/src/lib.rs
+++ b/rust/sedona-spatial-join/src/lib.rs
@@ -15,13 +15,13 @@
// specific language governing permissions and limitations
// under the License.
-mod build_index;
pub mod evaluated_batch;
pub mod exec;
mod index;
pub mod operand_evaluator;
pub mod optimizer;
pub mod partitioning;
+mod prepare;
pub mod refine;
pub mod spatial_predicate;
mod stream;
@@ -31,7 +31,6 @@ pub use exec::SpatialJoinExec;
pub use optimizer::register_spatial_join_optimizer;
// Re-export types needed for external usage (e.g., in Comet)
-pub use build_index::build_index;
pub use index::{SpatialIndex, SpatialJoinBuildMetrics};
pub use spatial_predicate::SpatialPredicate;
diff --git a/rust/sedona-spatial-join/src/partitioning/kdb.rs
b/rust/sedona-spatial-join/src/partitioning/kdb.rs
index 32ac3a4c..c09e98ff 100644
--- a/rust/sedona-spatial-join/src/partitioning/kdb.rs
+++ b/rust/sedona-spatial-join/src/partitioning/kdb.rs
@@ -43,7 +43,9 @@
use std::sync::Arc;
use crate::partitioning::{
- util::{bbox_to_geo_rect, rect_contains_point, rect_intersection_area,
rects_intersect},
+ util::{
+ bbox_to_geo_rect, make_rect, rect_contains_point,
rect_intersection_area, rects_intersect,
+ },
SpatialPartition, SpatialPartitioner,
};
use datafusion_common::Result;
@@ -126,9 +128,12 @@ impl KDBTree {
if max_items_per_node == 0 {
return sedona_internal_err!("max_items_per_node must be greater
than 0");
}
- let Some(extent_rect) = bbox_to_geo_rect(&extent)? else {
- return sedona_internal_err!("KDBTree extent cannot be empty");
- };
+
+ // extent_rect is a sentinel rect if the bounding box is empty. In
that case,
+ // almost all insertions will be ignored. We are free to partition the
data
+ // arbitrarily when the extent is empty.
+ let extent_rect = bbox_to_geo_rect(&extent)?.unwrap_or(make_rect(0.0,
0.0, 0.0, 0.0));
+
Ok(Self::new_with_level(
max_items_per_node,
max_levels,
@@ -507,6 +512,13 @@ impl KDBPartitioner {
}
Ok(())
}
+
+ /// Return the tree structure in human-readable format for debugging
purposes.
+ pub fn debug_str(&self) -> String {
+ let mut output = String::new();
+ let _ = self.debug_print(&mut output);
+ output
+ }
}
impl SpatialPartitioner for KDBPartitioner {
@@ -966,4 +978,19 @@ mod tests {
SpatialPartition::None
);
}
+
+ #[test]
+ fn test_kdb_partitioner_empty_extent() {
+ let extent = BoundingBox::empty();
+ let bboxes = vec![
+ BoundingBox::xy((0.0, 10.0), (0.0, 10.0)),
+ BoundingBox::xy((1.0, 10.0), (1.0, 10.0)),
+ ];
+ let partitioner = KDBPartitioner::build(bboxes.clone().into_iter(),
10, 4, extent).unwrap();
+
+ // Partition calls should succeed
+ for test_bbox in bboxes {
+ assert!(partitioner.partition(&test_bbox).is_ok());
+ }
+ }
}
diff --git a/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs
b/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs
index 44591107..038530b1 100644
--- a/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs
+++ b/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs
@@ -280,6 +280,13 @@ impl SpilledPartitions {
}
Ok(())
}
+
+ /// Return debug info for this spilled partitions as a string.
+ pub fn debug_str(&self) -> String {
+ let mut output = String::new();
+ let _ = self.debug_print(&mut output);
+ output
+ }
}
/// Incremental (stateful) repartitioner for an [`EvaluatedBatch`] stream.
diff --git a/rust/sedona-spatial-join/src/prepare.rs
b/rust/sedona-spatial-join/src/prepare.rs
new file mode 100644
index 00000000..76e825b3
--- /dev/null
+++ b/rust/sedona-spatial-join/src/prepare.rs
@@ -0,0 +1,514 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{mem, sync::Arc};
+
+use arrow_schema::SchemaRef;
+use datafusion_common::Result;
+use datafusion_common_runtime::JoinSet;
+use datafusion_execution::{
+ disk_manager::RefCountedTempFile, memory_pool::MemoryConsumer,
SendableRecordBatchStream,
+ TaskContext,
+};
+use datafusion_expr::JoinType;
+use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet;
+use fastrand::Rng;
+use sedona_common::{sedona_internal_err, NumSpatialPartitionsConfig,
SedonaOptions};
+use sedona_expr::statistics::GeoStatistics;
+use sedona_geometry::bounding_box::BoundingBox;
+
+use crate::{
+ index::{
+ memory_plan::{compute_memory_plan, MemoryPlan, PartitionMemorySummary},
+ partitioned_index_provider::PartitionedIndexProvider,
+ BuildPartition, BuildSideBatchesCollector, CollectBuildSideMetrics,
+ SpatialJoinBuildMetrics,
+ },
+ partitioning::{
+ kdb::KDBPartitioner,
+ stream_repartitioner::{SpilledPartition, SpilledPartitions,
StreamRepartitioner},
+ PartitionedSide, SpatialPartition, SpatialPartitioner,
+ },
+ spatial_predicate::SpatialPredicate,
+ utils::bbox_sampler::BoundingBoxSamples,
+};
+
+pub(crate) struct SpatialJoinComponents {
+ pub partitioned_index_provider: Arc<PartitionedIndexProvider>,
+}
+
+/// Builder for constructing `SpatialJoinComponents` from build-side streams.
+///
+/// Calling `build(...)` performs the full preparation flow:
+/// - collect (and spill if needed) build-side batches,
+/// - compute memory plan and pick single- or multi-partition mode,
+/// - repartition the build side into spatial partitions in multi-partition
mode,
+/// - create the appropriate `PartitionedIndexProvider` for creating spatial
indexes.
+pub(crate) struct SpatialJoinComponentsBuilder {
+ context: Arc<TaskContext>,
+ build_schema: SchemaRef,
+ spatial_predicate: SpatialPredicate,
+ join_type: JoinType,
+ probe_threads_count: usize,
+ metrics: ExecutionPlanMetricsSet,
+ seed: u64,
+ sedona_options: SedonaOptions,
+}
+
+impl SpatialJoinComponentsBuilder {
+ /// Create a new builder capturing the execution context and configuration
+ /// required to produce `SpatialJoinComponents` from build-side streams.
+ pub fn new(
+ context: Arc<TaskContext>,
+ build_schema: SchemaRef,
+ spatial_predicate: SpatialPredicate,
+ join_type: JoinType,
+ probe_threads_count: usize,
+ metrics: ExecutionPlanMetricsSet,
+ seed: u64,
+ ) -> Self {
+ let session_config = context.session_config();
+ let sedona_options = session_config
+ .options()
+ .extensions
+ .get::<SedonaOptions>()
+ .cloned()
+ .unwrap_or_default();
+ Self {
+ context,
+ build_schema,
+ spatial_predicate,
+ join_type,
+ probe_threads_count,
+ metrics,
+ seed,
+ sedona_options,
+ }
+ }
+
+ /// Prepare and return `SpatialJoinComponents` for the given build-side
+ /// streams. This drives the end-to-end preparation flow and returns a
+ /// ready-to-use `SpatialJoinComponents` for the spatial join operator.
+ pub async fn build(
+ mut self,
+ build_streams: Vec<SendableRecordBatchStream>,
+ ) -> Result<SpatialJoinComponents> {
+ let num_partitions = build_streams.len();
+ if num_partitions == 0 {
+ log::debug!("Build side has no data. Creating empty spatial
index.");
+ let partitioned_index_provider =
PartitionedIndexProvider::new_empty(
+ self.build_schema,
+ self.spatial_predicate,
+ self.sedona_options.spatial_join,
+ self.join_type,
+ self.probe_threads_count,
+ SpatialJoinBuildMetrics::new(0, &self.metrics),
+ );
+ return Ok(SpatialJoinComponents {
+ partitioned_index_provider:
Arc::new(partitioned_index_provider),
+ });
+ }
+
+ let mut rng = Rng::with_seed(self.seed);
+ let mut build_partitions = self
+ .collect_build_partitions(build_streams, rng.u64(0..0xFFFF))
+ .await?;
+
+ // Determine the number of spatial partitions based on the memory
reserved and the estimated amount of
+ // memory required for loading the entire build side into a spatial
index
+ let memory_plan =
+
compute_memory_plan(build_partitions.iter().map(PartitionMemorySummary::from))?;
+ log::debug!("Computed memory plan for spatial join:\n{:#?}",
memory_plan);
+ let num_partitions = match self
+ .sedona_options
+ .spatial_join
+ .debug
+ .num_spatial_partitions
+ {
+ NumSpatialPartitionsConfig::Auto => memory_plan.num_partitions,
+ NumSpatialPartitionsConfig::Fixed(n) => {
+ log::debug!("Override number of spatial partitions to {}", n);
+ n
+ }
+ };
+
+ if num_partitions == 1 {
+ log::debug!("Running single-partitioned in-memory spatial join");
+ let partitioned_index_provider =
PartitionedIndexProvider::new_single_partition(
+ self.build_schema,
+ self.spatial_predicate,
+ self.sedona_options.spatial_join,
+ self.join_type,
+ self.probe_threads_count,
+ build_partitions,
+ SpatialJoinBuildMetrics::new(0, &self.metrics),
+ );
+ Ok(SpatialJoinComponents {
+ partitioned_index_provider:
Arc::new(partitioned_index_provider),
+ })
+ } else {
+ // Collect all memory reservations grown during build side
collection
+ let mut reservations = Vec::with_capacity(build_partitions.len());
+ for partition in &mut build_partitions {
+ reservations.push(partition.reservation.take());
+ }
+
+ // Partition the build side into multiple spatial partitions, each
partition can be fully
+ // loaded into an in-memory spatial index
+ let build_partitioner = self.build_spatial_partitioner(
+ num_partitions,
+ &mut build_partitions,
+ rng.u64(0..0xFFFF),
+ )?;
+ let partitioned_spill_files_vec = self
+ .repartition_build_side(build_partitions, build_partitioner,
&memory_plan)
+ .await?;
+
+ let merged_spilled_partitions =
merge_spilled_partitions(partitioned_spill_files_vec)?;
+ log::debug!(
+ "Build side spatial partitions:\n{}",
+ merged_spilled_partitions.debug_str()
+ );
+
+ // Sanity check: Multi and None partitions must be empty. All the
geometries in the build side
+ // should fall into regular partitions
+ for partition in [SpatialPartition::None, SpatialPartition::Multi]
{
+ let spilled_partition =
merged_spilled_partitions.spilled_partition(partition)?;
+ if !spilled_partition.spill_files().is_empty() {
+ return sedona_internal_err!(
+ "Build side spatial partitions {:?} should be empty",
+ partition
+ );
+ }
+ }
+
+ let partitioned_index_provider =
PartitionedIndexProvider::new_multi_partition(
+ self.build_schema,
+ self.spatial_predicate,
+ self.sedona_options.spatial_join,
+ self.join_type,
+ self.probe_threads_count,
+ merged_spilled_partitions,
+ SpatialJoinBuildMetrics::new(0, &self.metrics),
+ reservations,
+ );
+
+ Ok(SpatialJoinComponents {
+ partitioned_index_provider:
Arc::new(partitioned_index_provider),
+ })
+ }
+ }
+
+ /// Collect build-side batches from the provided streams and return a
+ /// vector of `BuildPartition` entries representing the collected data.
+ /// The collector may spill to disk according to the configured options.
+ async fn collect_build_partitions(
+ &mut self,
+ build_streams: Vec<SendableRecordBatchStream>,
+ seed: u64,
+ ) -> Result<Vec<BuildPartition>> {
+ let runtime_env = self.context.runtime_env();
+ let session_config = self.context.session_config();
+ let spill_compression = session_config.spill_compression();
+
+ let num_partitions = build_streams.len();
+ let mut collect_metrics_vec = Vec::with_capacity(num_partitions);
+ let mut reservations = Vec::with_capacity(num_partitions);
+ let memory_pool = self.context.memory_pool();
+ for k in 0..num_partitions {
+ let consumer =
MemoryConsumer::new(format!("SpatialJoinCollectBuildSide[{k}]"))
+ .with_can_spill(true);
+ let reservation = consumer.register(memory_pool);
+ reservations.push(reservation);
+ collect_metrics_vec.push(CollectBuildSideMetrics::new(k,
&self.metrics));
+ }
+
+ let collector = BuildSideBatchesCollector::new(
+ self.spatial_predicate.clone(),
+ self.sedona_options.spatial_join.clone(),
+ Arc::clone(&runtime_env),
+ spill_compression,
+ );
+ let build_partitions = collector
+ .collect_all(
+ build_streams,
+ reservations,
+ collect_metrics_vec.clone(),
+ self.sedona_options
+ .spatial_join
+ .concurrent_build_side_collection,
+ seed,
+ )
+ .await?;
+
+ Ok(build_partitions)
+ }
+
+ /// Construct a `SpatialPartitioner` (e.g. KDB) from collected samples so
+ /// the build and probe sides can be partitioned spatially across
+ /// `num_partitions`.
+ fn build_spatial_partitioner(
+ &self,
+ num_partitions: usize,
+ build_partitions: &mut Vec<BuildPartition>,
+ seed: u64,
+ ) -> Result<Arc<dyn SpatialPartitioner>> {
+ if matches!(
+ self.spatial_predicate,
+ SpatialPredicate::KNearestNeighbors(..)
+ ) {
+ return sedona_internal_err!("Partitioned KNN join is not supported
yet");
+ }
+
+ let build_partitioner: Arc<dyn SpatialPartitioner> = {
+ // Use spatial partitioners to partition the build side and the
probe side, this will
+ // reduce the amount of work needed for probing each partitioned
index.
+ // The KDB partitioner is built using the collected bounding box
samples.
+ let mut bbox_samples = BoundingBoxSamples::empty();
+ let mut geo_stats = GeoStatistics::empty();
+ let mut rng = Rng::with_seed(seed);
+ for partition in build_partitions {
+ let samples = mem::take(&mut partition.bbox_samples);
+ bbox_samples = bbox_samples.combine(samples, &mut rng);
+ geo_stats.merge(&partition.geo_statistics);
+ }
+
+ let extent =
geo_stats.bbox().cloned().unwrap_or(BoundingBox::empty());
+ let mut samples = bbox_samples.take_samples();
+ let max_items_per_node = 1.max(samples.len() / num_partitions);
+ let max_levels = num_partitions;
+
+ log::debug!(
+ "Number of samples: {}, max_items_per_node: {}, max_levels:
{}",
+ samples.len(),
+ max_items_per_node,
+ max_levels
+ );
+ rng.shuffle(&mut samples);
+ let kdb_partitioner =
+ KDBPartitioner::build(samples.into_iter(), max_items_per_node,
max_levels, extent)?;
+ log::debug!(
+ "Built KDB spatial partitioner with {} partitions",
+ num_partitions
+ );
+ log::debug!(
+ "KDB partitioner debug info:\n{}",
+ kdb_partitioner.debug_str()
+ );
+
+ Arc::new(kdb_partitioner)
+ };
+
+ Ok(build_partitioner)
+ }
+
+ /// Repartition the collected build-side partitions using the provided
+ /// `SpatialPartitioner`. Returns the spilled partitions for each spatial
partition.
+ async fn repartition_build_side(
+ &self,
+ build_partitions: Vec<BuildPartition>,
+ build_partitioner: Arc<dyn SpatialPartitioner>,
+ memory_plan: &MemoryPlan,
+ ) -> Result<Vec<SpilledPartitions>> {
+ // Spawn each task for each build partition to repartition the data
using the spatial partitioner for
+ // the build/indexed side
+ let runtime_env = self.context.runtime_env();
+ let session_config = self.context.session_config();
+ let target_batch_size = session_config.batch_size();
+ let spill_compression = session_config.spill_compression();
+ let spilled_batch_in_memory_size_threshold = if self
+ .sedona_options
+ .spatial_join
+ .spilled_batch_in_memory_size_threshold
+ == 0
+ {
+ None
+ } else {
+ Some(
+ self.sedona_options
+ .spatial_join
+ .spilled_batch_in_memory_size_threshold,
+ )
+ };
+
+ let memory_for_intermittent_usage = match self
+ .sedona_options
+ .spatial_join
+ .debug
+ .memory_for_intermittent_usage
+ {
+ Some(value) => {
+ log::debug!("Override memory for intermittent usage to {}",
value);
+ value
+ }
+ None => memory_plan.memory_for_intermittent_usage,
+ };
+
+ let mut join_set = JoinSet::new();
+ let buffer_bytes_threshold = memory_for_intermittent_usage /
build_partitions.len();
+ for partition in build_partitions {
+ let stream = partition.build_side_batch_stream;
+ let metrics = &partition.metrics;
+ let spill_metrics = metrics.spill_metrics();
+ let runtime_env = Arc::clone(&runtime_env);
+ let partitioner = Arc::clone(&build_partitioner);
+ join_set.spawn(async move {
+ let partitioned_spill_files = StreamRepartitioner::builder(
+ runtime_env,
+ partitioner,
+ PartitionedSide::BuildSide,
+ spill_metrics,
+ )
+ .spill_compression(spill_compression)
+ .buffer_bytes_threshold(buffer_bytes_threshold)
+ .target_batch_size(target_batch_size)
+
.spilled_batch_in_memory_size_threshold(spilled_batch_in_memory_size_threshold)
+ .build()
+ .repartition_stream(stream)
+ .await;
+ partitioned_spill_files
+ });
+ }
+
+ let results = join_set.join_all().await;
+ let partitioned_spill_files_vec =
results.into_iter().collect::<Result<Vec<_>>>()?;
+ Ok(partitioned_spill_files_vec)
+ }
+}
+
+/// Aggregate the spill files and bounds of each spatial partition collected
from all build partitions
+fn merge_spilled_partitions(
+ spilled_partitions_vec: Vec<SpilledPartitions>,
+) -> Result<SpilledPartitions> {
+ let Some(first) = spilled_partitions_vec.first() else {
+ return sedona_internal_err!("spilled_partitions_vec cannot be empty");
+ };
+
+ let slots = first.slots();
+ let total_slots = slots.total_slots();
+ let mut merged_spill_files: Vec<Vec<Arc<RefCountedTempFile>>> =
+ (0..total_slots).map(|_| Vec::new()).collect();
+ let mut partition_geo_stats: Vec<GeoStatistics> =
+ (0..total_slots).map(|_| GeoStatistics::empty()).collect();
+ let mut partition_num_rows: Vec<usize> = (0..total_slots).map(|_|
0).collect();
+
+ for spilled_partitions in spilled_partitions_vec {
+ let partitions = spilled_partitions.into_spilled_partitions()?;
+ for (slot_idx, partition) in partitions.into_iter().enumerate() {
+ let (spill_files, geo_stats, num_rows) = partition.into_inner();
+ partition_geo_stats[slot_idx].merge(&geo_stats);
+ merged_spill_files[slot_idx].extend(spill_files);
+ partition_num_rows[slot_idx] += num_rows;
+ }
+ }
+
+ let merged_partitions = merged_spill_files
+ .into_iter()
+ .zip(partition_geo_stats)
+ .zip(partition_num_rows)
+ .map(|((spill_files, geo_stats), num_rows)| {
+ SpilledPartition::new(spill_files, geo_stats, num_rows)
+ })
+ .collect();
+
+ Ok(SpilledPartitions::new(slots, merged_partitions))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::partitioning::partition_slots::PartitionSlots;
+ use datafusion_execution::runtime_env::RuntimeEnv;
+ use sedona_geometry::interval::IntervalTrait;
+
+ fn sample_geo_stats(bbox: (f64, f64, f64, f64), total_geometries: i64) ->
GeoStatistics {
+ GeoStatistics::empty()
+ .with_bbox(Some(BoundingBox::xy((bbox.0, bbox.1), (bbox.2,
bbox.3))))
+ .with_total_geometries(total_geometries)
+ }
+
+ fn sample_partition(
+ env: &Arc<RuntimeEnv>,
+ labels: &[&str],
+ bbox: (f64, f64, f64, f64),
+ total_geometries: i64,
+ ) -> Result<SpilledPartition> {
+ let mut files = Vec::with_capacity(labels.len());
+ for label in labels {
+ files.push(Arc::new(env.disk_manager.create_tmp_file(label)?));
+ }
+ Ok(SpilledPartition::new(
+ files,
+ sample_geo_stats(bbox, total_geometries),
+ total_geometries as usize,
+ ))
+ }
+
+ #[test]
+ fn merge_spilled_partitions_combines_files_and_stats() -> Result<()> {
+ let runtime_env = Arc::new(RuntimeEnv::default());
+ let slots = PartitionSlots::new(2);
+
+ let partitions_a = vec![
+ sample_partition(&runtime_env, &["r0_a"], (0.0, 1.0, 0.0, 1.0),
10)?,
+ sample_partition(&runtime_env, &["r1_a"], (10.0, 11.0, -1.0, 1.0),
5)?,
+ sample_partition(&runtime_env, &["none_a"], (-5.0, -4.0, -5.0,
-4.0), 2)?,
+ SpilledPartition::empty(),
+ ];
+ let first = SpilledPartitions::new(slots, partitions_a);
+
+ let partitions_b = vec![
+ sample_partition(&runtime_env, &["r0_b1", "r0_b2"], (5.0, 6.0,
5.0, 6.0), 20)?,
+ sample_partition(&runtime_env, &[], (12.0, 13.0, 2.0, 3.0), 8)?,
+ SpilledPartition::empty(),
+ sample_partition(&runtime_env, &["multi_b"], (50.0, 51.0, 50.0,
51.0), 1)?,
+ ];
+ let second = SpilledPartitions::new(slots, partitions_b);
+
+ let merged = merge_spilled_partitions(vec![first, second])?;
+
+ assert_eq!(merged.spill_file_count(), 6);
+
+ let regular0 = merged.spilled_partition(SpatialPartition::Regular(0))?;
+ assert_eq!(regular0.spill_files().len(), 3);
+ assert_eq!(regular0.geo_statistics().total_geometries(), Some(30));
+ let bbox0 = regular0.geo_statistics().bbox().unwrap();
+ assert_eq!(bbox0.x().lo(), 0.0);
+ assert_eq!(bbox0.x().hi(), 6.0);
+ assert_eq!(bbox0.y().lo(), 0.0);
+ assert_eq!(bbox0.y().hi(), 6.0);
+
+ let regular1 = merged.spilled_partition(SpatialPartition::Regular(1))?;
+ assert_eq!(regular1.spill_files().len(), 1);
+ assert_eq!(regular1.geo_statistics().total_geometries(), Some(13));
+ let bbox1 = regular1.geo_statistics().bbox().unwrap();
+ assert_eq!(bbox1.x().lo(), 10.0);
+ assert_eq!(bbox1.x().hi(), 13.0);
+ assert_eq!(bbox1.y().lo(), -1.0);
+ assert_eq!(bbox1.y().hi(), 3.0);
+
+ let none_partition = merged.spilled_partition(SpatialPartition::None)?;
+ assert_eq!(none_partition.spill_files().len(), 1);
+ assert_eq!(none_partition.geo_statistics().total_geometries(),
Some(2));
+
+ let multi_partition =
merged.spilled_partition(SpatialPartition::Multi)?;
+ assert_eq!(multi_partition.spill_files().len(), 1);
+ assert_eq!(multi_partition.geo_statistics().total_geometries(),
Some(1));
+
+ Ok(())
+ }
+}
diff --git a/rust/sedona-spatial-join/src/stream.rs
b/rust/sedona-spatial-join/src/stream.rs
index 8451ff2d..edbb41dd 100644
--- a/rust/sedona-spatial-join/src/stream.rs
+++ b/rust/sedona-spatial-join/src/stream.rs
@@ -38,8 +38,10 @@ use std::sync::Arc;
use
crate::evaluated_batch::evaluated_batch_stream::evaluate::create_evaluated_probe_stream;
use
crate::evaluated_batch::evaluated_batch_stream::SendableEvaluatedBatchStream;
use crate::evaluated_batch::EvaluatedBatch;
+use crate::index::partitioned_index_provider::PartitionedIndexProvider;
use crate::index::SpatialIndex;
use crate::operand_evaluator::create_operand_evaluator;
+use crate::prepare::SpatialJoinComponents;
use crate::spatial_predicate::SpatialPredicate;
use crate::utils::join_utils::{
adjust_indices_by_join_type, apply_join_filter_to_indices,
build_batch_from_indices,
@@ -52,6 +54,8 @@ use sedona_common::option::SpatialJoinOptions;
/// Stream for producing spatial join result batches.
pub(crate) struct SpatialJoinStream {
+ /// The partition id of the probe side stream
+ probe_partition_id: usize,
/// Schema of joined results
schema: Arc<Schema>,
/// join filter
@@ -73,13 +77,18 @@ pub(crate) struct SpatialJoinStream {
options: SpatialJoinOptions,
/// Target output batch size
target_output_batch_size: usize,
- /// Once future for the spatial index
- once_fut_spatial_index: OnceFut<SpatialIndex>,
- /// Once async for the spatial index, will be manually disposed by the
last finished stream
- /// to avoid unnecessary memory usage.
- once_async_spatial_index: Arc<Mutex<Option<OnceAsync<SpatialIndex>>>>,
+ /// Once future for the shared partitioned index provider
+ once_fut_spatial_join_components: OnceFut<SpatialJoinComponents>,
+ /// Once async for the provider, disposed by the last finished stream
+ once_async_spatial_join_components:
Arc<Mutex<Option<OnceAsync<SpatialJoinComponents>>>>,
+ /// Cached index provider reference after it becomes available
+ index_provider: Option<Arc<PartitionedIndexProvider>>,
/// The spatial index
spatial_index: Option<Arc<SpatialIndex>>,
+ /// Pending future for building or waiting on a partitioned index
+ pending_index_future: Option<BoxFuture<'static,
Option<Result<Arc<SpatialIndex>>>>>,
+ /// Total number of regular partitions produced by the provider
+ num_regular_partitions: Option<u32>,
/// The spatial predicate being evaluated
spatial_predicate: SpatialPredicate,
}
@@ -87,6 +96,7 @@ pub(crate) struct SpatialJoinStream {
impl SpatialJoinStream {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
+ probe_partition_id: usize,
schema: Arc<Schema>,
on: &SpatialPredicate,
filter: Option<JoinFilter>,
@@ -97,8 +107,8 @@ impl SpatialJoinStream {
join_metrics: SpatialJoinProbeMetrics,
options: SpatialJoinOptions,
target_output_batch_size: usize,
- once_fut_spatial_index: OnceFut<SpatialIndex>,
- once_async_spatial_index: Arc<Mutex<Option<OnceAsync<SpatialIndex>>>>,
+ once_fut_spatial_join_components: OnceFut<SpatialJoinComponents>,
+ once_async_spatial_join_components:
Arc<Mutex<Option<OnceAsync<SpatialJoinComponents>>>>,
) -> Self {
let evaluator = create_operand_evaluator(on, options.clone());
let probe_stream = create_evaluated_probe_stream(
@@ -107,6 +117,7 @@ impl SpatialJoinStream {
join_metrics.join_time.clone(),
);
Self {
+ probe_partition_id,
schema,
filter,
join_type,
@@ -114,12 +125,15 @@ impl SpatialJoinStream {
column_indices,
probe_side_ordered,
join_metrics,
- state: SpatialJoinStreamState::WaitBuildIndex,
+ state: SpatialJoinStreamState::WaitPrepareSpatialJoinComponents,
options,
target_output_batch_size,
- once_fut_spatial_index,
- once_async_spatial_index,
+ once_fut_spatial_join_components,
+ once_async_spatial_join_components,
+ index_provider: None,
spatial_index: None,
+ pending_index_future: None,
+ num_regular_partitions: None,
spatial_predicate: on.clone(),
}
}
@@ -169,6 +183,8 @@ impl SpatialJoinProbeMetrics {
/// This enumeration represents various states of the nested loop join
algorithm.
#[allow(clippy::large_enum_variant)]
pub(crate) enum SpatialJoinStreamState {
+ /// The initial mode: waiting for the spatial join components to become
available
+ WaitPrepareSpatialJoinComponents,
/// The initial mode: waiting for the spatial index to be built
WaitBuildIndex,
/// Indicates that build-side has been collected, and stream is ready for
@@ -193,6 +209,9 @@ impl SpatialJoinStream {
) -> Poll<Option<Result<RecordBatch>>> {
loop {
return match &mut self.state {
+ SpatialJoinStreamState::WaitPrepareSpatialJoinComponents => {
+
handle_state!(ready!(self.wait_create_spatial_join_components(cx)))
+ }
SpatialJoinStreamState::WaitBuildIndex => {
handle_state!(ready!(self.wait_build_index(cx)))
}
@@ -213,16 +232,97 @@ impl SpatialJoinStream {
}
}
- fn wait_build_index(
+ fn wait_create_spatial_join_components(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
- let index = ready!(self.once_fut_spatial_index.get_shared(cx))?;
- self.spatial_index = Some(index);
- self.state = SpatialJoinStreamState::FetchProbeBatch;
+ if self.index_provider.is_none() {
+ let spatial_join_components =
+ ready!(self.once_fut_spatial_join_components.get_shared(cx))?;
+ let provider =
Arc::clone(&spatial_join_components.partitioned_index_provider);
+ self.num_regular_partitions =
Some(provider.num_regular_partitions() as u32);
+ self.index_provider = Some(provider);
+ }
+
+ let num_partitions = self
+ .num_regular_partitions
+ .expect("num_regular_partitions should be available");
+ if num_partitions == 0 {
+ // Usually does not happen. The indexed side should have at least
1 partition.
+ self.state = SpatialJoinStreamState::Completed;
+ return Poll::Ready(Ok(StatefulStreamResult::Continue));
+ }
+
+ if num_partitions > 1 {
+ return Poll::Ready(sedona_internal_err!(
+ "Multi-partitioned spatial join is not supported yet"
+ ));
+ }
+
+ self.state = SpatialJoinStreamState::WaitBuildIndex;
Poll::Ready(Ok(StatefulStreamResult::Continue))
}
+ fn wait_build_index(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+ let num_partitions = self
+ .num_regular_partitions
+ .expect("num_regular_partitions should be available");
+ let partition_id = 0;
+ if partition_id >= num_partitions {
+ self.state = SpatialJoinStreamState::Completed;
+ return Poll::Ready(Ok(StatefulStreamResult::Continue));
+ }
+
+ if self.pending_index_future.is_none() {
+ let provider = Arc::clone(
+ self.index_provider
+ .as_ref()
+ .expect("Partitioned index provider should be available"),
+ );
+ let future = {
+ log::debug!(
+ "[Partition {}] Building index for spatial partition {}",
+ self.probe_partition_id,
+ partition_id
+ );
+ async move {
provider.build_or_wait_for_index(partition_id).await }.boxed()
+ };
+ self.pending_index_future = Some(future);
+ }
+
+ let future = self
+ .pending_index_future
+ .as_mut()
+ .expect("pending future must exist");
+
+ match future.poll_unpin(cx) {
+ Poll::Ready(Some(Ok(index))) => {
+ self.pending_index_future = None;
+ self.spatial_index = Some(index);
+ log::debug!(
+ "[Partition {}] Start probing spatial partition {}",
+ self.probe_partition_id,
+ partition_id
+ );
+ self.state = SpatialJoinStreamState::FetchProbeBatch;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ }
+ Poll::Ready(Some(Err(err))) => {
+ self.pending_index_future = None;
+ Poll::Ready(Err(err))
+ }
+ Poll::Ready(None) => {
+ self.pending_index_future = None;
+ self.state = SpatialJoinStreamState::Completed;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ }
+ Poll::Pending => Poll::Pending,
+ }
+ }
+
fn fetch_probe_batch(
&mut self,
cx: &mut std::task::Context<'_>,
@@ -318,8 +418,13 @@ impl SpatialJoinStream {
// Drop the once async to avoid holding a long-living reference to
the spatial index.
// The spatial index will be dropped when this stream is dropped.
- let mut once_async = self.once_async_spatial_index.lock();
+ let mut once_async =
self.once_async_spatial_join_components.lock();
once_async.take();
+
+ if let Some(provider) = self.index_provider.as_ref() {
+ provider.dispose_index(0);
+ assert!(provider.num_loaded_indexes() == 0);
+ }
}
// Initial setup for processing unmatched build batches
diff --git a/rust/sedona-spatial-join/src/utils.rs
b/rust/sedona-spatial-join/src/utils.rs
index 42a257f0..4d73a002 100644
--- a/rust/sedona-spatial-join/src/utils.rs
+++ b/rust/sedona-spatial-join/src/utils.rs
@@ -17,6 +17,7 @@
pub(crate) mod arrow_utils;
pub(crate) mod bbox_sampler;
+pub(crate) mod disposable_async_cell;
pub(crate) mod init_once_array;
pub(crate) mod join_utils;
pub(crate) mod once_fut;
diff --git a/rust/sedona-spatial-join/src/utils/bbox_sampler.rs
b/rust/sedona-spatial-join/src/utils/bbox_sampler.rs
index 498f3863..99280162 100644
--- a/rust/sedona-spatial-join/src/utils/bbox_sampler.rs
+++ b/rust/sedona-spatial-join/src/utils/bbox_sampler.rs
@@ -15,7 +15,6 @@
// specific language governing permissions and limitations
// under the License.
-#![allow(unused)]
use datafusion_common::{DataFusionError, Result};
use fastrand::Rng;
use sedona_geometry::bounding_box::BoundingBox;
diff --git a/rust/sedona-spatial-join/src/utils/disposable_async_cell.rs
b/rust/sedona-spatial-join/src/utils/disposable_async_cell.rs
new file mode 100644
index 00000000..e738e034
--- /dev/null
+++ b/rust/sedona-spatial-join/src/utils/disposable_async_cell.rs
@@ -0,0 +1,204 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::fmt;
+
+use parking_lot::Mutex;
+use tokio::sync::Notify;
+
+/// Error returned when writing to a [`DisposableAsyncCell`] fails.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub(crate) enum CellSetError {
+ /// The cell has already been disposed, so new values are rejected.
+ Disposed,
+
+ /// The cell already has a value.
+ AlreadySet,
+}
+
+/// An asynchronous cell that can be set at most once before either being
+/// disposed or read by any number of waiters.
+///
+/// This is used as a lightweight one-shot coordination primitive in the
spatial
+/// join implementation. For example, `PartitionedIndexProvider` keeps one
+/// `DisposableAsyncCell` per regular partition to publish either a
successfully
+/// built `SpatialIndex` (or the build error) exactly once. Concurrent
+/// `SpatialJoinStream`s racing to probe the same partition can then await the
+/// same shared result instead of building duplicate indexes.
+///
+/// When an index is no longer needed (e.g. the last stream finishes a
+/// partition), the cell can be disposed to free resources.
+///
+/// Awaiters calling [`DisposableAsyncCell::get`] will park until a value is
set
+/// or the cell is disposed. Once disposed, `get` returns `None` and `set`
+/// returns [`CellSetError::Disposed`].
+pub(crate) struct DisposableAsyncCell<T> {
+ state: Mutex<CellState<T>>,
+ notify: Notify,
+}
+
+impl<T> fmt::Debug for DisposableAsyncCell<T> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "DisposableAsyncCell")
+ }
+}
+
+impl<T> Default for DisposableAsyncCell<T> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl<T> DisposableAsyncCell<T> {
+ /// Creates a new empty cell with no stored value.
+ pub(crate) fn new() -> Self {
+ Self {
+ state: Mutex::new(CellState::Empty),
+ notify: Notify::new(),
+ }
+ }
+
+ /// Marks the cell as disposed and wakes every waiter.
+ pub(crate) fn dispose(&self) {
+ {
+ let mut state = self.state.lock();
+ *state = CellState::Disposed;
+ }
+ self.notify.notify_waiters();
+ }
+
+ /// Check whether the cell has a value or not.
+ pub(crate) fn is_set(&self) -> bool {
+ let state = self.state.lock();
+ matches!(*state, CellState::Value(_))
+ }
+
+ /// Check whether the cell is empty (not set or disposed)
+ pub(crate) fn is_empty(&self) -> bool {
+ let state = self.state.lock();
+ matches!(*state, CellState::Empty)
+ }
+}
+
+impl<T: Clone> DisposableAsyncCell<T> {
+ /// Waits until a value is set or the cell is disposed.
+ /// Returns `None` if the cell is disposed without a value.
+ pub(crate) async fn get(&self) -> Option<T> {
+ loop {
+ let notified = self.notify.notified();
+ {
+ let state = self.state.lock();
+ match &*state {
+ CellState::Value(val) => return Some(val.clone()),
+ CellState::Disposed => return None,
+ CellState::Empty => {}
+ }
+ }
+ notified.await;
+ }
+ }
+
+ /// Stores the provided value if the cell is still empty.
+ /// Fails if a value already exists or the cell has been disposed.
+ pub(crate) fn set(&self, value: T) -> std::result::Result<(),
CellSetError> {
+ {
+ let mut state = self.state.lock();
+ match &mut *state {
+ CellState::Empty => *state = CellState::Value(value),
+ CellState::Disposed => return Err(CellSetError::Disposed),
+ CellState::Value(_) => return Err(CellSetError::AlreadySet),
+ }
+ }
+
+ self.notify.notify_waiters();
+ Ok(())
+ }
+}
+
+enum CellState<T> {
+ Empty,
+ Value(T),
+ Disposed,
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{CellSetError, DisposableAsyncCell};
+ use std::sync::Arc;
+ use tokio::task;
+ use tokio::time::{sleep, Duration};
+
+ #[tokio::test]
+ async fn get_returns_value_once_set() {
+ let cell = DisposableAsyncCell::new();
+ cell.set(42).expect("set succeeds");
+ assert_eq!(Some(42), cell.get().await);
+ }
+
+ #[tokio::test]
+ async fn multiple_waiters_receive_same_value() {
+ let cell = Arc::new(DisposableAsyncCell::new());
+ let cloned = Arc::clone(&cell);
+ let waiter_one = task::spawn(async move { cloned.get().await });
+ let cloned = Arc::clone(&cell);
+ let waiter_two = task::spawn(async move { cloned.get().await });
+
+ cell.set(String::from("value")).expect("set succeeds");
+ assert_eq!(Some("value".to_string()), waiter_one.await.unwrap());
+ assert_eq!(Some("value".to_string()), waiter_two.await.unwrap());
+ }
+
+ #[tokio::test]
+ async fn dispose_unblocks_waiters() {
+ let cell = Arc::new(DisposableAsyncCell::<i32>::new());
+ let waiter = tokio::spawn({
+ let cloned = Arc::clone(&cell);
+ async move { cloned.get().await }
+ });
+
+ cell.dispose();
+ assert_eq!(None, waiter.await.unwrap());
+ }
+
+ #[tokio::test]
+ async fn set_after_dispose_fails() {
+ let cell = DisposableAsyncCell::new();
+ cell.dispose();
+ assert_eq!(Err(CellSetError::Disposed), cell.set(5));
+ }
+
+ #[tokio::test]
+ async fn set_twice_rejects_second_value() {
+ let cell = DisposableAsyncCell::new();
+ cell.set("first").expect("initial set succeeds");
+ assert_eq!(Err(CellSetError::AlreadySet), cell.set("second"));
+ assert_eq!(Some("first"), cell.get().await);
+ }
+
+ #[tokio::test]
+ async fn get_waits_until_value_is_set() {
+ let cell = Arc::new(DisposableAsyncCell::new());
+ let cloned = Arc::clone(&cell);
+ let waiter = tokio::spawn(async move { cloned.get().await });
+
+ sleep(Duration::from_millis(20)).await;
+ assert!(!waiter.is_finished());
+
+ cell.set(99).expect("set succeeds");
+ assert_eq!(Some(99), waiter.await.unwrap());
+ }
+}