This is an automated email from the ASF dual-hosted git repository.
kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new 55b822f5 fix(rust/sedona-spatial-join): prevent filter pushdown past
KNN joins (#611)
55b822f5 is described below
commit 55b822f5f1ef587d938a44df34a7b3c973b1c5ea
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Feb 18 22:18:36 2026 +0800
fix(rust/sedona-spatial-join): prevent filter pushdown past KNN joins (#611)
## Summary
KNN joins have different semantics than regular spatial joins — pushing
filters to the object (build) side changes which objects are the k nearest
neighbors, producing incorrect results. DataFusion's builtin `PushDownFilter`
optimizer rule doesn't know this and incorrectly pushes filters through KNN
joins.
This PR adds a `KnnJoinEarlyRewrite` optimizer rule that converts KNN joins
to `SpatialJoinPlanNode` extension nodes **before** DataFusion's
`PushDownFilter` rule runs. Extension nodes naturally block filter pushdown via
`prevent_predicate_push_down_columns()` returning all columns.
## Changes
- **New `KnnJoinEarlyRewrite` optimizer rule** — handles two patterns:
1. `Join(filter=ST_KNN(...))` — when the ON clause has only the spatial
predicate
2. `Filter(ST_KNN(...), Join(on=[...]))` — when the ON clause also has
equi-join conditions (DataFusion's SQL planner separates these)
- **Positional rule insertion** — `MergeSpatialProjectionIntoJoin` and
`KnnJoinEarlyRewrite` are inserted before `PushDownFilter`, while
`SpatialJoinLogicalRewrite` (for non-KNN joins) remains after so non-KNN joins
still benefit from filter pushdown
- **Updated `SpatialJoinLogicalRewrite`** — skips KNN joins (already
handled by the early rewrite)
- **Integration tests** verifying that object-side filters are NOT pushed
down for KNN joins, but ARE pushed down for non-KNN spatial joins
## Rule ordering
```
... → MergeSpatialProjectionIntoJoin → KnnJoinEarlyRewrite → PushDownFilter
→ ... → SpatialJoinLogicalRewrite
```
## Follow-ups
1. We don't enforce `ST_KNN` to appear first in the chain of AND
expressions. For instance, `ST_KNN(L.geom, R.geom, 5) AND L.id = R.id` has the
same semantics as `L.id = R.id AND ST_KNN(L.geom, R.geom, 5)`. This seems to be
unnatural. Optimization rule does not seem to be a good place to enforce this,
so we leave it to future patches that work on SQL parser and ASTs.
2. We don't allow any filter pushdown for ST_KNN for now. Actually filters
applied to the query side could be pushed down without any problem, we need to
implement such rules ourselves in future patches.
## TODO
`prevent_predicate_push_down_columns` method seems to do the trick. I'll
experiment with it. Hopefully we can implement query side filter pushdown
easily.
**UPDATE**: No. It is a terrible idea. There's no shortcut. We have to
implement the optimization rule ourselves.
Closes https://github.com/apache/sedona-db/issues/605
---
rust/sedona-spatial-join/src/planner.rs | 9 +-
.../src/planner/logical_plan_node.rs | 10 +
rust/sedona-spatial-join/src/planner/optimizer.rs | 227 +++++++++++++++-----
.../src/planner/spatial_expr_utils.rs | 40 ++--
.../tests/spatial_join_integration.rs | 233 +++++++++++++++------
rust/sedona/src/context.rs | 2 +-
6 files changed, 384 insertions(+), 137 deletions(-)
diff --git a/rust/sedona-spatial-join/src/planner.rs
b/rust/sedona-spatial-join/src/planner.rs
index 11312aec..e40303bf 100644
--- a/rust/sedona-spatial-join/src/planner.rs
+++ b/rust/sedona-spatial-join/src/planner.rs
@@ -21,6 +21,7 @@
//! can produce `SpatialJoinExec`.
use datafusion::execution::SessionStateBuilder;
+use datafusion_common::Result;
mod logical_plan_node;
mod optimizer;
@@ -34,10 +35,12 @@ mod spatial_expr_utils;
/// implementation provided by this crate and ensures joins created by SQL or
using
/// a DataFrame API that meet certain conditions (e.g. contain a spatial
predicate as
/// a join condition) are executed using the `SpatialJoinExec`.
-pub fn register_planner(state_builder: SessionStateBuilder) ->
SessionStateBuilder {
+pub fn register_planner(state_builder: SessionStateBuilder) ->
Result<SessionStateBuilder> {
// Enable the logical rewrite that turns Filter(CrossJoin) into
Join(filter=...)
- let state_builder =
optimizer::register_spatial_join_logical_optimizer(state_builder);
+ let state_builder =
optimizer::register_spatial_join_logical_optimizer(state_builder)?;
// Enable planning SpatialJoinExec via an extension node during
logical->physical planning.
- physical_planner::register_spatial_join_planner(state_builder)
+ Ok(physical_planner::register_spatial_join_planner(
+ state_builder,
+ ))
}
diff --git a/rust/sedona-spatial-join/src/planner/logical_plan_node.rs
b/rust/sedona-spatial-join/src/planner/logical_plan_node.rs
index e93e228d..23f35d09 100644
--- a/rust/sedona-spatial-join/src/planner/logical_plan_node.rs
+++ b/rust/sedona-spatial-join/src/planner/logical_plan_node.rs
@@ -106,6 +106,16 @@ impl UserDefinedLogicalNodeCore for SpatialJoinPlanNode {
)
}
+ fn necessary_children_exprs(&self, _output_columns: &[usize]) ->
Option<Vec<Vec<usize>>> {
+ // Request all columns from both children. The default implementation
returns None, which
+ // should also be fine, but we need to return the columns indices
explicitly to workaround
+ // a bug in DataFusion's handling of None projection indices in FFI
table provider.
+ // See https://github.com/apache/datafusion/pull/20393
+ let left_indices: Vec<usize> =
(0..self.left.schema().fields().len()).collect();
+ let right_indices: Vec<usize> =
(0..self.right.schema().fields().len()).collect();
+ Some(vec![left_indices, right_indices])
+ }
+
fn with_exprs_and_inputs(
&self,
mut exprs: Vec<Expr>,
diff --git a/rust/sedona-spatial-join/src/planner/optimizer.rs
b/rust/sedona-spatial-join/src/planner/optimizer.rs
index 1edbc2d3..811c866b 100644
--- a/rust/sedona-spatial-join/src/planner/optimizer.rs
+++ b/rust/sedona-spatial-join/src/planner/optimizer.rs
@@ -18,9 +18,8 @@ use std::sync::Arc;
use crate::planner::logical_plan_node::SpatialJoinPlanNode;
use crate::planner::spatial_expr_utils::collect_spatial_predicate_names;
-use crate::planner::spatial_expr_utils::is_spatial_predicate;
use datafusion::execution::session_state::SessionStateBuilder;
-use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
+use datafusion::optimizer::{ApplyOrder, Optimizer, OptimizerConfig,
OptimizerRule};
use datafusion_common::tree_node::Transformed;
use datafusion_common::NullEquality;
use datafusion_common::Result;
@@ -28,21 +27,118 @@ use datafusion_expr::logical_plan::Extension;
use datafusion_expr::{BinaryExpr, Expr, Operator};
use datafusion_expr::{Filter, Join, JoinType, LogicalPlan};
use sedona_common::option::SedonaOptions;
+use sedona_common::{sedona_internal_datafusion_err, sedona_internal_err};
-/// Register only the logical spatial join optimizer rule.
+/// Register the logical spatial join optimizer rules.
///
-/// This enables building `Join(filter=...)` from patterns like
`Filter(CrossJoin)`.
-/// It intentionally does not register any physical plan rewrite rules.
+/// This inserts rules at specific positions relative to DataFusion's built-in
`PushDownFilter`
+/// rule to ensure correct semantics for KNN joins:
+///
+/// - `MergeSpatialFilterIntoJoin` and `KnnJoinEarlyRewrite` are inserted
*before*
+/// `PushDownFilter` so that KNN joins are converted to
`SpatialJoinPlanNode` extension nodes
+/// before filter pushdown runs. Extension nodes naturally block filter
pushdown via
+/// `prevent_predicate_push_down_columns()`, preventing incorrect pushdown
to the build side
+/// of KNN joins.
+///
+/// - `SpatialJoinLogicalRewrite` is appended at the end so that non-KNN
spatial joins still
+/// benefit from filter pushdown before being converted to extension nodes.
pub(crate) fn register_spatial_join_logical_optimizer(
- session_state_builder: SessionStateBuilder,
-) -> SessionStateBuilder {
- session_state_builder
- .with_optimizer_rule(Arc::new(MergeSpatialProjectionIntoJoin))
- .with_optimizer_rule(Arc::new(SpatialJoinLogicalRewrite))
+ mut session_state_builder: SessionStateBuilder,
+) -> Result<SessionStateBuilder> {
+ let optimizer = session_state_builder
+ .optimizer()
+ .get_or_insert_with(Optimizer::new);
+
+ // Find PushDownFilter position by name
+ let push_down_pos = optimizer
+ .rules
+ .iter()
+ .position(|r| r.name() == "push_down_filter")
+ .ok_or_else(|| {
+ sedona_internal_datafusion_err!(
+ "PushDownFilter rule not found in default optimizer rules"
+ )
+ })?;
+
+ // Insert KNN-specific rules BEFORE PushDownFilter.
+ // MergeSpatialFilterIntoJoin must come first because it creates the
Join(filter=...)
+ // nodes that KnnJoinEarlyRewrite then converts to SpatialJoinPlanNode.
+ optimizer
+ .rules
+ .insert(push_down_pos, Arc::new(KnnJoinEarlyRewrite));
+ optimizer
+ .rules
+ .insert(push_down_pos, Arc::new(MergeSpatialFilterIntoJoin));
+
+ // Append SpatialJoinLogicalRewrite at the end so non-KNN joins benefit
from filter pushdown.
+ optimizer.rules.push(Arc::new(SpatialJoinLogicalRewrite));
+
+ Ok(session_state_builder)
}
-/// Logical optimizer rule that enables spatial join planning.
+
+/// Early optimizer rule that converts KNN joins to `SpatialJoinPlanNode`
extension nodes
+/// *before* DataFusion's `PushDownFilter` runs.
///
-/// This rule turns eligible `Join(filter=...)` nodes into a
`SpatialJoinPlanNode` extension.
+/// This prevents `PushDownFilter` from pushing filters to the build (object)
side of KNN joins,
+/// which would change which objects are the K nearest neighbors and produce
incorrect results.
+///
+/// Extension nodes naturally block filter pushdown because their default
+/// `prevent_predicate_push_down_columns()` returns all columns.
+///
+/// Handles two patterns that DataFusion's SQL planner creates:
+///
+/// 1. `Join(filter=ST_KNN(...))` — when the ON clause has only the spatial
predicate
+/// 2. `Filter(ST_KNN(...), Join(on=[...]))` — when the ON clause also has
equi-join conditions,
+/// the SQL planner separates equi-conditions into `on` and the spatial
predicate into a Filter
+#[derive(Default, Debug)]
+struct KnnJoinEarlyRewrite;
+
+impl OptimizerRule for KnnJoinEarlyRewrite {
+ fn name(&self) -> &str {
+ "knn_join_early_rewrite"
+ }
+
+ fn apply_order(&self) -> Option<ApplyOrder> {
+ Some(ApplyOrder::BottomUp)
+ }
+
+ fn supports_rewrite(&self) -> bool {
+ true
+ }
+
+ fn rewrite(
+ &self,
+ plan: LogicalPlan,
+ config: &dyn OptimizerConfig,
+ ) -> Result<Transformed<LogicalPlan>> {
+ let options = config.options();
+ let Some(ext) = options.extensions.get::<SedonaOptions>() else {
+ return Ok(Transformed::no(plan));
+ };
+ if !ext.spatial_join.enable {
+ return Ok(Transformed::no(plan));
+ }
+
+ // Join(filter=ST_KNN(...))
+ if let LogicalPlan::Join(join) = &plan {
+ if let Some(filter) = join.filter.as_ref() {
+ let names = collect_spatial_predicate_names(filter);
+ if names.contains("st_knn") {
+ return rewrite_join_to_spatial_join_plan_node(join);
+ }
+ }
+ }
+
+ Ok(Transformed::no(plan))
+ }
+}
+
+/// Logical optimizer rule that converts non-KNN spatial joins to
`SpatialJoinPlanNode`.
+///
+/// This rule runs *after* `PushDownFilter` so that non-KNN spatial joins
benefit from
+/// filter pushdown before being converted to extension nodes.
+///
+/// KNN joins are skipped here because they are already handled by
[[KnnJoinEarlyRewrite]].
#[derive(Default, Debug)]
struct SpatialJoinLogicalRewrite;
@@ -86,54 +182,74 @@ impl OptimizerRule for SpatialJoinLogicalRewrite {
return Ok(Transformed::no(plan));
}
- // Join with with equi-join condition and spatial join condition. Only
handle it
- // when the join condition contains ST_KNN. KNN join is not a regular
join and
- // ST_KNN is also not a regular predicate. It must be handled by our
spatial join exec.
- if !join.on.is_empty() && !spatial_predicate_names.contains("st_knn") {
- return Ok(Transformed::no(plan));
+ if spatial_predicate_names.contains("st_knn") {
+ // KNN joins should have already been rewritten by
KnnJoinEarlyRewrite, so we shouldn't
+ // see them here.
+ return sedona_internal_err!(
+ "Found KNN predicate in SpatialJoinLogicalRewrite, which
should have been handled by KnnJoinEarlyRewrite");
}
- // Build new filter expression including equi-join conditions
- let filter = filter.clone();
- let eq_op = if join.null_equality == NullEquality::NullEqualsNothing {
- Operator::Eq
- } else {
- Operator::IsNotDistinctFrom
- };
- let filter = join.on.iter().fold(filter, |acc, (l, r)| {
- let eq_expr = Expr::BinaryExpr(BinaryExpr::new(
- Box::new(l.clone()),
- eq_op,
- Box::new(r.clone()),
- ));
- Expr::and(acc, eq_expr)
- });
-
- let schema = Arc::clone(&join.schema);
- let node = SpatialJoinPlanNode {
- left: join.left.as_ref().clone(),
- right: join.right.as_ref().clone(),
- join_type: join.join_type,
- filter,
- schema,
- join_constraint: join.join_constraint,
- null_equality: join.null_equality,
- };
+ // Join with with equi-join condition should be planned as a regular
HashJoin
+ // or SortMergeJoin.
+ if !join.on.is_empty() {
+ return Ok(Transformed::no(plan));
+ }
- Ok(Transformed::yes(LogicalPlan::Extension(Extension {
- node: Arc::new(node),
- })))
+ rewrite_join_to_spatial_join_plan_node(join)
}
}
+/// Shared helper: convert a `Join` node (with spatial predicate in `filter`)
to a
+/// `SpatialJoinPlanNode`, folding any equi-join `on` conditions into the
filter expression.
+fn rewrite_join_to_spatial_join_plan_node(join: &Join) ->
Result<Transformed<LogicalPlan>> {
+ let filter = join
+ .filter
+ .as_ref()
+ .ok_or_else(|| {
+ datafusion_common::DataFusionError::Internal(
+ "join filter must be present for spatial join
rewrite".to_string(),
+ )
+ })?
+ .clone();
+
+ let eq_op = if join.null_equality == NullEquality::NullEqualsNothing {
+ Operator::Eq
+ } else {
+ Operator::IsNotDistinctFrom
+ };
+ let filter = join.on.iter().fold(filter, |acc, (l, r)| {
+ let eq_expr = Expr::BinaryExpr(BinaryExpr::new(
+ Box::new(l.clone()),
+ eq_op,
+ Box::new(r.clone()),
+ ));
+ Expr::and(acc, eq_expr)
+ });
+
+ let schema = Arc::clone(&join.schema);
+ let node = SpatialJoinPlanNode {
+ left: join.left.as_ref().clone(),
+ right: join.right.as_ref().clone(),
+ join_type: join.join_type,
+ filter,
+ schema,
+ join_constraint: join.join_constraint,
+ null_equality: join.null_equality,
+ };
+
+ Ok(Transformed::yes(LogicalPlan::Extension(Extension {
+ node: Arc::new(node),
+ })))
+}
+
/// Logical optimizer rule that enables spatial join planning.
///
/// This rule turns eligible `Filter(Join(filter=...))` nodes into a
`Join(filter=...)` node,
/// so that the spatial join can be rewritten later by
[SpatialJoinLogicalRewrite].
#[derive(Debug, Default)]
-struct MergeSpatialProjectionIntoJoin;
+struct MergeSpatialFilterIntoJoin;
-impl OptimizerRule for MergeSpatialProjectionIntoJoin {
+impl OptimizerRule for MergeSpatialFilterIntoJoin {
fn name(&self) -> &str {
"merge_spatial_filter_into_join"
}
@@ -188,7 +304,9 @@ impl OptimizerRule for MergeSpatialProjectionIntoJoin {
else {
return Ok(Transformed::no(plan));
};
- if !is_spatial_predicate(predicate) {
+
+ let spatial_predicates = collect_spatial_predicate_names(predicate);
+ if spatial_predicates.is_empty() {
return Ok(Transformed::no(plan));
}
@@ -207,20 +325,25 @@ impl OptimizerRule for MergeSpatialProjectionIntoJoin {
};
// Check if this is a suitable join for rewriting
+ let is_equi_join = !on.is_empty() &&
!spatial_predicates.contains("st_knn");
if !matches!(
join_type,
JoinType::Inner | JoinType::Left | JoinType::Right
- ) || !on.is_empty()
- || filter.is_some()
+ ) || is_equi_join
{
return Ok(Transformed::no(plan));
}
+ let new_filter = match filter {
+ Some(existing_filter) => Expr::and(predicate.clone(),
existing_filter.clone()),
+ None => predicate.clone(),
+ };
+
let rewritten_plan = Join::try_new(
Arc::clone(left),
Arc::clone(right),
on.clone(),
- Some(predicate.clone()),
+ Some(new_filter),
JoinType::Inner,
*join_constraint,
*null_equality,
diff --git a/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs
b/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs
index 5fd85a50..1d858d45 100644
--- a/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs
+++ b/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs
@@ -98,13 +98,6 @@ pub(crate) fn collect_spatial_predicate_names(expr: &Expr)
-> HashSet<String> {
acc
}
-/// Check if a given logical expression contains a spatial predicate component
or not. We assume that the given
-/// `expr` evaluates to a boolean value and originates from a filter logical
node.
-pub(crate) fn is_spatial_predicate(expr: &Expr) -> bool {
- let pred_names = collect_spatial_predicate_names(expr);
- !pred_names.is_empty()
-}
-
/// Transform the join filter to a spatial predicate and a remainder.
///
/// * The spatial predicate is a spatial predicate that is extracted from
the join filter.
@@ -2244,16 +2237,17 @@ mod tests {
}
#[test]
- fn test_is_spatial_predicate() {
- // Test 1: ST_ functions should return true
+ fn test_collect_spatial_predicate_names() {
+ // ST_Intersects should be collected
let st_intersects_udf = create_dummy_st_intersects_udf();
let st_intersects_expr =
Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction {
func: st_intersects_udf,
args: vec![col("geom1"), col("geom2")],
});
- assert!(is_spatial_predicate(&st_intersects_expr));
+ let names = collect_spatial_predicate_names(&st_intersects_expr);
+ assert_eq!(names, HashSet::from(["st_intersects".to_string()]));
- // ST_Distance(geom1, geom2) < 100 should return true
+ // ST_Distance(geom1, geom2) < 100 should be collected as st_dwithin
let st_distance_udf = create_dummy_st_distance_udf();
let st_distance_expr =
Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction {
func: st_distance_udf,
@@ -2264,29 +2258,33 @@ mod tests {
op: Operator::Lt,
right: Box::new(lit(100.0)),
});
- assert!(is_spatial_predicate(&distance_lt_expr));
+ let names = collect_spatial_predicate_names(&distance_lt_expr);
+ assert_eq!(names, HashSet::from(["st_dwithin".to_string()]));
- // ST_Distance(geom1, geom2) > 100 should return false
+ // ST_Distance(geom1, geom2) > 100 should not be collected (wrong
comparison direction)
let distance_gt_expr =
Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
left: Box::new(st_distance_expr.clone()),
op: Operator::Gt,
right: Box::new(lit(100.0)),
});
- assert!(!is_spatial_predicate(&distance_gt_expr));
+ let names = collect_spatial_predicate_names(&distance_gt_expr);
+ assert!(names.is_empty());
- // AND expressions with spatial predicates should return true
+ // AND expression: spatial predicate should be collected through
conjunction
let and_expr = Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
left: Box::new(st_intersects_expr.clone()),
op: Operator::And,
right: Box::new(col("id").eq(lit(1))),
});
- assert!(is_spatial_predicate(&and_expr));
+ let names = collect_spatial_predicate_names(&and_expr);
+ assert_eq!(names, HashSet::from(["st_intersects".to_string()]));
- // Non-spatial expressions should return false
+ // Non-spatial expressions should return empty set
// Simple column comparison
let non_spatial_expr = col("id").eq(lit(1));
- assert!(!is_spatial_predicate(&non_spatial_expr));
+ let names = collect_spatial_predicate_names(&non_spatial_expr);
+ assert!(names.is_empty());
// Not a spatial relationship function
let non_st_func =
Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction {
@@ -2299,7 +2297,8 @@ mod tests {
))),
args: vec![col("id")],
});
- assert!(!is_spatial_predicate(&non_st_func));
+ let names = collect_spatial_predicate_names(&non_st_func);
+ assert!(names.is_empty());
// AND expression with no spatial predicates
let non_spatial_and =
Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
@@ -2307,6 +2306,7 @@ mod tests {
op: Operator::And,
right: Box::new(col("name").eq(lit("test"))),
});
- assert!(!is_spatial_predicate(&non_spatial_and));
+ let names = collect_spatial_predicate_names(&non_spatial_and);
+ assert!(names.is_empty());
}
}
diff --git a/rust/sedona-spatial-join/tests/spatial_join_integration.rs
b/rust/sedona-spatial-join/tests/spatial_join_integration.rs
index 671ce2c8..2add149d 100644
--- a/rust/sedona-spatial-join/tests/spatial_join_integration.rs
+++ b/rust/sedona-spatial-join/tests/spatial_join_integration.rs
@@ -21,12 +21,14 @@ use arrow_array::{Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::{
catalog::{MemTable, TableProvider},
+ datasource::empty::EmptyTable,
execution::SessionStateBuilder,
prelude::{SessionConfig, SessionContext},
};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::Result;
use datafusion_expr::{ColumnarValue, JoinType};
+use datafusion_physical_plan::filter::FilterExec;
use datafusion_physical_plan::joins::NestedLoopJoinExec;
use datafusion_physical_plan::ExecutionPlan;
use geo::{Distance, Euclidean};
@@ -142,7 +144,7 @@ fn setup_context(options: Option<SpatialJoinOptions>,
batch_size: usize) -> Resu
session_config = add_sedona_option_extension(session_config);
let mut state_builder = SessionStateBuilder::new();
if let Some(options) = options {
- state_builder = register_planner(state_builder);
+ state_builder = register_planner(state_builder)?;
let opts = session_config
.options_mut()
.extensions
@@ -1088,72 +1090,90 @@ async fn test_knn_join_with_filter_correctness(
};
let k = 3;
- let sql = format!(
- "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON ST_KNN(L.geometry,
R.geometry, {}, false) AND (L.id % 7) = (R.id % 7)",
- k
- );
+ let sqls = [
+ format!(
+ "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON
ST_KNN(L.geometry, R.geometry, {}, false) AND (L.id % 7) = (R.id % 7)",
+ k
+ ),
+ format!(
+ "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON
ST_KNN(L.geometry, R.geometry, {}, false) AND L.id % 7 = 0",
+ k
+ ),
+ format!(
+ "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON
ST_KNN(L.geometry, R.geometry, {}, false) AND R.id % 7 = 0",
+ k
+ ),
+ ];
- let batches = run_spatial_join_query(
- &left_schema,
- &right_schema,
- left_partitions.clone(),
- right_partitions.clone(),
- Some(options),
- max_batch_size,
- &sql,
- )
- .await?;
+ for (idx, sql) in sqls.iter().enumerate() {
+ let batches = run_spatial_join_query(
+ &left_schema,
+ &right_schema,
+ left_partitions.clone(),
+ right_partitions.clone(),
+ Some(options.clone()),
+ max_batch_size,
+ sql,
+ )
+ .await?;
- let mut actual_results = Vec::new();
- let combined_batch = arrow::compute::concat_batches(&batches.schema(),
&[batches])?;
- let l_ids = combined_batch
- .column(0)
- .as_any()
- .downcast_ref::<arrow_array::Int32Array>()
- .unwrap();
- let r_ids = combined_batch
- .column(1)
- .as_any()
- .downcast_ref::<arrow_array::Int32Array>()
- .unwrap();
+ let mut actual_results = Vec::new();
+ let combined_batch = arrow::compute::concat_batches(&batches.schema(),
&[batches])?;
+ let l_ids = combined_batch
+ .column(0)
+ .as_any()
+ .downcast_ref::<arrow_array::Int32Array>()
+ .unwrap();
+ let r_ids = combined_batch
+ .column(1)
+ .as_any()
+ .downcast_ref::<arrow_array::Int32Array>()
+ .unwrap();
- for i in 0..combined_batch.num_rows() {
- actual_results.push((l_ids.value(i), r_ids.value(i)));
- }
- actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
+ for i in 0..combined_batch.num_rows() {
+ actual_results.push((l_ids.value(i), r_ids.value(i)));
+ }
+ actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(||
a.1.cmp(&b.1)));
- // Prove the test actually exercises the "< K rows after filtering" case.
- // Build a list of all probe-side IDs and count how many results each has.
- let all_left_ids: Vec<i32> = extract_geoms_and_ids(&left_partitions)
- .into_iter()
- .map(|(id, _)| id)
- .collect();
- let mut per_left_counts: std::collections::HashMap<i32, usize> =
- std::collections::HashMap::new();
- for (l_id, _) in &actual_results {
- *per_left_counts.entry(*l_id).or_default() += 1;
- }
- let min_count = all_left_ids
- .iter()
- .map(|l_id| *per_left_counts.get(l_id).unwrap_or(&0))
- .min()
- .unwrap_or(0);
- assert!(
- min_count < k,
- "expected at least one probe row to produce < K rows after filtering;
min_count={min_count}, k={k}"
- );
+ // Prove the test actually exercises the "< K rows after filtering"
case.
+ // Build a list of all probe-side IDs and count how many results each
has.
+ let all_left_ids: Vec<i32> = extract_geoms_and_ids(&left_partitions)
+ .into_iter()
+ .map(|(id, _)| id)
+ .collect();
+ let mut per_left_counts: std::collections::HashMap<i32, usize> =
+ std::collections::HashMap::new();
+ for (l_id, _) in &actual_results {
+ *per_left_counts.entry(*l_id).or_default() += 1;
+ }
+ let min_count = all_left_ids
+ .iter()
+ .map(|l_id| *per_left_counts.get(l_id).unwrap_or(&0))
+ .min()
+ .unwrap_or(0);
+ assert!(
+ min_count < k,
+ "expected at least one probe row to produce < K rows after
filtering; min_count={min_count}, k={k}"
+ );
- let expected_results = compute_knn_ground_truth_with_pair_filter(
- &left_partitions,
- &right_partitions,
- k,
- |l_id, r_id| (l_id.rem_euclid(7)) == (r_id.rem_euclid(7)),
- )
- .into_iter()
- .map(|(l, r, _)| (l, r))
- .collect::<Vec<_>>();
+ let filter_closure = match idx {
+ 0 => |l_id: i32, r_id: i32| (l_id.rem_euclid(7)) ==
(r_id.rem_euclid(7)),
+ 1 => |l_id: i32, _r_id: i32| l_id.rem_euclid(7) == 0,
+ 2 => |_l_id: i32, r_id: i32| r_id.rem_euclid(7) == 0,
+ _ => unreachable!(),
+ };
+ let expected_results = compute_knn_ground_truth_with_pair_filter(
+ &left_partitions,
+ &right_partitions,
+ k,
+ filter_closure,
+ )
+ .into_iter()
+ .map(|(l, r, _)| (l, r))
+ .collect::<Vec<_>>();
- assert_eq!(actual_results, expected_results);
+ assert_eq!(actual_results, expected_results);
+ }
Ok(())
}
@@ -1368,3 +1388,94 @@ async fn test_knn_join_include_tie_breakers(
Ok(())
}
+
+/// Verify that a filter on the *object* (build / right) side of a KNN join is
NOT pushed down
+/// into the build side subtree.
+///
+/// If `PushDownFilter` incorrectly pushes `R.id > 5` below the spatial join,
the set of objects
+/// considered for the KNN search changes, yielding wrong nearest-neighbor
results.
+#[tokio::test]
+async fn test_knn_join_object_side_filter_not_pushed_down() -> Result<()> {
+ let sql = "SELECT L.id, R.id \
+ FROM L JOIN R ON ST_KNN(ST_Point(L.x, 0), ST_Point(R.x, 1), 3,
false) \
+ WHERE R.id > 5";
+ let plan = plan_for_filter_pushdown_test(sql).await?;
+
+ let spatial_joins = collect_spatial_join_exec(&plan)?;
+ assert_eq!(
+ spatial_joins.len(),
+ 1,
+ "expected exactly one SpatialJoinExec"
+ );
+ let sj = spatial_joins[0];
+
+ // The build (right / object) side must NOT have a FilterExec pushed into
it.
+ assert!(
+ !subtree_contains_filter_exec(&sj.right),
+ "FilterExec should NOT be pushed into the object (right/build) side of
a KNN join"
+ );
+
+ Ok(())
+}
+
+/// Verify that for a *non-KNN* spatial join, a filter on the build side IS
pushed down
+/// (the normal, desirable behaviour).
+#[tokio::test]
+async fn test_non_knn_join_object_side_filter_is_pushed_down() -> Result<()> {
+ let sql = "SELECT L.id, R.id \
+ FROM L JOIN R ON ST_Intersects(ST_Buffer(ST_Point(L.x, 0),
1.5), ST_Point(R.x, 1)) \
+ WHERE R.id > 5";
+ let plan = plan_for_filter_pushdown_test(sql).await?;
+
+ let spatial_joins = collect_spatial_join_exec(&plan)?;
+ assert_eq!(
+ spatial_joins.len(),
+ 1,
+ "expected exactly one SpatialJoinExec"
+ );
+ let sj = spatial_joins[0];
+
+ // For non-KNN joins, the filter SHOULD be pushed down to the build side.
+ assert!(
+ subtree_contains_filter_exec(&sj.right),
+ "FilterExec should be pushed into the object (right/build) side of a
non-KNN spatial join"
+ );
+
+ Ok(())
+}
+
+/// Recursively check whether any node in the physical plan tree is a
`FilterExec`.
+fn subtree_contains_filter_exec(plan: &Arc<dyn ExecutionPlan>) -> bool {
+ let mut found = false;
+ plan.apply(|node| {
+ if node.as_any().downcast_ref::<FilterExec>().is_some() {
+ found = true;
+ return Ok(TreeNodeRecursion::Stop);
+ }
+ Ok(TreeNodeRecursion::Continue)
+ })
+ .expect("failed to walk plan");
+ found
+}
+
+/// Create a session context with two small tables for filter-pushdown tests.
+///
+/// L(id INT, x DOUBLE) and R(id INT, x DOUBLE) are all empty, this is just
for exercising the
+/// plan optimizer and physical planner.
+/// Geometry is constructed in SQL via ST_Point so no geometry column exists
on the table itself.
+async fn plan_for_filter_pushdown_test(sql: &str) -> Result<Arc<dyn
ExecutionPlan>> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("id", DataType::Int32, false),
+ Field::new("x", DataType::Float64, false),
+ ]));
+
+ let options = SpatialJoinOptions::default();
+ let ctx = setup_context(Some(options), 100)?;
+ let empty_l: Arc<dyn TableProvider> =
Arc::new(EmptyTable::new(schema.clone()));
+ let empty_r: Arc<dyn TableProvider> =
Arc::new(EmptyTable::new(schema.clone()));
+ ctx.register_table("L", empty_l)?;
+ ctx.register_table("R", empty_r)?;
+
+ let df = ctx.sql(sql).await?;
+ df.create_physical_plan().await
+}
diff --git a/rust/sedona/src/context.rs b/rust/sedona/src/context.rs
index 2ea9cf03..fa659f82 100644
--- a/rust/sedona/src/context.rs
+++ b/rust/sedona/src/context.rs
@@ -117,7 +117,7 @@ impl SedonaContext {
// Register the spatial join planner extension
#[cfg(feature = "spatial-join")]
{
- state_builder =
sedona_spatial_join::register_planner(state_builder);
+ state_builder =
sedona_spatial_join::register_planner(state_builder)?;
}
let mut state = state_builder.build();