This is an automated email from the ASF dual-hosted git repository.

kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new 51b2bc95 refactor(rust/sedona-spatial-join): Unify execute and 
execute_knn, move end-to-end tests in exec.rs to tests directory (#593)
51b2bc95 is described below

commit 51b2bc95303a48495eb6eb6d73545dca9582adf9
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Feb 11 20:38:15 2026 +0800

    refactor(rust/sedona-spatial-join): Unify execute and execute_knn, move 
end-to-end tests in exec.rs to tests directory (#593)
    
    ## Summary
    - relocate spatial join integration tests from `exec.rs` into 
`rust/sedona-spatial-join/tests/`
    - keep test helpers and data builders intact while switching to public 
crate imports
    - preserve existing behavior while enabling integration-style execution
---
 rust/sedona-spatial-join/src/exec.rs               | 1401 +-------------------
 .../tests/spatial_join_integration.rs              | 1291 ++++++++++++++++++
 2 files changed, 1312 insertions(+), 1380 deletions(-)

diff --git a/rust/sedona-spatial-join/src/exec.rs 
b/rust/sedona-spatial-join/src/exec.rs
index 053b2ffb..6cbe8900 100644
--- a/rust/sedona-spatial-join/src/exec.rs
+++ b/rust/sedona-spatial-join/src/exec.rs
@@ -443,101 +443,27 @@ impl ExecutionPlan for SpatialJoinExec {
         &self,
         partition: usize,
         context: Arc<TaskContext>,
-    ) -> Result<SendableRecordBatchStream> {
-        match &self.on {
-            SpatialPredicate::KNearestNeighbors(_) => 
self.execute_knn(partition, context),
-            _ => {
-                // Regular spatial join logic - standard left=build, 
right=probe semantics
-                let session_config = context.session_config();
-
-                // Regular join semantics: left is build, right is probe
-                let (build_plan, probe_plan) = (&self.left, &self.right);
-
-                // A OnceFut for preparing the spatial join components once.
-                let once_fut_spatial_join_components = {
-                    let mut once_async = 
self.once_async_spatial_join_components.lock();
-                    once_async
-                        .get_or_insert(OnceAsync::default())
-                        .try_once(|| {
-                            let build_side = build_plan;
-
-                            let num_partitions = 
build_side.output_partitioning().partition_count();
-                            let mut build_streams = 
Vec::with_capacity(num_partitions);
-                            for k in 0..num_partitions {
-                                let stream = build_side.execute(k, 
Arc::clone(&context))?;
-                                build_streams.push(stream);
-                            }
-
-                            let probe_thread_count =
-                                
self.right.output_partitioning().partition_count();
-                            let spatial_join_components_builder = 
SpatialJoinComponentsBuilder::new(
-                                Arc::clone(&context),
-                                build_side.schema(),
-                                self.on.clone(),
-                                self.join_type,
-                                probe_thread_count,
-                                self.metrics.clone(),
-                                self.seed,
-                            );
-                            
Ok(spatial_join_components_builder.build(build_streams))
-                        })?
-                };
-
-                let column_indices_after_projection = match &self.projection {
-                    Some(projection) => projection
-                        .iter()
-                        .map(|i| self.column_indices[*i].clone())
-                        .collect(),
-                    None => self.column_indices.clone(),
-                };
-
-                let probe_stream = probe_plan.execute(partition, 
Arc::clone(&context))?;
-
-                // For regular joins: probe is right side (index 1)
-                let probe_side_ordered =
-                    self.maintains_input_order()[1] && 
self.right.output_ordering().is_some();
-
-                Ok(Box::pin(SpatialJoinStream::new(
-                    partition,
-                    self.schema(),
-                    &self.on,
-                    self.filter.clone(),
-                    self.join_type,
-                    probe_stream,
-                    column_indices_after_projection,
-                    probe_side_ordered,
-                    session_config,
-                    context.runtime_env(),
-                    &self.metrics,
-                    once_fut_spatial_join_components,
-                    Arc::clone(&self.once_async_spatial_join_components),
-                )))
-            }
-        }
-    }
-}
-
-impl SpatialJoinExec {
-    /// Execute KNN (K-Nearest Neighbors) spatial join with specialized logic 
for asymmetric KNN semantics
-    fn execute_knn(
-        &self,
-        partition: usize,
-        context: Arc<TaskContext>,
     ) -> Result<SendableRecordBatchStream> {
         let session_config = context.session_config();
 
-        // Extract KNN predicate for type safety
-        let knn_pred = match &self.on {
-            SpatialPredicate::KNearestNeighbors(knn_pred) => knn_pred,
-            _ => unreachable!("execute_knn called with non-KNN predicate"),
+        // Determine build/probe plans based on predicate type.
+        // For KNN joins, the probe/build assignment is dynamic based on the 
KNN predicate's
+        // probe_side. For regular spatial joins, left is always build and 
right is always probe.
+        let (build_plan, probe_plan, probe_side) = match &self.on {
+            SpatialPredicate::KNearestNeighbors(knn_pred) => {
+                let (build_plan, probe_plan) = determine_knn_build_probe_plans(
+                    knn_pred,
+                    &self.left,
+                    &self.right,
+                    &self.join_schema,
+                )?;
+                (build_plan, probe_plan, knn_pred.probe_side)
+            }
+            _ => (&self.left, &self.right, JoinSide::Right),
         };
 
-        // Determine which execution plan should be build vs probe using join 
schema analysis
-        let (build_plan, probe_plan) =
-            determine_knn_build_probe_plans(knn_pred, &self.left, &self.right, 
&self.join_schema)?;
-
-        // Determine if probe plan is the left execution plan (for column 
index swapping logic)
-        let actual_probe_plan_is_left = std::ptr::eq(probe_plan.as_ref(), 
self.left.as_ref());
+        // Determine which input index corresponds to the probe side for 
ordering checks
+        let probe_input_index = if probe_side == JoinSide::Left { 0 } else { 1 
};
 
         // A OnceFut for preparing the spatial join components once.
         let once_fut_spatial_join_components = {
@@ -545,19 +471,17 @@ impl SpatialJoinExec {
             once_async
                 .get_or_insert(OnceAsync::default())
                 .try_once(|| {
-                    let build_side = build_plan;
-
-                    let num_partitions = 
build_side.output_partitioning().partition_count();
+                    let num_partitions = 
build_plan.output_partitioning().partition_count();
                     let mut build_streams = Vec::with_capacity(num_partitions);
                     for k in 0..num_partitions {
-                        let stream = build_side.execute(k, 
Arc::clone(&context))?;
+                        let stream = build_plan.execute(k, 
Arc::clone(&context))?;
                         build_streams.push(stream);
                     }
 
                     let probe_thread_count = 
probe_plan.output_partitioning().partition_count();
                     let spatial_join_components_builder = 
SpatialJoinComponentsBuilder::new(
                         Arc::clone(&context),
-                        build_side.schema(),
+                        build_plan.schema(),
                         self.on.clone(),
                         self.join_type,
                         probe_thread_count,
@@ -578,14 +502,8 @@ impl SpatialJoinExec {
 
         let probe_stream = probe_plan.execute(partition, 
Arc::clone(&context))?;
 
-        // Determine if probe side ordering is maintained for KNN
-        let probe_side_ordered = if actual_probe_plan_is_left {
-            // Actual probe is left plan
-            self.maintains_input_order()[0] && 
self.left.output_ordering().is_some()
-        } else {
-            // Actual probe is right plan
-            self.maintains_input_order()[1] && 
self.right.output_ordering().is_some()
-        };
+        let probe_side_ordered = 
self.maintains_input_order()[probe_input_index]
+            && probe_plan.output_ordering().is_some();
 
         Ok(Box::pin(SpatialJoinStream::new(
             partition,
@@ -604,1280 +522,3 @@ impl SpatialJoinExec {
         )))
     }
 }
-
-#[cfg(test)]
-mod tests {
-    use arrow_array::{Array, RecordBatch};
-    use arrow_schema::{DataType, Field, Schema};
-    use datafusion::{
-        catalog::{MemTable, TableProvider},
-        execution::SessionStateBuilder,
-        prelude::{SessionConfig, SessionContext},
-    };
-    use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
-    use datafusion_expr::ColumnarValue;
-    use datafusion_physical_plan::joins::NestedLoopJoinExec;
-    use geo::{Distance, Euclidean};
-    use geo_types::{Coord, Rect};
-    use rstest::rstest;
-    use sedona_common::SedonaOptions;
-    use sedona_geo::to_geo::GeoTypesExecutor;
-    use sedona_geometry::types::GeometryTypeId;
-    use sedona_schema::datatypes::{SedonaType, WKB_GEOGRAPHY, WKB_GEOMETRY};
-    use sedona_testing::datagen::RandomPartitionedDataBuilder;
-    use tokio::sync::OnceCell;
-
-    use crate::register_spatial_join_optimizer;
-    use sedona_common::{
-        option::{add_sedona_option_extension, ExecutionMode, 
SpatialJoinOptions},
-        NumSpatialPartitionsConfig, SpatialJoinDebugOptions, SpatialLibrary,
-    };
-
-    use super::*;
-
-    type TestPartitions = (SchemaRef, Vec<Vec<RecordBatch>>);
-
-    /// Creates standard test data with left (Polygon) and right (Point) 
partitions
-    fn create_default_test_data() -> Result<(TestPartitions, TestPartitions)> {
-        create_test_data_with_size_range((1.0, 10.0), WKB_GEOMETRY)
-    }
-
-    /// Creates test data with custom size range
-    fn create_test_data_with_size_range(
-        size_range: (f64, f64),
-        sedona_type: SedonaType,
-    ) -> Result<(TestPartitions, TestPartitions)> {
-        let bounds = Rect::new(Coord { x: 0.0, y: 0.0 }, Coord { x: 100.0, y: 
100.0 });
-
-        let left_data = RandomPartitionedDataBuilder::new()
-            .seed(11584)
-            .num_partitions(2)
-            .batches_per_partition(2)
-            .rows_per_batch(30)
-            .geometry_type(GeometryTypeId::Polygon)
-            .sedona_type(sedona_type.clone())
-            .bounds(bounds)
-            .size_range(size_range)
-            .null_rate(0.1)
-            .build()?;
-
-        let right_data = RandomPartitionedDataBuilder::new()
-            .seed(54843)
-            .num_partitions(4)
-            .batches_per_partition(4)
-            .rows_per_batch(30)
-            .geometry_type(GeometryTypeId::Point)
-            .sedona_type(sedona_type)
-            .bounds(bounds)
-            .size_range(size_range)
-            .null_rate(0.1)
-            .build()?;
-
-        Ok((left_data, right_data))
-    }
-
-    /// Creates test data with empty partitions inserted at beginning and end
-    fn create_test_data_with_empty_partitions() -> Result<(TestPartitions, 
TestPartitions)> {
-        let (mut left_data, mut right_data) = create_default_test_data()?;
-
-        // Add empty partitions
-        left_data.1.insert(0, vec![]);
-        left_data.1.push(vec![]);
-        right_data.1.insert(0, vec![]);
-        right_data.1.push(vec![]);
-
-        Ok((left_data, right_data))
-    }
-
-    /// Creates test data for KNN join (Point-Point)
-    fn create_knn_test_data(
-        size_range: (f64, f64),
-        sedona_type: SedonaType,
-    ) -> Result<(TestPartitions, TestPartitions)> {
-        let bounds = Rect::new(Coord { x: 0.0, y: 0.0 }, Coord { x: 100.0, y: 
100.0 });
-
-        let left_data = RandomPartitionedDataBuilder::new()
-            .seed(1)
-            .num_partitions(2)
-            .batches_per_partition(2)
-            .rows_per_batch(30)
-            .geometry_type(GeometryTypeId::Point)
-            .sedona_type(sedona_type.clone())
-            .bounds(bounds)
-            .size_range(size_range)
-            .null_rate(0.1)
-            .build()?;
-
-        let right_data = RandomPartitionedDataBuilder::new()
-            .seed(2)
-            .num_partitions(4)
-            .batches_per_partition(4)
-            .rows_per_batch(30)
-            .geometry_type(GeometryTypeId::Point)
-            .sedona_type(sedona_type)
-            .bounds(bounds)
-            .size_range(size_range)
-            .null_rate(0.1)
-            .build()?;
-
-        Ok((left_data, right_data))
-    }
-
-    fn setup_context(
-        options: Option<SpatialJoinOptions>,
-        batch_size: usize,
-    ) -> Result<SessionContext> {
-        let mut session_config = SessionConfig::from_env()?
-            .with_information_schema(true)
-            .with_batch_size(batch_size);
-        session_config = add_sedona_option_extension(session_config);
-        let mut state_builder = SessionStateBuilder::new();
-        if let Some(options) = options {
-            state_builder = register_spatial_join_optimizer(state_builder);
-            let opts = session_config
-                .options_mut()
-                .extensions
-                .get_mut::<SedonaOptions>()
-                .unwrap();
-            opts.spatial_join = options;
-        }
-        let state = state_builder.with_config(session_config).build();
-        let ctx = SessionContext::new_with_state(state);
-
-        let mut function_set = 
sedona_functions::register::default_function_set();
-        let scalar_kernels = sedona_geos::register::scalar_kernels();
-
-        function_set.scalar_udfs().for_each(|udf| {
-            ctx.register_udf(udf.clone().into());
-        });
-
-        for (name, kernel) in scalar_kernels.into_iter() {
-            let udf = function_set.add_scalar_udf_impl(name, kernel)?;
-            ctx.register_udf(udf.clone().into());
-        }
-
-        Ok(ctx)
-    }
-
-    #[tokio::test]
-    async fn test_empty_data() -> Result<()> {
-        let schema = Arc::new(Schema::new(vec![
-            Field::new("id", DataType::Int32, false),
-            Field::new("dist", DataType::Float64, false),
-            WKB_GEOMETRY.to_storage_field("geometry", true).unwrap(),
-        ]));
-
-        let test_data_vec = vec![vec![vec![]], vec![vec![], vec![]]];
-
-        let options = SpatialJoinOptions::default();
-        let ctx = setup_context(Some(options.clone()), 10)?;
-        for test_data in test_data_vec {
-            let left_partitions = test_data.clone();
-            let right_partitions = test_data;
-
-            let mem_table_left: Arc<dyn TableProvider> = 
Arc::new(MemTable::try_new(
-                Arc::clone(&schema),
-                left_partitions.clone(),
-            )?);
-            let mem_table_right: Arc<dyn TableProvider> = 
Arc::new(MemTable::try_new(
-                Arc::clone(&schema),
-                right_partitions.clone(),
-            )?);
-
-            ctx.deregister_table("L")?;
-            ctx.deregister_table("R")?;
-            ctx.register_table("L", Arc::clone(&mem_table_left))?;
-            ctx.register_table("R", Arc::clone(&mem_table_right))?;
-
-            let sql = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id";
-            let df = ctx.sql(sql).await?;
-            let result_batches = df.collect().await?;
-            for result_batch in result_batches {
-                assert_eq!(result_batch.num_rows(), 0);
-            }
-        }
-
-        Ok(())
-    }
-
-    // Shared test data and expected results - computed only once across all 
parameterized test cases
-    // Using tokio::sync::OnceCell for async lazy initialization to avoid 
recomputing expensive
-    // test data generation and nested loop join results for each test 
parameter combination
-    static TEST_DATA: OnceCell<(TestPartitions, TestPartitions)> = 
OnceCell::const_new();
-    static RANGE_JOIN_EXPECTED_RESULTS: OnceCell<Vec<RecordBatch>> = 
OnceCell::const_new();
-    static DIST_JOIN_EXPECTED_RESULTS: OnceCell<Vec<RecordBatch>> = 
OnceCell::const_new();
-
-    const RANGE_JOIN_SQL1: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R 
ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id";
-    const RANGE_JOIN_SQL2: &str =
-        "SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER 
BY L.id, R.id";
-    const RANGE_JOIN_SQLS: &[&str] = &[RANGE_JOIN_SQL1, RANGE_JOIN_SQL2];
-
-    const DIST_JOIN_SQL1: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Distance(L.geometry, R.geometry) < 1.0 ORDER BY l_id, r_id";
-    const DIST_JOIN_SQL2: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Distance(L.geometry, R.geometry) < L.dist / 10.0 ORDER BY l_id, r_id";
-    const DIST_JOIN_SQL3: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Distance(L.geometry, R.geometry) < R.dist / 10.0 ORDER BY l_id, r_id";
-    const DIST_JOIN_SQL4: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_DWithin(L.geometry, R.geometry, 1.0) ORDER BY l_id, r_id";
-    const DIST_JOIN_SQLS: &[&str] = &[
-        DIST_JOIN_SQL1,
-        DIST_JOIN_SQL2,
-        DIST_JOIN_SQL3,
-        DIST_JOIN_SQL4,
-    ];
-
-    /// Get test data, computing it only once
-    async fn get_default_test_data() -> &'static (TestPartitions, 
TestPartitions) {
-        TEST_DATA
-            .get_or_init(|| async {
-                create_default_test_data().expect("Failed to create test data")
-            })
-            .await
-    }
-
-    /// Get expected results, computing them only once
-    async fn get_expected_range_join_results() -> &'static Vec<RecordBatch> {
-        get_or_init_expected_join_results(&RANGE_JOIN_EXPECTED_RESULTS, 
RANGE_JOIN_SQLS).await
-    }
-
-    async fn get_expected_distance_join_results() -> &'static Vec<RecordBatch> 
{
-        get_or_init_expected_join_results(&DIST_JOIN_EXPECTED_RESULTS, 
DIST_JOIN_SQLS).await
-    }
-
-    async fn get_or_init_expected_join_results<'a>(
-        lazy_init_results: &'a OnceCell<Vec<RecordBatch>>,
-        sql_queries: &[&str],
-    ) -> &'a Vec<RecordBatch> {
-        lazy_init_results
-            .get_or_init(|| async {
-                let test_data = get_default_test_data().await;
-                let ((left_schema, left_partitions), (right_schema, 
right_partitions)) = test_data;
-
-                let batch_size = 10;
-
-                // Run nested loop join to get expected results
-                let mut expected_results = 
Vec::with_capacity(sql_queries.len());
-
-                for (i, sql) in sql_queries.iter().enumerate() {
-                    let result = run_spatial_join_query(
-                        left_schema,
-                        right_schema,
-                        left_partitions.clone(),
-                        right_partitions.clone(),
-                        None,
-                        batch_size,
-                        sql,
-                    )
-                    .await
-                    .unwrap_or_else(|_| panic!("Failed to generate expected 
result {}", i + 1));
-                    expected_results.push(result);
-                }
-
-                expected_results
-            })
-            .await
-    }
-
-    #[rstest]
-    #[tokio::test]
-    async fn test_range_join_with_conf(
-        #[values(10, 30, 1000)] max_batch_size: usize,
-        #[values(
-            ExecutionMode::PrepareNone,
-            ExecutionMode::PrepareBuild,
-            ExecutionMode::PrepareProbe,
-            ExecutionMode::Speculative(20)
-        )]
-        execution_mode: ExecutionMode,
-        #[values(SpatialLibrary::Geo, SpatialLibrary::Geos, 
SpatialLibrary::Tg)]
-        spatial_library: SpatialLibrary,
-    ) -> Result<()> {
-        let test_data = get_default_test_data().await;
-        let expected_results = get_expected_range_join_results().await;
-        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
= test_data;
-
-        let options = SpatialJoinOptions {
-            spatial_library,
-            execution_mode,
-            ..Default::default()
-        };
-        for (idx, sql) in RANGE_JOIN_SQLS.iter().enumerate() {
-            let actual_result = run_spatial_join_query(
-                left_schema,
-                right_schema,
-                left_partitions.clone(),
-                right_partitions.clone(),
-                Some(options.clone()),
-                max_batch_size,
-                sql,
-            )
-            .await?;
-            assert_eq!(&actual_result, &expected_results[idx]);
-        }
-
-        Ok(())
-    }
-
-    #[rstest]
-    #[tokio::test]
-    async fn test_distance_join_with_conf(
-        #[values(30, 1000)] max_batch_size: usize,
-        #[values(SpatialLibrary::Geo, SpatialLibrary::Geos, 
SpatialLibrary::Tg)]
-        spatial_library: SpatialLibrary,
-    ) -> Result<()> {
-        let test_data = get_default_test_data().await;
-        let expected_results = get_expected_distance_join_results().await;
-        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
= test_data;
-
-        let options = SpatialJoinOptions {
-            spatial_library,
-            ..Default::default()
-        };
-        for (idx, sql) in DIST_JOIN_SQLS.iter().enumerate() {
-            let actual_result = run_spatial_join_query(
-                left_schema,
-                right_schema,
-                left_partitions.clone(),
-                right_partitions.clone(),
-                Some(options.clone()),
-                max_batch_size,
-                sql,
-            )
-            .await?;
-            assert_eq!(&actual_result, &expected_results[idx]);
-        }
-
-        Ok(())
-    }
-
-    #[tokio::test]
-    async fn test_spatial_join_with_filter() -> Result<()> {
-        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
=
-            create_test_data_with_size_range((0.1, 10.0), WKB_GEOMETRY)?;
-
-        for max_batch_size in [10, 30, 100] {
-            let options = SpatialJoinOptions::default();
-            test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
-                "SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, 
R.geometry) AND L.dist < R.dist ORDER BY L.id, R.id").await?;
-            test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
-                "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Intersects(L.geometry, R.geometry) AND L.dist < R.dist ORDER BY l_id, 
r_id").await?;
-            test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
-                "SELECT L.id l_id, R.id r_id, L.dist l_dist, R.dist r_dist 
FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) AND L.dist < R.dist 
ORDER BY l_id, r_id").await?;
-        }
-
-        Ok(())
-    }
-
-    #[tokio::test]
-    async fn test_range_join_with_empty_partitions() -> Result<()> {
-        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
=
-            create_test_data_with_empty_partitions()?;
-
-        for max_batch_size in [10, 30, 1000] {
-            let options = SpatialJoinOptions::default();
-            test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
-                "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id").await?;
-            test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
-                "SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, 
R.geometry) ORDER BY L.id, R.id").await?;
-        }
-
-        Ok(())
-    }
-
-    #[tokio::test]
-    async fn test_inner_join() -> Result<()> {
-        let options = SpatialJoinOptions::default();
-        test_with_join_types(JoinType::Inner, options, 30).await?;
-        Ok(())
-    }
-
-    #[rstest]
-    #[tokio::test]
-    async fn test_left_joins(
-        #[values(JoinType::Left, JoinType::LeftSemi, JoinType::LeftAnti)] 
join_type: JoinType,
-    ) -> Result<()> {
-        let options = SpatialJoinOptions::default();
-        test_with_join_types(join_type, options, 30).await?;
-        Ok(())
-    }
-
-    #[rstest]
-    #[tokio::test]
-    async fn test_right_joins(
-        #[values(JoinType::Right, JoinType::RightSemi, JoinType::RightAnti)] 
join_type: JoinType,
-    ) -> Result<()> {
-        let options = SpatialJoinOptions::default();
-        test_with_join_types(join_type, options, 30).await?;
-        Ok(())
-    }
-
-    #[tokio::test]
-    async fn test_full_outer_join() -> Result<()> {
-        let options = SpatialJoinOptions::default();
-        test_with_join_types(JoinType::Full, options, 30).await?;
-        Ok(())
-    }
-
-    #[tokio::test]
-    async fn test_geography_join_is_not_optimized() -> Result<()> {
-        let options = SpatialJoinOptions::default();
-        let ctx = setup_context(Some(options), 10)?;
-
-        // Prepare geography tables
-        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
=
-            create_test_data_with_size_range((0.1, 10.0), WKB_GEOGRAPHY)?;
-        let mem_table_left: Arc<dyn TableProvider> =
-            Arc::new(MemTable::try_new(left_schema, left_partitions)?);
-        let mem_table_right: Arc<dyn TableProvider> =
-            Arc::new(MemTable::try_new(right_schema, right_partitions)?);
-        ctx.register_table("L", mem_table_left)?;
-        ctx.register_table("R", mem_table_right)?;
-
-        // Execute geography join query
-        let df = ctx
-            .sql("SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, 
R.geometry)")
-            .await?;
-        let plan = df.create_physical_plan().await?;
-
-        // Verify that no SpatialJoinExec is present (geography join should 
not be optimized)
-        let spatial_joins = collect_spatial_join_exec(&plan)?;
-        assert!(
-            spatial_joins.is_empty(),
-            "Geography joins should not be optimized to SpatialJoinExec"
-        );
-
-        Ok(())
-    }
-
-    #[tokio::test]
-    async fn test_query_window_in_subquery() -> Result<()> {
-        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
=
-            create_test_data_with_size_range((50.0, 60.0), WKB_GEOMETRY)?;
-        let options = SpatialJoinOptions::default();
-        test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, 10,
-                "SELECT id FROM L WHERE ST_Intersects(L.geometry, (SELECT 
R.geometry FROM R WHERE R.id = 1))").await?;
-        Ok(())
-    }
-
-    #[tokio::test]
-    async fn test_parallel_refinement_for_large_candidate_set() -> Result<()> {
-        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
=
-            create_test_data_with_size_range((1.0, 50.0), WKB_GEOMETRY)?;
-
-        for max_batch_size in [10, 30, 100] {
-            let options = SpatialJoinOptions {
-                parallel_refinement_chunk_size: 10,
-                ..Default::default()
-            };
-            test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
-                "SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, 
R.geometry) AND L.dist < R.dist ORDER BY L.id, R.id").await?;
-        }
-
-        Ok(())
-    }
-
-    #[rstest]
-    #[tokio::test]
-    async fn test_spatial_partitioned_range_join(
-        #[values(10, 30, 1000)] max_batch_size: usize,
-        #[values(
-            ExecutionMode::PrepareNone,
-            ExecutionMode::PrepareBuild,
-            ExecutionMode::PrepareProbe,
-            ExecutionMode::Speculative(20)
-        )]
-        execution_mode: ExecutionMode,
-        #[values(SpatialLibrary::Geo, SpatialLibrary::Geos, 
SpatialLibrary::Tg)]
-        spatial_library: SpatialLibrary,
-    ) -> Result<()> {
-        let test_data = get_default_test_data().await;
-        let expected_results = get_expected_range_join_results().await;
-        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
= test_data;
-
-        let debug = SpatialJoinDebugOptions {
-            num_spatial_partitions: NumSpatialPartitionsConfig::Fixed(4),
-            force_spill: true,
-            memory_for_intermittent_usage: None,
-            ..Default::default()
-        };
-        let options = SpatialJoinOptions {
-            spatial_library,
-            execution_mode,
-            debug,
-            ..Default::default()
-        };
-
-        for (idx, sql) in RANGE_JOIN_SQLS.iter().enumerate() {
-            let actual_result = run_spatial_join_query(
-                left_schema,
-                right_schema,
-                left_partitions.clone(),
-                right_partitions.clone(),
-                Some(options.clone()),
-                max_batch_size,
-                sql,
-            )
-            .await?;
-            assert_eq!(&actual_result, &expected_results[idx]);
-        }
-
-        Ok(())
-    }
-
-    #[rstest]
-    #[tokio::test]
-    async fn test_spatial_partitioned_outer_join(
-        #[values(10, 30, 1000)] batch_size: usize,
-        #[values(
-            JoinType::Left,
-            JoinType::Right,
-            JoinType::Full,
-            JoinType::LeftSemi,
-            JoinType::LeftAnti,
-            JoinType::RightSemi,
-            JoinType::RightAnti
-        )]
-        join_type: JoinType,
-    ) -> Result<()> {
-        let debug = SpatialJoinDebugOptions {
-            num_spatial_partitions: NumSpatialPartitionsConfig::Fixed(4),
-            force_spill: true,
-            memory_for_intermittent_usage: None,
-            ..Default::default()
-        };
-        let options = SpatialJoinOptions {
-            debug,
-            ..Default::default()
-        };
-
-        test_with_join_types(join_type, options, batch_size).await?;
-        Ok(())
-    }
-
-    #[rstest]
-    #[tokio::test]
-    async fn test_mark_joins(
-        #[values(JoinType::LeftMark, JoinType::RightMark)] join_type: JoinType,
-    ) -> Result<()> {
-        let options = SpatialJoinOptions::default();
-        test_mark_join(join_type, options, 10).await?;
-        Ok(())
-    }
-
-    #[rstest]
-    #[tokio::test]
-    async fn test_spatial_partitioned_mark_joins(
-        #[values(JoinType::LeftMark, JoinType::RightMark)] join_type: JoinType,
-    ) -> Result<()> {
-        let debug = SpatialJoinDebugOptions {
-            num_spatial_partitions: NumSpatialPartitionsConfig::Fixed(4),
-            force_spill: true,
-            memory_for_intermittent_usage: None,
-            ..Default::default()
-        };
-        let options = SpatialJoinOptions {
-            debug,
-            ..Default::default()
-        };
-        test_mark_join(join_type, options, 10).await?;
-        Ok(())
-    }
-
-    async fn test_with_join_types(
-        join_type: JoinType,
-        options: SpatialJoinOptions,
-        batch_size: usize,
-    ) -> Result<RecordBatch> {
-        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
=
-            create_test_data_with_empty_partitions()?;
-
-        let inner_sql = "SELECT L.id l_id, R.id r_id FROM L INNER JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id";
-        let sql = match join_type {
-            JoinType::Inner => inner_sql,
-            JoinType::Left => "SELECT L.id l_id, R.id r_id FROM L LEFT JOIN R 
ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
-            JoinType::Right => "SELECT L.id l_id, R.id r_id FROM L RIGHT JOIN 
R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
-            JoinType::Full => "SELECT L.id l_id, R.id r_id FROM L FULL OUTER 
JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
-            JoinType::LeftSemi => "SELECT L.id l_id FROM L LEFT SEMI JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id",
-            JoinType::RightSemi => "SELECT R.id r_id FROM L RIGHT SEMI JOIN R 
ON ST_Intersects(L.geometry, R.geometry) ORDER BY r_id",
-            JoinType::LeftAnti => "SELECT L.id l_id FROM L LEFT ANTI JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id",
-            JoinType::RightAnti => "SELECT R.id r_id FROM L RIGHT ANTI JOIN R 
ON ST_Intersects(L.geometry, R.geometry) ORDER BY r_id",
-            JoinType::LeftMark => {
-                unreachable!("LeftMark is not directly supported in SQL, will 
be tested in other tests");
-            }
-            JoinType::RightMark => {
-                unreachable!("RightMark is not directly supported in SQL, will 
be tested in other tests");
-            }
-        };
-
-        let batches = test_spatial_join_query(
-            &left_schema,
-            &right_schema,
-            left_partitions.clone(),
-            right_partitions.clone(),
-            &options,
-            batch_size,
-            sql,
-        )
-        .await?;
-
-        if matches!(join_type, JoinType::Left | JoinType::Right | 
JoinType::Full) {
-            // Make sure that we are effectively testing outer joins. If outer 
joins produces the same result as inner join,
-            // it means that the test data is not suitable for testing outer 
joins.
-            let inner_batches = run_spatial_join_query(
-                &left_schema,
-                &right_schema,
-                left_partitions,
-                right_partitions,
-                Some(options),
-                batch_size,
-                inner_sql,
-            )
-            .await?;
-            assert!(inner_batches.num_rows() < batches.num_rows());
-        }
-
-        Ok(batches)
-    }
-
-    async fn test_spatial_join_query(
-        left_schema: &SchemaRef,
-        right_schema: &SchemaRef,
-        left_partitions: Vec<Vec<RecordBatch>>,
-        right_partitions: Vec<Vec<RecordBatch>>,
-        options: &SpatialJoinOptions,
-        batch_size: usize,
-        sql: &str,
-    ) -> Result<RecordBatch> {
-        // Run spatial join using SpatialJoinExec
-        let actual = run_spatial_join_query(
-            left_schema,
-            right_schema,
-            left_partitions.clone(),
-            right_partitions.clone(),
-            Some(options.clone()),
-            batch_size,
-            sql,
-        )
-        .await?;
-
-        // Run spatial join using NestedLoopJoinExec
-        let expected = run_spatial_join_query(
-            left_schema,
-            right_schema,
-            left_partitions.clone(),
-            right_partitions.clone(),
-            None,
-            batch_size,
-            sql,
-        )
-        .await?;
-
-        // Should produce the same result
-        assert!(expected.num_rows() > 0);
-        assert_eq!(expected, actual);
-
-        Ok(actual)
-    }
-
-    async fn run_spatial_join_query(
-        left_schema: &SchemaRef,
-        right_schema: &SchemaRef,
-        left_partitions: Vec<Vec<RecordBatch>>,
-        right_partitions: Vec<Vec<RecordBatch>>,
-        options: Option<SpatialJoinOptions>,
-        batch_size: usize,
-        sql: &str,
-    ) -> Result<RecordBatch> {
-        let mem_table_left: Arc<dyn TableProvider> =
-            Arc::new(MemTable::try_new(left_schema.to_owned(), 
left_partitions)?);
-        let mem_table_right: Arc<dyn TableProvider> = 
Arc::new(MemTable::try_new(
-            right_schema.to_owned(),
-            right_partitions,
-        )?);
-
-        let is_optimized_spatial_join = options.is_some();
-        let ctx = setup_context(options, batch_size)?;
-        ctx.register_table("L", Arc::clone(&mem_table_left))?;
-        ctx.register_table("R", Arc::clone(&mem_table_right))?;
-        let df = ctx.sql(sql).await?;
-        let actual_schema = df.schema().as_arrow().clone();
-        let plan = df.clone().create_physical_plan().await?;
-        let spatial_join_execs = collect_spatial_join_exec(&plan)?;
-        if is_optimized_spatial_join {
-            assert_eq!(spatial_join_execs.len(), 1);
-        } else {
-            assert!(spatial_join_execs.is_empty());
-        }
-        let result_batches = df.collect().await?;
-        let result_batch =
-            arrow::compute::concat_batches(&Arc::new(actual_schema), 
&result_batches)?;
-        Ok(result_batch)
-    }
-
-    fn collect_spatial_join_exec(plan: &Arc<dyn ExecutionPlan>) -> 
Result<Vec<&SpatialJoinExec>> {
-        let mut spatial_join_execs = Vec::new();
-        plan.apply(|node| {
-            if let Some(spatial_join_exec) = 
node.as_any().downcast_ref::<SpatialJoinExec>() {
-                spatial_join_execs.push(spatial_join_exec);
-            }
-            Ok(TreeNodeRecursion::Continue)
-        })?;
-        Ok(spatial_join_execs)
-    }
-
-    async fn test_mark_join(
-        join_type: JoinType,
-        options: SpatialJoinOptions,
-        batch_size: usize,
-    ) -> Result<()> {
-        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
=
-            create_test_data_with_size_range((0.1, 10.0), WKB_GEOMETRY)?;
-        let mem_table_left: Arc<dyn TableProvider> = 
Arc::new(MemTable::try_new(
-            left_schema.clone(),
-            left_partitions.clone(),
-        )?);
-        let mem_table_right: Arc<dyn TableProvider> = 
Arc::new(MemTable::try_new(
-            right_schema.clone(),
-            right_partitions.clone(),
-        )?);
-
-        // We use a Left Join as a template to create the plan, then modify it 
to Mark Join
-        let sql = "SELECT * FROM L LEFT JOIN R ON ST_Intersects(L.geometry, 
R.geometry)";
-
-        // Create SpatialJoinExec plan
-        let ctx = setup_context(Some(options.clone()), batch_size)?;
-        ctx.register_table("L", mem_table_left.clone())?;
-        ctx.register_table("R", mem_table_right.clone())?;
-        let df = ctx.sql(sql).await?;
-        let plan = df.create_physical_plan().await?;
-        let spatial_join_execs = collect_spatial_join_exec(&plan)?;
-        assert_eq!(spatial_join_execs.len(), 1);
-        let original_exec = spatial_join_execs[0];
-        let mark_exec = SpatialJoinExec::try_new(
-            original_exec.left.clone(),
-            original_exec.right.clone(),
-            original_exec.on.clone(),
-            original_exec.filter.clone(),
-            &join_type,
-            None,
-            &options,
-        )?;
-
-        // Create NestedLoopJoinExec plan for comparison
-        let ctx_no_opt = setup_context(None, batch_size)?;
-        ctx_no_opt.register_table("L", mem_table_left)?;
-        ctx_no_opt.register_table("R", mem_table_right)?;
-        let df_no_opt = ctx_no_opt.sql(sql).await?;
-        let plan_no_opt = df_no_opt.create_physical_plan().await?;
-        fn collect_nlj_exec(plan: &Arc<dyn ExecutionPlan>) -> 
Result<Vec<&NestedLoopJoinExec>> {
-            let mut execs = Vec::new();
-            plan.apply(|node| {
-                if let Some(exec) = 
node.as_any().downcast_ref::<NestedLoopJoinExec>() {
-                    execs.push(exec);
-                }
-                Ok(TreeNodeRecursion::Continue)
-            })?;
-            Ok(execs)
-        }
-        let nlj_execs = collect_nlj_exec(&plan_no_opt)?;
-        assert_eq!(nlj_execs.len(), 1);
-        let original_nlj = nlj_execs[0];
-        let mark_nlj = NestedLoopJoinExec::try_new(
-            original_nlj.children()[0].clone(),
-            original_nlj.children()[1].clone(),
-            original_nlj.filter().cloned(),
-            &join_type,
-            None,
-        )?;
-
-        async fn run_and_sort(
-            plan: Arc<dyn ExecutionPlan>,
-            ctx: &SessionContext,
-        ) -> Result<RecordBatch> {
-            let results = datafusion_physical_plan::collect(plan, 
ctx.task_ctx()).await?;
-            let batch = arrow::compute::concat_batches(&results[0].schema(), 
&results)?;
-            let sort_col = batch.column(0);
-            let indices = arrow::compute::sort_to_indices(sort_col, None, 
None)?;
-            let sorted_batch = arrow::compute::take_record_batch(&batch, 
&indices)?;
-            Ok(sorted_batch)
-        }
-
-        // Run both Mark Join plans and compare results
-        let mark_batch = run_and_sort(Arc::new(mark_exec), &ctx).await?;
-        let mark_nlj_batch = run_and_sort(Arc::new(mark_nlj), 
&ctx_no_opt).await?;
-        assert_eq!(mark_batch, mark_nlj_batch);
-
-        Ok(())
-    }
-
-    fn extract_geoms_and_ids(partitions: &[Vec<RecordBatch>]) -> Vec<(i32, 
geo::Geometry<f64>)> {
-        let mut result = Vec::new();
-        for partition in partitions {
-            for batch in partition {
-                let id_idx = batch.schema().index_of("id").expect("Id column 
not found");
-                let ids = batch
-                    .column(id_idx)
-                    .as_any()
-                    .downcast_ref::<arrow_array::Int32Array>()
-                    .expect("Column 'id' should be Int32");
-
-                let geom_idx = batch
-                    .schema()
-                    .index_of("geometry")
-                    .expect("Geometry column not found");
-
-                let geoms_col = batch.column(geom_idx);
-                let geom_type = 
SedonaType::from_storage_field(batch.schema().field(geom_idx))
-                    .expect("Failed to get SedonaType from geometry field");
-                let arg_types = [geom_type];
-                let arg_values = [ColumnarValue::Array(Arc::clone(geoms_col))];
-
-                let executor = GeoTypesExecutor::new(&arg_types, &arg_values);
-                let mut id_iter = ids.iter();
-                executor
-                    .execute_wkb_void(|maybe_geom| {
-                        if let Some(id_opt) = id_iter.next() {
-                            if let (Some(id), Some(geom)) = (id_opt, 
maybe_geom) {
-                                result.push((id, geom))
-                            }
-                        }
-                        Ok(())
-                    })
-                    .expect("Failed to extract geoms and ids from 
RecordBatch");
-            }
-        }
-        result
-    }
-
-    fn compute_knn_ground_truth_with_pair_filter<F>(
-        left_partitions: &[Vec<RecordBatch>],
-        right_partitions: &[Vec<RecordBatch>],
-        k: usize,
-        keep_pair: F,
-    ) -> Vec<(i32, i32, f64)>
-    where
-        F: Fn(i32, i32) -> bool,
-    {
-        // NOTE: This helper mirrors our KNN semantics used in execution:
-        // - select top-K unfiltered candidates by distance (stable by r_id)
-        // - then apply a cross-side predicate to decide which pairs to keep
-        //   (can yield < K results per probe row)
-        //
-        // The predicate is intentionally *post* top-K selection.
-        // (See `test_knn_join_with_filter_correctness`.)
-        let left_data = extract_geoms_and_ids(left_partitions);
-        let right_data = extract_geoms_and_ids(right_partitions);
-
-        let mut results = Vec::new();
-
-        for (l_id, l_geom) in left_data {
-            let mut distances: Vec<(i32, f64)> = right_data
-                .iter()
-                .map(|(r_id, r_geom)| (*r_id, Euclidean.distance(&l_geom, 
r_geom)))
-                .collect();
-
-            // Sort by distance, then by ID for stability
-            distances.sort_by(|a, b| a.1.total_cmp(&b.1).then_with(|| 
a.0.cmp(&b.0)));
-
-            // KNN semantics: pick top-K unfiltered, then optionally 
post-filter.
-            for (r_id, dist) in distances.iter().take(k.min(distances.len())) {
-                if keep_pair(l_id, *r_id) {
-                    results.push((l_id, *r_id, *dist));
-                }
-            }
-        }
-
-        // Sort results by L.id, R.id
-        results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
-        results
-    }
-
-    #[rstest]
-    #[tokio::test]
-    async fn test_knn_join_correctness(
-        #[values(true, false)] point_only: bool,
-        #[values(1, 2, 3, 4)] num_partitions: usize,
-        #[values(10, 30, 1000)] max_batch_size: usize,
-    ) -> Result<()> {
-        // Generate slightly larger data
-        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
= if point_only {
-            create_knn_test_data((0.1, 10.0), WKB_GEOMETRY)?
-        } else {
-            create_default_test_data()?
-        };
-
-        // Use single partition to verify algorithm correctness first, 
avoiding partitioning issues
-        let options = SpatialJoinOptions {
-            debug: SpatialJoinDebugOptions {
-                num_spatial_partitions: 
NumSpatialPartitionsConfig::Fixed(num_partitions),
-                ..Default::default()
-            },
-            ..Default::default()
-        };
-        let k = 6;
-
-        let sql1 = format!(
-            "SELECT L.id, R.id, ST_Distance(L.geometry, R.geometry) FROM L 
JOIN R ON ST_KNN(L.geometry, R.geometry, {}, false) ORDER BY L.id, R.id",
-            k
-        );
-        let expected1 = compute_knn_ground_truth_with_pair_filter(
-            &left_partitions,
-            &right_partitions,
-            k,
-            |_l_id, _r_id| true,
-        )
-        .into_iter()
-        .map(|(l, r, _)| (l, r))
-        .collect::<Vec<_>>();
-        let sql2 = format!(
-            "SELECT R.id, L.id, ST_Distance(L.geometry, R.geometry) FROM L 
JOIN R ON ST_KNN(R.geometry, L.geometry, {}, false) ORDER BY R.id, L.id",
-            k
-        );
-        let expected2 = compute_knn_ground_truth_with_pair_filter(
-            &right_partitions,
-            &left_partitions,
-            k,
-            |_l_id, _r_id| true,
-        )
-        .into_iter()
-        .map(|(l, r, _)| (l, r))
-        .collect::<Vec<_>>();
-
-        let sqls = [(&sql1, &expected1), (&sql2, &expected2)];
-
-        for (sql, expected_results) in sqls {
-            let batches = run_spatial_join_query(
-                &left_schema,
-                &right_schema,
-                left_partitions.clone(),
-                right_partitions.clone(),
-                Some(options.clone()),
-                max_batch_size,
-                sql,
-            )
-            .await?;
-
-            // Collect actual results
-            let mut actual_results = Vec::new();
-            let combined_batch = 
arrow::compute::concat_batches(&batches.schema(), &[batches])?;
-            let l_ids = combined_batch
-                .column(0)
-                .as_any()
-                .downcast_ref::<arrow_array::Int32Array>()
-                .unwrap();
-            let r_ids = combined_batch
-                .column(1)
-                .as_any()
-                .downcast_ref::<arrow_array::Int32Array>()
-                .unwrap();
-
-            for i in 0..combined_batch.num_rows() {
-                actual_results.push((l_ids.value(i), r_ids.value(i)));
-            }
-            actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| 
a.1.cmp(&b.1)));
-
-            assert_eq!(actual_results, *expected_results);
-        }
-
-        Ok(())
-    }
-
-    #[rstest]
-    #[tokio::test]
-    async fn test_knn_join_with_filter_correctness(
-        #[values(1, 2, 3, 4)] num_partitions: usize,
-        #[values(10, 30, 1000)] max_batch_size: usize,
-    ) -> Result<()> {
-        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
=
-            create_knn_test_data((0.1, 10.0), WKB_GEOMETRY)?;
-
-        let options = SpatialJoinOptions {
-            debug: SpatialJoinDebugOptions {
-                num_spatial_partitions: 
NumSpatialPartitionsConfig::Fixed(num_partitions),
-                ..Default::default()
-            },
-            ..Default::default()
-        };
-
-        let k = 3;
-        let sql = format!(
-            "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON 
ST_KNN(L.geometry, R.geometry, {}, false) AND (L.id % 7) = (R.id % 7)",
-            k
-        );
-
-        let batches = run_spatial_join_query(
-            &left_schema,
-            &right_schema,
-            left_partitions.clone(),
-            right_partitions.clone(),
-            Some(options),
-            max_batch_size,
-            &sql,
-        )
-        .await?;
-
-        let mut actual_results = Vec::new();
-        let combined_batch = arrow::compute::concat_batches(&batches.schema(), 
&[batches])?;
-        let l_ids = combined_batch
-            .column(0)
-            .as_any()
-            .downcast_ref::<arrow_array::Int32Array>()
-            .unwrap();
-        let r_ids = combined_batch
-            .column(1)
-            .as_any()
-            .downcast_ref::<arrow_array::Int32Array>()
-            .unwrap();
-
-        for i in 0..combined_batch.num_rows() {
-            actual_results.push((l_ids.value(i), r_ids.value(i)));
-        }
-        actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| 
a.1.cmp(&b.1)));
-
-        // Prove the test actually exercises the "< K rows after filtering" 
case.
-        // Build a list of all probe-side IDs and count how many results each 
has.
-        let all_left_ids: Vec<i32> = extract_geoms_and_ids(&left_partitions)
-            .into_iter()
-            .map(|(id, _)| id)
-            .collect();
-        let mut per_left_counts: std::collections::HashMap<i32, usize> =
-            std::collections::HashMap::new();
-        for (l_id, _) in &actual_results {
-            *per_left_counts.entry(*l_id).or_default() += 1;
-        }
-        let min_count = all_left_ids
-            .iter()
-            .map(|l_id| *per_left_counts.get(l_id).unwrap_or(&0))
-            .min()
-            .unwrap_or(0);
-        assert!(
-            min_count < k,
-            "expected at least one probe row to produce < K rows after 
filtering; min_count={min_count}, k={k}"
-        );
-
-        let expected_results = compute_knn_ground_truth_with_pair_filter(
-            &left_partitions,
-            &right_partitions,
-            k,
-            |l_id, r_id| (l_id.rem_euclid(7)) == (r_id.rem_euclid(7)),
-        )
-        .into_iter()
-        .map(|(l, r, _)| (l, r))
-        .collect::<Vec<_>>();
-
-        assert_eq!(actual_results, expected_results);
-
-        Ok(())
-    }
-
-    #[rstest]
-    #[tokio::test]
-    async fn test_knn_join_include_tie_breakers(
-        #[values(1, 2, 3, 4)] num_partitions: usize,
-        #[values(10, 100)] max_batch_size: usize,
-    ) -> Result<()> {
-        // Construct a larger dataset with *guaranteed* exact ties at the kth 
distance.
-        //
-        // For each probe point at (10*i, 0), we create two candidate points 
at (10*i-1, 0)
-        // and (10*i+1, 0). Those two candidates are tied (distance = 1).
-        // A third candidate at (10*i+2, 0) ensures there are also non-tied 
options.
-        // Spacing by 10 keeps other probes' candidates far enough away that 
they never interfere.
-        //
-        // With k=1:
-        // - knn_include_tie_breakers=false should return exactly 1 match per 
probe row.
-        // - knn_include_tie_breakers=true should return 2 matches per probe 
row (both ties).
-        //
-        // The exact choice of which tied row is returned when tie-breakers 
are disabled is not
-        // asserted (it is allowed to be either tied candidate).
-
-        let schema = Arc::new(Schema::new(vec![
-            Field::new("id", DataType::Int32, false),
-            Field::new("wkt", DataType::Utf8, false),
-        ]));
-
-        let num_probe_rows: i32 = 120;
-        let k = 1;
-
-        let input_batches_left = 6;
-        let input_batches_right = 6;
-
-        fn make_batches(
-            schema: SchemaRef,
-            ids: Vec<i32>,
-            wkts: Vec<String>,
-            num_batches: usize,
-        ) -> Result<Vec<RecordBatch>> {
-            assert_eq!(ids.len(), wkts.len());
-            let total = ids.len();
-            let chunk = total.div_ceil(num_batches);
-
-            let mut batches = Vec::new();
-            for b in 0..num_batches {
-                let start = b * chunk;
-                if start >= total {
-                    break;
-                }
-                let end = ((b + 1) * chunk).min(total);
-                let batch = RecordBatch::try_new(
-                    schema.clone(),
-                    vec![
-                        
Arc::new(arrow_array::Int32Array::from(ids[start..end].to_vec())),
-                        Arc::new(arrow_array::StringArray::from(
-                            wkts[start..end]
-                                .iter()
-                                .map(|s| s.as_str())
-                                .collect::<Vec<_>>(),
-                        )),
-                    ],
-                )?;
-                batches.push(batch);
-            }
-            Ok(batches)
-        }
-
-        let mut left_ids = Vec::with_capacity(num_probe_rows as usize);
-        let mut left_wkts = Vec::with_capacity(num_probe_rows as usize);
-
-        let mut right_ids = Vec::with_capacity((num_probe_rows as usize) * 3);
-        let mut right_wkts = Vec::with_capacity((num_probe_rows as usize) * 3);
-
-        for i in 0..num_probe_rows {
-            let cx = (i as i64) * 10;
-            left_ids.push(i);
-            left_wkts.push(format!("POINT ({cx} 0)"));
-
-            // Two tied candidates at distance 1.
-            let base = i * 10;
-            right_ids.push(base + 1);
-            right_wkts.push(format!("POINT ({x} 0)", x = cx - 1));
-
-            right_ids.push(base + 2);
-            right_wkts.push(format!("POINT ({x} 0)", x = cx + 1));
-
-            // One non-tied candidate.
-            right_ids.push(base + 3);
-            right_wkts.push(format!("POINT ({x} 0)", x = cx + 2));
-        }
-
-        let left_batches = make_batches(schema.clone(), left_ids, left_wkts, 
input_batches_left)?;
-        let right_batches =
-            make_batches(schema.clone(), right_ids, right_wkts, 
input_batches_right)?;
-
-        // Put each side into a single MemTable partition, but with multiple 
batches.
-        // This ensures the build/probe collectors see 4–8 batches and the 
round-robin batch
-        // partitioner has something to distribute.
-        let left_partitions = vec![left_batches];
-        let right_partitions = vec![right_batches];
-
-        let sql = format!(
-            "SELECT L.id AS l_id, R.id AS r_id \
-             FROM L JOIN R \
-             ON ST_KNN(ST_GeomFromWKT(L.wkt), ST_GeomFromWKT(R.wkt), {k}, 
false)"
-        );
-
-        let base_options = SpatialJoinOptions {
-            debug: SpatialJoinDebugOptions {
-                num_spatial_partitions: 
NumSpatialPartitionsConfig::Fixed(num_partitions),
-                ..Default::default()
-            },
-            ..Default::default()
-        };
-
-        // Without tie-breakers: exactly 1 match per probe row.
-        let out_no_ties = run_spatial_join_query(
-            &schema,
-            &schema,
-            left_partitions.clone(),
-            right_partitions.clone(),
-            Some(SpatialJoinOptions {
-                knn_include_tie_breakers: false,
-                ..base_options.clone()
-            }),
-            max_batch_size,
-            &sql,
-        )
-        .await?;
-        let combined = arrow::compute::concat_batches(&out_no_ties.schema(), 
&[out_no_ties])?;
-
-        let l_ids = combined
-            .column(0)
-            .as_any()
-            .downcast_ref::<arrow_array::Int32Array>()
-            .unwrap();
-        let r_ids = combined
-            .column(1)
-            .as_any()
-            .downcast_ref::<arrow_array::Int32Array>()
-            .unwrap();
-
-        let mut per_left: std::collections::HashMap<i32, Vec<i32>> =
-            std::collections::HashMap::new();
-        for i in 0..combined.num_rows() {
-            per_left
-                .entry(l_ids.value(i))
-                .or_default()
-                .push(r_ids.value(i));
-        }
-
-        assert_eq!(per_left.len() as i32, num_probe_rows);
-        for l_id in 0..num_probe_rows {
-            let r_list = per_left.get(&l_id).unwrap();
-            assert_eq!(
-                r_list.len(),
-                1,
-                "expected exactly 1 match for l_id={l_id} when tie-breakers 
are disabled"
-            );
-            let base = l_id * 10;
-            let r_id = r_list[0];
-            assert!(
-                r_id == base + 1 || r_id == base + 2,
-                "expected a tied nearest neighbor for l_id={l_id}, got 
r_id={r_id}"
-            );
-        }
-
-        // With tie-breakers: exactly 2 matches per probe row (both tied 
candidates).
-        let out_with_ties = run_spatial_join_query(
-            &schema,
-            &schema,
-            left_partitions.clone(),
-            right_partitions.clone(),
-            Some(SpatialJoinOptions {
-                knn_include_tie_breakers: true,
-                ..base_options
-            }),
-            max_batch_size,
-            &sql,
-        )
-        .await?;
-        let combined = arrow::compute::concat_batches(&out_with_ties.schema(), 
&[out_with_ties])?;
-        let l_ids = combined
-            .column(0)
-            .as_any()
-            .downcast_ref::<arrow_array::Int32Array>()
-            .unwrap();
-        let r_ids = combined
-            .column(1)
-            .as_any()
-            .downcast_ref::<arrow_array::Int32Array>()
-            .unwrap();
-
-        let mut per_left: std::collections::HashMap<i32, Vec<i32>> =
-            std::collections::HashMap::new();
-        for i in 0..combined.num_rows() {
-            per_left
-                .entry(l_ids.value(i))
-                .or_default()
-                .push(r_ids.value(i));
-        }
-        assert_eq!(per_left.len() as i32, num_probe_rows);
-        for l_id in 0..num_probe_rows {
-            let mut r_list = per_left.get(&l_id).unwrap().clone();
-            r_list.sort();
-            let base = l_id * 10;
-            assert_eq!(
-                r_list,
-                vec![base + 1, base + 2],
-                "expected both tied nearest neighbors for l_id={l_id}"
-            );
-        }
-
-        Ok(())
-    }
-}
diff --git a/rust/sedona-spatial-join/tests/spatial_join_integration.rs 
b/rust/sedona-spatial-join/tests/spatial_join_integration.rs
new file mode 100644
index 00000000..c6a2ae86
--- /dev/null
+++ b/rust/sedona-spatial-join/tests/spatial_join_integration.rs
@@ -0,0 +1,1291 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use arrow_array::{Array, RecordBatch};
+use arrow_schema::{DataType, Field, Schema, SchemaRef};
+use datafusion::{
+    catalog::{MemTable, TableProvider},
+    execution::SessionStateBuilder,
+    prelude::{SessionConfig, SessionContext},
+};
+use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
+use datafusion_common::Result;
+use datafusion_expr::{ColumnarValue, JoinType};
+use datafusion_physical_plan::joins::NestedLoopJoinExec;
+use datafusion_physical_plan::ExecutionPlan;
+use geo::{Distance, Euclidean};
+use geo_types::{Coord, Rect};
+use rstest::rstest;
+use sedona_common::SedonaOptions;
+use sedona_geo::to_geo::GeoTypesExecutor;
+use sedona_geometry::types::GeometryTypeId;
+use sedona_schema::datatypes::{SedonaType, WKB_GEOGRAPHY, WKB_GEOMETRY};
+use sedona_spatial_join::{register_spatial_join_optimizer, SpatialJoinExec};
+use sedona_testing::datagen::RandomPartitionedDataBuilder;
+use tokio::sync::OnceCell;
+
+use sedona_common::{
+    option::{add_sedona_option_extension, ExecutionMode, SpatialJoinOptions},
+    NumSpatialPartitionsConfig, SpatialJoinDebugOptions, SpatialLibrary,
+};
+
+type TestPartitions = (SchemaRef, Vec<Vec<RecordBatch>>);
+
+/// Creates standard test data with left (Polygon) and right (Point) partitions
+fn create_default_test_data() -> Result<(TestPartitions, TestPartitions)> {
+    create_test_data_with_size_range((1.0, 10.0), WKB_GEOMETRY)
+}
+
+/// Creates test data with custom size range
+fn create_test_data_with_size_range(
+    size_range: (f64, f64),
+    sedona_type: SedonaType,
+) -> Result<(TestPartitions, TestPartitions)> {
+    let bounds = Rect::new(Coord { x: 0.0, y: 0.0 }, Coord { x: 100.0, y: 
100.0 });
+
+    let left_data = RandomPartitionedDataBuilder::new()
+        .seed(11584)
+        .num_partitions(2)
+        .batches_per_partition(2)
+        .rows_per_batch(30)
+        .geometry_type(GeometryTypeId::Polygon)
+        .sedona_type(sedona_type.clone())
+        .bounds(bounds)
+        .size_range(size_range)
+        .null_rate(0.1)
+        .build()?;
+
+    let right_data = RandomPartitionedDataBuilder::new()
+        .seed(54843)
+        .num_partitions(4)
+        .batches_per_partition(4)
+        .rows_per_batch(30)
+        .geometry_type(GeometryTypeId::Point)
+        .sedona_type(sedona_type)
+        .bounds(bounds)
+        .size_range(size_range)
+        .null_rate(0.1)
+        .build()?;
+
+    Ok((left_data, right_data))
+}
+
+/// Creates test data with empty partitions inserted at beginning and end
+fn create_test_data_with_empty_partitions() -> Result<(TestPartitions, 
TestPartitions)> {
+    let (mut left_data, mut right_data) = create_default_test_data()?;
+
+    // Add empty partitions
+    left_data.1.insert(0, vec![]);
+    left_data.1.push(vec![]);
+    right_data.1.insert(0, vec![]);
+    right_data.1.push(vec![]);
+
+    Ok((left_data, right_data))
+}
+
+/// Creates test data for KNN join (Point-Point)
+fn create_knn_test_data(
+    size_range: (f64, f64),
+    sedona_type: SedonaType,
+) -> Result<(TestPartitions, TestPartitions)> {
+    let bounds = Rect::new(Coord { x: 0.0, y: 0.0 }, Coord { x: 100.0, y: 
100.0 });
+
+    let left_data = RandomPartitionedDataBuilder::new()
+        .seed(1)
+        .num_partitions(2)
+        .batches_per_partition(2)
+        .rows_per_batch(30)
+        .geometry_type(GeometryTypeId::Point)
+        .sedona_type(sedona_type.clone())
+        .bounds(bounds)
+        .size_range(size_range)
+        .null_rate(0.1)
+        .build()?;
+
+    let right_data = RandomPartitionedDataBuilder::new()
+        .seed(2)
+        .num_partitions(4)
+        .batches_per_partition(4)
+        .rows_per_batch(30)
+        .geometry_type(GeometryTypeId::Point)
+        .sedona_type(sedona_type)
+        .bounds(bounds)
+        .size_range(size_range)
+        .null_rate(0.1)
+        .build()?;
+
+    Ok((left_data, right_data))
+}
+
+fn setup_context(options: Option<SpatialJoinOptions>, batch_size: usize) -> 
Result<SessionContext> {
+    let mut session_config = SessionConfig::from_env()?
+        .with_information_schema(true)
+        .with_batch_size(batch_size);
+    session_config = add_sedona_option_extension(session_config);
+    let mut state_builder = SessionStateBuilder::new();
+    if let Some(options) = options {
+        state_builder = register_spatial_join_optimizer(state_builder);
+        let opts = session_config
+            .options_mut()
+            .extensions
+            .get_mut::<SedonaOptions>()
+            .unwrap();
+        opts.spatial_join = options;
+    }
+    let state = state_builder.with_config(session_config).build();
+    let ctx = SessionContext::new_with_state(state);
+
+    let mut function_set = sedona_functions::register::default_function_set();
+    let scalar_kernels = sedona_geos::register::scalar_kernels();
+
+    function_set.scalar_udfs().for_each(|udf| {
+        ctx.register_udf(udf.clone().into());
+    });
+
+    for (name, kernel) in scalar_kernels.into_iter() {
+        let udf = function_set.add_scalar_udf_impl(name, kernel)?;
+        ctx.register_udf(udf.clone().into());
+    }
+
+    Ok(ctx)
+}
+
+#[tokio::test]
+async fn test_empty_data() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("id", DataType::Int32, false),
+        Field::new("dist", DataType::Float64, false),
+        WKB_GEOMETRY.to_storage_field("geometry", true).unwrap(),
+    ]));
+
+    let test_data_vec = vec![vec![vec![]], vec![vec![], vec![]]];
+
+    let options = SpatialJoinOptions::default();
+    let ctx = setup_context(Some(options.clone()), 10)?;
+    for test_data in test_data_vec {
+        let left_partitions = test_data.clone();
+        let right_partitions = test_data;
+
+        let mem_table_left: Arc<dyn TableProvider> = 
Arc::new(MemTable::try_new(
+            Arc::clone(&schema),
+            left_partitions.clone(),
+        )?);
+        let mem_table_right: Arc<dyn TableProvider> = 
Arc::new(MemTable::try_new(
+            Arc::clone(&schema),
+            right_partitions.clone(),
+        )?);
+
+        ctx.deregister_table("L")?;
+        ctx.deregister_table("R")?;
+        ctx.register_table("L", Arc::clone(&mem_table_left))?;
+        ctx.register_table("R", Arc::clone(&mem_table_right))?;
+
+        let sql = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id";
+        let df = ctx.sql(sql).await?;
+        let result_batches = df.collect().await?;
+        for result_batch in result_batches {
+            assert_eq!(result_batch.num_rows(), 0);
+        }
+    }
+
+    Ok(())
+}
+
+// Shared test data and expected results - computed only once across all 
parameterized test cases
+// Using tokio::sync::OnceCell for async lazy initialization to avoid 
recomputing expensive
+// test data generation and nested loop join results for each test parameter 
combination
+static TEST_DATA: OnceCell<(TestPartitions, TestPartitions)> = 
OnceCell::const_new();
+static RANGE_JOIN_EXPECTED_RESULTS: OnceCell<Vec<RecordBatch>> = 
OnceCell::const_new();
+static DIST_JOIN_EXPECTED_RESULTS: OnceCell<Vec<RecordBatch>> = 
OnceCell::const_new();
+
+const RANGE_JOIN_SQL1: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id";
+const RANGE_JOIN_SQL2: &str =
+    "SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY 
L.id, R.id";
+const RANGE_JOIN_SQLS: &[&str] = &[RANGE_JOIN_SQL1, RANGE_JOIN_SQL2];
+
+const DIST_JOIN_SQL1: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Distance(L.geometry, R.geometry) < 1.0 ORDER BY l_id, r_id";
+const DIST_JOIN_SQL2: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Distance(L.geometry, R.geometry) < L.dist / 10.0 ORDER BY l_id, r_id";
+const DIST_JOIN_SQL3: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Distance(L.geometry, R.geometry) < R.dist / 10.0 ORDER BY l_id, r_id";
+const DIST_JOIN_SQL4: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_DWithin(L.geometry, R.geometry, 1.0) ORDER BY l_id, r_id";
+const DIST_JOIN_SQLS: &[&str] = &[
+    DIST_JOIN_SQL1,
+    DIST_JOIN_SQL2,
+    DIST_JOIN_SQL3,
+    DIST_JOIN_SQL4,
+];
+
+/// Get test data, computing it only once
+async fn get_default_test_data() -> &'static (TestPartitions, TestPartitions) {
+    TEST_DATA
+        .get_or_init(|| async { create_default_test_data().expect("Failed to 
create test data") })
+        .await
+}
+
+/// Get expected results, computing them only once
+async fn get_expected_range_join_results() -> &'static Vec<RecordBatch> {
+    get_or_init_expected_join_results(&RANGE_JOIN_EXPECTED_RESULTS, 
RANGE_JOIN_SQLS).await
+}
+
+async fn get_expected_distance_join_results() -> &'static Vec<RecordBatch> {
+    get_or_init_expected_join_results(&DIST_JOIN_EXPECTED_RESULTS, 
DIST_JOIN_SQLS).await
+}
+
+async fn get_or_init_expected_join_results<'a>(
+    lazy_init_results: &'a OnceCell<Vec<RecordBatch>>,
+    sql_queries: &[&str],
+) -> &'a Vec<RecordBatch> {
+    lazy_init_results
+        .get_or_init(|| async {
+            let test_data = get_default_test_data().await;
+            let ((left_schema, left_partitions), (right_schema, 
right_partitions)) = test_data;
+
+            let batch_size = 10;
+
+            // Run nested loop join to get expected results
+            let mut expected_results = Vec::with_capacity(sql_queries.len());
+
+            for (i, sql) in sql_queries.iter().enumerate() {
+                let result = run_spatial_join_query(
+                    left_schema,
+                    right_schema,
+                    left_partitions.clone(),
+                    right_partitions.clone(),
+                    None,
+                    batch_size,
+                    sql,
+                )
+                .await
+                .unwrap_or_else(|e| panic!("Failed to generate expected result 
{}: {}", i + 1, e));
+                expected_results.push(result);
+            }
+
+            expected_results
+        })
+        .await
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_range_join_with_conf(
+    #[values(10, 30, 1000)] max_batch_size: usize,
+    #[values(
+        ExecutionMode::PrepareNone,
+        ExecutionMode::PrepareBuild,
+        ExecutionMode::PrepareProbe,
+        ExecutionMode::Speculative(20)
+    )]
+    execution_mode: ExecutionMode,
+    #[values(SpatialLibrary::Geo, SpatialLibrary::Geos, SpatialLibrary::Tg)]
+    spatial_library: SpatialLibrary,
+) -> Result<()> {
+    let test_data = get_default_test_data().await;
+    let expected_results = get_expected_range_join_results().await;
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) = 
test_data;
+
+    let options = SpatialJoinOptions {
+        spatial_library,
+        execution_mode,
+        ..Default::default()
+    };
+    for (idx, sql) in RANGE_JOIN_SQLS.iter().enumerate() {
+        let actual_result = run_spatial_join_query(
+            left_schema,
+            right_schema,
+            left_partitions.clone(),
+            right_partitions.clone(),
+            Some(options.clone()),
+            max_batch_size,
+            sql,
+        )
+        .await?;
+        assert_eq!(&actual_result, &expected_results[idx]);
+    }
+
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_distance_join_with_conf(
+    #[values(30, 1000)] max_batch_size: usize,
+    #[values(SpatialLibrary::Geo, SpatialLibrary::Geos, SpatialLibrary::Tg)]
+    spatial_library: SpatialLibrary,
+) -> Result<()> {
+    let test_data = get_default_test_data().await;
+    let expected_results = get_expected_distance_join_results().await;
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) = 
test_data;
+
+    let options = SpatialJoinOptions {
+        spatial_library,
+        ..Default::default()
+    };
+    for (idx, sql) in DIST_JOIN_SQLS.iter().enumerate() {
+        let actual_result = run_spatial_join_query(
+            left_schema,
+            right_schema,
+            left_partitions.clone(),
+            right_partitions.clone(),
+            Some(options.clone()),
+            max_batch_size,
+            sql,
+        )
+        .await?;
+        assert_eq!(&actual_result, &expected_results[idx]);
+    }
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_spatial_join_with_filter() -> Result<()> {
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_test_data_with_size_range((0.1, 10.0), WKB_GEOMETRY)?;
+
+    for max_batch_size in [10, 30, 100] {
+        let options = SpatialJoinOptions::default();
+        test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
+            "SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) 
AND L.dist < R.dist ORDER BY L.id, R.id").await?;
+        test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
+            "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Intersects(L.geometry, R.geometry) AND L.dist < R.dist ORDER BY l_id, 
r_id").await?;
+        test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
+            "SELECT L.id l_id, R.id r_id, L.dist l_dist, R.dist r_dist FROM L 
JOIN R ON ST_Intersects(L.geometry, R.geometry) AND L.dist < R.dist ORDER BY 
l_id, r_id").await?;
+    }
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_range_join_with_empty_partitions() -> Result<()> {
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_test_data_with_empty_partitions()?;
+
+    for max_batch_size in [10, 30, 1000] {
+        let options = SpatialJoinOptions::default();
+        test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
+            "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id").await?;
+        test_spatial_join_query(
+            &left_schema,
+            &right_schema,
+            left_partitions.clone(),
+            right_partitions.clone(),
+            &options,
+            max_batch_size,
+            "SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) 
ORDER BY L.id, R.id",
+        )
+        .await?;
+    }
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_inner_join() -> Result<()> {
+    let options = SpatialJoinOptions::default();
+    test_with_join_types(JoinType::Inner, options, 30).await?;
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_left_joins(
+    #[values(JoinType::Left, JoinType::LeftSemi, JoinType::LeftAnti)] 
join_type: JoinType,
+) -> Result<()> {
+    let options = SpatialJoinOptions::default();
+    test_with_join_types(join_type, options, 30).await?;
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_right_joins(
+    #[values(JoinType::Right, JoinType::RightSemi, JoinType::RightAnti)] 
join_type: JoinType,
+) -> Result<()> {
+    let options = SpatialJoinOptions::default();
+    test_with_join_types(join_type, options, 30).await?;
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_full_outer_join() -> Result<()> {
+    let options = SpatialJoinOptions::default();
+    test_with_join_types(JoinType::Full, options, 30).await?;
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_geography_join_is_not_optimized() -> Result<()> {
+    let options = SpatialJoinOptions::default();
+    let ctx = setup_context(Some(options), 10)?;
+
+    // Prepare geography tables
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_test_data_with_size_range((0.1, 10.0), WKB_GEOGRAPHY)?;
+    let mem_table_left: Arc<dyn TableProvider> =
+        Arc::new(MemTable::try_new(left_schema, left_partitions)?);
+    let mem_table_right: Arc<dyn TableProvider> =
+        Arc::new(MemTable::try_new(right_schema, right_partitions)?);
+    ctx.register_table("L", mem_table_left)?;
+    ctx.register_table("R", mem_table_right)?;
+
+    // Execute geography join query
+    let df = ctx
+        .sql("SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry)")
+        .await?;
+    let plan = df.create_physical_plan().await?;
+
+    // Verify that no SpatialJoinExec is present (geography join should not be 
optimized)
+    let spatial_joins = collect_spatial_join_exec(&plan)?;
+    assert!(
+        spatial_joins.is_empty(),
+        "Geography joins should not be optimized to SpatialJoinExec"
+    );
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_query_window_in_subquery() -> Result<()> {
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_test_data_with_size_range((50.0, 60.0), WKB_GEOMETRY)?;
+    let options = SpatialJoinOptions::default();
+    test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, 10,
+            "SELECT id FROM L WHERE ST_Intersects(L.geometry, (SELECT 
R.geometry FROM R WHERE R.id = 1))").await?;
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_parallel_refinement_for_large_candidate_set() -> Result<()> {
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_test_data_with_size_range((1.0, 50.0), WKB_GEOMETRY)?;
+
+    for max_batch_size in [10, 30, 100] {
+        let options = SpatialJoinOptions {
+            parallel_refinement_chunk_size: 10,
+            ..Default::default()
+        };
+        test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
+            "SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) 
AND L.dist < R.dist ORDER BY L.id, R.id").await?;
+    }
+
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_spatial_partitioned_range_join(
+    #[values(10, 30, 1000)] max_batch_size: usize,
+    #[values(
+        ExecutionMode::PrepareNone,
+        ExecutionMode::PrepareBuild,
+        ExecutionMode::PrepareProbe,
+        ExecutionMode::Speculative(20)
+    )]
+    execution_mode: ExecutionMode,
+    #[values(SpatialLibrary::Geo, SpatialLibrary::Geos, SpatialLibrary::Tg)]
+    spatial_library: SpatialLibrary,
+) -> Result<()> {
+    let test_data = get_default_test_data().await;
+    let expected_results = get_expected_range_join_results().await;
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) = 
test_data;
+
+    let debug = SpatialJoinDebugOptions {
+        num_spatial_partitions: NumSpatialPartitionsConfig::Fixed(4),
+        force_spill: true,
+        memory_for_intermittent_usage: None,
+        ..Default::default()
+    };
+    let options = SpatialJoinOptions {
+        spatial_library,
+        execution_mode,
+        debug,
+        ..Default::default()
+    };
+
+    for (idx, sql) in RANGE_JOIN_SQLS.iter().enumerate() {
+        let actual_result = run_spatial_join_query(
+            left_schema,
+            right_schema,
+            left_partitions.clone(),
+            right_partitions.clone(),
+            Some(options.clone()),
+            max_batch_size,
+            sql,
+        )
+        .await?;
+        assert_eq!(&actual_result, &expected_results[idx]);
+    }
+
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_spatial_partitioned_outer_join(
+    #[values(10, 30, 1000)] batch_size: usize,
+    #[values(
+        JoinType::Left,
+        JoinType::Right,
+        JoinType::Full,
+        JoinType::LeftSemi,
+        JoinType::LeftAnti,
+        JoinType::RightSemi,
+        JoinType::RightAnti
+    )]
+    join_type: JoinType,
+) -> Result<()> {
+    let debug = SpatialJoinDebugOptions {
+        num_spatial_partitions: NumSpatialPartitionsConfig::Fixed(4),
+        force_spill: true,
+        memory_for_intermittent_usage: None,
+        ..Default::default()
+    };
+    let options = SpatialJoinOptions {
+        debug,
+        ..Default::default()
+    };
+
+    test_with_join_types(join_type, options, batch_size).await?;
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_mark_joins(
+    #[values(JoinType::LeftMark, JoinType::RightMark)] join_type: JoinType,
+) -> Result<()> {
+    let options = SpatialJoinOptions::default();
+    test_mark_join(join_type, options, 10).await?;
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_spatial_partitioned_mark_joins(
+    #[values(JoinType::LeftMark, JoinType::RightMark)] join_type: JoinType,
+) -> Result<()> {
+    let debug = SpatialJoinDebugOptions {
+        num_spatial_partitions: NumSpatialPartitionsConfig::Fixed(4),
+        force_spill: true,
+        memory_for_intermittent_usage: None,
+        ..Default::default()
+    };
+    let options = SpatialJoinOptions {
+        debug,
+        ..Default::default()
+    };
+    test_mark_join(join_type, options, 10).await?;
+    Ok(())
+}
+
+async fn test_with_join_types(
+    join_type: JoinType,
+    options: SpatialJoinOptions,
+    batch_size: usize,
+) -> Result<RecordBatch> {
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_test_data_with_empty_partitions()?;
+
+    let inner_sql = "SELECT L.id l_id, R.id r_id FROM L INNER JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id";
+    let sql = match join_type {
+        JoinType::Inner => inner_sql,
+        JoinType::Left => "SELECT L.id l_id, R.id r_id FROM L LEFT JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
+        JoinType::Right => "SELECT L.id l_id, R.id r_id FROM L RIGHT JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
+        JoinType::Full => "SELECT L.id l_id, R.id r_id FROM L FULL OUTER JOIN 
R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
+        JoinType::LeftSemi => "SELECT L.id l_id FROM L LEFT SEMI JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id",
+        JoinType::RightSemi => "SELECT R.id r_id FROM L RIGHT SEMI JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY r_id",
+        JoinType::LeftAnti => "SELECT L.id l_id FROM L LEFT ANTI JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id",
+        JoinType::RightAnti => "SELECT R.id r_id FROM L RIGHT ANTI JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY r_id",
+        JoinType::LeftMark => {
+            unreachable!("LeftMark is not directly supported in SQL, will be 
tested in other tests");
+        }
+        JoinType::RightMark => {
+            unreachable!("RightMark is not directly supported in SQL, will be 
tested in other tests");
+        }
+    };
+
+    let batches = test_spatial_join_query(
+        &left_schema,
+        &right_schema,
+        left_partitions.clone(),
+        right_partitions.clone(),
+        &options,
+        batch_size,
+        sql,
+    )
+    .await?;
+
+    if matches!(join_type, JoinType::Left | JoinType::Right | JoinType::Full) {
+        // Make sure that we are effectively testing outer joins. If outer 
joins produces the same result as inner join,
+        // it means that the test data is not suitable for testing outer joins.
+        let inner_batches = run_spatial_join_query(
+            &left_schema,
+            &right_schema,
+            left_partitions,
+            right_partitions,
+            Some(options),
+            batch_size,
+            inner_sql,
+        )
+        .await?;
+        assert!(inner_batches.num_rows() < batches.num_rows());
+    }
+
+    Ok(batches)
+}
+
+async fn test_spatial_join_query(
+    left_schema: &SchemaRef,
+    right_schema: &SchemaRef,
+    left_partitions: Vec<Vec<RecordBatch>>,
+    right_partitions: Vec<Vec<RecordBatch>>,
+    options: &SpatialJoinOptions,
+    batch_size: usize,
+    sql: &str,
+) -> Result<RecordBatch> {
+    // Run spatial join using SpatialJoinExec
+    let actual = run_spatial_join_query(
+        left_schema,
+        right_schema,
+        left_partitions.clone(),
+        right_partitions.clone(),
+        Some(options.clone()),
+        batch_size,
+        sql,
+    )
+    .await?;
+
+    // Run spatial join using NestedLoopJoinExec
+    let expected = run_spatial_join_query(
+        left_schema,
+        right_schema,
+        left_partitions.clone(),
+        right_partitions.clone(),
+        None,
+        batch_size,
+        sql,
+    )
+    .await?;
+
+    // Should produce the same result
+    assert!(expected.num_rows() > 0);
+    assert_eq!(expected, actual);
+
+    Ok(actual)
+}
+
+async fn run_spatial_join_query(
+    left_schema: &SchemaRef,
+    right_schema: &SchemaRef,
+    left_partitions: Vec<Vec<RecordBatch>>,
+    right_partitions: Vec<Vec<RecordBatch>>,
+    options: Option<SpatialJoinOptions>,
+    batch_size: usize,
+    sql: &str,
+) -> Result<RecordBatch> {
+    let mem_table_left: Arc<dyn TableProvider> =
+        Arc::new(MemTable::try_new(left_schema.to_owned(), left_partitions)?);
+    let mem_table_right: Arc<dyn TableProvider> = Arc::new(MemTable::try_new(
+        right_schema.to_owned(),
+        right_partitions,
+    )?);
+
+    let is_optimized_spatial_join = options.is_some();
+    let ctx = setup_context(options, batch_size)?;
+    ctx.register_table("L", Arc::clone(&mem_table_left))?;
+    ctx.register_table("R", Arc::clone(&mem_table_right))?;
+    let df = ctx.sql(sql).await?;
+    let actual_schema = df.schema().as_arrow().clone();
+    let plan = df.clone().create_physical_plan().await?;
+    let spatial_join_execs = collect_spatial_join_exec(&plan)?;
+    if is_optimized_spatial_join {
+        assert_eq!(spatial_join_execs.len(), 1);
+    } else {
+        assert!(spatial_join_execs.is_empty());
+    }
+    let result_batches = df.collect().await?;
+    let result_batch = 
arrow::compute::concat_batches(&Arc::new(actual_schema), &result_batches)?;
+    Ok(result_batch)
+}
+
+fn collect_spatial_join_exec(plan: &Arc<dyn ExecutionPlan>) -> 
Result<Vec<&SpatialJoinExec>> {
+    let mut spatial_join_execs = Vec::new();
+    plan.apply(|node| {
+        if let Some(spatial_join_exec) = 
node.as_any().downcast_ref::<SpatialJoinExec>() {
+            spatial_join_execs.push(spatial_join_exec);
+        }
+        Ok(TreeNodeRecursion::Continue)
+    })?;
+    Ok(spatial_join_execs)
+}
+
+async fn test_mark_join(
+    join_type: JoinType,
+    options: SpatialJoinOptions,
+    batch_size: usize,
+) -> Result<()> {
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_test_data_with_size_range((0.1, 10.0), WKB_GEOMETRY)?;
+    let mem_table_left: Arc<dyn TableProvider> = Arc::new(MemTable::try_new(
+        left_schema.clone(),
+        left_partitions.clone(),
+    )?);
+    let mem_table_right: Arc<dyn TableProvider> = Arc::new(MemTable::try_new(
+        right_schema.clone(),
+        right_partitions.clone(),
+    )?);
+
+    // We use a Left Join as a template to create the plan, then modify it to 
Mark Join
+    let sql = "SELECT * FROM L LEFT JOIN R ON ST_Intersects(L.geometry, 
R.geometry)";
+
+    // Create SpatialJoinExec plan
+    let ctx = setup_context(Some(options.clone()), batch_size)?;
+    ctx.register_table("L", mem_table_left.clone())?;
+    ctx.register_table("R", mem_table_right.clone())?;
+    let df = ctx.sql(sql).await?;
+    let plan = df.create_physical_plan().await?;
+    let spatial_join_execs = collect_spatial_join_exec(&plan)?;
+    assert_eq!(spatial_join_execs.len(), 1);
+    let original_exec = spatial_join_execs[0];
+    let mark_exec = SpatialJoinExec::try_new(
+        original_exec.left.clone(),
+        original_exec.right.clone(),
+        original_exec.on.clone(),
+        original_exec.filter.clone(),
+        &join_type,
+        None,
+        &options,
+    )?;
+
+    // Create NestedLoopJoinExec plan for comparison
+    let ctx_no_opt = setup_context(None, batch_size)?;
+    ctx_no_opt.register_table("L", mem_table_left)?;
+    ctx_no_opt.register_table("R", mem_table_right)?;
+    let df_no_opt = ctx_no_opt.sql(sql).await?;
+    let plan_no_opt = df_no_opt.create_physical_plan().await?;
+    fn collect_nlj_exec(plan: &Arc<dyn ExecutionPlan>) -> 
Result<Vec<&NestedLoopJoinExec>> {
+        let mut execs = Vec::new();
+        plan.apply(|node| {
+            if let Some(exec) = 
node.as_any().downcast_ref::<NestedLoopJoinExec>() {
+                execs.push(exec);
+            }
+            Ok(TreeNodeRecursion::Continue)
+        })?;
+        Ok(execs)
+    }
+    let nlj_execs = collect_nlj_exec(&plan_no_opt)?;
+    assert_eq!(nlj_execs.len(), 1);
+    let original_nlj = nlj_execs[0];
+    let mark_nlj = NestedLoopJoinExec::try_new(
+        original_nlj.children()[0].clone(),
+        original_nlj.children()[1].clone(),
+        original_nlj.filter().cloned(),
+        &join_type,
+        None,
+    )?;
+
+    async fn run_and_sort(
+        plan: Arc<dyn ExecutionPlan>,
+        ctx: &SessionContext,
+    ) -> Result<RecordBatch> {
+        let results = datafusion_physical_plan::collect(plan, 
ctx.task_ctx()).await?;
+        let batch = arrow::compute::concat_batches(&results[0].schema(), 
&results)?;
+        let sort_col = batch.column(0);
+        let indices = arrow::compute::sort_to_indices(sort_col, None, None)?;
+        let sorted_batch = arrow::compute::take_record_batch(&batch, 
&indices)?;
+        Ok(sorted_batch)
+    }
+
+    // Run both Mark Join plans and compare results
+    let mark_batch = run_and_sort(Arc::new(mark_exec), &ctx).await?;
+    let mark_nlj_batch = run_and_sort(Arc::new(mark_nlj), &ctx_no_opt).await?;
+    assert_eq!(mark_batch, mark_nlj_batch);
+
+    Ok(())
+}
+
+fn extract_geoms_and_ids(partitions: &[Vec<RecordBatch>]) -> Vec<(i32, 
geo::Geometry<f64>)> {
+    let mut result = Vec::new();
+    for partition in partitions {
+        for batch in partition {
+            let id_idx = batch.schema().index_of("id").expect("Id column not 
found");
+            let ids = batch
+                .column(id_idx)
+                .as_any()
+                .downcast_ref::<arrow_array::Int32Array>()
+                .expect("Column 'id' should be Int32");
+
+            let geom_idx = batch
+                .schema()
+                .index_of("geometry")
+                .expect("Geometry column not found");
+
+            let geoms_col = batch.column(geom_idx);
+            let geom_type = 
SedonaType::from_storage_field(batch.schema().field(geom_idx))
+                .expect("Failed to get SedonaType from geometry field");
+            let arg_types = [geom_type];
+            let arg_values = [ColumnarValue::Array(Arc::clone(geoms_col))];
+
+            let executor = GeoTypesExecutor::new(&arg_types, &arg_values);
+            let mut id_iter = ids.iter();
+            executor
+                .execute_wkb_void(|maybe_geom| {
+                    if let Some(id_opt) = id_iter.next() {
+                        if let (Some(id), Some(geom)) = (id_opt, maybe_geom) {
+                            result.push((id, geom))
+                        }
+                    }
+                    Ok(())
+                })
+                .expect("Failed to extract geoms and ids from RecordBatch");
+        }
+    }
+    result
+}
+
+fn compute_knn_ground_truth_with_pair_filter<F>(
+    left_partitions: &[Vec<RecordBatch>],
+    right_partitions: &[Vec<RecordBatch>],
+    k: usize,
+    keep_pair: F,
+) -> Vec<(i32, i32, f64)>
+where
+    F: Fn(i32, i32) -> bool,
+{
+    // NOTE: This helper mirrors our KNN semantics used in execution:
+    // - select top-K unfiltered candidates by distance (stable by r_id)
+    // - then apply a cross-side predicate to decide which pairs to keep
+    //   (can yield < K results per probe row)
+    //
+    // The predicate is intentionally *post* top-K selection.
+    // (See `test_knn_join_with_filter_correctness`.)
+    let left_data = extract_geoms_and_ids(left_partitions);
+    let right_data = extract_geoms_and_ids(right_partitions);
+
+    let mut results = Vec::new();
+
+    for (l_id, l_geom) in left_data {
+        let mut distances: Vec<(i32, f64)> = right_data
+            .iter()
+            .map(|(r_id, r_geom)| (*r_id, Euclidean.distance(&l_geom, r_geom)))
+            .collect();
+
+        // Sort by distance, then by ID for stability
+        distances.sort_by(|a, b| a.1.total_cmp(&b.1).then_with(|| 
a.0.cmp(&b.0)));
+
+        // KNN semantics: pick top-K unfiltered, then optionally post-filter.
+        for (r_id, dist) in distances.iter().take(k.min(distances.len())) {
+            if keep_pair(l_id, *r_id) {
+                results.push((l_id, *r_id, *dist));
+            }
+        }
+    }
+
+    // Sort results by L.id, R.id
+    results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
+    results
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_knn_join_correctness(
+    #[values(true, false)] point_only: bool,
+    #[values(1, 2, 3, 4)] num_partitions: usize,
+    #[values(10, 30, 1000)] max_batch_size: usize,
+) -> Result<()> {
+    // Generate slightly larger data
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) = 
if point_only {
+        create_knn_test_data((0.1, 10.0), WKB_GEOMETRY)?
+    } else {
+        create_default_test_data()?
+    };
+
+    // Use single partition to verify algorithm correctness first, avoiding 
partitioning issues
+    let options = SpatialJoinOptions {
+        debug: SpatialJoinDebugOptions {
+            num_spatial_partitions: 
NumSpatialPartitionsConfig::Fixed(num_partitions),
+            ..Default::default()
+        },
+        ..Default::default()
+    };
+    let k = 6;
+
+    let sql1 = format!(
+        "SELECT L.id, R.id, ST_Distance(L.geometry, R.geometry) FROM L JOIN R 
ON ST_KNN(L.geometry, R.geometry, {}, false) ORDER BY L.id, R.id",
+        k
+    );
+    let expected1 = compute_knn_ground_truth_with_pair_filter(
+        &left_partitions,
+        &right_partitions,
+        k,
+        |_l_id, _r_id| true,
+    )
+    .into_iter()
+    .map(|(l, r, _)| (l, r))
+    .collect::<Vec<_>>();
+    let sql2 = format!(
+        "SELECT R.id, L.id, ST_Distance(L.geometry, R.geometry) FROM L JOIN R 
ON ST_KNN(R.geometry, L.geometry, {}, false) ORDER BY R.id, L.id",
+        k
+    );
+    let expected2 = compute_knn_ground_truth_with_pair_filter(
+        &right_partitions,
+        &left_partitions,
+        k,
+        |_l_id, _r_id| true,
+    )
+    .into_iter()
+    .map(|(l, r, _)| (l, r))
+    .collect::<Vec<_>>();
+
+    let sqls = [(&sql1, &expected1), (&sql2, &expected2)];
+
+    for (sql, expected_results) in sqls {
+        let batches = run_spatial_join_query(
+            &left_schema,
+            &right_schema,
+            left_partitions.clone(),
+            right_partitions.clone(),
+            Some(options.clone()),
+            max_batch_size,
+            sql,
+        )
+        .await?;
+
+        // Collect actual results
+        let mut actual_results = Vec::new();
+        let combined_batch = arrow::compute::concat_batches(&batches.schema(), 
&[batches])?;
+        let l_ids = combined_batch
+            .column(0)
+            .as_any()
+            .downcast_ref::<arrow_array::Int32Array>()
+            .unwrap();
+        let r_ids = combined_batch
+            .column(1)
+            .as_any()
+            .downcast_ref::<arrow_array::Int32Array>()
+            .unwrap();
+
+        for i in 0..combined_batch.num_rows() {
+            actual_results.push((l_ids.value(i), r_ids.value(i)));
+        }
+        actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| 
a.1.cmp(&b.1)));
+
+        assert_eq!(actual_results, *expected_results);
+    }
+
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_knn_join_with_filter_correctness(
+    #[values(1, 2, 3, 4)] num_partitions: usize,
+    #[values(10, 30, 1000)] max_batch_size: usize,
+) -> Result<()> {
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_knn_test_data((0.1, 10.0), WKB_GEOMETRY)?;
+
+    let options = SpatialJoinOptions {
+        debug: SpatialJoinDebugOptions {
+            num_spatial_partitions: 
NumSpatialPartitionsConfig::Fixed(num_partitions),
+            ..Default::default()
+        },
+        ..Default::default()
+    };
+
+    let k = 3;
+    let sql = format!(
+        "SELECT L.id AS l_id, R.id AS r_id FROM L JOIN R ON ST_KNN(L.geometry, 
R.geometry, {}, false) AND (L.id % 7) = (R.id % 7)",
+        k
+    );
+
+    let batches = run_spatial_join_query(
+        &left_schema,
+        &right_schema,
+        left_partitions.clone(),
+        right_partitions.clone(),
+        Some(options),
+        max_batch_size,
+        &sql,
+    )
+    .await?;
+
+    let mut actual_results = Vec::new();
+    let combined_batch = arrow::compute::concat_batches(&batches.schema(), 
&[batches])?;
+    let l_ids = combined_batch
+        .column(0)
+        .as_any()
+        .downcast_ref::<arrow_array::Int32Array>()
+        .unwrap();
+    let r_ids = combined_batch
+        .column(1)
+        .as_any()
+        .downcast_ref::<arrow_array::Int32Array>()
+        .unwrap();
+
+    for i in 0..combined_batch.num_rows() {
+        actual_results.push((l_ids.value(i), r_ids.value(i)));
+    }
+    actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
+
+    // Prove the test actually exercises the "< K rows after filtering" case.
+    // Build a list of all probe-side IDs and count how many results each has.
+    let all_left_ids: Vec<i32> = extract_geoms_and_ids(&left_partitions)
+        .into_iter()
+        .map(|(id, _)| id)
+        .collect();
+    let mut per_left_counts: std::collections::HashMap<i32, usize> =
+        std::collections::HashMap::new();
+    for (l_id, _) in &actual_results {
+        *per_left_counts.entry(*l_id).or_default() += 1;
+    }
+    let min_count = all_left_ids
+        .iter()
+        .map(|l_id| *per_left_counts.get(l_id).unwrap_or(&0))
+        .min()
+        .unwrap_or(0);
+    assert!(
+        min_count < k,
+        "expected at least one probe row to produce < K rows after filtering; 
min_count={min_count}, k={k}"
+    );
+
+    let expected_results = compute_knn_ground_truth_with_pair_filter(
+        &left_partitions,
+        &right_partitions,
+        k,
+        |l_id, r_id| (l_id.rem_euclid(7)) == (r_id.rem_euclid(7)),
+    )
+    .into_iter()
+    .map(|(l, r, _)| (l, r))
+    .collect::<Vec<_>>();
+
+    assert_eq!(actual_results, expected_results);
+
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_knn_join_include_tie_breakers(
+    #[values(1, 2, 3, 4)] num_partitions: usize,
+    #[values(10, 100)] max_batch_size: usize,
+) -> Result<()> {
+    // Construct a larger dataset with *guaranteed* exact ties at the kth 
distance.
+    //
+    // For each probe point at (10*i, 0), we create two candidate points at 
(10*i-1, 0)
+    // and (10*i+1, 0). Those two candidates are tied (distance = 1).
+    // A third candidate at (10*i+2, 0) ensures there are also non-tied 
options.
+    // Spacing by 10 keeps other probes' candidates far enough away that they 
never interfere.
+    //
+    // With k=1:
+    // - knn_include_tie_breakers=false should return exactly 1 match per 
probe row.
+    // - knn_include_tie_breakers=true should return 2 matches per probe row 
(both ties).
+    //
+    // The exact choice of which tied row is returned when tie-breakers are 
disabled is not
+    // asserted (it is allowed to be either tied candidate).
+
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("id", DataType::Int32, false),
+        Field::new("wkt", DataType::Utf8, false),
+    ]));
+
+    let num_probe_rows: i32 = 120;
+    let k = 1;
+
+    let input_batches_left = 6;
+    let input_batches_right = 6;
+
+    fn make_batches(
+        schema: SchemaRef,
+        ids: Vec<i32>,
+        wkts: Vec<String>,
+        num_batches: usize,
+    ) -> Result<Vec<RecordBatch>> {
+        assert_eq!(ids.len(), wkts.len());
+        let total = ids.len();
+        let chunk = total.div_ceil(num_batches);
+
+        let mut batches = Vec::new();
+        for b in 0..num_batches {
+            let start = b * chunk;
+            if start >= total {
+                break;
+            }
+            let end = ((b + 1) * chunk).min(total);
+            let batch = RecordBatch::try_new(
+                schema.clone(),
+                vec![
+                    
Arc::new(arrow_array::Int32Array::from(ids[start..end].to_vec())),
+                    Arc::new(arrow_array::StringArray::from(
+                        wkts[start..end]
+                            .iter()
+                            .map(|s| s.as_str())
+                            .collect::<Vec<_>>(),
+                    )),
+                ],
+            )?;
+            batches.push(batch);
+        }
+        Ok(batches)
+    }
+
+    let mut left_ids = Vec::with_capacity(num_probe_rows as usize);
+    let mut left_wkts = Vec::with_capacity(num_probe_rows as usize);
+
+    let mut right_ids = Vec::with_capacity((num_probe_rows as usize) * 3);
+    let mut right_wkts = Vec::with_capacity((num_probe_rows as usize) * 3);
+
+    for i in 0..num_probe_rows {
+        let cx = (i as i64) * 10;
+        left_ids.push(i);
+        left_wkts.push(format!("POINT ({cx} 0)"));
+
+        // Two tied candidates at distance 1.
+        let base = i * 10;
+        right_ids.push(base + 1);
+        right_wkts.push(format!("POINT ({x} 0)", x = cx - 1));
+
+        right_ids.push(base + 2);
+        right_wkts.push(format!("POINT ({x} 0)", x = cx + 1));
+
+        // One non-tied candidate.
+        right_ids.push(base + 3);
+        right_wkts.push(format!("POINT ({x} 0)", x = cx + 2));
+    }
+
+    let left_batches = make_batches(schema.clone(), left_ids, left_wkts, 
input_batches_left)?;
+    let right_batches = make_batches(schema.clone(), right_ids, right_wkts, 
input_batches_right)?;
+
+    // Put each side into a single MemTable partition, but with multiple 
batches.
+    // This ensures the build/probe collectors see 4–8 batches and the 
round-robin batch
+    // partitioner has something to distribute.
+    let left_partitions = vec![left_batches];
+    let right_partitions = vec![right_batches];
+
+    let sql = format!(
+        "SELECT L.id AS l_id, R.id AS r_id \
+            FROM L JOIN R \
+            ON ST_KNN(ST_GeomFromWKT(L.wkt), ST_GeomFromWKT(R.wkt), {k}, 
false)"
+    );
+
+    let base_options = SpatialJoinOptions {
+        debug: SpatialJoinDebugOptions {
+            num_spatial_partitions: 
NumSpatialPartitionsConfig::Fixed(num_partitions),
+            ..Default::default()
+        },
+        ..Default::default()
+    };
+
+    // Without tie-breakers: exactly 1 match per probe row.
+    let out_no_ties = run_spatial_join_query(
+        &schema,
+        &schema,
+        left_partitions.clone(),
+        right_partitions.clone(),
+        Some(SpatialJoinOptions {
+            knn_include_tie_breakers: false,
+            ..base_options.clone()
+        }),
+        max_batch_size,
+        &sql,
+    )
+    .await?;
+    let combined = arrow::compute::concat_batches(&out_no_ties.schema(), 
&[out_no_ties])?;
+
+    let l_ids = combined
+        .column(0)
+        .as_any()
+        .downcast_ref::<arrow_array::Int32Array>()
+        .unwrap();
+    let r_ids = combined
+        .column(1)
+        .as_any()
+        .downcast_ref::<arrow_array::Int32Array>()
+        .unwrap();
+
+    let mut per_left: std::collections::HashMap<i32, Vec<i32>> = 
std::collections::HashMap::new();
+    for i in 0..combined.num_rows() {
+        per_left
+            .entry(l_ids.value(i))
+            .or_default()
+            .push(r_ids.value(i));
+    }
+
+    assert_eq!(per_left.len() as i32, num_probe_rows);
+    for l_id in 0..num_probe_rows {
+        let r_list = per_left.get(&l_id).unwrap();
+        assert_eq!(
+            r_list.len(),
+            1,
+            "expected exactly 1 match for l_id={l_id} when tie-breakers are 
disabled"
+        );
+        let base = l_id * 10;
+        let r_id = r_list[0];
+        assert!(
+            r_id == base + 1 || r_id == base + 2,
+            "expected a tied nearest neighbor for l_id={l_id}, got r_id={r_id}"
+        );
+    }
+
+    // With tie-breakers: exactly 2 matches per probe row (both tied 
candidates).
+    let out_with_ties = run_spatial_join_query(
+        &schema,
+        &schema,
+        left_partitions.clone(),
+        right_partitions.clone(),
+        Some(SpatialJoinOptions {
+            knn_include_tie_breakers: true,
+            ..base_options
+        }),
+        max_batch_size,
+        &sql,
+    )
+    .await?;
+    let combined = arrow::compute::concat_batches(&out_with_ties.schema(), 
&[out_with_ties])?;
+    let l_ids = combined
+        .column(0)
+        .as_any()
+        .downcast_ref::<arrow_array::Int32Array>()
+        .unwrap();
+    let r_ids = combined
+        .column(1)
+        .as_any()
+        .downcast_ref::<arrow_array::Int32Array>()
+        .unwrap();
+
+    let mut per_left: std::collections::HashMap<i32, Vec<i32>> = 
std::collections::HashMap::new();
+    for i in 0..combined.num_rows() {
+        per_left
+            .entry(l_ids.value(i))
+            .or_default()
+            .push(r_ids.value(i));
+    }
+    assert_eq!(per_left.len() as i32, num_probe_rows);
+    for l_id in 0..num_probe_rows {
+        let mut r_list = per_left.get(&l_id).unwrap().clone();
+        r_list.sort();
+        let base = l_id * 10;
+        assert_eq!(
+            r_list,
+            vec![base + 1, base + 2],
+            "expected both tied nearest neighbors for l_id={l_id}"
+        );
+    }
+
+    Ok(())
+}

Reply via email to