This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new 9657a5e  Add ST_SetCrs (#37)
9657a5e is described below

commit 9657a5e4e2bee4ba4f094299a6ec6131d94edc72
Author: jp <[email protected]>
AuthorDate: Mon Sep 8 12:21:35 2025 -0700

    Add ST_SetCrs (#37)
    
    * Split into ST_SetSRID and ST_SetCRS
    
    * Fix doc, update integration tests and allow null args
    
    * updates parquet integration tests to use st_setcrs
---
 python/sedonadb/tests/functions/test_transforms.py |  10 +-
 python/sedonadb/tests/io/test_parquet.py           |   4 +-
 rust/sedona-expr/src/scalar_udf.rs                 |  16 ++
 rust/sedona-functions/src/register.rs              |   2 +
 rust/sedona-functions/src/st_setsrid.rs            | 232 +++++++++++++++------
 5 files changed, 197 insertions(+), 67 deletions(-)

diff --git a/python/sedonadb/tests/functions/test_transforms.py 
b/python/sedonadb/tests/functions/test_transforms.py
index 1e11ecf..a5bd64d 100644
--- a/python/sedonadb/tests/functions/test_transforms.py
+++ b/python/sedonadb/tests/functions/test_transforms.py
@@ -50,19 +50,17 @@ def test_st_setsrid(eng, geom, srid, expected_srid):
         assert df.crs == pyproj.CRS(expected_srid)
 
 
-# PostGIS does not handle String CRS input to ST_SetSrid
+# PostGIS does not have an API ST_SetCrs
 @pytest.mark.parametrize("eng", [SedonaDB])
 @pytest.mark.parametrize(
-    ("geom", "srid", "expected_srid"),
+    ("geom", "crs", "expected_srid"),
     [
         ("POINT (1 1)", "EPSG:26920", 26920),
         ("POINT (1 1)", pyproj.CRS("EPSG:26920").to_json(), 26920),
     ],
 )
-def test_st_setsrid_sedonadb(eng, geom, srid, expected_srid):
+def test_st_setcrs_sedonadb(eng, geom, crs, expected_srid):
     eng = eng.create_or_skip()
-    result = eng.execute_and_collect(
-        f"SELECT ST_SetSrid({geom_or_null(geom)}, '{srid}')"
-    )
+    result = eng.execute_and_collect(f"SELECT ST_SetCrs({geom_or_null(geom)}, 
'{crs}')")
     df = eng.result_to_pandas(result)
     assert df.crs.to_epsg() == expected_srid
diff --git a/python/sedonadb/tests/io/test_parquet.py 
b/python/sedonadb/tests/io/test_parquet.py
index 669334d..2af3577 100644
--- a/python/sedonadb/tests/io/test_parquet.py
+++ b/python/sedonadb/tests/io/test_parquet.py
@@ -102,7 +102,7 @@ def test_read_geoparquet_pruned(geoarrow_data, name):
         result = eng.execute_and_collect(
             f"""
             SELECT "OBJECTID", geometry FROM tab
-            WHERE ST_Intersects(geometry, 
ST_SetSRID({geom_or_null(wkt_filter)}, '{gdf.crs.to_json()}'))
+            WHERE ST_Intersects(geometry, 
ST_SetCRS({geom_or_null(wkt_filter)}, '{gdf.crs.to_json()}'))
             ORDER BY "OBJECTID";
         """
         )
@@ -127,7 +127,7 @@ def test_read_geoparquet_pruned(geoarrow_data, name):
         result = eng.execute_and_collect(
             f"""
             SELECT * FROM tab_dataset
-            WHERE ST_Intersects(geometry, 
ST_SetSRID({geom_or_null(wkt_filter)}, '{gdf.crs.to_json()}'))
+            WHERE ST_Intersects(geometry, 
ST_SetCRS({geom_or_null(wkt_filter)}, '{gdf.crs.to_json()}'))
             ORDER BY "OBJECTID";
         """
         )
diff --git a/rust/sedona-expr/src/scalar_udf.rs 
b/rust/sedona-expr/src/scalar_udf.rs
index 427abb9..ab3491a 100644
--- a/rust/sedona-expr/src/scalar_udf.rs
+++ b/rust/sedona-expr/src/scalar_udf.rs
@@ -207,6 +207,11 @@ impl ArgMatcher {
         Arc::new(IsGeography {})
     }
 
+    /// Matches a null argument
+    pub fn is_null() -> Arc<dyn TypeMatcher + Send + Sync> {
+        Arc::new(IsNull {})
+    }
+
     /// Matches any numeric argument
     pub fn is_numeric() -> Arc<dyn TypeMatcher + Send + Sync> {
         Arc::new(IsNumeric {})
@@ -371,6 +376,14 @@ impl TypeMatcher for IsBoolean {
     }
 }
 
+#[derive(Debug)]
+struct IsNull {}
+impl TypeMatcher for IsNull {
+    fn match_type(&self, arg: &SedonaType) -> bool {
+        matches!(arg, SedonaType::Arrow(DataType::Null))
+    }
+}
+
 /// Type definition for a Scalar kernel implementation function
 pub type SedonaScalarKernelImpl =
     Arc<dyn Fn(&[SedonaType], &[ColumnarValue]) -> Result<ColumnarValue> + 
Send + Sync>;
@@ -602,6 +615,9 @@ mod tests {
 
         
assert!(ArgMatcher::is_boolean().match_type(&SedonaType::Arrow(DataType::Boolean)));
         
assert!(!ArgMatcher::is_boolean().match_type(&SedonaType::Arrow(DataType::Int32)));
+
+        
assert!(ArgMatcher::is_null().match_type(&SedonaType::Arrow(DataType::Null)));
+        
assert!(!ArgMatcher::is_null().match_type(&SedonaType::Arrow(DataType::Int32)));
     }
 
     #[test]
diff --git a/rust/sedona-functions/src/register.rs 
b/rust/sedona-functions/src/register.rs
index b5208ac..4db04da 100644
--- a/rust/sedona-functions/src/register.rs
+++ b/rust/sedona-functions/src/register.rs
@@ -85,6 +85,7 @@ pub fn default_function_set() -> FunctionSet {
         crate::st_pointzm::st_pointm_udf,
         crate::st_pointzm::st_pointzm_udf,
         crate::st_transform::st_transform_udf,
+        crate::st_setsrid::st_set_crs_udf,
         crate::st_setsrid::st_set_srid_udf,
         crate::st_srid::st_srid_udf,
         crate::st_xyzm::st_m_udf,
@@ -124,6 +125,7 @@ pub mod stubs {
     pub use crate::st_area::st_area_udf;
     pub use crate::st_length::st_length_udf;
     pub use crate::st_perimeter::st_perimeter_udf;
+    pub use crate::st_setsrid::st_set_crs_with_engine_udf;
     pub use crate::st_setsrid::st_set_srid_with_engine_udf;
     pub use crate::st_transform::st_transform_udf;
 }
diff --git a/rust/sedona-functions/src/st_setsrid.rs 
b/rust/sedona-functions/src/st_setsrid.rs
index b4e4c18..d797933 100644
--- a/rust/sedona-functions/src/st_setsrid.rs
+++ b/rust/sedona-functions/src/st_setsrid.rs
@@ -30,7 +30,7 @@ use sedona_schema::{crs::deserialize_crs, 
datatypes::SedonaType};
 ///
 /// An implementation of ST_SetSRID providing a scalar function implementation
 /// based on an optional [CrsEngine]. If provided, it will be used to validate
-/// the provided CRS (otherwise, all CRS input is applied without validation).
+/// the provided SRID (otherwise, all SRID input is applied without 
validation).
 pub fn st_set_srid_with_engine_udf(
     engine: Option<Arc<dyn CrsEngine + Send + Sync>>,
 ) -> SedonaScalarUDF {
@@ -42,6 +42,22 @@ pub fn st_set_srid_with_engine_udf(
     )
 }
 
+/// ST_SetCRS() scalar UDF implementation without CRS validation
+///
+/// An implementation of ST_SetCRS providing a scalar function implementation
+/// based on an optional [CrsEngine]. If provided, it will be used to validate
+/// The provided CRS (otherwise, all CRS input is applied without validation).
+pub fn st_set_crs_with_engine_udf(
+    engine: Option<Arc<dyn CrsEngine + Send + Sync>>,
+) -> SedonaScalarUDF {
+    SedonaScalarUDF::new(
+        "st_setcrs",
+        vec![Arc::new(STSetCRS { engine })],
+        Volatility::Immutable,
+        Some(set_crs_doc()),
+    )
+}
+
 /// ST_SetSRID() scalar UDF implementation without CRS validation
 ///
 /// See [st_set_srid_with_engine_udf] for a validating version of this function
@@ -49,18 +65,40 @@ pub fn st_set_srid_udf() -> SedonaScalarUDF {
     st_set_srid_with_engine_udf(None)
 }
 
+/// ST_SetCRS() scalar UDF implementation without CRS validation
+///
+/// See [st_set_crs_with_engine_udf] for a validating version of this function
+pub fn st_set_crs_udf() -> SedonaScalarUDF {
+    st_set_crs_with_engine_udf(None)
+}
+
 fn set_srid_doc() -> Documentation {
+    Documentation::builder(
+        DOC_SECTION_OTHER,
+        "Sets the spatial reference system identifier (SRID) of the geometry.",
+        "ST_SetSRID (geom: Geometry, srid: Integer)",
+    )
+    .with_argument("geom", "geometry: Input geometry or geography")
+    .with_argument("srid", "srid: EPSG code to set (e.g., 4326)")
+    .with_sql_example(
+        "SELECT ST_SetSRID(ST_GeomFromWKT('POINT (-64.363049 45.091501)'), 
4326)".to_string(),
+    )
+    .build()
+}
+
+fn set_crs_doc() -> Documentation {
     Documentation::builder(
         DOC_SECTION_OTHER,
         "Set CRS information for a geometry or geography",
-        "ST_SetSRID (geom: Geometry, crs: String)",
+        "ST_SetCrs (geom: Geometry, crs: String)",
     )
     .with_argument("geom", "geometry: Input geometry or geography")
-    .with_argument("crs", "string: Coordinate reference system identifier")
-    .with_argument("crs", "string: CRS identifier to set (e.g., 'OGC:CRS84')")
+    .with_argument(
+        "crs",
+        "string: Coordinate reference system identifier (e.g., 'OGC:CRS84')",
+    )
     .with_sql_example(
-        "SELECT ST_SetSRID(ST_GeomFromWKT('POINT (-64.363049 45.091501)'), 
'OGC:CRS84')"
-            .to_string(),
+        "SELECT ST_SetCrs(ST_GeomFromWKT('POINT (-64.363049 45.091501)'), 
'OGC:CRS84')".to_string(),
     )
     .build()
 }
@@ -76,39 +114,48 @@ impl SedonaScalarKernel for STSetSRID {
         args: &[SedonaType],
         scalar_args: &[Option<&ScalarValue>],
     ) -> Result<Option<SedonaType>> {
-        if args.len() != 2 {
+        if args.len() != 2
+            || !(ArgMatcher::is_numeric().match_type(&args[1])
+                || ArgMatcher::is_null().match_type(&args[1]))
+        {
             return Ok(None);
         }
+        determine_return_type(args, scalar_args, self.engine.as_ref())
+    }
 
-        if !ArgMatcher::is_geometry_or_geography().match_type(&args[0]) {
-            return Ok(None);
-        }
+    fn invoke_batch(
+        &self,
+        _arg_types: &[SedonaType],
+        args: &[ColumnarValue],
+    ) -> Result<ColumnarValue> {
+        Ok(args[0].clone())
+    }
 
-        if let Some(scalar_crs) = scalar_args[1] {
-            if let ScalarValue::Utf8(maybe_crs) = 
scalar_crs.cast_to(&DataType::Utf8)? {
-                let new_crs = match maybe_crs {
-                    Some(crs) => {
-                        if crs == "0" {
-                            None
-                        } else {
-                            validate_crs(&crs, self.engine.as_ref())?;
-                            deserialize_crs(&serde_json::Value::String(crs))?
-                        }
-                    }
-                    None => None,
-                };
+    fn return_type(&self, _args: &[SedonaType]) -> Result<Option<SedonaType>> {
+        sedona_internal_err!(
+            "Should not be called because return_type_from_args_and_scalars() 
is implemented"
+        )
+    }
+}
 
-                match args[0] {
-                    SedonaType::Wkb(edges, _) => return 
Ok(Some(SedonaType::Wkb(edges, new_crs))),
-                    SedonaType::WkbView(edges, _) => {
-                        return Ok(Some(SedonaType::WkbView(edges, new_crs)))
-                    }
-                    _ => {}
-                }
-            }
-        }
+#[derive(Debug)]
+struct STSetCRS {
+    engine: Option<Arc<dyn CrsEngine + Send + Sync>>,
+}
 
-        sedona_internal_err!("Unexpected argument types: {}, {}", args[0], 
args[1])
+impl SedonaScalarKernel for STSetCRS {
+    fn return_type_from_args_and_scalars(
+        &self,
+        args: &[SedonaType],
+        scalar_args: &[Option<&ScalarValue>],
+    ) -> Result<Option<SedonaType>> {
+        if args.len() != 2
+            || !(ArgMatcher::is_string().match_type(&args[1])
+                || ArgMatcher::is_null().match_type(&args[1]))
+        {
+            return Ok(None);
+        }
+        determine_return_type(args, scalar_args, self.engine.as_ref())
     }
 
     fn invoke_batch(
@@ -144,6 +191,42 @@ pub fn validate_crs(
     Ok(())
 }
 
+fn determine_return_type(
+    args: &[SedonaType],
+    scalar_args: &[Option<&ScalarValue>],
+    maybe_engine: Option<&Arc<dyn CrsEngine + Send + Sync>>,
+) -> Result<Option<SedonaType>> {
+    if !ArgMatcher::is_geometry_or_geography().match_type(&args[0]) {
+        return Ok(None);
+    }
+
+    if let Some(scalar_crs) = scalar_args[1] {
+        if let ScalarValue::Utf8(maybe_crs) = 
scalar_crs.cast_to(&DataType::Utf8)? {
+            let new_crs = match maybe_crs {
+                Some(crs) => {
+                    if crs == "0" {
+                        None
+                    } else {
+                        validate_crs(&crs, maybe_engine)?;
+                        deserialize_crs(&serde_json::Value::String(crs))?
+                    }
+                }
+                None => None,
+            };
+
+            match args[0] {
+                SedonaType::Wkb(edges, _) => return 
Ok(Some(SedonaType::Wkb(edges, new_crs))),
+                SedonaType::WkbView(edges, _) => {
+                    return Ok(Some(SedonaType::WkbView(edges, new_crs)))
+                }
+                _ => {}
+            }
+        }
+    }
+
+    sedona_internal_err!("Unexpected argument types: {}, {}", args[0], args[1])
+}
+
 #[cfg(test)]
 mod test {
     use std::rc::Rc;
@@ -163,62 +246,88 @@ mod test {
     fn udf_metadata() {
         let udf: ScalarUDF = st_set_srid_udf().into();
         assert_eq!(udf.name(), "st_setsrid");
-        assert!(udf.documentation().is_some())
+        assert!(udf.documentation().is_some());
+
+        let udf: ScalarUDF = st_set_crs_udf().into();
+        assert_eq!(udf.name(), "st_setcrs");
+        assert!(udf.documentation().is_some());
     }
 
     #[test]
-    fn udf() {
+    fn udf_srid() {
         let udf: ScalarUDF = st_set_srid_udf().into();
 
         let wkb_lnglat = SedonaType::Wkb(Edges::Planar, lnglat());
         let geom_arg = create_scalar_value(Some("POINT (0 1)"), &WKB_GEOMETRY);
         let geom_lnglat = create_scalar_value(Some("POINT (0 1)"), 
&wkb_lnglat);
 
-        let good_crs_scalar = ScalarValue::Utf8(Some("EPSG:4326".to_string()));
-        let null_crs_scalar = ScalarValue::Utf8(None);
-        let epsg_code_scalar = ScalarValue::Int32(Some(4326));
-        let unset_scalar = ScalarValue::Int32(Some(0));
-        let questionable_crs_scalar = 
ScalarValue::Utf8(Some("gazornenplat".to_string()));
+        let srid_scalar = ScalarValue::UInt32(Some(4326));
+        let unset_scalar = ScalarValue::UInt32(Some(0));
+        let null_srid_scalar = ScalarValue::UInt32(None);
 
-        // Call with a string scalar destination
+        // Call with an integer code destination (should result in a lnglat 
crs)
         let (return_type, result) = call_udf(
             &udf,
             geom_arg.clone(),
-            WKB_GEOMETRY,
-            good_crs_scalar.clone(),
+            &[WKB_GEOMETRY, SedonaType::Arrow(DataType::UInt32)],
+            srid_scalar.clone(),
         )
         .unwrap();
         assert_eq!(return_type, wkb_lnglat);
         assert_value_equal(&result, &geom_lnglat);
 
-        // Call with a null scalar destination (should *not* set the output 
crs)
+        // Call with an integer code of 0 (should unset the output crs)
+        let (return_type, result) = call_udf(
+            &udf,
+            geom_lnglat.clone(),
+            &[WKB_GEOMETRY, SedonaType::Arrow(DataType::UInt32)],
+            unset_scalar.clone(),
+        )
+        .unwrap();
+        assert_eq!(return_type, WKB_GEOMETRY);
+        assert_value_equal(&result, &geom_arg);
+
+        // Call with a null srid (should *not* set the output crs)
         let (return_type, result) = call_udf(
             &udf,
             geom_arg.clone(),
-            WKB_GEOMETRY,
-            null_crs_scalar.clone(),
+            &[WKB_GEOMETRY, SedonaType::Arrow(DataType::UInt32)],
+            null_srid_scalar.clone(),
         )
         .unwrap();
         assert_eq!(return_type, WKB_GEOMETRY);
         assert_value_equal(&result, &geom_arg);
+    }
 
-        // Call with an integer code destination (should result in a lnglat 
crs)
+    #[test]
+    fn udf_crs() {
+        let udf: ScalarUDF = st_set_crs_udf().into();
+
+        let wkb_lnglat = SedonaType::Wkb(Edges::Planar, lnglat());
+        let geom_arg = create_scalar_value(Some("POINT (0 1)"), &WKB_GEOMETRY);
+        let geom_lnglat = create_scalar_value(Some("POINT (0 1)"), 
&wkb_lnglat);
+
+        let good_crs_scalar = ScalarValue::Utf8(Some("EPSG:4326".to_string()));
+        let null_crs_scalar = ScalarValue::Utf8(None);
+        let questionable_crs_scalar = 
ScalarValue::Utf8(Some("gazornenplat".to_string()));
+
+        // Call with a string scalar destination
         let (return_type, result) = call_udf(
             &udf,
             geom_arg.clone(),
-            WKB_GEOMETRY,
-            epsg_code_scalar.clone(),
+            &[WKB_GEOMETRY, SedonaType::Arrow(DataType::Utf8)],
+            good_crs_scalar.clone(),
         )
         .unwrap();
         assert_eq!(return_type, wkb_lnglat);
         assert_value_equal(&result, &geom_lnglat);
 
-        // Call with an integer code of 0 (should unset the output crs)
+        // Call with a null scalar destination (should *not* set the output 
crs)
         let (return_type, result) = call_udf(
             &udf,
-            geom_lnglat.clone(),
-            WKB_GEOMETRY,
-            unset_scalar.clone(),
+            geom_arg.clone(),
+            &[WKB_GEOMETRY, SedonaType::Arrow(DataType::Utf8)],
+            null_crs_scalar.clone(),
         )
         .unwrap();
         assert_eq!(return_type, WKB_GEOMETRY);
@@ -226,26 +335,31 @@ mod test {
 
         // Ensure that an engine can reject a CRS if the UDF was constructed 
with one
         let udf_with_validation: ScalarUDF =
-            st_set_srid_with_engine_udf(Some(Arc::new(ExtremelyUnusefulEngine 
{}))).into();
+            st_set_crs_with_engine_udf(Some(Arc::new(ExtremelyUnusefulEngine 
{}))).into();
         let err = call_udf(
             &udf_with_validation,
             geom_arg.clone(),
-            WKB_GEOMETRY,
+            &[WKB_GEOMETRY, SedonaType::Arrow(DataType::Utf8)],
             questionable_crs_scalar.clone(),
         )
         .unwrap_err();
-        assert_eq!(err.message(), "Unknown geometry error");
+        assert_eq!(err.message(), "Unknown geometry error")
     }
 
     fn call_udf(
         udf: &ScalarUDF,
         arg: ColumnarValue,
-        arg_type: SedonaType,
+        arg_type: &[SedonaType],
         to: ScalarValue,
     ) -> Result<(SedonaType, ColumnarValue)> {
+        let SedonaType::Arrow(datatype) = &arg_type[1] else {
+            return Err(DataFusionError::Internal(
+                "Expected SedonaType::Arrow, but found a different 
variant".to_string(),
+            ));
+        };
         let arg_fields = vec![
-            Arc::new(arg_type.to_storage_field("", true)?),
-            Field::new("", DataType::Utf8, true).into(),
+            Arc::new(arg_type[0].to_storage_field("", true)?),
+            Field::new("", datatype.clone(), true).into(),
         ];
         let return_field_args = ReturnFieldArgs {
             arg_fields: &arg_fields,

Reply via email to