This is an automated email from the ASF dual-hosted git repository.
paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new 90ba5b4 feat(python/sedonadb): Implement Python UDFs (#228)
90ba5b4 is described below
commit 90ba5b41505ba1a1ee2425dc84a0663e7911c396
Author: Dewey Dunnington <[email protected]>
AuthorDate: Fri Oct 24 20:34:00 2025 -0500
feat(python/sedonadb): Implement Python UDFs (#228)
Co-authored-by: Copilot <[email protected]>
---
Cargo.lock | 1 +
docs/reference/python.md | 2 +
python/sedonadb/Cargo.toml | 1 +
python/sedonadb/pyproject.toml | 1 +
python/sedonadb/python/sedonadb/context.py | 34 +++
python/sedonadb/python/sedonadb/udf.py | 316 ++++++++++++++++++++++++++
python/sedonadb/src/context.rs | 51 ++++-
python/sedonadb/src/error.rs | 6 +
python/sedonadb/src/import_from.rs | 76 ++++++-
python/sedonadb/src/lib.rs | 4 +-
python/sedonadb/src/schema.rs | 1 +
python/sedonadb/src/udf.rs | 341 +++++++++++++++++++++++++++++
python/sedonadb/tests/test_udf.py | 264 ++++++++++++++++++++++
rust/sedona-expr/src/scalar_udf.rs | 14 +-
rust/sedona-schema/src/matchers.rs | 9 +-
15 files changed, 1109 insertions(+), 12 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 94550ec..59ce6bb 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -5206,6 +5206,7 @@ dependencies = [
"pyo3",
"sedona",
"sedona-adbc",
+ "sedona-expr",
"sedona-geoparquet",
"sedona-proj",
"sedona-schema",
diff --git a/docs/reference/python.md b/docs/reference/python.md
index b1b6cc4..5a93ab5 100644
--- a/docs/reference/python.md
+++ b/docs/reference/python.md
@@ -25,3 +25,5 @@
::: sedonadb.testing
::: sedonadb.dbapi
+
+::: sedonadb.udf
diff --git a/python/sedonadb/Cargo.toml b/python/sedonadb/Cargo.toml
index 98379bd..939a48e 100644
--- a/python/sedonadb/Cargo.toml
+++ b/python/sedonadb/Cargo.toml
@@ -42,6 +42,7 @@ futures = { workspace = true }
pyo3 = { version = "0.25.1" }
sedona = { path = "../../rust/sedona" }
sedona-adbc = { path = "../../rust/sedona-adbc" }
+sedona-expr = { path = "../../rust/sedona-expr" }
sedona-geoparquet = { path = "../../rust/sedona-geoparquet" }
sedona-schema = { path = "../../rust/sedona-schema" }
sedona-proj = { path = "../../c/sedona-proj", default-features = false }
diff --git a/python/sedonadb/pyproject.toml b/python/sedonadb/pyproject.toml
index d857b3c..ff949df 100644
--- a/python/sedonadb/pyproject.toml
+++ b/python/sedonadb/pyproject.toml
@@ -33,6 +33,7 @@ dynamic = ["version"]
test = [
"adbc-driver-manager[dbapi]",
"adbc-driver-postgresql",
+ "datafusion",
"duckdb",
"geoarrow-pyarrow",
"geopandas",
diff --git a/python/sedonadb/python/sedonadb/context.py
b/python/sedonadb/python/sedonadb/context.py
index 6348695..f1c4827 100644
--- a/python/sedonadb/python/sedonadb/context.py
+++ b/python/sedonadb/python/sedonadb/context.py
@@ -170,6 +170,40 @@ class SedonaContext:
"""
return DataFrame(self._impl, self._impl.sql(sql), self.options)
+ def register_udf(self, udf: Any):
+ """Register a user-defined function
+
+ Args:
+ udf: An object implementing the DataFusion PyCapsule protocol
+ (i.e., `__datafusion_scalar_udf__`) or a function annotated
+ with [arrow_udf][sedonadb.udf.arrow_udf].
+
+ Examples:
+
+ >>> import pyarrow as pa
+ >>> from sedonadb import udf
+ >>> sd = sedona.db.connect()
+ >>> @udf.arrow_udf(pa.int64(), [udf.STRING])
+ ... def char_count(arg0):
+ ... arg0 = pa.array(arg0.to_array())
+ ...
+ ... return pa.array(
+ ... (len(item) for item in arg0.to_pylist()),
+ ... pa.int64()
+ ... )
+ ...
+ >>> sd.register_udf(char_count)
+ >>> sd.sql("SELECT char_count('abcde') as col").show()
+ ┌───────┐
+ │ col │
+ │ int64 │
+ ╞═══════╡
+ │ 5 │
+ └───────┘
+
+ """
+ self._impl.register_udf(udf)
+
def connect() -> SedonaContext:
"""Create a new [SedonaContext][sedonadb.context.SedonaContext]"""
diff --git a/python/sedonadb/python/sedonadb/udf.py
b/python/sedonadb/python/sedonadb/udf.py
new file mode 100644
index 0000000..236243c
--- /dev/null
+++ b/python/sedonadb/python/sedonadb/udf.py
@@ -0,0 +1,316 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import inspect
+from typing import Any, Literal, Optional, List, Union
+
+from sedonadb._lib import sedona_scalar_udf
+from sedonadb.utility import sedona # noqa: F401
+
+
+class TypeMatcher(str):
+ """Helper class to mark type matchers that can be used as the
`input_types` for
+ user-defined functions
+
+ Note that the internal storage of the type matcher (currently a string) is
+ arbitrary and may change in a future release. Use the constants provided by
+ the `udf` module.
+ """
+
+ pass
+
+
+def arrow_udf(
+ return_type: Any,
+ input_types: List[Union[TypeMatcher, Any]] = None,
+ volatility: Literal["immutable", "stable", "volatile"] = "immutable",
+ name: Optional[str] = None,
+):
+ """Generic Arrow-based user-defined scalar function decorator
+
+ This decorator may be used to annotate a function that accepts arguments as
+ Arrow array wrappers implementing the
+ [Arrow PyCapsule
Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
+ The annotated function must return a value of a consistent length of the
+ appropriate type.
+
+ !!! warning
+ SedonaDB will call the provided function from multiple threads.
Attempts
+ to modify shared state from the body of the function may crash or cause
+ unusual behaviour.
+
+ SedonaDB Python UDFs are experimental and this interface may change based
on
+ user feedback.
+
+ Args:
+ return_type: One of
+ - A data type (e.g., pyarrow.DataType, arro3.core.DataType,
nanoarrow.Schema)
+ if this function returns the same type regardless of its inputs.
+ - A function of `arg_types` (list of data types) and `scalar_args`
(list of
+ optional scalars) that returns a data type. This function is also
+ responsible for returning `None` if this function does not apply
to the
+ input types.
+ input_types: One of
+ - A list where each member is a data type or a `TypeMatcher`. The
+ `udf.GEOMETRY` and `udf.GEOGRAPHY` type matchers are the most
useful
+ because otherwise the function will only match spatial data
types whose
+ coordinate reference system (CRS) also matches (i.e., based on
simple
+ equality). Using these type matchers will also ensure input CRS
consistency
+ and will automatically propagate input CRSes into the output.
+ - `None`, indicating that this function can accept any number of
arguments
+ of any type. Usually this is paired with a functional
`return_type` that
+ dynamically computes a return type or returns `None` if the
number or
+ types of arguments do not match.
+ volatility: Use "immutable" for functions whose output is always
consistent
+ for the same inputs (even between queries); use "stable" for
functions
+ whose output is always consistent for the same inputs but only
within
+ the same query, and use "volatile" for functions that generate
random
+ or otherwise non-deterministic output.
+ name: An optional name for the UDF. If not given, it will be derived
from
+ the name of the provided function.
+
+ Examples:
+
+ >>> import pyarrow as pa
+ >>> from sedonadb import udf
+ >>> sd = sedona.db.connect()
+
+ The simplest scalar UDF only specifies return types. This implies that
+ the function can handle input of any type.
+
+ >>> @udf.arrow_udf(pa.string())
+ ... def some_udf(arg0, arg1):
+ ... arg0, arg1 = (
+ ... pa.array(arg0.to_array()).to_pylist(),
+ ... pa.array(arg1.to_array()).to_pylist(),
+ ... )
+ ... return pa.array(
+ ... (f"{item0} / {item1}" for item0, item1 in zip(arg0, arg1)),
+ ... pa.string(),
+ ... )
+ ...
+ >>> sd.register_udf(some_udf)
+ >>> sd.sql("SELECT some_udf(123, 'abc') as col").show()
+ ┌───────────┐
+ │ col │
+ │ utf8 │
+ ╞═══════════╡
+ │ 123 / abc │
+ └───────────┘
+
+ Use the `TypeMatcher` constants where possible to specify input.
+ This ensures that the function can handle the usual range of input
+ types that might exist for a given input.
+
+ >>> @udf.arrow_udf(pa.int64(), [udf.STRING])
+ ... def char_count(arg0):
+ ... arg0 = pa.array(arg0.to_array())
+ ...
+ ... return pa.array(
+ ... (len(item) for item in arg0.to_pylist()),
+ ... pa.int64()
+ ... )
+ ...
+ >>> sd.register_udf(char_count)
+ >>> sd.sql("SELECT char_count('abcde') as col").show()
+ ┌───────┐
+ │ col │
+ │ int64 │
+ ╞═══════╡
+ │ 5 │
+ └───────┘
+
+ In this case, the type matcher ensures we can also use the function
+ for string view input which is the usual type SedonaDB emits when
+ reading Parquet files.
+
+ >>> sd.sql("SELECT char_count(arrow_cast('abcde', 'Utf8View')) as
col").show()
+ ┌───────┐
+ │ col │
+ │ int64 │
+ ╞═══════╡
+ │ 5 │
+ └───────┘
+
+ Geometry UDFs are best written using Shapely because pyproj (including
its use
+ in GeoPandas) is not thread safe and can crash when attempting to look
up
+ CRSes when importing an Arrow array. The UDF framework supports
returning
+ geometry storage to make this possible. Coordinate reference system
metadata
+ is propagated automatically from the input.
+
+ >>> import shapely
+ >>> import geoarrow.pyarrow as ga
+ >>> @udf.arrow_udf(ga.wkb(), [udf.GEOMETRY, udf.NUMERIC])
+ ... def shapely_udf(geom, distance):
+ ... geom_wkb = pa.array(geom.storage.to_array())
+ ... distance = pa.array(distance.to_array())
+ ... geom = shapely.from_wkb(geom_wkb)
+ ... result_shapely = shapely.buffer(geom, distance)
+ ... return pa.array(shapely.to_wkb(result_shapely))
+ ...
+ >>>
+ >>> sd.register_udf(shapely_udf)
+ >>> sd.sql("SELECT ST_SRID(shapely_udf(ST_Point(0, 0), 2.0)) as
col").show()
+ ┌────────┐
+ │ col │
+ │ uint32 │
+ ╞════════╡
+ │ 0 │
+ └────────┘
+
+ >>> sd.sql("SELECT ST_SRID(shapely_udf(ST_SetSRID(ST_Point(0, 0),
3857), 2.0)) as col").show()
+ ┌────────┐
+ │ col │
+ │ uint32 │
+ ╞════════╡
+ │ 3857 │
+ └────────┘
+
+ Annotated functions may also declare keyword arguments `return_type`
and/or `num_rows`,
+ which will be passed the appropriate value by the UDF framework. This
facilitates writing
+ generic UDFs and/or UDFs with no arguments.
+
+ >>> import numpy as np
+ >>> def random_impl(return_type, num_rows):
+ ... pa_type = pa.field(return_type).type
+ ... return pa.array(np.random.random(num_rows), pa_type)
+ ...
+ >>> @udf.arrow_udf(pa.float32(), [])
+ ... def random_f32(*, return_type=None, num_rows=None):
+ ... return random_impl(return_type, num_rows)
+ ...
+ >>> @udf.arrow_udf(pa.float64(), [])
+ ... def random_f64(*, return_type=None, num_rows=None):
+ ... return random_impl(return_type, num_rows)
+ ...
+ >>> np.random.seed(487)
+ >>> sd.register_udf(random_f32)
+ >>> sd.register_udf(random_f64)
+ >>> sd.sql("SELECT random_f32() AS f32, random_f64() as f64;").show()
+ ┌────────────┬─────────────────────┐
+ │ f32 ┆ f64 │
+ │ float32 ┆ float64 │
+ ╞════════════╪═════════════════════╡
+ │ 0.35385555 ┆ 0.24793247139474195 │
+ └────────────┴─────────────────────┘
+
+ """
+
+ def decorator(func):
+ kwarg_names = _callable_kwarg_only_names(func)
+ if "return_type" in kwarg_names and "num_rows" in kwarg_names:
+
+ def func_wrapper(args, return_type, num_rows):
+ return func(*args, return_type=return_type, num_rows=num_rows)
+ elif "return_type" in kwarg_names:
+
+ def func_wrapper(args, return_type, num_rows):
+ return func(*args, return_type=return_type)
+ elif "num_rows" in kwarg_names:
+
+ def func_wrapper(args, return_type, num_rows):
+ return func(*args, num_rows=num_rows)
+ else:
+
+ def func_wrapper(args, return_type, num_rows):
+ return func(*args)
+
+ name_arg = func.__name__ if name is None and hasattr(func, "__name__")
else name
+ return ScalarUdfImpl(
+ func_wrapper, return_type, input_types, volatility, name_arg
+ )
+
+ return decorator
+
+
+BINARY: TypeMatcher = "binary"
+"""Match any binary argument (i.e., binary, binary view, large binary,
+fixed-size binary)"""
+
+BOOLEAN: TypeMatcher = "boolean"
+"""Match a boolean argument"""
+
+GEOGRAPHY: TypeMatcher = "geography"
+"""Match a geography argument"""
+
+GEOMETRY: TypeMatcher = "geometry"
+"""Match a geometry argument"""
+
+NUMERIC: TypeMatcher = "numeric"
+"""Match any numeric argument"""
+
+STRING: TypeMatcher = "string"
+"""Match any string argument (i.e., string, string view, large string)"""
+
+
+class ScalarUdfImpl:
+ """Scalar user-defined function wrapper
+
+ This class is a wrapper class used as the return value for user-defined
+ function constructors. This wrapper allows the UDF to be registered with
+ a SedonaDB context or any context that accepts DataFusion Python
+ Scalar UDFs. This object is not intended to be used to call a UDF.
+ """
+
+ def __init__(
+ self,
+ invoke_batch,
+ return_type,
+ input_types=None,
+ volatility: Literal["immutable", "stable", "volatile"] = "immutable",
+ name: Optional[str] = None,
+ ):
+ # If the input_types are None, the return_type must be callable when
passed
+ # to the internals. In the Python API we allow a data type as the
return type
+ # to the argument easier to understand, which means we may have to wrap
+ # it in a callable here.
+ if input_types is None and not callable(return_type):
+
+ def return_type_impl(*args, **kwargs):
+ return return_type
+
+ self._return_type = return_type_impl
+ else:
+ self._return_type = return_type
+
+ self._invoke_batch = invoke_batch
+ self._input_types = input_types
+ if name is None and hasattr(invoke_batch, "__name__"):
+ self._name = invoke_batch.__name__
+ else:
+ self._name = name
+
+ self._volatility = volatility
+
+ def __sedona_internal_udf__(self):
+ return sedona_scalar_udf(
+ self._invoke_batch,
+ self._return_type,
+ self._input_types,
+ self._volatility,
+ self._name,
+ )
+
+ def __datafusion_scalar_udf__(self):
+ return self.__sedona_internal_udf__().__datafusion_scalar_udf__()
+
+
+def _callable_kwarg_only_names(f):
+ sig = inspect.signature(f)
+ return [
+ k for k, p in sig.parameters.items() if p.kind ==
inspect.Parameter.KEYWORD_ONLY
+ ]
diff --git a/python/sedonadb/src/context.rs b/python/sedonadb/src/context.rs
index 0e39a7c..4c48048 100644
--- a/python/sedonadb/src/context.rs
+++ b/python/sedonadb/src/context.rs
@@ -16,13 +16,17 @@
// under the License.
use std::{collections::HashMap, sync::Arc};
+use datafusion_expr::ScalarUDFImpl;
use pyo3::prelude::*;
use sedona::context::SedonaContext;
use tokio::runtime::Runtime;
use crate::{
- dataframe::InternalDataFrame, error::PySedonaError,
- import_from::import_table_provider_from_any, runtime::wait_for_future,
+ dataframe::InternalDataFrame,
+ error::PySedonaError,
+ import_from::{import_ffi_scalar_udf, import_table_provider_from_any},
+ runtime::wait_for_future,
+ udf::PySedonaScalarUdf,
};
#[pyclass]
@@ -116,4 +120,47 @@ impl InternalContext {
self.inner.ctx.deregister_table(table_ref)?;
Ok(())
}
+
+ pub fn scalar_udf(&self, name: &str) -> Result<PySedonaScalarUdf,
PySedonaError> {
+ if let Some(sedona_scalar_udf) = self.inner.functions.scalar_udf(name)
{
+ Ok(PySedonaScalarUdf {
+ inner: sedona_scalar_udf.clone(),
+ })
+ } else {
+ Err(PySedonaError::SedonaPython(format!(
+ "Sedona scalar UDF with name {name} was not found"
+ )))
+ }
+ }
+
+ pub fn register_udf(&mut self, udf: Bound<PyAny>) -> Result<(),
PySedonaError> {
+ if udf.hasattr("__sedona_internal_udf__")? {
+ let py_scalar_udf = udf
+ .getattr("__sedona_internal_udf__")?
+ .call0()?
+ .extract::<PySedonaScalarUdf>()?;
+ let name = py_scalar_udf.inner.name();
+ self.inner
+ .functions
+ .insert_scalar_udf(py_scalar_udf.inner.clone());
+ self.inner.ctx.register_udf(
+ self.inner
+ .functions
+ .scalar_udf(name)
+ .unwrap()
+ .clone()
+ .into(),
+ );
+ return Ok(());
+ } else if udf.hasattr("__datafusion_scalar_udf__")? {
+ let scalar_udf = import_ffi_scalar_udf(&udf)?;
+ self.inner.ctx.register_udf(scalar_udf);
+ return Ok(());
+ }
+
+ Err(PySedonaError::SedonaPython(
+ "Expected an object implementing __sedona_internal_udf__ or
__datafusion_scalar_udf__"
+ .to_string(),
+ ))
+ }
}
diff --git a/python/sedonadb/src/error.rs b/python/sedonadb/src/error.rs
index c274ed1..a3bc309 100644
--- a/python/sedonadb/src/error.rs
+++ b/python/sedonadb/src/error.rs
@@ -51,6 +51,12 @@ impl From<DataFusionError> for PySedonaError {
}
}
+impl From<PySedonaError> for DataFusionError {
+ fn from(other: PySedonaError) -> Self {
+ DataFusionError::External(Box::new(other))
+ }
+}
+
impl From<PyErr> for PySedonaError {
fn from(other: PyErr) -> Self {
PySedonaError::Py(other)
diff --git a/python/sedonadb/src/import_from.rs
b/python/sedonadb/src/import_from.rs
index b6c694e..e31b0be 100644
--- a/python/sedonadb/src/import_from.rs
+++ b/python/sedonadb/src/import_from.rs
@@ -20,18 +20,26 @@ use std::{
};
use arrow_array::{
- ffi::FFI_ArrowSchema,
+ ffi::{FFI_ArrowArray, FFI_ArrowSchema},
ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream},
- RecordBatchReader,
+ make_array, ArrayRef, RecordBatchReader,
};
-use arrow_schema::Schema;
+use arrow_schema::{Field, Schema};
use datafusion::catalog::TableProvider;
-use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
+use datafusion_expr::ScalarUDF;
+use datafusion_ffi::{
+ table_provider::{FFI_TableProvider, ForeignTableProvider},
+ udf::{FFI_ScalarUDF, ForeignScalarUDF},
+};
use pyo3::{
types::{PyAnyMethods, PyCapsule, PyCapsuleMethods},
Bound, PyAny, Python,
};
use sedona::record_batch_reader_provider::RecordBatchReaderProvider;
+use sedona_schema::{
+ datatypes::SedonaType,
+ matchers::{ArgMatcher, TypeMatcher},
+};
use crate::error::PySedonaError;
@@ -63,6 +71,13 @@ pub fn import_ffi_table_provider(
Ok(Arc::new(provider))
}
+pub fn import_ffi_scalar_udf(obj: &Bound<PyAny>) -> Result<ScalarUDF,
PySedonaError> {
+ let capsule = obj.getattr("__datafusion_scalar_udf__")?.call0()?;
+ let udf_ptr = check_pycapsule(&capsule, "datafusion_scalar_udf")? as *mut
FFI_ScalarUDF;
+ let udf: ForeignScalarUDF = unsafe { udf_ptr.as_ref().unwrap().try_into()?
};
+ Ok(udf.into())
+}
+
pub fn import_arrow_array_stream<'py>(
py: Python<'py>,
obj: &Bound<PyAny>,
@@ -88,6 +103,59 @@ pub fn import_arrow_array_stream<'py>(
Ok(Box::new(stream_reader))
}
+pub fn import_arrow_array(obj: &Bound<PyAny>) -> Result<(Field, ArrayRef),
PySedonaError> {
+ let schema_and_array = obj.getattr("__arrow_c_array__")?.call0()?;
+ let (schema_capsule, array_capsule): (Bound<PyCapsule>, Bound<PyCapsule>) =
+ schema_and_array.extract()?;
+
+ let ffi_schema = unsafe {
+ FFI_ArrowSchema::from_raw(check_pycapsule(&schema_capsule,
"arrow_schema")? as _)
+ };
+ let ffi_array =
+ unsafe { FFI_ArrowArray::from_raw(check_pycapsule(&array_capsule,
"arrow_array")? as _) };
+
+ let result_field = Field::try_from(&ffi_schema)?;
+ let result_array_data = unsafe { arrow_array::ffi::from_ffi(ffi_array,
&ffi_schema)? };
+
+ Ok((result_field, make_array(result_array_data)))
+}
+
+pub fn import_arg_matcher(
+ obj: &Bound<PyAny>,
+) -> Result<Arc<dyn TypeMatcher + Send + Sync>, PySedonaError> {
+ if let Ok(string_value) = obj.extract::<String>() {
+ match string_value.as_str() {
+ "geometry" => return Ok(ArgMatcher::is_geometry()),
+ "geography" => return Ok(ArgMatcher::is_geography()),
+ "numeric" => return Ok(ArgMatcher::is_numeric()),
+ "string" => return Ok(ArgMatcher::is_string()),
+ "binary" => return Ok(ArgMatcher::is_binary()),
+ "boolean" => return Ok(ArgMatcher::is_boolean()),
+ v => {
+ return Err(PySedonaError::SedonaPython(format!(
+ "Can't interpret literal string '{v}' as ArgMatcher"
+ )))
+ }
+ }
+ }
+
+ let sedona_type = import_sedona_type(obj)?;
+ Ok(ArgMatcher::is_exact(sedona_type))
+}
+
+pub fn import_sedona_type(obj: &Bound<PyAny>) -> Result<SedonaType,
PySedonaError> {
+ let field = import_arrow_field(obj)?;
+ Ok(SedonaType::from_storage_field(&field)?)
+}
+
+pub fn import_arrow_field(obj: &Bound<PyAny>) -> Result<Field, PySedonaError> {
+ let capsule = obj.getattr("__arrow_c_schema__")?.call0()?;
+ let schema =
+ unsafe { FFI_ArrowSchema::from_raw(check_pycapsule(&capsule,
"arrow_schema")? as _) };
+
+ Ok(Field::try_from(&schema)?)
+}
+
pub fn import_arrow_schema(obj: &Bound<PyAny>) -> Result<Schema,
PySedonaError> {
let capsule = obj.getattr("__arrow_c_schema__")?.call0()?;
let schema =
diff --git a/python/sedonadb/src/lib.rs b/python/sedonadb/src/lib.rs
index ca09d87..62a0cab 100644
--- a/python/sedonadb/src/lib.rs
+++ b/python/sedonadb/src/lib.rs
@@ -14,7 +14,7 @@
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
-use crate::error::PySedonaError;
+use crate::{error::PySedonaError, udf::sedona_scalar_udf};
use pyo3::{ffi::Py_uintptr_t, prelude::*};
use sedona_adbc::AdbcSedonadbDriverInit;
use sedona_proj::register::{configure_global_proj_engine,
ProjCrsEngineBuilder};
@@ -27,6 +27,7 @@ mod import_from;
mod reader;
mod runtime;
mod schema;
+mod udf;
const VERSION: &str = env!("CARGO_PKG_VERSION");
@@ -89,6 +90,7 @@ fn _lib(py: Python<'_>, m: &Bound<'_, PyModule>) ->
PyResult<()> {
m.add_function(wrap_pyfunction!(configure_proj_shared, m)?)?;
m.add_function(wrap_pyfunction!(sedona_adbc_driver_init, m)?)?;
m.add_function(wrap_pyfunction!(sedona_python_version, m)?)?;
+ m.add_function(wrap_pyfunction!(sedona_scalar_udf, m)?)?;
m.add_class::<context::InternalContext>()?;
m.add_class::<dataframe::InternalDataFrame>()?;
diff --git a/python/sedonadb/src/schema.rs b/python/sedonadb/src/schema.rs
index d9466ea..d261043 100644
--- a/python/sedonadb/src/schema.rs
+++ b/python/sedonadb/src/schema.rs
@@ -171,6 +171,7 @@ impl PySedonaField {
}
#[pyclass]
+#[derive(Clone, Debug)]
pub struct PySedonaType {
pub inner: SedonaType,
}
diff --git a/python/sedonadb/src/udf.rs b/python/sedonadb/src/udf.rs
new file mode 100644
index 0000000..eeb7cdd
--- /dev/null
+++ b/python/sedonadb/src/udf.rs
@@ -0,0 +1,341 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{ffi::CString, iter::zip, sync::Arc};
+
+use arrow_array::{
+ ffi::{FFI_ArrowArray, FFI_ArrowSchema},
+ ArrayRef,
+};
+use arrow_schema::Field;
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::{ColumnarValue, ScalarUDF, Volatility};
+use datafusion_ffi::udf::FFI_ScalarUDF;
+use pyo3::{
+ pyclass, pyfunction, pymethods,
+ types::{PyAnyMethods, PyCapsule, PyTuple},
+ Bound, PyObject, Python,
+};
+use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF};
+use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};
+
+use crate::{
+ error::PySedonaError,
+ import_from::{check_pycapsule, import_arg_matcher, import_arrow_array,
import_sedona_type},
+ schema::PySedonaType,
+};
+
+#[pyfunction]
+pub fn sedona_scalar_udf<'py>(
+ py: Python<'py>,
+ py_invoke_batch: PyObject,
+ py_return_type: PyObject,
+ py_input_types: Option<Vec<PyObject>>,
+ volatility: &str,
+ name: &str,
+) -> Result<PySedonaScalarUdf, PySedonaError> {
+ let volatility = match volatility {
+ "immutable" => Volatility::Immutable,
+ "stable" => Volatility::Stable,
+ "volatile" => Volatility::Volatile,
+ v => {
+ return Err(PySedonaError::SedonaPython(format!(
+ "Expected one of 'immutable', 'stable', or 'volatile' but got
'{v}'"
+ )));
+ }
+ };
+
+ let scalar_kernel = sedona_scalar_kernel(py, py_input_types,
py_return_type, py_invoke_batch)?;
+ let sedona_scalar_udf =
+ SedonaScalarUDF::new(name, vec![Arc::new(scalar_kernel)], volatility,
None);
+
+ Ok(PySedonaScalarUdf {
+ inner: sedona_scalar_udf,
+ })
+}
+
+fn sedona_scalar_kernel<'py>(
+ py: Python<'py>,
+ input_types: Option<Vec<PyObject>>,
+ py_return_field: PyObject,
+ py_invoke_batch: PyObject,
+) -> Result<PySedonaScalarKernel, PySedonaError> {
+ let matcher = if let Some(input_types) = input_types {
+ let arg_matchers = input_types
+ .iter()
+ .map(|obj| import_arg_matcher(obj.bind(py)))
+ .collect::<Result<Vec<_>, _>>()?;
+ let return_type = import_sedona_type(py_return_field.bind(py))?;
+ Some(ArgMatcher::new(arg_matchers, return_type))
+ } else {
+ None
+ };
+
+ let kernel_impl = PySedonaScalarKernel {
+ matcher,
+ py_return_field,
+ py_invoke_batch,
+ };
+
+ Ok(kernel_impl)
+}
+
+#[derive(Debug)]
+struct PySedonaScalarKernel {
+ matcher: Option<ArgMatcher>,
+ py_return_field: PyObject,
+ py_invoke_batch: PyObject,
+}
+
+impl SedonaScalarKernel for PySedonaScalarKernel {
+ fn return_type(&self, _args: &[SedonaType]) -> Result<Option<SedonaType>> {
+ Err(PySedonaError::SedonaPython("Unexpected call to
return_type()".to_string()).into())
+ }
+
+ fn invoke_batch(
+ &self,
+ _arg_types: &[SedonaType],
+ _args: &[ColumnarValue],
+ ) -> Result<ColumnarValue> {
+ Err(PySedonaError::SedonaPython("Unexpected call to
invoke_batch()".to_string()).into())
+ }
+
+ fn return_type_from_args_and_scalars(
+ &self,
+ args: &[SedonaType],
+ scalar_args: &[Option<&ScalarValue>],
+ ) -> Result<Option<SedonaType>> {
+ if let Some(matcher) = &self.matcher {
+ let return_type = matcher.match_args(args)?;
+ return Ok(return_type);
+ }
+
+ let return_type = Python::with_gil(|py| -> Result<Option<SedonaType>,
PySedonaError> {
+ let py_sedona_types = args
+ .iter()
+ .map(|arg| -> Result<_, PySedonaError> {
Ok(PySedonaType::new(arg.clone())) })
+ .collect::<Result<Vec<_>, _>>()?;
+ let py_scalar_values = zip(&py_sedona_types, scalar_args)
+ .map(|(sedona_type, maybe_arg)| {
+ maybe_arg.map(|arg| PySedonaValue {
+ sedona_type: sedona_type.clone(),
+ value: ColumnarValue::Scalar(arg.clone()),
+ num_rows: 1,
+ })
+ })
+ .collect::<Vec<_>>();
+
+ let py_return_field =
+ self.py_return_field
+ .call(py, (py_sedona_types, py_scalar_values), None)?;
+ if py_return_field.is_none(py) {
+ return Ok(None);
+ }
+
+ let return_type = import_sedona_type(py_return_field.bind(py))?;
+ Ok(Some(return_type))
+ })?;
+
+ Ok(return_type)
+ }
+
+ fn invoke_batch_from_args(
+ &self,
+ arg_types: &[SedonaType],
+ args: &[ColumnarValue],
+ return_type: &SedonaType,
+ num_rows: usize,
+ ) -> Result<ColumnarValue> {
+ let result = Python::with_gil(|py| -> Result<ArrayRef, PySedonaError> {
+ let py_values = zip(arg_types, args)
+ .map(|(sedona_type, arg)| PySedonaValue {
+ sedona_type: PySedonaType::new(sedona_type.clone()),
+ value: arg.clone(),
+ num_rows,
+ })
+ .collect::<Vec<_>>();
+
+ let py_return_type = PySedonaType::new(return_type.clone());
+ let py_args = PyTuple::new(py, py_values)?;
+
+ let result =
+ self.py_invoke_batch
+ .call(py, (py_args, py_return_type, num_rows), None)?;
+ let result_bound = result.bind(py);
+ if !result_bound.hasattr("__arrow_c_array__")? {
+ return Err(
+ PySedonaError::SedonaPython(
+ "Expected result of user-defined function to return an
object implementing __arrow_c_array__()".to_string()
+ )
+ );
+ }
+
+ let (result_field, result_array) =
import_arrow_array(result_bound)?;
+ let result_sedona_type =
SedonaType::from_storage_field(&result_field)?;
+
+ if return_type != &result_sedona_type {
+ let return_type_storage =
SedonaType::Arrow(return_type.storage_type().clone());
+ if return_type_storage != result_sedona_type {
+ return Err(PySedonaError::SedonaPython(format!(
+ "Expected result of user-defined function to return
array of type {return_type} or its storage but got {result_sedona_type}"
+ )));
+ }
+ }
+
+ if result_array.len() != num_rows {
+ return Err(PySedonaError::SedonaPython(format!(
+ "Expected result of user-defined function to return array
of length {num_rows} but got {}",
+ result_array.len()
+ )));
+ }
+
+ Ok(result_array)
+ })?;
+
+ if args.is_empty() {
+ return Ok(ColumnarValue::Array(result));
+ }
+
+ for arg in args {
+ match arg {
+ ColumnarValue::Array(_) => return
Ok(ColumnarValue::Array(result)),
+ ColumnarValue::Scalar(_) => {}
+ }
+ }
+
+ Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
+ &result, 0,
+ )?))
+ }
+}
+
+#[pyclass]
+#[derive(Clone)]
+pub struct PySedonaScalarUdf {
+ pub inner: SedonaScalarUDF,
+}
+
+#[pymethods]
+impl PySedonaScalarUdf {
+ fn __datafusion_scalar_udf__<'py>(
+ &self,
+ py: Python<'py>,
+ ) -> Result<Bound<'py, PyCapsule>, PySedonaError> {
+ let capsule_name = CString::new("datafusion_scalar_udf").unwrap();
+ let scalar_udf: ScalarUDF = self.inner.clone().into();
+ let ffi_scalar_udf = FFI_ScalarUDF::from(Arc::new(scalar_udf));
+ Ok(PyCapsule::new(py, ffi_scalar_udf, Some(capsule_name))?)
+ }
+}
+
+#[pyclass]
+#[derive(Debug)]
+pub struct PySedonaValue {
+ pub sedona_type: PySedonaType,
+ pub value: ColumnarValue,
+ pub num_rows: usize,
+}
+
+#[pymethods]
+impl PySedonaValue {
+ #[getter]
+ fn r#type(&self) -> Result<PySedonaType, PySedonaError> {
+ Ok(self.sedona_type.clone())
+ }
+
+ fn is_scalar(&self) -> bool {
+ matches!(&self.value, ColumnarValue::Scalar(_))
+ }
+
+ #[getter]
+ fn storage(&self) -> Result<Self, PySedonaError> {
+ Ok(PySedonaValue {
+ sedona_type: PySedonaType {
+ inner:
SedonaType::Arrow(self.sedona_type.inner.storage_type().clone()),
+ },
+ value: self.value.clone(),
+ num_rows: self.num_rows,
+ })
+ }
+
+ fn to_array(&self) -> Result<Self, PySedonaError> {
+ Ok(PySedonaValue {
+ sedona_type: self.sedona_type.clone(),
+ value: ColumnarValue::Array(self.value.to_array(self.num_rows)?),
+ num_rows: self.num_rows,
+ })
+ }
+
+ fn __arrow_c_schema__<'py>(
+ &self,
+ py: Python<'py>,
+ ) -> Result<Bound<'py, PyCapsule>, PySedonaError> {
+ let schema_capsule_name = CString::new("arrow_schema").unwrap();
+ let storage_field = self.sedona_type.inner.to_storage_field("", true)?;
+ let ffi_schema = FFI_ArrowSchema::try_from(storage_field)?;
+ Ok(PyCapsule::new(py, ffi_schema, Some(schema_capsule_name))?)
+ }
+
+ #[pyo3(signature = (requested_schema=None))]
+ fn __arrow_c_array__<'py>(
+ &self,
+ py: Python<'py>,
+ requested_schema: Option<Bound<PyCapsule>>,
+ ) -> Result<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>), PySedonaError>
{
+ if let Some(requested_schema) = requested_schema {
+ let ffi_requested_schema = unsafe {
+ FFI_ArrowSchema::from_raw(check_pycapsule(&requested_schema,
"arrow_schema")? as _)
+ };
+ let requested_type =
+
SedonaType::from_storage_field(&Field::try_from(&ffi_requested_schema)?)?;
+ if requested_type != self.sedona_type.inner {
+ return Err(PySedonaError::SedonaPython(
+ "requested type is not implemented for
PySedonaValue".to_string(),
+ ));
+ }
+ }
+
+ let schema_capsule_name = CString::new("arrow_schema").unwrap();
+ let field = self.sedona_type.inner.to_storage_field("", true)?;
+ let ffi_schema = FFI_ArrowSchema::try_from(&field)?;
+
+ let array_capsule_name = CString::new("arrow_array").unwrap();
+ let out_size = match &self.value {
+ ColumnarValue::Array(array) => array.len(),
+ ColumnarValue::Scalar(_) => 1,
+ };
+ let array = self.value.to_array(out_size)?;
+ let ffi_array = FFI_ArrowArray::new(&array.to_data());
+
+ Ok((
+ PyCapsule::new(py, ffi_schema, Some(schema_capsule_name))?,
+ PyCapsule::new(py, ffi_array, Some(array_capsule_name))?,
+ ))
+ }
+
+ fn __repr__(&self) -> String {
+ let label = match &self.value {
+ ColumnarValue::Array(_) => "Array",
+ ColumnarValue::Scalar(_) => "Scalar",
+ };
+
+ format!(
+ "PySedonaValue {label} {}[{}]",
+ self.sedona_type.inner, self.num_rows
+ )
+ }
+}
diff --git a/python/sedonadb/tests/test_udf.py
b/python/sedonadb/tests/test_udf.py
new file mode 100644
index 0000000..3dea726
--- /dev/null
+++ b/python/sedonadb/tests/test_udf.py
@@ -0,0 +1,264 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pandas as pd
+import pyarrow as pa
+import pytest
+from sedonadb import udf
+
+
+def some_udf(arg0, arg1):
+ arg0, arg1 = (
+ pa.array(arg0.to_array()).to_pylist(),
+ pa.array(arg1.to_array()).to_pylist(),
+ )
+ return pa.array(
+ (f"{item0} / {item1}".encode() for item0, item1 in zip(arg0, arg1)),
+ pa.binary(),
+ )
+
+
+def test_udf_matchers(con):
+ udf_impl = udf.arrow_udf(pa.binary(), [udf.STRING, udf.NUMERIC])(some_udf)
+ assert udf_impl._name == "some_udf"
+
+ con.register_udf(udf_impl)
+ pd.testing.assert_frame_equal(
+ con.sql("SELECT some_udf('abcd', 123) as col").to_pandas(),
+ pd.DataFrame({"col": [b"abcd / 123"]}),
+ )
+
+
+def test_udf_types(con):
+ udf_impl = udf.arrow_udf(pa.binary(), [pa.string(), pa.int64()])(some_udf)
+ assert udf_impl._name == "some_udf"
+
+ con.register_udf(udf_impl)
+ pd.testing.assert_frame_equal(
+ con.sql("SELECT some_udf('abcd', 123) as col").to_pandas(),
+ pd.DataFrame({"col": [b"abcd / 123"]}),
+ )
+
+
+def test_udf_any_input(con):
+ udf_impl = udf.arrow_udf(pa.binary())(some_udf)
+ assert udf_impl._name == "some_udf"
+
+ con.register_udf(udf_impl)
+ pd.testing.assert_frame_equal(
+ con.sql("SELECT some_udf('abcd', 123) as col").to_pandas(),
+ pd.DataFrame({"col": [b"abcd / 123"]}),
+ )
+
+
+def test_udf_return_type_fn(con):
+ udf_impl = udf.arrow_udf(lambda arg_types, arg_scalars:
arg_types[0])(some_udf)
+ assert udf_impl._name == "some_udf"
+
+ con.register_udf(udf_impl)
+ pd.testing.assert_frame_equal(
+ con.sql("SELECT some_udf('abcd'::BYTEA, 123) as col").to_pandas(),
+ pd.DataFrame({"col": [b"b'abcd' / 123"]}),
+ )
+
+
+def test_udf_array_input(con):
+ udf_impl = udf.arrow_udf(pa.binary(), [udf.STRING, udf.NUMERIC])(some_udf)
+ assert udf_impl._name == "some_udf"
+
+ con.register_udf(udf_impl)
+ pd.testing.assert_frame_equal(
+ con.sql(
+ "SELECT some_udf(x, 123) as col FROM (VALUES ('a'), ('b'), ('c'))
as t(x)"
+ ).to_pandas(),
+ pd.DataFrame({"col": [b"a / 123", b"b / 123", b"c / 123"]}),
+ )
+
+
+def test_udf_name():
+ udf_impl = udf.arrow_udf(pa.binary(), name="foofy")(some_udf)
+ assert udf_impl._name == "foofy"
+
+
+def test_shapely_udf(con):
+ import shapely
+ import geoarrow.pyarrow as ga
+ import numpy as np
+
+ @udf.arrow_udf(ga.wkb(), [udf.GEOMETRY, udf.NUMERIC])
+ def shapely_udf(geom, distance):
+ geom_wkb = pa.array(geom.storage.to_array())
+ distance = pa.array(distance.to_array())
+ geom = shapely.from_wkb(geom_wkb)
+ result_shapely = shapely.buffer(geom, distance)
+ return pa.array(shapely.to_wkb(result_shapely))
+
+ con.register_udf(shapely_udf)
+
+ pd.testing.assert_frame_equal(
+ con.sql("SELECT ST_Area(shapely_udf(ST_Point(0, 0), 2.0)) as
col").to_pandas(),
+ pd.DataFrame({"col": [12.485780609032208]}),
+ )
+
+ # Ensure we can propagate a crs
+ pd.testing.assert_frame_equal(
+ con.sql(
+ "SELECT ST_SRID(shapely_udf(ST_SetSRID(ST_Point(0, 0), 3857),
2.0)) as col"
+ ).to_pandas(),
+ pd.DataFrame({"col": [3857]}, dtype=np.uint32),
+ )
+
+
+def test_py_sedona_value(con):
+ @udf.arrow_udf(pa.int64())
+ def fn_arg_only(arg):
+ assert repr(arg) == "PySedonaValue Scalar Int64[1]"
+ assert arg.is_scalar() is True
+ assert repr(arg.type) == "SedonaType int64<Int64>"
+
+ return pa.array(range(len(pa.array(arg))))
+
+ con.register_udf(fn_arg_only)
+ con.sql("SELECT fn_arg_only(123)").to_arrow_table()
+
+
+def test_udf_kwargs(con):
+ @udf.arrow_udf(pa.int64())
+ def fn_return_type(arg, *, return_type=None):
+ assert repr(return_type) == "SedonaType int64<Int64>"
+ return pa.array(range(len(pa.array(arg))))
+
+ con.register_udf(fn_return_type)
+ con.sql("SELECT fn_return_type('123')").to_arrow_table()
+
+ @udf.arrow_udf(pa.int64())
+ def fn_num_rows(arg, *, num_rows=None):
+ assert num_rows == 1
+ return pa.array(range(len(pa.array(arg))))
+
+ con.register_udf(fn_num_rows)
+ con.sql("SELECT fn_num_rows('123')").to_arrow_table()
+
+ @udf.arrow_udf(pa.int64())
+ def fn_num_rows_and_return_type(arg, *, num_rows=None, return_type=None):
+ assert repr(return_type) == "SedonaType int64<Int64>"
+ assert num_rows == 1
+ return pa.array(range(len(pa.array(arg))))
+
+ con.register_udf(fn_num_rows_and_return_type)
+ con.sql("SELECT fn_num_rows_and_return_type('123')").to_arrow_table()
+
+
+def test_udf_bad_return_object(con):
+ @udf.arrow_udf(pa.binary())
+ def questionable_udf(arg):
+ return None
+
+ con.register_udf(questionable_udf)
+ with pytest.raises(
+ ValueError,
+ match="Expected result of user-defined function to return an object
implementing __arrow_c_array__",
+ ):
+ con.sql("SELECT questionable_udf(123) as col").to_pandas()
+
+
+def test_udf_bad_return_type(con):
+ @udf.arrow_udf(pa.binary())
+ def questionable_udf(arg):
+ return pa.array(["abc"], pa.string())
+
+ con.register_udf(questionable_udf)
+ with pytest.raises(
+ ValueError,
+ match=(
+ "Expected result of user-defined function to "
+ "return array of type Binary or its storage "
+ "but got Utf8"
+ ),
+ ):
+ con.sql("SELECT questionable_udf(123) as col").to_pandas()
+
+
+def test_udf_bad_return_length(con):
+ @udf.arrow_udf(pa.binary())
+ def questionable_udf(arg):
+ return pa.array([b"abc", b"def"], pa.binary())
+
+ con.register_udf(questionable_udf)
+ with pytest.raises(
+ ValueError,
+ match="Expected result of user-defined function to return array of
length 1 but got 2",
+ ):
+ con.sql("SELECT questionable_udf(123) as col").to_pandas()
+
+
+def test_udf_datafusion_to_sedonadb(con):
+ udf_impl = udf.arrow_udf(
+ pa.binary(), [udf.STRING, udf.NUMERIC], name="some_external_udf"
+ )(some_udf)
+
+ class UdfWrapper:
+ def __init__(self, obj):
+ self.obj = obj
+
+ def __datafusion_scalar_udf__(self):
+ return self.obj.__datafusion_scalar_udf__()
+
+ con.register_udf(UdfWrapper(udf_impl))
+ pd.testing.assert_frame_equal(
+ con.sql("SELECT some_external_udf('abcd', 123) as col").to_pandas(),
+ pd.DataFrame({"col": [b"abcd / 123"]}),
+ )
+
+
+def test_udf_sedonadb_registry_function_to_datafusion(con):
+ datafusion = pytest.importorskip("datafusion")
+ udf_impl = udf.arrow_udf(pa.binary(), [udf.STRING, udf.NUMERIC])(some_udf)
+
+ # Register with our session
+ con.register_udf(udf_impl)
+
+ # Create a datafusion session, fetch our udf and register with the other
session
+ datafusion_ctx = datafusion.SessionContext()
+ datafusion_ctx.register_udf(
+ datafusion.ScalarUDF.from_pycapsule(con._impl.scalar_udf("some_udf"))
+ )
+
+ # Can't quite use to_pandas() because there is a schema/batch nullability
mismatch
+ batches = datafusion_ctx.sql("SELECT some_udf('abcd', 123) as
col").collect()
+ assert len(batches) == 1
+ pd.testing.assert_frame_equal(
+ batches[0].to_pandas(),
+ pd.DataFrame({"col": [b"abcd / 123"]}),
+ )
+
+
+def test_udf_sedonadb_to_datafusion():
+ datafusion = pytest.importorskip("datafusion")
+ udf_impl = udf.arrow_udf(pa.binary(), [udf.STRING, udf.NUMERIC])(some_udf)
+
+ # Create a datafusion session, register udf_impl directly
+ datafusion_ctx = datafusion.SessionContext()
+ datafusion_ctx.register_udf(datafusion.ScalarUDF.from_pycapsule(udf_impl))
+
+ # Can't quite use to_pandas() because there is a schema/batch nullability
mismatch
+ batches = datafusion_ctx.sql("SELECT some_udf('abcd', 123) as
col").collect()
+ assert len(batches) == 1
+ pd.testing.assert_frame_equal(
+ batches[0].to_pandas(),
+ pd.DataFrame({"col": [b"abcd / 123"]}),
+ )
diff --git a/rust/sedona-expr/src/scalar_udf.rs
b/rust/sedona-expr/src/scalar_udf.rs
index 7f4c187..4fca4f1 100644
--- a/rust/sedona-expr/src/scalar_udf.rs
+++ b/rust/sedona-expr/src/scalar_udf.rs
@@ -83,6 +83,16 @@ pub trait SedonaScalarKernel: Debug {
arg_types: &[SedonaType],
args: &[ColumnarValue],
) -> Result<ColumnarValue>;
+
+ fn invoke_batch_from_args(
+ &self,
+ arg_types: &[SedonaType],
+ args: &[ColumnarValue],
+ _return_type: &SedonaType,
+ _num_rows: usize,
+ ) -> Result<ColumnarValue> {
+ self.invoke_batch(arg_types, args)
+ }
}
/// Type definition for a Scalar kernel implementation function
@@ -259,8 +269,8 @@ impl ScalarUDFImpl for SedonaScalarUDF {
})
.collect::<Vec<_>>();
- let (kernel, _) = self.return_type_impl(&arg_types, &arg_scalars)?;
- kernel.invoke_batch(&arg_types, &args.args)
+ let (kernel, return_type) = self.return_type_impl(&arg_types,
&arg_scalars)?;
+ kernel.invoke_batch_from_args(&arg_types, &args.args, &return_type,
args.number_rows)
}
fn aliases(&self) -> &[String] {
diff --git a/rust/sedona-schema/src/matchers.rs
b/rust/sedona-schema/src/matchers.rs
index 57a74dd..2992b05 100644
--- a/rust/sedona-schema/src/matchers.rs
+++ b/rust/sedona-schema/src/matchers.rs
@@ -150,9 +150,12 @@ impl ArgMatcher {
/// Matches the given Arrow type using PartialEq
pub fn is_arrow(data_type: DataType) -> Arc<dyn TypeMatcher + Send + Sync>
{
- Arc::new(IsExact {
- exact_type: SedonaType::Arrow(data_type),
- })
+ Self::is_exact(SedonaType::Arrow(data_type))
+ }
+
+ /// Matches the given [SedonaType] using PartialEq
+ pub fn is_exact(exact_type: SedonaType) -> Arc<dyn TypeMatcher + Send +
Sync> {
+ Arc::new(IsExact { exact_type })
}
/// Matches any geography or geometry argument without considering Crs