This is an automated email from the ASF dual-hosted git repository.

timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 8fc94362 feat: add CatalogProviderList support (#1363)
8fc94362 is described below

commit 8fc943629b93b342ef67bb8aea0aa581615a374d
Author: Tim Saucer <[email protected]>
AuthorDate: Mon Feb 9 09:17:26 2026 -0500

    feat: add CatalogProviderList support (#1363)
    
    * Implement catalog provider list
    
    * Flush out python side and add unit test
    
    * Add FFI test for catalog provider list
    
    * Update type hints
    
    * Update unit test to add a different type of catalog to the catalog list
---
 .../python/tests/_test_catalog_provider.py         |  26 ++-
 .../datafusion-ffi-example/src/catalog_provider.rs |  69 ++++++-
 examples/datafusion-ffi-example/src/lib.rs         |   3 +-
 python/datafusion/catalog.py                       |  91 +++++++++
 python/datafusion/context.py                       |  27 ++-
 python/tests/test_catalog.py                       |  35 +++-
 src/catalog.rs                                     | 224 ++++++++++++++++++++-
 src/context.rs                                     |  41 +++-
 8 files changed, 496 insertions(+), 20 deletions(-)

diff --git 
a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py 
b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py
index b26e1208..a862b23b 100644
--- a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py
+++ b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py
@@ -22,7 +22,7 @@ import pyarrow.dataset as ds
 import pytest
 from datafusion import SessionContext, Table
 from datafusion.catalog import Schema
-from datafusion_ffi_example import MyCatalogProvider
+from datafusion_ffi_example import MyCatalogProvider, MyCatalogProviderList
 
 
 def create_test_dataset() -> Table:
@@ -35,6 +35,30 @@ def create_test_dataset() -> Table:
     return Table(dataset)
 
 
[email protected]("inner_capsule", [True, False])
+def test_ffi_catalog_provider_list(inner_capsule: bool) -> None:
+    """Test basic FFI CatalogProviderList functionality."""
+    ctx = SessionContext()
+
+    # Register FFI catalog
+    catalog_provider_list = MyCatalogProviderList()
+    if inner_capsule:
+        catalog_provider_list = (
+            catalog_provider_list.__datafusion_catalog_provider_list__(ctx)
+        )
+
+    ctx.register_catalog_provider_list(catalog_provider_list)
+
+    # Verify the catalog exists
+    catalog = ctx.catalog("auto_ffi_catalog")
+    schema_names = catalog.names()
+    assert "my_schema" in schema_names
+
+    ctx.register_catalog_provider("second", MyCatalogProvider())
+
+    assert ctx.catalog_names() == {"auto_ffi_catalog", "second"}
+
+
 @pytest.mark.parametrize("inner_capsule", [True, False])
 def test_ffi_catalog_provider_basic(inner_capsule: bool) -> None:
     """Test basic FFI CatalogProvider functionality."""
diff --git a/examples/datafusion-ffi-example/src/catalog_provider.rs 
b/examples/datafusion-ffi-example/src/catalog_provider.rs
index 57022274..aee23602 100644
--- a/examples/datafusion-ffi-example/src/catalog_provider.rs
+++ b/examples/datafusion-ffi-example/src/catalog_provider.rs
@@ -22,11 +22,12 @@ use std::sync::Arc;
 use arrow::datatypes::Schema;
 use async_trait::async_trait;
 use datafusion_catalog::{
-    CatalogProvider, MemTable, MemoryCatalogProvider, MemorySchemaProvider, 
SchemaProvider,
-    TableProvider,
+    CatalogProvider, CatalogProviderList, MemTable, MemoryCatalogProvider,
+    MemoryCatalogProviderList, MemorySchemaProvider, SchemaProvider, 
TableProvider,
 };
 use datafusion_common::error::{DataFusionError, Result};
 use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
+use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList;
 use datafusion_ffi::schema_provider::FFI_SchemaProvider;
 use pyo3::types::PyCapsule;
 use pyo3::{pyclass, pymethods, Bound, PyAny, PyResult, Python};
@@ -203,3 +204,67 @@ impl MyCatalogProvider {
         PyCapsule::new(py, provider, Some(name))
     }
 }
+
+/// This catalog provider list is intended only for unit tests.
+/// It pre-populates with a single catalog.
+#[pyclass(
+    name = "MyCatalogProviderList",
+    module = "datafusion_ffi_example",
+    subclass
+)]
+#[derive(Debug, Clone)]
+pub(crate) struct MyCatalogProviderList {
+    inner: Arc<MemoryCatalogProviderList>,
+}
+
+impl CatalogProviderList for MyCatalogProviderList {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn catalog_names(&self) -> Vec<String> {
+        self.inner.catalog_names()
+    }
+
+    fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
+        self.inner.catalog(name)
+    }
+
+    fn register_catalog(
+        &self,
+        name: String,
+        catalog: Arc<dyn CatalogProvider>,
+    ) -> Option<Arc<dyn CatalogProvider>> {
+        self.inner.register_catalog(name, catalog)
+    }
+}
+
+#[pymethods]
+impl MyCatalogProviderList {
+    #[new]
+    pub fn new() -> PyResult<Self> {
+        let inner = Arc::new(MemoryCatalogProviderList::new());
+
+        inner.register_catalog(
+            "auto_ffi_catalog".to_owned(),
+            Arc::new(MyCatalogProvider::new()?),
+        );
+
+        Ok(Self { inner })
+    }
+
+    pub fn __datafusion_catalog_provider_list__<'py>(
+        &self,
+        py: Python<'py>,
+        session: Bound<PyAny>,
+    ) -> PyResult<Bound<'py, PyCapsule>> {
+        let name = cr"datafusion_catalog_provider_list".into();
+
+        let provider = Arc::clone(&self.inner) as Arc<dyn CatalogProviderList 
+ Send>;
+
+        let codec = ffi_logical_codec_from_pycapsule(session)?;
+        let provider = FFI_CatalogProviderList::new_with_ffi_codec(provider, 
None, codec);
+
+        PyCapsule::new(py, provider, Some(name))
+    }
+}
diff --git a/examples/datafusion-ffi-example/src/lib.rs 
b/examples/datafusion-ffi-example/src/lib.rs
index 005d8b80..6c64c9fe 100644
--- a/examples/datafusion-ffi-example/src/lib.rs
+++ b/examples/datafusion-ffi-example/src/lib.rs
@@ -18,7 +18,7 @@
 use pyo3::prelude::*;
 
 use crate::aggregate_udf::MySumUDF;
-use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider};
+use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider, 
MyCatalogProviderList};
 use crate::scalar_udf::IsNullUDF;
 use crate::table_function::MyTableFunction;
 use crate::table_provider::MyTableProvider;
@@ -37,6 +37,7 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> 
PyResult<()> {
     m.add_class::<MyTableProvider>()?;
     m.add_class::<MyTableFunction>()?;
     m.add_class::<MyCatalogProvider>()?;
+    m.add_class::<MyCatalogProviderList>()?;
     m.add_class::<FixedSchemaProvider>()?;
     m.add_class::<IsNullUDF>()?;
     m.add_class::<MySumUDF>()?;
diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py
index 16c3ccc2..bc43cf34 100644
--- a/python/datafusion/catalog.py
+++ b/python/datafusion/catalog.py
@@ -38,13 +38,61 @@ except ImportError:
 
 __all__ = [
     "Catalog",
+    "CatalogList",
     "CatalogProvider",
+    "CatalogProviderList",
     "Schema",
     "SchemaProvider",
     "Table",
 ]
 
 
+class CatalogList:
+    """DataFusion data catalog list."""
+
+    def __init__(self, catalog_list: df_internal.catalog.RawCatalogList) -> 
None:
+        """This constructor is not typically called by the end user."""
+        self.catalog_list = catalog_list
+
+    def __repr__(self) -> str:
+        """Print a string representation of the catalog list."""
+        return self.catalog_list.__repr__()
+
+    def names(self) -> set[str]:
+        """This is an alias for `catalog_names`."""
+        return self.catalog_names()
+
+    def catalog_names(self) -> set[str]:
+        """Returns the list of schemas in this catalog."""
+        return self.catalog_list.catalog_names()
+
+    @staticmethod
+    def memory_catalog(ctx: SessionContext | None = None) -> CatalogList:
+        """Create an in-memory catalog provider list."""
+        catalog_list = df_internal.catalog.RawCatalogList.memory_catalog(ctx)
+        return CatalogList(catalog_list)
+
+    def catalog(self, name: str = "datafusion") -> Catalog:
+        """Returns the catalog with the given ``name`` from this catalog."""
+        catalog = self.catalog_list.catalog(name)
+
+        return (
+            Catalog(catalog)
+            if isinstance(catalog, df_internal.catalog.RawCatalog)
+            else catalog
+        )
+
+    def register_catalog(
+        self,
+        name: str,
+        catalog: Catalog | CatalogProvider | CatalogProviderExportable,
+    ) -> Catalog | None:
+        """Register a catalog with this catalog list."""
+        if isinstance(catalog, Catalog):
+            return self.catalog_list.register_catalog(name, catalog.catalog)
+        return self.catalog_list.register_catalog(name, catalog)
+
+
 class Catalog:
     """DataFusion data catalog."""
 
@@ -195,6 +243,40 @@ class Table:
         return self._inner.kind
 
 
+class CatalogProviderList(ABC):
+    """Abstract class for defining a Python based Catalog Provider List."""
+
+    @abstractmethod
+    def catalog_names(self) -> set[str]:
+        """Set of the names of all catalogs in this catalog list."""
+        ...
+
+    @abstractmethod
+    def catalog(
+        self, name: str
+    ) -> CatalogProviderExportable | CatalogProvider | Catalog | None:
+        """Retrieve a specific catalog from this catalog list."""
+        ...
+
+    def register_catalog(  # noqa: B027
+        self, name: str, catalog: CatalogProviderExportable | CatalogProvider 
| Catalog
+    ) -> None:
+        """Add a catalog to this catalog list.
+
+        This method is optional. If your catalog provides a fixed list of 
catalogs, you
+        do not need to implement this method.
+        """
+
+
+class CatalogProviderListExportable(Protocol):
+    """Type hint for object that has __datafusion_catalog_provider_list__ 
PyCapsule.
+
+    
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProviderList.html
+    """
+
+    def __datafusion_catalog_provider_list__(self, session: Any) -> object: ...
+
+
 class CatalogProvider(ABC):
     """Abstract class for defining a Python based Catalog Provider."""
 
@@ -229,6 +311,15 @@ class CatalogProvider(ABC):
         """
 
 
+class CatalogProviderExportable(Protocol):
+    """Type hint for object that has __datafusion_catalog_provider__ PyCapsule.
+
+    
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
+    """
+
+    def __datafusion_catalog_provider__(self, session: Any) -> object: ...
+
+
 class SchemaProvider(ABC):
     """Abstract class for defining a Python based Schema Provider."""
 
diff --git a/python/datafusion/context.py b/python/datafusion/context.py
index 7b92c082..0d825977 100644
--- a/python/datafusion/context.py
+++ b/python/datafusion/context.py
@@ -31,7 +31,13 @@ except ImportError:
 
 import pyarrow as pa
 
-from datafusion.catalog import Catalog
+from datafusion.catalog import (
+    Catalog,
+    CatalogList,
+    CatalogProviderExportable,
+    CatalogProviderList,
+    CatalogProviderListExportable,
+)
 from datafusion.dataframe import DataFrame
 from datafusion.expr import sort_list_to_raw_sort_list
 from datafusion.options import (
@@ -96,15 +102,6 @@ class TableProviderExportable(Protocol):
     def __datafusion_table_provider__(self, session: Any) -> object: ...  # 
noqa: D105
 
 
-class CatalogProviderExportable(Protocol):
-    """Type hint for object that has __datafusion_catalog_provider__ PyCapsule.
-
-    
https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
-    """
-
-    def __datafusion_catalog_provider__(self, session: Any) -> object: ...  # 
noqa: D105
-
-
 class SessionConfig:
     """Session configuration options."""
 
@@ -837,6 +834,16 @@ class SessionContext:
         """Returns the list of catalogs in this context."""
         return self.ctx.catalog_names()
 
+    def register_catalog_provider_list(
+        self,
+        provider: CatalogProviderListExportable | CatalogProviderList | 
CatalogList,
+    ) -> None:
+        """Register a catalog provider list."""
+        if isinstance(provider, CatalogList):
+            self.ctx.register_catalog_provider_list(provider.catalog)
+        else:
+            self.ctx.register_catalog_provider_list(provider)
+
     def register_catalog_provider(
         self, name: str, provider: CatalogProviderExportable | CatalogProvider 
| Catalog
     ) -> None:
diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py
index 08f494de..dd4c8246 100644
--- a/python/tests/test_catalog.py
+++ b/python/tests/test_catalog.py
@@ -16,11 +16,16 @@
 # under the License.
 from __future__ import annotations
 
+from typing import TYPE_CHECKING
+
 import datafusion as dfn
 import pyarrow as pa
 import pyarrow.dataset as ds
 import pytest
-from datafusion import SessionContext, Table, udtf
+from datafusion import Catalog, SessionContext, Table, udtf
+
+if TYPE_CHECKING:
+    from datafusion.catalog import CatalogProvider, CatalogProviderExportable
 
 
 # Note we take in `database` as a variable even though we don't use
@@ -93,6 +98,34 @@ class CustomCatalogProvider(dfn.catalog.CatalogProvider):
         del self.schemas[name]
 
 
+class CustomCatalogProviderList(dfn.catalog.CatalogProviderList):
+    def __init__(self):
+        self.catalogs = {"my_catalog": CustomCatalogProvider()}
+
+    def catalog_names(self) -> set[str]:
+        return set(self.catalogs.keys())
+
+    def catalog(self, name: str) -> Catalog | None:
+        return self.catalogs[name]
+
+    def register_catalog(
+        self, name: str, catalog: CatalogProviderExportable | CatalogProvider 
| Catalog
+    ) -> None:
+        self.catalogs[name] = catalog
+
+
+def test_python_catalog_provider_list(ctx: SessionContext):
+    ctx.register_catalog_provider_list(CustomCatalogProviderList())
+
+    # Ensure `datafusion` catalog does not exist since
+    # we replaced the catalog list
+    assert ctx.catalog_names() == {"my_catalog"}
+
+    # Ensure registering works
+    ctx.register_catalog_provider("second_catalog", Catalog.memory_catalog())
+    assert ctx.catalog_names() == {"my_catalog", "second_catalog"}
+
+
 def test_python_catalog_provider(ctx: SessionContext):
     ctx.register_catalog_provider("my_catalog", CustomCatalogProvider())
 
diff --git a/src/catalog.rs b/src/catalog.rs
index 10ca1dd1..b5b98397 100644
--- a/src/catalog.rs
+++ b/src/catalog.rs
@@ -21,10 +21,12 @@ use std::sync::Arc;
 
 use async_trait::async_trait;
 use datafusion::catalog::{
-    CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, 
SchemaProvider,
+    CatalogProvider, CatalogProviderList, MemoryCatalogProvider, 
MemoryCatalogProviderList,
+    MemorySchemaProvider, SchemaProvider,
 };
 use datafusion::common::DataFusionError;
 use datafusion::datasource::TableProvider;
+use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
 use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
 use datafusion_ffi::schema_provider::FFI_SchemaProvider;
 use pyo3::exceptions::PyKeyError;
@@ -40,6 +42,18 @@ use crate::utils::{
     wait_for_future,
 };
 
+#[pyclass(
+    frozen,
+    name = "RawCatalogList",
+    module = "datafusion.catalog",
+    subclass
+)]
+#[derive(Clone)]
+pub struct PyCatalogList {
+    pub catalog_list: Arc<dyn CatalogProviderList>,
+    codec: Arc<FFI_LogicalExtensionCodec>,
+}
+
 #[pyclass(frozen, name = "RawCatalog", module = "datafusion.catalog", 
subclass)]
 #[derive(Clone)]
 pub struct PyCatalog {
@@ -72,6 +86,77 @@ impl PySchema {
     }
 }
 
+#[pymethods]
+impl PyCatalogList {
+    #[new]
+    pub fn new(
+        py: Python,
+        catalog_list: Py<PyAny>,
+        session: Option<Bound<PyAny>>,
+    ) -> PyResult<Self> {
+        let codec = extract_logical_extension_codec(py, session)?;
+        let catalog_list = Arc::new(RustWrappedPyCatalogProviderList::new(
+            catalog_list,
+            codec.clone(),
+        )) as Arc<dyn CatalogProviderList>;
+        Ok(Self {
+            catalog_list,
+            codec,
+        })
+    }
+
+    #[staticmethod]
+    pub fn memory_catalog_list(py: Python, session: Option<Bound<PyAny>>) -> 
PyResult<Self> {
+        let codec = extract_logical_extension_codec(py, session)?;
+        let catalog_list =
+            Arc::new(MemoryCatalogProviderList::default()) as Arc<dyn 
CatalogProviderList>;
+        Ok(Self {
+            catalog_list,
+            codec,
+        })
+    }
+
+    pub fn catalog_names(&self) -> HashSet<String> {
+        self.catalog_list.catalog_names().into_iter().collect()
+    }
+
+    #[pyo3(signature = (name="public"))]
+    pub fn catalog(&self, name: &str) -> PyResult<Py<PyAny>> {
+        let catalog = self
+            .catalog_list
+            .catalog(name)
+            .ok_or(PyKeyError::new_err(format!(
+                "Schema with name {name} doesn't exist."
+            )))?;
+
+        Python::attach(|py| {
+            match catalog
+                .as_any()
+                .downcast_ref::<RustWrappedPyCatalogProvider>()
+            {
+                Some(wrapped_catalog) => 
Ok(wrapped_catalog.catalog_provider.clone_ref(py)),
+                None => PyCatalog::new_from_parts(catalog, 
self.codec.clone()).into_py_any(py),
+            }
+        })
+    }
+
+    pub fn register_catalog(&self, name: &str, catalog_provider: Bound<'_, 
PyAny>) -> PyResult<()> {
+        let provider = extract_catalog_provider_from_pyobj(catalog_provider, 
self.codec.as_ref())?;
+
+        let _ = self
+            .catalog_list
+            .register_catalog(name.to_owned(), provider);
+
+        Ok(())
+    }
+
+    pub fn __repr__(&self) -> PyResult<String> {
+        let mut names: Vec<String> = 
self.catalog_names().into_iter().collect();
+        names.sort();
+        Ok(format!("CatalogList(catalog_names=[{}])", names.join(", ")))
+    }
+}
+
 #[pymethods]
 impl PyCatalog {
     #[new]
@@ -373,8 +458,9 @@ impl CatalogProvider for RustWrappedPyCatalogProvider {
         Python::attach(|py| {
             let provider = self.catalog_provider.bind(py);
             provider
-                .getattr("schema_names")
-                .and_then(|names| names.extract::<Vec<String>>())
+                .call_method0("schema_names")
+                .and_then(|names| names.extract::<HashSet<String>>())
+                .map(|names| names.into_iter().collect())
                 .unwrap_or_else(|err| {
                     log::error!("Unable to get schema_names: {err}");
                     Vec::default()
@@ -442,6 +528,138 @@ impl CatalogProvider for RustWrappedPyCatalogProvider {
     }
 }
 
+#[derive(Debug)]
+pub(crate) struct RustWrappedPyCatalogProviderList {
+    pub(crate) catalog_provider_list: Py<PyAny>,
+    codec: Arc<FFI_LogicalExtensionCodec>,
+}
+
+impl RustWrappedPyCatalogProviderList {
+    pub fn new(catalog_provider_list: Py<PyAny>, codec: 
Arc<FFI_LogicalExtensionCodec>) -> Self {
+        Self {
+            catalog_provider_list,
+            codec,
+        }
+    }
+
+    fn catalog_inner(&self, name: &str) -> PyResult<Option<Arc<dyn 
CatalogProvider>>> {
+        Python::attach(|py| {
+            let provider = self.catalog_provider_list.bind(py);
+
+            let py_schema = provider.call_method1("catalog", (name,))?;
+            if py_schema.is_none() {
+                return Ok(None);
+            }
+
+            extract_catalog_provider_from_pyobj(py_schema, 
self.codec.as_ref()).map(Some)
+        })
+    }
+}
+
+#[async_trait]
+impl CatalogProviderList for RustWrappedPyCatalogProviderList {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn catalog_names(&self) -> Vec<String> {
+        Python::attach(|py| {
+            let provider = self.catalog_provider_list.bind(py);
+            provider
+                .call_method0("catalog_names")
+                .and_then(|names| names.extract::<HashSet<String>>())
+                .map(|names| names.into_iter().collect())
+                .unwrap_or_else(|err| {
+                    log::error!("Unable to get catalog_names: {err}");
+                    Vec::default()
+                })
+        })
+    }
+
+    fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
+        self.catalog_inner(name).unwrap_or_else(|err| {
+            log::error!("CatalogProvider catalog returned error: {err}");
+            None
+        })
+    }
+
+    fn register_catalog(
+        &self,
+        name: String,
+        catalog: Arc<dyn CatalogProvider>,
+    ) -> Option<Arc<dyn CatalogProvider>> {
+        Python::attach(|py| {
+            let py_catalog = match catalog
+                .as_any()
+                .downcast_ref::<RustWrappedPyCatalogProvider>()
+            {
+                Some(wrapped_schema) => 
wrapped_schema.catalog_provider.as_any().clone_ref(py),
+                None => {
+                    match PyCatalog::new_from_parts(catalog, 
self.codec.clone()).into_py_any(py) {
+                        Ok(c) => c,
+                        Err(err) => {
+                            log::error!(
+                                "register_catalog returned error during 
conversion to PyAny: {err}"
+                            );
+                            return None;
+                        }
+                    }
+                }
+            };
+
+            let provider = self.catalog_provider_list.bind(py);
+            let catalog = match provider.call_method1("register_catalog", 
(name, py_catalog)) {
+                Ok(c) => c,
+                Err(err) => {
+                    log::error!("register_catalog returned error: {err}");
+                    return None;
+                }
+            };
+            if catalog.is_none() {
+                return None;
+            }
+
+            let catalog = Arc::new(RustWrappedPyCatalogProvider::new(
+                catalog.into(),
+                self.codec.clone(),
+            )) as Arc<dyn CatalogProvider>;
+
+            Some(catalog)
+        })
+    }
+}
+
+fn extract_catalog_provider_from_pyobj(
+    mut catalog_provider: Bound<PyAny>,
+    codec: &FFI_LogicalExtensionCodec,
+) -> PyResult<Arc<dyn CatalogProvider>> {
+    if catalog_provider.hasattr("__datafusion_catalog_provider__")? {
+        let py = catalog_provider.py();
+        let codec_capsule = create_logical_extension_capsule(py, codec)?;
+        catalog_provider = catalog_provider
+            .getattr("__datafusion_catalog_provider__")?
+            .call1((codec_capsule,))?;
+    }
+
+    let provider = if let Ok(capsule) = 
catalog_provider.downcast::<PyCapsule>() {
+        validate_pycapsule(capsule, "datafusion_catalog_provider")?;
+
+        let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
+        let provider: Arc<dyn CatalogProvider + Send> = provider.into();
+        provider as Arc<dyn CatalogProvider>
+    } else {
+        match catalog_provider.extract::<PyCatalog>() {
+            Ok(py_catalog) => py_catalog.catalog,
+            Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(
+                catalog_provider.into(),
+                Arc::new(codec.clone()),
+            )) as Arc<dyn CatalogProvider>,
+        }
+    };
+
+    Ok(provider)
+}
+
 fn extract_schema_provider_from_pyobj(
     mut schema_provider: Bound<PyAny>,
     codec: &FFI_LogicalExtensionCodec,
diff --git a/src/context.rs b/src/context.rs
index f28c5982..89bbe934 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -26,7 +26,7 @@ use arrow::pyarrow::FromPyArrow;
 use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
 use datafusion::arrow::pyarrow::PyArrowType;
 use datafusion::arrow::record_batch::RecordBatch;
-use datafusion::catalog::CatalogProvider;
+use datafusion::catalog::{CatalogProvider, CatalogProviderList};
 use datafusion::common::{exec_err, ScalarValue, TableReference};
 use 
datafusion::datasource::file_format::file_compression_type::FileCompressionType;
 use datafusion::datasource::file_format::parquet::ParquetFormat;
@@ -47,6 +47,7 @@ use datafusion::prelude::{
     AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, 
ParquetReadOptions,
 };
 use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
+use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList;
 use datafusion_ffi::execution::FFI_TaskContextProvider;
 use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
 use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
@@ -58,7 +59,9 @@ use pyo3::IntoPyObjectExt;
 use url::Url;
 use uuid::Uuid;
 
-use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider};
+use crate::catalog::{
+    PyCatalog, PyCatalogList, RustWrappedPyCatalogProvider, 
RustWrappedPyCatalogProviderList,
+};
 use crate::common::data_type::PyScalarValue;
 use crate::dataframe::PyDataFrame;
 use crate::dataset::Dataset;
@@ -627,6 +630,40 @@ impl PySessionContext {
         Ok(())
     }
 
+    pub fn register_catalog_provider_list(
+        &self,
+        mut provider: Bound<PyAny>,
+    ) -> PyDataFusionResult<()> {
+        if provider.hasattr("__datafusion_catalog_provider_list__")? {
+            let py = provider.py();
+            let codec_capsule = create_logical_extension_capsule(py, 
self.logical_codec.as_ref())?;
+            provider = provider
+                .getattr("__datafusion_catalog_provider_list__")?
+                .call1((codec_capsule,))?;
+        }
+
+        let provider =
+            if let Ok(capsule) = 
provider.downcast::<PyCapsule>().map_err(py_datafusion_err) {
+                validate_pycapsule(capsule, 
"datafusion_catalog_provider_list")?;
+
+                let provider = unsafe { 
capsule.reference::<FFI_CatalogProviderList>() };
+                let provider: Arc<dyn CatalogProviderList + Send> = 
provider.into();
+                provider as Arc<dyn CatalogProviderList>
+            } else {
+                match provider.extract::<PyCatalogList>() {
+                    Ok(py_catalog_list) => py_catalog_list.catalog_list,
+                    Err(_) => Arc::new(RustWrappedPyCatalogProviderList::new(
+                        provider.into(),
+                        Arc::clone(&self.logical_codec),
+                    )) as Arc<dyn CatalogProviderList>,
+                }
+            };
+
+        self.ctx.register_catalog_list(provider);
+
+        Ok(())
+    }
+
     pub fn register_catalog_provider(
         &self,
         name: &str,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to