This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 545e93e  Store the Tokio Runtime in an _internal module level 
attribute and reuse it. (#341)
545e93e is described below

commit 545e93e1739a71ef77cee1afe9196c318aabb667
Author: Kyle Brooks <[email protected]>
AuthorDate: Mon Apr 24 19:48:04 2023 -0400

    Store the Tokio Runtime in an _internal module level attribute and reuse 
it. (#341)
---
 src/context.rs |  5 ++---
 src/lib.rs     |  9 +++++++++
 src/utils.rs   | 11 +++++++++--
 3 files changed, 20 insertions(+), 5 deletions(-)

diff --git a/src/context.rs b/src/context.rs
index 0ba1200..3dc8a8f 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -36,7 +36,7 @@ use crate::sql::logical::PyLogicalPlan;
 use crate::store::StorageContexts;
 use crate::udaf::PyAggregateUDF;
 use crate::udf::PyScalarUDF;
-use crate::utils::wait_for_future;
+use crate::utils::{get_tokio_runtime, wait_for_future};
 use datafusion::arrow::datatypes::{DataType, Schema};
 use datafusion::arrow::pyarrow::PyArrowType;
 use datafusion::arrow::record_batch::RecordBatch;
@@ -52,7 +52,6 @@ use datafusion::prelude::{
 };
 use datafusion_common::ScalarValue;
 use pyo3::types::PyTuple;
-use tokio::runtime::Runtime;
 use tokio::task::JoinHandle;
 
 /// Configuration options for a SessionContext
@@ -722,7 +721,7 @@ impl PySessionContext {
             Arc::new(RuntimeEnv::default()),
         );
         // create a Tokio runtime to run the async code
-        let rt = Runtime::new().unwrap();
+        let rt = &get_tokio_runtime(py).0;
         let plan = plan.plan.clone();
         let fut: 
JoinHandle<datafusion_common::Result<SendableRecordBatchStream>> =
             rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
diff --git a/src/lib.rs b/src/lib.rs
index 4a6574c..0bb4d9a 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -59,12 +59,21 @@ pub mod utils;
 #[global_allocator]
 static GLOBAL: MiMalloc = MiMalloc;
 
+// Used to define Tokio Runtime as a Python module attribute
+#[pyclass]
+pub(crate) struct TokioRuntime(tokio::runtime::Runtime);
+
 /// Low-level DataFusion internal package.
 ///
 /// The higher-level public API is defined in pure python files under the
 /// datafusion directory.
 #[pymodule]
 fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
+    // Register the Tokio Runtime as a module attribute so we can reuse it
+    m.add(
+        "runtime",
+        TokioRuntime(tokio::runtime::Runtime::new().unwrap()),
+    )?;
     // Register the python classes
     m.add_class::<catalog::PyCatalog>()?;
     m.add_class::<catalog::PyDatabase>()?;
diff --git a/src/utils.rs b/src/utils.rs
index 4158b74..427a8a0 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -16,19 +16,26 @@
 // under the License.
 
 use crate::errors::DataFusionError;
+use crate::TokioRuntime;
 use datafusion_expr::Volatility;
 use pyo3::prelude::*;
 use std::future::Future;
 use tokio::runtime::Runtime;
 
+/// Utility to get the Tokio Runtime from Python
+pub(crate) fn get_tokio_runtime(py: Python) -> PyRef<TokioRuntime> {
+    let datafusion = py.import("datafusion._internal").unwrap();
+    datafusion.getattr("runtime").unwrap().extract().unwrap()
+}
+
 /// Utility to collect rust futures with GIL released
 pub fn wait_for_future<F: Future>(py: Python, f: F) -> F::Output
 where
     F: Send,
     F::Output: Send,
 {
-    let rt = Runtime::new().unwrap();
-    py.allow_threads(|| rt.block_on(f))
+    let runtime: &Runtime = &get_tokio_runtime(py).0;
+    py.allow_threads(|| runtime.block_on(f))
 }
 
 pub(crate) fn parse_volatility(value: &str) -> Result<Volatility, 
DataFusionError> {

Reply via email to