This is an automated email from the ASF dual-hosted git repository.

kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new 76ae8fc1 chore(rust/sedona-spatial-join): Make partitioner not sync, 
create dedicated partitioner for each task (#592)
76ae8fc1 is described below

commit 76ae8fc143f6d3caaabc6f58b2d316572c394853
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Feb 11 20:37:18 2026 +0800

    chore(rust/sedona-spatial-join): Make partitioner not sync, create 
dedicated partitioner for each task (#592)
    
    This is a follow up of comment 
https://github.com/apache/sedona-db/pull/573#discussion_r2763587238. The round 
robin partitioner uses an internal atomic counter and produces nondeterministic 
partitioning results due to task scheduling orders. This patch did a large 
refactoring to make each task own its own partitioner instance, thus eliminates 
randomness caused by racing. We have also removed the `Sync` constraint from 
the partitioner trait to enforce this design.
---
 .../bench/partitioning/stream_repartitioner.rs     |  9 ++-
 rust/sedona-spatial-join/src/partitioning.rs       |  5 +-
 .../src/partitioning/broadcast.rs                  |  5 ++
 rust/sedona-spatial-join/src/partitioning/flat.rs  |  5 ++
 rust/sedona-spatial-join/src/partitioning/kdb.rs   |  5 ++
 .../src/partitioning/round_robin.rs                | 14 ++--
 rust/sedona-spatial-join/src/partitioning/rtree.rs | 75 +++++++++++++++-----
 .../src/partitioning/stream_repartitioner.rs       | 16 ++---
 rust/sedona-spatial-join/src/prepare.rs            | 30 ++++----
 .../src/probe/first_pass_stream.rs                 |  4 +-
 .../src/probe/partitioned_stream_provider.rs       | 79 ++++++++++++++++------
 11 files changed, 172 insertions(+), 75 deletions(-)

diff --git 
a/rust/sedona-spatial-join/bench/partitioning/stream_repartitioner.rs 
b/rust/sedona-spatial-join/bench/partitioning/stream_repartitioner.rs
index aa4f8a16..8ea7e2a0 100644
--- a/rust/sedona-spatial-join/bench/partitioning/stream_repartitioner.rs
+++ b/rust/sedona-spatial-join/bench/partitioning/stream_repartitioner.rs
@@ -68,10 +68,9 @@ fn bench_stream_partitioner(c: &mut Criterion) {
         let seed_counter = Arc::clone(&seed_counter);
         let schema = Arc::clone(&schema);
         let runtime_env = Arc::clone(&runtime_env);
-        let partitioner = Arc::clone(&partitioner);
         let spill_metrics = spill_metrics.clone();
         let extent = Arc::clone(&extent);
-
+        let partitioner = partitioner.box_clone();
         b.iter_batched(
             move || {
                 let seed = seed_counter.fetch_add(1, Ordering::Relaxed);
@@ -81,7 +80,7 @@ fn bench_stream_partitioner(c: &mut Criterion) {
                 block_on(async {
                     StreamRepartitioner::builder(
                         runtime_env.clone(),
-                        partitioner.clone(),
+                        partitioner.box_clone(),
                         PartitionedSide::BuildSide,
                         spill_metrics.clone(),
                     )
@@ -187,7 +186,7 @@ fn build_schema() -> Schema {
     ])
 }
 
-fn build_partitioner(extent: &BoundingBox) -> Arc<dyn SpatialPartitioner + 
Send + Sync> {
+fn build_partitioner(extent: &BoundingBox) -> Box<dyn SpatialPartitioner> {
     let mut rng = StdRng::seed_from_u64(RNG_SEED ^ 0x00FF_FFFF);
     let samples = (0..SAMPLE_FOR_PARTITIONER)
         .map(|_| random_bbox(extent, &mut rng))
@@ -201,7 +200,7 @@ fn build_partitioner(extent: &BoundingBox) -> Arc<dyn 
SpatialPartitioner + Send
     )
     .expect("kdb builder should succeed");
 
-    Arc::new(partitioner)
+    Box::new(partitioner)
 }
 
 fn random_bbox(extent: &BoundingBox, rng: &mut impl RngExt) -> BoundingBox {
diff --git a/rust/sedona-spatial-join/src/partitioning.rs 
b/rust/sedona-spatial-join/src/partitioning.rs
index 60028974..df0914d2 100644
--- a/rust/sedona-spatial-join/src/partitioning.rs
+++ b/rust/sedona-spatial-join/src/partitioning.rs
@@ -57,7 +57,7 @@ pub enum SpatialPartition {
 }
 
 /// Partitioning larger-than-memory indexed side to support out-of-core 
spatial join.
-pub trait SpatialPartitioner: Send + Sync {
+pub trait SpatialPartitioner: Send {
     /// Get the total number of spatial partitions, excluding the None 
partition and Multi partition.
     fn num_regular_partitions(&self) -> usize;
 
@@ -68,6 +68,9 @@ pub trait SpatialPartitioner: Send + Sync {
     /// Multi partition. If `bbox` intersects with multiple partitions, only 
one of them will be
     /// selected as regular partition.
     fn partition_no_multi(&self, bbox: &BoundingBox) -> 
Result<SpatialPartition>;
+
+    /// Clone the partitioner as a boxed trait object.
+    fn box_clone(&self) -> Box<dyn SpatialPartitioner>;
 }
 
 /// Indicates for which side of the spatial join the partitioning is being 
performed.
diff --git a/rust/sedona-spatial-join/src/partitioning/broadcast.rs 
b/rust/sedona-spatial-join/src/partitioning/broadcast.rs
index 308dccc0..b7addef2 100644
--- a/rust/sedona-spatial-join/src/partitioning/broadcast.rs
+++ b/rust/sedona-spatial-join/src/partitioning/broadcast.rs
@@ -26,6 +26,7 @@ use crate::partitioning::{SpatialPartition, 
SpatialPartitioner};
 /// This partitioner is useful when we want to broadcast the data to all 
partitions.
 /// Currently it is used for KNN join where regular spatial partitioning is 
hard because
 /// it is hard to know in advance how far away a given number of neighbours 
will be to assign it.
+#[derive(Clone)]
 pub struct BroadcastPartitioner {
     num_partitions: usize,
 }
@@ -48,6 +49,10 @@ impl SpatialPartitioner for BroadcastPartitioner {
     fn partition_no_multi(&self, _bbox: &BoundingBox) -> 
Result<SpatialPartition> {
         sedona_internal_err!("BroadcastPartitioner does not support 
partition_no_multi")
     }
+
+    fn box_clone(&self) -> Box<dyn SpatialPartitioner> {
+        Box::new(self.clone())
+    }
 }
 
 #[cfg(test)]
diff --git a/rust/sedona-spatial-join/src/partitioning/flat.rs 
b/rust/sedona-spatial-join/src/partitioning/flat.rs
index 29e33b44..3e6bce25 100644
--- a/rust/sedona-spatial-join/src/partitioning/flat.rs
+++ b/rust/sedona-spatial-join/src/partitioning/flat.rs
@@ -40,6 +40,7 @@ use sedona_geometry::interval::IntervalTrait;
 use crate::partitioning::{SpatialPartition, SpatialPartitioner};
 
 /// Spatial partitioner that linearly scans partition boundaries.
+#[derive(Clone)]
 pub struct FlatPartitioner {
     boundaries: Vec<BoundingBox>,
 }
@@ -106,6 +107,10 @@ impl SpatialPartitioner for FlatPartitioner {
             None => SpatialPartition::None,
         })
     }
+
+    fn box_clone(&self) -> Box<dyn SpatialPartitioner> {
+        Box::new(self.clone())
+    }
 }
 
 #[cfg(test)]
diff --git a/rust/sedona-spatial-join/src/partitioning/kdb.rs 
b/rust/sedona-spatial-join/src/partitioning/kdb.rs
index c09e98ff..02197102 100644
--- a/rust/sedona-spatial-join/src/partitioning/kdb.rs
+++ b/rust/sedona-spatial-join/src/partitioning/kdb.rs
@@ -455,6 +455,7 @@ impl KDBTree {
 /// let query_bbox = BoundingBox::xy((5.0, 15.0), (5.0, 15.0));
 /// let partition = partitioner.partition(&query_bbox).unwrap();
 /// ```
+#[derive(Clone)]
 pub struct KDBPartitioner {
     tree: Arc<KDBTree>,
 }
@@ -566,6 +567,10 @@ impl SpatialPartitioner for KDBPartitioner {
             None => Ok(SpatialPartition::None),
         }
     }
+
+    fn box_clone(&self) -> Box<dyn SpatialPartitioner> {
+        Box::new(self.clone())
+    }
 }
 
 #[cfg(test)]
diff --git a/rust/sedona-spatial-join/src/partitioning/round_robin.rs 
b/rust/sedona-spatial-join/src/partitioning/round_robin.rs
index a5d73117..71251653 100644
--- a/rust/sedona-spatial-join/src/partitioning/round_robin.rs
+++ b/rust/sedona-spatial-join/src/partitioning/round_robin.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::sync::atomic::{AtomicUsize, Ordering};
+use std::cell::Cell;
 
 use datafusion_common::Result;
 use sedona_geometry::bounding_box::BoundingBox;
@@ -27,16 +27,17 @@ use crate::partitioning::{SpatialPartition, 
SpatialPartitioner};
 /// This partitioner is used for KNN join, where the build side is partitioned
 /// into `num_partitions` partitions, and the probe side is assigned to the
 /// `Multi` partition (i.e., broadcast to all partitions).
+#[derive(Clone)]
 pub struct RoundRobinPartitioner {
     num_partitions: usize,
-    counter: AtomicUsize,
+    counter: Cell<usize>,
 }
 
 impl RoundRobinPartitioner {
     pub fn new(num_partitions: usize) -> Self {
         Self {
             num_partitions,
-            counter: AtomicUsize::new(0),
+            counter: Cell::new(0),
         }
     }
 }
@@ -51,11 +52,16 @@ impl SpatialPartitioner for RoundRobinPartitioner {
     }
 
     fn partition_no_multi(&self, _bbox: &BoundingBox) -> 
Result<SpatialPartition> {
-        let idx = self.counter.fetch_add(1, Ordering::Relaxed);
+        let idx = self.counter.get();
+        self.counter.set(idx.wrapping_add(1));
         Ok(SpatialPartition::Regular(
             (idx % self.num_partitions) as u32,
         ))
     }
+
+    fn box_clone(&self) -> Box<dyn SpatialPartitioner> {
+        Box::new(self.clone())
+    }
 }
 
 #[cfg(test)]
diff --git a/rust/sedona-spatial-join/src/partitioning/rtree.rs 
b/rust/sedona-spatial-join/src/partitioning/rtree.rs
index ffb43032..760b359b 100644
--- a/rust/sedona-spatial-join/src/partitioning/rtree.rs
+++ b/rust/sedona-spatial-join/src/partitioning/rtree.rs
@@ -35,6 +35,8 @@
 //! 4. **None-partition Handling**: If a bbox doesn't intersect any partition 
boundary, it's assigned
 //!    to [`SpatialPartition::None`].
 
+use std::sync::Arc;
+
 use datafusion_common::Result;
 use geo::Rect;
 use geo_index::rtree::{sort::HilbertSort, RTree, RTreeBuilder, RTreeIndex};
@@ -50,15 +52,9 @@ use crate::partitioning::{SpatialPartition, 
SpatialPartitioner};
 /// This partitioner constructs an RTree index over a set of partition 
boundaries
 /// (rectangles) and uses it to efficiently determine which partition a given
 /// bounding box belongs to based on spatial intersection.
+#[derive(Clone)]
 pub struct RTreePartitioner {
-    /// The RTree index storing partition boundaries as f32 rectangles
-    rtree: RTree<f32>,
-    /// Flat representation of partition boundaries for overlap calculations
-    boundaries: Vec<Rect<f32>>,
-    /// Number of partitions (excluding None and Multi)
-    num_partitions: usize,
-    /// Map from RTree index to original partition index
-    partition_map: Vec<usize>,
+    inner: Arc<RawRTreePartitioner>,
 }
 
 impl RTreePartitioner {
@@ -84,12 +80,58 @@ impl RTreePartitioner {
     /// let partitioner = RTreePartitioner::try_new(boundaries).unwrap();
     /// ```
     pub fn try_new(boundaries: Vec<BoundingBox>) -> Result<Self> {
-        Self::build(boundaries, None)
+        let inner = RawRTreePartitioner::try_new(boundaries)?;
+        Ok(Self {
+            inner: Arc::new(inner),
+        })
     }
 
     /// Create a new RTree partitioner with a custom node size.
     pub fn try_new_with_node_size(boundaries: Vec<BoundingBox>, node_size: 
u16) -> Result<Self> {
-        Self::build(boundaries, Some(node_size))
+        let inner = RawRTreePartitioner::build(boundaries, Some(node_size))?;
+        Ok(Self {
+            inner: Arc::new(inner),
+        })
+    }
+
+    /// Return the number of levels in the underlying RTree.
+    pub fn depth(&self) -> usize {
+        self.inner.depth()
+    }
+}
+
+impl SpatialPartitioner for RTreePartitioner {
+    fn num_regular_partitions(&self) -> usize {
+        self.inner.num_regular_partitions()
+    }
+
+    fn partition(&self, bbox: &BoundingBox) -> Result<SpatialPartition> {
+        self.inner.partition(bbox)
+    }
+
+    fn partition_no_multi(&self, bbox: &BoundingBox) -> 
Result<SpatialPartition> {
+        self.inner.partition_no_multi(bbox)
+    }
+
+    fn box_clone(&self) -> Box<dyn SpatialPartitioner> {
+        Box::new(self.clone())
+    }
+}
+
+struct RawRTreePartitioner {
+    /// The RTree index storing partition boundaries as f32 rectangles
+    rtree: RTree<f32>,
+    /// Flat representation of partition boundaries for overlap calculations
+    boundaries: Vec<Rect<f32>>,
+    /// Number of partitions (excluding None and Multi)
+    num_partitions: usize,
+    /// Map from RTree index to original partition index
+    partition_map: Vec<usize>,
+}
+
+impl RawRTreePartitioner {
+    fn try_new(boundaries: Vec<BoundingBox>) -> Result<Self> {
+        Self::build(boundaries, None)
     }
 
     fn build(boundaries: Vec<BoundingBox>, node_size: Option<u16>) -> 
Result<Self> {
@@ -122,7 +164,7 @@ impl RTreePartitioner {
 
         let rtree = rtree_builder.finish::<HilbertSort>();
 
-        Ok(RTreePartitioner {
+        Ok(RawRTreePartitioner {
             rtree,
             boundaries: rects,
             num_partitions,
@@ -130,17 +172,14 @@ impl RTreePartitioner {
         })
     }
 
-    /// Return the number of levels in the underlying RTree.
-    pub fn depth(&self) -> usize {
-        self.rtree.num_levels()
-    }
-}
-
-impl SpatialPartitioner for RTreePartitioner {
     fn num_regular_partitions(&self) -> usize {
         self.num_partitions
     }
 
+    fn depth(&self) -> usize {
+        self.rtree.num_levels()
+    }
+
     fn partition(&self, bbox: &BoundingBox) -> Result<SpatialPartition> {
         // Convert bbox to f32 for RTree query with proper bounds handling
         let (min_x, min_y, max_x, max_y) = match bbox_to_f32_rect(bbox)? {
diff --git a/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs 
b/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs
index 0a326cce..cc8e6d7e 100644
--- a/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs
+++ b/rust/sedona-spatial-join/src/partitioning/stream_repartitioner.rs
@@ -300,7 +300,7 @@ impl SpilledPartitions {
 /// `target_batch_size` rows per partition batch.
 pub struct StreamRepartitioner {
     runtime_env: Arc<RuntimeEnv>,
-    partitioner: Arc<dyn SpatialPartitioner>,
+    partitioner: Box<dyn SpatialPartitioner>,
     partitioned_side: PartitionedSide,
     slots: PartitionSlots,
     /// Spill files for each spatial partition.
@@ -330,7 +330,7 @@ pub struct StreamRepartitioner {
 /// - `spilled_batch_in_memory_size_threshold`: `None`
 pub struct StreamRepartitionerBuilder {
     runtime_env: Arc<RuntimeEnv>,
-    partitioner: Arc<dyn SpatialPartitioner>,
+    partitioner: Box<dyn SpatialPartitioner>,
     partitioned_side: PartitionedSide,
     spill_compression: SpillCompression,
     spill_metrics: SpillMetrics,
@@ -407,7 +407,7 @@ impl StreamRepartitioner {
     /// spill metrics). Optional parameters can then be set on the returned 
builder.
     pub fn builder(
         runtime_env: Arc<RuntimeEnv>,
-        partitioner: Arc<dyn SpatialPartitioner>,
+        partitioner: Box<dyn SpatialPartitioner>,
         partitioned_side: PartitionedSide,
         spill_metrics: SpillMetrics,
     ) -> StreamRepartitionerBuilder {
@@ -840,7 +840,7 @@ mod tests {
             BoundingBox::xy((0.0, 50.0), (0.0, 50.0)),
             BoundingBox::xy((50.0, 100.0), (0.0, 50.0)),
         ];
-        let partitioner = Arc::new(FlatPartitioner::try_new(partitions)?);
+        let partitioner = Box::new(FlatPartitioner::try_new(partitions)?);
         let runtime_env = Arc::new(RuntimeEnv::default());
         let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
 
@@ -926,7 +926,7 @@ mod tests {
             BoundingBox::xy((0.0, 50.0), (0.0, 50.0)),
             BoundingBox::xy((50.0, 100.0), (0.0, 50.0)),
         ];
-        let partitioner = Arc::new(FlatPartitioner::try_new(partitions)?);
+        let partitioner = Box::new(FlatPartitioner::try_new(partitions)?);
         let runtime_env = Arc::new(RuntimeEnv::default());
         let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
 
@@ -990,7 +990,7 @@ mod tests {
             BoundingBox::xy((0.0, 50.0), (0.0, 50.0)),
             BoundingBox::xy((50.0, 100.0), (0.0, 50.0)),
         ];
-        let partitioner = Arc::new(FlatPartitioner::try_new(partitions)?);
+        let partitioner = Box::new(FlatPartitioner::try_new(partitions)?);
         let runtime_env = Arc::new(RuntimeEnv::default());
         let spill_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 
0);
         let mut repartitioner = StreamRepartitioner::builder(
@@ -1035,7 +1035,7 @@ mod tests {
         let batch_a = sample_batch(&[0], vec![Some(wkb_point((10.0, 
10.0)).unwrap())])?;
         let batch_b = sample_batch(&[1], vec![Some(wkb_point((20.0, 
10.0)).unwrap())])?;
         let partitions = vec![BoundingBox::xy((0.0, 50.0), (0.0, 50.0))];
-        let partitioner = Arc::new(FlatPartitioner::try_new(partitions)?);
+        let partitioner = Box::new(FlatPartitioner::try_new(partitions)?);
         let runtime_env = Arc::new(RuntimeEnv::default());
         let spill_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 
0);
         let mut repartitioner = StreamRepartitioner::builder(
@@ -1069,7 +1069,7 @@ mod tests {
         let batch_a = sample_batch(&[0], vec![Some(wkb_point((10.0, 
10.0)).unwrap())])?;
         let batch_b = sample_batch(&[1], vec![Some(wkb_point((20.0, 
10.0)).unwrap())])?;
         let partitions = vec![BoundingBox::xy((0.0, 50.0), (0.0, 50.0))];
-        let partitioner = Arc::new(FlatPartitioner::try_new(partitions)?);
+        let partitioner = Box::new(FlatPartitioner::try_new(partitions)?);
         let runtime_env = Arc::new(RuntimeEnv::default());
         let spill_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 
0);
         let mut repartitioner = StreamRepartitioner::builder(
diff --git a/rust/sedona-spatial-join/src/prepare.rs 
b/rust/sedona-spatial-join/src/prepare.rs
index 309eca12..9e9a2053 100644
--- a/rust/sedona-spatial-join/src/prepare.rs
+++ b/rust/sedona-spatial-join/src/prepare.rs
@@ -244,8 +244,8 @@ impl SpatialJoinComponentsBuilder {
         num_partitions: usize,
         build_partitions: &mut Vec<BuildPartition>,
         seed: u64,
-    ) -> Result<Arc<dyn SpatialPartitioner>> {
-        let build_partitioner: Arc<dyn SpatialPartitioner> = if matches!(
+    ) -> Result<Box<dyn SpatialPartitioner>> {
+        let build_partitioner: Box<dyn SpatialPartitioner> = if matches!(
             self.spatial_predicate,
             SpatialPredicate::KNearestNeighbors(_)
         ) {
@@ -253,7 +253,7 @@ impl SpatialJoinComponentsBuilder {
             // partitioning to spread the indexed data evenly to make each 
index fit in memory, and
             // the probe side will be broadcasted to all partitions by 
partitioning all of them to
             // the Multi partition.
-            Arc::new(RoundRobinPartitioner::new(num_partitions))
+            Box::new(RoundRobinPartitioner::new(num_partitions))
         } else {
             // Use spatial partitioners to partition the build side and the 
probe side, this will
             // reduce the amount of work needed for probing each partitioned 
index.
@@ -290,7 +290,7 @@ impl SpatialJoinComponentsBuilder {
                 kdb_partitioner.debug_str()
             );
 
-            Arc::new(kdb_partitioner)
+            Box::new(kdb_partitioner)
         };
 
         Ok(build_partitioner)
@@ -302,12 +302,12 @@ impl SpatialJoinComponentsBuilder {
         &self,
         num_partitions: usize,
         merged_spilled_partitions: &SpilledPartitions,
-    ) -> Result<Arc<dyn SpatialPartitioner>> {
-        let probe_partitioner: Arc<dyn SpatialPartitioner> = if matches!(
+    ) -> Result<Box<dyn SpatialPartitioner>> {
+        let probe_partitioner: Box<dyn SpatialPartitioner> = if matches!(
             self.spatial_predicate,
             SpatialPredicate::KNearestNeighbors(_)
         ) {
-            Arc::new(BroadcastPartitioner::new(num_partitions))
+            Box::new(BroadcastPartitioner::new(num_partitions))
         } else {
             // Build a flat partitioner using these partitions
             let mut partition_bounds = Vec::with_capacity(num_partitions);
@@ -320,7 +320,7 @@ impl SpatialJoinComponentsBuilder {
                     .unwrap_or(BoundingBox::empty());
                 partition_bounds.push(partition_bound);
             }
-            Arc::new(FlatPartitioner::try_new(partition_bounds)?)
+            Box::new(FlatPartitioner::try_new(partition_bounds)?)
         };
         Ok(probe_partitioner)
     }
@@ -330,7 +330,7 @@ impl SpatialJoinComponentsBuilder {
     async fn repartition_build_side(
         &self,
         build_partitions: Vec<BuildPartition>,
-        build_partitioner: Arc<dyn SpatialPartitioner>,
+        build_partitioner: Box<dyn SpatialPartitioner>,
         memory_plan: &MemoryPlan,
     ) -> Result<Vec<SpilledPartitions>> {
         // Spawn each task for each build partition to repartition the data 
using the spatial partitioner for
@@ -349,7 +349,7 @@ impl SpatialJoinComponentsBuilder {
             let metrics = &partition.metrics;
             let spill_metrics = metrics.spill_metrics();
             let runtime_env = Arc::clone(&runtime_env);
-            let partitioner = Arc::clone(&build_partitioner);
+            let partitioner = build_partitioner.box_clone();
             join_set.spawn(async move {
                 let partitioned_spill_files = StreamRepartitioner::builder(
                     runtime_env,
@@ -437,7 +437,7 @@ impl SpatialJoinComponentsBuilder {
     fn create_multi_partitioned_spatial_join_components(
         self,
         merged_spilled_partitions: SpilledPartitions,
-        probe_partitioner: Arc<dyn SpatialPartitioner>,
+        probe_partitioner: Box<dyn SpatialPartitioner>,
         reservations: Vec<MemoryReservation>,
         memory_plan: &MemoryPlan,
     ) -> Result<SpatialJoinComponents> {
@@ -459,13 +459,13 @@ impl SpatialJoinComponentsBuilder {
         );
 
         let buffer_bytes_threshold = memory_for_intermittent_usage / 
self.probe_threads_count;
-        let probe_stream_options = ProbeStreamOptions {
-            partitioner: Some(probe_partitioner),
-            target_batch_rows: target_batch_size,
+        let probe_stream_options = ProbeStreamOptions::new(
+            Some(probe_partitioner),
+            target_batch_size,
             spill_compression,
             buffer_bytes_threshold,
             spilled_batch_in_memory_size_threshold,
-        };
+        );
 
         Ok(SpatialJoinComponents {
             partitioned_index_provider: Arc::new(partitioned_index_provider),
diff --git a/rust/sedona-spatial-join/src/probe/first_pass_stream.rs 
b/rust/sedona-spatial-join/src/probe/first_pass_stream.rs
index 7e75b344..9bc691cf 100644
--- a/rust/sedona-spatial-join/src/probe/first_pass_stream.rs
+++ b/rust/sedona-spatial-join/src/probe/first_pass_stream.rs
@@ -51,7 +51,7 @@ use crate::{
 pub(crate) struct FirstPassStream<C: FirstPassStreamCallback> {
     source: SendableEvaluatedBatchStream,
     repartitioner: Option<StreamRepartitioner>,
-    partitioner: Arc<dyn SpatialPartitioner>,
+    partitioner: Box<dyn SpatialPartitioner>,
     pending_output: VecDeque<Result<EvaluatedBatch>>,
     metrics: ProbeStreamMetrics,
     callback: Option<C>,
@@ -71,7 +71,7 @@ impl<C: FirstPassStreamCallback> FirstPassStream<C> {
     pub fn new(
         source: SendableEvaluatedBatchStream,
         repartitioner: StreamRepartitioner,
-        partitioner: Arc<dyn SpatialPartitioner>,
+        partitioner: Box<dyn SpatialPartitioner>,
         metrics: ProbeStreamMetrics,
         callback: C,
     ) -> Self {
diff --git a/rust/sedona-spatial-join/src/probe/partitioned_stream_provider.rs 
b/rust/sedona-spatial-join/src/probe/partitioned_stream_provider.rs
index 14c45b31..ef4e91c6 100644
--- a/rust/sedona-spatial-join/src/probe/partitioned_stream_provider.rs
+++ b/rust/sedona-spatial-join/src/probe/partitioned_stream_provider.rs
@@ -38,7 +38,6 @@ use crate::{
     },
 };
 
-#[derive(Clone)]
 /// Configuration options for creating a probe-side stream provider.
 ///
 /// When a `partitioner` is provided, the provider performs an initial first 
pass that
@@ -50,7 +49,12 @@ pub(crate) struct ProbeStreamOptions {
     /// - `None` means the probe side is treated as a single, non-partitioned 
stream and only
     ///   [`SpatialPartition::Regular(0)`] is supported.
     /// - `Some(_)` enables partitioned streaming with a warm-up (first) pass.
-    pub partitioner: Option<Arc<dyn SpatialPartitioner>>,
+    ///
+    /// The `Mutex` is used here to make [`ProbeStreamOptions`] (and its 
contained options)
+    /// `Send + Sync` so it can be shared/cloned into `SpatialJoinExec` and 
across tasks.
+    /// The partitioner itself is treated as a clonable prototype and is not 
intended to be
+    /// used by multiple tasks concurrently via this shared `Mutex`.
+    pub partitioner: Option<Mutex<Box<dyn SpatialPartitioner>>>,
     /// Target number of rows per output batch produced by the partitioning 
stream.
     pub target_batch_rows: usize,
     /// Spill compression to use when writing partition spill files.
@@ -62,6 +66,41 @@ pub(crate) struct ProbeStreamOptions {
     pub spilled_batch_in_memory_size_threshold: Option<usize>,
 }
 
+impl ProbeStreamOptions {
+    pub fn new(
+        partitioner: Option<Box<dyn SpatialPartitioner>>,
+        target_batch_rows: usize,
+        spill_compression: SpillCompression,
+        buffer_bytes_threshold: usize,
+        spilled_batch_in_memory_size_threshold: Option<usize>,
+    ) -> Self {
+        let partitioner = partitioner.map(Mutex::new);
+        Self {
+            partitioner,
+            target_batch_rows,
+            spill_compression,
+            buffer_bytes_threshold,
+            spilled_batch_in_memory_size_threshold,
+        }
+    }
+}
+
+impl Clone for ProbeStreamOptions {
+    fn clone(&self) -> Self {
+        let cloned_partitioner = self
+            .partitioner
+            .as_ref()
+            .map(|p| Mutex::new(p.lock().box_clone()));
+        Self {
+            partitioner: cloned_partitioner,
+            target_batch_rows: self.target_batch_rows,
+            spill_compression: self.spill_compression,
+            buffer_bytes_threshold: self.buffer_bytes_threshold,
+            spilled_batch_in_memory_size_threshold: 
self.spilled_batch_in_memory_size_threshold,
+        }
+    }
+}
+
 /// Provides probe-side streams for a given [`SpatialPartition`].
 ///
 /// For partitioned joins this provider is a small state machine:
@@ -144,15 +183,17 @@ impl PartitionedProbeStreamProvider {
         let mut state_guard = self.state.lock();
         match std::mem::replace(&mut *state_guard, 
ProbeStreamState::FirstPass) {
             ProbeStreamState::Pending { source } => {
-                let partitioner = Arc::clone(
-                    self.options
-                        .partitioner
-                        .as_ref()
-                        .expect("Partitioned first pass requires a 
partitioner"),
-                );
+                let partitioner_for_stream = self
+                    .options
+                    .partitioner
+                    .as_ref()
+                    .expect("Partitioned first pass requires a partitioner")
+                    .lock()
+                    .box_clone();
+                let partitioner_for_repartitioner = 
partitioner_for_stream.box_clone();
                 let repartitioner = StreamRepartitioner::builder(
                     Arc::clone(&self.runtime_env),
-                    Arc::clone(&partitioner),
+                    partitioner_for_repartitioner,
                     PartitionedSide::ProbeSide,
                     self.metrics.spill_metrics.clone(),
                 )
@@ -206,7 +247,7 @@ impl PartitionedProbeStreamProvider {
                 let first_pass = FirstPassStream::new(
                     source,
                     repartitioner,
-                    partitioner,
+                    partitioner_for_stream,
                     self.metrics.clone(),
                     callback,
                 );
@@ -384,7 +425,7 @@ mod tests {
 
     fn create_probe_stream(
         batches: Vec<EvaluatedBatch>,
-        partitioner: Option<Arc<dyn SpatialPartitioner>>,
+        partitioner: Option<Box<dyn SpatialPartitioner>>,
     ) -> PartitionedProbeStreamProvider {
         let runtime_env = Arc::new(RuntimeEnv::default());
         assert!(!batches.is_empty(), "test batches should not be empty");
@@ -393,24 +434,18 @@ mod tests {
             Box::pin(InMemoryEvaluatedBatchStream::new(schema, batches));
         PartitionedProbeStreamProvider::new(
             runtime_env,
-            ProbeStreamOptions {
-                partitioner,
-                target_batch_rows: 1024,
-                spill_compression: SpillCompression::Uncompressed,
-                buffer_bytes_threshold: 0,
-                spilled_batch_in_memory_size_threshold: None,
-            },
+            ProbeStreamOptions::new(partitioner, 1024, 
SpillCompression::Uncompressed, 0, None),
             stream,
             ProbeStreamMetrics::new(0, &ExecutionPlanMetricsSet::new()),
         )
     }
 
-    fn sample_partitioner() -> Result<Arc<dyn SpatialPartitioner>> {
+    fn sample_partitioner() -> Result<Box<dyn SpatialPartitioner>> {
         let partitions = vec![
             BoundingBox::xy((0.0, 50.0), (0.0, 50.0)),
             BoundingBox::xy((50.0, 100.0), (0.0, 50.0)),
         ];
-        Ok(Arc::new(FlatPartitioner::try_new(partitions)?))
+        Ok(Box::new(FlatPartitioner::try_new(partitions)?))
     }
 
     #[tokio::test]
@@ -425,7 +460,7 @@ mod tests {
                 Some(wkb_point((200.0, 200.0)).unwrap()),
             ],
         )?;
-        let probe_stream = create_probe_stream(vec![batch], 
Some(Arc::clone(&partitioner)));
+        let probe_stream = create_probe_stream(vec![batch], Some(partitioner));
 
         let first_pass = 
probe_stream.stream_for(SpatialPartition::Regular(0))?;
         let batches = first_pass.try_collect::<Vec<_>>().await?;
@@ -448,7 +483,7 @@ mod tests {
     async fn requesting_regular_partition_before_first_pass_fails() -> 
Result<()> {
         let partitioner = sample_partitioner()?;
         let batch = sample_batch(&[0], vec![Some(wkb_point((60.0, 
10.0)).unwrap())])?;
-        let probe_stream = create_probe_stream(vec![batch], 
Some(Arc::clone(&partitioner)));
+        let probe_stream = create_probe_stream(vec![batch], Some(partitioner));
         assert!(probe_stream
             .stream_for(SpatialPartition::Regular(1))
             .is_err());

Reply via email to