This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new d411cabd9 fix(rust/driver/datafusion): using datafusion driver in 
async runtime (#3712)
d411cabd9 is described below

commit d411cabd9e34a534e13a06ea4f533ff9a326869e
Author: Pavel Agafonov <[email protected]>
AuthorDate: Fri Nov 14 09:29:08 2025 +0300

    fix(rust/driver/datafusion): using datafusion driver in async runtime 
(#3712)
    
    Closes #3711
    
    This is not an ideal solution, but we need to think about sync/async
    ergonomics in the future.
    
    At the moment, there is a problem: the inability to use this approach in
    single-thread runtime.
    
    ---------
    
    Signed-off-by: if0ne <[email protected]>
    Signed-off-by: Pavel Agafonov <[email protected]>
---
 rust/driver/datafusion/src/lib.rs               | 59 +++++++++++++++++++------
 rust/driver/datafusion/tests/test_datafusion.rs | 26 ++++++++---
 2 files changed, 65 insertions(+), 20 deletions(-)

diff --git a/rust/driver/datafusion/src/lib.rs 
b/rust/driver/datafusion/src/lib.rs
index ef61f4d24..c08ba28c7 100644
--- a/rust/driver/datafusion/src/lib.rs
+++ b/rust/driver/datafusion/src/lib.rs
@@ -25,9 +25,9 @@ use 
datafusion_substrait::logical_plan::consumer::from_substrait_plan;
 use datafusion_substrait::substrait::proto::Plan;
 use prost::Message;
 use std::fmt::Debug;
+use std::future::Future;
 use std::sync::Arc;
 use std::vec::IntoIter;
-use tokio::runtime::Runtime;
 
 use arrow_array::builder::{
     BooleanBuilder, Int32Builder, Int64Builder, ListBuilder, MapBuilder, 
MapFieldNames,
@@ -48,6 +48,31 @@ use adbc_core::{
     schemas, Connection, Database, Driver, Optionable, Statement,
 };
 
+pub enum Runtime {
+    Handle(tokio::runtime::Handle),
+    Tokio(tokio::runtime::Runtime),
+}
+
+impl Runtime {
+    pub fn new(handle: Option<tokio::runtime::Handle>) -> 
std::io::Result<Self> {
+        if let Some(handle) = handle {
+            Ok(Self::Handle(handle))
+        } else {
+            let runtime = tokio::runtime::Builder::new_multi_thread()
+                .enable_all()
+                .build()?;
+            Ok(Self::Tokio(runtime))
+        }
+    }
+
+    pub fn block_on<F: Future>(&self, future: F) -> F::Output {
+        match self {
+            Runtime::Handle(handle) => tokio::task::block_in_place(|| 
handle.block_on(future)),
+            Runtime::Tokio(runtime) => runtime.block_on(future),
+        }
+    }
+}
+
 #[derive(Debug)]
 pub struct SingleBatchReader {
     batch: Option<RecordBatch>,
@@ -109,13 +134,23 @@ impl RecordBatchReader for DataFusionReader {
 }
 
 #[derive(Default)]
-pub struct DataFusionDriver {}
+pub struct DataFusionDriver {
+    handle: Option<tokio::runtime::Handle>,
+}
+
+impl DataFusionDriver {
+    pub fn new(handle: Option<tokio::runtime::Handle>) -> Self {
+        Self { handle }
+    }
+}
 
 impl Driver for DataFusionDriver {
     type DatabaseType = DataFusionDatabase;
 
     fn new_database(&mut self) -> Result<Self::DatabaseType> {
-        Ok(Self::DatabaseType {})
+        Ok(Self::DatabaseType {
+            handle: self.handle.clone(),
+        })
     }
 
     fn new_database_with_opts(
@@ -127,7 +162,9 @@ impl Driver for DataFusionDriver {
             ),
         >,
     ) -> adbc_core::error::Result<Self::DatabaseType> {
-        let mut database = Self::DatabaseType {};
+        let mut database = Self::DatabaseType {
+            handle: self.handle.clone(),
+        };
         for (key, value) in opts {
             database.set_option(key, value)?;
         }
@@ -135,7 +172,9 @@ impl Driver for DataFusionDriver {
     }
 }
 
-pub struct DataFusionDatabase {}
+pub struct DataFusionDatabase {
+    handle: Option<tokio::runtime::Handle>,
+}
 
 impl Optionable for DataFusionDatabase {
     type Option = OptionDatabase;
@@ -186,10 +225,7 @@ impl Database for DataFusionDatabase {
     fn new_connection(&self) -> Result<Self::ConnectionType> {
         let ctx = SessionContext::new();
 
-        let runtime = tokio::runtime::Builder::new_multi_thread()
-            .enable_all()
-            .build()
-            .unwrap();
+        let runtime = Runtime::new(self.handle.clone()).unwrap();
 
         Ok(DataFusionConnection {
             runtime: Arc::new(runtime),
@@ -208,10 +244,7 @@ impl Database for DataFusionDatabase {
     ) -> adbc_core::error::Result<Self::ConnectionType> {
         let ctx = SessionContext::new();
 
-        let runtime = tokio::runtime::Builder::new_multi_thread()
-            .enable_all()
-            .build()
-            .unwrap();
+        let runtime = Runtime::new(self.handle.clone()).unwrap();
 
         let mut connection = DataFusionConnection {
             runtime: Arc::new(runtime),
diff --git a/rust/driver/datafusion/tests/test_datafusion.rs 
b/rust/driver/datafusion/tests/test_datafusion.rs
index 1ba496c9a..9123a480a 100644
--- a/rust/driver/datafusion/tests/test_datafusion.rs
+++ b/rust/driver/datafusion/tests/test_datafusion.rs
@@ -26,8 +26,8 @@ use 
datafusion_substrait::logical_plan::producer::to_substrait_plan;
 use datafusion_substrait::substrait::proto::Plan;
 use prost::Message;
 
-fn get_connection() -> DataFusionConnection {
-    let mut driver = DataFusionDriver::default();
+fn get_connection(handle: Option<tokio::runtime::Handle>) -> 
DataFusionConnection {
+    let mut driver = DataFusionDriver::new(handle);
     let database = driver.new_database().unwrap();
     database.new_connection().unwrap()
 }
@@ -80,7 +80,7 @@ fn execute_substrait(connection: &mut DataFusionConnection, 
plan: Plan) -> Recor
 
 #[test]
 fn test_connection_options() {
-    let mut connection = get_connection();
+    let mut connection = get_connection(None);
 
     let current_catalog = connection
         .get_option_string(OptionConnection::CurrentCatalog)
@@ -119,7 +119,7 @@ fn test_connection_options() {
 
 #[test]
 fn test_get_objects_database() {
-    let mut connection = get_connection();
+    let mut connection = get_connection(None);
 
     let objects = get_objects(&connection);
 
@@ -134,7 +134,7 @@ fn test_get_objects_database() {
 
 #[test]
 fn test_execute_sql() {
-    let mut connection = get_connection();
+    let mut connection = get_connection(None);
 
     execute_update(&mut connection, "CREATE TABLE IF NOT EXISTS 
datafusion.public.example (c1 INT, c2 VARCHAR) AS 
VALUES(1,'HELLO'),(2,'DATAFUSION'),(3,'!')");
 
@@ -146,7 +146,7 @@ fn test_execute_sql() {
 
 #[test]
 fn test_ingest() {
-    let mut connection = get_connection();
+    let mut connection = get_connection(None);
 
     execute_update(&mut connection, "CREATE TABLE IF NOT EXISTS 
datafusion.public.example (c1 INT, c2 VARCHAR) AS 
VALUES(1,'HELLO'),(2,'DATAFUSION'),(3,'!')");
 
@@ -172,7 +172,7 @@ fn test_ingest() {
 
 #[test]
 fn test_execute_substrait() {
-    let mut connection = get_connection();
+    let mut connection = get_connection(None);
 
     execute_update(&mut connection, "CREATE TABLE IF NOT EXISTS 
datafusion.public.example (c1 INT, c2 VARCHAR) AS 
VALUES(1,'HELLO'),(2,'DATAFUSION'),(3,'!')");
 
@@ -198,3 +198,15 @@ fn test_execute_substrait() {
     assert_eq!(batch.num_rows(), 3);
     assert_eq!(batch.num_columns(), 2);
 }
+
+#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
+async fn test_running_in_async() {
+    let mut connection = 
get_connection(Some(tokio::runtime::Handle::current()));
+
+    execute_update(&mut connection, "CREATE TABLE IF NOT EXISTS 
datafusion.public.example (c1 INT, c2 VARCHAR) AS 
VALUES(1,'HELLO'),(2,'DATAFUSION'),(3,'!')");
+
+    let batch = execute_sql_query(&mut connection, "SELECT * FROM 
datafusion.public.example");
+
+    assert_eq!(batch.num_rows(), 3);
+    assert_eq!(batch.num_columns(), 2);
+}

Reply via email to