This is an automated email from the ASF dual-hosted git repository.

timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 63b13da  chore: set validation and typehint (#983)
63b13da is described below

commit 63b13da4bccd66cb474186ebc2c4a1f8ba82230f
Author: Ion Koutsouris <[email protected]>
AuthorDate: Tue Jan 7 14:34:44 2025 +0100

    chore: set validation and typehint (#983)
---
 python/datafusion/context.py | 13 ++++++++++++-
 src/context.rs               |  4 ++--
 src/dataframe.rs             | 21 +--------------------
 src/utils.rs                 | 21 +++++++++++++++++++++
 4 files changed, 36 insertions(+), 23 deletions(-)

diff --git a/python/datafusion/context.py b/python/datafusion/context.py
index a07b5d1..3fa1333 100644
--- a/python/datafusion/context.py
+++ b/python/datafusion/context.py
@@ -63,6 +63,15 @@ class ArrowArrayExportable(Protocol):
     ) -> tuple[object, object]: ...
 
 
+class TableProviderExportable(Protocol):
+    """Type hint for object that has __datafusion_table_provider__ PyCapsule.
+
+    https://datafusion.apache.org/python/user-guide/io/table_provider.html
+    """
+
+    def __datafusion_table_provider__(self) -> object: ...  # noqa: D105
+
+
 class SessionConfig:
     """Session configuration options."""
 
@@ -685,7 +694,9 @@ class SessionContext:
         """Remove a table from the session."""
         self.ctx.deregister_table(name)
 
-    def register_table_provider(self, name: str, provider: Any) -> None:
+    def register_table_provider(
+        self, name: str, provider: TableProviderExportable
+    ) -> None:
         """Register a table provider.
 
         This table provider must have a method called 
``__datafusion_table_provider__``
diff --git a/src/context.rs b/src/context.rs
index 8675e97..0512285 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -43,7 +43,7 @@ use crate::store::StorageContexts;
 use crate::udaf::PyAggregateUDF;
 use crate::udf::PyScalarUDF;
 use crate::udwf::PyWindowUDF;
-use crate::utils::{get_tokio_runtime, wait_for_future};
+use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
 use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
 use datafusion::arrow::pyarrow::PyArrowType;
 use datafusion::arrow::record_batch::RecordBatch;
@@ -576,7 +576,7 @@ impl PySessionContext {
         if provider.hasattr("__datafusion_table_provider__")? {
             let capsule = 
provider.getattr("__datafusion_table_provider__")?.call0()?;
             let capsule = capsule.downcast::<PyCapsule>()?;
-            // validate_pycapsule(capsule, "arrow_array_stream")?;
+            validate_pycapsule(capsule, "datafusion_table_provider")?;
 
             let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
             let provider: ForeignTableProvider = provider.into();
diff --git a/src/dataframe.rs b/src/dataframe.rs
index e7d6ca6..fcb46a7 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -44,7 +44,7 @@ use crate::expr::sort_expr::to_sort_expressions;
 use crate::physical_plan::PyExecutionPlan;
 use crate::record_batch::PyRecordBatchStream;
 use crate::sql::logical::PyLogicalPlan;
-use crate::utils::{get_tokio_runtime, wait_for_future};
+use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
 use crate::{
     errors::DataFusionError,
     expr::{sort_expr::PySortExpr, PyExpr},
@@ -724,22 +724,3 @@ fn record_batch_into_schema(
 
     RecordBatch::try_new(schema, data_arrays)
 }
-
-fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
-    let capsule_name = capsule.name()?;
-    if capsule_name.is_none() {
-        return Err(PyValueError::new_err(
-            "Expected schema PyCapsule to have name set.",
-        ));
-    }
-
-    let capsule_name = capsule_name.unwrap().to_str()?;
-    if capsule_name != name {
-        return Err(PyValueError::new_err(format!(
-            "Expected name '{}' in PyCapsule, instead got '{}'",
-            name, capsule_name
-        )));
-    }
-
-    Ok(())
-}
diff --git a/src/utils.rs b/src/utils.rs
index 7fb23ca..7955897 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -18,7 +18,9 @@
 use crate::errors::DataFusionError;
 use crate::TokioRuntime;
 use datafusion::logical_expr::Volatility;
+use pyo3::exceptions::PyValueError;
 use pyo3::prelude::*;
+use pyo3::types::PyCapsule;
 use std::future::Future;
 use std::sync::OnceLock;
 use tokio::runtime::Runtime;
@@ -58,3 +60,22 @@ pub(crate) fn parse_volatility(value: &str) -> 
Result<Volatility, DataFusionErro
         }
     })
 }
+
+pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> 
PyResult<()> {
+    let capsule_name = capsule.name()?;
+    if capsule_name.is_none() {
+        return Err(PyValueError::new_err(
+            "Expected schema PyCapsule to have name set.",
+        ));
+    }
+
+    let capsule_name = capsule_name.unwrap().to_str()?;
+    if capsule_name != name {
+        return Err(PyValueError::new_err(format!(
+            "Expected name '{}' in PyCapsule, instead got '{}'",
+            name, capsule_name
+        )));
+    }
+
+    Ok(())
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to