This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new b380e30 [#61] Accept SRIDs as args to ST_Transform (#62)
b380e30 is described below
commit b380e304287b99e313b1f12c00ac15bd615e2164
Author: jp <[email protected]>
AuthorDate: Fri Sep 12 11:16:51 2025 -0700
[#61] Accept SRIDs as args to ST_Transform (#62)
---
c/sedona-proj/src/st_transform.rs | 194 +++++++++++++++++++++++++++++++++++---
1 file changed, 181 insertions(+), 13 deletions(-)
diff --git a/c/sedona-proj/src/st_transform.rs
b/c/sedona-proj/src/st_transform.rs
index cefcce8..51a2682 100644
--- a/c/sedona-proj/src/st_transform.rs
+++ b/c/sedona-proj/src/st_transform.rs
@@ -26,6 +26,7 @@ use sedona_geometry::transform::{transform, CachingCrsEngine,
CrsEngine, CrsTran
use sedona_geometry::wkb_factory::WKB_MIN_PROBABLE_BYTES;
use sedona_schema::crs::deserialize_crs;
use sedona_schema::datatypes::{Edges, SedonaType};
+use sedona_schema::matchers::ArgMatcher;
use std::cell::OnceCell;
use std::rc::Rc;
use std::sync::{Arc, RwLock};
@@ -135,7 +136,9 @@ fn define_arg_indexes(arg_types: &[SedonaType], indexes:
&mut TransformArgIndexe
indexes.first_crs = 1;
for (i, arg_type) in arg_types.iter().enumerate().skip(2) {
- if *arg_type == SedonaType::Arrow(DataType::Utf8) {
+ if ArgMatcher::is_numeric().match_type(arg_type)
+ || ArgMatcher::is_string().match_type(arg_type)
+ {
indexes.second_crs = Some(i);
} else if *arg_type == SedonaType::Arrow(DataType::Boolean) {
indexes.lenient = Some(i);
@@ -154,17 +157,41 @@ impl SedonaScalarKernel for STTransform {
arg_types: &[SedonaType],
scalar_args: &[Option<&ScalarValue>],
) -> Result<Option<SedonaType>> {
+ let matcher = ArgMatcher::new(
+ vec![
+ ArgMatcher::is_geometry_or_geography(),
+ ArgMatcher::or(vec![ArgMatcher::is_numeric(),
ArgMatcher::is_string()]),
+ ArgMatcher::optional(ArgMatcher::or(vec![
+ ArgMatcher::is_numeric(),
+ ArgMatcher::is_string(),
+ ])),
+ ArgMatcher::optional(ArgMatcher::is_boolean()),
+ ],
+ SedonaType::Wkb(Edges::Planar, None),
+ );
+
+ if !matcher.matches(arg_types) {
+ return Ok(None);
+ }
+
let mut indexes = TransformArgIndexes::new();
define_arg_indexes(arg_types, &mut indexes);
- let to_crs_opt = if let Some(second_crs_index) = indexes.second_crs {
+ let scalar_arg_opt = if let Some(second_crs_index) =
indexes.second_crs {
scalar_args.get(second_crs_index).unwrap()
} else {
scalar_args.get(indexes.first_crs).unwrap()
};
- match to_crs_opt {
- Some(ScalarValue::Utf8(Some(to_crs))) => {
+ let crs_str_opt = if let Some(scalar_crs) = scalar_arg_opt {
+ to_crs_str(scalar_crs)
+ } else {
+ None
+ };
+
+ // If there is no CRS argument, we cannot determine the return type.
+ match crs_str_opt {
+ Some(to_crs) => {
let val = serde_json::Value::String(to_crs.to_string());
let crs = deserialize_crs(&val)?;
Ok(Some(SedonaType::Wkb(Edges::Planar, crs)))
@@ -187,8 +214,10 @@ impl SedonaScalarKernel for STTransform {
let mut indexes = TransformArgIndexes::new();
define_arg_indexes(arg_types, &mut indexes);
- let first_crs = get_scalar_str(args, indexes.first_crs).ok_or_else(|| {
- DataFusionError::Execution("First argument must be a scalar
string".into())
+ let first_crs = get_crs_str(args, indexes.first_crs).ok_or_else(|| {
+ DataFusionError::Execution(
+ "First CRS argument must be a string or numeric
scalar".to_string(),
+ )
})?;
let lenient = indexes
@@ -196,7 +225,7 @@ impl SedonaScalarKernel for STTransform {
.is_some_and(|i| get_scalar_bool(args, i).unwrap_or(false));
let second_crs = if let Some(second_crs_index) = indexes.second_crs {
- get_scalar_str(args, second_crs_index)
+ get_crs_str(args, second_crs_index)
} else {
None
};
@@ -270,12 +299,23 @@ fn parse_source_crs(source_type: &SedonaType) ->
Result<Option<String>> {
}
}
-fn get_scalar_str(args: &[ColumnarValue], index: usize) -> Option<String> {
- if let Some(ColumnarValue::Scalar(ScalarValue::Utf8(opt_str))) =
args.get(index) {
- opt_str.clone()
- } else {
- None
+fn to_crs_str(scalar_arg: &ScalarValue) -> Option<String> {
+ if let Ok(ScalarValue::Utf8(Some(crs))) =
scalar_arg.cast_to(&DataType::Utf8) {
+ if crs.chars().all(|c| c.is_ascii_digit()) {
+ return Some(format!("EPSG:{crs}"));
+ } else {
+ return Some(crs);
+ }
+ }
+
+ None
+}
+
+fn get_crs_str(args: &[ColumnarValue], index: usize) -> Option<String> {
+ if let ColumnarValue::Scalar(scalar_crs) = &args[index] {
+ return to_crs_str(scalar_crs);
}
+ None
}
fn get_scalar_bool(args: &[ColumnarValue], index: usize) -> Option<bool> {
@@ -303,6 +343,88 @@ mod tests {
const NAD83ZONE6PROJ: &str = "EPSG:2230";
const WGS84: &str = "EPSG:4326";
+ #[rstest]
+ fn invalid_arg_checks() {
+ let udf: SedonaScalarUDF =
+ SedonaScalarUDF::from_kernel("st_transform", st_transform_impl());
+
+ // No args
+ let result = udf.return_field_from_args(ReturnFieldArgs {
+ arg_fields: &[],
+ scalar_arguments: &[],
+ });
+ assert!(
+ result.is_err()
+ && result
+ .unwrap_err()
+ .to_string()
+ .contains("No kernel matching arguments")
+ );
+
+ // Too many args
+ let arg_types = [
+ WKB_GEOMETRY,
+ SedonaType::Arrow(DataType::Utf8),
+ SedonaType::Arrow(DataType::Utf8),
+ SedonaType::Arrow(DataType::Boolean),
+ SedonaType::Arrow(DataType::Int32),
+ ];
+ let arg_fields: Vec<Arc<Field>> = arg_types
+ .iter()
+ .map(|arg_type| Arc::new(arg_type.to_storage_field("",
true).unwrap()))
+ .collect();
+ let result = udf.return_field_from_args(ReturnFieldArgs {
+ arg_fields: &arg_fields,
+ scalar_arguments: &[None, None, None, None, None],
+ });
+ assert!(
+ result.is_err()
+ && result
+ .unwrap_err()
+ .to_string()
+ .contains("No kernel matching arguments")
+ );
+
+ // First arg not geometry
+ let arg_types = [
+ SedonaType::Arrow(DataType::Utf8),
+ SedonaType::Arrow(DataType::Utf8),
+ ];
+ let arg_fields: Vec<Arc<Field>> = arg_types
+ .iter()
+ .map(|arg_type| Arc::new(arg_type.to_storage_field("",
true).unwrap()))
+ .collect();
+ let result = udf.return_field_from_args(ReturnFieldArgs {
+ arg_fields: &arg_fields,
+ scalar_arguments: &[None, None],
+ });
+ assert!(
+ result.is_err()
+ && result
+ .unwrap_err()
+ .to_string()
+ .contains("No kernel matching arguments")
+ );
+
+ // Second arg not string or numeric
+ let arg_types = [WKB_GEOMETRY, SedonaType::Arrow(DataType::Boolean)];
+ let arg_fields: Vec<Arc<Field>> = arg_types
+ .iter()
+ .map(|arg_type| Arc::new(arg_type.to_storage_field("",
true).unwrap()))
+ .collect();
+ let result = udf.return_field_from_args(ReturnFieldArgs {
+ arg_fields: &arg_fields,
+ scalar_arguments: &[None, None],
+ });
+ assert!(
+ result.is_err()
+ && result
+ .unwrap_err()
+ .to_string()
+ .contains("No kernel matching arguments")
+ );
+ }
+
#[rstest]
fn test_invoke_batch_with_geo_crs() {
// From-CRS pulled from sedona type
@@ -329,6 +451,32 @@ mod tests {
);
}
+ #[rstest]
+ fn test_invoke_with_srids() {
+ // Use an integer SRID for the to CRS
+ let arg_types = [
+ SedonaType::Wkb(Edges::Planar, lnglat()),
+ SedonaType::Arrow(DataType::UInt32),
+ ];
+
+ let wkb = create_array(&[None, Some("POINT (79.3871 43.6426)")],
&arg_types[0]);
+
+ let scalar_args = vec![ScalarValue::UInt32(Some(2230))];
+
+ let expected = create_array_value(
+ &[None, Some("POINT (-21508577.363421552 34067918.06097863)")],
+ &SedonaType::Wkb(Edges::Planar, get_crs(NAD83ZONE6PROJ)),
+ );
+
+ let (result_type, result_col) =
+ invoke_udf_test(wkb, scalar_args, arg_types.to_vec()).unwrap();
+ assert_value_equal(&result_col, &expected);
+ assert_eq!(
+ result_type,
+ SedonaType::Wkb(Edges::Planar, get_crs(NAD83ZONE6PROJ))
+ );
+ }
+
#[rstest]
fn test_invoke_batch_with_lenient() {
let arg_types = [
@@ -372,7 +520,7 @@ mod tests {
}
#[rstest]
- fn test_invoke_batch_with_string_source() {
+ fn test_invoke_batch_with_source_arg() {
let arg_types = [
WKB_GEOMETRY,
SedonaType::Arrow(DataType::Utf8),
@@ -392,6 +540,26 @@ mod tests {
&SedonaType::Wkb(Edges::Planar,
Some(get_crs(NAD83ZONE6PROJ).unwrap())),
);
+ let (result_type, result_col) =
+ invoke_udf_test(wkb.clone(), scalar_args,
arg_types.to_vec()).unwrap();
+ assert_value_equal(&result_col, &expected);
+ assert_eq!(
+ result_type,
+ SedonaType::Wkb(Edges::Planar,
Some(get_crs(NAD83ZONE6PROJ).unwrap()))
+ );
+
+ // Test with integer SRIDs
+ let arg_types = [
+ WKB_GEOMETRY,
+ SedonaType::Arrow(DataType::Int32),
+ SedonaType::Arrow(DataType::Int32),
+ ];
+
+ let scalar_args = vec![
+ ScalarValue::Int32(Some(4326)),
+ ScalarValue::Int32(Some(2230)),
+ ];
+
let (result_type, result_col) =
invoke_udf_test(wkb, scalar_args, arg_types.to_vec()).unwrap();
assert_value_equal(&result_col, &expected);