This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 3dcf7c7e feat/making global context accessible for users (#1060)
3dcf7c7e is described below
commit 3dcf7c7e5c0af0eb3c5e3bdf9c6e33fd4541b070
Author: jsai28 <[email protected]>
AuthorDate: Thu Mar 13 04:09:03 2025 -0600
feat/making global context accessible for users (#1060)
* Rename _global_ctx to global_ctx
* Add global context to python wrapper code
* Update context.py
* singleton for global context
* formatting
* remove udf from import
* remove _global_instance
* formatting
* formatting
* unnecessary test
* fix test_io.py
* ran ruff
* ran ruff format
---
python/datafusion/context.py | 12 +++++++++
python/datafusion/io.py | 63 +++++++++++++++++++-------------------------
python/tests/test_context.py | 18 +++++++++++++
src/context.rs | 2 +-
4 files changed, 58 insertions(+), 37 deletions(-)
diff --git a/python/datafusion/context.py b/python/datafusion/context.py
index 0ab1a908..58ad9a94 100644
--- a/python/datafusion/context.py
+++ b/python/datafusion/context.py
@@ -496,6 +496,18 @@ class SessionContext:
self.ctx = SessionContextInternal(config, runtime)
+ @classmethod
+ def global_ctx(cls) -> SessionContext:
+ """Retrieve the global context as a `SessionContext` wrapper.
+
+ Returns:
+ A `SessionContext` object that wraps the global
`SessionContextInternal`.
+ """
+ internal_ctx = SessionContextInternal.global_ctx()
+ wrapper = cls()
+ wrapper.ctx = internal_ctx
+ return wrapper
+
def enable_url_table(self) -> SessionContext:
"""Control if local files can be queried as tables.
diff --git a/python/datafusion/io.py b/python/datafusion/io.py
index 3e39703e..ef5ebf96 100644
--- a/python/datafusion/io.py
+++ b/python/datafusion/io.py
@@ -21,10 +21,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING
+from datafusion.context import SessionContext
from datafusion.dataframe import DataFrame
-from ._internal import SessionContext as SessionContextInternal
-
if TYPE_CHECKING:
import pathlib
@@ -68,16 +67,14 @@ def read_parquet(
"""
if table_partition_cols is None:
table_partition_cols = []
- return DataFrame(
- SessionContextInternal._global_ctx().read_parquet(
- str(path),
- table_partition_cols,
- parquet_pruning,
- file_extension,
- skip_metadata,
- schema,
- file_sort_order,
- )
+ return SessionContext.global_ctx().read_parquet(
+ str(path),
+ table_partition_cols,
+ parquet_pruning,
+ file_extension,
+ skip_metadata,
+ schema,
+ file_sort_order,
)
@@ -110,15 +107,13 @@ def read_json(
"""
if table_partition_cols is None:
table_partition_cols = []
- return DataFrame(
- SessionContextInternal._global_ctx().read_json(
- str(path),
- schema,
- schema_infer_max_records,
- file_extension,
- table_partition_cols,
- file_compression_type,
- )
+ return SessionContext.global_ctx().read_json(
+ str(path),
+ schema,
+ schema_infer_max_records,
+ file_extension,
+ table_partition_cols,
+ file_compression_type,
)
@@ -161,17 +156,15 @@ def read_csv(
path = [str(p) for p in path] if isinstance(path, list) else str(path)
- return DataFrame(
- SessionContextInternal._global_ctx().read_csv(
- path,
- schema,
- has_header,
- delimiter,
- schema_infer_max_records,
- file_extension,
- table_partition_cols,
- file_compression_type,
- )
+ return SessionContext.global_ctx().read_csv(
+ path,
+ schema,
+ has_header,
+ delimiter,
+ schema_infer_max_records,
+ file_extension,
+ table_partition_cols,
+ file_compression_type,
)
@@ -198,8 +191,6 @@ def read_avro(
"""
if file_partition_cols is None:
file_partition_cols = []
- return DataFrame(
- SessionContextInternal._global_ctx().read_avro(
- str(path), schema, file_partition_cols, file_extension
- )
+ return SessionContext.global_ctx().read_avro(
+ str(path), schema, file_partition_cols, file_extension
)
diff --git a/python/tests/test_context.py b/python/tests/test_context.py
index 7a0a7aa0..4a15ac9c 100644
--- a/python/tests/test_context.py
+++ b/python/tests/test_context.py
@@ -632,3 +632,21 @@ def test_sql_with_options_no_statements(ctx):
options = SQLOptions().with_allow_statements(allow=False)
with pytest.raises(Exception, match="SetVariable"):
ctx.sql_with_options(sql, options=options)
+
+
[email protected]
+def batch():
+ return pa.RecordBatch.from_arrays(
+ [pa.array([4, 5, 6])],
+ names=["a"],
+ )
+
+
+def test_create_dataframe_with_global_ctx(batch):
+ ctx = SessionContext.global_ctx()
+
+ df = ctx.create_dataframe([[batch]])
+
+ result = df.collect()[0].column(0)
+
+ assert result == pa.array([4, 5, 6])
diff --git a/src/context.rs b/src/context.rs
index 9ba87eb8..0db0f4d7 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -308,7 +308,7 @@ impl PySessionContext {
#[classmethod]
#[pyo3(signature = ())]
- fn _global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
+ fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
Ok(Self {
ctx: get_global_ctx().clone(),
})
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]