This is an automated email from the ASF dual-hosted git repository. paleolimbot pushed a commit to branch branch-0.3.0 in repository https://gitbox.apache.org/repos/asf/sedona-db.git
commit 7a3ab1c2f9f163aa1194c509924d9de98447f9b4 Author: Kristin Cowalcijk <[email protected]> AuthorDate: Wed Mar 4 23:40:07 2026 +0800 fix(rust/sedona-spatial-join): wrap probe-side repartition in ProbeShuffleExec to prevent optimizer stripping (#677) --- rust/sedona-spatial-join/src/lib.rs | 4 + rust/sedona-spatial-join/src/planner.rs | 1 + .../src/planner/physical_planner.rs | 10 +- .../src/planner/probe_shuffle_exec.rs | 225 +++++++++++++++++++++ .../tests/spatial_join_integration.rs | 37 +++- 5 files changed, 267 insertions(+), 10 deletions(-) diff --git a/rust/sedona-spatial-join/src/lib.rs b/rust/sedona-spatial-join/src/lib.rs index 3a044825..4e4b5b08 100644 --- a/rust/sedona-spatial-join/src/lib.rs +++ b/rust/sedona-spatial-join/src/lib.rs @@ -33,6 +33,10 @@ pub use exec::SpatialJoinExec; // Re-export function for register the spatial join planner pub use planner::register_planner; +// Re-export ProbeShuffleExec so that integration tests (and other crates) can verify +// its presence in optimized physical plans. +pub use planner::probe_shuffle_exec::ProbeShuffleExec; + // Re-export types needed for external usage (e.g., in Comet) pub use index::{SpatialIndex, SpatialJoinBuildMetrics}; pub use spatial_predicate::SpatialPredicate; diff --git a/rust/sedona-spatial-join/src/planner.rs b/rust/sedona-spatial-join/src/planner.rs index e40303bf..b663b293 100644 --- a/rust/sedona-spatial-join/src/planner.rs +++ b/rust/sedona-spatial-join/src/planner.rs @@ -26,6 +26,7 @@ use datafusion_common::Result; mod logical_plan_node; mod optimizer; mod physical_planner; +pub mod probe_shuffle_exec; mod spatial_expr_utils; /// Register Sedona spatial join planning hooks. diff --git a/rust/sedona-spatial-join/src/planner/physical_planner.rs b/rust/sedona-spatial-join/src/planner/physical_planner.rs index d0341154..33fa3a44 100644 --- a/rust/sedona-spatial-join/src/planner/physical_planner.rs +++ b/rust/sedona-spatial-join/src/planner/physical_planner.rs @@ -31,15 +31,13 @@ use datafusion_common::{plan_err, DFSchema, JoinSide, Result}; use datafusion_expr::logical_plan::UserDefinedLogicalNode; use datafusion_expr::LogicalPlan; use datafusion_physical_expr::create_physical_expr; -use datafusion_physical_expr::Partitioning; use datafusion_physical_plan::joins::utils::JoinFilter; use datafusion_physical_plan::joins::NestedLoopJoinExec; -use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::ExecutionPlanProperties; use sedona_common::sedona_internal_err; use crate::exec::SpatialJoinExec; use crate::planner::logical_plan_node::SpatialJoinPlanNode; +use crate::planner::probe_shuffle_exec::ProbeShuffleExec; use crate::planner::spatial_expr_utils::{is_spatial_predicate_supported, transform_join_filter}; use crate::spatial_predicate::SpatialPredicate; use sedona_common::option::SedonaOptions; @@ -325,11 +323,7 @@ fn repartition_probe_side( } }; - let num_partitions = probe_plan.output_partitioning().partition_count(); - *probe_plan = Arc::new(RepartitionExec::try_new( - Arc::clone(probe_plan), - Partitioning::RoundRobinBatch(num_partitions), - )?); + *probe_plan = Arc::new(ProbeShuffleExec::try_new(Arc::clone(probe_plan))?); Ok((physical_left, physical_right)) } diff --git a/rust/sedona-spatial-join/src/planner/probe_shuffle_exec.rs b/rust/sedona-spatial-join/src/planner/probe_shuffle_exec.rs new file mode 100644 index 00000000..65d73469 --- /dev/null +++ b/rust/sedona-spatial-join/src/planner/probe_shuffle_exec.rs @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ProbeShuffleExec`] — a round-robin repartitioning wrapper that is invisible +//! to DataFusion's `EnforceDistribution` / `EnforceSorting` optimizer passes. +//! +//! Those passes unconditionally strip every [`RepartitionExec`] before +//! re-evaluating distribution requirements. Because `SpatialJoinExec` reports +//! `UnspecifiedDistribution` for its inputs, a bare `RepartitionExec` that was +//! inserted by the extension planner is removed and never re-added. +//! +//! `ProbeShuffleExec` wraps a hidden, internal `RepartitionExec` so that: +//! * **Optimizer passes** see an opaque node (not a `RepartitionExec`) and leave +//! it alone. +//! * **`children()` / `with_new_children()`** expose the *original* input so +//! the rest of the optimizer tree can still be rewritten normally. +//! * **`execute()`** delegates to the internal `RepartitionExec` which performs +//! the actual round-robin shuffle. + +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::{internal_err, plan_err, Result, Statistics}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_plan::execution_plan::CardinalityEffect; +use datafusion_physical_plan::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation, +}; +use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, + PlanProperties, +}; + +/// A round-robin repartitioning node that is invisible to DataFusion's +/// physical optimizer passes. +/// +/// See [module-level documentation](self) for motivation and design. +#[derive(Debug)] +pub struct ProbeShuffleExec { + inner_repartition: RepartitionExec, +} + +impl ProbeShuffleExec { + /// Create a new [`ProbeShuffleExec`] that round-robin repartitions `input` + /// into the same number of output partitions as `input`. This will ensure + /// that the probe workload of a spatial join will be evenly distributed. + /// More importantly, shuffled probe side data will be less likely to + /// cause skew issues when out-of-core, spatial partitioned spatial join is enabled, + /// especially when the input probe data is sorted by their spatial locations. + pub fn try_new(input: Arc<dyn ExecutionPlan>) -> Result<Self> { + let num_partitions = input.output_partitioning().partition_count(); + let inner_repartition = RepartitionExec::try_new( + Arc::clone(&input), + Partitioning::RoundRobinBatch(num_partitions), + )?; + Ok(Self { inner_repartition }) + } + + /// Try to wrap the given [`RepartitionExec`] `plan` with [`ProbeShuffleExec`]. + pub fn try_wrap_repartition(plan: Arc<dyn ExecutionPlan>) -> Result<Self> { + let Some(repartition_exec) = plan.as_any().downcast_ref::<RepartitionExec>() else { + return plan_err!( + "ProbeShuffleExec can only wrap RepartitionExec, but got {}", + plan.name() + ); + }; + Ok(Self { + inner_repartition: repartition_exec.clone(), + }) + } + + /// Number of output partitions. + pub fn num_partitions(&self) -> usize { + self.inner_repartition + .properties() + .output_partitioning() + .partition_count() + } +} + +impl DisplayAs for ProbeShuffleExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "ProbeShuffleExec: partitioning=RoundRobinBatch({})", + self.num_partitions() + ) + } + DisplayFormatType::TreeRender => { + write!(f, "partitioning=RoundRobinBatch({})", self.num_partitions()) + } + } + } +} + +impl ExecutionPlan for ProbeShuffleExec { + fn name(&self) -> &str { + "ProbeShuffleExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + self.inner_repartition.properties() + } + + fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { + vec![self.inner_repartition.input()] + } + + fn with_new_children( + self: Arc<Self>, + mut children: Vec<Arc<dyn ExecutionPlan>>, + ) -> Result<Arc<dyn ExecutionPlan>> { + if children.len() != 1 { + return internal_err!( + "ProbeShuffleExec expects exactly 1 child, got {}", + children.len() + ); + } + let child = children.remove(0); + Ok(Arc::new(Self::try_new(child)?)) + } + + fn execute( + &self, + partition: usize, + context: Arc<TaskContext>, + ) -> Result<SendableRecordBatchStream> { + self.inner_repartition.execute(partition, context) + } + + fn maintains_input_order(&self) -> Vec<bool> { + self.inner_repartition.maintains_input_order() + } + + fn benefits_from_input_partitioning(&self) -> Vec<bool> { + self.inner_repartition.benefits_from_input_partitioning() + } + + fn cardinality_effect(&self) -> CardinalityEffect { + self.inner_repartition.cardinality_effect() + } + + fn metrics(&self) -> Option<MetricsSet> { + self.inner_repartition.metrics() + } + + fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> { + self.inner_repartition.partition_statistics(partition) + } + + fn try_swapping_with_projection( + &self, + projection: &ProjectionExec, + ) -> Result<Option<Arc<dyn ExecutionPlan>>> { + let Some(new_repartition) = self + .inner_repartition + .try_swapping_with_projection(projection)? + else { + return Ok(None); + }; + let new_plan = Self::try_wrap_repartition(new_repartition)?; + Ok(Some(Arc::new(new_plan))) + } + + fn gather_filters_for_pushdown( + &self, + phase: FilterPushdownPhase, + parent_filters: Vec<Arc<dyn PhysicalExpr>>, + config: &ConfigOptions, + ) -> Result<FilterDescription> { + self.inner_repartition + .gather_filters_for_pushdown(phase, parent_filters, config) + } + + fn handle_child_pushdown_result( + &self, + phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + config: &ConfigOptions, + ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> { + self.inner_repartition + .handle_child_pushdown_result(phase, child_pushdown_result, config) + } + + fn repartitioned( + &self, + target_partitions: usize, + config: &ConfigOptions, + ) -> Result<Option<Arc<dyn ExecutionPlan>>> { + let Some(plan) = self + .inner_repartition + .repartitioned(target_partitions, config)? + else { + return Ok(None); + }; + let new_plan = Self::try_wrap_repartition(plan)?; + Ok(Some(Arc::new(new_plan))) + } +} diff --git a/rust/sedona-spatial-join/tests/spatial_join_integration.rs b/rust/sedona-spatial-join/tests/spatial_join_integration.rs index 00047255..35338017 100644 --- a/rust/sedona-spatial-join/tests/spatial_join_integration.rs +++ b/rust/sedona-spatial-join/tests/spatial_join_integration.rs @@ -26,7 +26,7 @@ use datafusion::{ prelude::{SessionConfig, SessionContext}, }; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::Result; +use datafusion_common::{JoinSide, Result}; use datafusion_expr::{ColumnarValue, JoinType}; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::NestedLoopJoinExec; @@ -43,7 +43,8 @@ use sedona_schema::{ matchers::ArgMatcher, }; use sedona_spatial_join::{ - register_planner, spatial_predicate::RelationPredicate, SpatialJoinExec, SpatialPredicate, + register_planner, spatial_predicate::RelationPredicate, ProbeShuffleExec, SpatialJoinExec, + SpatialPredicate, }; use sedona_testing::datagen::RandomPartitionedDataBuilder; use tokio::sync::OnceCell; @@ -801,6 +802,7 @@ async fn run_spatial_join_query( )?); let is_optimized_spatial_join = options.is_some(); + let repartition_probe_side = options.as_ref().is_some_and(|o| o.repartition_probe_side); let ctx = setup_context(options, batch_size)?; ctx.register_table("L", Arc::clone(&mem_table_left))?; ctx.register_table("R", Arc::clone(&mem_table_right))?; @@ -810,6 +812,9 @@ async fn run_spatial_join_query( let spatial_join_execs = collect_spatial_join_exec(&plan)?; if is_optimized_spatial_join { assert_eq!(spatial_join_execs.len(), 1); + if repartition_probe_side { + probe_side_of_spatial_join_exec_should_be_shuffled(spatial_join_execs[0]); + } } else { assert!(spatial_join_execs.is_empty()); } @@ -829,6 +834,20 @@ fn collect_spatial_join_exec(plan: &Arc<dyn ExecutionPlan>) -> Result<Vec<&Spati Ok(spatial_join_execs) } +fn probe_side_of_spatial_join_exec_should_be_shuffled(sj: &SpatialJoinExec) { + let probe_child = match &sj.on { + SpatialPredicate::KNearestNeighbors(knn) => match knn.probe_side { + JoinSide::Left => &sj.left, + _ => &sj.right, + }, + _ => &sj.right, // non-KNN: probe is always right after swap + }; + assert!( + subtree_contains_probe_shuffle_exec(probe_child), + "ProbeShuffleExec should be present on the probe side of SpatialJoinExec" + ); +} + async fn test_mark_join( join_type: JoinType, options: SpatialJoinOptions, @@ -1613,6 +1632,20 @@ fn subtree_contains_filter_exec(plan: &Arc<dyn ExecutionPlan>) -> bool { found } +/// Recursively check whether any node in the physical plan tree is a `ProbeShuffleExec`. +fn subtree_contains_probe_shuffle_exec(plan: &Arc<dyn ExecutionPlan>) -> bool { + let mut found = false; + plan.apply(|node| { + if node.as_any().downcast_ref::<ProbeShuffleExec>().is_some() { + found = true; + return Ok(TreeNodeRecursion::Stop); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("failed to walk plan"); + found +} + /// Create a session context with two small tables for filter-pushdown tests. /// /// L(id INT, x DOUBLE) and R(id INT, x DOUBLE) are all empty, this is just for exercising the
