This is an automated email from the ASF dual-hosted git repository.

paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new 00393da6 feat(rust/sedona-functions): Add ST_GeomFromEWKB and support 
item crs inputs to ST_SetSRID (#547)
00393da6 is described below

commit 00393da61cc906f7cb1e70a19f36dfe4529e1b99
Author: Dewey Dunnington <[email protected]>
AuthorDate: Fri Jan 30 14:31:33 2026 -0600

    feat(rust/sedona-functions): Add ST_GeomFromEWKB and support item crs 
inputs to ST_SetSRID (#547)
---
 python/sedonadb/python/sedonadb/testing.py   |  32 ++++-
 python/sedonadb/tests/functions/test_wkb.py  |  31 +++-
 rust/sedona-expr/src/item_crs.rs             |  11 +-
 rust/sedona-functions/src/lib.rs             |   1 +
 rust/sedona-functions/src/register.rs        |   5 +-
 rust/sedona-functions/src/st_geomfromewkb.rs | 206 +++++++++++++++++++++++++++
 rust/sedona-functions/src/st_setsrid.rs      |  73 +++++-----
 7 files changed, 312 insertions(+), 47 deletions(-)

diff --git a/python/sedonadb/python/sedonadb/testing.py 
b/python/sedonadb/python/sedonadb/testing.py
index eb271060..3c7d2e8a 100644
--- a/python/sedonadb/python/sedonadb/testing.py
+++ b/python/sedonadb/python/sedonadb/testing.py
@@ -18,7 +18,7 @@ import math
 import os
 import warnings
 from pathlib import Path
-from typing import TYPE_CHECKING, List, Tuple
+from typing import TYPE_CHECKING, List, Tuple, Any
 
 import geoarrow.pyarrow as ga
 import pyarrow as pa
@@ -125,6 +125,10 @@ class DBEngine:
                 f"Failed to create engine tester {cls.name()}: 
{e}\n{cls.install_hint()}"
             )
 
+    def val_or_null(self, arg: Any) -> str:
+        """Format SQL expression for a value or NULL"""
+        return val_or_null(arg)
+
     def assert_query_result(self, query: str, expected, **kwargs) -> 
"DBEngine":
         """Assert a SQL query result matches an expected target
 
@@ -334,6 +338,12 @@ class SedonaDB(DBEngine):
         # Don't allow this to fail with a skip
         return cls(*args, **kwargs)
 
+    def val_or_null(self, arg):
+        if isinstance(arg, bytes):
+            return f"X'{arg.hex()}'"
+        else:
+            return super().val_or_null(arg)
+
     def create_table_parquet(self, name, paths) -> "SedonaDB":
         self.con.read_parquet(paths).to_memtable().to_view(name, 
overwrite=True)
         return self
@@ -454,6 +464,12 @@ class PostGIS(DBEngine):
             "- Run `docker compose up postgis` to start a test PostGIS runtime"
         )
 
+    def val_or_null(self, arg):
+        if isinstance(arg, bytes):
+            return f"'\\x{arg.hex()}'::bytea"
+        else:
+            return super().val_or_null(arg)
+
     def create_table_parquet(self, name, paths) -> "PostGIS":
         import json
 
@@ -654,10 +670,20 @@ def geog_or_null(arg):
 
 
 def val_or_null(arg):
-    """Format SQL expression for a value or NULL"""
+    """Format SQL expression for a value or NULL
+
+    Use an engine-specific method when formatting bytes as there is no
+    engine-agnostic way to to represent bytes as a SQL literal.
+
+    This is not secure (i.e., does not prevent SQL injection of any kind)
+    and should only be used for testing.
+    """
     if arg is None:
         return "NULL"
-    return arg
+    elif isinstance(arg, bytes):
+        raise NotImplementedError("Use eng.val_or_null() to format bytes to 
SQL")
+    else:
+        return arg
 
 
 def _geometry_columns(schema):
diff --git a/python/sedonadb/tests/functions/test_wkb.py 
b/python/sedonadb/tests/functions/test_wkb.py
index 424d9a36..4ebb7b17 100644
--- a/python/sedonadb/tests/functions/test_wkb.py
+++ b/python/sedonadb/tests/functions/test_wkb.py
@@ -21,7 +21,7 @@ from sedonadb.testing import PostGIS, SedonaDB, geom_or_null
 
 
 @pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
[email protected]("srid", [None, 4326])
[email protected]("srid", [0, 4326])
 @pytest.mark.parametrize(
     "geom",
     [
@@ -29,7 +29,7 @@ from sedonadb.testing import PostGIS, SedonaDB, geom_or_null
         "POINT (1 2)",
         "LINESTRING (1 2, 3 4, 5 6)",
         "POLYGON ((0 1, 2 0, 2 3, 0 3, 0 1))",
-        "MULTIPOINT ((1 2), (3 4))",
+        "MULTIPOINT (1 2, 3 4)",
         "MULTILINESTRING ((1 2, 3 4), (5 6, 7 8))",
         "MULTIPOLYGON (((0 1, 2 0, 2 3, 0 3, 0 1)))",
         "GEOMETRYCOLLECTION (POINT (1 2), LINESTRING (3 4, 5 6))",
@@ -37,7 +37,7 @@ from sedonadb.testing import PostGIS, SedonaDB, geom_or_null
         "POINT Z (1 2 3)",
         "LINESTRING Z (1 2 3, 4 5 6)",
         "POLYGON Z ((0 1 2, 3 0 2, 3 4 2, 0 4 2, 0 1 2))",
-        "MULTIPOINT Z ((1 2 3), (4 5 6))",
+        "MULTIPOINT Z (1 2 3, 4 5 6)",
         "MULTILINESTRING Z ((1 2 3, 4 5 6), (7 8 9, 10 11 12))",
         "MULTIPOLYGON Z (((0 1 2, 3 0 2, 3 4 2, 0 4 2, 0 1 2)))",
         "GEOMETRYCOLLECTION Z (POINT Z (1 2 3))",
@@ -45,7 +45,7 @@ from sedonadb.testing import PostGIS, SedonaDB, geom_or_null
         "POINT M (1 2 3)",
         "LINESTRING M (1 2 3, 4 5 6)",
         "POLYGON M ((0 1 2, 3 0 2, 3 4 2, 0 4 2, 0 1 2))",
-        "MULTIPOINT M ((1 2 3), (4 5 6))",
+        "MULTIPOINT M (1 2 3, 4 5 6)",
         "MULTILINESTRING M ((1 2 3, 4 5 6), (7 8 9, 10 11 12))",
         "MULTIPOLYGON M (((0 1 2, 3 0 2, 3 4 2, 0 4 2, 0 1 2)))",
         "GEOMETRYCOLLECTION M (POINT M (1 2 3))",
@@ -53,7 +53,7 @@ from sedonadb.testing import PostGIS, SedonaDB, geom_or_null
         "POINT ZM (1 2 3 4)",
         "LINESTRING ZM (1 2 3 4, 5 6 7 8)",
         "POLYGON ZM ((0 1 2 3, 4 0 2 3, 4 5 2 3, 0 5 2 3, 0 1 2 3))",
-        "MULTIPOINT ZM ((1 2 3 4), (5 6 7 8))",
+        "MULTIPOINT ZM (1 2 3 4, 5 6 7 8)",
         "MULTILINESTRING ZM ((1 2 3 4, 5 6 7 8), (9 10 11 12, 13 14 15 16))",
         "MULTIPOLYGON ZM (((0 1 2 3, 4 0 2 3, 4 5 2 3, 0 5 2 3, 0 1 2 3)))",
         "GEOMETRYCOLLECTION ZM (POINT ZM (1 2 3 4))",
@@ -70,11 +70,14 @@ from sedonadb.testing import PostGIS, SedonaDB, geom_or_null
     ],
 )
 def test_st_asewkb(eng, srid, geom):
+    if shapely.geos_version < (3, 12, 0):
+        pytest.skip("GEOS version 3.12+ required for EWKB tests")
+
     eng = eng.create_or_skip()
 
     if geom is not None:
         shapely_geom = shapely.from_wkt(geom)
-        if srid is not None:
+        if srid:
             shapely_geom = shapely.set_srid(shapely_geom, srid)
             write_srid = True
         else:
@@ -90,4 +93,20 @@ def test_st_asewkb(eng, srid, geom):
     else:
         expected = None
 
+    # Check rendering of WKB against shapely
     eng.assert_query_result(f"SELECT ST_AsEWKB({geom_or_null(geom, srid)})", 
expected)
+
+    # Check read of EWKB against read SRID
+    if expected is None:
+        srid = None
+    eng.assert_query_result(
+        f"SELECT ST_SRID(ST_GeomFromEWKB({eng.val_or_null(expected)}))", srid
+    )
+
+    # Check read of EWKB against read geometry content
+    # Workaround bug in geoarrow-c
+    if geom == "POINT EMPTY":
+        geom = "POINT (nan nan)"
+    eng.assert_query_result(
+        f"SELECT ST_SetSRID(ST_GeomFromEWKB({eng.val_or_null(expected)}), 0)", 
geom
+    )
diff --git a/rust/sedona-expr/src/item_crs.rs b/rust/sedona-expr/src/item_crs.rs
index 3bb73a16..d584cb22 100644
--- a/rust/sedona-expr/src/item_crs.rs
+++ b/rust/sedona-expr/src/item_crs.rs
@@ -549,8 +549,11 @@ pub fn make_item_crs(
 }
 
 /// Given an input type, separate it into an item and crs type (if the input
-/// is an item_crs type). Otherwise, just return the item type as is.
-fn parse_item_crs_arg_type(sedona_type: &SedonaType) -> Result<(SedonaType, 
Option<SedonaType>)> {
+/// is an item_crs type). Otherwise, just return the item type as is and 
return a
+/// CRS type of None.
+pub fn parse_item_crs_arg_type(
+    sedona_type: &SedonaType,
+) -> Result<(SedonaType, Option<SedonaType>)> {
     if let SedonaType::Arrow(DataType::Struct(fields)) = sedona_type {
         let field_names = fields.iter().map(|f| f.name()).collect::<Vec<_>>();
         if field_names != ["item", "crs"] {
@@ -569,7 +572,7 @@ fn parse_item_crs_arg_type(sedona_type: &SedonaType) -> 
Result<(SedonaType, Opti
 /// is an item_crs type). Otherwise, just return the item type as is. This
 /// version strips the CRS, which we need to do here before passing it to the
 /// underlying kernel (which expects all input CRSes to match).
-fn parse_item_crs_arg_type_strip_crs(
+pub fn parse_item_crs_arg_type_strip_crs(
     sedona_type: &SedonaType,
 ) -> Result<(SedonaType, Option<SedonaType>)> {
     match sedona_type {
@@ -588,7 +591,7 @@ fn parse_item_crs_arg_type_strip_crs(
 
 /// Separate an argument into the item and its crs (if applicable). This
 /// operates on the result of parse_item_crs_arg_type().
-fn parse_item_crs_arg(
+pub fn parse_item_crs_arg(
     item_type: &SedonaType,
     crs_type: &Option<SedonaType>,
     arg: &ColumnarValue,
diff --git a/rust/sedona-functions/src/lib.rs b/rust/sedona-functions/src/lib.rs
index 6e8f884b..e8037a06 100644
--- a/rust/sedona-functions/src/lib.rs
+++ b/rust/sedona-functions/src/lib.rs
@@ -44,6 +44,7 @@ pub mod st_envelope_agg;
 pub mod st_flipcoordinates;
 mod st_geometryn;
 mod st_geometrytype;
+mod st_geomfromewkb;
 mod st_geomfromwkb;
 mod st_geomfromwkt;
 mod st_haszm;
diff --git a/rust/sedona-functions/src/register.rs 
b/rust/sedona-functions/src/register.rs
index 14405409..883f5a5a 100644
--- a/rust/sedona-functions/src/register.rs
+++ b/rust/sedona-functions/src/register.rs
@@ -58,7 +58,6 @@ pub fn default_function_set() -> FunctionSet {
         crate::predicates::st_knn_udf,
         crate::predicates::st_touches_udf,
         crate::predicates::st_within_udf,
-        crate::st_line_merge::st_line_merge_udf,
         crate::referencing::st_line_interpolate_point_udf,
         crate::referencing::st_line_locate_point_udf,
         crate::sd_format::sd_format_udf,
@@ -80,12 +79,13 @@ pub fn default_function_set() -> FunctionSet {
         crate::st_flipcoordinates::st_flipcoordinates_udf,
         crate::st_geometryn::st_geometryn_udf,
         crate::st_geometrytype::st_geometry_type_udf,
+        crate::st_geomfromewkb::st_geomfromewkb_udf,
         crate::st_geomfromwkb::st_geogfromwkb_udf,
         crate::st_geomfromwkb::st_geomfromwkb_udf,
         crate::st_geomfromwkb::st_geomfromwkbunchecked_udf,
         crate::st_geomfromwkt::st_geogfromwkt_udf,
-        crate::st_geomfromwkt::st_geomfromwkt_udf,
         crate::st_geomfromwkt::st_geomfromewkt_udf,
+        crate::st_geomfromwkt::st_geomfromwkt_udf,
         crate::st_haszm::st_hasm_udf,
         crate::st_haszm::st_hasz_udf,
         crate::st_interiorringn::st_interiorringn_udf,
@@ -93,6 +93,7 @@ pub fn default_function_set() -> FunctionSet {
         crate::st_iscollection::st_iscollection_udf,
         crate::st_isempty::st_isempty_udf,
         crate::st_length::st_length_udf,
+        crate::st_line_merge::st_line_merge_udf,
         crate::st_makeline::st_makeline_udf,
         crate::st_numgeometries::st_numgeometries_udf,
         crate::st_perimeter::st_perimeter_udf,
diff --git a/rust/sedona-functions/src/st_geomfromewkb.rs 
b/rust/sedona-functions/src/st_geomfromewkb.rs
new file mode 100644
index 00000000..7634655b
--- /dev/null
+++ b/rust/sedona-functions/src/st_geomfromewkb.rs
@@ -0,0 +1,206 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+use std::{sync::Arc, vec};
+
+use arrow_array::builder::{BinaryBuilder, StringViewBuilder};
+use arrow_schema::DataType;
+use datafusion_common::{error::Result, exec_datafusion_err, ScalarValue};
+use datafusion_expr::{
+    scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, 
Volatility,
+};
+use sedona_common::sedona_internal_err;
+use sedona_expr::{
+    item_crs::make_item_crs,
+    scalar_udf::{SedonaScalarKernel, SedonaScalarUDF},
+};
+use sedona_geometry::{wkb_factory::WKB_MIN_PROBABLE_BYTES, 
wkb_header::WkbHeader};
+use sedona_schema::{
+    datatypes::{SedonaType, WKB_GEOMETRY, WKB_GEOMETRY_ITEM_CRS, 
WKB_VIEW_GEOGRAPHY},
+    matchers::ArgMatcher,
+};
+
+use crate::executor::WkbExecutor;
+
+/// ST_GeomFromEWKB() scalar UDF implementation
+///
+/// An implementation of EWKB reading using GeoRust's wkb crate and our 
internal
+/// WkbHeader utility.
+pub fn st_geomfromewkb_udf() -> SedonaScalarUDF {
+    SedonaScalarUDF::new(
+        "st_geomfromewkb",
+        vec![Arc::new(STGeomFromEWKB {})],
+        Volatility::Immutable,
+        Some(doc()),
+    )
+}
+
+fn doc() -> Documentation {
+    Documentation::builder(
+        DOC_SECTION_OTHER,
+        "Construct a geometry from EWKB".to_string(),
+        "ST_GeomFromEWKB (Wkb: Binary)".to_string(),
+    )
+    .with_argument(
+        "EWKB",
+        "binary: Extended well-known binary (EWKB) representation of the 
geometry".to_string(),
+    )
+    .with_sql_example("SELECT ST_GeomFromEWKB([01 02 00 00 00 02 00 00 00 00 
00 00 00 84 D6 00 C0 00 00 00 00 80 B5 D6 BF 00 00 00 60 E1 EF F7 BF 00 00 00 
80 07 5D E5 BF])")
+    .build()
+}
+
+#[derive(Debug)]
+struct STGeomFromEWKB {}
+
+impl SedonaScalarKernel for STGeomFromEWKB {
+    fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
+        let matcher = ArgMatcher::new(vec![ArgMatcher::is_binary()], 
WKB_GEOMETRY_ITEM_CRS.clone());
+        matcher.match_args(args)
+    }
+
+    fn invoke_batch(
+        &self,
+        arg_types: &[SedonaType],
+        args: &[ColumnarValue],
+    ) -> Result<ColumnarValue> {
+        let iter_type = match &arg_types[0] {
+            SedonaType::Arrow(data_type) => match data_type {
+                DataType::Binary => WKB_GEOMETRY,
+                DataType::BinaryView => WKB_VIEW_GEOGRAPHY,
+                DataType::Null => SedonaType::Arrow(DataType::Null),
+                _ => {
+                    return sedona_internal_err!(
+                        "Unexpected arguments to invoke_batch: {arg_types:?}"
+                    )
+                }
+            },
+            _ => {
+                return sedona_internal_err!("Unexpected arguments to 
invoke_batch: {arg_types:?}")
+            }
+        };
+
+        let temp_args = [iter_type];
+        let executor = WkbExecutor::new(&temp_args, args);
+        let mut geom_builder = BinaryBuilder::with_capacity(
+            executor.num_iterations(),
+            WKB_MIN_PROBABLE_BYTES * executor.num_iterations(),
+        );
+        let mut srid_builder = 
StringViewBuilder::with_capacity(executor.num_iterations());
+
+        executor.execute_wkb_void(|maybe_item| {
+            match maybe_item {
+                Some(item) => {
+                    let header =
+                        WkbHeader::try_new(item.buf()).map_err(|e| 
exec_datafusion_err!("{e}"))?;
+                    let maybe_crs = match header.srid() {
+                        0 => None,
+                        valid_srid => Some(format!("EPSG:{valid_srid}")),
+                    };
+
+                    wkb::writer::write_geometry(&mut geom_builder, &item, 
&Default::default())
+                        .map_err(|e| exec_datafusion_err!("{e}"))?;
+                    geom_builder.append_value([]);
+                    srid_builder.append_option(maybe_crs);
+                }
+                None => {
+                    geom_builder.append_null();
+                    srid_builder.append_null();
+                }
+            }
+
+            Ok(())
+        })?;
+
+        let new_geom_array = geom_builder.finish();
+        let item_result = executor.finish(Arc::new(new_geom_array))?;
+
+        let new_srid_array = srid_builder.finish();
+        let crs_value = if matches!(&item_result, ColumnarValue::Scalar(_)) {
+            ColumnarValue::Scalar(ScalarValue::try_from_array(&new_srid_array, 
0)?)
+        } else {
+            ColumnarValue::Array(Arc::new(new_srid_array))
+        };
+
+        make_item_crs(&WKB_GEOMETRY, item_result, &crs_value, None)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use arrow_array::BinaryArray;
+    use datafusion_common::scalar::ScalarValue;
+    use datafusion_expr::ScalarUDF;
+    use rstest::rstest;
+    use sedona_testing::{
+        compare::{assert_array_equal, assert_scalar_equal},
+        create::{create_array_item_crs, create_scalar, create_scalar_item_crs},
+        fixtures::POINT_WITH_SRID_4326_EWKB,
+        testers::ScalarUdfTester,
+    };
+
+    use super::*;
+
+    const POINT12: [u8; 21] = [
+        0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 
0xf0, 0x3f, 0x00, 0x00,
+        0x00, 0x00, 0x00, 0x00, 0x00, 0x40,
+    ];
+
+    #[test]
+    fn udf_metadata() {
+        let geog_from_wkb: ScalarUDF = st_geomfromewkb_udf().into();
+        assert_eq!(geog_from_wkb.name(), "st_geomfromewkb");
+        assert!(geog_from_wkb.documentation().is_some());
+    }
+
+    #[rstest]
+    fn udf(#[values(DataType::Binary, DataType::BinaryView)] data_type: 
DataType) {
+        let udf = st_geomfromewkb_udf();
+        let tester = ScalarUdfTester::new(
+            udf.clone().into(),
+            vec![SedonaType::Arrow(data_type.clone())],
+        );
+
+        assert_eq!(tester.return_type().unwrap(), 
WKB_GEOMETRY_ITEM_CRS.clone());
+
+        assert_scalar_equal(
+            &tester
+                .invoke_scalar(POINT_WITH_SRID_4326_EWKB.to_vec())
+                .unwrap(),
+            &create_scalar_item_crs(Some("POINT (1 2)"), Some("EPSG:4326"), 
&WKB_GEOMETRY),
+        );
+
+        assert_scalar_equal(
+            &tester.invoke_scalar(ScalarValue::Null).unwrap(),
+            &create_scalar(None, &WKB_GEOMETRY_ITEM_CRS),
+        );
+
+        let binary_array: BinaryArray = [
+            Some(POINT12.to_vec()),
+            None,
+            Some(POINT_WITH_SRID_4326_EWKB.to_vec()),
+        ]
+        .iter()
+        .collect();
+        assert_array_equal(
+            &tester.invoke_array(Arc::new(binary_array)).unwrap(),
+            &create_array_item_crs(
+                &[Some("POINT (1 2)"), None, Some("POINT (1 2)")],
+                [None, None, Some("EPSG:4326")],
+                &WKB_GEOMETRY,
+            ),
+        );
+    }
+}
diff --git a/rust/sedona-functions/src/st_setsrid.rs 
b/rust/sedona-functions/src/st_setsrid.rs
index 7c5b9713..ddef11d4 100644
--- a/rust/sedona-functions/src/st_setsrid.rs
+++ b/rust/sedona-functions/src/st_setsrid.rs
@@ -36,7 +36,7 @@ use datafusion_expr::{
 };
 use sedona_common::sedona_internal_err;
 use sedona_expr::{
-    item_crs::make_item_crs,
+    item_crs::{make_item_crs, parse_item_crs_arg, 
parse_item_crs_arg_type_strip_crs},
     scalar_udf::{ScalarKernelRef, SedonaScalarKernel, SedonaScalarUDF},
 };
 use sedona_geometry::transform::CrsEngine;
@@ -129,7 +129,7 @@ impl SedonaScalarKernel for STSetSRID {
         scalar_args: &[Option<&ScalarValue>],
     ) -> Result<Option<SedonaType>> {
         if args.len() != 2
-            || !(ArgMatcher::is_numeric().match_type(&args[1])
+            || !(ArgMatcher::is_integer().match_type(&args[1])
                 || ArgMatcher::is_null().match_type(&args[1]))
         {
             return Ok(None);
@@ -145,17 +145,20 @@ impl SedonaScalarKernel for STSetSRID {
         _num_rows: usize,
         _config_options: Option<&ConfigOptions>,
     ) -> Result<ColumnarValue> {
+        let (item_type, maybe_crs_type) = 
parse_item_crs_arg_type_strip_crs(&arg_types[0])?;
+        let (item_arg, _) = parse_item_crs_arg(&item_type, &maybe_crs_type, 
&args[0])?;
+
         let item_crs_matcher = ArgMatcher::is_item_crs();
         if item_crs_matcher.match_type(return_type) {
             let normalized_crs_value = normalize_crs_array(&args[1], 
self.engine.as_ref())?;
             make_item_crs(
-                &arg_types[0],
-                args[0].clone(),
+                &item_type,
+                item_arg,
                 &ColumnarValue::Array(normalized_crs_value),
                 crs_input_nulls(&args[1]),
             )
         } else {
-            Ok(args[0].clone())
+            Ok(item_arg)
         }
     }
 
@@ -202,17 +205,20 @@ impl SedonaScalarKernel for STSetCRS {
         _num_rows: usize,
         _config_options: Option<&ConfigOptions>,
     ) -> Result<ColumnarValue> {
+        let (item_type, maybe_crs_type) = 
parse_item_crs_arg_type_strip_crs(&arg_types[0])?;
+        let (item_arg, _) = parse_item_crs_arg(&item_type, &maybe_crs_type, 
&args[0])?;
+
         let item_crs_matcher = ArgMatcher::is_item_crs();
         if item_crs_matcher.match_type(return_type) {
             let normalized_crs_value = normalize_crs_array(&args[1], 
self.engine.as_ref())?;
             make_item_crs(
-                &arg_types[0],
-                args[0].clone(),
+                &item_type,
+                item_arg,
                 &ColumnarValue::Array(normalized_crs_value),
                 crs_input_nulls(&args[1]),
             )
         } else {
-            Ok(args[0].clone())
+            Ok(item_arg)
         }
     }
 
@@ -236,7 +242,11 @@ fn determine_return_type(
     scalar_args: &[Option<&ScalarValue>],
     maybe_engine: Option<&Arc<dyn CrsEngine + Send + Sync>>,
 ) -> Result<Option<SedonaType>> {
-    if !ArgMatcher::is_geometry_or_geography().match_type(&args[0]) {
+    let (item_type, _) = parse_item_crs_arg_type_strip_crs(&args[0])?;
+
+    // If this is not geometry or geography and/or this is not an item_crs of 
one,
+    // this kernel does not apply.
+    if !ArgMatcher::is_geometry_or_geography().match_type(&item_type) {
         return Ok(None);
     }
 
@@ -244,29 +254,23 @@ fn determine_return_type(
         if let ScalarValue::Utf8(maybe_crs) = 
scalar_crs.cast_to(&DataType::Utf8)? {
             let new_crs = match maybe_crs {
                 Some(crs) => {
-                    if crs == "0" {
-                        None
-                    } else {
-                        validate_crs(&crs, maybe_engine)?;
-                        deserialize_crs(&crs)?
-                    }
+                    validate_crs(&crs, maybe_engine)?;
+                    deserialize_crs(&crs)?
                 }
                 None => None,
             };
 
-            match args[0] {
-                SedonaType::Wkb(edges, _) => return 
Ok(Some(SedonaType::Wkb(edges, new_crs))),
-                SedonaType::WkbView(edges, _) => {
-                    return Ok(Some(SedonaType::WkbView(edges, new_crs)))
-                }
-                _ => {}
+            match item_type {
+                SedonaType::Wkb(edges, _) => Ok(Some(SedonaType::Wkb(edges, 
new_crs))),
+                SedonaType::WkbView(edges, _) => 
Ok(Some(SedonaType::WkbView(edges, new_crs))),
+                _ => sedona_internal_err!("Unexpected argument types: {}, {}", 
args[0], args[1]),
             }
+        } else {
+            sedona_internal_err!("Unexpected return type of cast to string")
         }
     } else {
-        return Ok(Some(SedonaType::new_item_crs(&args[0])?));
+        Ok(Some(SedonaType::new_item_crs(&item_type)?))
     }
-
-    sedona_internal_err!("Unexpected argument types: {}, {}", args[0], args[1])
 }
 
 /// [SedonaScalarKernel] wrapper that handles the SRID argument for 
constructors like ST_Point
@@ -523,6 +527,7 @@ mod test {
     use arrow_schema::Field;
     use datafusion_common::config::ConfigOptions;
     use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF};
+    use rstest::rstest;
     use sedona_geometry::{error::SedonaGeometryError, transform::CrsTransform};
     use sedona_schema::{
         crs::lnglat,
@@ -640,11 +645,13 @@ mod test {
         assert_eq!(err.message(), "Unknown geometry error")
     }
 
-    #[test]
-    fn udf_item_srid() {
+    #[rstest]
+    fn udf_item_srid_output(
+        #[values(WKB_GEOMETRY, WKB_GEOMETRY_ITEM_CRS.clone())] sedona_type: 
SedonaType,
+    ) {
         let tester = ScalarUdfTester::new(
             st_set_srid_udf().into(),
-            vec![WKB_GEOMETRY, SedonaType::Arrow(DataType::Int32)],
+            vec![sedona_type.clone(), SedonaType::Arrow(DataType::Int32)],
         );
         tester.assert_return_type(WKB_GEOMETRY_ITEM_CRS.clone());
 
@@ -656,7 +663,7 @@ mod test {
                 Some("POINT (6 7)"),
                 Some("POINT (8 9)"),
             ],
-            &WKB_GEOMETRY,
+            &sedona_type,
         );
         let crs_array =
             create_array!(Int32, [Some(4326), Some(3857), Some(3857), Some(0), 
None]) as ArrayRef;
@@ -686,11 +693,13 @@ mod test {
         );
     }
 
-    #[test]
-    fn udf_item_crs() {
+    #[rstest]
+    fn udf_item_crs_output(
+        #[values(WKB_GEOMETRY, WKB_GEOMETRY_ITEM_CRS.clone())] sedona_type: 
SedonaType,
+    ) {
         let tester = ScalarUdfTester::new(
             st_set_crs_udf().into(),
-            vec![WKB_GEOMETRY, SedonaType::Arrow(DataType::Utf8)],
+            vec![sedona_type.clone(), SedonaType::Arrow(DataType::Utf8)],
         );
         tester.assert_return_type(WKB_GEOMETRY_ITEM_CRS.clone());
 
@@ -702,7 +711,7 @@ mod test {
                 Some("POINT (6 7)"),
                 Some("POINT (8 9)"),
             ],
-            &WKB_GEOMETRY,
+            &sedona_type,
         );
         let crs_array = create_array!(
             Utf8,

Reply via email to