This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new 9efa952 ST_Distance and ST_DWithin based on georust/geo (#73)
9efa952 is described below
commit 9efa952c45eefe54c9e9f998dd8f2b204705c3f4
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Sat Sep 13 11:05:42 2025 +0800
ST_Distance and ST_DWithin based on georust/geo (#73)
---
Cargo.lock | 62 +++++------
c/sedona-geos/src/st_dwithin.rs | 31 +++---
rust/sedona-geo/benches/geo-functions.rs | 24 ++++
rust/sedona-geo/src/lib.rs | 2 +
rust/sedona-geo/src/register.rs | 6 +-
rust/sedona-geo/src/st_distance.rs | 123 +++++++++++++++++++++
.../sedona-geo}/src/st_dwithin.rs | 114 +++++++++++--------
7 files changed, 264 insertions(+), 98 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 5a975d9..2337df6 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -794,9 +794,9 @@ dependencies = [
[[package]]
name = "aws-smithy-runtime"
-version = "1.9.1"
+version = "1.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d3946acbe1ead1301ba6862e712c7903ca9bb230bdf1fbd1b5ac54158ef2ab1f"
+checksum = "4fa63ad37685ceb7762fa4d73d06f1d5493feb88e3f27259b9ed277f4c01b185"
dependencies = [
"aws-smithy-async",
"aws-smithy-http",
@@ -1108,9 +1108,9 @@ checksum =
"37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cc"
-version = "1.2.36"
+version = "1.2.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5252b3d2648e5eedbc1a6f501e3c795e07025c1e93bbf8bbdd6eef7f447a6d54"
+checksum = "65193589c6404eb80b450d618eaf9a2cafaaafd57ecce47370519ef674a7bd44"
dependencies = [
"find-msvc-tools",
"jobserver",
@@ -1265,9 +1265,9 @@ checksum =
"b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
[[package]]
name = "comfy-table"
-version = "7.2.0"
+version = "7.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3f8e18d0dca9578507f13f9803add0df13362b02c501c1c17734f0dbb52eaf0b"
+checksum = "b03b7db8e0b4b2fdad6c551e634134e99ec000e5c8c3b6856c65e8bbaded7a3b"
dependencies = [
"crossterm",
"unicode-segmentation",
@@ -1296,9 +1296,9 @@ dependencies = [
[[package]]
name = "const_panic"
-version = "0.2.14"
+version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bb8a602185c3c95b52f86dc78e55a6df9a287a7a93ddbcf012509930880cf879"
+checksum = "e262cdaac42494e3ae34c43969f9cdeb7da178bdb4b66fa6a1ea2edb4c8ae652"
dependencies = [
"typewit",
]
@@ -2617,7 +2617,7 @@ dependencies = [
[[package]]
name = "geo-generic-alg"
version = "0.1.0"
-source =
"git+https://github.com/wherobots/geo.git?branch=generic-alg#86f5f3aba769d998336a3c36ac0a3d9f5536596f"
+source =
"git+https://github.com/wherobots/geo.git?branch=generic-alg#567d1da8e094e74d0fdc5ce90f9bdd8f90dae40b"
dependencies = [
"earcutr",
"float_next_after",
@@ -2652,7 +2652,7 @@ dependencies = [
[[package]]
name = "geo-traits"
version = "0.2.0"
-source =
"git+https://github.com/wherobots/geo.git?branch=generic-alg#86f5f3aba769d998336a3c36ac0a3d9f5536596f"
+source =
"git+https://github.com/wherobots/geo.git?branch=generic-alg#567d1da8e094e74d0fdc5ce90f9bdd8f90dae40b"
dependencies = [
"geo-types",
]
@@ -2669,7 +2669,7 @@ dependencies = [
[[package]]
name = "geo-traits-ext"
version = "0.1.0"
-source =
"git+https://github.com/wherobots/geo.git?branch=generic-alg#86f5f3aba769d998336a3c36ac0a3d9f5536596f"
+source =
"git+https://github.com/wherobots/geo.git?branch=generic-alg#567d1da8e094e74d0fdc5ce90f9bdd8f90dae40b"
dependencies = [
"approx",
"geo-traits 0.2.0",
@@ -2681,7 +2681,7 @@ dependencies = [
[[package]]
name = "geo-types"
version = "0.7.16"
-source =
"git+https://github.com/wherobots/geo.git?branch=generic-alg#86f5f3aba769d998336a3c36ac0a3d9f5536596f"
+source =
"git+https://github.com/wherobots/geo.git?branch=generic-alg#567d1da8e094e74d0fdc5ce90f9bdd8f90dae40b"
dependencies = [
"approx",
"num-traits",
@@ -2933,9 +2933,9 @@ checksum =
"6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
[[package]]
name = "humantime"
-version = "2.2.0"
+version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f"
+checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424"
[[package]]
name = "hyper"
@@ -3046,9 +3046,9 @@ checksum =
"155181bc97d770181cf9477da51218a19ee92a8e5be642e796661aee2b601139"
[[package]]
name = "iana-time-zone"
-version = "0.1.63"
+version = "0.1.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8"
+checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb"
dependencies = [
"android_system_properties",
"core-foundation-sys",
@@ -4641,9 +4641,9 @@ dependencies = [
[[package]]
name = "rustls-webpki"
-version = "0.103.4"
+version = "0.103.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc"
+checksum = "b5a37813727b78798e53c2bec3f5e8fe12a6d6f8389bf9ca7802add4c9905ad8"
dependencies = [
"aws-lc-rs",
"ring",
@@ -5855,15 +5855,15 @@ checksum =
"1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f"
[[package]]
name = "typewit"
-version = "1.14.1"
+version = "1.14.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4c98488b93df24b7c794d6a58c4198d7a2abde676324beaca84f7fb5b39c0811"
+checksum = "f8c1ae7cc0fdb8b842d65d127cb981574b0d2b249b74d1c7a2986863dc134f71"
[[package]]
name = "unicode-ident"
-version = "1.0.18"
+version = "1.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
+checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d"
[[package]]
name = "unicode-segmentation"
@@ -6130,13 +6130,13 @@ checksum =
"712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-core"
-version = "0.61.2"
+version = "0.62.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3"
+checksum = "57fe7168f7de578d2d8a05b07fd61870d2e73b4020e9f49aa00da8471723497c"
dependencies = [
"windows-implement",
"windows-interface",
- "windows-link 0.1.3",
+ "windows-link 0.2.0",
"windows-result",
"windows-strings",
]
@@ -6177,20 +6177,20 @@ checksum =
"45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65"
[[package]]
name = "windows-result"
-version = "0.3.4"
+version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6"
+checksum = "7084dcc306f89883455a206237404d3eaf961e5bd7e0f312f7c91f57eb44167f"
dependencies = [
- "windows-link 0.1.3",
+ "windows-link 0.2.0",
]
[[package]]
name = "windows-strings"
-version = "0.4.2"
+version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57"
+checksum = "7218c655a553b0bed4426cf54b20d7ba363ef543b52d515b3e48d7fd55318dda"
dependencies = [
- "windows-link 0.1.3",
+ "windows-link 0.2.0",
]
[[package]]
diff --git a/c/sedona-geos/src/st_dwithin.rs b/c/sedona-geos/src/st_dwithin.rs
index 03bbaa1..bc7bbd1 100644
--- a/c/sedona-geos/src/st_dwithin.rs
+++ b/c/sedona-geos/src/st_dwithin.rs
@@ -18,7 +18,7 @@ use std::sync::Arc;
use arrow_array::builder::BooleanBuilder;
use arrow_schema::DataType;
-use datafusion_common::{error::Result, DataFusionError};
+use datafusion_common::{cast::as_float64_array, error::Result,
DataFusionError};
use datafusion_expr::ColumnarValue;
use geos::Geom;
use sedona_expr::scalar_udf::{ScalarKernelRef, SedonaScalarKernel};
@@ -53,26 +53,14 @@ impl SedonaScalarKernel for STDWithin {
arg_types: &[SedonaType],
args: &[ColumnarValue],
) -> Result<ColumnarValue> {
- // Extract the constant scalar value before looping over the input
geometries
- let distance: Option<f64>;
let arg2 = args[2].cast_to(&DataType::Float64, None)?;
- if let ColumnarValue::Scalar(scalar_arg) = &arg2 {
- if scalar_arg.is_null() {
- distance = None;
- } else {
- distance = Some(f64::try_from(scalar_arg.clone())?);
- }
- } else {
- return Err(DataFusionError::Execution(format!(
- "Invalid distance: {:?}",
- args[2]
- )));
- }
-
let executor = GeosExecutor::new(arg_types, args);
+ let arg2_array = arg2.to_array(executor.num_iterations())?;
+ let arg2_f64_array = as_float64_array(&arg2_array)?;
+ let mut arg2_iter = arg2_f64_array.iter();
let mut builder =
BooleanBuilder::with_capacity(executor.num_iterations());
executor.execute_wkb_wkb_void(|lhs, rhs| {
- match (lhs, rhs, distance) {
+ match (lhs, rhs, arg2_iter.next().unwrap()) {
(Some(lhs), Some(rhs), Some(distance)) => {
builder.append_value(invoke_scalar(lhs, rhs, distance)?);
}
@@ -151,9 +139,16 @@ mod tests {
let expected: ArrayRef = arrow_array!(Boolean, [Some(true),
Some(false), None, Some(true)]);
assert_array_equal(
&tester
- .invoke_array_array_scalar(arg1, arg2, distance)
+ .invoke_array_array_scalar(Arc::clone(&arg1),
Arc::clone(&arg2), distance)
.unwrap(),
&expected,
);
+
+ let distance = arrow_array!(Int32, [Some(1), Some(1), Some(1),
Some(1)]);
+ let expected: ArrayRef = arrow_array!(Boolean, [Some(true),
Some(false), None, Some(true)]);
+ assert_array_equal(
+ &tester.invoke_arrays(vec![arg1, arg2, distance]).unwrap(),
+ &expected,
+ );
}
}
diff --git a/rust/sedona-geo/benches/geo-functions.rs
b/rust/sedona-geo/benches/geo-functions.rs
index 3576cbe..18d288c 100644
--- a/rust/sedona-geo/benches/geo-functions.rs
+++ b/rust/sedona-geo/benches/geo-functions.rs
@@ -41,6 +41,30 @@ fn criterion_benchmark(c: &mut Criterion) {
"st_intersects",
ArrayScalar(Point, Polygon(500)),
);
+
+ benchmark::scalar(c, &f, "geo", "st_distance", ArrayScalar(Point,
Polygon(10)));
+ benchmark::scalar(
+ c,
+ &f,
+ "geo",
+ "st_distance",
+ ArrayScalar(Point, Polygon(500)),
+ );
+
+ benchmark::scalar(
+ c,
+ &f,
+ "geo",
+ "st_dwithin",
+ ArrayArrayScalar(Polygon(10), Polygon(10), Float64(1.0, 2.0)),
+ );
+ benchmark::scalar(
+ c,
+ &f,
+ "geo",
+ "st_dwithin",
+ ArrayArrayScalar(Polygon(10), Polygon(500), Float64(1.0, 2.0)),
+ );
}
fn criterion_benchmark_aggr(c: &mut Criterion) {
diff --git a/rust/sedona-geo/src/lib.rs b/rust/sedona-geo/src/lib.rs
index 24c6408..6af5dd2 100644
--- a/rust/sedona-geo/src/lib.rs
+++ b/rust/sedona-geo/src/lib.rs
@@ -18,6 +18,8 @@ pub mod centroid;
pub mod register;
mod st_area;
mod st_centroid;
+mod st_distance;
+mod st_dwithin;
mod st_intersection_aggr;
mod st_intersects;
mod st_length;
diff --git a/rust/sedona-geo/src/register.rs b/rust/sedona-geo/src/register.rs
index 9db2c32..5dfb25d 100644
--- a/rust/sedona-geo/src/register.rs
+++ b/rust/sedona-geo/src/register.rs
@@ -21,8 +21,8 @@ use crate::st_intersection_aggr::st_intersection_aggr_impl;
use crate::st_line_interpolate_point::st_line_interpolate_point_impl;
use crate::st_union_aggr::st_union_aggr_impl;
use crate::{
- st_area::st_area_impl, st_centroid::st_centroid_impl,
st_intersects::st_intersects_impl,
- st_length::st_length_impl,
+ st_area::st_area_impl, st_centroid::st_centroid_impl,
st_distance::st_distance_impl,
+ st_dwithin::st_dwithin_impl, st_intersects::st_intersects_impl,
st_length::st_length_impl,
};
pub fn scalar_kernels() -> Vec<(&'static str, ScalarKernelRef)> {
@@ -30,6 +30,8 @@ pub fn scalar_kernels() -> Vec<(&'static str,
ScalarKernelRef)> {
("st_intersects", st_intersects_impl()),
("st_area", st_area_impl()),
("st_centroid", st_centroid_impl()),
+ ("st_distance", st_distance_impl()),
+ ("st_dwithin", st_dwithin_impl()),
("st_length", st_length_impl()),
("st_lineinterpolatepoint", st_line_interpolate_point_impl()),
]
diff --git a/rust/sedona-geo/src/st_distance.rs
b/rust/sedona-geo/src/st_distance.rs
new file mode 100644
index 0000000..4900690
--- /dev/null
+++ b/rust/sedona-geo/src/st_distance.rs
@@ -0,0 +1,123 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+use std::sync::Arc;
+
+use arrow_array::builder::Float64Builder;
+use arrow_schema::DataType;
+use datafusion_common::error::Result;
+use datafusion_expr::ColumnarValue;
+use geo_generic_alg::line_measures::DistanceExt;
+use sedona_expr::scalar_udf::{ScalarKernelRef, SedonaScalarKernel};
+use sedona_functions::executor::WkbExecutor;
+use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};
+use wkb::reader::Wkb;
+
+/// ST_Distance() implementation using [DistanceExt]
+pub fn st_distance_impl() -> ScalarKernelRef {
+ Arc::new(STDistance {})
+}
+
+#[derive(Debug)]
+struct STDistance {}
+
+impl SedonaScalarKernel for STDistance {
+ fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
+ let matcher = ArgMatcher::new(
+ vec![ArgMatcher::is_geometry(), ArgMatcher::is_geometry()],
+ SedonaType::Arrow(DataType::Float64),
+ );
+
+ matcher.match_args(args)
+ }
+
+ fn invoke_batch(
+ &self,
+ arg_types: &[SedonaType],
+ args: &[ColumnarValue],
+ ) -> Result<ColumnarValue> {
+ let executor = WkbExecutor::new(arg_types, args);
+ let mut builder =
Float64Builder::with_capacity(executor.num_iterations());
+ executor.execute_wkb_wkb_void(|maybe_wkb0, maybe_wkb1| {
+ match (maybe_wkb0, maybe_wkb1) {
+ (Some(wkb0), Some(wkb1)) => {
+ builder.append_value(invoke_scalar(wkb0, wkb1)?);
+ }
+ _ => builder.append_null(),
+ }
+
+ Ok(())
+ })?;
+
+ executor.finish(Arc::new(builder.finish()))
+ }
+}
+
+fn invoke_scalar(wkb_a: &Wkb, wkb_b: &Wkb) -> Result<f64> {
+ Ok(wkb_a.distance_ext(wkb_b))
+}
+
+#[cfg(test)]
+mod tests {
+ use datafusion_common::scalar::ScalarValue;
+ use rstest::rstest;
+ use sedona_expr::scalar_udf::SedonaScalarUDF;
+ use sedona_schema::datatypes::{WKB_GEOMETRY, WKB_VIEW_GEOMETRY};
+ use sedona_testing::create::create_scalar;
+ use sedona_testing::testers::ScalarUdfTester;
+
+ use super::*;
+
+ #[rstest]
+ fn udf(
+ #[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] left_sedona_type:
SedonaType,
+ #[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] right_sedona_type:
SedonaType,
+ ) {
+ let udf = SedonaScalarUDF::from_kernel("st_distance",
st_distance_impl());
+ let tester = ScalarUdfTester::new(
+ udf.into(),
+ vec![left_sedona_type.clone(), right_sedona_type.clone()],
+ );
+
+ assert_eq!(
+ tester.return_type().unwrap(),
+ SedonaType::Arrow(DataType::Float64)
+ );
+
+ // Test distance between two points (3-4-5 triangle)
+ let point_0_0 = create_scalar(Some("POINT (0 0)"), &left_sedona_type);
+ let point_3_4 = create_scalar(Some("POINT (3 4)"), &right_sedona_type);
+
+ let result = tester
+ .invoke_scalar_scalar(point_0_0.clone(), point_3_4.clone())
+ .unwrap();
+ if let ScalarValue::Float64(Some(distance)) = result {
+ assert!((distance - 5.0).abs() < 1e-10);
+ } else {
+ panic!("Expected Float64 result");
+ }
+
+ // Test with null values
+ let result = tester
+ .invoke_scalar_scalar(ScalarValue::Null, point_3_4.clone())
+ .unwrap();
+ assert!(result.is_null());
+ let result = tester
+ .invoke_scalar_scalar(point_0_0.clone(), ScalarValue::Null)
+ .unwrap();
+ assert!(result.is_null());
+ }
+}
diff --git a/c/sedona-geos/src/st_dwithin.rs b/rust/sedona-geo/src/st_dwithin.rs
similarity index 52%
copy from c/sedona-geos/src/st_dwithin.rs
copy to rust/sedona-geo/src/st_dwithin.rs
index 03bbaa1..2ba3f5d 100644
--- a/c/sedona-geos/src/st_dwithin.rs
+++ b/rust/sedona-geo/src/st_dwithin.rs
@@ -18,15 +18,15 @@ use std::sync::Arc;
use arrow_array::builder::BooleanBuilder;
use arrow_schema::DataType;
-use datafusion_common::{error::Result, DataFusionError};
+use datafusion_common::{cast::as_float64_array, error::Result};
use datafusion_expr::ColumnarValue;
-use geos::Geom;
+use geo_generic_alg::line_measures::DistanceExt;
use sedona_expr::scalar_udf::{ScalarKernelRef, SedonaScalarKernel};
+use sedona_functions::executor::WkbExecutor;
use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};
+use wkb::reader::Wkb;
-use crate::executor::GeosExecutor;
-
-/// Implementation of ST_DWithin using the geos crate
+/// ST_DWithin() implementation using [DistanceExt]
pub fn st_dwithin_impl() -> ScalarKernelRef {
Arc::new(STDWithin {})
}
@@ -36,7 +36,7 @@ struct STDWithin {}
impl SedonaScalarKernel for STDWithin {
fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
- let matcher: ArgMatcher = ArgMatcher::new(
+ let matcher = ArgMatcher::new(
vec![
ArgMatcher::is_geometry(),
ArgMatcher::is_geometry(),
@@ -53,81 +53,104 @@ impl SedonaScalarKernel for STDWithin {
arg_types: &[SedonaType],
args: &[ColumnarValue],
) -> Result<ColumnarValue> {
- // Extract the constant scalar value before looping over the input
geometries
- let distance: Option<f64>;
let arg2 = args[2].cast_to(&DataType::Float64, None)?;
- if let ColumnarValue::Scalar(scalar_arg) = &arg2 {
- if scalar_arg.is_null() {
- distance = None;
- } else {
- distance = Some(f64::try_from(scalar_arg.clone())?);
- }
- } else {
- return Err(DataFusionError::Execution(format!(
- "Invalid distance: {:?}",
- args[2]
- )));
- }
-
- let executor = GeosExecutor::new(arg_types, args);
+ let executor = WkbExecutor::new(arg_types, args);
+ let arg2_array = arg2.to_array(executor.num_iterations())?;
+ let arg2_f64_array = as_float64_array(&arg2_array)?;
+ let mut arg2_iter = arg2_f64_array.iter();
let mut builder =
BooleanBuilder::with_capacity(executor.num_iterations());
- executor.execute_wkb_wkb_void(|lhs, rhs| {
- match (lhs, rhs, distance) {
- (Some(lhs), Some(rhs), Some(distance)) => {
- builder.append_value(invoke_scalar(lhs, rhs, distance)?);
+ executor.execute_wkb_wkb_void(|maybe_wkb0, maybe_wkb1| {
+ match (maybe_wkb0, maybe_wkb1, arg2_iter.next().unwrap()) {
+ (Some(wkb0), Some(wkb1), Some(distance)) => {
+ builder.append_value(invoke_scalar(wkb0, wkb1, distance)?);
}
_ => builder.append_null(),
- };
+ }
+
Ok(())
})?;
+
executor.finish(Arc::new(builder.finish()))
}
}
-fn invoke_scalar(lhs: &geos::Geometry, rhs: &geos::Geometry, distance: f64) ->
Result<bool> {
- let dist_between = lhs
- .distance(rhs)
- .map_err(|e| DataFusionError::Execution(format!("Failed to calculate
dwithin: {e}")))?;
- Ok(dist_between <= distance)
+fn invoke_scalar(wkb_a: &Wkb, wkb_b: &Wkb, distance_bound: f64) ->
Result<bool> {
+ let actual_distance = wkb_a.distance_ext(wkb_b);
+ Ok(actual_distance <= distance_bound)
}
#[cfg(test)]
mod tests {
use arrow_array::{create_array as arrow_array, ArrayRef};
+ use datafusion_common::scalar::ScalarValue;
use rstest::rstest;
use sedona_expr::scalar_udf::SedonaScalarUDF;
use sedona_schema::datatypes::{WKB_GEOMETRY, WKB_VIEW_GEOMETRY};
- use sedona_testing::compare::assert_array_equal;
- use sedona_testing::create::create_array;
+ use sedona_testing::create::create_scalar;
use sedona_testing::testers::ScalarUdfTester;
use super::*;
#[rstest]
- fn udf(#[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] sedona_type: SedonaType)
{
- use datafusion_common::ScalarValue;
+ fn udf(
+ #[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] left_sedona_type:
SedonaType,
+ #[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] right_sedona_type:
SedonaType,
+ ) {
+ use sedona_testing::{compare::assert_array_equal,
create::create_array};
let udf = SedonaScalarUDF::from_kernel("st_dwithin",
st_dwithin_impl());
let tester = ScalarUdfTester::new(
udf.into(),
vec![
- sedona_type.clone(),
- sedona_type,
+ left_sedona_type.clone(),
+ right_sedona_type.clone(),
SedonaType::Arrow(DataType::Float64),
],
);
- tester.assert_return_type(DataType::Boolean);
+
+ assert_eq!(
+ tester.return_type().unwrap(),
+ SedonaType::Arrow(DataType::Boolean)
+ );
+
+ // Test points within distance (3-4-5 triangle, distance = 5.0)
+ let point_0_0 = create_scalar(Some("POINT (0 0)"), &left_sedona_type);
+ let point_3_4 = create_scalar(Some("POINT (3 4)"), &right_sedona_type);
+ let distance_5 = ScalarValue::Float64(Some(5.0));
+ let distance_4 = ScalarValue::Float64(Some(4.0));
let result = tester
- .invoke_scalar_scalar_scalar("POINT (0 0)", "POINT (0 0)", 0.0)
+ .invoke_scalar_scalar_scalar(point_0_0.clone(), point_3_4.clone(),
distance_5.clone())
.unwrap();
- tester.assert_scalar_result_equals(result, true);
+ assert_eq!(result, ScalarValue::Boolean(Some(true)));
+ // Test points outside distance
let result = tester
- .invoke_scalar_scalar_scalar(ScalarValue::Null, ScalarValue::Null,
ScalarValue::Null)
+ .invoke_scalar_scalar_scalar(point_0_0.clone(), point_3_4.clone(),
distance_4.clone())
+ .unwrap();
+ assert_eq!(result, ScalarValue::Boolean(Some(false)));
+
+ // Test with null values
+ let result = tester
+ .invoke_scalar_scalar_scalar(ScalarValue::Null, point_3_4.clone(),
distance_5.clone())
+ .unwrap();
+ assert!(result.is_null());
+ let result = tester
+ .invoke_scalar_scalar_scalar(point_0_0.clone(), ScalarValue::Null,
distance_5.clone())
.unwrap();
assert!(result.is_null());
+ // Test with null distance
+ let result = tester
+ .invoke_scalar_scalar_scalar(
+ point_0_0.clone(),
+ point_3_4.clone(),
+ ScalarValue::Float64(None),
+ )
+ .unwrap();
+ assert!(result.is_null());
+
+ // Test with array args
let arg1 = create_array(
&[
Some("POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))"),
@@ -146,13 +169,10 @@ mod tests {
],
&WKB_GEOMETRY,
);
- let distance = 1;
-
+ let distance = arrow_array!(Int32, [Some(1), Some(1), Some(1),
Some(1)]);
let expected: ArrayRef = arrow_array!(Boolean, [Some(true),
Some(false), None, Some(true)]);
assert_array_equal(
- &tester
- .invoke_array_array_scalar(arg1, arg2, distance)
- .unwrap(),
+ &tester.invoke_arrays(vec![arg1, arg2, distance]).unwrap(),
&expected,
);
}