This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 63b13da chore: set validation and typehint (#983)
63b13da is described below
commit 63b13da4bccd66cb474186ebc2c4a1f8ba82230f
Author: Ion Koutsouris <[email protected]>
AuthorDate: Tue Jan 7 14:34:44 2025 +0100
chore: set validation and typehint (#983)
---
python/datafusion/context.py | 13 ++++++++++++-
src/context.rs | 4 ++--
src/dataframe.rs | 21 +--------------------
src/utils.rs | 21 +++++++++++++++++++++
4 files changed, 36 insertions(+), 23 deletions(-)
diff --git a/python/datafusion/context.py b/python/datafusion/context.py
index a07b5d1..3fa1333 100644
--- a/python/datafusion/context.py
+++ b/python/datafusion/context.py
@@ -63,6 +63,15 @@ class ArrowArrayExportable(Protocol):
) -> tuple[object, object]: ...
+class TableProviderExportable(Protocol):
+ """Type hint for object that has __datafusion_table_provider__ PyCapsule.
+
+ https://datafusion.apache.org/python/user-guide/io/table_provider.html
+ """
+
+ def __datafusion_table_provider__(self) -> object: ... # noqa: D105
+
+
class SessionConfig:
"""Session configuration options."""
@@ -685,7 +694,9 @@ class SessionContext:
"""Remove a table from the session."""
self.ctx.deregister_table(name)
- def register_table_provider(self, name: str, provider: Any) -> None:
+ def register_table_provider(
+ self, name: str, provider: TableProviderExportable
+ ) -> None:
"""Register a table provider.
This table provider must have a method called
``__datafusion_table_provider__``
diff --git a/src/context.rs b/src/context.rs
index 8675e97..0512285 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -43,7 +43,7 @@ use crate::store::StorageContexts;
use crate::udaf::PyAggregateUDF;
use crate::udf::PyScalarUDF;
use crate::udwf::PyWindowUDF;
-use crate::utils::{get_tokio_runtime, wait_for_future};
+use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
@@ -576,7 +576,7 @@ impl PySessionContext {
if provider.hasattr("__datafusion_table_provider__")? {
let capsule =
provider.getattr("__datafusion_table_provider__")?.call0()?;
let capsule = capsule.downcast::<PyCapsule>()?;
- // validate_pycapsule(capsule, "arrow_array_stream")?;
+ validate_pycapsule(capsule, "datafusion_table_provider")?;
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
let provider: ForeignTableProvider = provider.into();
diff --git a/src/dataframe.rs b/src/dataframe.rs
index e7d6ca6..fcb46a7 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -44,7 +44,7 @@ use crate::expr::sort_expr::to_sort_expressions;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
use crate::sql::logical::PyLogicalPlan;
-use crate::utils::{get_tokio_runtime, wait_for_future};
+use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
use crate::{
errors::DataFusionError,
expr::{sort_expr::PySortExpr, PyExpr},
@@ -724,22 +724,3 @@ fn record_batch_into_schema(
RecordBatch::try_new(schema, data_arrays)
}
-
-fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
- let capsule_name = capsule.name()?;
- if capsule_name.is_none() {
- return Err(PyValueError::new_err(
- "Expected schema PyCapsule to have name set.",
- ));
- }
-
- let capsule_name = capsule_name.unwrap().to_str()?;
- if capsule_name != name {
- return Err(PyValueError::new_err(format!(
- "Expected name '{}' in PyCapsule, instead got '{}'",
- name, capsule_name
- )));
- }
-
- Ok(())
-}
diff --git a/src/utils.rs b/src/utils.rs
index 7fb23ca..7955897 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -18,7 +18,9 @@
use crate::errors::DataFusionError;
use crate::TokioRuntime;
use datafusion::logical_expr::Volatility;
+use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
+use pyo3::types::PyCapsule;
use std::future::Future;
use std::sync::OnceLock;
use tokio::runtime::Runtime;
@@ -58,3 +60,22 @@ pub(crate) fn parse_volatility(value: &str) ->
Result<Volatility, DataFusionErro
}
})
}
+
+pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) ->
PyResult<()> {
+ let capsule_name = capsule.name()?;
+ if capsule_name.is_none() {
+ return Err(PyValueError::new_err(
+ "Expected schema PyCapsule to have name set.",
+ ));
+ }
+
+ let capsule_name = capsule_name.unwrap().to_str()?;
+ if capsule_name != name {
+ return Err(PyValueError::new_err(format!(
+ "Expected name '{}' in PyCapsule, instead got '{}'",
+ name, capsule_name
+ )));
+ }
+
+ Ok(())
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]