This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 545e93e Store the Tokio Runtime in an _internal module level
attribute and reuse it. (#341)
545e93e is described below
commit 545e93e1739a71ef77cee1afe9196c318aabb667
Author: Kyle Brooks <[email protected]>
AuthorDate: Mon Apr 24 19:48:04 2023 -0400
Store the Tokio Runtime in an _internal module level attribute and reuse
it. (#341)
---
src/context.rs | 5 ++---
src/lib.rs | 9 +++++++++
src/utils.rs | 11 +++++++++--
3 files changed, 20 insertions(+), 5 deletions(-)
diff --git a/src/context.rs b/src/context.rs
index 0ba1200..3dc8a8f 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -36,7 +36,7 @@ use crate::sql::logical::PyLogicalPlan;
use crate::store::StorageContexts;
use crate::udaf::PyAggregateUDF;
use crate::udf::PyScalarUDF;
-use crate::utils::wait_for_future;
+use crate::utils::{get_tokio_runtime, wait_for_future};
use datafusion::arrow::datatypes::{DataType, Schema};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
@@ -52,7 +52,6 @@ use datafusion::prelude::{
};
use datafusion_common::ScalarValue;
use pyo3::types::PyTuple;
-use tokio::runtime::Runtime;
use tokio::task::JoinHandle;
/// Configuration options for a SessionContext
@@ -722,7 +721,7 @@ impl PySessionContext {
Arc::new(RuntimeEnv::default()),
);
// create a Tokio runtime to run the async code
- let rt = Runtime::new().unwrap();
+ let rt = &get_tokio_runtime(py).0;
let plan = plan.plan.clone();
let fut:
JoinHandle<datafusion_common::Result<SendableRecordBatchStream>> =
rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
diff --git a/src/lib.rs b/src/lib.rs
index 4a6574c..0bb4d9a 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -59,12 +59,21 @@ pub mod utils;
#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;
+// Used to define Tokio Runtime as a Python module attribute
+#[pyclass]
+pub(crate) struct TokioRuntime(tokio::runtime::Runtime);
+
/// Low-level DataFusion internal package.
///
/// The higher-level public API is defined in pure python files under the
/// datafusion directory.
#[pymodule]
fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
+ // Register the Tokio Runtime as a module attribute so we can reuse it
+ m.add(
+ "runtime",
+ TokioRuntime(tokio::runtime::Runtime::new().unwrap()),
+ )?;
// Register the python classes
m.add_class::<catalog::PyCatalog>()?;
m.add_class::<catalog::PyDatabase>()?;
diff --git a/src/utils.rs b/src/utils.rs
index 4158b74..427a8a0 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -16,19 +16,26 @@
// under the License.
use crate::errors::DataFusionError;
+use crate::TokioRuntime;
use datafusion_expr::Volatility;
use pyo3::prelude::*;
use std::future::Future;
use tokio::runtime::Runtime;
+/// Utility to get the Tokio Runtime from Python
+pub(crate) fn get_tokio_runtime(py: Python) -> PyRef<TokioRuntime> {
+ let datafusion = py.import("datafusion._internal").unwrap();
+ datafusion.getattr("runtime").unwrap().extract().unwrap()
+}
+
/// Utility to collect rust futures with GIL released
pub fn wait_for_future<F: Future>(py: Python, f: F) -> F::Output
where
F: Send,
F::Output: Send,
{
- let rt = Runtime::new().unwrap();
- py.allow_threads(|| rt.block_on(f))
+ let runtime: &Runtime = &get_tokio_runtime(py).0;
+ py.allow_threads(|| runtime.block_on(f))
}
pub(crate) fn parse_volatility(value: &str) -> Result<Volatility,
DataFusionError> {