This is an automated email from the ASF dual-hosted git repository.

paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f1d2704 feat(rust/sedona-functions): Simplify ST_KNN() call for most 
common case (#667)
6f1d2704 is described below

commit 6f1d27044775efdfd082b0ab48d396b619101221
Author: Dewey Dunnington <[email protected]>
AuthorDate: Fri Feb 27 16:18:24 2026 -0600

    feat(rust/sedona-functions): Simplify ST_KNN() call for most common case 
(#667)
---
 c/sedona-extension/src/scalar_kernel.rs            |  2 +-
 c/sedona-proj/src/st_transform.rs                  | 10 +--
 docs/reference/sql/st_knn.qmd                      | 22 ++++-
 python/sedonadb/tests/test_knnjoin.py              | 29 +++++--
 rust/sedona-expr/src/aggregate_udf.rs              | 18 +++--
 rust/sedona-expr/src/item_crs.rs                   |  2 +-
 rust/sedona-expr/src/scalar_udf.rs                 | 18 +++--
 rust/sedona-functions/src/st_knn.rs                |  4 +-
 .../src/planner/spatial_expr_utils.rs              | 93 +++++++++++++++++-----
 9 files changed, 150 insertions(+), 48 deletions(-)

diff --git a/c/sedona-extension/src/scalar_kernel.rs 
b/c/sedona-extension/src/scalar_kernel.rs
index 2928421c..66f0f787 100644
--- a/c/sedona-extension/src/scalar_kernel.rs
+++ b/c/sedona-extension/src/scalar_kernel.rs
@@ -735,7 +735,7 @@ mod test {
         let err = ffi_tester.return_type().unwrap_err();
         assert_eq!(
             err.message(),
-            "simple_udf_from_ffi([]): No kernel matching arguments"
+            "simple_udf_from_ffi(): No kernel matching arguments"
         );
     }
 
diff --git a/c/sedona-proj/src/st_transform.rs 
b/c/sedona-proj/src/st_transform.rs
index ccd7839f..efeb3b02 100644
--- a/c/sedona-proj/src/st_transform.rs
+++ b/c/sedona-proj/src/st_transform.rs
@@ -740,7 +740,7 @@ mod tests {
         let err = tester.return_type().unwrap_err();
         assert_eq!(
             err.message(),
-            "st_transform([]): No kernel matching arguments"
+            "st_transform(): No kernel matching arguments"
         );
 
         // Too many args
@@ -756,7 +756,7 @@ mod tests {
         let err = tester.return_type().unwrap_err();
         assert_eq!(
             err.message(),
-            "st_transform([Arrow(Utf8), Arrow(Utf8), Arrow(Utf8), 
Arrow(Utf8)]): No kernel matching arguments"
+            "st_transform(utf8, utf8, utf8, utf8): No kernel matching 
arguments"
         );
 
         // First arg not geometry
@@ -770,7 +770,7 @@ mod tests {
         let err = tester.return_type().unwrap_err();
         assert_eq!(
             err.message(),
-            "st_transform([Arrow(Utf8), Arrow(Utf8)]): No kernel matching 
arguments"
+            "st_transform(utf8, utf8): No kernel matching arguments"
         );
 
         // Second arg not string or numeric
@@ -781,7 +781,7 @@ mod tests {
         let err = tester.return_type().unwrap_err();
         assert_eq!(
             err.message(),
-            "st_transform([Wkb(Planar, None), Arrow(Boolean)]): No kernel 
matching arguments"
+            "st_transform(geometry, boolean): No kernel matching arguments"
         );
 
         // third arg not string or numeric
@@ -796,7 +796,7 @@ mod tests {
         let err = tester.return_type().unwrap_err();
         assert_eq!(
             err.message(),
-            "st_transform([Wkb(Planar, None), Arrow(Utf8), Arrow(Boolean)]): 
No kernel matching arguments"
+            "st_transform(geometry, utf8, boolean): No kernel matching 
arguments"
         );
     }
 
diff --git a/docs/reference/sql/st_knn.qmd b/docs/reference/sql/st_knn.qmd
index f83e6fd2..264b54bd 100644
--- a/docs/reference/sql/st_knn.qmd
+++ b/docs/reference/sql/st_knn.qmd
@@ -24,15 +24,29 @@ kernels:
     - name: geomA
       type: geometry
       description: The geometry around which to search.
-    - name: geomb
+    - name: geomB
       type: geometry
       description: Column containing candidate geometries.
     - name: k
       type: integer
-      description: The number of nearest neighbours to return.
+      description: The number of nearest neighbours to return. Defaults to 1.
     - name: use_spheroid
       type: boolean
-      description: true to use spherical distance, false for Euclidean
+      description: true to use spherical distance, false for Euclidean. 
Defaults to false.
+  - returns: boolean
+    args:
+    - name: geomA
+      type: geometry
+    - name: geomB
+      type: geometry
+    - name: k
+      type: integer
+  - returns: boolean
+    args:
+    - name: geomA
+      type: geometry
+    - name: geomB
+      type: geometry
 ---
 
 ## Description
@@ -54,5 +68,5 @@ WITH table2 AS (
 table1 AS (
     SELECT ST_Point(0, 0) AS geom2
 )
-SELECT * FROM table1 JOIN table2 ON ST_KNN(geom1, geom2, 3, false);
+SELECT * FROM table1 JOIN table2 ON ST_KNN(geom1, geom2, 3);
 ```
diff --git a/python/sedonadb/tests/test_knnjoin.py 
b/python/sedonadb/tests/test_knnjoin.py
index 331bcf1f..e23b1c56 100644
--- a/python/sedonadb/tests/test_knnjoin.py
+++ b/python/sedonadb/tests/test_knnjoin.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import pandas as pd
 import pytest
 import json
 from sedonadb.testing import PostGIS, SedonaDB, random_geometry
@@ -47,7 +48,7 @@ def test_knn_join_basic(k):
                 t.id as target_id,
                 ST_Distance(q.geometry, t.geometry) as distance
             FROM knn_query_points q
-            JOIN knn_target_points t ON ST_KNN(q.geometry, t.geometry, {k}, 
FALSE)
+            JOIN knn_target_points t ON ST_KNN(q.geometry, t.geometry, {k})
             ORDER BY query_id, distance
         """
 
@@ -117,7 +118,7 @@ def test_knn_join_with_polygons():
                 pol.id as polygon_id,
                 ST_Distance(p.geometry, pol.geometry) as distance
             FROM knn_points p
-            JOIN knn_polygons pol ON ST_KNN(p.geometry, pol.geometry, {k}, 
FALSE)
+            JOIN knn_polygons pol ON ST_KNN(p.geometry, pol.geometry, {k})
             ORDER BY point_id, distance
         """
 
@@ -182,7 +183,7 @@ def test_knn_join_edge_cases():
                 t.id as target_id,
                 ST_Distance(q.geometry, t.geometry) as distance
             FROM knn_query_small q
-            JOIN knn_target_small t ON ST_KNN(q.geometry, t.geometry, {k}, 
FALSE)
+            JOIN knn_target_small t ON ST_KNN(q.geometry, t.geometry, {k})
             ORDER BY query_id, distance
         """
 
@@ -272,7 +273,7 @@ def test_knn_join_with_attributes():
                 t.target_value,
                 ST_Distance(q.geometry, t.geometry) as distance
             FROM knn_points_attr q
-            JOIN knn_targets_attr t ON ST_KNN(q.geometry, t.geometry, {k}, 
FALSE)
+            JOIN knn_targets_attr t ON ST_KNN(q.geometry, t.geometry, {k})
             ORDER BY query_id, distance
         """
 
@@ -338,7 +339,7 @@ def test_knn_join_correctness_known_points():
                 t.id as target_id,
                 ST_Distance(q.geometry, t.geometry) as distance
             FROM knn_known q
-            JOIN knn_target_known t ON ST_KNN(q.geometry, t.geometry, {k}, 
FALSE)
+            JOIN knn_target_known t ON ST_KNN(q.geometry, t.geometry, {k})
             WHERE q.id = 0  -- Query from first point (synthetic data uses 
0-based IDs)
             ORDER BY distance
         """
@@ -374,3 +375,21 @@ def test_knn_join_correctness_known_points():
         """
 
         eng_postgis.assert_query_result(postgis_sql, sedonadb_results)
+
+
+def test_knn_join_default_k(con):
+    con.funcs.table.sd_random_geometry("Point", 100, seed=9483).to_view(
+        "l", overwrite=True
+    )
+    con.funcs.table.sd_random_geometry("Point", 100, seed=2983).to_view(
+        "r", overwrite=True
+    )
+
+    nearest_default_k = con.sql(
+        "SELECT l.id as lid, r.id as rid FROM l JOIN r ON ST_KNN(l.geometry, 
r.geometry) ORDER BY lid, rid"
+    ).to_pandas()
+    nearest_explicit_k = con.sql(
+        "SELECT l.id as lid, r.id as rid FROM l JOIN r ON ST_KNN(l.geometry, 
r.geometry, 1) ORDER BY lid, rid"
+    ).to_pandas()
+
+    pd.testing.assert_frame_equal(nearest_default_k, nearest_explicit_k)
diff --git a/rust/sedona-expr/src/aggregate_udf.rs 
b/rust/sedona-expr/src/aggregate_udf.rs
index e2d6c076..6dc8d88b 100644
--- a/rust/sedona-expr/src/aggregate_udf.rs
+++ b/rust/sedona-expr/src/aggregate_udf.rs
@@ -139,7 +139,16 @@ impl SedonaAggregateUDF {
             }
         }
 
-        not_impl_err!("{}({:?}): No kernel matching arguments", self.name, 
args)
+        let args_display = args
+            .iter()
+            .map(|arg| arg.logical_type_name())
+            .collect::<Vec<_>>()
+            .join(", ");
+
+        not_impl_err!(
+            "{}({args_display}): No kernel matching arguments",
+            self.name
+        )
     }
 }
 
@@ -268,15 +277,12 @@ mod test {
         );
         assert_eq!(udf.name(), "empty");
         let err = udf.return_field(&[]).unwrap_err();
-        assert_eq!(err.message(), "empty([]): No kernel matching arguments");
+        assert_eq!(err.message(), "empty(): No kernel matching arguments");
         assert!(udf.kernels().is_empty());
         assert_eq!(udf.coerce_types(&[])?, vec![]);
 
         let batch_err = udf.return_field(&[]).unwrap_err();
-        assert_eq!(
-            batch_err.message(),
-            "empty([]): No kernel matching arguments"
-        );
+        assert_eq!(batch_err.message(), "empty(): No kernel matching 
arguments");
 
         Ok(())
     }
diff --git a/rust/sedona-expr/src/item_crs.rs b/rust/sedona-expr/src/item_crs.rs
index d584cb22..8abf99c7 100644
--- a/rust/sedona-expr/src/item_crs.rs
+++ b/rust/sedona-expr/src/item_crs.rs
@@ -727,7 +727,7 @@ mod test {
         let err = tester.return_type().unwrap_err();
         assert_eq!(
             err.message(),
-            "fun([Wkb(Planar, None), Wkb(Planar, None)]): No kernel matching 
arguments"
+            "fun(geometry, geometry): No kernel matching arguments"
         );
     }
 
diff --git a/rust/sedona-expr/src/scalar_udf.rs 
b/rust/sedona-expr/src/scalar_udf.rs
index 37c6cff0..07d2b0ea 100644
--- a/rust/sedona-expr/src/scalar_udf.rs
+++ b/rust/sedona-expr/src/scalar_udf.rs
@@ -238,7 +238,16 @@ impl SedonaScalarUDF {
             }
         }
 
-        not_impl_err!("{}({:?}): No kernel matching arguments", self.name, 
args)
+        let args_display = args
+            .iter()
+            .map(|arg| arg.logical_type_name())
+            .collect::<Vec<_>>()
+            .join(", ");
+
+        not_impl_err!(
+            "{}({args_display}): No kernel matching arguments",
+            self.name
+        )
     }
 }
 
@@ -335,13 +344,10 @@ mod tests {
         let tester = ScalarUdfTester::new(udf.into(), vec![]);
 
         let err = tester.return_type().unwrap_err();
-        assert_eq!(err.message(), "empty([]): No kernel matching arguments");
+        assert_eq!(err.message(), "empty(): No kernel matching arguments");
 
         let batch_err = tester.invoke_arrays(vec![]).unwrap_err();
-        assert_eq!(
-            batch_err.message(),
-            "empty([]): No kernel matching arguments"
-        );
+        assert_eq!(batch_err.message(), "empty(): No kernel matching 
arguments");
 
         Ok(())
     }
diff --git a/rust/sedona-functions/src/st_knn.rs 
b/rust/sedona-functions/src/st_knn.rs
index 6ca3563b..3b9c67ca 100644
--- a/rust/sedona-functions/src/st_knn.rs
+++ b/rust/sedona-functions/src/st_knn.rs
@@ -33,8 +33,8 @@ pub fn st_knn_udf() -> SedonaScalarUDF {
             vec![
                 ArgMatcher::is_geometry_or_geography(),
                 ArgMatcher::is_geometry_or_geography(),
-                ArgMatcher::is_numeric(),
-                ArgMatcher::is_boolean(),
+                ArgMatcher::optional(ArgMatcher::is_numeric()),
+                ArgMatcher::optional(ArgMatcher::is_boolean()),
             ],
             SedonaType::Arrow(DataType::Boolean),
         ),
diff --git a/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs 
b/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs
index 3602e631..700b51bc 100644
--- a/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs
+++ b/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs
@@ -340,18 +340,36 @@ fn match_knn_predicate(
     }
 
     let args = scalar_fn.args();
-    if args.len() < 4 {
-        return None; // ST_KNN requires 4 arguments: (queries_geom, 
objects_geom, k, use_spheroid)
+
+    if args.len() < 2 {
+        return None;
     }
 
     let queries_geom = &args[0];
     let objects_geom = &args[1];
-    let k_expr = &args[2];
-    let use_spheroid_expr = &args[3];
 
-    // Extract literal values for k and use_spheroid
-    let k = extract_literal_u32(k_expr)?;
-    let use_spheroid = extract_literal_bool(use_spheroid_expr)?;
+    let (k, use_spheroid) = match args.len() {
+        2 => {
+            // Apply default k (1) and use_spheroid (false)
+            (1, false)
+        }
+        3 => {
+            // Extract literal values for k and apply default use_spheroid 
(false)
+            let k_expr = &args[2];
+            let k = extract_literal_u32(k_expr)?;
+            (k, false)
+        }
+        4 => {
+            // Extract literal values for k and use_spheroid
+            let k_expr = &args[2];
+            let k = extract_literal_u32(k_expr)?;
+
+            let use_spheroid_expr = &args[3];
+            let use_spheroid = extract_literal_bool(use_spheroid_expr)?;
+            (k, use_spheroid)
+        }
+        _ => return None,
+    };
 
     // Collect column references for geometry arguments
     let queries_refs = collect_column_references(queries_geom, column_indices);
@@ -1663,22 +1681,20 @@ mod tests {
     fn test_match_knn_predicate_basic() {
         let column_indices = create_test_column_indices();
 
-        // Create ST_KNN(left_geom, right_geom, 5, false)
+        // Create ST_KNN(left_geom, right_geom, 5, true)
         let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn 
PhysicalExpr>;
         let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn 
PhysicalExpr>;
         let k_literal =
             Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn 
PhysicalExpr>;
         let use_spheroid_literal =
-            Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as 
Arc<dyn PhysicalExpr>;
+            Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) as 
Arc<dyn PhysicalExpr>;
 
         let st_knn_udf = create_dummy_st_knn_udf();
         let args = vec![left_geom, right_geom, k_literal, 
use_spheroid_literal];
         let st_knn = create_spatial_function_expr(st_knn_udf, args);
 
-        let predicate = match_knn_predicate(&st_knn, &column_indices);
-        assert!(predicate.is_some());
+        let pred = match_knn_predicate(&st_knn, &column_indices).unwrap();
 
-        let pred = predicate.unwrap();
         // Verify left argument is reprojected to left side
         let left_arg_col = 
pred.left.as_any().downcast_ref::<Column>().unwrap();
         assert_eq!(left_arg_col.index(), 1);
@@ -1692,7 +1708,51 @@ mod tests {
         // Verify k is literal value 5
         assert_eq!(pred.k, 5);
 
-        // Verify use_spheroid is literal value false
+        // Verify use_spheroid is literal value true
+        assert!(pred.use_spheroid);
+    }
+
+    #[test]
+    fn test_match_knn_predicate_default_use_spheroid() {
+        let column_indices = create_test_column_indices();
+
+        // Create ST_KNN(left_geom, right_geom, 5)
+        let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn 
PhysicalExpr>;
+        let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn 
PhysicalExpr>;
+        let k_literal =
+            Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn 
PhysicalExpr>;
+
+        let st_knn_udf = create_dummy_st_knn_udf();
+        let args = vec![left_geom, right_geom, k_literal];
+        let st_knn = create_spatial_function_expr(st_knn_udf, args);
+
+        let pred = match_knn_predicate(&st_knn, &column_indices).unwrap();
+
+        // Verify k is literal value 5
+        assert_eq!(pred.k, 5);
+
+        // Verify use_spheroid is literal value false (default)
+        assert!(!pred.use_spheroid);
+    }
+
+    #[test]
+    fn test_match_knn_predicate_default_k() {
+        let column_indices = create_test_column_indices();
+
+        // Create ST_KNN(left_geom, right_geom)
+        let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn 
PhysicalExpr>;
+        let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn 
PhysicalExpr>;
+
+        let st_knn_udf = create_dummy_st_knn_udf();
+        let args = vec![left_geom, right_geom];
+        let st_knn = create_spatial_function_expr(st_knn_udf, args);
+
+        let pred = match_knn_predicate(&st_knn, &column_indices).unwrap();
+
+        // Verify k is literal value 1 (default)
+        assert_eq!(pred.k, 1);
+
+        // Verify use_spheroid is literal value false (default)
         assert!(!pred.use_spheroid);
     }
 
@@ -1924,14 +1984,11 @@ mod tests {
     fn test_match_knn_predicate_insufficient_args() {
         let column_indices = create_test_column_indices();
 
-        // Create ST_KNN with only 3 arguments (insufficient - needs 4)
+        // Create ST_KNN with only 1 arguments (insufficient - needs 2)
         let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn 
PhysicalExpr>;
-        let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn 
PhysicalExpr>;
-        let k_literal =
-            Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn 
PhysicalExpr>;
 
         let st_knn_udf = create_dummy_st_knn_udf();
-        let args = vec![left_geom, right_geom, k_literal]; // Missing 
use_spheroid arg
+        let args = vec![left_geom];
         let st_knn = create_spatial_function_expr(st_knn_udf, args);
 
         let predicate = match_knn_predicate(&st_knn, &column_indices);

Reply via email to