This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-rust.git
The following commit(s) were added to refs/heads/main by this push:
new cd9b90d [feature] support python scalar udf:video_snapshot for video
(#336)
cd9b90d is described below
commit cd9b90d2ef540a11820c7d6308e774f8e62b14a5
Author: jerry <[email protected]>
AuthorDate: Thu May 21 13:33:54 2026 +0800
[feature] support python scalar udf:video_snapshot for video (#336)
---
bindings/python/pyproject.toml | 8 +
.../python/python/pypaimon_rust/datafusion.pyi | 44 +-
bindings/python/python/pypaimon_rust/functions.py | 179 ++++++++
.../{datafusion.pyi => functions.pyi} | 22 +-
bindings/python/src/blob.rs | 232 ++++++++++
bindings/python/src/context.rs | 88 +++-
bindings/python/src/lib.rs | 2 +
bindings/python/src/udf.rs | 363 ++++++++++++++++
bindings/python/tests/test_datafusion.py | 475 ++++++++++++++++++++-
crates/integrations/datafusion/src/blob_reader.rs | 101 +++++
crates/integrations/datafusion/src/catalog.rs | 22 +-
crates/integrations/datafusion/src/lib.rs | 2 +
crates/integrations/datafusion/src/sql_context.rs | 19 +-
crates/integrations/datafusion/src/table/mod.rs | 10 +
14 files changed, 1530 insertions(+), 37 deletions(-)
diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml
index 7ff28d0..cdefdfe 100644
--- a/bindings/python/pyproject.toml
+++ b/bindings/python/pyproject.toml
@@ -39,6 +39,12 @@ classifiers = [
"Programming Language :: Rust",
]
+[project.optional-dependencies]
+video = [
+ "av>=17.0,<18.0",
+ "pillow>=12.0,<13.0",
+]
+
[tool.maturin]
module-name = "pypaimon_rust.pypaimon_rust"
python-source = "python"
@@ -54,4 +60,6 @@ dev = [
"pytest>=8.0",
"pyarrow>=17.0,<24.0",
"datafusion==53.0.0",
+ "av>=17.0,<18.0",
+ "pillow>=12.0,<13.0",
]
diff --git a/bindings/python/python/pypaimon_rust/datafusion.pyi
b/bindings/python/python/pypaimon_rust/datafusion.pyi
index 4d0e973..dfa63c7 100644
--- a/bindings/python/python/pypaimon_rust/datafusion.pyi
+++ b/bindings/python/python/pypaimon_rust/datafusion.pyi
@@ -15,14 +15,55 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List
+from typing import Any, Callable, Dict, List, Optional, Sequence, TypeAlias,
Union
import pyarrow
+ArrowTypeLike: TypeAlias = Union[pyarrow.DataType, pyarrow.Field, str]
+InputFieldsLike: TypeAlias = Union[ArrowTypeLike, Sequence[ArrowTypeLike]]
+VolatilityLike: TypeAlias = Union[str, Any]
+
class PaimonCatalog:
def __init__(self, catalog_options: Dict[str, str]) -> None: ...
def __datafusion_catalog_provider__(self, session: Any) -> object: ...
+class PythonScalarUDF:
+ def __init__(
+ self,
+ name: str,
+ func: Callable[..., pyarrow.Array],
+ input_fields: InputFieldsLike,
+ return_field: ArrowTypeLike,
+ volatility: VolatilityLike,
+ ) -> None: ...
+ @staticmethod
+ def udf(
+ func: Callable[..., pyarrow.Array],
+ input_fields: InputFieldsLike,
+ return_field: ArrowTypeLike,
+ volatility: VolatilityLike,
+ name: Optional[str] = None,
+ ) -> "PythonScalarUDF": ...
+ @property
+ def name(self) -> str: ...
+
+def udf(
+ func: Callable[..., pyarrow.Array],
+ input_fields: InputFieldsLike,
+ return_field: ArrowTypeLike,
+ volatility: VolatilityLike,
+ name: Optional[str] = None,
+) -> PythonScalarUDF:
+ """
+ Create a scalar UDF.
+
+ This mirrors DataFusion Python's function-style API:
+ ``udf(func, input_fields, return_field, volatility, name)``.
+ ``input_fields`` and ``return_field`` accept PyArrow DataType or Field
+ values. String type names remain accepted for compatibility.
+ """
+ ...
+
class SQLContext:
def __init__(self) -> None: ...
def register_catalog(
@@ -31,4 +72,5 @@ class SQLContext:
def set_current_catalog(self, catalog_name: str) -> None: ...
def set_current_database(self, database_name: str) -> None: ...
def register_batch(self, name: str, batch: pyarrow.RecordBatch) -> None:
...
+ def register_udf(self, udf: PythonScalarUDF) -> None: ...
def sql(self, sql: str) -> List[pyarrow.RecordBatch]: ...
diff --git a/bindings/python/python/pypaimon_rust/functions.py
b/bindings/python/python/pypaimon_rust/functions.py
new file mode 100644
index 0000000..9a3b4cf
--- /dev/null
+++ b/bindings/python/python/pypaimon_rust/functions.py
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import io
+import logging
+import struct
+from typing import Any, BinaryIO
+
+logger = logging.getLogger(__name__)
+_STILL_IMAGE_FORMATS = {
+ "apng",
+ "bmp_pipe",
+ "gif",
+ "ico",
+ "image2",
+ "image2pipe",
+ "jpeg_pipe",
+ "png_pipe",
+ "tiff_pipe",
+ "webp_pipe",
+}
+
+
+class _BlobDescriptorProbe:
+ CURRENT_VERSION = 2
+ MAGIC = 0x424C4F4244455343
+
+ @classmethod
+ def is_blob_descriptor(cls, data: Any) -> bool:
+ if not isinstance(data, (bytes, bytearray, memoryview)):
+ return False
+ raw = bytes(data)
+ if len(raw) < 9:
+ return False
+
+ version = raw[0]
+ # Version 1 has no magic header, so it cannot be distinguished safely
+ # from arbitrary inline video bytes in this heuristic.
+ if version == 1 or version > cls.CURRENT_VERSION:
+ return False
+
+ try:
+ return struct.unpack("<Q", raw[1:9])[0] == cls.MAGIC
+ except Exception:
+ return False
+
+
+def is_blob_descriptor(data: Any) -> bool:
+ return _BlobDescriptorProbe.is_blob_descriptor(data)
+
+
+def open_blob_descriptor_stream(
+ raw_value: bytes,
+ blob_reader_registry=None,
+) -> BinaryIO:
+ if blob_reader_registry is not None:
+ stream = blob_reader_registry.open_blob_descriptor_stream(raw_value)
+ if stream is not None:
+ return stream
+
+ if _BlobDescriptorProbe.is_blob_descriptor(raw_value):
+ raise RuntimeError(
+ "BlobDescriptor input requires a registered Paimon table FileIO"
+ )
+ return io.BytesIO(bytes(raw_value))
+
+
+def _decode_video_snapshot(
+ stream: BinaryIO,
+ image_format: str,
+ timestamp_ms: int = 0,
+) -> bytes | None:
+ try:
+ import av
+ except ImportError as e:
+ raise ImportError("PyAV is required to decode video snapshots") from e
+
+ with av.open(stream, mode="r") as container:
+ format_names = set((container.format.name or "").split(","))
+ if format_names & _STILL_IMAGE_FORMATS:
+ logger.debug(
+ "video_snapshot input is a still image format: %s",
+ container.format.name,
+ )
+ return None
+ if not container.streams.video:
+ return None
+
+ target_seconds = timestamp_ms / 1000
+ if timestamp_ms > 0:
+ container.seek(timestamp_ms * 1000, backward=True, any_frame=False)
+
+ candidate = None
+ for frame in container.decode(video=0):
+ if (
+ timestamp_ms > 0
+ and frame.time is not None
+ and frame.time < target_seconds
+ ):
+ candidate = frame
+ continue
+ candidate = frame
+ break
+
+ if candidate is not None:
+ try:
+ image = candidate.to_image()
+ except ImportError as e:
+ raise ImportError(
+ "Pillow is required to encode video_snapshot images"
+ ) from e
+ output = io.BytesIO()
+ image.save(output, format=image_format)
+ return output.getvalue()
+ return None
+
+
+def _make_video_snapshot(image_format: str = "PNG", blob_reader_registry=None):
+ image_format = image_format.upper()
+
+ def video_snapshot(values, timestamps_ms=None):
+ try:
+ import pyarrow as pa
+ except ImportError as e:
+ raise ImportError("pyarrow is required to return video_snapshot
results") from e
+
+ frames = []
+ raw_values = values.to_pylist()
+ if timestamps_ms is None:
+ timestamp_values = [0] * len(raw_values)
+ else:
+ timestamp_values = timestamps_ms.to_pylist()
+ if len(timestamp_values) != len(raw_values):
+ raise ValueError(
+ "video_snapshot timestamp argument must have the same row
count"
+ )
+
+ # v1 intentionally decodes rows serially; callers should filter or
limit
+ # large scans before applying video_snapshot.
+ for raw_value, timestamp_ms in zip(raw_values, timestamp_values):
+ if raw_value is None or timestamp_ms is None:
+ frames.append(None)
+ continue
+
+ try:
+ timestamp_ms = int(timestamp_ms)
+ if timestamp_ms < 0:
+ frames.append(None)
+ continue
+ stream = open_blob_descriptor_stream(raw_value,
blob_reader_registry)
+ try:
+ frames.append(
+ _decode_video_snapshot(stream, image_format,
timestamp_ms)
+ )
+ finally:
+ stream.close()
+ except ImportError:
+ raise
+ except Exception as e:
+ logger.warning("Failed to decode video snapshot: %s", e)
+ frames.append(None)
+
+ return pa.array(frames, type=pa.binary())
+
+ return video_snapshot
diff --git a/bindings/python/python/pypaimon_rust/datafusion.pyi
b/bindings/python/python/pypaimon_rust/functions.pyi
similarity index 54%
copy from bindings/python/python/pypaimon_rust/datafusion.pyi
copy to bindings/python/python/pypaimon_rust/functions.pyi
index 4d0e973..aa91ba0 100644
--- a/bindings/python/python/pypaimon_rust/datafusion.pyi
+++ b/bindings/python/python/pypaimon_rust/functions.pyi
@@ -15,20 +15,10 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List
+from typing import Any, BinaryIO
-import pyarrow
-
-class PaimonCatalog:
- def __init__(self, catalog_options: Dict[str, str]) -> None: ...
- def __datafusion_catalog_provider__(self, session: Any) -> object: ...
-
-class SQLContext:
- def __init__(self) -> None: ...
- def register_catalog(
- self, catalog_name: str, catalog_options: Dict[str, str]
- ) -> None: ...
- def set_current_catalog(self, catalog_name: str) -> None: ...
- def set_current_database(self, database_name: str) -> None: ...
- def register_batch(self, name: str, batch: pyarrow.RecordBatch) -> None:
...
- def sql(self, sql: str) -> List[pyarrow.RecordBatch]: ...
+def is_blob_descriptor(data: Any) -> bool: ...
+def open_blob_descriptor_stream(
+ raw_value: bytes,
+ blob_reader_registry: Any | None = None,
+) -> BinaryIO: ...
diff --git a/bindings/python/src/blob.rs b/bindings/python/src/blob.rs
new file mode 100644
index 0000000..9e4ff2d
--- /dev/null
+++ b/bindings/python/src/blob.rs
@@ -0,0 +1,232 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::future::Future;
+
+use paimon::io::{FileIO, FileRead};
+use paimon::spec::BlobDescriptor;
+use paimon_datafusion::runtime::runtime;
+use paimon_datafusion::BlobReaderRegistry;
+use pyo3::exceptions::PyValueError;
+use pyo3::prelude::*;
+use pyo3::types::PyBytes;
+
+use crate::error::to_py_err;
+
+fn block_on_runtime<F>(future: F, panic_error: &'static str) -> F::Output
+where
+ F: Future + Send + 'static,
+ F::Output: Send + 'static,
+{
+ if tokio::runtime::Handle::try_current().is_ok() {
+ let handle = runtime();
+ std::thread::spawn(move || handle.block_on(future))
+ .join()
+ .expect(panic_error)
+ } else {
+ runtime().block_on(future)
+ }
+}
+
+#[pyclass(name = "BlobReaderRegistry", skip_from_py_object)]
+#[derive(Clone)]
+pub(crate) struct PyBlobReaderRegistry {
+ inner: BlobReaderRegistry,
+}
+
+impl PyBlobReaderRegistry {
+ pub(crate) fn new(inner: BlobReaderRegistry) -> Self {
+ Self { inner }
+ }
+}
+
+#[pymethods]
+impl PyBlobReaderRegistry {
+ fn open_blob_descriptor_stream(&self, raw_value: &[u8]) ->
PyResult<Option<PyBlobInputStream>> {
+ if !BlobDescriptor::is_blob_descriptor(raw_value) {
+ return Ok(None);
+ }
+
+ let descriptor =
BlobDescriptor::deserialize(raw_value).map_err(to_py_err)?;
+ let Some(file_io) = self.inner.resolve(descriptor.uri()) else {
+ return Ok(None);
+ };
+
+ Ok(Some(PyBlobInputStream::new(file_io, descriptor)?))
+ }
+}
+
+#[pyclass(name = "BlobInputStream")]
+struct PyBlobInputStream {
+ file_io: FileIO,
+ uri: String,
+ offset: u64,
+ length: Option<u64>,
+ position: u64,
+ closed: bool,
+}
+
+impl PyBlobInputStream {
+ fn new(file_io: FileIO, descriptor: BlobDescriptor) -> PyResult<Self> {
+ if descriptor.offset() < 0 {
+ return Err(PyValueError::new_err(format!(
+ "BlobDescriptor has negative offset: {}",
+ descriptor.offset()
+ )));
+ }
+ if descriptor.length() < -1 {
+ return Err(PyValueError::new_err(format!(
+ "BlobDescriptor has invalid length: {}",
+ descriptor.length()
+ )));
+ }
+
+ Ok(Self {
+ file_io,
+ uri: descriptor.uri().to_string(),
+ offset: descriptor.offset() as u64,
+ length: (descriptor.length() >= 0).then_some(descriptor.length()
as u64),
+ position: 0,
+ closed: false,
+ })
+ }
+
+ fn ensure_open(&self) -> PyResult<()> {
+ if self.closed {
+ Err(PyValueError::new_err("I/O operation on closed file."))
+ } else {
+ Ok(())
+ }
+ }
+
+ fn stream_length(&self, py: Python<'_>) -> PyResult<u64> {
+ if let Some(length) = self.length {
+ return Ok(length);
+ }
+
+ let file_io = self.file_io.clone();
+ let uri = self.uri.clone();
+ let offset = self.offset;
+ py.detach(|| {
+ block_on_runtime(
+ async move {
+ let input = file_io.new_input(&uri).map_err(to_py_err)?;
+ let metadata = input.metadata().await.map_err(to_py_err)?;
+ Ok(metadata.size.saturating_sub(offset))
+ },
+ "paimon blob metadata read thread panicked",
+ )
+ })
+ }
+
+ fn read_bytes(&mut self, py: Python<'_>, size: isize) -> PyResult<Vec<u8>>
{
+ self.ensure_open()?;
+ let stream_length = self.stream_length(py)?;
+ let remaining = stream_length.saturating_sub(self.position);
+ if remaining == 0 || size == 0 {
+ return Ok(Vec::new());
+ }
+
+ let to_read = if size < 0 {
+ remaining
+ } else {
+ remaining.min(size as u64)
+ };
+ let start = self.offset + self.position;
+ let end = start + to_read;
+ let file_io = self.file_io.clone();
+ let uri = self.uri.clone();
+ let bytes = py.detach(|| {
+ block_on_runtime(
+ async move {
+ let input = file_io.new_input(&uri).map_err(to_py_err)?;
+ let reader = input.reader().await.map_err(to_py_err)?;
+ let bytes =
reader.read(start..end).await.map_err(to_py_err)?;
+ Ok::<_, PyErr>(bytes.to_vec())
+ },
+ "paimon blob range read thread panicked",
+ )
+ })?;
+ self.position += bytes.len() as u64;
+ Ok(bytes)
+ }
+}
+
+#[pymethods]
+impl PyBlobInputStream {
+ fn readable(&self) -> bool {
+ true
+ }
+
+ fn seekable(&self) -> bool {
+ true
+ }
+
+ fn tell(&self) -> u64 {
+ self.position
+ }
+
+ #[getter]
+ fn closed(&self) -> bool {
+ self.closed
+ }
+
+ fn __enter__(slf: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> {
+ slf
+ }
+
+ fn __exit__(
+ &mut self,
+ _exc_type: &Bound<'_, PyAny>,
+ _exc: &Bound<'_, PyAny>,
+ _traceback: &Bound<'_, PyAny>,
+ ) -> bool {
+ self.close();
+ false
+ }
+
+ #[pyo3(signature = (size = -1))]
+ fn read<'py>(&mut self, py: Python<'py>, size: isize) ->
PyResult<Bound<'py, PyBytes>> {
+ let bytes = self.read_bytes(py, size)?;
+ Ok(PyBytes::new(py, &bytes))
+ }
+
+ #[pyo3(signature = (pos, whence = 0))]
+ fn seek(&mut self, py: Python<'_>, pos: i64, whence: i32) -> PyResult<u64>
{
+ self.ensure_open()?;
+ let base = match whence {
+ 0 => 0,
+ 1 => self.position as i64,
+ 2 => self.stream_length(py)? as i64,
+ other => return Err(PyValueError::new_err(format!("Invalid whence:
{other}"))),
+ };
+ let target = base
+ .checked_add(pos)
+ .ok_or_else(|| PyValueError::new_err("Seek position overflow"))?;
+ if target < 0 {
+ return Err(PyValueError::new_err(format!(
+ "Negative seek position: {target}"
+ )));
+ }
+ self.position = target as u64;
+ Ok(self.position)
+ }
+
+ fn close(&mut self) {
+ self.closed = true;
+ }
+}
diff --git a/bindings/python/src/context.rs b/bindings/python/src/context.rs
index e1050d3..f65d6a1 100644
--- a/bindings/python/src/context.rs
+++ b/bindings/python/src/context.rs
@@ -18,16 +18,21 @@
use std::collections::HashMap;
use std::sync::Arc;
+use arrow::datatypes::DataType as ArrowDataType;
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
use datafusion::catalog::CatalogProvider;
+use datafusion::logical_expr::{Signature, TypeSignature, Volatility};
use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use paimon::{CatalogFactory, Options};
use paimon_datafusion::{PaimonCatalogProvider, SQLContext};
+use pyo3::exceptions::PyRuntimeWarning;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;
+use crate::blob::PyBlobReaderRegistry;
use crate::error::{df_to_py_err, to_py_err};
+use crate::udf::{build_python_scalar_udf, udf, PyPythonScalarUDFObject};
use paimon_datafusion::runtime::runtime;
fn build_paimon_catalog_provider(
@@ -93,28 +98,78 @@ pub struct PySQLContext {
inner: SQLContext,
}
+impl PySQLContext {
+ fn register_video_snapshot_builtin(&self, py: Python<'_>) -> PyResult<()> {
+ let functions = py.import("pypaimon_rust.functions")?;
+ let blob_reader_registry = Py::new(
+ py,
+ PyBlobReaderRegistry::new(self.inner.blob_reader_registry()),
+ )?;
+ let func = functions
+ .getattr("_make_video_snapshot")?
+ .call1(("PNG", blob_reader_registry))?
+ .unbind();
+ let signature = Signature::one_of(
+ vec![
+ TypeSignature::Exact(vec![ArrowDataType::Binary]),
+ TypeSignature::Exact(vec![ArrowDataType::Binary,
ArrowDataType::Int32]),
+ TypeSignature::Exact(vec![ArrowDataType::Binary,
ArrowDataType::Int64]),
+ ],
+ Volatility::Volatile,
+ );
+ let udf = build_python_scalar_udf(
+ "video_snapshot".to_string(),
+ func,
+ ArrowDataType::Binary,
+ signature,
+ );
+ self.inner.ctx().register_udf(udf);
+ Ok(())
+ }
+
+ fn warn_video_snapshot_registration_failure(py: Python<'_>, err: PyErr) {
+ if let Ok(warnings) = py.import("warnings") {
+ let category = py.get_type::<PyRuntimeWarning>();
+ let _ = warnings.call_method1(
+ "warn",
+ (
+ format!("video_snapshot built-in could not be registered:
{err}"),
+ category,
+ ),
+ );
+ }
+ }
+}
+
#[pymethods]
impl PySQLContext {
#[new]
- fn new() -> Self {
- Self {
+ fn new(py: Python<'_>) -> PyResult<Self> {
+ let ctx = Self {
inner: SQLContext::new(),
+ };
+ if let Err(err) = ctx.register_video_snapshot_builtin(py) {
+ Self::warn_video_snapshot_registration_failure(py, err);
}
+ Ok(ctx)
}
fn register_catalog(
&mut self,
+ py: Python<'_>,
catalog_name: String,
catalog_options: HashMap<String, String>,
) -> PyResult<()> {
let rt = runtime();
- rt.block_on(async {
- let options = Options::from_map(catalog_options);
- let catalog =
CatalogFactory::create(options).await.map_err(to_py_err)?;
- self.inner
- .register_catalog(catalog_name, catalog)
- .await
- .map_err(df_to_py_err)
+ py.detach(|| {
+ rt.block_on(async {
+ let options = Options::from_map(catalog_options);
+ let catalog =
CatalogFactory::create(options).await.map_err(to_py_err)?;
+ self.inner
+ .register_catalog(catalog_name, catalog)
+ .await
+ .map_err(df_to_py_err)
+ })
})
}
@@ -148,11 +203,18 @@ impl PySQLContext {
.map_err(df_to_py_err)
}
+ fn register_udf(&self, udf: &PyPythonScalarUDFObject) -> PyResult<()> {
+ self.inner.ctx().register_udf(udf.datafusion_udf());
+ Ok(())
+ }
+
fn sql(&self, py: Python<'_>, sql: String) -> PyResult<Vec<Py<PyAny>>> {
let rt = runtime();
- let batches = rt.block_on(async {
- let df = self.inner.sql(&sql).await.map_err(df_to_py_err)?;
- df.collect().await.map_err(df_to_py_err)
+ let batches = py.detach(|| {
+ rt.block_on(async {
+ let df = self.inner.sql(&sql).await.map_err(df_to_py_err)?;
+ df.collect().await.map_err(df_to_py_err)
+ })
})?;
batches
.iter()
@@ -164,7 +226,9 @@ impl PySQLContext {
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) ->
PyResult<()> {
let this = PyModule::new(py, "datafusion")?;
this.add_class::<PaimonCatalog>()?;
+ this.add_class::<PyPythonScalarUDFObject>()?;
this.add_class::<PySQLContext>()?;
+ this.add_function(wrap_pyfunction!(udf, &this)?)?;
m.add_submodule(&this)?;
py.import("sys")?
.getattr("modules")?
diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs
index 326796d..5f8d17a 100644
--- a/bindings/python/src/lib.rs
+++ b/bindings/python/src/lib.rs
@@ -17,8 +17,10 @@
use pyo3::prelude::*;
+mod blob;
mod context;
mod error;
+mod udf;
#[pymodule]
fn pypaimon_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
diff --git a/bindings/python/src/udf.rs b/bindings/python/src/udf.rs
new file mode 100644
index 0000000..3634028
--- /dev/null
+++ b/bindings/python/src/udf.rs
@@ -0,0 +1,363 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::any::Any;
+use std::fmt::{self, Debug};
+use std::hash::{Hash, Hasher};
+use std::sync::Arc;
+
+use arrow::array::{make_array, Array, ArrayData, ArrayRef};
+use arrow::datatypes::{DataType as ArrowDataType, Field as ArrowField};
+use arrow::pyarrow::{FromPyArrow, ToPyArrow};
+use datafusion::common::{DataFusionError, Result as DFResult};
+use datafusion::logical_expr::{
+ ColumnarValue, ScalarFunctionArgs, ScalarUDF as DFScalarUDF,
ScalarUDFImpl, Signature,
+ Volatility,
+};
+use pyo3::exceptions::PyTypeError;
+use pyo3::prelude::*;
+use pyo3::types::{PyList, PyTuple};
+
+fn parse_arrow_type(type_name: &str) -> PyResult<ArrowDataType> {
+ match type_name.to_ascii_lowercase().as_str() {
+ "bool" | "boolean" => Ok(ArrowDataType::Boolean),
+ "int8" => Ok(ArrowDataType::Int8),
+ "int16" => Ok(ArrowDataType::Int16),
+ "int" | "int32" | "integer" => Ok(ArrowDataType::Int32),
+ "bigint" | "int64" | "long" => Ok(ArrowDataType::Int64),
+ "float" | "float32" => Ok(ArrowDataType::Float32),
+ "double" | "float64" => Ok(ArrowDataType::Float64),
+ "string" | "utf8" => Ok(ArrowDataType::Utf8),
+ "large_string" | "large_utf8" => Ok(ArrowDataType::LargeUtf8),
+ "binary" => Ok(ArrowDataType::Binary),
+ "large_binary" => Ok(ArrowDataType::LargeBinary),
+ other => Err(PyTypeError::new_err(format!(
+ "Unsupported Arrow type for Python UDF: {other}"
+ ))),
+ }
+}
+
+fn parse_arrow_type_like(value: &Bound<'_, PyAny>) -> PyResult<ArrowDataType> {
+ if let Ok(field) = ArrowField::from_pyarrow_bound(value) {
+ return Ok(field.data_type().clone());
+ }
+ if let Ok(data_type) = ArrowDataType::from_pyarrow_bound(value) {
+ return Ok(data_type);
+ }
+ if let Ok(type_name) = value.extract::<String>() {
+ return parse_arrow_type(&type_name);
+ }
+
+ Err(PyTypeError::new_err(
+ "Expected a pyarrow.DataType, pyarrow.Field, or supported Arrow type
name",
+ ))
+}
+
+fn parse_input_types(input_fields: &Bound<'_, PyAny>) ->
PyResult<Vec<ArrowDataType>> {
+ if let Ok(fields) = input_fields.cast::<PyList>() {
+ return fields
+ .iter()
+ .map(|field| parse_arrow_type_like(&field))
+ .collect();
+ }
+ if let Ok(fields) = input_fields.cast::<PyTuple>() {
+ return fields
+ .iter()
+ .map(|field| parse_arrow_type_like(&field))
+ .collect();
+ }
+
+ Ok(vec![parse_arrow_type_like(input_fields)?])
+}
+
+fn parse_volatility(volatility: &Bound<'_, PyAny>) -> PyResult<Volatility> {
+ let value = if let Ok(value) = volatility.extract::<String>() {
+ value
+ } else if let Ok(name) = volatility.getattr("name") {
+ name.extract::<String>()?
+ } else {
+ volatility.str()?.to_str()?.to_string()
+ };
+
+ match value.to_ascii_lowercase().as_str() {
+ "immutable" => Ok(Volatility::Immutable),
+ "stable" => Ok(Volatility::Stable),
+ "volatile" => Ok(Volatility::Volatile),
+ other => Err(PyTypeError::new_err(format!(
+ "Unsupported UDF volatility: {other}. Expected immutable, stable,
or volatile"
+ ))),
+ }
+}
+
+fn sanitize_udf_name(name: &str) -> String {
+ let mut sanitized = name
+ .chars()
+ .map(|ch| {
+ if ch.is_ascii_alphanumeric() || ch == '_' {
+ ch.to_ascii_lowercase()
+ } else {
+ '_'
+ }
+ })
+ .collect::<String>()
+ .trim_matches('_')
+ .to_string();
+
+ if sanitized.is_empty() {
+ sanitized.push_str("python_udf");
+ }
+ if sanitized
+ .chars()
+ .next()
+ .is_some_and(|ch| ch.is_ascii_digit())
+ {
+ sanitized.insert(0, '_');
+ }
+ sanitized
+}
+
+fn default_udf_name(py: Python<'_>, func: &Py<PyAny>) -> PyResult<String> {
+ let func = func.bind(py);
+ if let Ok(name) = func.getattr("__name__") {
+ return Ok(sanitize_udf_name(&name.extract::<String>()?));
+ }
+ if let Ok(name) = func.getattr("__qualname__") {
+ return Ok(sanitize_udf_name(&name.extract::<String>()?));
+ }
+ let name = func
+ .getattr("__class__")?
+ .getattr("__name__")?
+ .extract::<String>()?;
+ Ok(sanitize_udf_name(&name))
+}
+
+fn df_execution_error(message: impl Into<String>) -> DataFusionError {
+ DataFusionError::Execution(message.into())
+}
+
+fn columnar_value_to_array(value: &ColumnarValue, num_rows: usize) ->
DFResult<ArrayRef> {
+ match value {
+ ColumnarValue::Array(array) => Ok(Arc::clone(array)),
+ ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows),
+ }
+}
+
+struct PyScalarUDF {
+ name: String,
+ func: Py<PyAny>,
+ return_type: ArrowDataType,
+ signature: Signature,
+}
+
+impl PyScalarUDF {
+ fn new(
+ name: String,
+ func: Py<PyAny>,
+ return_type: ArrowDataType,
+ signature: Signature,
+ ) -> Self {
+ Self {
+ name,
+ func,
+ return_type,
+ signature,
+ }
+ }
+}
+
+pub(crate) fn build_python_scalar_udf(
+ name: String,
+ func: Py<PyAny>,
+ return_type: ArrowDataType,
+ signature: Signature,
+) -> DFScalarUDF {
+ DFScalarUDF::new_from_impl(PyScalarUDF::new(name, func, return_type,
signature))
+}
+
+#[pyclass(name = "PythonScalarUDF")]
+pub struct PyPythonScalarUDFObject {
+ name: String,
+ udf: DFScalarUDF,
+}
+
+impl PyPythonScalarUDFObject {
+ fn create(
+ py: Python<'_>,
+ name: String,
+ func: Py<PyAny>,
+ input_fields: &Bound<'_, PyAny>,
+ return_field: &Bound<'_, PyAny>,
+ volatility: &Bound<'_, PyAny>,
+ ) -> PyResult<Self> {
+ if !func.bind(py).is_callable() {
+ return Err(PyTypeError::new_err("`func` argument must be
callable"));
+ }
+
+ let input_types = parse_input_types(input_fields)?;
+ let return_type = parse_arrow_type_like(return_field)?;
+ let volatility = parse_volatility(volatility)?;
+ let signature = Signature::exact(input_types, volatility);
+ let udf = PyScalarUDF::new(name.clone(), func, return_type, signature);
+ Ok(Self {
+ name,
+ udf: DFScalarUDF::new_from_impl(udf),
+ })
+ }
+
+ pub(crate) fn datafusion_udf(&self) -> DFScalarUDF {
+ self.udf.clone()
+ }
+}
+
+#[pymethods]
+impl PyPythonScalarUDFObject {
+ #[new]
+ fn new(
+ py: Python<'_>,
+ name: String,
+ func: Py<PyAny>,
+ input_fields: Bound<'_, PyAny>,
+ return_field: Bound<'_, PyAny>,
+ volatility: Bound<'_, PyAny>,
+ ) -> PyResult<Self> {
+ Self::create(py, name, func, &input_fields, &return_field, &volatility)
+ }
+
+ #[staticmethod]
+ #[pyo3(signature = (func, input_fields, return_field, volatility, name =
None))]
+ fn udf(
+ py: Python<'_>,
+ func: Py<PyAny>,
+ input_fields: Bound<'_, PyAny>,
+ return_field: Bound<'_, PyAny>,
+ volatility: Bound<'_, PyAny>,
+ name: Option<String>,
+ ) -> PyResult<Self> {
+ let name = match name {
+ Some(name) => name,
+ None => default_udf_name(py, &func)?,
+ };
+ Self::create(py, name, func, &input_fields, &return_field, &volatility)
+ }
+
+ #[getter]
+ fn name(&self) -> &str {
+ &self.name
+ }
+
+ fn __repr__(&self) -> String {
+ format!("PythonScalarUDF({})", self.name)
+ }
+}
+
+#[pyfunction]
+#[pyo3(signature = (func, input_fields, return_field, volatility, name =
None))]
+pub(crate) fn udf(
+ py: Python<'_>,
+ func: Py<PyAny>,
+ input_fields: Bound<'_, PyAny>,
+ return_field: Bound<'_, PyAny>,
+ volatility: Bound<'_, PyAny>,
+ name: Option<String>,
+) -> PyResult<PyPythonScalarUDFObject> {
+ PyPythonScalarUDFObject::udf(py, func, input_fields, return_field,
volatility, name)
+}
+
+impl Debug for PyScalarUDF {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("PyScalarUDF")
+ .field("name", &self.name)
+ .field("signature", &self.signature)
+ .field("return_type", &self.return_type)
+ .finish_non_exhaustive()
+ }
+}
+
+impl PartialEq for PyScalarUDF {
+ fn eq(&self, other: &Self) -> bool {
+ self.name == other.name
+ && self.return_type == other.return_type
+ && self.signature == other.signature
+ }
+}
+
+impl Eq for PyScalarUDF {}
+
+impl Hash for PyScalarUDF {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ self.name.hash(state);
+ self.return_type.hash(state);
+ self.signature.hash(state);
+ }
+}
+
+impl ScalarUDFImpl for PyScalarUDF {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[ArrowDataType]) ->
DFResult<ArrowDataType> {
+ Ok(self.return_type.clone())
+ }
+
+ fn invoke_with_args(&self, args: ScalarFunctionArgs) ->
DFResult<ColumnarValue> {
+ let arrays = args
+ .args
+ .iter()
+ .map(|value| columnar_value_to_array(value, args.number_rows))
+ .collect::<DFResult<Vec<_>>>()?;
+
+ let output = Python::try_attach(|py| -> PyResult<ArrayRef> {
+ let py_args = arrays
+ .iter()
+ .map(|array| array.to_data().to_pyarrow(py))
+ .collect::<PyResult<Vec<_>>>()?;
+ let py_args = PyTuple::new(py, py_args)?;
+ let output = self.func.bind(py).call1(py_args)?;
+ Ok(make_array(ArrayData::from_pyarrow_bound(&output)?))
+ })
+ .ok_or_else(|| df_execution_error("Python interpreter is not
available"))?
+ .map_err(|err| df_execution_error(format!("Python UDF '{}' failed:
{err}", self.name)))?;
+
+ if output.len() != args.number_rows {
+ return Err(df_execution_error(format!(
+ "Python UDF '{}' returned {} rows, expected {}",
+ self.name,
+ output.len(),
+ args.number_rows
+ )));
+ }
+ if output.data_type() != &self.return_type {
+ return Err(df_execution_error(format!(
+ "Python UDF '{}' returned {:?}, expected {:?}",
+ self.name,
+ output.data_type(),
+ self.return_type
+ )));
+ }
+
+ Ok(ColumnarValue::Array(output))
+ }
+}
diff --git a/bindings/python/tests/test_datafusion.py
b/bindings/python/tests/test_datafusion.py
index 2576b7c..d120a94 100644
--- a/bindings/python/tests/test_datafusion.py
+++ b/bindings/python/tests/test_datafusion.py
@@ -15,15 +15,63 @@
# specific language governing permissions and limitations
# under the License.
+import io
import os
+import struct
+import sys
import tempfile
+import types
+from pathlib import Path
import pyarrow as pa
+import pytest
from datafusion import SessionContext
-from pypaimon_rust.datafusion import PaimonCatalog, SQLContext
+from pypaimon_rust.datafusion import PaimonCatalog, PythonScalarUDF,
SQLContext, udf
WAREHOUSE = os.environ.get("PAIMON_TEST_WAREHOUSE", "/tmp/paimon-warehouse")
+PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
+BLOB_DESCRIPTOR_MAGIC = 0x424C4F4244455343
+
+
+def serialize_blob_descriptor(uri: str, offset: int, length: int) -> bytes:
+ uri_bytes = uri.encode("utf-8")
+ return (
+ struct.pack("<BQI", 2, BLOB_DESCRIPTOR_MAGIC, len(uri_bytes))
+ + uri_bytes
+ + struct.pack("<qq", offset, length)
+ )
+
+
+def write_sample_video(
+ path: Path,
+ colors: tuple[tuple[int, int, int], ...] = ((240, 40, 80),),
+) -> None:
+ av = pytest.importorskip("av")
+ image_module = pytest.importorskip("PIL.Image")
+
+ with av.open(str(path), mode="w") as container:
+ stream = container.add_stream("mpeg4", rate=1)
+ stream.width = 32
+ stream.height = 32
+ stream.pix_fmt = "yuv420p"
+
+ for color in colors:
+ image = image_module.new("RGB", (32, 32), color=color)
+ frame = av.VideoFrame.from_image(image)
+ for packet in stream.encode(frame):
+ container.mux(packet)
+ for packet in stream.encode():
+ container.mux(packet)
+
+
+def sample_image_bytes() -> bytes:
+ image_module = pytest.importorskip("PIL.Image")
+
+ output = io.BytesIO()
+ image = image_module.new("RGB", (32, 32), color=(40, 120, 220))
+ image.save(output, format="PNG")
+ return output.getvalue()
def extract_rows(batches):
@@ -31,6 +79,207 @@ def extract_rows(batches):
return sorted(zip(table["id"].to_pylist(), table["name"].to_pylist()))
+def test_video_snapshot_builtin_registered_on_context_init():
+ ctx = SQLContext()
+
+ batches = ctx.sql("SELECT video_snapshot(CAST(NULL AS BYTEA)) AS
cover_png")
+ table = pa.Table.from_batches(batches)
+
+ assert table["cover_png"].to_pylist() == [None]
+
+
+def test_sql_context_survives_video_snapshot_registration_failure(monkeypatch):
+ monkeypatch.setitem(
+ sys.modules,
+ "pypaimon_rust.functions",
+ types.SimpleNamespace(),
+ )
+
+ with pytest.warns(
+ RuntimeWarning,
+ match="video_snapshot built-in could not be registered",
+ ):
+ ctx = SQLContext()
+
+ batches = ctx.sql("SELECT 1 AS value")
+ table = pa.Table.from_batches(batches)
+ assert table["value"].to_pylist() == [1]
+
+
+def test_video_snapshot_builtin_auto_registered_for_sql():
+ with tempfile.TemporaryDirectory() as warehouse:
+ video_path = Path(warehouse) / "sample.mp4"
+ write_sample_video(video_path)
+ video_bytes = video_path.read_bytes()
+
+ ctx = SQLContext()
+ ctx.register_catalog("paimon", {"warehouse": warehouse})
+ ctx.register_batch(
+ "paimon.default.videos",
+ pa.record_batch(
+ [[1], pa.array([video_bytes], type=pa.binary())],
+ names=["id", "video"],
+ ),
+ )
+
+ batches = ctx.sql(
+ """
+ SELECT id, video_snapshot(video) AS cover_png
+ FROM paimon.default.videos
+ """
+ )
+ table = pa.Table.from_batches(batches)
+
+ assert table["id"].to_pylist() == [1]
+ assert table["cover_png"].to_pylist()[0].startswith(PNG_SIGNATURE)
+
+ ctx.sql("DROP TEMPORARY TABLE paimon.default.videos")
+
+
+def test_video_snapshot_descriptor_without_table_file_io_returns_null():
+ with tempfile.TemporaryDirectory() as warehouse:
+ video_path = Path(warehouse) / "sample.mp4"
+ write_sample_video(video_path)
+ descriptor = serialize_blob_descriptor(
+ str(video_path), 0, video_path.stat().st_size
+ )
+
+ ctx = SQLContext()
+ ctx.register_catalog("paimon", {"warehouse": warehouse})
+ ctx.register_batch(
+ "paimon.default.videos",
+ pa.record_batch(
+ [[1], pa.array([descriptor], type=pa.binary())],
+ names=["id", "video"],
+ ),
+ )
+
+ batches = ctx.sql(
+ """
+ SELECT id, video_snapshot(video) AS cover_png
+ FROM paimon.default.videos
+ """
+ )
+ table = pa.Table.from_batches(batches)
+
+ assert table["id"].to_pylist() == [1]
+ assert table["cover_png"].to_pylist() == [None]
+
+ ctx.sql("DROP TEMPORARY TABLE paimon.default.videos")
+
+
+def test_video_snapshot_returns_null_for_image_bytes():
+ with tempfile.TemporaryDirectory() as warehouse:
+ ctx = SQLContext()
+ ctx.register_catalog("paimon", {"warehouse": warehouse})
+ ctx.register_batch(
+ "paimon.default.media",
+ pa.record_batch(
+ [[1], pa.array([sample_image_bytes()], type=pa.binary())],
+ names=["id", "content"],
+ ),
+ )
+
+ batches = ctx.sql(
+ """
+ SELECT id, video_snapshot(content) AS cover_png
+ FROM paimon.default.media
+ """
+ )
+ table = pa.Table.from_batches(batches)
+
+ assert table["id"].to_pylist() == [1]
+ assert table["cover_png"].to_pylist() == [None]
+
+ ctx.sql("DROP TEMPORARY TABLE paimon.default.media")
+
+
+def test_video_snapshot_reads_descriptor_with_table_file_io():
+ with tempfile.TemporaryDirectory() as warehouse:
+ ctx = SQLContext()
+ ctx.register_catalog("paimon", {"warehouse": warehouse})
+ ctx.sql("CREATE TABLE paimon.default.videos (id INT, video BINARY)")
+
+ video_path = Path(warehouse) / "default.db" / "videos" / "sample.mp4"
+ video_path.parent.mkdir(parents=True, exist_ok=True)
+ write_sample_video(video_path)
+ descriptor = serialize_blob_descriptor(
+ str(video_path), 0, video_path.stat().st_size
+ )
+
+ ctx.register_batch(
+ "source_videos",
+ pa.record_batch(
+ [[1], pa.array([descriptor], type=pa.binary())],
+ names=["id", "video"],
+ ),
+ )
+ ctx.sql(
+ """
+ INSERT INTO paimon.default.videos
+ SELECT id, video FROM paimon.default.source_videos
+ """
+ )
+
+ batches = ctx.sql(
+ """
+ SELECT id, video_snapshot(video) AS cover_png
+ FROM paimon.default.videos
+ """
+ )
+ table = pa.Table.from_batches(batches)
+
+ assert table["id"].to_pylist() == [1]
+ assert table["cover_png"].to_pylist()[0].startswith(PNG_SIGNATURE)
+
+ ctx.sql("DROP TEMPORARY TABLE paimon.default.source_videos")
+ ctx.sql("DROP TABLE paimon.default.videos")
+
+
+def test_video_snapshot_accepts_timestamp_ms():
+ image_module = pytest.importorskip("PIL.Image")
+
+ with tempfile.TemporaryDirectory() as warehouse:
+ video_path = Path(warehouse) / "sample.mp4"
+ write_sample_video(video_path, colors=((240, 40, 80), (40, 220, 80)))
+ video_bytes = video_path.read_bytes()
+
+ ctx = SQLContext()
+ ctx.register_catalog("paimon", {"warehouse": warehouse})
+ ctx.register_batch(
+ "paimon.default.videos",
+ pa.record_batch(
+ [[1], pa.array([video_bytes], type=pa.binary())],
+ names=["id", "video"],
+ ),
+ )
+
+ batches = ctx.sql(
+ """
+ SELECT
+ video_snapshot(video) AS first_png,
+ video_snapshot(video, CAST(1000 AS INT)) AS second_png,
+ video_snapshot(video, 5000) AS beyond_duration_png
+ FROM paimon.default.videos
+ """
+ )
+ row = pa.Table.from_batches(batches).to_pylist()[0]
+
+ assert row["first_png"].startswith(PNG_SIGNATURE)
+ assert row["second_png"].startswith(PNG_SIGNATURE)
+
+ first_image =
image_module.open(io.BytesIO(row["first_png"])).convert("RGB")
+ second_image =
image_module.open(io.BytesIO(row["second_png"])).convert("RGB")
+ assert first_image.getpixel((16, 16)) != second_image.getpixel((16,
16))
+ assert row["beyond_duration_png"].startswith(PNG_SIGNATURE)
+ beyond_duration_image = image_module.open(
+ io.BytesIO(row["beyond_duration_png"])
+ ).convert("RGB")
+ assert beyond_duration_image.getpixel((16, 16)) ==
second_image.getpixel((16, 16))
+
+ ctx.sql("DROP TEMPORARY TABLE paimon.default.videos")
+
+
def test_query_simple_table_via_catalog_provider():
catalog = PaimonCatalog({"warehouse": WAREHOUSE})
ctx = SessionContext()
@@ -100,6 +349,226 @@ def test_register_batch_bare_name():
ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp")
+def test_register_udf_from_python():
+ with tempfile.TemporaryDirectory() as warehouse:
+ ctx = SQLContext()
+ ctx.register_catalog("paimon", {"warehouse": warehouse})
+
+ batch = pa.record_batch([[1, None, 3]], names=["id"])
+ ctx.register_batch("my_temp", batch)
+
+ def plus_ten(values):
+ return pa.array(
+ [None if value is None else value + 10 for value in
values.to_pylist()],
+ type=pa.int64(),
+ )
+
+ ctx.register_udf(udf(plus_ten, [pa.int64()], pa.int64(), "volatile",
"plus_ten"))
+
+ batches = ctx.sql(
+ "SELECT plus_ten(id) AS id FROM paimon.default.my_temp ORDER BY id"
+ )
+ table = pa.Table.from_batches(batches)
+ assert table["id"].to_pylist() == [11, 13, None]
+
+ ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp")
+
+
+def test_register_udf_default_name_is_sql_identifier_for_closure():
+ with tempfile.TemporaryDirectory() as warehouse:
+ ctx = SQLContext()
+ ctx.register_catalog("paimon", {"warehouse": warehouse})
+
+ batch = pa.record_batch([[1, 2]], names=["id"])
+ ctx.register_batch("my_temp", batch)
+
+ def build_udf():
+ def plus_one(values):
+ return pa.array(
+ [value + 1 for value in values.to_pylist()],
type=pa.int64()
+ )
+
+ return plus_one
+
+ scalar_udf = udf(build_udf(), [pa.int64()], pa.int64(), "volatile")
+ assert scalar_udf.name == "plus_one"
+ ctx.register_udf(scalar_udf)
+
+ batches = ctx.sql(
+ "SELECT plus_one(id) AS id FROM paimon.default.my_temp ORDER BY id"
+ )
+ table = pa.Table.from_batches(batches)
+ assert table["id"].to_pylist() == [2, 3]
+
+ ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp")
+
+
+def test_register_udf_multiple_arguments():
+ with tempfile.TemporaryDirectory() as warehouse:
+ ctx = SQLContext()
+ ctx.register_catalog("paimon", {"warehouse": warehouse})
+
+ batch = pa.record_batch(
+ [
+ pa.array([1, 2, None], type=pa.int64()),
+ pa.array([10, 20, 30], type=pa.int64()),
+ ],
+ names=["id", "delta"],
+ )
+ ctx.register_batch("my_temp", batch)
+
+ def add_values(left, right):
+ values = []
+ for left_value, right_value in zip(left.to_pylist(),
right.to_pylist()):
+ if left_value is None or right_value is None:
+ values.append(None)
+ else:
+ values.append(left_value + right_value)
+ return pa.array(values, type=pa.int64())
+
+ ctx.register_udf(
+ udf(
+ add_values,
+ [pa.int64(), pa.int64()],
+ pa.int64(),
+ "volatile",
+ "add_values",
+ )
+ )
+
+ batches = ctx.sql(
+ """
+ SELECT add_values(id, delta) AS value
+ FROM paimon.default.my_temp
+ ORDER BY id
+ """
+ )
+ table = pa.Table.from_batches(batches)
+ assert table["value"].to_pylist() == [11, 22, None]
+
+ ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp")
+
+
+def test_register_udf_multi_partition_union_plan():
+ with tempfile.TemporaryDirectory() as warehouse:
+ ctx = SQLContext()
+ ctx.register_catalog("paimon", {"warehouse": warehouse})
+
+ batch = pa.record_batch([[1, 2, 3]], names=["id"])
+ ctx.register_batch("my_temp", batch)
+
+ def plus_ten(values):
+ return pa.array([value + 10 for value in values.to_pylist()],
type=pa.int64())
+
+ ctx.register_udf(udf(plus_ten, [pa.int64()], pa.int64(), "volatile",
"plus_ten"))
+
+ batches = ctx.sql(
+ """
+ SELECT plus_ten(id) AS id FROM paimon.default.my_temp
+ UNION ALL
+ SELECT plus_ten(id) AS id FROM paimon.default.my_temp
+ ORDER BY id
+ """
+ )
+ table = pa.Table.from_batches(batches)
+ assert table["id"].to_pylist() == [11, 11, 12, 12, 13, 13]
+
+ ctx.sql("DROP TEMPORARY TABLE paimon.default.my_temp")
+
+
+def test_udf_rejects_non_callable():
+ try:
+ udf(1, [pa.int64()], pa.int64(), "volatile")
+ pytest.fail("expected non-callable UDF creation to fail")
+ except TypeError as e:
+ assert "`func` argument must be callable" in str(e)
+
+
+def test_udf_rejects_unsupported_type():
+ def identity(values):
+ return values
+
+ try:
+ udf(identity, [object()], pa.int64(), "volatile", "identity")
+ pytest.fail("expected unsupported type registration to fail")
+ except TypeError as e:
+ assert "Expected a pyarrow.DataType" in str(e)
+
+
+def test_python_scalar_udf_constructor_matches_datafusion_shape():
+ def identity(values):
+ return values
+
+ scalar_udf = PythonScalarUDF(
+ "identity", identity, [pa.field("value", pa.int64())], pa.int64(),
"stable"
+ )
+
+ assert scalar_udf.name == "identity"
+ assert repr(scalar_udf) == "PythonScalarUDF(identity)"
+
+
+def test_python_udf_exception_surfaces():
+ with tempfile.TemporaryDirectory() as warehouse:
+ ctx = SQLContext()
+ ctx.register_catalog("paimon", {"warehouse": warehouse})
+ ctx.register_batch("my_temp", pa.record_batch([[1]], names=["id"]))
+
+ def boom(values):
+ raise RuntimeError("boom")
+
+ ctx.register_udf(udf(boom, [pa.int64()], pa.int64(), "volatile",
"boom"))
+
+ try:
+ ctx.sql("SELECT boom(id) AS id FROM paimon.default.my_temp")
+ pytest.fail("expected Python UDF exception to fail the query")
+ except Exception as e:
+ message = str(e)
+ assert "Python UDF 'boom' failed" in message
+ assert "boom" in message
+
+
+def test_python_udf_rejects_wrong_length():
+ with tempfile.TemporaryDirectory() as warehouse:
+ ctx = SQLContext()
+ ctx.register_catalog("paimon", {"warehouse": warehouse})
+ ctx.register_batch("my_temp", pa.record_batch([[1, 2]], names=["id"]))
+
+ def wrong_length(values):
+ return pa.array([1], type=pa.int64())
+
+ ctx.register_udf(
+ udf(wrong_length, [pa.int64()], pa.int64(), "volatile",
"wrong_length")
+ )
+
+ try:
+ ctx.sql("SELECT wrong_length(id) AS id FROM
paimon.default.my_temp")
+ pytest.fail("expected wrong-length UDF result to fail the query")
+ except Exception as e:
+ message = str(e)
+ assert "Python UDF 'wrong_length' returned 1 rows, expected 2" in
message
+
+
+def test_python_udf_rejects_wrong_type():
+ with tempfile.TemporaryDirectory() as warehouse:
+ ctx = SQLContext()
+ ctx.register_catalog("paimon", {"warehouse": warehouse})
+ ctx.register_batch("my_temp", pa.record_batch([[1]], names=["id"]))
+
+ def wrong_type(values):
+ return pa.array(["not an int"], type=pa.string())
+
+ ctx.register_udf(
+ udf(wrong_type, [pa.int64()], pa.int64(), "volatile", "wrong_type")
+ )
+
+ try:
+ ctx.sql("SELECT wrong_type(id) AS id FROM paimon.default.my_temp")
+ pytest.fail("expected wrong-type UDF result to fail the query")
+ except Exception as e:
+ message = str(e)
+ assert "Python UDF 'wrong_type' returned Utf8, expected Int64" in
message
+
+
def test_temp_table_shadows_paimon_table():
with tempfile.TemporaryDirectory() as warehouse:
ctx = SQLContext()
@@ -174,7 +643,7 @@ def test_register_batch_invalid_catalog():
batch = pa.record_batch([[1]], names=["id"])
try:
ctx.register_batch("unknown_catalog.default.my_temp", batch)
- assert False, "Expected an error for unknown catalog"
+ pytest.fail("Expected an error for unknown catalog")
except Exception as e:
assert "unknown_catalog" in str(e).lower() or "not a paimon" in
str(e).lower() or "unknown" in str(e).lower()
@@ -191,6 +660,6 @@ def test_table_functions_registered_with_catalog():
for fn in ("vector_search", "full_text_search"):
try:
ctx.sql(f"SELECT * FROM {fn}('only_one_arg')")
- assert False, f"expected {fn} to reject a single argument"
+ pytest.fail(f"expected {fn} to reject a single argument")
except Exception as e:
assert "requires 4 arguments" in str(e), str(e)
diff --git a/crates/integrations/datafusion/src/blob_reader.rs
b/crates/integrations/datafusion/src/blob_reader.rs
new file mode 100644
index 0000000..fe7bd88
--- /dev/null
+++ b/crates/integrations/datafusion/src/blob_reader.rs
@@ -0,0 +1,101 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::{Arc, RwLock};
+
+use paimon::io::FileIO;
+
+#[derive(Clone, Debug)]
+struct BlobFileIO {
+ prefix: String,
+ file_io: FileIO,
+}
+
+/// Session-scoped registry of Paimon [`FileIO`] instances for BlobDescriptor
reads.
+#[derive(Clone, Debug, Default)]
+pub struct BlobReaderRegistry {
+ readers: Arc<RwLock<Vec<BlobFileIO>>>,
+}
+
+impl BlobReaderRegistry {
+ pub fn register(&self, prefix: impl Into<String>, file_io: FileIO) {
+ let prefix = prefix.into();
+ let mut readers = self.readers.write().unwrap_or_else(|e|
e.into_inner());
+ if let Some(existing) = readers.iter_mut().find(|reader| reader.prefix
== prefix) {
+ existing.file_io = file_io;
+ return;
+ }
+ readers.push(BlobFileIO { prefix, file_io });
+ }
+
+ pub fn register_if_absent(&self, prefix: impl Into<String>, file_io:
FileIO) {
+ let prefix = prefix.into();
+ let mut readers = self.readers.write().unwrap_or_else(|e|
e.into_inner());
+ if readers.iter().any(|reader| reader.prefix == prefix) {
+ return;
+ }
+ readers.push(BlobFileIO { prefix, file_io });
+ }
+
+ pub fn resolve(&self, uri: &str) -> Option<FileIO> {
+ let readers = self.readers.read().unwrap_or_else(|e| e.into_inner());
+ readers
+ .iter()
+ .filter(|reader| uri.starts_with(&reader.prefix))
+ .max_by_key(|reader| reader.prefix.len())
+ .map(|reader| reader.file_io.clone())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::fs;
+
+ use paimon::io::{FileIOBuilder, FileRead};
+ use paimon::spec::BlobDescriptor;
+
+ use super::*;
+
+ #[tokio::test]
+ async fn resolves_file_blob_descriptor_with_file_io() {
+ let directory = tempfile::tempdir().unwrap();
+ let blob_path = directory.path().join("blob.bin");
+ fs::write(&blob_path, b"prefixpayloadsuffix").unwrap();
+
+ let descriptor = BlobDescriptor::new(
+ blob_path.to_string_lossy().to_string(),
+ 6,
+ "payload".len() as i64,
+ );
+ let descriptor =
BlobDescriptor::deserialize(&descriptor.serialize()).unwrap();
+
+ let registry = BlobReaderRegistry::default();
+ let file_io = FileIOBuilder::new("file").build().unwrap();
+ registry.register(directory.path().to_string_lossy().to_string(),
file_io);
+
+ let resolved_file_io = registry
+ .resolve(descriptor.uri())
+ .expect("file blob descriptor should resolve to registered
FileIO");
+ let input = resolved_file_io.new_input(descriptor.uri()).unwrap();
+ let reader = input.reader().await.unwrap();
+ let start = descriptor.offset() as u64;
+ let end = start + descriptor.length() as u64;
+ let bytes = reader.read(start..end).await.unwrap();
+
+ assert_eq!(&bytes[..], b"payload");
+ }
+}
diff --git a/crates/integrations/datafusion/src/catalog.rs
b/crates/integrations/datafusion/src/catalog.rs
index a93201a..018ef6b 100644
--- a/crates/integrations/datafusion/src/catalog.rs
+++ b/crates/integrations/datafusion/src/catalog.rs
@@ -34,7 +34,7 @@ use crate::error::to_datafusion_error;
use crate::runtime::{await_with_runtime, block_on_with_runtime};
use crate::system_tables;
use crate::table::PaimonTableProvider;
-use crate::DynamicOptions;
+use crate::{BlobReaderRegistry, DynamicOptions};
/// Provides an interface to manage and access multiple schemas (databases)
/// within a Paimon [`Catalog`].
@@ -54,6 +54,7 @@ pub struct PaimonCatalogProvider {
/// propagate the panic to all subsequent operations. The worst case is a
temp table
/// becoming invisible or stale, which is recoverable by re-registering it.
temp_tables: Arc<RwLock<HashMap<String, Arc<MemorySchemaProvider>>>>,
+ blob_reader_registry: BlobReaderRegistry,
}
impl Debug for PaimonCatalogProvider {
@@ -73,17 +74,20 @@ impl PaimonCatalogProvider {
catalog,
dynamic_options: Default::default(),
temp_tables: Arc::new(RwLock::new(HashMap::new())),
+ blob_reader_registry: BlobReaderRegistry::default(),
}
}
pub(crate) fn with_dynamic_options(
catalog: Arc<dyn Catalog>,
dynamic_options: DynamicOptions,
+ blob_reader_registry: BlobReaderRegistry,
) -> Self {
PaimonCatalogProvider {
catalog,
dynamic_options,
temp_tables: Arc::new(RwLock::new(HashMap::new())),
+ blob_reader_registry,
}
}
}
@@ -109,6 +113,7 @@ impl CatalogProvider for PaimonCatalogProvider {
fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
let catalog = Arc::clone(&self.catalog);
let dynamic_options = Arc::clone(&self.dynamic_options);
+ let blob_reader_registry = self.blob_reader_registry.clone();
let name = name.to_string();
let temp_provider = {
@@ -124,6 +129,7 @@ impl CatalogProvider for PaimonCatalogProvider {
name,
dynamic_options,
temp_provider,
+ blob_reader_registry,
)) as Arc<dyn SchemaProvider>),
Err(paimon::Error::DatabaseNotExist { .. }) => {
if temp_provider.is_some() {
@@ -132,6 +138,7 @@ impl CatalogProvider for PaimonCatalogProvider {
name,
dynamic_options,
temp_provider,
+ blob_reader_registry,
)) as Arc<dyn SchemaProvider>)
} else {
None
@@ -154,6 +161,7 @@ impl CatalogProvider for PaimonCatalogProvider {
) -> DFResult<Option<Arc<dyn SchemaProvider>>> {
let catalog = Arc::clone(&self.catalog);
let dynamic_options = Arc::clone(&self.dynamic_options);
+ let blob_reader_registry = self.blob_reader_registry.clone();
let name = name.to_string();
block_on_with_runtime(
async move {
@@ -166,6 +174,7 @@ impl CatalogProvider for PaimonCatalogProvider {
name,
dynamic_options,
None,
+ blob_reader_registry,
)) as Arc<dyn SchemaProvider>))
},
"paimon catalog access thread panicked",
@@ -179,6 +188,7 @@ impl CatalogProvider for PaimonCatalogProvider {
) -> DFResult<Option<Arc<dyn SchemaProvider>>> {
let catalog = Arc::clone(&self.catalog);
let dynamic_options = Arc::clone(&self.dynamic_options);
+ let blob_reader_registry = self.blob_reader_registry.clone();
let name = name.to_string();
block_on_with_runtime(
async move {
@@ -191,6 +201,7 @@ impl CatalogProvider for PaimonCatalogProvider {
name,
dynamic_options,
None,
+ blob_reader_registry,
)) as Arc<dyn SchemaProvider>))
},
"paimon catalog access thread panicked",
@@ -289,6 +300,7 @@ pub struct PaimonSchemaProvider {
dynamic_options: DynamicOptions,
/// Optional temporary in-memory provider for temp tables and views.
temp_provider: Option<Arc<MemorySchemaProvider>>,
+ blob_reader_registry: BlobReaderRegistry,
}
impl Debug for PaimonSchemaProvider {
@@ -307,12 +319,14 @@ impl PaimonSchemaProvider {
database: String,
dynamic_options: DynamicOptions,
temp_provider: Option<Arc<MemorySchemaProvider>>,
+ blob_reader_registry: BlobReaderRegistry,
) -> Self {
PaimonSchemaProvider {
catalog,
database,
dynamic_options,
temp_provider,
+ blob_reader_registry,
}
}
}
@@ -372,6 +386,7 @@ impl SchemaProvider for PaimonSchemaProvider {
let catalog = Arc::clone(&self.catalog);
let dynamic_options = Arc::clone(&self.dynamic_options);
+ let blob_reader_registry = self.blob_reader_registry.clone();
let identifier = Identifier::new(self.database.clone(), base);
await_with_runtime(async move {
match catalog.get_table(&identifier).await {
@@ -382,7 +397,10 @@ impl SchemaProvider for PaimonSchemaProvider {
} else {
table.copy_with_options(opts)
};
- let provider = PaimonTableProvider::try_new(table)?;
+ let provider =
PaimonTableProvider::try_new_with_blob_reader_registry(
+ table,
+ blob_reader_registry,
+ )?;
Ok(Some(Arc::new(provider) as Arc<dyn TableProvider>))
}
Err(paimon::Error::TableNotExist { .. }) => Ok(None),
diff --git a/crates/integrations/datafusion/src/lib.rs
b/crates/integrations/datafusion/src/lib.rs
index 47f1bab..f11cfce 100644
--- a/crates/integrations/datafusion/src/lib.rs
+++ b/crates/integrations/datafusion/src/lib.rs
@@ -36,6 +36,7 @@
//! This version supports partition predicate pushdown by extracting
//! translatable partition-only conjuncts from DataFusion filters.
+mod blob_reader;
mod catalog;
mod delete;
mod error;
@@ -63,6 +64,7 @@ use std::sync::{Arc, RwLock};
/// so that SET/RESET mutations are visible to subsequent table scans.
pub(crate) type DynamicOptions = Arc<RwLock<HashMap<String, String>>>;
+pub use blob_reader::BlobReaderRegistry;
pub use catalog::{PaimonCatalogProvider, PaimonSchemaProvider};
pub use error::to_datafusion_error;
#[cfg(feature = "fulltext")]
diff --git a/crates/integrations/datafusion/src/sql_context.rs
b/crates/integrations/datafusion/src/sql_context.rs
index f77b79e..ec93741 100644
--- a/crates/integrations/datafusion/src/sql_context.rs
+++ b/crates/integrations/datafusion/src/sql_context.rs
@@ -65,7 +65,7 @@ use paimon::spec::{
};
use crate::error::to_datafusion_error;
-use crate::DynamicOptions;
+use crate::{BlobReaderRegistry, DynamicOptions};
/// A SQL context that supports registering multiple Paimon catalogs and
executing SQL.
///
@@ -81,6 +81,7 @@ pub struct SQLContext {
catalogs: HashMap<String, Arc<dyn Catalog>>,
/// Session-scoped dynamic options set via `SET 'paimon.key' = 'value'`.
dynamic_options: DynamicOptions,
+ blob_reader_registry: BlobReaderRegistry,
}
impl Default for SQLContext {
@@ -102,9 +103,14 @@ impl SQLContext {
ctx,
catalogs: HashMap::new(),
dynamic_options: Default::default(),
+ blob_reader_registry: BlobReaderRegistry::default(),
}
}
+ pub fn blob_reader_registry(&self) -> BlobReaderRegistry {
+ self.blob_reader_registry.clone()
+ }
+
/// Registers a Paimon catalog under the given name.
///
/// The first registered catalog automatically becomes the current catalog
@@ -134,6 +140,7 @@ impl SQLContext {
Arc::new(crate::catalog::PaimonCatalogProvider::with_dynamic_options(
catalog.clone(),
self.dynamic_options.clone(),
+ self.blob_reader_registry.clone(),
)),
);
register_table_functions(&self.ctx, &catalog, default_db);
@@ -471,7 +478,10 @@ impl SQLContext {
options.insert(SCAN_VERSION_OPTION.to_string(),
info.version.clone());
let table_with_options = paimon_table.copy_with_options(options);
- let provider =
Arc::new(PaimonTableProvider::try_new(table_with_options)?);
+ let provider =
Arc::new(PaimonTableProvider::try_new_with_blob_reader_registry(
+ table_with_options,
+ self.blob_reader_registry.clone(),
+ )?);
let uuid_name = format!("__paimon_tt_{}",
uuid::Uuid::new_v4().as_simple());
self.register_temp_table(uuid_name.as_str(), provider)?;
@@ -497,7 +507,10 @@ impl SQLContext {
options.insert(SCAN_TIMESTAMP_MILLIS_OPTION.to_string(),
millis.to_string());
let table_with_options = paimon_table.copy_with_options(options);
- let provider =
Arc::new(PaimonTableProvider::try_new(table_with_options)?);
+ let provider =
Arc::new(PaimonTableProvider::try_new_with_blob_reader_registry(
+ table_with_options,
+ self.blob_reader_registry.clone(),
+ )?);
let uuid_name = format!("__paimon_tt_{}",
uuid::Uuid::new_v4().as_simple());
self.register_temp_table(uuid_name.as_str(), provider)?;
diff --git a/crates/integrations/datafusion/src/table/mod.rs
b/crates/integrations/datafusion/src/table/mod.rs
index 5bae744..3508362 100644
--- a/crates/integrations/datafusion/src/table/mod.rs
+++ b/crates/integrations/datafusion/src/table/mod.rs
@@ -32,6 +32,7 @@ use datafusion::physical_plan::ExecutionPlan;
use paimon::table::Table;
use crate::physical_plan::PaimonDataSink;
+use crate::BlobReaderRegistry;
use crate::error::to_datafusion_error;
#[cfg(test)]
@@ -74,6 +75,15 @@ impl PaimonTableProvider {
Ok(Self { table, schema })
}
+ pub fn try_new_with_blob_reader_registry(
+ table: Table,
+ blob_reader_registry: BlobReaderRegistry,
+ ) -> DFResult<Self> {
+ blob_reader_registry
+ .register_if_absent(table.location().to_string(),
table.file_io().clone());
+ Self::try_new(table)
+ }
+
pub fn table(&self) -> &Table {
&self.table
}