This is an automated email from the ASF dual-hosted git repository.

kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new 55b822f5 fix(rust/sedona-spatial-join): prevent filter pushdown past 
KNN joins (#611)
55b822f5 is described below

commit 55b822f5f1ef587d938a44df34a7b3c973b1c5ea
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Feb 18 22:18:36 2026 +0800

    fix(rust/sedona-spatial-join): prevent filter pushdown past KNN joins (#611)
    
    ## Summary
    
    KNN joins have different semantics than regular spatial joins — pushing 
filters to the object (build) side changes which objects are the k nearest 
neighbors, producing incorrect results. DataFusion's builtin `PushDownFilter` 
optimizer rule doesn't know this and incorrectly pushes filters through KNN 
joins.
    
    This PR adds a `KnnJoinEarlyRewrite` optimizer rule that converts KNN joins 
to `SpatialJoinPlanNode` extension nodes **before** DataFusion's 
`PushDownFilter` rule runs. Extension nodes naturally block filter pushdown via 
`prevent_predicate_push_down_columns()` returning all columns.
    
    ## Changes
    
    - **New `KnnJoinEarlyRewrite` optimizer rule** — handles two patterns:
      1. `Join(filter=ST_KNN(...))` — when the ON clause has only the spatial 
predicate
      2. `Filter(ST_KNN(...), Join(on=[...]))` — when the ON clause also has 
equi-join conditions (DataFusion's SQL planner separates these)
    - **Positional rule insertion** — `MergeSpatialProjectionIntoJoin` and 
`KnnJoinEarlyRewrite` are inserted before `PushDownFilter`, while 
`SpatialJoinLogicalRewrite` (for non-KNN joins) remains after so non-KNN joins 
still benefit from filter pushdown
    - **Updated `SpatialJoinLogicalRewrite`** — skips KNN joins (already 
handled by the early rewrite)
    - **Integration tests** verifying that object-side filters are NOT pushed 
down for KNN joins, but ARE pushed down for non-KNN spatial joins
    
    ## Rule ordering
    
    ```
    ... → MergeSpatialProjectionIntoJoin → KnnJoinEarlyRewrite → PushDownFilter 
→ ... → SpatialJoinLogicalRewrite
    ```
    
    ## Follow-ups
    
    1. We don't enforce `ST_KNN` to appear first in the chain of AND 
expressions. For instance, `ST_KNN(L.geom, R.geom, 5) AND L.id = R.id` has the 
same semantics as `L.id = R.id AND ST_KNN(L.geom, R.geom, 5)`. This seems to be 
unnatural. Optimization rule does not seem to be a good place to enforce this, 
so we leave it to future patches that work on SQL parser and ASTs.
    2. We don't allow any filter pushdown for ST_KNN for now. Actually filters 
applied to the query side could be pushed down without any problem, we need to 
implement such rules ourselves in future patches.
    
    ## TODO
    
    `prevent_predicate_push_down_columns` method seems to do the trick. I'll 
experiment with it. Hopefully we can implement query side filter pushdown 
easily.
    **UPDATE**: No. It is a terrible idea. There's no shortcut. We have to 
implement the optimization rule ourselves.
    
    Closes https://github.com/apache/sedona-db/issues/605
---
 rust/sedona-spatial-join/src/planner.rs            |   9 +-
 .../src/planner/logical_plan_node.rs               |  10 +
 rust/sedona-spatial-join/src/planner/optimizer.rs  | 227 +++++++++++++++-----
 .../src/planner/spatial_expr_utils.rs              |  40 ++--
 .../tests/spatial_join_integration.rs              | 233 +++++++++++++++------
 rust/sedona/src/context.rs                         |   2 +-
 6 files changed, 384 insertions(+), 137 deletions(-)

diff --git a/rust/sedona-spatial-join/src/planner.rs 
b/rust/sedona-spatial-join/src/planner.rs
index 11312aec..e40303bf 100644
--- a/rust/sedona-spatial-join/src/planner.rs
+++ b/rust/sedona-spatial-join/src/planner.rs
@@ -21,6 +21,7 @@
 //! can produce `SpatialJoinExec`.
 
 use datafusion::execution::SessionStateBuilder;
+use datafusion_common::Result;
 
 mod logical_plan_node;
 mod optimizer;
@@ -34,10 +35,12 @@ mod spatial_expr_utils;
 /// implementation provided by this crate and ensures joins created by SQL or 
using
 /// a DataFrame API that meet certain conditions (e.g. contain a spatial 
predicate as
 /// a join condition) are executed using the `SpatialJoinExec`.
-pub fn register_planner(state_builder: SessionStateBuilder) -> 
SessionStateBuilder {
+pub fn register_planner(state_builder: SessionStateBuilder) -> 
Result<SessionStateBuilder> {
     // Enable the logical rewrite that turns Filter(CrossJoin) into 
Join(filter=...)
-    let state_builder = 
optimizer::register_spatial_join_logical_optimizer(state_builder);
+    let state_builder = 
optimizer::register_spatial_join_logical_optimizer(state_builder)?;
 
     // Enable planning SpatialJoinExec via an extension node during 
logical->physical planning.
-    physical_planner::register_spatial_join_planner(state_builder)
+    Ok(physical_planner::register_spatial_join_planner(
+        state_builder,
+    ))
 }
diff --git a/rust/sedona-spatial-join/src/planner/logical_plan_node.rs 
b/rust/sedona-spatial-join/src/planner/logical_plan_node.rs
index e93e228d..23f35d09 100644
--- a/rust/sedona-spatial-join/src/planner/logical_plan_node.rs
+++ b/rust/sedona-spatial-join/src/planner/logical_plan_node.rs
@@ -106,6 +106,16 @@ impl UserDefinedLogicalNodeCore for SpatialJoinPlanNode {
         )
     }
 
+    fn necessary_children_exprs(&self, _output_columns: &[usize]) -> 
Option<Vec<Vec<usize>>> {
+        // Request all columns from both children. The default implementation 
returns None, which
+        // should also be fine, but we need to return the columns indices 
explicitly to workaround
+        // a bug in DataFusion's handling of None projection indices in FFI 
table provider.
+        // See https://github.com/apache/datafusion/pull/20393
+        let left_indices: Vec<usize> = 
(0..self.left.schema().fields().len()).collect();
+        let right_indices: Vec<usize> = 
(0..self.right.schema().fields().len()).collect();
+        Some(vec![left_indices, right_indices])
+    }
+
     fn with_exprs_and_inputs(
         &self,
         mut exprs: Vec<Expr>,
diff --git a/rust/sedona-spatial-join/src/planner/optimizer.rs 
b/rust/sedona-spatial-join/src/planner/optimizer.rs
index 1edbc2d3..811c866b 100644
--- a/rust/sedona-spatial-join/src/planner/optimizer.rs
+++ b/rust/sedona-spatial-join/src/planner/optimizer.rs
@@ -18,9 +18,8 @@ use std::sync::Arc;
 
 use crate::planner::logical_plan_node::SpatialJoinPlanNode;
 use crate::planner::spatial_expr_utils::collect_spatial_predicate_names;
-use crate::planner::spatial_expr_utils::is_spatial_predicate;
 use datafusion::execution::session_state::SessionStateBuilder;
-use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
+use datafusion::optimizer::{ApplyOrder, Optimizer, OptimizerConfig, 
OptimizerRule};
 use datafusion_common::tree_node::Transformed;
 use datafusion_common::NullEquality;
 use datafusion_common::Result;
@@ -28,21 +27,118 @@ use datafusion_expr::logical_plan::Extension;
 use datafusion_expr::{BinaryExpr, Expr, Operator};
 use datafusion_expr::{Filter, Join, JoinType, LogicalPlan};
 use sedona_common::option::SedonaOptions;
+use sedona_common::{sedona_internal_datafusion_err, sedona_internal_err};
 
-/// Register only the logical spatial join optimizer rule.
+/// Register the logical spatial join optimizer rules.
 ///
-/// This enables building `Join(filter=...)` from patterns like 
`Filter(CrossJoin)`.
-/// It intentionally does not register any physical plan rewrite rules.
+/// This inserts rules at specific positions relative to DataFusion's built-in 
`PushDownFilter`
+/// rule to ensure correct semantics for KNN joins:
+///
+/// - `MergeSpatialFilterIntoJoin` and `KnnJoinEarlyRewrite` are inserted 
*before*
+///   `PushDownFilter` so that KNN joins are converted to 
`SpatialJoinPlanNode` extension nodes
+///   before filter pushdown runs. Extension nodes naturally block filter 
pushdown via
+///   `prevent_predicate_push_down_columns()`, preventing incorrect pushdown 
to the build side
+///   of KNN joins.
+///
+/// - `SpatialJoinLogicalRewrite` is appended at the end so that non-KNN 
spatial joins still
+///   benefit from filter pushdown before being converted to extension nodes.
 pub(crate) fn register_spatial_join_logical_optimizer(
-    session_state_builder: SessionStateBuilder,
-) -> SessionStateBuilder {
-    session_state_builder
-        .with_optimizer_rule(Arc::new(MergeSpatialProjectionIntoJoin))
-        .with_optimizer_rule(Arc::new(SpatialJoinLogicalRewrite))
+    mut session_state_builder: SessionStateBuilder,
+) -> Result<SessionStateBuilder> {
+    let optimizer = session_state_builder
+        .optimizer()
+        .get_or_insert_with(Optimizer::new);
+
+    // Find PushDownFilter position by name
+    let push_down_pos = optimizer
+        .rules
+        .iter()
+        .position(|r| r.name() == "push_down_filter")
+        .ok_or_else(|| {
+            sedona_internal_datafusion_err!(
+                "PushDownFilter rule not found in default optimizer rules"
+            )
+        })?;
+
+    // Insert KNN-specific rules BEFORE PushDownFilter.
+    // MergeSpatialFilterIntoJoin must come first because it creates the 
Join(filter=...)
+    // nodes that KnnJoinEarlyRewrite then converts to SpatialJoinPlanNode.
+    optimizer
+        .rules
+        .insert(push_down_pos, Arc::new(KnnJoinEarlyRewrite));
+    optimizer
+        .rules
+        .insert(push_down_pos, Arc::new(MergeSpatialFilterIntoJoin));
+
+    // Append SpatialJoinLogicalRewrite at the end so non-KNN joins benefit 
from filter pushdown.
+    optimizer.rules.push(Arc::new(SpatialJoinLogicalRewrite));
+
+    Ok(session_state_builder)
 }
-/// Logical optimizer rule that enables spatial join planning.
+
+/// Early optimizer rule that converts KNN joins to `SpatialJoinPlanNode` 
extension nodes
+/// *before* DataFusion's `PushDownFilter` runs.
 ///
-/// This rule turns eligible `Join(filter=...)` nodes into a 
`SpatialJoinPlanNode` extension.
+/// This prevents `PushDownFilter` from pushing filters to the build (object) 
side of KNN joins,
+/// which would change which objects are the K nearest neighbors and produce 
incorrect results.
+///
+/// Extension nodes naturally block filter pushdown because their default
+/// `prevent_predicate_push_down_columns()` returns all columns.
+///
+/// Handles two patterns that DataFusion's SQL planner creates:
+///
+/// 1. `Join(filter=ST_KNN(...))` — when the ON clause has only the spatial 
predicate
+/// 2. `Filter(ST_KNN(...), Join(on=[...]))` — when the ON clause also has 
equi-join conditions,
+///    the SQL planner separates equi-conditions into `on` and the spatial 
predicate into a Filter
+#[derive(Default, Debug)]
+struct KnnJoinEarlyRewrite;
+
+impl OptimizerRule for KnnJoinEarlyRewrite {
+    fn name(&self) -> &str {
+        "knn_join_early_rewrite"
+    }
+
+    fn apply_order(&self) -> Option<ApplyOrder> {
+        Some(ApplyOrder::BottomUp)
+    }
+
+    fn supports_rewrite(&self) -> bool {
+        true
+    }
+
+    fn rewrite(
+        &self,
+        plan: LogicalPlan,
+        config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        let options = config.options();
+        let Some(ext) = options.extensions.get::<SedonaOptions>() else {
+            return Ok(Transformed::no(plan));
+        };
+        if !ext.spatial_join.enable {
+            return Ok(Transformed::no(plan));
+        }
+
+        // Join(filter=ST_KNN(...))
+        if let LogicalPlan::Join(join) = &plan {
+            if let Some(filter) = join.filter.as_ref() {
+                let names = collect_spatial_predicate_names(filter);
+                if names.contains("st_knn") {
+                    return rewrite_join_to_spatial_join_plan_node(join);
+                }
+            }
+        }
+
+        Ok(Transformed::no(plan))
+    }
+}
+
+/// Logical optimizer rule that converts non-KNN spatial joins to 
`SpatialJoinPlanNode`.
+///
+/// This rule runs *after* `PushDownFilter` so that non-KNN spatial joins 
benefit from
+/// filter pushdown before being converted to extension nodes.
+///
+/// KNN joins are skipped here because they are already handled by 
[[KnnJoinEarlyRewrite]].
 #[derive(Default, Debug)]
 struct SpatialJoinLogicalRewrite;
 
@@ -86,54 +182,74 @@ impl OptimizerRule for SpatialJoinLogicalRewrite {
             return Ok(Transformed::no(plan));
         }
 
-        // Join with with equi-join condition and spatial join condition. Only 
handle it
-        // when the join condition contains ST_KNN. KNN join is not a regular 
join and
-        // ST_KNN is also not a regular predicate. It must be handled by our 
spatial join exec.
-        if !join.on.is_empty() && !spatial_predicate_names.contains("st_knn") {
-            return Ok(Transformed::no(plan));
+        if spatial_predicate_names.contains("st_knn") {
+            // KNN joins should have already been rewritten by 
KnnJoinEarlyRewrite, so we shouldn't
+            // see them here.
+            return sedona_internal_err!(
+                "Found KNN predicate in SpatialJoinLogicalRewrite, which 
should have been handled by KnnJoinEarlyRewrite");
         }
 
-        // Build new filter expression including equi-join conditions
-        let filter = filter.clone();
-        let eq_op = if join.null_equality == NullEquality::NullEqualsNothing {
-            Operator::Eq
-        } else {
-            Operator::IsNotDistinctFrom
-        };
-        let filter = join.on.iter().fold(filter, |acc, (l, r)| {
-            let eq_expr = Expr::BinaryExpr(BinaryExpr::new(
-                Box::new(l.clone()),
-                eq_op,
-                Box::new(r.clone()),
-            ));
-            Expr::and(acc, eq_expr)
-        });
-
-        let schema = Arc::clone(&join.schema);
-        let node = SpatialJoinPlanNode {
-            left: join.left.as_ref().clone(),
-            right: join.right.as_ref().clone(),
-            join_type: join.join_type,
-            filter,
-            schema,
-            join_constraint: join.join_constraint,
-            null_equality: join.null_equality,
-        };
+        // Join with with equi-join condition should be planned as a regular 
HashJoin
+        // or SortMergeJoin.
+        if !join.on.is_empty() {
+            return Ok(Transformed::no(plan));
+        }
 
-        Ok(Transformed::yes(LogicalPlan::Extension(Extension {
-            node: Arc::new(node),
-        })))
+        rewrite_join_to_spatial_join_plan_node(join)
     }
 }
 
+/// Shared helper: convert a `Join` node (with spatial predicate in `filter`) 
to a
+/// `SpatialJoinPlanNode`, folding any equi-join `on` conditions into the 
filter expression.
+fn rewrite_join_to_spatial_join_plan_node(join: &Join) -> 
Result<Transformed<LogicalPlan>> {
+    let filter = join
+        .filter
+        .as_ref()
+        .ok_or_else(|| {
+            datafusion_common::DataFusionError::Internal(
+                "join filter must be present for spatial join 
rewrite".to_string(),
+            )
+        })?
+        .clone();
+
+    let eq_op = if join.null_equality == NullEquality::NullEqualsNothing {
+        Operator::Eq
+    } else {
+        Operator::IsNotDistinctFrom
+    };
+    let filter = join.on.iter().fold(filter, |acc, (l, r)| {
+        let eq_expr = Expr::BinaryExpr(BinaryExpr::new(
+            Box::new(l.clone()),
+            eq_op,
+            Box::new(r.clone()),
+        ));
+        Expr::and(acc, eq_expr)
+    });
+
+    let schema = Arc::clone(&join.schema);
+    let node = SpatialJoinPlanNode {
+        left: join.left.as_ref().clone(),
+        right: join.right.as_ref().clone(),
+        join_type: join.join_type,
+        filter,
+        schema,
+        join_constraint: join.join_constraint,
+        null_equality: join.null_equality,
+    };
+
+    Ok(Transformed::yes(LogicalPlan::Extension(Extension {
+        node: Arc::new(node),
+    })))
+}
+
 /// Logical optimizer rule that enables spatial join planning.
 ///
 /// This rule turns eligible `Filter(Join(filter=...))` nodes into a 
`Join(filter=...)` node,
 /// so that the spatial join can be rewritten later by 
[SpatialJoinLogicalRewrite].
 #[derive(Debug, Default)]
-struct MergeSpatialProjectionIntoJoin;
+struct MergeSpatialFilterIntoJoin;
 
-impl OptimizerRule for MergeSpatialProjectionIntoJoin {
+impl OptimizerRule for MergeSpatialFilterIntoJoin {
     fn name(&self) -> &str {
         "merge_spatial_filter_into_join"
     }
@@ -188,7 +304,9 @@ impl OptimizerRule for MergeSpatialProjectionIntoJoin {
         else {
             return Ok(Transformed::no(plan));
         };
-        if !is_spatial_predicate(predicate) {
+
+        let spatial_predicates = collect_spatial_predicate_names(predicate);
+        if spatial_predicates.is_empty() {
             return Ok(Transformed::no(plan));
         }
 
@@ -207,20 +325,25 @@ impl OptimizerRule for MergeSpatialProjectionIntoJoin {
         };
 
         // Check if this is a suitable join for rewriting
+        let is_equi_join = !on.is_empty() && 
!spatial_predicates.contains("st_knn");
         if !matches!(
             join_type,
             JoinType::Inner | JoinType::Left | JoinType::Right
-        ) || !on.is_empty()
-            || filter.is_some()
+        ) || is_equi_join
         {
             return Ok(Transformed::no(plan));
         }
 
+        let new_filter = match filter {
+            Some(existing_filter) => Expr::and(predicate.clone(), 
existing_filter.clone()),
+            None => predicate.clone(),
+        };
+
         let rewritten_plan = Join::try_new(
             Arc::clone(left),
             Arc::clone(right),
             on.clone(),
-            Some(predicate.clone()),
+            Some(new_filter),
             JoinType::Inner,
             *join_constraint,
             *null_equality,
diff --git a/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs 
b/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs
index 5fd85a50..1d858d45 100644
--- a/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs
+++ b/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs
@@ -98,13 +98,6 @@ pub(crate) fn collect_spatial_predicate_names(expr: &Expr) 
-> HashSet<String> {
     acc
 }
 
-/// Check if a given logical expression contains a spatial predicate component 
or not. We assume that the given
-/// `expr` evaluates to a boolean value and originates from a filter logical 
node.
-pub(crate) fn is_spatial_predicate(expr: &Expr) -> bool {
-    let pred_names = collect_spatial_predicate_names(expr);
-    !pred_names.is_empty()
-}
-
 /// Transform the join filter to a spatial predicate and a remainder.
 ///
 ///   * The spatial predicate is a spatial predicate that is extracted from 
the join filter.
@@ -2244,16 +2237,17 @@ mod tests {
     }
 
     #[test]
-    fn test_is_spatial_predicate() {
-        // Test 1: ST_ functions should return true
+    fn test_collect_spatial_predicate_names() {
+        // ST_Intersects should be collected
         let st_intersects_udf = create_dummy_st_intersects_udf();
         let st_intersects_expr = 
Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction {
             func: st_intersects_udf,
             args: vec![col("geom1"), col("geom2")],
         });
-        assert!(is_spatial_predicate(&st_intersects_expr));
+        let names = collect_spatial_predicate_names(&st_intersects_expr);
+        assert_eq!(names, HashSet::from(["st_intersects".to_string()]));
 
-        // ST_Distance(geom1, geom2) < 100 should return true
+        // ST_Distance(geom1, geom2) < 100 should be collected as st_dwithin
         let st_distance_udf = create_dummy_st_distance_udf();
         let st_distance_expr = 
Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction {
             func: st_distance_udf,
@@ -2264,29 +2258,33 @@ mod tests {
             op: Operator::Lt,
             right: Box::new(lit(100.0)),
         });
-        assert!(is_spatial_predicate(&distance_lt_expr));
+        let names = collect_spatial_predicate_names(&distance_lt_expr);
+        assert_eq!(names, HashSet::from(["st_dwithin".to_string()]));
 
-        // ST_Distance(geom1, geom2) > 100 should return false
+        // ST_Distance(geom1, geom2) > 100 should not be collected (wrong 
comparison direction)
         let distance_gt_expr = 
Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
             left: Box::new(st_distance_expr.clone()),
             op: Operator::Gt,
             right: Box::new(lit(100.0)),
         });
-        assert!(!is_spatial_predicate(&distance_gt_expr));
+        let names = collect_spatial_predicate_names(&distance_gt_expr);
+        assert!(names.is_empty());
 
-        // AND expressions with spatial predicates should return true
+        // AND expression: spatial predicate should be collected through 
conjunction
         let and_expr = Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
             left: Box::new(st_intersects_expr.clone()),
             op: Operator::And,
             right: Box::new(col("id").eq(lit(1))),
         });
-        assert!(is_spatial_predicate(&and_expr));
+        let names = collect_spatial_predicate_names(&and_expr);
+        assert_eq!(names, HashSet::from(["st_intersects".to_string()]));
 
-        // Non-spatial expressions should return false
+        // Non-spatial expressions should return empty set
 
         // Simple column comparison
         let non_spatial_expr = col("id").eq(lit(1));
-        assert!(!is_spatial_predicate(&non_spatial_expr));
+        let names = collect_spatial_predicate_names(&non_spatial_expr);
+        assert!(names.is_empty());
 
         // Not a spatial relationship function
         let non_st_func = 
Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction {
@@ -2299,7 +2297,8 @@ mod tests {
             ))),
             args: vec![col("id")],
         });
-        assert!(!is_spatial_predicate(&non_st_func));
+        let names = collect_spatial_predicate_names(&non_st_func);
+        assert!(names.is_empty());
 
         // AND expression with no spatial predicates
         let non_spatial_and = 
Expr::BinaryExpr(datafusion_expr::expr::BinaryExpr {
@@ -2307,6 +2306,7 @@ mod tests {
             op: Operator::And,
             right: Box::new(col("name").eq(lit("test"))),
         });
-        assert!(!is_spatial_predicate(&non_spatial_and));
+        let names = collect_spatial_predicate_names(&non_spatial_and);
+        assert!(names.is_empty());
     }
 }
diff --git a/rust/sedona-spatial-join/tests/spatial_join_integration.rs 
b/rust/sedona-spatial-join/tests/spatial_join_integration.rs
index 671ce2c8..2add149d 100644
--- a/rust/sedona-spatial-join/tests/spatial_join_integration.rs
+++ b/rust/sedona-spatial-join/tests/spatial_join_integration.rs
@@ -21,12 +21,14 @@ use arrow_array::{Array, RecordBatch};
 use arrow_schema::{DataType, Field, Schema, SchemaRef};
 use datafusion::{
     catalog::{MemTable, TableProvider},
+    datasource::empty::EmptyTable,
     execution::SessionStateBuilder,
     prelude::{SessionConfig, SessionContext},
 };
 use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
 use datafusion_common::Result;
 use datafusion_expr::{ColumnarValue, JoinType};
+use datafusion_physical_plan::filter::FilterExec;
 use datafusion_physical_plan::joins::NestedLoopJoinExec;
 use datafusion_physical_plan::ExecutionPlan;
 use geo::{Distance, Euclidean};
@@ -142,7 +144,7 @@ fn setup_context(options: Option<SpatialJoinOptions>, 
batch_size: usize) -> Resu
     session_config = add_sedona_option_extension(session_config);
     let mut state_builder = SessionStateBuilder::new();
     if let Some(options) = options {
-        state_builder = register_planner(state_builder);
+        state_builder = register_planner(state_builder)?;
         let opts = session_config
             .options_mut()
             .extensions
@@ -1088,72 +1090,90 @@ async fn test_knn_join_with_filter_correctness(
     };
 
     let k = 3;
-    let sql = format!(
-        "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON ST_KNN(L.geometry, 
R.geometry, {}, false) AND (L.id % 7) = (R.id % 7)",
-        k
-    );
+    let sqls = [
+        format!(
+            "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON 
ST_KNN(L.geometry, R.geometry, {}, false) AND (L.id % 7) = (R.id % 7)",
+            k
+        ),
+        format!(
+            "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON 
ST_KNN(L.geometry, R.geometry, {}, false) AND L.id % 7 = 0",
+            k
+        ),
+        format!(
+            "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON 
ST_KNN(L.geometry, R.geometry, {}, false) AND R.id % 7 = 0",
+            k
+        ),
+    ];
 
-    let batches = run_spatial_join_query(
-        &left_schema,
-        &right_schema,
-        left_partitions.clone(),
-        right_partitions.clone(),
-        Some(options),
-        max_batch_size,
-        &sql,
-    )
-    .await?;
+    for (idx, sql) in sqls.iter().enumerate() {
+        let batches = run_spatial_join_query(
+            &left_schema,
+            &right_schema,
+            left_partitions.clone(),
+            right_partitions.clone(),
+            Some(options.clone()),
+            max_batch_size,
+            sql,
+        )
+        .await?;
 
-    let mut actual_results = Vec::new();
-    let combined_batch = arrow::compute::concat_batches(&batches.schema(), 
&[batches])?;
-    let l_ids = combined_batch
-        .column(0)
-        .as_any()
-        .downcast_ref::<arrow_array::Int32Array>()
-        .unwrap();
-    let r_ids = combined_batch
-        .column(1)
-        .as_any()
-        .downcast_ref::<arrow_array::Int32Array>()
-        .unwrap();
+        let mut actual_results = Vec::new();
+        let combined_batch = arrow::compute::concat_batches(&batches.schema(), 
&[batches])?;
+        let l_ids = combined_batch
+            .column(0)
+            .as_any()
+            .downcast_ref::<arrow_array::Int32Array>()
+            .unwrap();
+        let r_ids = combined_batch
+            .column(1)
+            .as_any()
+            .downcast_ref::<arrow_array::Int32Array>()
+            .unwrap();
 
-    for i in 0..combined_batch.num_rows() {
-        actual_results.push((l_ids.value(i), r_ids.value(i)));
-    }
-    actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
+        for i in 0..combined_batch.num_rows() {
+            actual_results.push((l_ids.value(i), r_ids.value(i)));
+        }
+        actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| 
a.1.cmp(&b.1)));
 
-    // Prove the test actually exercises the "< K rows after filtering" case.
-    // Build a list of all probe-side IDs and count how many results each has.
-    let all_left_ids: Vec<i32> = extract_geoms_and_ids(&left_partitions)
-        .into_iter()
-        .map(|(id, _)| id)
-        .collect();
-    let mut per_left_counts: std::collections::HashMap<i32, usize> =
-        std::collections::HashMap::new();
-    for (l_id, _) in &actual_results {
-        *per_left_counts.entry(*l_id).or_default() += 1;
-    }
-    let min_count = all_left_ids
-        .iter()
-        .map(|l_id| *per_left_counts.get(l_id).unwrap_or(&0))
-        .min()
-        .unwrap_or(0);
-    assert!(
-        min_count < k,
-        "expected at least one probe row to produce < K rows after filtering; 
min_count={min_count}, k={k}"
-    );
+        // Prove the test actually exercises the "< K rows after filtering" 
case.
+        // Build a list of all probe-side IDs and count how many results each 
has.
+        let all_left_ids: Vec<i32> = extract_geoms_and_ids(&left_partitions)
+            .into_iter()
+            .map(|(id, _)| id)
+            .collect();
+        let mut per_left_counts: std::collections::HashMap<i32, usize> =
+            std::collections::HashMap::new();
+        for (l_id, _) in &actual_results {
+            *per_left_counts.entry(*l_id).or_default() += 1;
+        }
+        let min_count = all_left_ids
+            .iter()
+            .map(|l_id| *per_left_counts.get(l_id).unwrap_or(&0))
+            .min()
+            .unwrap_or(0);
+        assert!(
+            min_count < k,
+            "expected at least one probe row to produce < K rows after 
filtering; min_count={min_count}, k={k}"
+        );
 
-    let expected_results = compute_knn_ground_truth_with_pair_filter(
-        &left_partitions,
-        &right_partitions,
-        k,
-        |l_id, r_id| (l_id.rem_euclid(7)) == (r_id.rem_euclid(7)),
-    )
-    .into_iter()
-    .map(|(l, r, _)| (l, r))
-    .collect::<Vec<_>>();
+        let filter_closure = match idx {
+            0 => |l_id: i32, r_id: i32| (l_id.rem_euclid(7)) == 
(r_id.rem_euclid(7)),
+            1 => |l_id: i32, _r_id: i32| l_id.rem_euclid(7) == 0,
+            2 => |_l_id: i32, r_id: i32| r_id.rem_euclid(7) == 0,
+            _ => unreachable!(),
+        };
+        let expected_results = compute_knn_ground_truth_with_pair_filter(
+            &left_partitions,
+            &right_partitions,
+            k,
+            filter_closure,
+        )
+        .into_iter()
+        .map(|(l, r, _)| (l, r))
+        .collect::<Vec<_>>();
 
-    assert_eq!(actual_results, expected_results);
+        assert_eq!(actual_results, expected_results);
+    }
 
     Ok(())
 }
@@ -1368,3 +1388,94 @@ async fn test_knn_join_include_tie_breakers(
 
     Ok(())
 }
+
+/// Verify that a filter on the *object* (build / right) side of a KNN join is 
NOT pushed down
+/// into the build side subtree.
+///
+/// If `PushDownFilter` incorrectly pushes `R.id > 5` below the spatial join, 
the set of objects
+/// considered for the KNN search changes, yielding wrong nearest-neighbor 
results.
+#[tokio::test]
+async fn test_knn_join_object_side_filter_not_pushed_down() -> Result<()> {
+    let sql = "SELECT L.id, R.id \
+               FROM L JOIN R ON ST_KNN(ST_Point(L.x, 0), ST_Point(R.x, 1), 3, 
false) \
+               WHERE R.id > 5";
+    let plan = plan_for_filter_pushdown_test(sql).await?;
+
+    let spatial_joins = collect_spatial_join_exec(&plan)?;
+    assert_eq!(
+        spatial_joins.len(),
+        1,
+        "expected exactly one SpatialJoinExec"
+    );
+    let sj = spatial_joins[0];
+
+    // The build (right / object) side must NOT have a FilterExec pushed into 
it.
+    assert!(
+        !subtree_contains_filter_exec(&sj.right),
+        "FilterExec should NOT be pushed into the object (right/build) side of 
a KNN join"
+    );
+
+    Ok(())
+}
+
+/// Verify that for a *non-KNN* spatial join, a filter on the build side IS 
pushed down
+/// (the normal, desirable behaviour).
+#[tokio::test]
+async fn test_non_knn_join_object_side_filter_is_pushed_down() -> Result<()> {
+    let sql = "SELECT L.id, R.id \
+               FROM L JOIN R ON ST_Intersects(ST_Buffer(ST_Point(L.x, 0), 
1.5), ST_Point(R.x, 1)) \
+               WHERE R.id > 5";
+    let plan = plan_for_filter_pushdown_test(sql).await?;
+
+    let spatial_joins = collect_spatial_join_exec(&plan)?;
+    assert_eq!(
+        spatial_joins.len(),
+        1,
+        "expected exactly one SpatialJoinExec"
+    );
+    let sj = spatial_joins[0];
+
+    // For non-KNN joins, the filter SHOULD be pushed down to the build side.
+    assert!(
+        subtree_contains_filter_exec(&sj.right),
+        "FilterExec should be pushed into the object (right/build) side of a 
non-KNN spatial join"
+    );
+
+    Ok(())
+}
+
+/// Recursively check whether any node in the physical plan tree is a 
`FilterExec`.
+fn subtree_contains_filter_exec(plan: &Arc<dyn ExecutionPlan>) -> bool {
+    let mut found = false;
+    plan.apply(|node| {
+        if node.as_any().downcast_ref::<FilterExec>().is_some() {
+            found = true;
+            return Ok(TreeNodeRecursion::Stop);
+        }
+        Ok(TreeNodeRecursion::Continue)
+    })
+    .expect("failed to walk plan");
+    found
+}
+
+/// Create a session context with two small tables for filter-pushdown tests.
+///
+/// L(id INT, x DOUBLE) and R(id INT, x DOUBLE) are all empty, this is just 
for exercising the
+/// plan optimizer and physical planner.
+/// Geometry is constructed in SQL via ST_Point so no geometry column exists 
on the table itself.
+async fn plan_for_filter_pushdown_test(sql: &str) -> Result<Arc<dyn 
ExecutionPlan>> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("id", DataType::Int32, false),
+        Field::new("x", DataType::Float64, false),
+    ]));
+
+    let options = SpatialJoinOptions::default();
+    let ctx = setup_context(Some(options), 100)?;
+    let empty_l: Arc<dyn TableProvider> = 
Arc::new(EmptyTable::new(schema.clone()));
+    let empty_r: Arc<dyn TableProvider> = 
Arc::new(EmptyTable::new(schema.clone()));
+    ctx.register_table("L", empty_l)?;
+    ctx.register_table("R", empty_r)?;
+
+    let df = ctx.sql(sql).await?;
+    df.create_physical_plan().await
+}
diff --git a/rust/sedona/src/context.rs b/rust/sedona/src/context.rs
index 2ea9cf03..fa659f82 100644
--- a/rust/sedona/src/context.rs
+++ b/rust/sedona/src/context.rs
@@ -117,7 +117,7 @@ impl SedonaContext {
         // Register the spatial join planner extension
         #[cfg(feature = "spatial-join")]
         {
-            state_builder = 
sedona_spatial_join::register_planner(state_builder);
+            state_builder = 
sedona_spatial_join::register_planner(state_builder)?;
         }
 
         let mut state = state_builder.build();

Reply via email to