Kontinuation commented on code in PR #611:
URL: https://github.com/apache/sedona-db/pull/611#discussion_r2815280907


##########
rust/sedona-spatial-join/src/planner/optimizer.rs:
##########
@@ -93,47 +184,57 @@ impl OptimizerRule for SpatialJoinLogicalRewrite {
             return Ok(Transformed::no(plan));
         }

Review Comment:
   Good idea. Changed to return an internal error here.



##########
rust/sedona-spatial-join/tests/spatial_join_integration.rs:
##########
@@ -1088,72 +1090,90 @@ async fn test_knn_join_with_filter_correctness(
     };
 
     let k = 3;
-    let sql = format!(
-        "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON ST_KNN(L.geometry, 
R.geometry, {}, false) AND (L.id % 7) = (R.id % 7)",
-        k
-    );
+    let sqls = [
+        format!(
+            "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON 
ST_KNN(L.geometry, R.geometry, {}, false) AND (L.id % 7) = (R.id % 7)",
+            k
+        ),
+        format!(
+            "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON 
ST_KNN(L.geometry, R.geometry, {}, false) AND L.id % 7 = 0",
+            k
+        ),
+        format!(
+            "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON 
ST_KNN(L.geometry, R.geometry, {}, false) AND R.id % 7 = 0",
+            k
+        ),
+    ];
 
-    let batches = run_spatial_join_query(
-        &left_schema,
-        &right_schema,
-        left_partitions.clone(),
-        right_partitions.clone(),
-        Some(options),
-        max_batch_size,
-        &sql,
-    )
-    .await?;
+    for (idx, sql) in sqls.iter().enumerate() {
+        let batches = run_spatial_join_query(
+            &left_schema,
+            &right_schema,
+            left_partitions.clone(),
+            right_partitions.clone(),
+            Some(options.clone()),
+            max_batch_size,
+            sql,
+        )
+        .await?;
 
-    let mut actual_results = Vec::new();
-    let combined_batch = arrow::compute::concat_batches(&batches.schema(), 
&[batches])?;
-    let l_ids = combined_batch
-        .column(0)
-        .as_any()
-        .downcast_ref::<arrow_array::Int32Array>()
-        .unwrap();
-    let r_ids = combined_batch
-        .column(1)
-        .as_any()
-        .downcast_ref::<arrow_array::Int32Array>()
-        .unwrap();
+        let mut actual_results = Vec::new();
+        let combined_batch = arrow::compute::concat_batches(&batches.schema(), 
&[batches])?;
+        let l_ids = combined_batch
+            .column(0)
+            .as_any()
+            .downcast_ref::<arrow_array::Int32Array>()
+            .unwrap();
+        let r_ids = combined_batch
+            .column(1)
+            .as_any()
+            .downcast_ref::<arrow_array::Int32Array>()
+            .unwrap();
 
-    for i in 0..combined_batch.num_rows() {
-        actual_results.push((l_ids.value(i), r_ids.value(i)));
-    }
-    actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
+        for i in 0..combined_batch.num_rows() {
+            actual_results.push((l_ids.value(i), r_ids.value(i)));
+        }
+        actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| 
a.1.cmp(&b.1)));
 
-    // Prove the test actually exercises the "< K rows after filtering" case.
-    // Build a list of all probe-side IDs and count how many results each has.
-    let all_left_ids: Vec<i32> = extract_geoms_and_ids(&left_partitions)
-        .into_iter()
-        .map(|(id, _)| id)
-        .collect();
-    let mut per_left_counts: std::collections::HashMap<i32, usize> =
-        std::collections::HashMap::new();
-    for (l_id, _) in &actual_results {
-        *per_left_counts.entry(*l_id).or_default() += 1;
-    }
-    let min_count = all_left_ids
-        .iter()
-        .map(|l_id| *per_left_counts.get(l_id).unwrap_or(&0))
-        .min()
-        .unwrap_or(0);
-    assert!(
-        min_count < k,
-        "expected at least one probe row to produce < K rows after filtering; 
min_count={min_count}, k={k}"
-    );
+        // Prove the test actually exercises the "< K rows after filtering" 
case.
+        // Build a list of all probe-side IDs and count how many results each 
has.
+        let all_left_ids: Vec<i32> = extract_geoms_and_ids(&left_partitions)
+            .into_iter()
+            .map(|(id, _)| id)
+            .collect();
+        let mut per_left_counts: std::collections::HashMap<i32, usize> =
+            std::collections::HashMap::new();
+        for (l_id, _) in &actual_results {
+            *per_left_counts.entry(*l_id).or_default() += 1;
+        }
+        let min_count = all_left_ids
+            .iter()
+            .map(|l_id| *per_left_counts.get(l_id).unwrap_or(&0))
+            .min()
+            .unwrap_or(0);
+        assert!(
+            min_count < k,
+            "expected at least one probe row to produce < K rows after 
filtering; min_count={min_count}, k={k}"
+        );
 
-    let expected_results = compute_knn_ground_truth_with_pair_filter(
-        &left_partitions,
-        &right_partitions,
-        k,
-        |l_id, r_id| (l_id.rem_euclid(7)) == (r_id.rem_euclid(7)),
-    )
-    .into_iter()
-    .map(|(l, r, _)| (l, r))
-    .collect::<Vec<_>>();
+        let filter_closure = match idx {
+            0 => |l_id: i32, r_id: i32| (l_id.rem_euclid(7)) == 
(r_id.rem_euclid(7)),
+            1 => |l_id: i32, _r_id: i32| l_id.rem_euclid(7) == 0,
+            2 => |_l_id: i32, r_id: i32| r_id.rem_euclid(7) == 0,
+            _ => unreachable!(),
+        };

Review Comment:
   It compiled so I guess there's no problem. The Rust compiler might have 
improved to allow such code to compile.



##########
rust/sedona-spatial-join/tests/spatial_join_integration.rs:
##########
@@ -1368,3 +1388,93 @@ async fn test_knn_join_include_tie_breakers(
 
     Ok(())
 }
+
+/// Verify that a filter on the *object* (build / right) side of a KNN join is 
NOT pushed down
+/// into the build side subtree.
+///
+/// If `PushDownFilter` incorrectly pushes `R.id > 5` below the spatial join, 
the set of objects
+/// considered for KNN changes, yielding wrong nearest-neighbor results.
+#[tokio::test]
+async fn test_knn_join_object_side_filter_not_pushed_down() -> Result<()> {
+    let sql = "SELECT L.id, R.id \
+               FROM L JOIN R ON ST_KNN(ST_Point(L.x, 0), ST_Point(R.x, 1), 3, 
false) \
+               WHERE R.id > 5";
+    let plan = plan_for_filter_pushdown_test(sql).await?;
+
+    let spatial_joins = collect_spatial_join_exec(&plan)?;
+    assert_eq!(
+        spatial_joins.len(),
+        1,
+        "expected exactly one SpatialJoinExec"
+    );
+    let sj = spatial_joins[0];
+
+    // The build (right / object) side must NOT have a FilterExec pushed into 
it.
+    assert!(
+        !subtree_contains_filter_exec(&sj.right),
+        "FilterExec should NOT be pushed into the object (right/build) side of 
a KNN join"
+    );
+
+    Ok(())
+}
+
+/// Verify that for a *non-KNN* spatial join, a filter on the build side IS 
pushed down
+/// (the normal, desirable behaviour).
+#[tokio::test]
+async fn test_non_knn_join_object_side_filter_is_pushed_down() -> Result<()> {
+    let sql = "SELECT L.id, R.id \
+               FROM L JOIN R ON ST_Intersects(ST_Buffer(ST_Point(L.x, 0), 
1.5), ST_Point(R.x, 1)) \
+               WHERE R.id > 5";
+    let plan = plan_for_filter_pushdown_test(sql).await?;
+
+    let spatial_joins = collect_spatial_join_exec(&plan)?;
+    assert_eq!(
+        spatial_joins.len(),
+        1,
+        "expected exactly one SpatialJoinExec"
+    );
+    let sj = spatial_joins[0];
+
+    // For non-KNN joins, the filter SHOULD be pushed down to the build side.
+    assert!(
+        subtree_contains_filter_exec(&sj.right),
+        "FilterExec should be pushed into the object (right/build) side of a 
non-KNN spatial join"
+    );
+
+    Ok(())
+}
+
+/// Recursively check whether any node in the physical plan tree is a 
`FilterExec`.
+fn subtree_contains_filter_exec(plan: &Arc<dyn ExecutionPlan>) -> bool {
+    let mut found = false;
+    plan.apply(|node| {
+        if node.as_any().downcast_ref::<FilterExec>().is_some() {
+            found = true;
+            return Ok(TreeNodeRecursion::Stop);
+        }
+        Ok(TreeNodeRecursion::Continue)
+    })
+    .expect("failed to walk plan");
+    found
+}
+
+/// Create a session context with two small tables for filter-pushdown tests.
+///
+/// L(id INT, x DOUBLE) and R(id INT, x DOUBLE) each with 10 rows.
+/// Geometry is constructed in SQL via ST_Point so no geometry column exists 
on the table itself.
+async fn plan_for_filter_pushdown_test(sql: &str) -> Result<Arc<dyn 
ExecutionPlan>> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("id", DataType::Int32, false),
+        Field::new("x", DataType::Float64, false),
+    ]));
+
+    let options = SpatialJoinOptions::default();
+    let ctx = setup_context(Some(options), 100)?;
+    let empty_l: Arc<dyn TableProvider> = 
Arc::new(EmptyTable::new(schema.clone()));
+    let empty_r: Arc<dyn TableProvider> = 
Arc::new(EmptyTable::new(schema.clone()));
+    ctx.register_table("L", empty_l)?;
+    ctx.register_table("R", empty_r)?;

Review Comment:
   Fixed the comment



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to