This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 8fc94362 feat: add CatalogProviderList support (#1363)
8fc94362 is described below
commit 8fc943629b93b342ef67bb8aea0aa581615a374d
Author: Tim Saucer <[email protected]>
AuthorDate: Mon Feb 9 09:17:26 2026 -0500
feat: add CatalogProviderList support (#1363)
* Implement catalog provider list
* Flush out python side and add unit test
* Add FFI test for catalog provider list
* Update type hints
* Update unit test to add a different type of catalog to the catalog list
---
.../python/tests/_test_catalog_provider.py | 26 ++-
.../datafusion-ffi-example/src/catalog_provider.rs | 69 ++++++-
examples/datafusion-ffi-example/src/lib.rs | 3 +-
python/datafusion/catalog.py | 91 +++++++++
python/datafusion/context.py | 27 ++-
python/tests/test_catalog.py | 35 +++-
src/catalog.rs | 224 ++++++++++++++++++++-
src/context.rs | 41 +++-
8 files changed, 496 insertions(+), 20 deletions(-)
diff --git
a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py
b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py
index b26e1208..a862b23b 100644
--- a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py
+++ b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py
@@ -22,7 +22,7 @@ import pyarrow.dataset as ds
import pytest
from datafusion import SessionContext, Table
from datafusion.catalog import Schema
-from datafusion_ffi_example import MyCatalogProvider
+from datafusion_ffi_example import MyCatalogProvider, MyCatalogProviderList
def create_test_dataset() -> Table:
@@ -35,6 +35,30 @@ def create_test_dataset() -> Table:
return Table(dataset)
[email protected]("inner_capsule", [True, False])
+def test_ffi_catalog_provider_list(inner_capsule: bool) -> None:
+ """Test basic FFI CatalogProviderList functionality."""
+ ctx = SessionContext()
+
+ # Register FFI catalog
+ catalog_provider_list = MyCatalogProviderList()
+ if inner_capsule:
+ catalog_provider_list = (
+ catalog_provider_list.__datafusion_catalog_provider_list__(ctx)
+ )
+
+ ctx.register_catalog_provider_list(catalog_provider_list)
+
+ # Verify the catalog exists
+ catalog = ctx.catalog("auto_ffi_catalog")
+ schema_names = catalog.names()
+ assert "my_schema" in schema_names
+
+ ctx.register_catalog_provider("second", MyCatalogProvider())
+
+ assert ctx.catalog_names() == {"auto_ffi_catalog", "second"}
+
+
@pytest.mark.parametrize("inner_capsule", [True, False])
def test_ffi_catalog_provider_basic(inner_capsule: bool) -> None:
"""Test basic FFI CatalogProvider functionality."""
diff --git a/examples/datafusion-ffi-example/src/catalog_provider.rs
b/examples/datafusion-ffi-example/src/catalog_provider.rs
index 57022274..aee23602 100644
--- a/examples/datafusion-ffi-example/src/catalog_provider.rs
+++ b/examples/datafusion-ffi-example/src/catalog_provider.rs
@@ -22,11 +22,12 @@ use std::sync::Arc;
use arrow::datatypes::Schema;
use async_trait::async_trait;
use datafusion_catalog::{
- CatalogProvider, MemTable, MemoryCatalogProvider, MemorySchemaProvider,
SchemaProvider,
- TableProvider,
+ CatalogProvider, CatalogProviderList, MemTable, MemoryCatalogProvider,
+ MemoryCatalogProviderList, MemorySchemaProvider, SchemaProvider,
TableProvider,
};
use datafusion_common::error::{DataFusionError, Result};
use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
+use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList;
use datafusion_ffi::schema_provider::FFI_SchemaProvider;
use pyo3::types::PyCapsule;
use pyo3::{pyclass, pymethods, Bound, PyAny, PyResult, Python};
@@ -203,3 +204,67 @@ impl MyCatalogProvider {
PyCapsule::new(py, provider, Some(name))
}
}
+
+/// This catalog provider list is intended only for unit tests.
+/// It pre-populates with a single catalog.
+#[pyclass(
+ name = "MyCatalogProviderList",
+ module = "datafusion_ffi_example",
+ subclass
+)]
+#[derive(Debug, Clone)]
+pub(crate) struct MyCatalogProviderList {
+ inner: Arc<MemoryCatalogProviderList>,
+}
+
+impl CatalogProviderList for MyCatalogProviderList {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn catalog_names(&self) -> Vec<String> {
+ self.inner.catalog_names()
+ }
+
+ fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
+ self.inner.catalog(name)
+ }
+
+ fn register_catalog(
+ &self,
+ name: String,
+ catalog: Arc<dyn CatalogProvider>,
+ ) -> Option<Arc<dyn CatalogProvider>> {
+ self.inner.register_catalog(name, catalog)
+ }
+}
+
+#[pymethods]
+impl MyCatalogProviderList {
+ #[new]
+ pub fn new() -> PyResult<Self> {
+ let inner = Arc::new(MemoryCatalogProviderList::new());
+
+ inner.register_catalog(
+ "auto_ffi_catalog".to_owned(),
+ Arc::new(MyCatalogProvider::new()?),
+ );
+
+ Ok(Self { inner })
+ }
+
+ pub fn __datafusion_catalog_provider_list__<'py>(
+ &self,
+ py: Python<'py>,
+ session: Bound<PyAny>,
+ ) -> PyResult<Bound<'py, PyCapsule>> {
+ let name = cr"datafusion_catalog_provider_list".into();
+
+ let provider = Arc::clone(&self.inner) as Arc<dyn CatalogProviderList
+ Send>;
+
+ let codec = ffi_logical_codec_from_pycapsule(session)?;
+ let provider = FFI_CatalogProviderList::new_with_ffi_codec(provider,
None, codec);
+
+ PyCapsule::new(py, provider, Some(name))
+ }
+}
diff --git a/examples/datafusion-ffi-example/src/lib.rs
b/examples/datafusion-ffi-example/src/lib.rs
index 005d8b80..6c64c9fe 100644
--- a/examples/datafusion-ffi-example/src/lib.rs
+++ b/examples/datafusion-ffi-example/src/lib.rs
@@ -18,7 +18,7 @@
use pyo3::prelude::*;
use crate::aggregate_udf::MySumUDF;
-use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider};
+use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider,
MyCatalogProviderList};
use crate::scalar_udf::IsNullUDF;
use crate::table_function::MyTableFunction;
use crate::table_provider::MyTableProvider;
@@ -37,6 +37,7 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) ->
PyResult<()> {
m.add_class::<MyTableProvider>()?;
m.add_class::<MyTableFunction>()?;
m.add_class::<MyCatalogProvider>()?;
+ m.add_class::<MyCatalogProviderList>()?;
m.add_class::<FixedSchemaProvider>()?;
m.add_class::<IsNullUDF>()?;
m.add_class::<MySumUDF>()?;
diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py
index 16c3ccc2..bc43cf34 100644
--- a/python/datafusion/catalog.py
+++ b/python/datafusion/catalog.py
@@ -38,13 +38,61 @@ except ImportError:
__all__ = [
"Catalog",
+ "CatalogList",
"CatalogProvider",
+ "CatalogProviderList",
"Schema",
"SchemaProvider",
"Table",
]
+class CatalogList:
+ """DataFusion data catalog list."""
+
+ def __init__(self, catalog_list: df_internal.catalog.RawCatalogList) ->
None:
+ """This constructor is not typically called by the end user."""
+ self.catalog_list = catalog_list
+
+ def __repr__(self) -> str:
+ """Print a string representation of the catalog list."""
+ return self.catalog_list.__repr__()
+
+ def names(self) -> set[str]:
+ """This is an alias for `catalog_names`."""
+ return self.catalog_names()
+
+ def catalog_names(self) -> set[str]:
+ """Returns the list of schemas in this catalog."""
+ return self.catalog_list.catalog_names()
+
+ @staticmethod
+ def memory_catalog(ctx: SessionContext | None = None) -> CatalogList:
+ """Create an in-memory catalog provider list."""
+ catalog_list = df_internal.catalog.RawCatalogList.memory_catalog(ctx)
+ return CatalogList(catalog_list)
+
+ def catalog(self, name: str = "datafusion") -> Catalog:
+ """Returns the catalog with the given ``name`` from this catalog."""
+ catalog = self.catalog_list.catalog(name)
+
+ return (
+ Catalog(catalog)
+ if isinstance(catalog, df_internal.catalog.RawCatalog)
+ else catalog
+ )
+
+ def register_catalog(
+ self,
+ name: str,
+ catalog: Catalog | CatalogProvider | CatalogProviderExportable,
+ ) -> Catalog | None:
+ """Register a catalog with this catalog list."""
+ if isinstance(catalog, Catalog):
+ return self.catalog_list.register_catalog(name, catalog.catalog)
+ return self.catalog_list.register_catalog(name, catalog)
+
+
class Catalog:
"""DataFusion data catalog."""
@@ -195,6 +243,40 @@ class Table:
return self._inner.kind
+class CatalogProviderList(ABC):
+ """Abstract class for defining a Python based Catalog Provider List."""
+
+ @abstractmethod
+ def catalog_names(self) -> set[str]:
+ """Set of the names of all catalogs in this catalog list."""
+ ...
+
+ @abstractmethod
+ def catalog(
+ self, name: str
+ ) -> CatalogProviderExportable | CatalogProvider | Catalog | None:
+ """Retrieve a specific catalog from this catalog list."""
+ ...
+
+ def register_catalog( # noqa: B027
+ self, name: str, catalog: CatalogProviderExportable | CatalogProvider
| Catalog
+ ) -> None:
+ """Add a catalog to this catalog list.
+
+ This method is optional. If your catalog provides a fixed list of
catalogs, you
+ do not need to implement this method.
+ """
+
+
+class CatalogProviderListExportable(Protocol):
+ """Type hint for object that has __datafusion_catalog_provider_list__
PyCapsule.
+
+
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProviderList.html
+ """
+
+ def __datafusion_catalog_provider_list__(self, session: Any) -> object: ...
+
+
class CatalogProvider(ABC):
"""Abstract class for defining a Python based Catalog Provider."""
@@ -229,6 +311,15 @@ class CatalogProvider(ABC):
"""
+class CatalogProviderExportable(Protocol):
+ """Type hint for object that has __datafusion_catalog_provider__ PyCapsule.
+
+
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
+ """
+
+ def __datafusion_catalog_provider__(self, session: Any) -> object: ...
+
+
class SchemaProvider(ABC):
"""Abstract class for defining a Python based Schema Provider."""
diff --git a/python/datafusion/context.py b/python/datafusion/context.py
index 7b92c082..0d825977 100644
--- a/python/datafusion/context.py
+++ b/python/datafusion/context.py
@@ -31,7 +31,13 @@ except ImportError:
import pyarrow as pa
-from datafusion.catalog import Catalog
+from datafusion.catalog import (
+ Catalog,
+ CatalogList,
+ CatalogProviderExportable,
+ CatalogProviderList,
+ CatalogProviderListExportable,
+)
from datafusion.dataframe import DataFrame
from datafusion.expr import sort_list_to_raw_sort_list
from datafusion.options import (
@@ -96,15 +102,6 @@ class TableProviderExportable(Protocol):
def __datafusion_table_provider__(self, session: Any) -> object: ... #
noqa: D105
-class CatalogProviderExportable(Protocol):
- """Type hint for object that has __datafusion_catalog_provider__ PyCapsule.
-
-
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
- """
-
- def __datafusion_catalog_provider__(self, session: Any) -> object: ... #
noqa: D105
-
-
class SessionConfig:
"""Session configuration options."""
@@ -837,6 +834,16 @@ class SessionContext:
"""Returns the list of catalogs in this context."""
return self.ctx.catalog_names()
+ def register_catalog_provider_list(
+ self,
+ provider: CatalogProviderListExportable | CatalogProviderList |
CatalogList,
+ ) -> None:
+ """Register a catalog provider list."""
+ if isinstance(provider, CatalogList):
+ self.ctx.register_catalog_provider_list(provider.catalog)
+ else:
+ self.ctx.register_catalog_provider_list(provider)
+
def register_catalog_provider(
self, name: str, provider: CatalogProviderExportable | CatalogProvider
| Catalog
) -> None:
diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py
index 08f494de..dd4c8246 100644
--- a/python/tests/test_catalog.py
+++ b/python/tests/test_catalog.py
@@ -16,11 +16,16 @@
# under the License.
from __future__ import annotations
+from typing import TYPE_CHECKING
+
import datafusion as dfn
import pyarrow as pa
import pyarrow.dataset as ds
import pytest
-from datafusion import SessionContext, Table, udtf
+from datafusion import Catalog, SessionContext, Table, udtf
+
+if TYPE_CHECKING:
+ from datafusion.catalog import CatalogProvider, CatalogProviderExportable
# Note we take in `database` as a variable even though we don't use
@@ -93,6 +98,34 @@ class CustomCatalogProvider(dfn.catalog.CatalogProvider):
del self.schemas[name]
+class CustomCatalogProviderList(dfn.catalog.CatalogProviderList):
+ def __init__(self):
+ self.catalogs = {"my_catalog": CustomCatalogProvider()}
+
+ def catalog_names(self) -> set[str]:
+ return set(self.catalogs.keys())
+
+ def catalog(self, name: str) -> Catalog | None:
+ return self.catalogs[name]
+
+ def register_catalog(
+ self, name: str, catalog: CatalogProviderExportable | CatalogProvider
| Catalog
+ ) -> None:
+ self.catalogs[name] = catalog
+
+
+def test_python_catalog_provider_list(ctx: SessionContext):
+ ctx.register_catalog_provider_list(CustomCatalogProviderList())
+
+ # Ensure `datafusion` catalog does not exist since
+ # we replaced the catalog list
+ assert ctx.catalog_names() == {"my_catalog"}
+
+ # Ensure registering works
+ ctx.register_catalog_provider("second_catalog", Catalog.memory_catalog())
+ assert ctx.catalog_names() == {"my_catalog", "second_catalog"}
+
+
def test_python_catalog_provider(ctx: SessionContext):
ctx.register_catalog_provider("my_catalog", CustomCatalogProvider())
diff --git a/src/catalog.rs b/src/catalog.rs
index 10ca1dd1..b5b98397 100644
--- a/src/catalog.rs
+++ b/src/catalog.rs
@@ -21,10 +21,12 @@ use std::sync::Arc;
use async_trait::async_trait;
use datafusion::catalog::{
- CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider,
SchemaProvider,
+ CatalogProvider, CatalogProviderList, MemoryCatalogProvider,
MemoryCatalogProviderList,
+ MemorySchemaProvider, SchemaProvider,
};
use datafusion::common::DataFusionError;
use datafusion::datasource::TableProvider;
+use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use datafusion_ffi::schema_provider::FFI_SchemaProvider;
use pyo3::exceptions::PyKeyError;
@@ -40,6 +42,18 @@ use crate::utils::{
wait_for_future,
};
+#[pyclass(
+ frozen,
+ name = "RawCatalogList",
+ module = "datafusion.catalog",
+ subclass
+)]
+#[derive(Clone)]
+pub struct PyCatalogList {
+ pub catalog_list: Arc<dyn CatalogProviderList>,
+ codec: Arc<FFI_LogicalExtensionCodec>,
+}
+
#[pyclass(frozen, name = "RawCatalog", module = "datafusion.catalog",
subclass)]
#[derive(Clone)]
pub struct PyCatalog {
@@ -72,6 +86,77 @@ impl PySchema {
}
}
+#[pymethods]
+impl PyCatalogList {
+ #[new]
+ pub fn new(
+ py: Python,
+ catalog_list: Py<PyAny>,
+ session: Option<Bound<PyAny>>,
+ ) -> PyResult<Self> {
+ let codec = extract_logical_extension_codec(py, session)?;
+ let catalog_list = Arc::new(RustWrappedPyCatalogProviderList::new(
+ catalog_list,
+ codec.clone(),
+ )) as Arc<dyn CatalogProviderList>;
+ Ok(Self {
+ catalog_list,
+ codec,
+ })
+ }
+
+ #[staticmethod]
+ pub fn memory_catalog_list(py: Python, session: Option<Bound<PyAny>>) ->
PyResult<Self> {
+ let codec = extract_logical_extension_codec(py, session)?;
+ let catalog_list =
+ Arc::new(MemoryCatalogProviderList::default()) as Arc<dyn
CatalogProviderList>;
+ Ok(Self {
+ catalog_list,
+ codec,
+ })
+ }
+
+ pub fn catalog_names(&self) -> HashSet<String> {
+ self.catalog_list.catalog_names().into_iter().collect()
+ }
+
+ #[pyo3(signature = (name="public"))]
+ pub fn catalog(&self, name: &str) -> PyResult<Py<PyAny>> {
+ let catalog = self
+ .catalog_list
+ .catalog(name)
+ .ok_or(PyKeyError::new_err(format!(
+ "Schema with name {name} doesn't exist."
+ )))?;
+
+ Python::attach(|py| {
+ match catalog
+ .as_any()
+ .downcast_ref::<RustWrappedPyCatalogProvider>()
+ {
+ Some(wrapped_catalog) =>
Ok(wrapped_catalog.catalog_provider.clone_ref(py)),
+ None => PyCatalog::new_from_parts(catalog,
self.codec.clone()).into_py_any(py),
+ }
+ })
+ }
+
+ pub fn register_catalog(&self, name: &str, catalog_provider: Bound<'_,
PyAny>) -> PyResult<()> {
+ let provider = extract_catalog_provider_from_pyobj(catalog_provider,
self.codec.as_ref())?;
+
+ let _ = self
+ .catalog_list
+ .register_catalog(name.to_owned(), provider);
+
+ Ok(())
+ }
+
+ pub fn __repr__(&self) -> PyResult<String> {
+ let mut names: Vec<String> =
self.catalog_names().into_iter().collect();
+ names.sort();
+ Ok(format!("CatalogList(catalog_names=[{}])", names.join(", ")))
+ }
+}
+
#[pymethods]
impl PyCatalog {
#[new]
@@ -373,8 +458,9 @@ impl CatalogProvider for RustWrappedPyCatalogProvider {
Python::attach(|py| {
let provider = self.catalog_provider.bind(py);
provider
- .getattr("schema_names")
- .and_then(|names| names.extract::<Vec<String>>())
+ .call_method0("schema_names")
+ .and_then(|names| names.extract::<HashSet<String>>())
+ .map(|names| names.into_iter().collect())
.unwrap_or_else(|err| {
log::error!("Unable to get schema_names: {err}");
Vec::default()
@@ -442,6 +528,138 @@ impl CatalogProvider for RustWrappedPyCatalogProvider {
}
}
+#[derive(Debug)]
+pub(crate) struct RustWrappedPyCatalogProviderList {
+ pub(crate) catalog_provider_list: Py<PyAny>,
+ codec: Arc<FFI_LogicalExtensionCodec>,
+}
+
+impl RustWrappedPyCatalogProviderList {
+ pub fn new(catalog_provider_list: Py<PyAny>, codec:
Arc<FFI_LogicalExtensionCodec>) -> Self {
+ Self {
+ catalog_provider_list,
+ codec,
+ }
+ }
+
+ fn catalog_inner(&self, name: &str) -> PyResult<Option<Arc<dyn
CatalogProvider>>> {
+ Python::attach(|py| {
+ let provider = self.catalog_provider_list.bind(py);
+
+ let py_schema = provider.call_method1("catalog", (name,))?;
+ if py_schema.is_none() {
+ return Ok(None);
+ }
+
+ extract_catalog_provider_from_pyobj(py_schema,
self.codec.as_ref()).map(Some)
+ })
+ }
+}
+
+#[async_trait]
+impl CatalogProviderList for RustWrappedPyCatalogProviderList {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn catalog_names(&self) -> Vec<String> {
+ Python::attach(|py| {
+ let provider = self.catalog_provider_list.bind(py);
+ provider
+ .call_method0("catalog_names")
+ .and_then(|names| names.extract::<HashSet<String>>())
+ .map(|names| names.into_iter().collect())
+ .unwrap_or_else(|err| {
+ log::error!("Unable to get catalog_names: {err}");
+ Vec::default()
+ })
+ })
+ }
+
+ fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
+ self.catalog_inner(name).unwrap_or_else(|err| {
+ log::error!("CatalogProvider catalog returned error: {err}");
+ None
+ })
+ }
+
+ fn register_catalog(
+ &self,
+ name: String,
+ catalog: Arc<dyn CatalogProvider>,
+ ) -> Option<Arc<dyn CatalogProvider>> {
+ Python::attach(|py| {
+ let py_catalog = match catalog
+ .as_any()
+ .downcast_ref::<RustWrappedPyCatalogProvider>()
+ {
+ Some(wrapped_schema) =>
wrapped_schema.catalog_provider.as_any().clone_ref(py),
+ None => {
+ match PyCatalog::new_from_parts(catalog,
self.codec.clone()).into_py_any(py) {
+ Ok(c) => c,
+ Err(err) => {
+ log::error!(
+ "register_catalog returned error during
conversion to PyAny: {err}"
+ );
+ return None;
+ }
+ }
+ }
+ };
+
+ let provider = self.catalog_provider_list.bind(py);
+ let catalog = match provider.call_method1("register_catalog",
(name, py_catalog)) {
+ Ok(c) => c,
+ Err(err) => {
+ log::error!("register_catalog returned error: {err}");
+ return None;
+ }
+ };
+ if catalog.is_none() {
+ return None;
+ }
+
+ let catalog = Arc::new(RustWrappedPyCatalogProvider::new(
+ catalog.into(),
+ self.codec.clone(),
+ )) as Arc<dyn CatalogProvider>;
+
+ Some(catalog)
+ })
+ }
+}
+
+fn extract_catalog_provider_from_pyobj(
+ mut catalog_provider: Bound<PyAny>,
+ codec: &FFI_LogicalExtensionCodec,
+) -> PyResult<Arc<dyn CatalogProvider>> {
+ if catalog_provider.hasattr("__datafusion_catalog_provider__")? {
+ let py = catalog_provider.py();
+ let codec_capsule = create_logical_extension_capsule(py, codec)?;
+ catalog_provider = catalog_provider
+ .getattr("__datafusion_catalog_provider__")?
+ .call1((codec_capsule,))?;
+ }
+
+ let provider = if let Ok(capsule) =
catalog_provider.downcast::<PyCapsule>() {
+ validate_pycapsule(capsule, "datafusion_catalog_provider")?;
+
+ let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
+ let provider: Arc<dyn CatalogProvider + Send> = provider.into();
+ provider as Arc<dyn CatalogProvider>
+ } else {
+ match catalog_provider.extract::<PyCatalog>() {
+ Ok(py_catalog) => py_catalog.catalog,
+ Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(
+ catalog_provider.into(),
+ Arc::new(codec.clone()),
+ )) as Arc<dyn CatalogProvider>,
+ }
+ };
+
+ Ok(provider)
+}
+
fn extract_schema_provider_from_pyobj(
mut schema_provider: Bound<PyAny>,
codec: &FFI_LogicalExtensionCodec,
diff --git a/src/context.rs b/src/context.rs
index f28c5982..89bbe934 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -26,7 +26,7 @@ use arrow::pyarrow::FromPyArrow;
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
-use datafusion::catalog::CatalogProvider;
+use datafusion::catalog::{CatalogProvider, CatalogProviderList};
use datafusion::common::{exec_err, ScalarValue, TableReference};
use
datafusion::datasource::file_format::file_compression_type::FileCompressionType;
use datafusion::datasource::file_format::parquet::ParquetFormat;
@@ -47,6 +47,7 @@ use datafusion::prelude::{
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions,
ParquetReadOptions,
};
use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
+use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList;
use datafusion_ffi::execution::FFI_TaskContextProvider;
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
@@ -58,7 +59,9 @@ use pyo3::IntoPyObjectExt;
use url::Url;
use uuid::Uuid;
-use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider};
+use crate::catalog::{
+ PyCatalog, PyCatalogList, RustWrappedPyCatalogProvider,
RustWrappedPyCatalogProviderList,
+};
use crate::common::data_type::PyScalarValue;
use crate::dataframe::PyDataFrame;
use crate::dataset::Dataset;
@@ -627,6 +630,40 @@ impl PySessionContext {
Ok(())
}
+ pub fn register_catalog_provider_list(
+ &self,
+ mut provider: Bound<PyAny>,
+ ) -> PyDataFusionResult<()> {
+ if provider.hasattr("__datafusion_catalog_provider_list__")? {
+ let py = provider.py();
+ let codec_capsule = create_logical_extension_capsule(py,
self.logical_codec.as_ref())?;
+ provider = provider
+ .getattr("__datafusion_catalog_provider_list__")?
+ .call1((codec_capsule,))?;
+ }
+
+ let provider =
+ if let Ok(capsule) =
provider.downcast::<PyCapsule>().map_err(py_datafusion_err) {
+ validate_pycapsule(capsule,
"datafusion_catalog_provider_list")?;
+
+ let provider = unsafe {
capsule.reference::<FFI_CatalogProviderList>() };
+ let provider: Arc<dyn CatalogProviderList + Send> =
provider.into();
+ provider as Arc<dyn CatalogProviderList>
+ } else {
+ match provider.extract::<PyCatalogList>() {
+ Ok(py_catalog_list) => py_catalog_list.catalog_list,
+ Err(_) => Arc::new(RustWrappedPyCatalogProviderList::new(
+ provider.into(),
+ Arc::clone(&self.logical_codec),
+ )) as Arc<dyn CatalogProviderList>,
+ }
+ };
+
+ self.ctx.register_catalog_list(provider);
+
+ Ok(())
+ }
+
pub fn register_catalog_provider(
&self,
name: &str,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]