This is an automated email from the ASF dual-hosted git repository.
paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new ad6d474a feat(rust/sedona-functions): Add sd_simplifystorage utility
(#650)
ad6d474a is described below
commit ad6d474a8a9b70742fb8b6aaf0ed8629e2d1d5fd
Author: Dewey Dunnington <[email protected]>
AuthorDate: Mon Feb 23 16:21:54 2026 -0600
feat(rust/sedona-functions): Add sd_simplifystorage utility (#650)
Co-authored-by: Copilot <[email protected]>
---
python/sedonadb/python/sedonadb/dataframe.py | 15 +-
python/sedonadb/src/dataframe.rs | 73 ++++--
python/sedonadb/tests/io/test_pyogrio.py | 27 +++
rust/sedona-functions/src/lib.rs | 1 +
rust/sedona-functions/src/register.rs | 1 +
rust/sedona-functions/src/sd_simplifystorage.rs | 304 ++++++++++++++++++++++++
rust/sedona/src/lib.rs | 1 +
rust/sedona/src/projected_reader.rs | 162 +++++++++++++
8 files changed, 565 insertions(+), 19 deletions(-)
diff --git a/python/sedonadb/python/sedonadb/dataframe.py
b/python/sedonadb/python/sedonadb/dataframe.py
index 74941814..f599dc35 100644
--- a/python/sedonadb/python/sedonadb/dataframe.py
+++ b/python/sedonadb/python/sedonadb/dataframe.py
@@ -220,7 +220,9 @@ class DataFrame:
Args:
requested_schema: A PyCapsule representing the desired output
schema.
"""
- return self._impl.__arrow_c_stream__(requested_schema=requested_schema)
+ return self._impl.to_stream(self._ctx,
simplify=False).__arrow_c_stream__(
+ requested_schema=requested_schema
+ )
def to_view(self, name: str, overwrite: bool = False):
"""Create a view based on the query represented by this object
@@ -518,13 +520,17 @@ class DataFrame:
if driver is None and isinstance(path, str) and
path.endswith(".fgb.zip"):
driver = "FlatGeoBuf"
+ # GDAL does not support newer Arrow types like string views util 3.14,
so we export a
+ # reader with simpler types here
+ self_simplified = self._impl.to_stream(self._ctx, simplify=True)
+
# Writer: pyogrio.write_arrow() via Cython ogr_write_arrow()
#
https://github.com/geopandas/pyogrio/blob/3b2d40273b501c10ecf46cbd37c6e555754c89af/pyogrio/raw.py#L755-L897
#
https://github.com/geopandas/pyogrio/blob/3b2d40273b501c10ecf46cbd37c6e555754c89af/pyogrio/_io.pyx#L2858-L2980
import pyogrio.raw
pyogrio.raw.write_arrow(
- self,
+ self_simplified,
path,
driver=driver,
geometry_type=geometry_type,
@@ -616,6 +622,11 @@ class DataFrame:
else:
return super().__repr__()
+ def _simplify_storage_types(self):
+ return DataFrame(
+ self._ctx, self._impl.simplify_storage_types(self._ctx),
self._options
+ )
+
def _out_width(self, width=None) -> int:
if width is None:
width = self._options.width
diff --git a/python/sedonadb/src/dataframe.rs b/python/sedonadb/src/dataframe.rs
index efceca98..e0a46d5b 100644
--- a/python/sedonadb/src/dataframe.rs
+++ b/python/sedonadb/src/dataframe.rs
@@ -29,10 +29,12 @@ use datafusion::prelude::DataFrame;
use datafusion_common::{Column, DataFusionError, ParamValues};
use datafusion_expr::{ExplainFormat, ExplainOption, Expr};
use datafusion_ffi::table_provider::FFI_TableProvider;
+use futures::lock::Mutex;
use futures::TryStreamExt;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyDict, PyList};
use sedona::context::{SedonaDataFrame, SedonaWriteOptions};
+use sedona::projected_reader::simplify_record_batch_reader;
use sedona::show::{DisplayMode, DisplayTableOptions};
use sedona_geoparquet::options::TableGeoParquetOptions;
use sedona_schema::schema::SedonaSchema;
@@ -46,10 +48,12 @@ use crate::runtime::wait_for_future;
use crate::schema::PySedonaSchema;
#[pyclass]
+#[derive(Clone)]
pub struct InternalDataFrame {
pub inner: DataFrame,
pub runtime: Arc<Runtime>,
}
+
impl InternalDataFrame {
pub fn new(inner: DataFrame, runtime: Arc<Runtime>) -> Self {
Self { inner, runtime }
@@ -176,6 +180,25 @@ impl InternalDataFrame {
Ok(batches)
}
+ fn to_stream<'py>(
+ &self,
+ py: Python<'py>,
+ ctx: &InternalContext,
+ simplify: Option<bool>,
+ ) -> Result<StreamingResult, PySedonaError> {
+ let stream = wait_for_future(py, &self.runtime,
self.inner.clone().execute_stream())??;
+ let reader = PySedonaStreamReader::new(self.runtime.clone(), stream);
+ let mut reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
+
+ if simplify.unwrap_or(false) {
+ reader = simplify_record_batch_reader(&ctx.inner.ctx.state(),
reader)?;
+ }
+
+ Ok(StreamingResult {
+ inner: Some(reader).into(),
+ })
+ }
+
#[allow(clippy::too_many_arguments)]
fn to_parquet<'py>(
&self,
@@ -371,23 +394,6 @@ impl InternalDataFrame {
FFI_TableProvider::new(provider, true,
Some(self.runtime.handle().clone()));
Ok(PyCapsule::new(py, ffi_provider, Some(name))?)
}
-
- #[pyo3(signature = (requested_schema=None))]
- fn __arrow_c_stream__<'py>(
- &self,
- py: Python<'py>,
- requested_schema: Option<Bound<'py, PyAny>>,
- ) -> Result<Bound<'py, PyCapsule>, PySedonaError> {
- check_py_requested_schema(requested_schema,
self.inner.schema().as_arrow())?;
-
- let stream = wait_for_future(py, &self.runtime,
self.inner.clone().execute_stream())??;
- let reader = PySedonaStreamReader::new(self.runtime.clone(), stream);
- let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
-
- let ffi_stream = FFI_ArrowArrayStream::new(reader);
- let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
- Ok(PyCapsule::new(py, ffi_stream, Some(stream_capsule_name))?)
- }
}
#[pyclass]
@@ -418,6 +424,39 @@ impl Batches {
}
}
+#[pyclass]
+pub struct StreamingResult {
+ inner: Mutex<Option<Box<dyn RecordBatchReader + Send>>>,
+}
+
+#[pymethods]
+impl StreamingResult {
+ #[pyo3(signature = (requested_schema=None))]
+ fn __arrow_c_stream__<'py>(
+ &self,
+ py: Python<'py>,
+ requested_schema: Option<Bound<'py, PyAny>>,
+ ) -> Result<Bound<'py, PyCapsule>, PySedonaError> {
+ let Some(mut reader_opt) = self.inner.try_lock() else {
+ return Err(PySedonaError::SedonaPython(
+ "SedonaDB DataFrame streaming result may only be consumed from
a single thread"
+ .to_string(),
+ ));
+ };
+
+ if let Some(reader) = reader_opt.take() {
+ check_py_requested_schema(requested_schema, &reader.schema())?;
+ let ffi_stream = FFI_ArrowArrayStream::new(reader);
+ let stream_capsule_name =
CString::new("arrow_array_stream").unwrap();
+ Ok(PyCapsule::new(py, ffi_stream, Some(stream_capsule_name))?)
+ } else {
+ Err(PySedonaError::SedonaPython(
+ "SedonaDB DataFrame streaming result may only be consumed
once".to_string(),
+ ))
+ }
+ }
+}
+
fn check_py_requested_schema<'py>(
requested_schema: Option<Bound<'py, PyAny>>,
actual_schema: &Schema,
diff --git a/python/sedonadb/tests/io/test_pyogrio.py
b/python/sedonadb/tests/io/test_pyogrio.py
index 297c4198..8b3d8084 100644
--- a/python/sedonadb/tests/io/test_pyogrio.py
+++ b/python/sedonadb/tests/io/test_pyogrio.py
@@ -19,9 +19,11 @@ import io
import tempfile
from pathlib import Path
+import geoarrow.pyarrow as ga
import geopandas
import geopandas.testing
import pandas as pd
+import pyarrow as pa
import pytest
import sedonadb
import shapely
@@ -221,3 +223,28 @@ def test_write_ogr_many_batches(con):
geopandas.testing.assert_geodataframe_equal(
geopandas.read_file(f"{td}/foofy.gpkg"), expected
)
+
+
+def test_write_ogr_from_view_types(con):
+ # Check that we can write something with view types (even though it is
read back
+ # as the simplified type)
+ wkb_array = ga.with_crs(ga.as_wkb(["POINT (0 1)", "POINT (1 2)"]),
ga.OGC_CRS84)
+ wkb_view_array = (
+ ga.wkb_view()
+ .with_crs(ga.OGC_CRS84)
+ .wrap_array(wkb_array.storage.cast(pa.binary_view()))
+ )
+ tab_simple = pa.table(
+ {"string_col": pa.array(["one", "two"], pa.string()), "wkb_geometry":
wkb_array}
+ )
+ tab = pa.table(
+ {
+ "string_col": pa.array(["one", "two"], pa.string_view()),
+ "wkb_geometry": wkb_view_array,
+ }
+ )
+
+ with tempfile.TemporaryDirectory() as td:
+ con.create_data_frame(tab).to_pyogrio(f"{td}/foofy.fgb")
+ tab_roundtrip = con.read_pyogrio(f"{td}/foofy.fgb").to_arrow_table()
+ assert tab_roundtrip.sort_by("string_col") == tab_simple
diff --git a/rust/sedona-functions/src/lib.rs b/rust/sedona-functions/src/lib.rs
index 2c9ad0e1..43c77045 100644
--- a/rust/sedona-functions/src/lib.rs
+++ b/rust/sedona-functions/src/lib.rs
@@ -19,6 +19,7 @@ pub mod executor;
pub mod register;
mod sd_format;
pub mod sd_order;
+mod sd_simplifystorage;
mod st_affine;
mod st_affine_helpers;
pub mod st_analyze_agg;
diff --git a/rust/sedona-functions/src/register.rs
b/rust/sedona-functions/src/register.rs
index 7fdb1cb3..4ff9bcac 100644
--- a/rust/sedona-functions/src/register.rs
+++ b/rust/sedona-functions/src/register.rs
@@ -40,6 +40,7 @@ pub fn default_function_set() -> FunctionSet {
function_set,
crate::sd_format::sd_format_udf,
crate::sd_order::sd_order_udf,
+ crate::sd_simplifystorage::sd_simplifystorage_udf,
crate::st_affine::st_affine_udf,
crate::st_asbinary::st_asbinary_udf,
crate::st_asewkb::st_asewkb_udf,
diff --git a/rust/sedona-functions/src/sd_simplifystorage.rs
b/rust/sedona-functions/src/sd_simplifystorage.rs
new file mode 100644
index 00000000..ad107d5c
--- /dev/null
+++ b/rust/sedona-functions/src/sd_simplifystorage.rs
@@ -0,0 +1,304 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use arrow_array::ArrayRef;
+use arrow_schema::{DataType, FieldRef, UnionFields};
+use datafusion_common::{config::ConfigOptions, datatype::DataTypeExt, Result,
ScalarValue};
+use datafusion_expr::{ColumnarValue, Volatility};
+use sedona_common::sedona_internal_err;
+use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF};
+use sedona_schema::datatypes::SedonaType;
+
+/// SD_SimplifyStorage() scalar UDF implementation
+///
+/// This function is invoked to strip strip view, dictionary, or run-end
encoded
+/// types from storage if needed (or return the input otherwise). This is to
support
+/// integration with other libraries like GDAL that haven't yet supported these
+/// storage encodings.
+pub fn sd_simplifystorage_udf() -> SedonaScalarUDF {
+ SedonaScalarUDF::new(
+ "sd_simplifystorage",
+ vec![Arc::new(SDSimplifyStorage {})],
+ Volatility::Immutable,
+ )
+}
+
+#[derive(Debug)]
+struct SDSimplifyStorage {}
+
+impl SedonaScalarKernel for SDSimplifyStorage {
+ fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
+ let field = args[0].to_storage_field("", true)?;
+ let new_field = simplify_field(field.into());
+ Ok(Some(SedonaType::from_storage_field(&new_field)?))
+ }
+
+ fn invoke_batch_from_args(
+ &self,
+ _arg_types: &[SedonaType],
+ args: &[ColumnarValue],
+ return_type: &SedonaType,
+ _num_rows: usize,
+ _config_options: Option<&ConfigOptions>,
+ ) -> Result<ColumnarValue> {
+ let target = Arc::new(return_type.to_storage_field("", true)?);
+ match &args[0] {
+ ColumnarValue::Array(array) => {
+ Ok(ColumnarValue::Array(simplify_array(array, &target)?))
+ }
+ ColumnarValue::Scalar(scalar_value) => {
+ let array = simplify_array(&scalar_value.to_array()?,
&target)?;
+ Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
+ &array, 0,
+ )?))
+ }
+ }
+ }
+
+ fn invoke_batch(
+ &self,
+ _arg_types: &[SedonaType],
+ _args: &[ColumnarValue],
+ ) -> Result<ColumnarValue> {
+ sedona_internal_err!("Unexpected call to invoke_batch()")
+ }
+}
+
+fn simplify_field(field: FieldRef) -> FieldRef {
+ let new_type = match field.data_type() {
+ DataType::BinaryView => DataType::Binary,
+ DataType::Utf8View => DataType::Utf8,
+ DataType::Dictionary(_key_type, value_type) => {
+ simplify_field(value_type.clone().into_nullable_field_ref())
+ .data_type()
+ .clone()
+ }
+ DataType::RunEndEncoded(_run_ends, values) => {
+ simplify_field(values.clone()).data_type().clone()
+ }
+ DataType::ListView(field) | DataType::List(field) => {
+ DataType::List(simplify_field(field.clone()))
+ }
+ DataType::LargeListView(field) | DataType::LargeList(field) => {
+ DataType::LargeList(simplify_field(field.clone()))
+ }
+ DataType::FixedSizeList(field, list_size) => {
+ DataType::FixedSizeList(simplify_field(field.clone()), *list_size)
+ }
+ DataType::Struct(fields) => {
+
DataType::Struct(fields.into_iter().cloned().map(simplify_field).collect())
+ }
+ DataType::Union(union_fields, union_mode) => {
+ let new_fields = union_fields
+ .iter()
+ .map(|(_, field)| simplify_field(field.clone()))
+ .collect::<Vec<_>>();
+ let new_ids = union_fields.iter().map(|(idx, _)|
idx).collect::<Vec<_>>();
+ let new_union_fields = UnionFields::new(new_ids, new_fields);
+ DataType::Union(new_union_fields, *union_mode)
+ }
+ DataType::Map(field, is_ordered) => {
+ DataType::Map(simplify_field(field.clone()), *is_ordered)
+ }
+ _ => field.data_type().clone(),
+ };
+
+ let new_nullable = if let DataType::RunEndEncoded(_, values) =
field.data_type() {
+ field.is_nullable() || values.is_nullable()
+ } else {
+ field.is_nullable()
+ };
+
+ field
+ .as_ref()
+ .clone()
+ .with_data_type(new_type)
+ .with_nullable(new_nullable)
+ .into()
+}
+
+fn simplify_array(array: &ArrayRef, target: &FieldRef) -> Result<ArrayRef> {
+ Ok(datafusion_common::arrow::compute::cast(
+ array,
+ target.data_type(),
+ )?)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use arrow_schema::{DataType, Field};
+ use datafusion_expr::ScalarUDF;
+ use sedona_schema::datatypes::{SedonaType, WKB_GEOMETRY,
WKB_VIEW_GEOMETRY};
+ use sedona_testing::testers::ScalarUdfTester;
+
+ #[test]
+ fn udf_metadata() {
+ let udf: ScalarUDF = sd_simplifystorage_udf().into();
+ assert_eq!(udf.name(), "sd_simplifystorage");
+ }
+
+ #[test]
+ fn simplify_identity() {
+ let udf = sd_simplifystorage_udf();
+ let types_that_dont_need_simplification = [
+ SedonaType::Arrow(DataType::Utf8),
+ SedonaType::Arrow(DataType::LargeUtf8),
+ SedonaType::Arrow(DataType::Binary),
+ SedonaType::Arrow(DataType::LargeBinary),
+ SedonaType::Arrow(DataType::Struct(
+ vec![Field::new("foofy", DataType::Utf8, false)].into(),
+ )),
+ SedonaType::Arrow(DataType::new_list(DataType::Utf8, true)),
+ SedonaType::Arrow(DataType::List(
+ WKB_GEOMETRY.to_storage_field("item", true).unwrap().into(),
+ )),
+ WKB_GEOMETRY,
+ ];
+
+ for sedona_type in types_that_dont_need_simplification {
+ let tester = ScalarUdfTester::new(udf.clone().into(),
vec![sedona_type.clone()]);
+ tester.assert_return_type(sedona_type.clone());
+
+ let initial_scalar = ScalarValue::Null
+ .cast_to(sedona_type.storage_type())
+ .unwrap();
+ let scalar_result =
tester.invoke_scalar(initial_scalar.clone()).unwrap();
+ assert_eq!(scalar_result, initial_scalar);
+
+ let initial_array = initial_scalar.to_array_of_size(10).unwrap();
+ let array_result =
tester.invoke_array(initial_array.clone()).unwrap();
+ assert_eq!(&array_result, &initial_array);
+ }
+ }
+
+ #[test]
+ fn simplify_actually() {
+ let udf = sd_simplifystorage_udf();
+
+ let cases = [
+ // Check primitive types that need to be simplified
+ (
+ SedonaType::Arrow(DataType::Utf8View),
+ SedonaType::Arrow(DataType::Utf8),
+ ),
+ (
+ SedonaType::Arrow(DataType::BinaryView),
+ SedonaType::Arrow(DataType::Binary),
+ ),
+ // Check nested types that need to be recursively simplified
+ (
+ SedonaType::Arrow(DataType::new_list(DataType::Utf8View,
true)),
+ SedonaType::Arrow(DataType::new_list(DataType::Utf8, true)),
+ ),
+ (
+ SedonaType::Arrow(DataType::new_large_list(DataType::Utf8View,
true)),
+ SedonaType::Arrow(DataType::new_large_list(DataType::Utf8,
true)),
+ ),
+ (
+
SedonaType::Arrow(DataType::new_fixed_size_list(DataType::Utf8View, 2, true)),
+
SedonaType::Arrow(DataType::new_fixed_size_list(DataType::Utf8, 2, true)),
+ ),
+ (
+ SedonaType::Arrow(DataType::Struct(
+ vec![Field::new("foofy", DataType::Utf8View,
false)].into(),
+ )),
+ SedonaType::Arrow(DataType::Struct(
+ vec![Field::new("foofy", DataType::Utf8, false)].into(),
+ )),
+ ),
+ // Check dictionary types
+ (
+ SedonaType::Arrow(DataType::Dictionary(
+ Box::new(DataType::Int8),
+ Box::new(DataType::Utf8View),
+ )),
+ SedonaType::Arrow(DataType::Utf8),
+ ),
+ // Check run-end encoded
+ (
+ SedonaType::Arrow(DataType::RunEndEncoded(
+ DataType::Int32.into_nullable_field_ref(),
+ DataType::Utf8View.into_nullable_field_ref(),
+ )),
+ SedonaType::Arrow(DataType::Utf8),
+ ),
+ // Check complex complex nested types that need to be recursively
simplified
+ (
+ SedonaType::Arrow(DataType::ListView(
+ Field::new("item", DataType::Utf8, true).into(),
+ )),
+ SedonaType::Arrow(DataType::new_list(DataType::Utf8, true)),
+ ),
+ (
+ SedonaType::Arrow(DataType::LargeListView(
+ Field::new("item", DataType::Utf8, true).into(),
+ )),
+ SedonaType::Arrow(DataType::new_large_list(DataType::Utf8,
true)),
+ ),
+ // Extension type metadata should be propagated
+ (WKB_VIEW_GEOMETRY, WKB_GEOMETRY),
+ (
+ SedonaType::Arrow(DataType::List(
+ WKB_VIEW_GEOMETRY
+ .to_storage_field("item", true)
+ .unwrap()
+ .into(),
+ )),
+ SedonaType::Arrow(DataType::List(
+ WKB_GEOMETRY.to_storage_field("item",
true).unwrap().into(),
+ )),
+ ),
+ ];
+
+ for (initial_type, simplified_type) in cases {
+ let tester = ScalarUdfTester::new(udf.clone().into(),
vec![initial_type.clone()]);
+ let return_type = tester.return_type().unwrap();
+ assert_eq!(
+ return_type,
+ simplified_type,
+ "expected {initial_type:?} to simplify to {simplified_type:?}
but got {return_type:?}"
+ );
+
+ // A few types aren't well supported by Arrow/DataFusion internals
which make it
+ // difficult to create test data.
+ if !matches!(
+ initial_type,
+ SedonaType::Arrow(DataType::RunEndEncoded(_, _))
+ | SedonaType::Arrow(DataType::ListView(_))
+ | SedonaType::Arrow(DataType::LargeListView(_))
+ ) {
+ let initial_scalar = ScalarValue::Null
+ .cast_to(initial_type.storage_type())
+ .unwrap();
+ let expected_scalar = ScalarValue::Null
+ .cast_to(simplified_type.storage_type())
+ .unwrap();
+
+ let scalar_result =
tester.invoke_scalar(initial_scalar.clone()).unwrap();
+ assert_eq!(scalar_result, expected_scalar);
+
+ let initial_array =
initial_scalar.to_array_of_size(10).unwrap();
+ let expected_array =
expected_scalar.to_array_of_size(10).unwrap();
+ let array_result = tester.invoke_array(initial_array).unwrap();
+ assert_eq!(&array_result, &expected_array);
+ }
+ }
+ }
+}
diff --git a/rust/sedona/src/lib.rs b/rust/sedona/src/lib.rs
index 6752c328..81c103f9 100644
--- a/rust/sedona/src/lib.rs
+++ b/rust/sedona/src/lib.rs
@@ -21,6 +21,7 @@ mod exec;
pub mod memory_pool;
mod object_storage;
pub mod pool_type;
+pub mod projected_reader;
pub mod random_geometry_provider;
pub mod reader;
pub mod record_batch_reader_provider;
diff --git a/rust/sedona/src/projected_reader.rs
b/rust/sedona/src/projected_reader.rs
new file mode 100644
index 00000000..3d6bc6ce
--- /dev/null
+++ b/rust/sedona/src/projected_reader.rs
@@ -0,0 +1,162 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use arrow_array::{RecordBatch, RecordBatchIterator, RecordBatchReader};
+use arrow_schema::{ArrowError, Schema};
+use datafusion::{
+ catalog::Session,
+ physical_expr::ScalarFunctionExpr,
+ physical_plan::{expressions::Column, PhysicalExpr},
+};
+use datafusion_common::Result;
+use datafusion_expr::ReturnFieldArgs;
+use sedona_common::sedona_internal_datafusion_err;
+
+/// Project a [RecordBatchReader] using a list of physical expressions
+///
+/// This is useful to wrap results with projections that don't change the
ordering
+/// of the input, as adding a projection node may otherwise trigger
DataFusion's
+/// optimizer to introduce round robin partitioning (which in turn can result
in
+/// non-deterministic ordering in some situations).
+pub fn projected_record_batch_reader(
+ reader: Box<dyn RecordBatchReader + Send>,
+ projection: Vec<(Arc<dyn PhysicalExpr>, String)>,
+) -> Result<Box<dyn RecordBatchReader + Send>> {
+ let existing_schema = reader.schema();
+ let new_fields = projection
+ .iter()
+ .map(|expr| {
+ Ok(expr
+ .0
+ .return_field(&existing_schema)?
+ .as_ref()
+ .clone()
+ .with_name(expr.1.clone()))
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let new_schema = Arc::new(Schema::new_with_metadata(
+ new_fields,
+ existing_schema.metadata().clone(),
+ ));
+ let iter_schema = new_schema.clone();
+ let reader_iter = reader.map(move |maybe_batch| {
+ let batch = maybe_batch?;
+ let new_columns = projection
+ .iter()
+ .map(|expr| expr.0.evaluate(&batch)?.to_array(batch.num_rows()))
+ .collect::<Result<Vec<_>>>()
+ .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
+ RecordBatch::try_new(iter_schema.clone(), new_columns)
+ });
+
+ Ok(Box::new(RecordBatchIterator::new(reader_iter, new_schema)))
+}
+
+/// Project a [RecordBatchReader] with a projection that simplifies the
storage types of
+/// all columns
+///
+/// This is useful for wrapping a result reader when exporting a result to an
Arrow-consuming
+/// library that does not support newer Arrow types like String views. This is
similar
+/// in spirit to calling SD_SimplifyStorage(col) AS col for all columns except
this version
+/// does not cause non-deterministic row order.
+pub fn simplify_record_batch_reader(
+ ctx: &dyn Session,
+ reader: Box<dyn RecordBatchReader + Send>,
+) -> Result<Box<dyn RecordBatchReader + Send>> {
+ let existing_schema = reader.schema();
+ let config_options = Arc::new(ctx.config_options().clone());
+ let udf = ctx
+ .scalar_functions()
+ .get("sd_simplifystorage")
+ .ok_or_else(|| sedona_internal_datafusion_err!("Expected
sd_simplifystorage UDF"))?;
+ let projection = existing_schema
+ .fields()
+ .iter()
+ .enumerate()
+ .map(|(i, f)| {
+ let arg_field = Arc::new(f.clone());
+ let col_expr = Column::new(arg_field.name(), i);
+ let return_field = udf.return_field_from_args(ReturnFieldArgs {
+ arg_fields: std::slice::from_ref(&arg_field),
+ scalar_arguments: &[None],
+ })?;
+ let expr = Arc::new(ScalarFunctionExpr::new(
+ udf.name(),
+ Arc::clone(udf),
+ vec![Arc::new(col_expr)],
+ return_field,
+ config_options.clone(),
+ ));
+ Ok((expr as Arc<dyn PhysicalExpr>, arg_field.name().clone()))
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ projected_record_batch_reader(reader, projection)
+}
+
+#[cfg(test)]
+mod test {
+ use arrow_array::{create_array, ArrayRef};
+ use arrow_schema::{DataType, Field};
+
+ use crate::context::SedonaContext;
+
+ use super::*;
+
+ #[test]
+ fn simplify_storage_types() {
+ let ctx = SedonaContext::new();
+
+ let view_array = create_array!(
+ Utf8View,
+ [
+ Some("foofy1"),
+ None,
+ Some("foofy2 longer than twelve chars")
+ ]
+ ) as ArrayRef;
+ let simple_array = create_array!(
+ Utf8,
+ [
+ Some("foofy1"),
+ None,
+ Some("foofy2 longer than twelve chars")
+ ]
+ ) as ArrayRef;
+ let view_batch = RecordBatch::try_from_iter([("col1",
view_array)]).unwrap();
+ let simple_batch = RecordBatch::try_from_iter([("col1",
simple_array)]).unwrap();
+
+ let view_reader =
+ RecordBatchIterator::new([view_batch.clone()].map(Ok),
view_batch.schema());
+ let wrapped_reader =
+ simplify_record_batch_reader(&ctx.ctx.state(),
Box::new(view_reader)).unwrap();
+ assert_eq!(
+ wrapped_reader.schema().field(0),
+ &Field::new("col1", DataType::Utf8, true)
+ );
+
+ let wrapped_batches = wrapped_reader
+ .collect::<Vec<_>>()
+ .into_iter()
+ .map(|b| b.unwrap())
+ .collect::<Vec<_>>();
+ assert_eq!(wrapped_batches, vec![simple_batch])
+ }
+}