This is an automated email from the ASF dual-hosted git repository.
kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new 76ae8fc1 chore(rust/sedona-spatial-join): Make partitioner not sync,
create dedicated partitioner for each task (#592)
76ae8fc1 is described below
commit 76ae8fc143f6d3caaabc6f58b2d316572c394853
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Feb 11 20:37:18 2026 +0800
chore(rust/sedona-spatial-join): Make partitioner not sync, create
dedicated partitioner for each task (#592)
This is a follow up of comment
https://github.com/apache/sedona-db/pull/573#discussion_r2763587238. The round
robin partitioner uses an internal atomic counter and produces nondeterministic
partitioning results due to task scheduling orders. This patch did a large
refactoring to make each task own its own partitioner instance, thus eliminates
randomness caused by racing. We have also removed the `Sync` constraint from
the partitioner trait to enforce this design.
---
.../bench/partitioning/stream_repartitioner.rs | 9 ++-
rust/sedona-spatial-join/src/partitioning.rs | 5 +-
.../src/partitioning/broadcast.rs | 5 ++
rust/sedona-spatial-join/src/partitioning/flat.rs | 5 ++
rust/sedona-spatial-join/src/partitioning/kdb.rs | 5 ++
.../src/partitioning/round_robin.rs | 14 ++--
rust/sedona-spatial-join/src/partitioning/rtree.rs | 75 +++++++++++++++-----
.../src/partitioning/stream_repartitioner.rs | 16 ++---
rust/sedona-spatial-join/src/prepare.rs | 30 ++++----
.../src/probe/first_pass_stream.rs | 4 +-
.../src/probe/partitioned_stream_provider.rs | 79 ++++++++++++++++------
11 files changed, 172 insertions(+), 75 deletions(-)
diff --git
a/rust/sedona-spatial-join/bench/partitioning/stream_repartitioner.rs
b/rust/sedona-spatial-join/bench/partitioning/stream_repartitioner.rs
index aa4f8a16..8ea7e2a0 100644
--- a/rust/sedona-spatial-join/bench/partitioning/stream_repartitioner.rs
+++ b/rust/sedona-spatial-join/bench/partitioning/stream_repartitioner.rs
@@ -68,10 +68,9 @@ fn bench_stream_partitioner(c: &mut Criterion) {
let seed_counter = Arc::clone(&seed_counter);
let schema = Arc::clone(&schema);
let runtime_env = Arc::clone(&runtime_env);
- let partitioner = Arc::clone(&partitioner);
let spill_metrics = spill_metrics.clone();
let extent = Arc::clone(&extent);
-
+ let partitioner = partitioner.box_clone();
b.iter_batched(
move || {
let seed = seed_counter.fetch_add(1, Ordering::Relaxed);
@@ -81,7 +80,7 @@ fn bench_stream_partitioner(c: &mut Criterion) {
block_on(async {
StreamRepartitioner::builder(
runtime_env.clone(),
- partitioner.clone(),
+ partitioner.box_clone(),
PartitionedSide::BuildSide,
spill_metrics.clone(),
)
@@ -187,7 +186,7 @@ fn build_schema() -> Schema {
])
}
-fn build_partitioner(extent: &BoundingBox) -> Arc<dyn SpatialPartitioner +
Send + Sync> {
+fn build_partitioner(extent: &BoundingBox) -> Box<dyn SpatialPartitioner> {
let mut rng = StdRng::seed_from_u64(RNG_SEED ^ 0x00FF_FFFF);
let samples = (0..SAMPLE_FOR_PARTITIONER)
.map(|_| random_bbox(extent, &mut rng))
@@ -201,7 +200,7 @@ fn build_partitioner(extent: &BoundingBox) -> Arc<dyn
SpatialPartitioner + Send
)
.expect("kdb builder should succeed");
- Arc::new(partitioner)
+ Box::new(partitioner)
}
fn random_bbox(extent: &BoundingBox, rng: &mut impl RngExt) -> BoundingBox {
diff --git a/rust/sedona-spatial-join/src/partitioning.rs
b/rust/sedona-spatial-join/src/partitioning.rs
index 60028974..df0914d2 100644
--- a/rust/sedona-spatial-join/src/partitioning.rs
+++ b/rust/sedona-spatial-join/src/partitioning.rs
@@ -57,7 +57,7 @@ pub enum SpatialPartition {
}
/// Partitioning larger-than-memory indexed side to support out-of-core
spatial join.
-pub trait SpatialPartitioner: Send + Sync {
+pub trait SpatialPartitioner: Send {
/// Get the total number of spatial partitions, excluding the None
partition and Multi partition.
fn num_regular_partitions(&self) -> usize;
@@ -68,6 +68,9 @@ pub trait SpatialPartitioner: Send + Sync {
/// Multi partition. If `bbox` intersects with multiple partitions, only
one of them will be
/// selected as regular partition.
fn partition_no_multi(&self, bbox: &BoundingBox) ->
Result<SpatialPartition>;
+
+ /// Clone the partitioner as a boxed trait object.
+ fn box_clone(&self) -> Box<dyn SpatialPartitioner>;
}
/// Indicates for which side of the spatial join the partitioning is being
performed.
diff --git a/rust/sedona-spatial-join/src/partitioning/broadcast.rs
b/rust/sedona-spatial-join/src/partitioning/broadcast.rs
index 308dccc0..b7addef2 100644
--- a/rust/sedona-spatial-join/src/partitioning/broadcast.rs
+++ b/rust/sedona-spatial-join/src/partitioning/broadcast.rs
@@ -26,6 +26,7 @@ use crate::partitioning::{SpatialPartition,
SpatialPartitioner};
/// This partitioner is useful when we want to broadcast the data to all
partitions.
/// Currently it is used for KNN join where regular spatial partitioning is
hard because
/// it is hard to know in advance how far away a given number of neighbours
will be to assign it.
+#[derive(Clone)]
pub struct BroadcastPartitioner {
num_partitions: usize,
}
@@ -48,6 +49,10 @@ impl SpatialPartitioner for BroadcastPartitioner {
fn partition_no_multi(&self, _bbox: &BoundingBox) ->
Result<SpatialPartition> {
sedona_internal_err!("BroadcastPartitioner does not support
partition_no_multi")
}
+
+ fn box_clone(&self) -> Box<dyn SpatialPartitioner> {
+ Box::new(self.clone())
+ }
}
#[cfg(test)]
diff --git a/rust/sedona-spatial-join/src/partitioning/flat.rs
b/rust/sedona-spatial-join/src/partitioning/flat.rs
index 29e33b44..3e6bce25 100644
--- a/rust/sedona-spatial-join/src/partitioning/flat.rs
+++ b/rust/sedona-spatial-join/src/partitioning/flat.rs
@@ -40,6 +40,7 @@ use sedona_geometry::interval::IntervalTrait;
use crate::partitioning::{SpatialPartition, SpatialPartitioner};
/// Spatial partitioner that linearly scans partition boundaries.
+#[derive(Clone)]
pub struct FlatPartitioner {
boundaries: Vec<BoundingBox>,
}
@@ -106,6 +107,10 @@ impl SpatialPartitioner for FlatPartitioner {
None => SpatialPartition::None,
})
}
+
+ fn box_clone(&self) -> Box<dyn SpatialPartitioner> {
+ Box::new(self.clone())
+ }
}
#[cfg(test)]
diff --git a/rust/sedona-spatial-join/src/partitioning/kdb.rs
b/rust/sedona-spatial-join/src/partitioning/kdb.rs
index c09e98ff..02197102 100644
--- a/rust/sedona-spatial-join/src/partitioning/kdb.rs
+++ b/rust/sedona-spatial-join/src/partitioning/kdb.rs
@@ -455,6 +455,7 @@ impl KDBTree {
/// let query_bbox = BoundingBox::xy((5.0, 15.0), (5.0, 15.0));
/// let partition = partitioner.partition(&query_bbox).unwrap();
/// ```
+#[derive(Clone)]
pub struct KDBPartitioner {
tree: Arc<KDBTree>,
}
@@ -566,6 +567,10 @@ impl SpatialPartitioner for KDBPartitioner {
None => Ok(SpatialPartition::None),
}
}
+
+ fn box_clone(&self) -> Box<dyn SpatialPartitioner> {
+ Box::new(self.clone())
+ }
}
#[cfg(test)]
diff --git a/rust/sedona-spatial-join/src/partitioning/round_robin.rs
b/rust/sedona-spatial-join/src/partitioning/round_robin.rs
index a5d73117..71251653 100644
--- a/rust/sedona-spatial-join/src/partitioning/round_robin.rs
+++ b/rust/sedona-spatial-join/src/partitioning/round_robin.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use std::sync::atomic::{AtomicUsize, Ordering};
+use std::cell::Cell;
use datafusion_common::Result;
use sedona_geometry::bounding_box::BoundingBox;
@@ -27,16 +27,17 @@ use crate::partitioning::{SpatialPartition,
SpatialPartitioner};
/// This partitioner is used for KNN join, where the build side is partitioned
/// into `num_partitions` partitions, and the probe side is assigned to the
/// `Multi` partition (i.e., broadcast to all partitions).
+#[derive(Clone)]
pub struct RoundRobinPartitioner {
num_partitions: usize,
- counter: AtomicUsize,
+ counter: Cell<usize>,
}
impl RoundRobinPartitioner {
pub fn new(num_partitions: usize) -> Self {
Self {
num_partitions,
- counter: AtomicUsize::new(0),
+ counter: Cell::new(0),
}
}
}
@@ -51,11 +52,16 @@ impl SpatialPartitioner for RoundRobinPartitioner {
}
fn partition_no_multi(&self, _bbox: &BoundingBox) ->
Result<SpatialPartition> {
- let idx = self.counter.fetch_add(1, Ordering::Relaxed);
+ let idx = self.counter.get();
+ self.counter.set(idx.wrapping_add(1));
Ok(SpatialPartition::Regular(
(idx % self.num_partitions) as u32,
))
}
+
+ fn box_clone(&self) -> Box<dyn SpatialPartitioner> {
+ Box::new(self.clone())
+ }
}
#[cfg(test)]
diff --git a/rust/sedona-spatial-join/src/partitioning/rtree.rs
b/rust/sedona-spatial-join/src/partitioning/rtree.rs
index ffb43032..760b359b 100644
--- a/rust/sedona-spatial-join/src/partitioning/rtree.rs
+++ b/rust/sedona-spatial-join/src/partitioning/rtree.rs
@@ -35,6 +35,8 @@
//! 4. **None-partition Handling**: If a bbox doesn't intersect any partition
boundary, it's assigned
//! to [`SpatialPartition::None`].
+use std::sync::Arc;
+
use datafusion_common::Result;
use geo::Rect;
use geo_index::rtree::{sort::HilbertSort, RTree, RTreeBuilder, RTreeIndex};
@@ -50,15 +52,9 @@ use crate::partitioning::{SpatialPartition,
SpatialPartitioner};
/// This partitioner constructs an RTree index over a set of partition
boundaries
/// (rectangles) and uses it to efficiently determine which partition a given
/// bounding box belongs to based on spatial intersection.
+#[derive(Clone)]
pub struct RTreePartitioner {
- /// The RTree index storing partition boundaries as f32 rectangles
- rtree: RTree<f32>,
- /// Flat representation of partition boundaries for overlap calculations
- boundaries: Vec<Rect<f32>>,
- /// Number of partitions (excluding None and Multi)
- num_partitions: usize,
- /// Map from RTree index to original partition index
- partition_map: Vec<usize>,
+ inner: Arc<RawRTreePartitioner>,
}
impl RTreePartitioner {
@@ -84,12 +80,58 @@ impl RTreePartitioner {
/// let partitioner = RTreePartitioner::try_new(boundaries).unwrap();
/// ```
pub fn try_new(boundaries: Vec<BoundingBox>) -> Result<Self> {
- Self::build(boundaries, None)
+ let inner = RawRTreePartitioner::try_new(boundaries)?;
+ Ok(Self {
+ inner: Arc::new(inner),
+ })
}
/// Create a new RTree partitioner with a custom node size.
pub fn try_new_with_node_size(boundaries: Vec<BoundingBox>, node_size:
u16) -> Result<Self> {
- Self::build(boundaries, Some(node_size))
+ let inner = RawRTreePartitioner::build(boundaries, Some(node_size))?;
+ Ok(Self {
+ inner: Arc::new(inner),
+ })
+ }
+
+ /// Return the number of levels in the underlying RTree.
+ pub fn depth(&self) -> usize {
+ self.inner.depth()
+ }
+}
+
+impl SpatialPartitioner for RTreePartitioner {
+ fn num_regular_partitions(&self) -> usize {
+ self.inner.num_regular_partitions()
+ }
+
+ fn partition(&self, bbox: &BoundingBox) -> Result<SpatialPartition> {
+ self.inner.partition(bbox)
+ }
+
+ fn partition_no_multi(&self, bbox: &BoundingBox) ->
Result<SpatialPartition> {
+ self.inner.partition_no_multi(bbox)
+ }
+
+ fn box_clone(&self) -> Box<dyn SpatialPartitioner> {
+ Box::new(self.clone())
+ }
+}
+
+struct RawRTreePartitioner {
+ /// The RTree index storing partition boundaries as f32 rectangles
+ rtree: RTree<f32>,
+ /// Flat representation of partition boundaries for overlap calculations
+ boundaries: Vec<Rect<f32>>,
+ /// Number of partitions (excluding None and Multi)
+ num_partitions: usize,
+ /// Map from RTree index to original partition index
+ partition_map: Vec<usize>,
+}
+
+impl RawRTreePartitioner {
+ fn try_new(boundaries: Vec<BoundingBox>) -> Result<Self> {
+ Self::build(boundaries, None)
}
fn build(boundaries: Vec<BoundingBox>, node_size: Option<u16>) ->
Result<Self> {
@@ -122,7 +164,7 @@ impl RTreePartitioner {
let rtree = rtree_builder.finish::<HilbertSort>();
- Ok(RTreePartitioner {
+ Ok(RawRTreePartitioner {
rtree,
boundaries: rects,
num_partitions,
@@ -130,17 +172,14 @@ impl RTreePartitioner {
})
}
- /// Return the number of levels in the underlying RTree.
- pub fn depth(&self) -> usize {
- self.rtree.num_levels()
- }
-}
-
-impl SpatialPartitioner for RTreePartitioner {
fn num_regular_partitions(&self) -> usize {
self.num_partitions
}
+ fn depth(&self) -> usize {
+ self.rtree.num_levels()
+ }
+
fn partition(&self, bbox: &BoundingBox) -> Result<SpatialPartition> {
// Convert bbox to f32 for RTree query with proper bounds handling
let (min_x, min_y, max_x, max_y) = match bbox_to_f32_rect(bbox)? {
diff --git a/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs
b/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs
index 0a326cce..cc8e6d7e 100644
--- a/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs
+++ b/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs
@@ -300,7 +300,7 @@ impl SpilledPartitions {
/// `target_batch_size` rows per partition batch.
pub struct StreamRepartitioner {
runtime_env: Arc<RuntimeEnv>,
- partitioner: Arc<dyn SpatialPartitioner>,
+ partitioner: Box<dyn SpatialPartitioner>,
partitioned_side: PartitionedSide,
slots: PartitionSlots,
/// Spill files for each spatial partition.
@@ -330,7 +330,7 @@ pub struct StreamRepartitioner {
/// - `spilled_batch_in_memory_size_threshold`: `None`
pub struct StreamRepartitionerBuilder {
runtime_env: Arc<RuntimeEnv>,
- partitioner: Arc<dyn SpatialPartitioner>,
+ partitioner: Box<dyn SpatialPartitioner>,
partitioned_side: PartitionedSide,
spill_compression: SpillCompression,
spill_metrics: SpillMetrics,
@@ -407,7 +407,7 @@ impl StreamRepartitioner {
/// spill metrics). Optional parameters can then be set on the returned
builder.
pub fn builder(
runtime_env: Arc<RuntimeEnv>,
- partitioner: Arc<dyn SpatialPartitioner>,
+ partitioner: Box<dyn SpatialPartitioner>,
partitioned_side: PartitionedSide,
spill_metrics: SpillMetrics,
) -> StreamRepartitionerBuilder {
@@ -840,7 +840,7 @@ mod tests {
BoundingBox::xy((0.0, 50.0), (0.0, 50.0)),
BoundingBox::xy((50.0, 100.0), (0.0, 50.0)),
];
- let partitioner = Arc::new(FlatPartitioner::try_new(partitions)?);
+ let partitioner = Box::new(FlatPartitioner::try_new(partitions)?);
let runtime_env = Arc::new(RuntimeEnv::default());
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
@@ -926,7 +926,7 @@ mod tests {
BoundingBox::xy((0.0, 50.0), (0.0, 50.0)),
BoundingBox::xy((50.0, 100.0), (0.0, 50.0)),
];
- let partitioner = Arc::new(FlatPartitioner::try_new(partitions)?);
+ let partitioner = Box::new(FlatPartitioner::try_new(partitions)?);
let runtime_env = Arc::new(RuntimeEnv::default());
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
@@ -990,7 +990,7 @@ mod tests {
BoundingBox::xy((0.0, 50.0), (0.0, 50.0)),
BoundingBox::xy((50.0, 100.0), (0.0, 50.0)),
];
- let partitioner = Arc::new(FlatPartitioner::try_new(partitions)?);
+ let partitioner = Box::new(FlatPartitioner::try_new(partitions)?);
let runtime_env = Arc::new(RuntimeEnv::default());
let spill_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(),
0);
let mut repartitioner = StreamRepartitioner::builder(
@@ -1035,7 +1035,7 @@ mod tests {
let batch_a = sample_batch(&[0], vec![Some(wkb_point((10.0,
10.0)).unwrap())])?;
let batch_b = sample_batch(&[1], vec![Some(wkb_point((20.0,
10.0)).unwrap())])?;
let partitions = vec![BoundingBox::xy((0.0, 50.0), (0.0, 50.0))];
- let partitioner = Arc::new(FlatPartitioner::try_new(partitions)?);
+ let partitioner = Box::new(FlatPartitioner::try_new(partitions)?);
let runtime_env = Arc::new(RuntimeEnv::default());
let spill_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(),
0);
let mut repartitioner = StreamRepartitioner::builder(
@@ -1069,7 +1069,7 @@ mod tests {
let batch_a = sample_batch(&[0], vec![Some(wkb_point((10.0,
10.0)).unwrap())])?;
let batch_b = sample_batch(&[1], vec![Some(wkb_point((20.0,
10.0)).unwrap())])?;
let partitions = vec![BoundingBox::xy((0.0, 50.0), (0.0, 50.0))];
- let partitioner = Arc::new(FlatPartitioner::try_new(partitions)?);
+ let partitioner = Box::new(FlatPartitioner::try_new(partitions)?);
let runtime_env = Arc::new(RuntimeEnv::default());
let spill_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(),
0);
let mut repartitioner = StreamRepartitioner::builder(
diff --git a/rust/sedona-spatial-join/src/prepare.rs
b/rust/sedona-spatial-join/src/prepare.rs
index 309eca12..9e9a2053 100644
--- a/rust/sedona-spatial-join/src/prepare.rs
+++ b/rust/sedona-spatial-join/src/prepare.rs
@@ -244,8 +244,8 @@ impl SpatialJoinComponentsBuilder {
num_partitions: usize,
build_partitions: &mut Vec<BuildPartition>,
seed: u64,
- ) -> Result<Arc<dyn SpatialPartitioner>> {
- let build_partitioner: Arc<dyn SpatialPartitioner> = if matches!(
+ ) -> Result<Box<dyn SpatialPartitioner>> {
+ let build_partitioner: Box<dyn SpatialPartitioner> = if matches!(
self.spatial_predicate,
SpatialPredicate::KNearestNeighbors(_)
) {
@@ -253,7 +253,7 @@ impl SpatialJoinComponentsBuilder {
// partitioning to spread the indexed data evenly to make each
index fit in memory, and
// the probe side will be broadcasted to all partitions by
partitioning all of them to
// the Multi partition.
- Arc::new(RoundRobinPartitioner::new(num_partitions))
+ Box::new(RoundRobinPartitioner::new(num_partitions))
} else {
// Use spatial partitioners to partition the build side and the
probe side, this will
// reduce the amount of work needed for probing each partitioned
index.
@@ -290,7 +290,7 @@ impl SpatialJoinComponentsBuilder {
kdb_partitioner.debug_str()
);
- Arc::new(kdb_partitioner)
+ Box::new(kdb_partitioner)
};
Ok(build_partitioner)
@@ -302,12 +302,12 @@ impl SpatialJoinComponentsBuilder {
&self,
num_partitions: usize,
merged_spilled_partitions: &SpilledPartitions,
- ) -> Result<Arc<dyn SpatialPartitioner>> {
- let probe_partitioner: Arc<dyn SpatialPartitioner> = if matches!(
+ ) -> Result<Box<dyn SpatialPartitioner>> {
+ let probe_partitioner: Box<dyn SpatialPartitioner> = if matches!(
self.spatial_predicate,
SpatialPredicate::KNearestNeighbors(_)
) {
- Arc::new(BroadcastPartitioner::new(num_partitions))
+ Box::new(BroadcastPartitioner::new(num_partitions))
} else {
// Build a flat partitioner using these partitions
let mut partition_bounds = Vec::with_capacity(num_partitions);
@@ -320,7 +320,7 @@ impl SpatialJoinComponentsBuilder {
.unwrap_or(BoundingBox::empty());
partition_bounds.push(partition_bound);
}
- Arc::new(FlatPartitioner::try_new(partition_bounds)?)
+ Box::new(FlatPartitioner::try_new(partition_bounds)?)
};
Ok(probe_partitioner)
}
@@ -330,7 +330,7 @@ impl SpatialJoinComponentsBuilder {
async fn repartition_build_side(
&self,
build_partitions: Vec<BuildPartition>,
- build_partitioner: Arc<dyn SpatialPartitioner>,
+ build_partitioner: Box<dyn SpatialPartitioner>,
memory_plan: &MemoryPlan,
) -> Result<Vec<SpilledPartitions>> {
// Spawn each task for each build partition to repartition the data
using the spatial partitioner for
@@ -349,7 +349,7 @@ impl SpatialJoinComponentsBuilder {
let metrics = &partition.metrics;
let spill_metrics = metrics.spill_metrics();
let runtime_env = Arc::clone(&runtime_env);
- let partitioner = Arc::clone(&build_partitioner);
+ let partitioner = build_partitioner.box_clone();
join_set.spawn(async move {
let partitioned_spill_files = StreamRepartitioner::builder(
runtime_env,
@@ -437,7 +437,7 @@ impl SpatialJoinComponentsBuilder {
fn create_multi_partitioned_spatial_join_components(
self,
merged_spilled_partitions: SpilledPartitions,
- probe_partitioner: Arc<dyn SpatialPartitioner>,
+ probe_partitioner: Box<dyn SpatialPartitioner>,
reservations: Vec<MemoryReservation>,
memory_plan: &MemoryPlan,
) -> Result<SpatialJoinComponents> {
@@ -459,13 +459,13 @@ impl SpatialJoinComponentsBuilder {
);
let buffer_bytes_threshold = memory_for_intermittent_usage /
self.probe_threads_count;
- let probe_stream_options = ProbeStreamOptions {
- partitioner: Some(probe_partitioner),
- target_batch_rows: target_batch_size,
+ let probe_stream_options = ProbeStreamOptions::new(
+ Some(probe_partitioner),
+ target_batch_size,
spill_compression,
buffer_bytes_threshold,
spilled_batch_in_memory_size_threshold,
- };
+ );
Ok(SpatialJoinComponents {
partitioned_index_provider: Arc::new(partitioned_index_provider),
diff --git a/rust/sedona-spatial-join/src/probe/first_pass_stream.rs
b/rust/sedona-spatial-join/src/probe/first_pass_stream.rs
index 7e75b344..9bc691cf 100644
--- a/rust/sedona-spatial-join/src/probe/first_pass_stream.rs
+++ b/rust/sedona-spatial-join/src/probe/first_pass_stream.rs
@@ -51,7 +51,7 @@ use crate::{
pub(crate) struct FirstPassStream<C: FirstPassStreamCallback> {
source: SendableEvaluatedBatchStream,
repartitioner: Option<StreamRepartitioner>,
- partitioner: Arc<dyn SpatialPartitioner>,
+ partitioner: Box<dyn SpatialPartitioner>,
pending_output: VecDeque<Result<EvaluatedBatch>>,
metrics: ProbeStreamMetrics,
callback: Option<C>,
@@ -71,7 +71,7 @@ impl<C: FirstPassStreamCallback> FirstPassStream<C> {
pub fn new(
source: SendableEvaluatedBatchStream,
repartitioner: StreamRepartitioner,
- partitioner: Arc<dyn SpatialPartitioner>,
+ partitioner: Box<dyn SpatialPartitioner>,
metrics: ProbeStreamMetrics,
callback: C,
) -> Self {
diff --git a/rust/sedona-spatial-join/src/probe/partitioned_stream_provider.rs
b/rust/sedona-spatial-join/src/probe/partitioned_stream_provider.rs
index 14c45b31..ef4e91c6 100644
--- a/rust/sedona-spatial-join/src/probe/partitioned_stream_provider.rs
+++ b/rust/sedona-spatial-join/src/probe/partitioned_stream_provider.rs
@@ -38,7 +38,6 @@ use crate::{
},
};
-#[derive(Clone)]
/// Configuration options for creating a probe-side stream provider.
///
/// When a `partitioner` is provided, the provider performs an initial first
pass that
@@ -50,7 +49,12 @@ pub(crate) struct ProbeStreamOptions {
/// - `None` means the probe side is treated as a single, non-partitioned
stream and only
/// [`SpatialPartition::Regular(0)`] is supported.
/// - `Some(_)` enables partitioned streaming with a warm-up (first) pass.
- pub partitioner: Option<Arc<dyn SpatialPartitioner>>,
+ ///
+ /// The `Mutex` is used here to make [`ProbeStreamOptions`] (and its
contained options)
+ /// `Send + Sync` so it can be shared/cloned into `SpatialJoinExec` and
across tasks.
+ /// The partitioner itself is treated as a clonable prototype and is not
intended to be
+ /// used by multiple tasks concurrently via this shared `Mutex`.
+ pub partitioner: Option<Mutex<Box<dyn SpatialPartitioner>>>,
/// Target number of rows per output batch produced by the partitioning
stream.
pub target_batch_rows: usize,
/// Spill compression to use when writing partition spill files.
@@ -62,6 +66,41 @@ pub(crate) struct ProbeStreamOptions {
pub spilled_batch_in_memory_size_threshold: Option<usize>,
}
+impl ProbeStreamOptions {
+ pub fn new(
+ partitioner: Option<Box<dyn SpatialPartitioner>>,
+ target_batch_rows: usize,
+ spill_compression: SpillCompression,
+ buffer_bytes_threshold: usize,
+ spilled_batch_in_memory_size_threshold: Option<usize>,
+ ) -> Self {
+ let partitioner = partitioner.map(Mutex::new);
+ Self {
+ partitioner,
+ target_batch_rows,
+ spill_compression,
+ buffer_bytes_threshold,
+ spilled_batch_in_memory_size_threshold,
+ }
+ }
+}
+
+impl Clone for ProbeStreamOptions {
+ fn clone(&self) -> Self {
+ let cloned_partitioner = self
+ .partitioner
+ .as_ref()
+ .map(|p| Mutex::new(p.lock().box_clone()));
+ Self {
+ partitioner: cloned_partitioner,
+ target_batch_rows: self.target_batch_rows,
+ spill_compression: self.spill_compression,
+ buffer_bytes_threshold: self.buffer_bytes_threshold,
+ spilled_batch_in_memory_size_threshold:
self.spilled_batch_in_memory_size_threshold,
+ }
+ }
+}
+
/// Provides probe-side streams for a given [`SpatialPartition`].
///
/// For partitioned joins this provider is a small state machine:
@@ -144,15 +183,17 @@ impl PartitionedProbeStreamProvider {
let mut state_guard = self.state.lock();
match std::mem::replace(&mut *state_guard,
ProbeStreamState::FirstPass) {
ProbeStreamState::Pending { source } => {
- let partitioner = Arc::clone(
- self.options
- .partitioner
- .as_ref()
- .expect("Partitioned first pass requires a
partitioner"),
- );
+ let partitioner_for_stream = self
+ .options
+ .partitioner
+ .as_ref()
+ .expect("Partitioned first pass requires a partitioner")
+ .lock()
+ .box_clone();
+ let partitioner_for_repartitioner =
partitioner_for_stream.box_clone();
let repartitioner = StreamRepartitioner::builder(
Arc::clone(&self.runtime_env),
- Arc::clone(&partitioner),
+ partitioner_for_repartitioner,
PartitionedSide::ProbeSide,
self.metrics.spill_metrics.clone(),
)
@@ -206,7 +247,7 @@ impl PartitionedProbeStreamProvider {
let first_pass = FirstPassStream::new(
source,
repartitioner,
- partitioner,
+ partitioner_for_stream,
self.metrics.clone(),
callback,
);
@@ -384,7 +425,7 @@ mod tests {
fn create_probe_stream(
batches: Vec<EvaluatedBatch>,
- partitioner: Option<Arc<dyn SpatialPartitioner>>,
+ partitioner: Option<Box<dyn SpatialPartitioner>>,
) -> PartitionedProbeStreamProvider {
let runtime_env = Arc::new(RuntimeEnv::default());
assert!(!batches.is_empty(), "test batches should not be empty");
@@ -393,24 +434,18 @@ mod tests {
Box::pin(InMemoryEvaluatedBatchStream::new(schema, batches));
PartitionedProbeStreamProvider::new(
runtime_env,
- ProbeStreamOptions {
- partitioner,
- target_batch_rows: 1024,
- spill_compression: SpillCompression::Uncompressed,
- buffer_bytes_threshold: 0,
- spilled_batch_in_memory_size_threshold: None,
- },
+ ProbeStreamOptions::new(partitioner, 1024,
SpillCompression::Uncompressed, 0, None),
stream,
ProbeStreamMetrics::new(0, &ExecutionPlanMetricsSet::new()),
)
}
- fn sample_partitioner() -> Result<Arc<dyn SpatialPartitioner>> {
+ fn sample_partitioner() -> Result<Box<dyn SpatialPartitioner>> {
let partitions = vec![
BoundingBox::xy((0.0, 50.0), (0.0, 50.0)),
BoundingBox::xy((50.0, 100.0), (0.0, 50.0)),
];
- Ok(Arc::new(FlatPartitioner::try_new(partitions)?))
+ Ok(Box::new(FlatPartitioner::try_new(partitions)?))
}
#[tokio::test]
@@ -425,7 +460,7 @@ mod tests {
Some(wkb_point((200.0, 200.0)).unwrap()),
],
)?;
- let probe_stream = create_probe_stream(vec![batch],
Some(Arc::clone(&partitioner)));
+ let probe_stream = create_probe_stream(vec![batch], Some(partitioner));
let first_pass =
probe_stream.stream_for(SpatialPartition::Regular(0))?;
let batches = first_pass.try_collect::<Vec<_>>().await?;
@@ -448,7 +483,7 @@ mod tests {
async fn requesting_regular_partition_before_first_pass_fails() ->
Result<()> {
let partitioner = sample_partitioner()?;
let batch = sample_batch(&[0], vec![Some(wkb_point((60.0,
10.0)).unwrap())])?;
- let probe_stream = create_probe_stream(vec![batch],
Some(Arc::clone(&partitioner)));
+ let probe_stream = create_probe_stream(vec![batch], Some(partitioner));
assert!(probe_stream
.stream_for(SpatialPartition::Regular(1))
.is_err());