This is an automated email from the ASF dual-hosted git repository.
paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new 6f1d2704 feat(rust/sedona-functions): Simplify ST_KNN() call for most
common case (#667)
6f1d2704 is described below
commit 6f1d27044775efdfd082b0ab48d396b619101221
Author: Dewey Dunnington <[email protected]>
AuthorDate: Fri Feb 27 16:18:24 2026 -0600
feat(rust/sedona-functions): Simplify ST_KNN() call for most common case
(#667)
---
c/sedona-extension/src/scalar_kernel.rs | 2 +-
c/sedona-proj/src/st_transform.rs | 10 +--
docs/reference/sql/st_knn.qmd | 22 ++++-
python/sedonadb/tests/test_knnjoin.py | 29 +++++--
rust/sedona-expr/src/aggregate_udf.rs | 18 +++--
rust/sedona-expr/src/item_crs.rs | 2 +-
rust/sedona-expr/src/scalar_udf.rs | 18 +++--
rust/sedona-functions/src/st_knn.rs | 4 +-
.../src/planner/spatial_expr_utils.rs | 93 +++++++++++++++++-----
9 files changed, 150 insertions(+), 48 deletions(-)
diff --git a/c/sedona-extension/src/scalar_kernel.rs
b/c/sedona-extension/src/scalar_kernel.rs
index 2928421c..66f0f787 100644
--- a/c/sedona-extension/src/scalar_kernel.rs
+++ b/c/sedona-extension/src/scalar_kernel.rs
@@ -735,7 +735,7 @@ mod test {
let err = ffi_tester.return_type().unwrap_err();
assert_eq!(
err.message(),
- "simple_udf_from_ffi([]): No kernel matching arguments"
+ "simple_udf_from_ffi(): No kernel matching arguments"
);
}
diff --git a/c/sedona-proj/src/st_transform.rs
b/c/sedona-proj/src/st_transform.rs
index ccd7839f..efeb3b02 100644
--- a/c/sedona-proj/src/st_transform.rs
+++ b/c/sedona-proj/src/st_transform.rs
@@ -740,7 +740,7 @@ mod tests {
let err = tester.return_type().unwrap_err();
assert_eq!(
err.message(),
- "st_transform([]): No kernel matching arguments"
+ "st_transform(): No kernel matching arguments"
);
// Too many args
@@ -756,7 +756,7 @@ mod tests {
let err = tester.return_type().unwrap_err();
assert_eq!(
err.message(),
- "st_transform([Arrow(Utf8), Arrow(Utf8), Arrow(Utf8),
Arrow(Utf8)]): No kernel matching arguments"
+ "st_transform(utf8, utf8, utf8, utf8): No kernel matching
arguments"
);
// First arg not geometry
@@ -770,7 +770,7 @@ mod tests {
let err = tester.return_type().unwrap_err();
assert_eq!(
err.message(),
- "st_transform([Arrow(Utf8), Arrow(Utf8)]): No kernel matching
arguments"
+ "st_transform(utf8, utf8): No kernel matching arguments"
);
// Second arg not string or numeric
@@ -781,7 +781,7 @@ mod tests {
let err = tester.return_type().unwrap_err();
assert_eq!(
err.message(),
- "st_transform([Wkb(Planar, None), Arrow(Boolean)]): No kernel
matching arguments"
+ "st_transform(geometry, boolean): No kernel matching arguments"
);
// third arg not string or numeric
@@ -796,7 +796,7 @@ mod tests {
let err = tester.return_type().unwrap_err();
assert_eq!(
err.message(),
- "st_transform([Wkb(Planar, None), Arrow(Utf8), Arrow(Boolean)]):
No kernel matching arguments"
+ "st_transform(geometry, utf8, boolean): No kernel matching
arguments"
);
}
diff --git a/docs/reference/sql/st_knn.qmd b/docs/reference/sql/st_knn.qmd
index f83e6fd2..264b54bd 100644
--- a/docs/reference/sql/st_knn.qmd
+++ b/docs/reference/sql/st_knn.qmd
@@ -24,15 +24,29 @@ kernels:
- name: geomA
type: geometry
description: The geometry around which to search.
- - name: geomb
+ - name: geomB
type: geometry
description: Column containing candidate geometries.
- name: k
type: integer
- description: The number of nearest neighbours to return.
+ description: The number of nearest neighbours to return. Defaults to 1.
- name: use_spheroid
type: boolean
- description: true to use spherical distance, false for Euclidean
+ description: true to use spherical distance, false for Euclidean.
Defaults to false.
+ - returns: boolean
+ args:
+ - name: geomA
+ type: geometry
+ - name: geomB
+ type: geometry
+ - name: k
+ type: integer
+ - returns: boolean
+ args:
+ - name: geomA
+ type: geometry
+ - name: geomB
+ type: geometry
---
## Description
@@ -54,5 +68,5 @@ WITH table2 AS (
table1 AS (
SELECT ST_Point(0, 0) AS geom2
)
-SELECT * FROM table1 JOIN table2 ON ST_KNN(geom1, geom2, 3, false);
+SELECT * FROM table1 JOIN table2 ON ST_KNN(geom1, geom2, 3);
```
diff --git a/python/sedonadb/tests/test_knnjoin.py
b/python/sedonadb/tests/test_knnjoin.py
index 331bcf1f..e23b1c56 100644
--- a/python/sedonadb/tests/test_knnjoin.py
+++ b/python/sedonadb/tests/test_knnjoin.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
+import pandas as pd
import pytest
import json
from sedonadb.testing import PostGIS, SedonaDB, random_geometry
@@ -47,7 +48,7 @@ def test_knn_join_basic(k):
t.id as target_id,
ST_Distance(q.geometry, t.geometry) as distance
FROM knn_query_points q
- JOIN knn_target_points t ON ST_KNN(q.geometry, t.geometry, {k},
FALSE)
+ JOIN knn_target_points t ON ST_KNN(q.geometry, t.geometry, {k})
ORDER BY query_id, distance
"""
@@ -117,7 +118,7 @@ def test_knn_join_with_polygons():
pol.id as polygon_id,
ST_Distance(p.geometry, pol.geometry) as distance
FROM knn_points p
- JOIN knn_polygons pol ON ST_KNN(p.geometry, pol.geometry, {k},
FALSE)
+ JOIN knn_polygons pol ON ST_KNN(p.geometry, pol.geometry, {k})
ORDER BY point_id, distance
"""
@@ -182,7 +183,7 @@ def test_knn_join_edge_cases():
t.id as target_id,
ST_Distance(q.geometry, t.geometry) as distance
FROM knn_query_small q
- JOIN knn_target_small t ON ST_KNN(q.geometry, t.geometry, {k},
FALSE)
+ JOIN knn_target_small t ON ST_KNN(q.geometry, t.geometry, {k})
ORDER BY query_id, distance
"""
@@ -272,7 +273,7 @@ def test_knn_join_with_attributes():
t.target_value,
ST_Distance(q.geometry, t.geometry) as distance
FROM knn_points_attr q
- JOIN knn_targets_attr t ON ST_KNN(q.geometry, t.geometry, {k},
FALSE)
+ JOIN knn_targets_attr t ON ST_KNN(q.geometry, t.geometry, {k})
ORDER BY query_id, distance
"""
@@ -338,7 +339,7 @@ def test_knn_join_correctness_known_points():
t.id as target_id,
ST_Distance(q.geometry, t.geometry) as distance
FROM knn_known q
- JOIN knn_target_known t ON ST_KNN(q.geometry, t.geometry, {k},
FALSE)
+ JOIN knn_target_known t ON ST_KNN(q.geometry, t.geometry, {k})
WHERE q.id = 0 -- Query from first point (synthetic data uses
0-based IDs)
ORDER BY distance
"""
@@ -374,3 +375,21 @@ def test_knn_join_correctness_known_points():
"""
eng_postgis.assert_query_result(postgis_sql, sedonadb_results)
+
+
+def test_knn_join_default_k(con):
+ con.funcs.table.sd_random_geometry("Point", 100, seed=9483).to_view(
+ "l", overwrite=True
+ )
+ con.funcs.table.sd_random_geometry("Point", 100, seed=2983).to_view(
+ "r", overwrite=True
+ )
+
+ nearest_default_k = con.sql(
+ "SELECT l.id as lid, r.id as rid FROM l JOIN r ON ST_KNN(l.geometry,
r.geometry) ORDER BY lid, rid"
+ ).to_pandas()
+ nearest_explicit_k = con.sql(
+ "SELECT l.id as lid, r.id as rid FROM l JOIN r ON ST_KNN(l.geometry,
r.geometry, 1) ORDER BY lid, rid"
+ ).to_pandas()
+
+ pd.testing.assert_frame_equal(nearest_default_k, nearest_explicit_k)
diff --git a/rust/sedona-expr/src/aggregate_udf.rs
b/rust/sedona-expr/src/aggregate_udf.rs
index e2d6c076..6dc8d88b 100644
--- a/rust/sedona-expr/src/aggregate_udf.rs
+++ b/rust/sedona-expr/src/aggregate_udf.rs
@@ -139,7 +139,16 @@ impl SedonaAggregateUDF {
}
}
- not_impl_err!("{}({:?}): No kernel matching arguments", self.name,
args)
+ let args_display = args
+ .iter()
+ .map(|arg| arg.logical_type_name())
+ .collect::<Vec<_>>()
+ .join(", ");
+
+ not_impl_err!(
+ "{}({args_display}): No kernel matching arguments",
+ self.name
+ )
}
}
@@ -268,15 +277,12 @@ mod test {
);
assert_eq!(udf.name(), "empty");
let err = udf.return_field(&[]).unwrap_err();
- assert_eq!(err.message(), "empty([]): No kernel matching arguments");
+ assert_eq!(err.message(), "empty(): No kernel matching arguments");
assert!(udf.kernels().is_empty());
assert_eq!(udf.coerce_types(&[])?, vec![]);
let batch_err = udf.return_field(&[]).unwrap_err();
- assert_eq!(
- batch_err.message(),
- "empty([]): No kernel matching arguments"
- );
+ assert_eq!(batch_err.message(), "empty(): No kernel matching
arguments");
Ok(())
}
diff --git a/rust/sedona-expr/src/item_crs.rs b/rust/sedona-expr/src/item_crs.rs
index d584cb22..8abf99c7 100644
--- a/rust/sedona-expr/src/item_crs.rs
+++ b/rust/sedona-expr/src/item_crs.rs
@@ -727,7 +727,7 @@ mod test {
let err = tester.return_type().unwrap_err();
assert_eq!(
err.message(),
- "fun([Wkb(Planar, None), Wkb(Planar, None)]): No kernel matching
arguments"
+ "fun(geometry, geometry): No kernel matching arguments"
);
}
diff --git a/rust/sedona-expr/src/scalar_udf.rs
b/rust/sedona-expr/src/scalar_udf.rs
index 37c6cff0..07d2b0ea 100644
--- a/rust/sedona-expr/src/scalar_udf.rs
+++ b/rust/sedona-expr/src/scalar_udf.rs
@@ -238,7 +238,16 @@ impl SedonaScalarUDF {
}
}
- not_impl_err!("{}({:?}): No kernel matching arguments", self.name,
args)
+ let args_display = args
+ .iter()
+ .map(|arg| arg.logical_type_name())
+ .collect::<Vec<_>>()
+ .join(", ");
+
+ not_impl_err!(
+ "{}({args_display}): No kernel matching arguments",
+ self.name
+ )
}
}
@@ -335,13 +344,10 @@ mod tests {
let tester = ScalarUdfTester::new(udf.into(), vec![]);
let err = tester.return_type().unwrap_err();
- assert_eq!(err.message(), "empty([]): No kernel matching arguments");
+ assert_eq!(err.message(), "empty(): No kernel matching arguments");
let batch_err = tester.invoke_arrays(vec![]).unwrap_err();
- assert_eq!(
- batch_err.message(),
- "empty([]): No kernel matching arguments"
- );
+ assert_eq!(batch_err.message(), "empty(): No kernel matching
arguments");
Ok(())
}
diff --git a/rust/sedona-functions/src/st_knn.rs
b/rust/sedona-functions/src/st_knn.rs
index 6ca3563b..3b9c67ca 100644
--- a/rust/sedona-functions/src/st_knn.rs
+++ b/rust/sedona-functions/src/st_knn.rs
@@ -33,8 +33,8 @@ pub fn st_knn_udf() -> SedonaScalarUDF {
vec![
ArgMatcher::is_geometry_or_geography(),
ArgMatcher::is_geometry_or_geography(),
- ArgMatcher::is_numeric(),
- ArgMatcher::is_boolean(),
+ ArgMatcher::optional(ArgMatcher::is_numeric()),
+ ArgMatcher::optional(ArgMatcher::is_boolean()),
],
SedonaType::Arrow(DataType::Boolean),
),
diff --git a/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs
b/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs
index 3602e631..700b51bc 100644
--- a/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs
+++ b/rust/sedona-spatial-join/src/planner/spatial_expr_utils.rs
@@ -340,18 +340,36 @@ fn match_knn_predicate(
}
let args = scalar_fn.args();
- if args.len() < 4 {
- return None; // ST_KNN requires 4 arguments: (queries_geom,
objects_geom, k, use_spheroid)
+
+ if args.len() < 2 {
+ return None;
}
let queries_geom = &args[0];
let objects_geom = &args[1];
- let k_expr = &args[2];
- let use_spheroid_expr = &args[3];
- // Extract literal values for k and use_spheroid
- let k = extract_literal_u32(k_expr)?;
- let use_spheroid = extract_literal_bool(use_spheroid_expr)?;
+ let (k, use_spheroid) = match args.len() {
+ 2 => {
+ // Apply default k (1) and use_spheroid (false)
+ (1, false)
+ }
+ 3 => {
+ // Extract literal values for k and apply default use_spheroid
(false)
+ let k_expr = &args[2];
+ let k = extract_literal_u32(k_expr)?;
+ (k, false)
+ }
+ 4 => {
+ // Extract literal values for k and use_spheroid
+ let k_expr = &args[2];
+ let k = extract_literal_u32(k_expr)?;
+
+ let use_spheroid_expr = &args[3];
+ let use_spheroid = extract_literal_bool(use_spheroid_expr)?;
+ (k, use_spheroid)
+ }
+ _ => return None,
+ };
// Collect column references for geometry arguments
let queries_refs = collect_column_references(queries_geom, column_indices);
@@ -1663,22 +1681,20 @@ mod tests {
fn test_match_knn_predicate_basic() {
let column_indices = create_test_column_indices();
- // Create ST_KNN(left_geom, right_geom, 5, false)
+ // Create ST_KNN(left_geom, right_geom, 5, true)
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn
PhysicalExpr>;
let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn
PhysicalExpr>;
let k_literal =
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn
PhysicalExpr>;
let use_spheroid_literal =
- Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) as
Arc<dyn PhysicalExpr>;
+ Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) as
Arc<dyn PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
let args = vec![left_geom, right_geom, k_literal,
use_spheroid_literal];
let st_knn = create_spatial_function_expr(st_knn_udf, args);
- let predicate = match_knn_predicate(&st_knn, &column_indices);
- assert!(predicate.is_some());
+ let pred = match_knn_predicate(&st_knn, &column_indices).unwrap();
- let pred = predicate.unwrap();
// Verify left argument is reprojected to left side
let left_arg_col =
pred.left.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(left_arg_col.index(), 1);
@@ -1692,7 +1708,51 @@ mod tests {
// Verify k is literal value 5
assert_eq!(pred.k, 5);
- // Verify use_spheroid is literal value false
+ // Verify use_spheroid is literal value true
+ assert!(pred.use_spheroid);
+ }
+
+ #[test]
+ fn test_match_knn_predicate_default_use_spheroid() {
+ let column_indices = create_test_column_indices();
+
+ // Create ST_KNN(left_geom, right_geom, 5)
+ let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn
PhysicalExpr>;
+ let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn
PhysicalExpr>;
+ let k_literal =
+ Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn
PhysicalExpr>;
+
+ let st_knn_udf = create_dummy_st_knn_udf();
+ let args = vec![left_geom, right_geom, k_literal];
+ let st_knn = create_spatial_function_expr(st_knn_udf, args);
+
+ let pred = match_knn_predicate(&st_knn, &column_indices).unwrap();
+
+ // Verify k is literal value 5
+ assert_eq!(pred.k, 5);
+
+ // Verify use_spheroid is literal value false (default)
+ assert!(!pred.use_spheroid);
+ }
+
+ #[test]
+ fn test_match_knn_predicate_default_k() {
+ let column_indices = create_test_column_indices();
+
+ // Create ST_KNN(left_geom, right_geom)
+ let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn
PhysicalExpr>;
+ let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn
PhysicalExpr>;
+
+ let st_knn_udf = create_dummy_st_knn_udf();
+ let args = vec![left_geom, right_geom];
+ let st_knn = create_spatial_function_expr(st_knn_udf, args);
+
+ let pred = match_knn_predicate(&st_knn, &column_indices).unwrap();
+
+ // Verify k is literal value 1 (default)
+ assert_eq!(pred.k, 1);
+
+ // Verify use_spheroid is literal value false (default)
assert!(!pred.use_spheroid);
}
@@ -1924,14 +1984,11 @@ mod tests {
fn test_match_knn_predicate_insufficient_args() {
let column_indices = create_test_column_indices();
- // Create ST_KNN with only 3 arguments (insufficient - needs 4)
+ // Create ST_KNN with only 1 arguments (insufficient - needs 2)
let left_geom = Arc::new(Column::new("left_geom", 1)) as Arc<dyn
PhysicalExpr>;
- let right_geom = Arc::new(Column::new("right_geom", 2)) as Arc<dyn
PhysicalExpr>;
- let k_literal =
- Arc::new(Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn
PhysicalExpr>;
let st_knn_udf = create_dummy_st_knn_udf();
- let args = vec![left_geom, right_geom, k_literal]; // Missing
use_spheroid arg
+ let args = vec![left_geom];
let st_knn = create_spatial_function_expr(st_knn_udf, args);
let predicate = match_knn_predicate(&st_knn, &column_indices);