This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 90b5cc74 Fix regression in DataFrame.write_xxx (#945)
90b5cc74 is described below

commit 90b5cc748c5f92762e9ede52d6908d62823f9563
Author: Andy Grove <[email protected]>
AuthorDate: Sat Dec 30 11:28:00 2023 -0700

    Fix regression in DataFrame.write_xxx (#945)
---
 Cargo.toml                                     |  6 +--
 ballista/cache/src/listener/loading_cache.rs   | 10 ++---
 ballista/client/src/context.rs                 | 53 ++++++++++++++++++++------
 ballista/core/src/client.rs                    |  2 +-
 ballista/core/src/error.rs                     |  2 +-
 ballista/core/src/serde/mod.rs                 |  6 +--
 ballista/scheduler/src/cluster/storage/etcd.rs |  4 +-
 ballista/scheduler/src/state/task_manager.rs   |  2 +-
 benchmarks/src/bin/tpch.rs                     |  2 -
 9 files changed, 57 insertions(+), 30 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 0deca9c4..0f2f91ff 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -34,9 +34,9 @@ arrow-flight = { version = "49.0.0", features = 
["flight-sql-experimental"] }
 arrow-schema = { version = "49.0.0", default-features = false }
 configure_me = { version = "0.4.0" }
 configure_me_codegen = { version = "0.4.4" }
-datafusion = "34.0.0"
-datafusion-cli = "34.0.0"
-datafusion-proto = "34.0.0"
+datafusion = { git = "https://github.com/apache/arrow-datafusion";, rev = 
"7fc663c2e40be2928778102386bbf76962dd2cdc" }
+datafusion-cli = { git = "https://github.com/apache/arrow-datafusion";, rev = 
"7fc663c2e40be2928778102386bbf76962dd2cdc" }
+datafusion-proto = { git = "https://github.com/apache/arrow-datafusion";, rev = 
"7fc663c2e40be2928778102386bbf76962dd2cdc" }
 object_store = "0.8.0"
 sqlparser = "0.40.0"
 tonic = { version = "0.10" }
diff --git a/ballista/cache/src/listener/loading_cache.rs 
b/ballista/cache/src/listener/loading_cache.rs
index e54e7d83..09f61c3b 100644
--- a/ballista/cache/src/listener/loading_cache.rs
+++ b/ballista/cache/src/listener/loading_cache.rs
@@ -145,7 +145,7 @@ where
     fn listen_on_get_if_present(&self, k: Self::K, v: Option<Self::V>) {
         if self.listeners.len() == 1 {
             self.listeners
-                .get(0)
+                .first()
                 .unwrap()
                 .listen_on_get_if_present(k, v);
         } else {
@@ -157,7 +157,7 @@ where
 
     fn listen_on_get(&self, k: Self::K, v: Self::V, status: CacheGetStatus) {
         if self.listeners.len() == 1 {
-            self.listeners.get(0).unwrap().listen_on_get(k, v, status);
+            self.listeners.first().unwrap().listen_on_get(k, v, status);
         } else {
             self.listeners.iter().for_each(|listener| {
                 listener.listen_on_get(k.clone(), v.clone(), status)
@@ -167,7 +167,7 @@ where
 
     fn listen_on_put(&self, k: Self::K, v: Self::V) {
         if self.listeners.len() == 1 {
-            self.listeners.get(0).unwrap().listen_on_put(k, v);
+            self.listeners.first().unwrap().listen_on_put(k, v);
         } else {
             self.listeners
                 .iter()
@@ -177,7 +177,7 @@ where
 
     fn listen_on_invalidate(&self, k: Self::K) {
         if self.listeners.len() == 1 {
-            self.listeners.get(0).unwrap().listen_on_invalidate(k);
+            self.listeners.first().unwrap().listen_on_invalidate(k);
         } else {
             self.listeners
                 .iter()
@@ -187,7 +187,7 @@ where
 
     fn listen_on_get_cancelling(&self, k: Self::K) {
         if self.listeners.len() == 1 {
-            self.listeners.get(0).unwrap().listen_on_get_cancelling(k);
+            self.listeners.first().unwrap().listen_on_get_cancelling(k);
         } else {
             self.listeners
                 .iter()
diff --git a/ballista/client/src/context.rs b/ballista/client/src/context.rs
index 28904487..9b1e64e6 100644
--- a/ballista/client/src/context.rs
+++ b/ballista/client/src/context.rs
@@ -405,7 +405,7 @@ impl BallistaContext {
                         schema
                             .field_with_name(col)
                             .map(|f| (f.name().to_owned(), 
f.data_type().to_owned()))
-                            .map_err(DataFusionError::ArrowError)
+                            .map_err(|e| DataFusionError::ArrowError(e, None))
                     })
                     .collect::<Result<Vec<_>>>()?;
 
@@ -458,12 +458,15 @@ impl BallistaContext {
 }
 
 #[cfg(test)]
-mod tests {
-    #[cfg(feature = "standalone")]
+#[cfg(feature = "standalone")]
+mod standalone_tests {
+    use ballista_core::error::Result;
+    use datafusion::dataframe::DataFrameWriteOptions;
     use datafusion::datasource::listing::ListingTableUrl;
+    use datafusion::parquet::file::properties::WriterProperties;
+    use tempfile::TempDir;
 
     #[tokio::test]
-    #[cfg(feature = "standalone")]
     async fn test_standalone_mode() {
         use super::*;
         let context = 
BallistaContext::standalone(&BallistaConfig::new().unwrap(), 1)
@@ -474,7 +477,40 @@ mod tests {
     }
 
     #[tokio::test]
-    #[cfg(feature = "standalone")]
+    async fn test_write_parquet() -> Result<()> {
+        use super::*;
+        let context =
+            BallistaContext::standalone(&BallistaConfig::new().unwrap(), 
1).await?;
+        let df = context.sql("SELECT 1;").await?;
+        let tmp_dir = TempDir::new().unwrap();
+        let file_path = format!(
+            "{}",
+            tmp_dir.path().join("test_write_parquet.parquet").display()
+        );
+        df.write_parquet(
+            &file_path,
+            DataFrameWriteOptions::default(),
+            Some(WriterProperties::default()),
+        )
+        .await?;
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_write_csv() -> Result<()> {
+        use super::*;
+        let context =
+            BallistaContext::standalone(&BallistaConfig::new().unwrap(), 
1).await?;
+        let df = context.sql("SELECT 1;").await?;
+        let tmp_dir = TempDir::new().unwrap();
+        let file_path =
+            format!("{}", tmp_dir.path().join("test_write_csv.csv").display());
+        df.write_csv(&file_path, DataFrameWriteOptions::default(), None)
+            .await?;
+        Ok(())
+    }
+
+    #[tokio::test]
     async fn test_ballista_show_tables() {
         use super::*;
         use std::fs::File;
@@ -517,7 +553,6 @@ mod tests {
     }
 
     #[tokio::test]
-    #[cfg(feature = "standalone")]
     async fn test_show_tables_not_with_information_schema() {
         use super::*;
         use ballista_core::config::{
@@ -563,7 +598,6 @@ mod tests {
     }
 
     #[tokio::test]
-    #[cfg(feature = "standalone")]
     #[ignore]
     // Tracking: https://github.com/apache/arrow-datafusion/issues/1840
     async fn test_task_stuck_when_referenced_task_failed() {
@@ -611,9 +645,7 @@ mod tests {
                         collect_stat: x.collect_stat,
                         target_partitions: x.target_partitions,
                         file_sort_order: vec![],
-                        infinite_source: false,
                         file_type_write_options: None,
-                        single_file: false,
                     };
 
                     let table_paths = listing_table
@@ -644,7 +676,6 @@ mod tests {
     }
 
     #[tokio::test]
-    #[cfg(feature = "standalone")]
     async fn test_empty_exec_with_one_row() {
         use crate::context::BallistaContext;
         use ballista_core::config::{
@@ -664,7 +695,6 @@ mod tests {
     }
 
     #[tokio::test]
-    #[cfg(feature = "standalone")]
     async fn test_union_and_union_all() {
         use super::*;
         use ballista_core::config::{
@@ -723,7 +753,6 @@ mod tests {
     }
 
     #[tokio::test]
-    #[cfg(feature = "standalone")]
     async fn test_aggregate_func() {
         use crate::context::BallistaContext;
         use ballista_core::config::{
diff --git a/ballista/core/src/client.rs b/ballista/core/src/client.rs
index 4382dd05..d19e6766 100644
--- a/ballista/core/src/client.rs
+++ b/ballista/core/src/client.rs
@@ -220,7 +220,7 @@ impl Stream for FlightDataStream {
                             self.schema.clone(),
                             &self.dictionaries_by_id,
                         )
-                        .map_err(DataFusionError::ArrowError)
+                        .map_err(|e| DataFusionError::ArrowError(e, None))
                     });
                 Some(converted_chunk)
             }
diff --git a/ballista/core/src/error.rs b/ballista/core/src/error.rs
index 67f75283..b56107c9 100644
--- a/ballista/core/src/error.rs
+++ b/ballista/core/src/error.rs
@@ -101,7 +101,7 @@ impl From<parser::ParserError> for BallistaError {
 impl From<DataFusionError> for BallistaError {
     fn from(e: DataFusionError) -> Self {
         match e {
-            DataFusionError::ArrowError(e) => Self::from(e),
+            DataFusionError::ArrowError(e, _) => Self::from(e),
             _ => BallistaError::DataFusionError(e),
         }
     }
diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs
index 7f74e5ae..e12dec8a 100644
--- a/ballista/core/src/serde/mod.rs
+++ b/ballista/core/src/serde/mod.rs
@@ -286,9 +286,9 @@ impl PhysicalExtensionCodec for 
BallistaPhysicalExtensionCodec {
 
             Ok(())
         } else {
-            Err(DataFusionError::Internal(
-                "unsupported plan type".to_string(),
-            ))
+            Err(DataFusionError::Internal(format!(
+                "unsupported plan type: {node:?}"
+            )))
         }
     }
 }
diff --git a/ballista/scheduler/src/cluster/storage/etcd.rs 
b/ballista/scheduler/src/cluster/storage/etcd.rs
index 528a9116..514511c9 100644
--- a/ballista/scheduler/src/cluster/storage/etcd.rs
+++ b/ballista/scheduler/src/cluster/storage/etcd.rs
@@ -57,7 +57,7 @@ impl KeyValueStore for EtcdClient {
             .await
             .map_err(|e| ballista_error(&format!("etcd error {e:?}")))?
             .kvs()
-            .get(0)
+            .first()
             .map(|kv| kv.value().to_owned())
             .unwrap_or_default())
     }
@@ -181,7 +181,7 @@ impl KeyValueStore for EtcdClient {
             .await
             .map_err(|e| ballista_error(&format!("etcd error {e:?}")))?
             .kvs()
-            .get(0)
+            .first()
             .map(|kv| kv.value().to_owned());
 
         if let Some(value) = current_value {
diff --git a/ballista/scheduler/src/state/task_manager.rs 
b/ballista/scheduler/src/state/task_manager.rs
index 02b90861..66714e6c 100644
--- a/ballista/scheduler/src/state/task_manager.rs
+++ b/ballista/scheduler/src/state/task_manager.rs
@@ -557,7 +557,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         &self,
         tasks: Vec<TaskDescription>,
     ) -> Result<Vec<MultiTaskDefinition>> {
-        if let Some(task) = tasks.get(0) {
+        if let Some(task) = tasks.first() {
             let session_id = task.session_id.clone();
             let job_id = task.partition.job_id.clone();
             let stage_id = task.partition.stage_id;
diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs
index b2863a41..c0719175 100644
--- a/benchmarks/src/bin/tpch.rs
+++ b/benchmarks/src/bin/tpch.rs
@@ -844,9 +844,7 @@ async fn get_table(
         collect_stat: true,
         table_partition_cols: vec![],
         file_sort_order: vec![],
-        infinite_source: false,
         file_type_write_options: None,
-        single_file: false,
     };
 
     let url = ListingTableUrl::parse(path)?;

Reply via email to