This is an automated email from the ASF dual-hosted git repository.

paleolimbot pushed a commit to branch branch-0.3.0
in repository https://gitbox.apache.org/repos/asf/sedona-db.git

commit 7a3ab1c2f9f163aa1194c509924d9de98447f9b4
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Mar 4 23:40:07 2026 +0800

    fix(rust/sedona-spatial-join): wrap probe-side repartition in 
ProbeShuffleExec to prevent optimizer stripping (#677)
---
 rust/sedona-spatial-join/src/lib.rs                |   4 +
 rust/sedona-spatial-join/src/planner.rs            |   1 +
 .../src/planner/physical_planner.rs                |  10 +-
 .../src/planner/probe_shuffle_exec.rs              | 225 +++++++++++++++++++++
 .../tests/spatial_join_integration.rs              |  37 +++-
 5 files changed, 267 insertions(+), 10 deletions(-)

diff --git a/rust/sedona-spatial-join/src/lib.rs 
b/rust/sedona-spatial-join/src/lib.rs
index 3a044825..4e4b5b08 100644
--- a/rust/sedona-spatial-join/src/lib.rs
+++ b/rust/sedona-spatial-join/src/lib.rs
@@ -33,6 +33,10 @@ pub use exec::SpatialJoinExec;
 // Re-export function for register the spatial join planner
 pub use planner::register_planner;
 
+// Re-export ProbeShuffleExec so that integration tests (and other crates) can 
verify
+// its presence in optimized physical plans.
+pub use planner::probe_shuffle_exec::ProbeShuffleExec;
+
 // Re-export types needed for external usage (e.g., in Comet)
 pub use index::{SpatialIndex, SpatialJoinBuildMetrics};
 pub use spatial_predicate::SpatialPredicate;
diff --git a/rust/sedona-spatial-join/src/planner.rs 
b/rust/sedona-spatial-join/src/planner.rs
index e40303bf..b663b293 100644
--- a/rust/sedona-spatial-join/src/planner.rs
+++ b/rust/sedona-spatial-join/src/planner.rs
@@ -26,6 +26,7 @@ use datafusion_common::Result;
 mod logical_plan_node;
 mod optimizer;
 mod physical_planner;
+pub mod probe_shuffle_exec;
 mod spatial_expr_utils;
 
 /// Register Sedona spatial join planning hooks.
diff --git a/rust/sedona-spatial-join/src/planner/physical_planner.rs 
b/rust/sedona-spatial-join/src/planner/physical_planner.rs
index d0341154..33fa3a44 100644
--- a/rust/sedona-spatial-join/src/planner/physical_planner.rs
+++ b/rust/sedona-spatial-join/src/planner/physical_planner.rs
@@ -31,15 +31,13 @@ use datafusion_common::{plan_err, DFSchema, JoinSide, 
Result};
 use datafusion_expr::logical_plan::UserDefinedLogicalNode;
 use datafusion_expr::LogicalPlan;
 use datafusion_physical_expr::create_physical_expr;
-use datafusion_physical_expr::Partitioning;
 use datafusion_physical_plan::joins::utils::JoinFilter;
 use datafusion_physical_plan::joins::NestedLoopJoinExec;
-use datafusion_physical_plan::repartition::RepartitionExec;
-use datafusion_physical_plan::ExecutionPlanProperties;
 use sedona_common::sedona_internal_err;
 
 use crate::exec::SpatialJoinExec;
 use crate::planner::logical_plan_node::SpatialJoinPlanNode;
+use crate::planner::probe_shuffle_exec::ProbeShuffleExec;
 use crate::planner::spatial_expr_utils::{is_spatial_predicate_supported, 
transform_join_filter};
 use crate::spatial_predicate::SpatialPredicate;
 use sedona_common::option::SedonaOptions;
@@ -325,11 +323,7 @@ fn repartition_probe_side(
         }
     };
 
-    let num_partitions = probe_plan.output_partitioning().partition_count();
-    *probe_plan = Arc::new(RepartitionExec::try_new(
-        Arc::clone(probe_plan),
-        Partitioning::RoundRobinBatch(num_partitions),
-    )?);
+    *probe_plan = Arc::new(ProbeShuffleExec::try_new(Arc::clone(probe_plan))?);
 
     Ok((physical_left, physical_right))
 }
diff --git a/rust/sedona-spatial-join/src/planner/probe_shuffle_exec.rs 
b/rust/sedona-spatial-join/src/planner/probe_shuffle_exec.rs
new file mode 100644
index 00000000..65d73469
--- /dev/null
+++ b/rust/sedona-spatial-join/src/planner/probe_shuffle_exec.rs
@@ -0,0 +1,225 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`ProbeShuffleExec`] — a round-robin repartitioning wrapper that is 
invisible
+//! to DataFusion's `EnforceDistribution` / `EnforceSorting` optimizer passes.
+//!
+//! Those passes unconditionally strip every [`RepartitionExec`] before
+//! re-evaluating distribution requirements.  Because `SpatialJoinExec` reports
+//! `UnspecifiedDistribution` for its inputs, a bare `RepartitionExec` that was
+//! inserted by the extension planner is removed and never re-added.
+//!
+//! `ProbeShuffleExec` wraps a hidden, internal `RepartitionExec` so that:
+//! * **Optimizer passes** see an opaque node (not a `RepartitionExec`) and 
leave
+//!   it alone.
+//! * **`children()` / `with_new_children()`** expose the *original* input so
+//!   the rest of the optimizer tree can still be rewritten normally.
+//! * **`execute()`** delegates to the internal `RepartitionExec` which 
performs
+//!   the actual round-robin shuffle.
+
+use std::any::Any;
+use std::fmt;
+use std::sync::Arc;
+
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::{internal_err, plan_err, Result, Statistics};
+use datafusion_execution::{SendableRecordBatchStream, TaskContext};
+use datafusion_physical_expr::PhysicalExpr;
+use datafusion_physical_plan::execution_plan::CardinalityEffect;
+use datafusion_physical_plan::filter_pushdown::{
+    ChildPushdownResult, FilterDescription, FilterPushdownPhase, 
FilterPushdownPropagation,
+};
+use datafusion_physical_plan::metrics::MetricsSet;
+use datafusion_physical_plan::projection::ProjectionExec;
+use datafusion_physical_plan::repartition::RepartitionExec;
+use datafusion_physical_plan::{
+    DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, 
Partitioning,
+    PlanProperties,
+};
+
+/// A round-robin repartitioning node that is invisible to DataFusion's
+/// physical optimizer passes.
+///
+/// See [module-level documentation](self) for motivation and design.
+#[derive(Debug)]
+pub struct ProbeShuffleExec {
+    inner_repartition: RepartitionExec,
+}
+
+impl ProbeShuffleExec {
+    /// Create a new [`ProbeShuffleExec`] that round-robin repartitions `input`
+    /// into the same number of output partitions as `input`. This will ensure
+    /// that the probe workload of a spatial join will be evenly distributed.
+    /// More importantly, shuffled probe side data will be less likely to
+    /// cause skew issues when out-of-core, spatial partitioned spatial join 
is enabled,
+    /// especially when the input probe data is sorted by their spatial 
locations.
+    pub fn try_new(input: Arc<dyn ExecutionPlan>) -> Result<Self> {
+        let num_partitions = input.output_partitioning().partition_count();
+        let inner_repartition = RepartitionExec::try_new(
+            Arc::clone(&input),
+            Partitioning::RoundRobinBatch(num_partitions),
+        )?;
+        Ok(Self { inner_repartition })
+    }
+
+    /// Try to wrap the given [`RepartitionExec`] `plan` with 
[`ProbeShuffleExec`].
+    pub fn try_wrap_repartition(plan: Arc<dyn ExecutionPlan>) -> Result<Self> {
+        let Some(repartition_exec) = 
plan.as_any().downcast_ref::<RepartitionExec>() else {
+            return plan_err!(
+                "ProbeShuffleExec can only wrap RepartitionExec, but got {}",
+                plan.name()
+            );
+        };
+        Ok(Self {
+            inner_repartition: repartition_exec.clone(),
+        })
+    }
+
+    /// Number of output partitions.
+    pub fn num_partitions(&self) -> usize {
+        self.inner_repartition
+            .properties()
+            .output_partitioning()
+            .partition_count()
+    }
+}
+
+impl DisplayAs for ProbeShuffleExec {
+    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> 
fmt::Result {
+        match t {
+            DisplayFormatType::Default | DisplayFormatType::Verbose => {
+                write!(
+                    f,
+                    "ProbeShuffleExec: partitioning=RoundRobinBatch({})",
+                    self.num_partitions()
+                )
+            }
+            DisplayFormatType::TreeRender => {
+                write!(f, "partitioning=RoundRobinBatch({})", 
self.num_partitions())
+            }
+        }
+    }
+}
+
+impl ExecutionPlan for ProbeShuffleExec {
+    fn name(&self) -> &str {
+        "ProbeShuffleExec"
+    }
+
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn properties(&self) -> &PlanProperties {
+        self.inner_repartition.properties()
+    }
+
+    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
+        vec![self.inner_repartition.input()]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        mut children: Vec<Arc<dyn ExecutionPlan>>,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        if children.len() != 1 {
+            return internal_err!(
+                "ProbeShuffleExec expects exactly 1 child, got {}",
+                children.len()
+            );
+        }
+        let child = children.remove(0);
+        Ok(Arc::new(Self::try_new(child)?))
+    }
+
+    fn execute(
+        &self,
+        partition: usize,
+        context: Arc<TaskContext>,
+    ) -> Result<SendableRecordBatchStream> {
+        self.inner_repartition.execute(partition, context)
+    }
+
+    fn maintains_input_order(&self) -> Vec<bool> {
+        self.inner_repartition.maintains_input_order()
+    }
+
+    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
+        self.inner_repartition.benefits_from_input_partitioning()
+    }
+
+    fn cardinality_effect(&self) -> CardinalityEffect {
+        self.inner_repartition.cardinality_effect()
+    }
+
+    fn metrics(&self) -> Option<MetricsSet> {
+        self.inner_repartition.metrics()
+    }
+
+    fn partition_statistics(&self, partition: Option<usize>) -> 
Result<Statistics> {
+        self.inner_repartition.partition_statistics(partition)
+    }
+
+    fn try_swapping_with_projection(
+        &self,
+        projection: &ProjectionExec,
+    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
+        let Some(new_repartition) = self
+            .inner_repartition
+            .try_swapping_with_projection(projection)?
+        else {
+            return Ok(None);
+        };
+        let new_plan = Self::try_wrap_repartition(new_repartition)?;
+        Ok(Some(Arc::new(new_plan)))
+    }
+
+    fn gather_filters_for_pushdown(
+        &self,
+        phase: FilterPushdownPhase,
+        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
+        config: &ConfigOptions,
+    ) -> Result<FilterDescription> {
+        self.inner_repartition
+            .gather_filters_for_pushdown(phase, parent_filters, config)
+    }
+
+    fn handle_child_pushdown_result(
+        &self,
+        phase: FilterPushdownPhase,
+        child_pushdown_result: ChildPushdownResult,
+        config: &ConfigOptions,
+    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
+        self.inner_repartition
+            .handle_child_pushdown_result(phase, child_pushdown_result, config)
+    }
+
+    fn repartitioned(
+        &self,
+        target_partitions: usize,
+        config: &ConfigOptions,
+    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
+        let Some(plan) = self
+            .inner_repartition
+            .repartitioned(target_partitions, config)?
+        else {
+            return Ok(None);
+        };
+        let new_plan = Self::try_wrap_repartition(plan)?;
+        Ok(Some(Arc::new(new_plan)))
+    }
+}
diff --git a/rust/sedona-spatial-join/tests/spatial_join_integration.rs 
b/rust/sedona-spatial-join/tests/spatial_join_integration.rs
index 00047255..35338017 100644
--- a/rust/sedona-spatial-join/tests/spatial_join_integration.rs
+++ b/rust/sedona-spatial-join/tests/spatial_join_integration.rs
@@ -26,7 +26,7 @@ use datafusion::{
     prelude::{SessionConfig, SessionContext},
 };
 use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
-use datafusion_common::Result;
+use datafusion_common::{JoinSide, Result};
 use datafusion_expr::{ColumnarValue, JoinType};
 use datafusion_physical_plan::filter::FilterExec;
 use datafusion_physical_plan::joins::NestedLoopJoinExec;
@@ -43,7 +43,8 @@ use sedona_schema::{
     matchers::ArgMatcher,
 };
 use sedona_spatial_join::{
-    register_planner, spatial_predicate::RelationPredicate, SpatialJoinExec, 
SpatialPredicate,
+    register_planner, spatial_predicate::RelationPredicate, ProbeShuffleExec, 
SpatialJoinExec,
+    SpatialPredicate,
 };
 use sedona_testing::datagen::RandomPartitionedDataBuilder;
 use tokio::sync::OnceCell;
@@ -801,6 +802,7 @@ async fn run_spatial_join_query(
     )?);
 
     let is_optimized_spatial_join = options.is_some();
+    let repartition_probe_side = options.as_ref().is_some_and(|o| 
o.repartition_probe_side);
     let ctx = setup_context(options, batch_size)?;
     ctx.register_table("L", Arc::clone(&mem_table_left))?;
     ctx.register_table("R", Arc::clone(&mem_table_right))?;
@@ -810,6 +812,9 @@ async fn run_spatial_join_query(
     let spatial_join_execs = collect_spatial_join_exec(&plan)?;
     if is_optimized_spatial_join {
         assert_eq!(spatial_join_execs.len(), 1);
+        if repartition_probe_side {
+            
probe_side_of_spatial_join_exec_should_be_shuffled(spatial_join_execs[0]);
+        }
     } else {
         assert!(spatial_join_execs.is_empty());
     }
@@ -829,6 +834,20 @@ fn collect_spatial_join_exec(plan: &Arc<dyn 
ExecutionPlan>) -> Result<Vec<&Spati
     Ok(spatial_join_execs)
 }
 
+fn probe_side_of_spatial_join_exec_should_be_shuffled(sj: &SpatialJoinExec) {
+    let probe_child = match &sj.on {
+        SpatialPredicate::KNearestNeighbors(knn) => match knn.probe_side {
+            JoinSide::Left => &sj.left,
+            _ => &sj.right,
+        },
+        _ => &sj.right, // non-KNN: probe is always right after swap
+    };
+    assert!(
+        subtree_contains_probe_shuffle_exec(probe_child),
+        "ProbeShuffleExec should be present on the probe side of 
SpatialJoinExec"
+    );
+}
+
 async fn test_mark_join(
     join_type: JoinType,
     options: SpatialJoinOptions,
@@ -1613,6 +1632,20 @@ fn subtree_contains_filter_exec(plan: &Arc<dyn 
ExecutionPlan>) -> bool {
     found
 }
 
+/// Recursively check whether any node in the physical plan tree is a 
`ProbeShuffleExec`.
+fn subtree_contains_probe_shuffle_exec(plan: &Arc<dyn ExecutionPlan>) -> bool {
+    let mut found = false;
+    plan.apply(|node| {
+        if node.as_any().downcast_ref::<ProbeShuffleExec>().is_some() {
+            found = true;
+            return Ok(TreeNodeRecursion::Stop);
+        }
+        Ok(TreeNodeRecursion::Continue)
+    })
+    .expect("failed to walk plan");
+    found
+}
+
 /// Create a session context with two small tables for filter-pushdown tests.
 ///
 /// L(id INT, x DOUBLE) and R(id INT, x DOUBLE) are all empty, this is just 
for exercising the

Reply via email to