This is an automated email from the ASF dual-hosted git repository.
paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new 527c1aa0 feat(rust/sedona-raster-functions): add RS_SetSRID/RS_SetCRS
with batch-local cache refactoring (#630)
527c1aa0 is described below
commit 527c1aa051c54c084c1df3e4d29a551cef55fe28
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Feb 18 22:57:29 2026 +0800
feat(rust/sedona-raster-functions): add RS_SetSRID/RS_SetCRS with
batch-local cache refactoring (#630)
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Dewey Dunnington <[email protected]>
---
rust/sedona-functions/src/st_setsrid.rs | 50 +-
.../benches/native-raster-functions.rs | 14 +
rust/sedona-raster-functions/src/lib.rs | 1 +
rust/sedona-raster-functions/src/register.rs | 2 +
rust/sedona-raster-functions/src/rs_setsrid.rs | 722 +++++++++++++++++++++
rust/sedona-schema/src/crs.rs | 114 +++-
6 files changed, 856 insertions(+), 47 deletions(-)
diff --git a/rust/sedona-functions/src/st_setsrid.rs
b/rust/sedona-functions/src/st_setsrid.rs
index 3a11e778..e5d7cb77 100644
--- a/rust/sedona-functions/src/st_setsrid.rs
+++ b/rust/sedona-functions/src/st_setsrid.rs
@@ -14,10 +14,7 @@
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
-use std::{
- collections::{HashMap, HashSet},
- sync::{Arc, OnceLock},
-};
+use std::sync::{Arc, OnceLock};
use arrow_array::{
builder::{BinaryBuilder, NullBufferBuilder},
@@ -41,7 +38,11 @@ use sedona_expr::{
scalar_udf::{ScalarKernelRef, SedonaScalarKernel, SedonaScalarUDF},
};
use sedona_geometry::transform::CrsEngine;
-use sedona_schema::{crs::deserialize_crs, datatypes::SedonaType,
matchers::ArgMatcher};
+use sedona_schema::{
+ crs::{deserialize_crs, CachedCrsNormalization, CachedSRIDToCrs},
+ datatypes::SedonaType,
+ matchers::ArgMatcher,
+};
/// ST_SetSRID() scalar UDF implementation
///
@@ -473,8 +474,7 @@ fn normalize_crs_array(
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64 => {
- // Local cache to avoid re-validating inputs
- let mut known_valid = HashSet::new();
+ let mut srid_to_crs = CachedSRIDToCrs::new();
let int_value = crs_value.cast_to(&DataType::Int64, None)?;
let int_array_ref = ColumnarValue::values_to_arrays(&[int_value])?;
@@ -483,18 +483,10 @@ fn normalize_crs_array(
.iter()
.map(|maybe_srid| -> Result<Option<String>> {
if let Some(srid) = maybe_srid {
- if srid == 0 {
+ let Some(auth_code) = srid_to_crs.get_crs(srid)? else {
return Ok(None);
- } else if srid == 4326 {
- return Ok(Some("OGC:CRS84".to_string()));
- }
-
- let auth_code = format!("EPSG:{srid}");
- if !known_valid.contains(&srid) {
- validate_crs(&auth_code, maybe_engine)?;
- known_valid.insert(srid);
- }
-
+ };
+ validate_crs(&auth_code, maybe_engine)?;
Ok(Some(auth_code))
} else {
Ok(None)
@@ -505,7 +497,7 @@ fn normalize_crs_array(
Ok(Arc::new(utf8_view_array))
}
_ => {
- let mut known_abbreviated = HashMap::<String, String>::new();
+ let mut crs_norm = CachedCrsNormalization::new();
let string_value = crs_value.cast_to(&DataType::Utf8View, None)?;
let string_array_ref =
ColumnarValue::values_to_arrays(&[string_value])?;
@@ -514,25 +506,7 @@ fn normalize_crs_array(
.iter()
.map(|maybe_crs| -> Result<Option<String>> {
if let Some(crs_str) = maybe_crs {
- if crs_str == "0" {
- return Ok(None);
- }
-
- if let Some(abbreviated_crs) =
known_abbreviated.get(crs_str) {
- Ok(Some(abbreviated_crs.clone()))
- } else if let Some(crs) = deserialize_crs(crs_str)? {
- let abbreviated_crs =
- if let Some(auth_code) =
crs.to_authority_code()? {
- auth_code
- } else {
- crs_str.to_string()
- };
-
- known_abbreviated.insert(crs.to_string(),
abbreviated_crs.clone());
- Ok(Some(abbreviated_crs))
- } else {
- Ok(None)
- }
+ crs_norm.normalize(crs_str)
} else {
Ok(None)
}
diff --git a/rust/sedona-raster-functions/benches/native-raster-functions.rs
b/rust/sedona-raster-functions/benches/native-raster-functions.rs
index 5eead05d..eae081ac 100644
--- a/rust/sedona-raster-functions/benches/native-raster-functions.rs
+++ b/rust/sedona-raster-functions/benches/native-raster-functions.rs
@@ -55,6 +55,20 @@ fn criterion_benchmark(c: &mut Criterion) {
BenchmarkArgs::ArrayScalarScalar(Raster(64, 64), Int32(0, 63),
Int32(0, 63)),
);
benchmark::scalar(c, &f, "native-raster", "rs_rotation", Raster(64, 64));
+ benchmark::scalar(
+ c,
+ &f,
+ "native-raster",
+ "rs_setcrs",
+ BenchmarkArgs::ArrayScalar(Raster(64, 64),
String("EPSG:3857".to_string())),
+ );
+ benchmark::scalar(
+ c,
+ &f,
+ "native-raster",
+ "rs_setsrid",
+ BenchmarkArgs::ArrayScalar(Raster(64, 64), Int32(3857, 3858)),
+ );
benchmark::scalar(c, &f, "native-raster", "rs_scalex", Raster(64, 64));
benchmark::scalar(c, &f, "native-raster", "rs_scaley", Raster(64, 64));
benchmark::scalar(c, &f, "native-raster", "rs_skewx", Raster(64, 64));
diff --git a/rust/sedona-raster-functions/src/lib.rs
b/rust/sedona-raster-functions/src/lib.rs
index e7c63b03..55325b1a 100644
--- a/rust/sedona-raster-functions/src/lib.rs
+++ b/rust/sedona-raster-functions/src/lib.rs
@@ -24,6 +24,7 @@ pub mod rs_georeference;
pub mod rs_geotransform;
pub mod rs_numbands;
pub mod rs_rastercoordinate;
+pub mod rs_setsrid;
pub mod rs_size;
pub mod rs_srid;
pub mod rs_worldcoordinate;
diff --git a/rust/sedona-raster-functions/src/register.rs
b/rust/sedona-raster-functions/src/register.rs
index 6f5e2baa..fc687e1b 100644
--- a/rust/sedona-raster-functions/src/register.rs
+++ b/rust/sedona-raster-functions/src/register.rs
@@ -55,6 +55,8 @@ pub fn default_function_set() -> FunctionSet {
crate::rs_rastercoordinate::rs_worldtorastercoordy_udf,
crate::rs_size::rs_height_udf,
crate::rs_size::rs_width_udf,
+ crate::rs_setsrid::rs_set_crs_udf,
+ crate::rs_setsrid::rs_set_srid_udf,
crate::rs_srid::rs_crs_udf,
crate::rs_srid::rs_srid_udf,
crate::rs_worldcoordinate::rs_rastertoworldcoord_udf,
diff --git a/rust/sedona-raster-functions/src/rs_setsrid.rs
b/rust/sedona-raster-functions/src/rs_setsrid.rs
new file mode 100644
index 00000000..a58409d9
--- /dev/null
+++ b/rust/sedona-raster-functions/src/rs_setsrid.rs
@@ -0,0 +1,722 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use arrow_array::{Array, ArrayRef, StringViewArray, StructArray};
+use arrow_buffer::NullBuffer;
+use arrow_schema::DataType;
+use datafusion_common::cast::{as_int64_array, as_string_view_array};
+use datafusion_common::error::Result;
+use datafusion_common::{exec_err, DataFusionError, ScalarValue};
+use datafusion_expr::{
+ scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation,
Volatility,
+};
+use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF};
+use sedona_geometry::transform::CrsEngine;
+use sedona_schema::crs::{CachedCrsNormalization, CachedSRIDToCrs};
+use sedona_schema::datatypes::SedonaType;
+use sedona_schema::matchers::ArgMatcher;
+use sedona_schema::raster::{raster_indices, RasterSchema};
+
+/// RS_SetSRID() scalar UDF implementation
+///
+/// An implementation of RS_SetSRID providing a scalar function implementation
+/// based on an optional [CrsEngine]. If provided, it will be used to validate
+/// the provided SRID (otherwise, all SRID input is applied without
validation).
+pub fn rs_set_srid_with_engine_udf(
+ engine: Option<Arc<dyn CrsEngine + Send + Sync>>,
+) -> SedonaScalarUDF {
+ SedonaScalarUDF::new(
+ "rs_setsrid",
+ vec![Arc::new(RsSetSrid { engine })],
+ Volatility::Immutable,
+ Some(rs_set_srid_doc()),
+ )
+}
+
+/// RS_SetCRS() scalar UDF implementation
+///
+/// An implementation of RS_SetCRS providing a scalar function implementation
+/// based on an optional [CrsEngine]. If provided, it will be used to validate
+/// the provided CRS (otherwise, all CRS input is applied without validation).
+pub fn rs_set_crs_with_engine_udf(
+ engine: Option<Arc<dyn CrsEngine + Send + Sync>>,
+) -> SedonaScalarUDF {
+ SedonaScalarUDF::new(
+ "rs_setcrs",
+ vec![Arc::new(RsSetCrs { engine })],
+ Volatility::Immutable,
+ Some(rs_set_crs_doc()),
+ )
+}
+
+/// RS_SetSRID() scalar UDF implementation without CRS validation
+///
+/// See [rs_set_srid_with_engine_udf] for a validating version of this function
+pub fn rs_set_srid_udf() -> SedonaScalarUDF {
+ rs_set_srid_with_engine_udf(None)
+}
+
+/// RS_SetCRS() scalar UDF implementation without CRS validation
+///
+/// See [rs_set_crs_with_engine_udf] for a validating version of this function
+pub fn rs_set_crs_udf() -> SedonaScalarUDF {
+ rs_set_crs_with_engine_udf(None)
+}
+
+fn rs_set_srid_doc() -> Documentation {
+ Documentation::builder(
+ DOC_SECTION_OTHER,
+ "Set the spatial reference system identifier (SRID) of the
raster".to_string(),
+ "RS_SetSRID(raster: Raster, srid: Integer)".to_string(),
+ )
+ .with_argument("raster", "Raster: Input raster")
+ .with_argument("srid", "Integer: EPSG code to set (e.g., 4326)")
+ .with_sql_example("SELECT RS_SetSRID(RS_Example(), 3857)".to_string())
+ .build()
+}
+
+fn rs_set_crs_doc() -> Documentation {
+ Documentation::builder(
+ DOC_SECTION_OTHER,
+ "Set the coordinate reference system (CRS) of the raster".to_string(),
+ "RS_SetCRS(raster: Raster, crs: String)".to_string(),
+ )
+ .with_argument("raster", "Raster: Input raster")
+ .with_argument(
+ "crs",
+ "String: Coordinate reference system identifier (e.g., 'OGC:CRS84',
'EPSG:4326')",
+ )
+ .with_sql_example("SELECT RS_SetCRS(RS_Example(),
'EPSG:3857')".to_string())
+ .build()
+}
+
+// ---------------------------------------------------------------------------
+// RS_SetSRID kernel
+// ---------------------------------------------------------------------------
+
+#[derive(Debug)]
+struct RsSetSrid {
+ engine: Option<Arc<dyn CrsEngine + Send + Sync>>,
+}
+
+impl SedonaScalarKernel for RsSetSrid {
+ fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
+ let matcher = ArgMatcher::new(
+ vec![ArgMatcher::is_raster(), ArgMatcher::is_integer()],
+ SedonaType::Raster,
+ );
+ matcher.match_args(args)
+ }
+
+ fn invoke_batch(
+ &self,
+ _arg_types: &[SedonaType],
+ args: &[ColumnarValue],
+ ) -> Result<ColumnarValue> {
+ let raster_arg = &args[0];
+ let srid_arg = &args[1];
+
+ let input_nulls = extract_input_nulls(srid_arg);
+
+ // Convert SRID integer(s) to CRS string(s)
+ let crs_columnar = srid_to_crs_columnar(srid_arg,
self.engine.as_ref())?;
+
+ replace_raster_crs(raster_arg, &crs_columnar, input_nulls)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// RS_SetCRS kernel
+// ---------------------------------------------------------------------------
+
+#[derive(Debug)]
+struct RsSetCrs {
+ engine: Option<Arc<dyn CrsEngine + Send + Sync>>,
+}
+
+impl SedonaScalarKernel for RsSetCrs {
+ fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
+ let matcher = ArgMatcher::new(
+ vec![ArgMatcher::is_raster(), ArgMatcher::is_string()],
+ SedonaType::Raster,
+ );
+ matcher.match_args(args)
+ }
+
+ fn invoke_batch(
+ &self,
+ _arg_types: &[SedonaType],
+ args: &[ColumnarValue],
+ ) -> Result<ColumnarValue> {
+ let raster_arg = &args[0];
+ let crs_arg = &args[1];
+
+ let input_nulls = extract_input_nulls(crs_arg);
+
+ // Normalize the CRS string(s) — abbreviate PROJJSON to
authority:code, map "0" to null
+ let crs_columnar = normalize_crs_columnar(crs_arg,
self.engine.as_ref())?;
+
+ replace_raster_crs(raster_arg, &crs_columnar, input_nulls)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Core: zero-copy CRS column swap
+// ---------------------------------------------------------------------------
+
+/// Replace the CRS column of a raster StructArray with a new CRS value.
+///
+/// This is a zero-copy operation for the metadata and bands columns:
+/// we clone the Arc pointers for columns 0 (metadata) and 2 (bands),
+/// and only rebuild column 1 (CRS) from the provided value.
+///
+/// When `input_nulls` is provided, rows where the original SRID/CRS input was
+/// null will have the entire raster nulled out (not just the CRS column).
+fn replace_raster_crs(
+ raster_arg: &ColumnarValue,
+ crs_array: &StringViewArray,
+ input_nulls: Option<NullBuffer>,
+) -> Result<ColumnarValue> {
+ match raster_arg {
+ ColumnarValue::Array(raster_array) => {
+ let raster_struct = raster_array
+ .as_any()
+ .downcast_ref::<StructArray>()
+ .ok_or_else(|| {
+ datafusion_common::DataFusionError::Internal(
+ "Expected StructArray for raster data".to_string(),
+ )
+ })?;
+
+ let num_rows = raster_struct.len();
+ let new_crs: ArrayRef = broadcast_string_view(crs_array, num_rows);
+ let new_struct = swap_crs_column(raster_struct, new_crs)?;
+
+ let input_nulls = input_nulls.map(|nulls| {
+ if nulls.len() == 1 && num_rows != 1 {
+ if nulls.is_valid(0) {
+ NullBuffer::new_valid(num_rows)
+ } else {
+ NullBuffer::new_null(num_rows)
+ }
+ } else {
+ nulls
+ }
+ });
+
+ // Merge input nulls: rows where the SRID/CRS input was null
become null rasters
+ let merged_nulls = NullBuffer::union(new_struct.nulls(),
input_nulls.as_ref());
+ let new_struct = StructArray::new(
+ RasterSchema::fields(),
+ new_struct.columns().to_vec(),
+ merged_nulls,
+ );
+
+ Ok(ColumnarValue::Array(Arc::new(new_struct)))
+ }
+ ColumnarValue::Scalar(ScalarValue::Struct(arc_struct)) => {
+ let new_crs: ArrayRef = Arc::new(crs_array.clone());
+ let new_struct = swap_crs_column(arc_struct.as_ref(), new_crs)?;
+
+ // Merge input nulls: null SRID/CRS input produces a null raster
+ let merged_nulls = NullBuffer::union(new_struct.nulls(),
input_nulls.as_ref());
+ let new_struct = StructArray::new(
+ RasterSchema::fields(),
+ new_struct.columns().to_vec(),
+ merged_nulls,
+ );
+
+ Ok(ColumnarValue::Scalar(ScalarValue::Struct(Arc::new(
+ new_struct,
+ ))))
+ }
+ ColumnarValue::Scalar(ScalarValue::Null) =>
Ok(ColumnarValue::Scalar(ScalarValue::Null)),
+ _ => exec_err!("Expected raster (Struct) input for
RS_SetSRID/RS_SetCRS"),
+ }
+}
+
+/// Broadcast a `StringViewArray` to a target length.
+///
+/// If the array already has the target length, it is returned as-is (clone of
Arc).
+/// Otherwise the array must have length 1, and its single value is repeated.
+fn broadcast_string_view(array: &StringViewArray, len: usize) -> ArrayRef {
+ if array.len() == len {
+ return Arc::new(array.clone());
+ }
+ debug_assert_eq!(array.len(), 1);
+ if array.is_null(0) {
+ Arc::new(StringViewArray::new_null(len))
+ } else {
+ let value = array.value(0);
+ Arc::new(std::iter::repeat_n(Some(value),
len).collect::<StringViewArray>())
+ }
+}
+
+/// Swap only the CRS column of a raster StructArray, keeping all other
columns intact.
+fn swap_crs_column(raster_struct: &StructArray, new_crs_array: ArrayRef) ->
Result<StructArray> {
+ let mut columns: Vec<ArrayRef> = raster_struct.columns().to_vec();
+ columns[raster_indices::CRS] = new_crs_array;
+ Ok(StructArray::new(
+ RasterSchema::fields(),
+ columns,
+ raster_struct.nulls().cloned(),
+ ))
+}
+
+/// Extract a [NullBuffer] from the original SRID/CRS input argument.
+///
+/// For arrays, this returns the array's null buffer directly.
+/// For scalars, this returns a single-element null buffer if the scalar is
null.
+///
+/// This is used to distinguish "input was null" (which should null the raster)
+/// from "input mapped to null CRS" (e.g. SRID=0 or CRS="0", which should
+/// clear the CRS but preserve the raster).
+fn extract_input_nulls(input: &ColumnarValue) -> Option<NullBuffer> {
+ match input {
+ ColumnarValue::Array(array) => array.nulls().cloned(),
+ ColumnarValue::Scalar(scalar) => {
+ if scalar.is_null() {
+ Some(NullBuffer::new_null(1))
+ } else {
+ None
+ }
+ }
+ }
+}
+
+// ---------------------------------------------------------------------------
+// SRID-to-CRS conversion
+// ---------------------------------------------------------------------------
+
+/// Convert an SRID integer ColumnarValue to a CRS StringViewArray
ColumnarValue.
+///
+/// Uses [CachedSRIDToCrs] to avoid repeated validation of the same SRID
within a batch.
+///
+/// Mapping:
+/// - 0 -> null (no CRS)
+/// - 4326 -> "OGC:CRS84"
+/// - other -> "EPSG:{srid}"
+fn srid_to_crs_columnar(
+ srid_arg: &ColumnarValue,
+ maybe_engine: Option<&Arc<dyn CrsEngine + Send + Sync>>,
+) -> Result<StringViewArray> {
+ let mut srid_to_crs = CachedSRIDToCrs::new();
+
+ // Cast to Int64 for uniform handling
+ let int_value = srid_arg.cast_to(&DataType::Int64, None)?;
+ let int_array_ref = ColumnarValue::values_to_arrays(&[int_value])?;
+ let int_array = as_int64_array(&int_array_ref[0])?;
+
+ let crs_array: StringViewArray = int_array
+ .iter()
+ .map(|maybe_srid| -> Result<Option<String>> {
+ if let Some(srid) = maybe_srid {
+ let Some(auth_code) = srid_to_crs.get_crs(srid)? else {
+ return Ok(None);
+ };
+ validate_crs(&auth_code, maybe_engine)?;
+ Ok(Some(auth_code))
+ } else {
+ Ok(None)
+ }
+ })
+ .collect::<Result<_>>()?;
+
+ Ok(crs_array)
+}
+
+// ---------------------------------------------------------------------------
+// CRS string normalization
+// ---------------------------------------------------------------------------
+
+/// Normalize a CRS string ColumnarValue — abbreviate PROJJSON to
authority:code
+/// where possible, and map "0" to null.
+///
+/// Uses [CachedCrsNormalization] to avoid repeated deserialization of the
same CRS
+/// string within a batch.
+fn normalize_crs_columnar(
+ crs_arg: &ColumnarValue,
+ _maybe_engine: Option<&Arc<dyn CrsEngine + Send + Sync>>,
+) -> Result<StringViewArray> {
+ let mut crs_norm = CachedCrsNormalization::new();
+
+ let string_value = crs_arg.cast_to(&DataType::Utf8View, None)?;
+ let string_array_ref = ColumnarValue::values_to_arrays(&[string_value])?;
+ let string_view_array = as_string_view_array(&string_array_ref[0])?;
+
+ let crs_array: StringViewArray = string_view_array
+ .iter()
+ .map(|maybe_crs| -> Result<Option<String>> {
+ if let Some(crs_str) = maybe_crs {
+ let normalized = crs_norm.normalize(crs_str)?;
+ Ok(normalized)
+ } else {
+ Ok(None)
+ }
+ })
+ .collect::<Result<_>>()?;
+
+ Ok(crs_array)
+}
+
+/// Validate a CRS string
+///
+/// If an engine is provided, the engine will be used to validate the CRS.
+/// Otherwise, no additional validation is performed (basic deserialization
+/// checks are handled by the cache structs).
+fn validate_crs(crs: &str, maybe_engine: Option<&Arc<dyn CrsEngine + Send +
Sync>>) -> Result<()> {
+ if let Some(engine) = maybe_engine {
+ engine
+ .as_ref()
+ .get_transform_crs_to_crs(crs, crs, None, "")
+ .map_err(|e| DataFusionError::External(Box::new(e)))?;
+ }
+
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use arrow_array::StructArray;
+ use datafusion_common::ScalarValue;
+ use datafusion_expr::ScalarUDF;
+ use sedona_raster::array::RasterStructArray;
+ use sedona_raster::traits::RasterRef;
+ use sedona_schema::datatypes::RASTER;
+ use sedona_testing::rasters::generate_test_rasters;
+ use sedona_testing::testers::ScalarUdfTester;
+
+ #[test]
+ fn udf_metadata() {
+ let udf: ScalarUDF = rs_set_srid_udf().into();
+ assert_eq!(udf.name(), "rs_setsrid");
+ assert!(udf.documentation().is_some());
+
+ let udf: ScalarUDF = rs_set_crs_udf().into();
+ assert_eq!(udf.name(), "rs_setcrs");
+ assert!(udf.documentation().is_some());
+ }
+
+ #[test]
+ fn set_srid_array() {
+ let udf: ScalarUDF = rs_set_srid_udf().into();
+ let tester = ScalarUdfTester::new(udf, vec![RASTER,
SedonaType::Arrow(DataType::UInt32)]);
+
+ tester.assert_return_type(RASTER);
+
+ // Generate rasters with OGC:CRS84 and set SRID to 3857
+ let rasters = generate_test_rasters(3, Some(1)).unwrap();
+ let result = tester
+ .invoke_array_scalar(Arc::new(rasters), 3857u32)
+ .unwrap();
+
+ // Verify CRS was changed to EPSG:3857
+ let result_struct =
result.as_any().downcast_ref::<StructArray>().unwrap();
+ let raster_array = RasterStructArray::new(result_struct);
+ assert_eq!(raster_array.len(), 3);
+
+ let raster0 = raster_array.get(0).unwrap();
+ assert_eq!(raster0.crs(), Some("EPSG:3857"));
+
+ // Null raster at index 1 should remain null
+ assert!(raster_array.is_null(1));
+
+ let raster2 = raster_array.get(2).unwrap();
+ assert_eq!(raster2.crs(), Some("EPSG:3857"));
+ }
+
+ #[test]
+ fn set_srid_4326_maps_to_ogc_crs84() {
+ let udf: ScalarUDF = rs_set_srid_udf().into();
+ let tester = ScalarUdfTester::new(udf, vec![RASTER,
SedonaType::Arrow(DataType::UInt32)]);
+
+ let rasters = generate_test_rasters(1, None).unwrap();
+ let result = tester
+ .invoke_array_scalar(Arc::new(rasters), 4326u32)
+ .unwrap();
+
+ let result_struct =
result.as_any().downcast_ref::<StructArray>().unwrap();
+ let raster_array = RasterStructArray::new(result_struct);
+ let raster = raster_array.get(0).unwrap();
+ assert_eq!(raster.crs(), Some("OGC:CRS84"));
+ }
+
+ #[test]
+ fn set_srid_zero_clears_crs() {
+ let udf: ScalarUDF = rs_set_srid_udf().into();
+ let tester = ScalarUdfTester::new(udf, vec![RASTER,
SedonaType::Arrow(DataType::UInt32)]);
+
+ let rasters = generate_test_rasters(1, None).unwrap();
+ let result = tester.invoke_array_scalar(Arc::new(rasters),
0u32).unwrap();
+
+ let result_struct =
result.as_any().downcast_ref::<StructArray>().unwrap();
+ let raster_array = RasterStructArray::new(result_struct);
+ let raster = raster_array.get(0).unwrap();
+ // CRS should be None (null) for SRID 0
+ assert_eq!(raster.crs(), None);
+ }
+
+ #[test]
+ fn set_crs_array() {
+ let udf: ScalarUDF = rs_set_crs_udf().into();
+ let tester = ScalarUdfTester::new(udf, vec![RASTER,
SedonaType::Arrow(DataType::Utf8)]);
+
+ tester.assert_return_type(RASTER);
+
+ let rasters = generate_test_rasters(3, Some(1)).unwrap();
+ let result = tester
+ .invoke_array_scalar(Arc::new(rasters), "EPSG:3857")
+ .unwrap();
+
+ let result_struct =
result.as_any().downcast_ref::<StructArray>().unwrap();
+ let raster_array = RasterStructArray::new(result_struct);
+ assert_eq!(raster_array.len(), 3);
+
+ let raster0 = raster_array.get(0).unwrap();
+ assert_eq!(raster0.crs(), Some("EPSG:3857"));
+
+ assert!(raster_array.is_null(1));
+
+ let raster2 = raster_array.get(2).unwrap();
+ assert_eq!(raster2.crs(), Some("EPSG:3857"));
+ }
+
+ #[test]
+ fn set_crs_epsg_4326_normalizes_to_ogc_crs84() {
+ let udf: ScalarUDF = rs_set_crs_udf().into();
+ let tester = ScalarUdfTester::new(udf, vec![RASTER,
SedonaType::Arrow(DataType::Utf8)]);
+
+ let rasters = generate_test_rasters(1, None).unwrap();
+ let result = tester
+ .invoke_array_scalar(Arc::new(rasters), "EPSG:4326")
+ .unwrap();
+
+ let result_struct =
result.as_any().downcast_ref::<StructArray>().unwrap();
+ let raster_array = RasterStructArray::new(result_struct);
+ let raster = raster_array.get(0).unwrap();
+ // EPSG:4326 should normalize to OGC:CRS84
+ assert_eq!(raster.crs(), Some("OGC:CRS84"));
+ }
+
+ #[test]
+ fn set_crs_zero_clears_crs() {
+ let udf: ScalarUDF = rs_set_crs_udf().into();
+ let tester = ScalarUdfTester::new(udf, vec![RASTER,
SedonaType::Arrow(DataType::Utf8)]);
+
+ let rasters = generate_test_rasters(1, None).unwrap();
+ let result = tester.invoke_array_scalar(Arc::new(rasters),
"0").unwrap();
+
+ let result_struct =
result.as_any().downcast_ref::<StructArray>().unwrap();
+ let raster_array = RasterStructArray::new(result_struct);
+ let raster = raster_array.get(0).unwrap();
+ assert_eq!(raster.crs(), None);
+ }
+
+ #[test]
+ fn set_srid_preserves_metadata_and_bands() {
+ let udf: ScalarUDF = rs_set_srid_udf().into();
+ let tester = ScalarUdfTester::new(udf, vec![RASTER,
SedonaType::Arrow(DataType::UInt32)]);
+
+ let rasters = generate_test_rasters(3, Some(1)).unwrap();
+ let original_array = RasterStructArray::new(&rasters);
+
+ let result = tester
+ .invoke_array_scalar(Arc::new(rasters.clone()), 3857u32)
+ .unwrap();
+ let result_struct =
result.as_any().downcast_ref::<StructArray>().unwrap();
+ let result_array = RasterStructArray::new(result_struct);
+
+ // Verify non-null rasters have same metadata and band data
+ for i in [0, 2] {
+ let original = original_array.get(i).unwrap();
+ let modified = result_array.get(i).unwrap();
+
+ // Metadata preserved
+ assert_eq!(original.metadata().width(),
modified.metadata().width());
+ assert_eq!(original.metadata().height(),
modified.metadata().height());
+ assert_eq!(
+ original.metadata().upper_left_x(),
+ modified.metadata().upper_left_x()
+ );
+ assert_eq!(
+ original.metadata().upper_left_y(),
+ modified.metadata().upper_left_y()
+ );
+
+ // Band data preserved
+ let orig_bands = original.bands();
+ let mod_bands = modified.bands();
+ assert_eq!(orig_bands.len(), mod_bands.len());
+ for band_idx in 0..orig_bands.len() {
+ let orig_band = orig_bands.band(band_idx + 1).unwrap();
+ let mod_band = mod_bands.band(band_idx + 1).unwrap();
+ assert_eq!(orig_band.data(), mod_band.data());
+ assert_eq!(
+ orig_band.metadata().data_type().unwrap(),
+ mod_band.metadata().data_type().unwrap()
+ );
+ }
+
+ // CRS changed
+ assert_eq!(modified.crs(), Some("EPSG:3857"));
+ assert_ne!(original.crs(), modified.crs());
+ }
+ }
+
+ #[test]
+ fn set_srid_scalar_null_raster() {
+ let udf: ScalarUDF = rs_set_srid_udf().into();
+ let tester = ScalarUdfTester::new(udf, vec![RASTER,
SedonaType::Arrow(DataType::Int32)]);
+
+ let result = tester
+ .invoke_scalar_scalar(ScalarValue::Null, 3857)
+ .unwrap();
+ // ScalarValue::Null gets cast to a typed null raster struct by the
tester,
+ // so the result is a null struct entry (not ScalarValue::Null).
+ match result {
+ ScalarValue::Struct(s) => assert!(s.is_null(0), "Expected null
raster at index 0"),
+ other => panic!("Expected struct scalar, got {other:?}"),
+ }
+ }
+
+ #[test]
+ fn set_crs_scalar_null_raster() {
+ let udf: ScalarUDF = rs_set_crs_udf().into();
+ let tester = ScalarUdfTester::new(udf, vec![RASTER,
SedonaType::Arrow(DataType::Utf8)]);
+
+ let result = tester
+ .invoke_scalar_scalar(ScalarValue::Null, "EPSG:4326")
+ .unwrap();
+ // ScalarValue::Null gets cast to a typed null raster struct by the
tester,
+ // so the result is a null struct entry (not ScalarValue::Null).
+ match result {
+ ScalarValue::Struct(s) => assert!(s.is_null(0), "Expected null
raster at index 0"),
+ other => panic!("Expected struct scalar, got {other:?}"),
+ }
+ }
+
+ #[test]
+ fn set_srid_with_null_srid() {
+ let udf: ScalarUDF = rs_set_srid_udf().into();
+ let tester = ScalarUdfTester::new(udf, vec![RASTER,
SedonaType::Arrow(DataType::Int32)]);
+
+ let rasters = generate_test_rasters(3, Some(1)).unwrap();
+ let null_srid = ScalarValue::Int32(None);
+
+ let result = tester
+ .invoke_array_scalar(Arc::new(rasters), null_srid)
+ .unwrap();
+ let raster_array =
+
RasterStructArray::new(result.as_any().downcast_ref::<StructArray>().unwrap());
+ for i in 0..raster_array.len() {
+ assert!(
+ raster_array.is_null(i),
+ "Expected null raster at index {i} for null SRID input"
+ );
+ }
+ }
+
+ #[test]
+ fn set_crs_with_null_crs() {
+ let udf: ScalarUDF = rs_set_crs_udf().into();
+ let tester = ScalarUdfTester::new(udf, vec![RASTER,
SedonaType::Arrow(DataType::Utf8)]);
+
+ let rasters = generate_test_rasters(3, Some(1)).unwrap();
+ let null_crs = ScalarValue::Utf8(None);
+
+ let result = tester
+ .invoke_array_scalar(Arc::new(rasters), null_crs)
+ .unwrap();
+ let raster_array =
+
RasterStructArray::new(result.as_any().downcast_ref::<StructArray>().unwrap());
+ for i in 0..raster_array.len() {
+ assert!(
+ raster_array.is_null(i),
+ "Expected null raster at index {i} for null SRID input"
+ );
+ }
+ }
+
+ #[test]
+ fn set_srid_array_with_null_srid_per_row() {
+ let udf: ScalarUDF = rs_set_srid_udf().into();
+ let tester = ScalarUdfTester::new(udf, vec![RASTER,
SedonaType::Arrow(DataType::Int32)]);
+
+ // 3 rasters (null at index 1), SRIDs: [Some(3857), Some(4326), None]
+ let rasters = generate_test_rasters(3, Some(1)).unwrap();
+ let srid_array: ArrayRef = Arc::new(arrow_array::Int32Array::from(vec![
+ Some(3857),
+ Some(4326),
+ None,
+ ]));
+
+ let result = tester
+ .invoke_array_array(Arc::new(rasters), srid_array)
+ .unwrap();
+ let result_struct =
result.as_any().downcast_ref::<StructArray>().unwrap();
+ let raster_array = RasterStructArray::new(result_struct);
+
+ // Row 0: valid raster + valid SRID -> EPSG:3857
+ let raster0 = raster_array.get(0).unwrap();
+ assert_eq!(raster0.crs(), Some("EPSG:3857"));
+
+ // Row 1: null raster (from input) -> still null
+ assert!(raster_array.is_null(1));
+
+ // Row 2: valid raster + null SRID -> null raster
+ assert!(
+ raster_array.is_null(2),
+ "Expected null raster at index 2 (null SRID input)"
+ );
+ }
+
+ #[test]
+ fn set_crs_array_with_null_crs_per_row() {
+ let udf: ScalarUDF = rs_set_crs_udf().into();
+ let tester = ScalarUdfTester::new(udf, vec![RASTER,
SedonaType::Arrow(DataType::Utf8)]);
+
+ // 3 rasters (null at index 1), CRS strings: [Some("EPSG:3857"),
Some("OGC:CRS84"), None]
+ let rasters = generate_test_rasters(3, Some(1)).unwrap();
+ let crs_array: ArrayRef = Arc::new(arrow_array::StringArray::from(vec![
+ Some("EPSG:3857"),
+ Some("OGC:CRS84"),
+ None,
+ ]));
+
+ let result = tester
+ .invoke_array_array(Arc::new(rasters), crs_array)
+ .unwrap();
+ let result_struct =
result.as_any().downcast_ref::<StructArray>().unwrap();
+ let raster_array = RasterStructArray::new(result_struct);
+
+ // Row 0: valid raster + valid CRS -> EPSG:3857
+ let raster0 = raster_array.get(0).unwrap();
+ assert_eq!(raster0.crs(), Some("EPSG:3857"));
+
+ // Row 1: null raster (from input) -> still null
+ assert!(raster_array.is_null(1));
+
+ // Row 2: valid raster + null CRS -> null raster
+ assert!(
+ raster_array.is_null(2),
+ "Expected null raster at index 2 (null CRS input)"
+ );
+ }
+}
diff --git a/rust/sedona-schema/src/crs.rs b/rust/sedona-schema/src/crs.rs
index 13835538..54f1f386 100644
--- a/rust/sedona-schema/src/crs.rs
+++ b/rust/sedona-schema/src/crs.rs
@@ -18,7 +18,6 @@ use datafusion_common::{
exec_err, plan_datafusion_err, plan_err, DataFusionError, HashMap, Result,
};
use lru::LruCache;
-use std::borrow::Cow;
use std::cell::RefCell;
use std::fmt::{Debug, Display};
use std::num::NonZeroUsize;
@@ -109,14 +108,9 @@ pub fn deserialize_crs_from_obj(crs_value:
&serde_json::Value) -> Result<Crs> {
}
/// Translating CRS into integer SRID with a cache to avoid expensive CRS
deserialization.
+#[derive(Default)]
pub struct CachedCrsToSRIDMapping {
- cache: HashMap<Cow<'static, str>, u32>,
-}
-
-impl Default for CachedCrsToSRIDMapping {
- fn default() -> Self {
- Self::new()
- }
+ cache: HashMap<String, u32>,
}
impl CachedCrsToSRIDMapping {
@@ -143,7 +137,7 @@ impl CachedCrsToSRIDMapping {
Ok(*srid)
} else if let Some(crs) = deserialize_crs(crs_str)? {
if let Some(srid) = crs.srid()? {
- self.cache.insert(Cow::Owned(crs_str.to_string()), srid);
+ self.cache.insert(crs_str.to_string(), srid);
Ok(srid)
} else {
exec_err!("Can't extract SRID from item-level CRS
'{crs_str}'")
@@ -157,6 +151,108 @@ impl CachedCrsToSRIDMapping {
}
}
+/// Cache for converting integer SRIDs to CRS strings.
+///
+/// Maps SRID integers to their CRS string representation with caching to avoid
+/// repeated validation/deserialization of the same SRID:
+/// - `0` → `None` (no CRS)
+/// - `4326` → `Some("OGC:CRS84")`
+/// - other → `Some("EPSG:{srid}")`, validated once via the caller-provided
closure
+#[derive(Default)]
+pub struct CachedSRIDToCrs {
+ cache: HashMap<i64, Option<String>>,
+}
+
+impl CachedSRIDToCrs {
+ /// Create a new CachedSRIDToCrs with an empty cache.
+ pub fn new() -> Self {
+ Self {
+ cache: HashMap::new(),
+ }
+ }
+
+ /// Create a new CachedSRIDToCrs with a pre-allocated cache.
+ pub fn with_capacity(capacity: usize) -> Self {
+ Self {
+ cache: HashMap::with_capacity(capacity),
+ }
+ }
+
+ /// Get the CRS string for a given SRID, using the cache to avoid repeated
validation.
+ pub fn get_crs(&mut self, srid: i64) -> Result<Option<String>> {
+ if let Some(cached) = self.cache.get(&srid) {
+ return Ok(cached.clone());
+ }
+
+ let result = if srid == 0 {
+ None
+ } else if srid == 4326 {
+ Some("OGC:CRS84".to_string())
+ } else {
+ let auth_code = format!("EPSG:{srid}");
+ Some(auth_code)
+ };
+
+ self.cache.insert(srid, result.clone());
+ Ok(result)
+ }
+}
+
+/// Cache for normalizing CRS strings to their abbreviated form.
+///
+/// Maps CRS input strings to their abbreviated `authority:code` representation
+/// with caching to avoid repeated deserialization of the same CRS string:
+/// - `"0"` or `""` → `None` (no CRS)
+/// - other → deserialized and abbreviated to `authority:code` if possible
+#[derive(Default)]
+pub struct CachedCrsNormalization {
+ cache: HashMap<String, Option<String>>,
+}
+
+impl CachedCrsNormalization {
+ /// Create a new CachedCrsNormalization with an empty cache.
+ pub fn new() -> Self {
+ Self {
+ cache: HashMap::new(),
+ }
+ }
+
+ /// Create a new CachedCrsNormalization with a pre-allocated cache.
+ pub fn with_capacity(capacity: usize) -> Self {
+ Self {
+ cache: HashMap::with_capacity(capacity),
+ }
+ }
+
+ /// Normalize a CRS string, using the cache to avoid repeated
deserialization.
+ ///
+ /// Returns the abbreviated `authority:code` form if available, otherwise
the
+ /// original string. Returns `None` for `"0"`, `""`, or CRS strings that
+ /// deserialize to `None`.
+ pub fn normalize(&mut self, crs_str: &str) -> Result<Option<String>> {
+ if crs_str == "0" || crs_str.is_empty() {
+ return Ok(None);
+ }
+
+ if let Some(cached) = self.cache.get(crs_str) {
+ return Ok(cached.clone());
+ }
+
+ let result = if let Some(crs) = deserialize_crs(crs_str)? {
+ if let Some(auth_code) = crs.to_authority_code()? {
+ Some(auth_code)
+ } else {
+ Some(crs_str.to_string())
+ }
+ } else {
+ None
+ };
+
+ self.cache.insert(crs_str.to_string(), result.clone());
+ Ok(result)
+ }
+}
+
/// Longitude/latitude CRS (WGS84)
///
/// A [`Crs`] that matches EPSG:4326 or OGC:CRS84.