This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new d07b9f0a9 feat(rust/drivers/datafusion): add support for bulk ingest 
(#2279)
d07b9f0a9 is described below

commit d07b9f0a9934af3da2b988825bc7523619fee581
Author: Tornike Gurgenidze <[email protected]>
AuthorDate: Tue Nov 5 05:37:40 2024 +0400

    feat(rust/drivers/datafusion): add support for bulk ingest (#2279)
    
    - swaps todos from Optionable impls to partial support: none for
    Database, CurrentCatalog and CurrentSchema for Connection and
    IngestTableTarget for Statement.
    - adds support for binding a single batch and executing a bulk insert on
    it. I used `RefCell` to pass bound record batch from `bind` to an
    `execute_update` method.
---
 rust/drivers/datafusion/src/lib.rs               | 245 ++++++++++++++++++-----
 rust/drivers/datafusion/tests/test_datafusion.rs |  69 ++++++-
 2 files changed, 257 insertions(+), 57 deletions(-)

diff --git a/rust/drivers/datafusion/src/lib.rs 
b/rust/drivers/datafusion/src/lib.rs
index 5141d856c..55adcc571 100644
--- a/rust/drivers/datafusion/src/lib.rs
+++ b/rust/drivers/datafusion/src/lib.rs
@@ -17,14 +17,16 @@
 
 #![allow(refining_impl_trait)]
 
+use adbc_core::ffi::constants;
+use datafusion::dataframe::DataFrameWriteOptions;
 use datafusion::datasource::TableType;
 use datafusion::prelude::*;
 use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
 use datafusion_substrait::substrait::proto::Plan;
 use prost::Message;
+use std::fmt::Debug;
 use std::sync::Arc;
 use std::vec::IntoIter;
-use std::{collections::HashMap, fmt::Debug};
 use tokio::runtime::Runtime;
 
 use arrow_array::builder::{
@@ -113,9 +115,7 @@ impl Driver for DataFusionDriver {
     type DatabaseType = DataFusionDatabase;
 
     fn new_database(&mut self) -> Result<Self::DatabaseType> {
-        Ok(Self::DatabaseType {
-            options: HashMap::new(),
-        })
+        Ok(Self::DatabaseType {})
     }
 
     fn new_database_with_opts(
@@ -127,9 +127,7 @@ impl Driver for DataFusionDriver {
             ),
         >,
     ) -> adbc_core::error::Result<Self::DatabaseType> {
-        let mut database = Self::DatabaseType {
-            options: HashMap::new(),
-        };
+        let mut database = Self::DatabaseType {};
         for (key, value) in opts {
             database.set_option(key, value)?;
         }
@@ -137,36 +135,48 @@ impl Driver for DataFusionDriver {
     }
 }
 
-pub struct DataFusionDatabase {
-    options: HashMap<OptionDatabase, OptionValue>,
-}
+pub struct DataFusionDatabase {}
 
 impl Optionable for DataFusionDatabase {
     type Option = OptionDatabase;
 
     fn set_option(
         &mut self,
-        _key: Self::Option,
+        key: Self::Option,
         _value: adbc_core::options::OptionValue,
     ) -> adbc_core::error::Result<()> {
-        self.options.insert(_key, _value);
-        Ok(())
+        Err(Error::with_message_and_status(
+            format!("Unrecognized option: {key:?}"),
+            Status::NotFound,
+        ))
     }
 
-    fn get_option_string(&self, _key: Self::Option) -> 
adbc_core::error::Result<String> {
-        todo!()
+    fn get_option_string(&self, key: Self::Option) -> 
adbc_core::error::Result<String> {
+        Err(Error::with_message_and_status(
+            format!("Unrecognized option: {key:?}"),
+            Status::NotFound,
+        ))
     }
 
-    fn get_option_bytes(&self, _key: Self::Option) -> 
adbc_core::error::Result<Vec<u8>> {
-        todo!()
+    fn get_option_bytes(&self, key: Self::Option) -> 
adbc_core::error::Result<Vec<u8>> {
+        Err(Error::with_message_and_status(
+            format!("Unrecognized option: {key:?}"),
+            Status::NotFound,
+        ))
     }
 
-    fn get_option_int(&self, _key: Self::Option) -> 
adbc_core::error::Result<i64> {
-        todo!()
+    fn get_option_int(&self, key: Self::Option) -> 
adbc_core::error::Result<i64> {
+        Err(Error::with_message_and_status(
+            format!("Unrecognized option: {key:?}"),
+            Status::NotFound,
+        ))
     }
 
-    fn get_option_double(&self, _key: Self::Option) -> 
adbc_core::error::Result<f64> {
-        todo!()
+    fn get_option_double(&self, key: Self::Option) -> 
adbc_core::error::Result<f64> {
+        Err(Error::with_message_and_status(
+            format!("Unrecognized option: {key:?}"),
+            Status::NotFound,
+        ))
     }
 }
 
@@ -189,7 +199,7 @@ impl Database for DataFusionDatabase {
 
     fn new_connection_with_opts(
         &mut self,
-        _opts: impl IntoIterator<
+        opts: impl IntoIterator<
             Item = (
                 adbc_core::options::OptionConnection,
                 adbc_core::options::OptionValue,
@@ -203,10 +213,16 @@ impl Database for DataFusionDatabase {
             .build()
             .unwrap();
 
-        Ok(DataFusionConnection {
+        let mut connection = DataFusionConnection {
             runtime: Arc::new(runtime),
             ctx: Arc::new(ctx),
-        })
+        };
+
+        for (key, value) in opts {
+            connection.set_option(key, value)?;
+        }
+
+        Ok(connection)
     }
 }
 
@@ -220,26 +236,85 @@ impl Optionable for DataFusionConnection {
 
     fn set_option(
         &mut self,
-        _key: Self::Option,
-        _value: adbc_core::options::OptionValue,
+        key: Self::Option,
+        value: adbc_core::options::OptionValue,
     ) -> adbc_core::error::Result<()> {
-        todo!()
+        match key.as_ref() {
+            constants::ADBC_CONNECTION_OPTION_CURRENT_CATALOG => match value {
+                OptionValue::String(value) => {
+                    self.runtime.block_on(async {
+                        let query = format!("SET 
datafusion.catalog.default_catalog = {value}");
+                        self.ctx.sql(query.as_str()).await.unwrap();
+                    });
+                    Ok(())
+                }
+                _ => Err(Error::with_message_and_status(
+                    "CurrentCatalog value must be of type String",
+                    Status::InvalidArguments,
+                )),
+            },
+            constants::ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA => match value 
{
+                OptionValue::String(value) => {
+                    self.runtime.block_on(async {
+                        let query = format!("SET 
datafusion.catalog.default_schema = {value}");
+                        self.ctx.sql(query.as_str()).await.unwrap();
+                    });
+                    Ok(())
+                }
+                _ => Err(Error::with_message_and_status(
+                    "CurrentSchema value must be of type String",
+                    Status::InvalidArguments,
+                )),
+            },
+            _ => Err(Error::with_message_and_status(
+                format!("Unrecognized option: {key:?}"),
+                Status::NotFound,
+            )),
+        }
     }
 
-    fn get_option_string(&self, _key: Self::Option) -> 
adbc_core::error::Result<String> {
-        todo!()
+    fn get_option_string(&self, key: Self::Option) -> 
adbc_core::error::Result<String> {
+        match key.as_ref() {
+            constants::ADBC_CONNECTION_OPTION_CURRENT_CATALOG => Ok(self
+                .ctx
+                .state()
+                .config_options()
+                .catalog
+                .default_catalog
+                .clone()),
+            constants::ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA => Ok(self
+                .ctx
+                .state()
+                .config_options()
+                .catalog
+                .default_schema
+                .clone()),
+            _ => Err(Error::with_message_and_status(
+                format!("Unrecognized option: {key:?}"),
+                Status::NotFound,
+            )),
+        }
     }
 
-    fn get_option_bytes(&self, _key: Self::Option) -> 
adbc_core::error::Result<Vec<u8>> {
-        todo!()
+    fn get_option_bytes(&self, key: Self::Option) -> 
adbc_core::error::Result<Vec<u8>> {
+        Err(Error::with_message_and_status(
+            format!("Unrecognized option: {key:?}"),
+            Status::NotFound,
+        ))
     }
 
-    fn get_option_int(&self, _key: Self::Option) -> 
adbc_core::error::Result<i64> {
-        todo!()
+    fn get_option_int(&self, key: Self::Option) -> 
adbc_core::error::Result<i64> {
+        Err(Error::with_message_and_status(
+            format!("Unrecognized option: {key:?}"),
+            Status::NotFound,
+        ))
     }
 
-    fn get_option_double(&self, _key: Self::Option) -> 
adbc_core::error::Result<f64> {
-        todo!()
+    fn get_option_double(&self, key: Self::Option) -> 
adbc_core::error::Result<f64> {
+        Err(Error::with_message_and_status(
+            format!("Unrecognized option: {key:?}"),
+            Status::NotFound,
+        ))
     }
 }
 
@@ -622,6 +697,8 @@ impl Connection for DataFusionConnection {
             ctx: self.ctx.clone(),
             sql_query: None,
             substrait_plan: None,
+            bound_record_batch: None,
+            ingest_target_table: None,
         })
     }
 
@@ -707,6 +784,8 @@ pub struct DataFusionStatement {
     ctx: Arc<SessionContext>,
     sql_query: Option<String>,
     substrait_plan: Option<Plan>,
+    bound_record_batch: Option<RecordBatch>,
+    ingest_target_table: Option<String>,
 }
 
 impl Optionable for DataFusionStatement {
@@ -714,32 +793,72 @@ impl Optionable for DataFusionStatement {
 
     fn set_option(
         &mut self,
-        _key: Self::Option,
-        _value: adbc_core::options::OptionValue,
+        key: Self::Option,
+        value: adbc_core::options::OptionValue,
     ) -> adbc_core::error::Result<()> {
-        todo!()
+        match key.as_ref() {
+            constants::ADBC_INGEST_OPTION_TARGET_TABLE => match value {
+                OptionValue::String(value) => {
+                    self.ingest_target_table = Some(value);
+                    Ok(())
+                }
+                _ => Err(Error::with_message_and_status(
+                    "IngestOptionTargetTable value must be of type String",
+                    Status::InvalidArguments,
+                )),
+            },
+            _ => Err(Error::with_message_and_status(
+                format!("Unrecognized option: {key:?}"),
+                Status::NotFound,
+            )),
+        }
     }
 
-    fn get_option_string(&self, _key: Self::Option) -> 
adbc_core::error::Result<String> {
-        todo!()
+    fn get_option_string(&self, key: Self::Option) -> 
adbc_core::error::Result<String> {
+        match key.as_ref() {
+            constants::ADBC_INGEST_OPTION_TARGET_TABLE => {
+                let target_table = self.ingest_target_table.clone();
+                match target_table {
+                    Some(table) => Ok(table),
+                    None => Err(Error::with_message_and_status(
+                        format!("{key:?} has not been set"),
+                        Status::NotFound,
+                    )),
+                }
+            }
+            _ => Err(Error::with_message_and_status(
+                format!("Unrecognized option: {key:?}"),
+                Status::NotFound,
+            )),
+        }
     }
 
-    fn get_option_bytes(&self, _key: Self::Option) -> 
adbc_core::error::Result<Vec<u8>> {
-        todo!()
+    fn get_option_bytes(&self, key: Self::Option) -> 
adbc_core::error::Result<Vec<u8>> {
+        Err(Error::with_message_and_status(
+            format!("Unrecognized option: {key:?}"),
+            Status::NotFound,
+        ))
     }
 
-    fn get_option_int(&self, _key: Self::Option) -> 
adbc_core::error::Result<i64> {
-        todo!()
+    fn get_option_int(&self, key: Self::Option) -> 
adbc_core::error::Result<i64> {
+        Err(Error::with_message_and_status(
+            format!("Unrecognized option: {key:?}"),
+            Status::NotFound,
+        ))
     }
 
-    fn get_option_double(&self, _key: Self::Option) -> 
adbc_core::error::Result<f64> {
-        todo!()
+    fn get_option_double(&self, key: Self::Option) -> 
adbc_core::error::Result<f64> {
+        Err(Error::with_message_and_status(
+            format!("Unrecognized option: {key:?}"),
+            Status::NotFound,
+        ))
     }
 }
 
 impl Statement for DataFusionStatement {
-    fn bind(&mut self, _batch: arrow_array::RecordBatch) -> 
adbc_core::error::Result<()> {
-        todo!()
+    fn bind(&mut self, batch: arrow_array::RecordBatch) -> 
adbc_core::error::Result<()> {
+        self.bound_record_batch.replace(batch);
+        Ok(())
     }
 
     fn bind_stream(
@@ -768,13 +887,29 @@ impl Statement for DataFusionStatement {
     }
 
     fn execute_update(&mut self) -> adbc_core::error::Result<Option<i64>> {
-        self.runtime.block_on(async {
-            let _ = self
-                .ctx
-                .sql(&self.sql_query.clone().unwrap())
-                .await
-                .unwrap();
-        });
+        if self.sql_query.is_some() {
+            self.runtime.block_on(async {
+                let _ = self
+                    .ctx
+                    .sql(&self.sql_query.clone().unwrap())
+                    .await
+                    .unwrap();
+            });
+        } else if let Some(batch) = self.bound_record_batch.take() {
+            self.runtime.block_on(async {
+                let table = match self.ingest_target_table.clone() {
+                    Some(table) => table,
+                    None => todo!(),
+                };
+
+                self.ctx
+                    .read_batch(batch)
+                    .unwrap()
+                    .write_table(table.as_str(), DataFrameWriteOptions::new())
+                    .await
+                    .unwrap();
+            });
+        }
 
         Ok(Some(0))
     }
diff --git a/rust/drivers/datafusion/tests/test_datafusion.rs 
b/rust/drivers/datafusion/tests/test_datafusion.rs
index 4e1db4fa8..38764459a 100644
--- a/rust/drivers/datafusion/tests/test_datafusion.rs
+++ b/rust/drivers/datafusion/tests/test_datafusion.rs
@@ -16,11 +16,11 @@
 // under the License.
 
 use adbc_core::driver_manager::{ManagedConnection, ManagedDriver};
-use adbc_core::{Connection, Database, Driver, Statement};
+use adbc_core::{Connection, Database, Driver, Optionable, Statement};
 use arrow_array::RecordBatch;
 use datafusion::prelude::*;
 
-use adbc_core::options::AdbcVersion;
+use adbc_core::options::{AdbcVersion, OptionConnection, OptionStatement, 
OptionValue};
 use arrow_select::concat::concat_batches;
 use datafusion_substrait::logical_plan::producer::to_substrait_plan;
 use datafusion_substrait::substrait::proto::Plan;
@@ -85,6 +85,45 @@ fn execute_substrait(connection: &mut ManagedConnection, 
plan: Plan) -> RecordBa
     concat_batches(&schema, &batches).unwrap()
 }
 
+#[test]
+fn test_connection_options() {
+    let mut connection = get_connection();
+
+    let current_catalog = connection
+        .get_option_string(OptionConnection::CurrentCatalog)
+        .unwrap();
+
+    assert_eq!(current_catalog, "datafusion");
+
+    let _ = connection.set_option(
+        OptionConnection::CurrentCatalog,
+        OptionValue::String("datafusion2".to_string()),
+    );
+
+    let current_catalog = connection
+        .get_option_string(OptionConnection::CurrentCatalog)
+        .unwrap();
+
+    assert_eq!(current_catalog, "datafusion2");
+
+    let current_schema = connection
+        .get_option_string(OptionConnection::CurrentSchema)
+        .unwrap();
+
+    assert_eq!(current_schema, "public");
+
+    let _ = connection.set_option(
+        OptionConnection::CurrentSchema,
+        OptionValue::String("public2".to_string()),
+    );
+
+    let current_schema = connection
+        .get_option_string(OptionConnection::CurrentSchema)
+        .unwrap();
+
+    assert_eq!(current_schema, "public2");
+}
+
 #[test]
 fn test_get_objects_database() {
     let mut connection = get_connection();
@@ -112,6 +151,32 @@ fn test_execute_sql() {
     assert_eq!(batch.num_columns(), 2);
 }
 
+#[test]
+fn test_ingest() {
+    let mut connection = get_connection();
+
+    execute_update(&mut connection, "CREATE TABLE IF NOT EXISTS 
datafusion.public.example (c1 INT, c2 VARCHAR) AS 
VALUES(1,'HELLO'),(2,'DATAFUSION'),(3,'!')");
+
+    let batch = execute_sql_query(&mut connection, "SELECT * FROM 
datafusion.public.example");
+
+    assert_eq!(batch.num_rows(), 3);
+    assert_eq!(batch.num_columns(), 2);
+
+    let mut statement = connection.new_statement().unwrap();
+
+    let _ = statement.set_option(
+        OptionStatement::TargetTable,
+        OptionValue::String("example".to_string()),
+    );
+    let _ = statement.bind(batch);
+
+    let _ = statement.execute_update();
+
+    let batch = execute_sql_query(&mut connection, "SELECT * FROM 
datafusion.public.example");
+
+    assert_eq!(batch.num_rows(), 6);
+}
+
 #[test]
 fn test_execute_substrait() {
     let mut connection = get_connection();

Reply via email to