This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new b380e30  [#61] Accept SRIDs as args to ST_Transform  (#62)
b380e30 is described below

commit b380e304287b99e313b1f12c00ac15bd615e2164
Author: jp <[email protected]>
AuthorDate: Fri Sep 12 11:16:51 2025 -0700

    [#61] Accept SRIDs as args to ST_Transform  (#62)
---
 c/sedona-proj/src/st_transform.rs | 194 +++++++++++++++++++++++++++++++++++---
 1 file changed, 181 insertions(+), 13 deletions(-)

diff --git a/c/sedona-proj/src/st_transform.rs 
b/c/sedona-proj/src/st_transform.rs
index cefcce8..51a2682 100644
--- a/c/sedona-proj/src/st_transform.rs
+++ b/c/sedona-proj/src/st_transform.rs
@@ -26,6 +26,7 @@ use sedona_geometry::transform::{transform, CachingCrsEngine, 
CrsEngine, CrsTran
 use sedona_geometry::wkb_factory::WKB_MIN_PROBABLE_BYTES;
 use sedona_schema::crs::deserialize_crs;
 use sedona_schema::datatypes::{Edges, SedonaType};
+use sedona_schema::matchers::ArgMatcher;
 use std::cell::OnceCell;
 use std::rc::Rc;
 use std::sync::{Arc, RwLock};
@@ -135,7 +136,9 @@ fn define_arg_indexes(arg_types: &[SedonaType], indexes: 
&mut TransformArgIndexe
     indexes.first_crs = 1;
 
     for (i, arg_type) in arg_types.iter().enumerate().skip(2) {
-        if *arg_type == SedonaType::Arrow(DataType::Utf8) {
+        if ArgMatcher::is_numeric().match_type(arg_type)
+            || ArgMatcher::is_string().match_type(arg_type)
+        {
             indexes.second_crs = Some(i);
         } else if *arg_type == SedonaType::Arrow(DataType::Boolean) {
             indexes.lenient = Some(i);
@@ -154,17 +157,41 @@ impl SedonaScalarKernel for STTransform {
         arg_types: &[SedonaType],
         scalar_args: &[Option<&ScalarValue>],
     ) -> Result<Option<SedonaType>> {
+        let matcher = ArgMatcher::new(
+            vec![
+                ArgMatcher::is_geometry_or_geography(),
+                ArgMatcher::or(vec![ArgMatcher::is_numeric(), 
ArgMatcher::is_string()]),
+                ArgMatcher::optional(ArgMatcher::or(vec![
+                    ArgMatcher::is_numeric(),
+                    ArgMatcher::is_string(),
+                ])),
+                ArgMatcher::optional(ArgMatcher::is_boolean()),
+            ],
+            SedonaType::Wkb(Edges::Planar, None),
+        );
+
+        if !matcher.matches(arg_types) {
+            return Ok(None);
+        }
+
         let mut indexes = TransformArgIndexes::new();
         define_arg_indexes(arg_types, &mut indexes);
 
-        let to_crs_opt = if let Some(second_crs_index) = indexes.second_crs {
+        let scalar_arg_opt = if let Some(second_crs_index) = 
indexes.second_crs {
             scalar_args.get(second_crs_index).unwrap()
         } else {
             scalar_args.get(indexes.first_crs).unwrap()
         };
 
-        match to_crs_opt {
-            Some(ScalarValue::Utf8(Some(to_crs))) => {
+        let crs_str_opt = if let Some(scalar_crs) = scalar_arg_opt {
+            to_crs_str(scalar_crs)
+        } else {
+            None
+        };
+
+        // If there is no CRS argument, we cannot determine the return type.
+        match crs_str_opt {
+            Some(to_crs) => {
                 let val = serde_json::Value::String(to_crs.to_string());
                 let crs = deserialize_crs(&val)?;
                 Ok(Some(SedonaType::Wkb(Edges::Planar, crs)))
@@ -187,8 +214,10 @@ impl SedonaScalarKernel for STTransform {
         let mut indexes = TransformArgIndexes::new();
         define_arg_indexes(arg_types, &mut indexes);
 
-        let first_crs = get_scalar_str(args, indexes.first_crs).ok_or_else(|| {
-            DataFusionError::Execution("First argument must be a scalar 
string".into())
+        let first_crs = get_crs_str(args, indexes.first_crs).ok_or_else(|| {
+            DataFusionError::Execution(
+                "First CRS argument must be a string or numeric 
scalar".to_string(),
+            )
         })?;
 
         let lenient = indexes
@@ -196,7 +225,7 @@ impl SedonaScalarKernel for STTransform {
             .is_some_and(|i| get_scalar_bool(args, i).unwrap_or(false));
 
         let second_crs = if let Some(second_crs_index) = indexes.second_crs {
-            get_scalar_str(args, second_crs_index)
+            get_crs_str(args, second_crs_index)
         } else {
             None
         };
@@ -270,12 +299,23 @@ fn parse_source_crs(source_type: &SedonaType) -> 
Result<Option<String>> {
     }
 }
 
-fn get_scalar_str(args: &[ColumnarValue], index: usize) -> Option<String> {
-    if let Some(ColumnarValue::Scalar(ScalarValue::Utf8(opt_str))) = 
args.get(index) {
-        opt_str.clone()
-    } else {
-        None
+fn to_crs_str(scalar_arg: &ScalarValue) -> Option<String> {
+    if let Ok(ScalarValue::Utf8(Some(crs))) = 
scalar_arg.cast_to(&DataType::Utf8) {
+        if crs.chars().all(|c| c.is_ascii_digit()) {
+            return Some(format!("EPSG:{crs}"));
+        } else {
+            return Some(crs);
+        }
+    }
+
+    None
+}
+
+fn get_crs_str(args: &[ColumnarValue], index: usize) -> Option<String> {
+    if let ColumnarValue::Scalar(scalar_crs) = &args[index] {
+        return to_crs_str(scalar_crs);
     }
+    None
 }
 
 fn get_scalar_bool(args: &[ColumnarValue], index: usize) -> Option<bool> {
@@ -303,6 +343,88 @@ mod tests {
     const NAD83ZONE6PROJ: &str = "EPSG:2230";
     const WGS84: &str = "EPSG:4326";
 
+    #[rstest]
+    fn invalid_arg_checks() {
+        let udf: SedonaScalarUDF =
+            SedonaScalarUDF::from_kernel("st_transform", st_transform_impl());
+
+        // No args
+        let result = udf.return_field_from_args(ReturnFieldArgs {
+            arg_fields: &[],
+            scalar_arguments: &[],
+        });
+        assert!(
+            result.is_err()
+                && result
+                    .unwrap_err()
+                    .to_string()
+                    .contains("No kernel matching arguments")
+        );
+
+        // Too many args
+        let arg_types = [
+            WKB_GEOMETRY,
+            SedonaType::Arrow(DataType::Utf8),
+            SedonaType::Arrow(DataType::Utf8),
+            SedonaType::Arrow(DataType::Boolean),
+            SedonaType::Arrow(DataType::Int32),
+        ];
+        let arg_fields: Vec<Arc<Field>> = arg_types
+            .iter()
+            .map(|arg_type| Arc::new(arg_type.to_storage_field("", 
true).unwrap()))
+            .collect();
+        let result = udf.return_field_from_args(ReturnFieldArgs {
+            arg_fields: &arg_fields,
+            scalar_arguments: &[None, None, None, None, None],
+        });
+        assert!(
+            result.is_err()
+                && result
+                    .unwrap_err()
+                    .to_string()
+                    .contains("No kernel matching arguments")
+        );
+
+        // First arg not geometry
+        let arg_types = [
+            SedonaType::Arrow(DataType::Utf8),
+            SedonaType::Arrow(DataType::Utf8),
+        ];
+        let arg_fields: Vec<Arc<Field>> = arg_types
+            .iter()
+            .map(|arg_type| Arc::new(arg_type.to_storage_field("", 
true).unwrap()))
+            .collect();
+        let result = udf.return_field_from_args(ReturnFieldArgs {
+            arg_fields: &arg_fields,
+            scalar_arguments: &[None, None],
+        });
+        assert!(
+            result.is_err()
+                && result
+                    .unwrap_err()
+                    .to_string()
+                    .contains("No kernel matching arguments")
+        );
+
+        // Second arg not string or numeric
+        let arg_types = [WKB_GEOMETRY, SedonaType::Arrow(DataType::Boolean)];
+        let arg_fields: Vec<Arc<Field>> = arg_types
+            .iter()
+            .map(|arg_type| Arc::new(arg_type.to_storage_field("", 
true).unwrap()))
+            .collect();
+        let result = udf.return_field_from_args(ReturnFieldArgs {
+            arg_fields: &arg_fields,
+            scalar_arguments: &[None, None],
+        });
+        assert!(
+            result.is_err()
+                && result
+                    .unwrap_err()
+                    .to_string()
+                    .contains("No kernel matching arguments")
+        );
+    }
+
     #[rstest]
     fn test_invoke_batch_with_geo_crs() {
         // From-CRS pulled from sedona type
@@ -329,6 +451,32 @@ mod tests {
         );
     }
 
+    #[rstest]
+    fn test_invoke_with_srids() {
+        // Use an integer SRID for the to CRS
+        let arg_types = [
+            SedonaType::Wkb(Edges::Planar, lnglat()),
+            SedonaType::Arrow(DataType::UInt32),
+        ];
+
+        let wkb = create_array(&[None, Some("POINT (79.3871 43.6426)")], 
&arg_types[0]);
+
+        let scalar_args = vec![ScalarValue::UInt32(Some(2230))];
+
+        let expected = create_array_value(
+            &[None, Some("POINT (-21508577.363421552 34067918.06097863)")],
+            &SedonaType::Wkb(Edges::Planar, get_crs(NAD83ZONE6PROJ)),
+        );
+
+        let (result_type, result_col) =
+            invoke_udf_test(wkb, scalar_args, arg_types.to_vec()).unwrap();
+        assert_value_equal(&result_col, &expected);
+        assert_eq!(
+            result_type,
+            SedonaType::Wkb(Edges::Planar, get_crs(NAD83ZONE6PROJ))
+        );
+    }
+
     #[rstest]
     fn test_invoke_batch_with_lenient() {
         let arg_types = [
@@ -372,7 +520,7 @@ mod tests {
     }
 
     #[rstest]
-    fn test_invoke_batch_with_string_source() {
+    fn test_invoke_batch_with_source_arg() {
         let arg_types = [
             WKB_GEOMETRY,
             SedonaType::Arrow(DataType::Utf8),
@@ -392,6 +540,26 @@ mod tests {
             &SedonaType::Wkb(Edges::Planar, 
Some(get_crs(NAD83ZONE6PROJ).unwrap())),
         );
 
+        let (result_type, result_col) =
+            invoke_udf_test(wkb.clone(), scalar_args, 
arg_types.to_vec()).unwrap();
+        assert_value_equal(&result_col, &expected);
+        assert_eq!(
+            result_type,
+            SedonaType::Wkb(Edges::Planar, 
Some(get_crs(NAD83ZONE6PROJ).unwrap()))
+        );
+
+        // Test with integer SRIDs
+        let arg_types = [
+            WKB_GEOMETRY,
+            SedonaType::Arrow(DataType::Int32),
+            SedonaType::Arrow(DataType::Int32),
+        ];
+
+        let scalar_args = vec![
+            ScalarValue::Int32(Some(4326)),
+            ScalarValue::Int32(Some(2230)),
+        ];
+
         let (result_type, result_col) =
             invoke_udf_test(wkb, scalar_args, arg_types.to_vec()).unwrap();
         assert_value_equal(&result_col, &expected);

Reply via email to