This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 9b78efaecc Add serde support for Arrow FileTypeWriterOptions (#8850)
9b78efaecc is described below

commit 9b78efaeccd57ab8b8e20c29174a121af8130376
Author: Tushushu <[email protected]>
AuthorDate: Fri Jan 19 04:24:48 2024 +0800

    Add serde support for Arrow FileTypeWriterOptions (#8850)
    
    * refactor
    
    * generated files
    
    * feat
    
    * feat
    
    * feat
    
    * feat
    
    * tests
    
    * clippy
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/common/src/file_options/arrow_writer.rs | 12 +++
 datafusion/proto/proto/datafusion.proto            |  3 +
 datafusion/proto/src/generated/pbjson.rs           | 85 ++++++++++++++++++++++
 datafusion/proto/src/generated/prost.rs            |  7 +-
 datafusion/proto/src/logical_plan/mod.rs           | 19 +++++
 datafusion/proto/src/physical_plan/from_proto.rs   |  5 ++
 .../proto/tests/cases/roundtrip_logical_plan.rs    | 40 ++++++++++
 7 files changed, 170 insertions(+), 1 deletion(-)

diff --git a/datafusion/common/src/file_options/arrow_writer.rs 
b/datafusion/common/src/file_options/arrow_writer.rs
index a30e6d800e..cb921535ab 100644
--- a/datafusion/common/src/file_options/arrow_writer.rs
+++ b/datafusion/common/src/file_options/arrow_writer.rs
@@ -27,6 +27,18 @@ use super::StatementOptions;
 #[derive(Clone, Debug)]
 pub struct ArrowWriterOptions {}
 
+impl ArrowWriterOptions {
+    pub fn new() -> Self {
+        Self {}
+    }
+}
+
+impl Default for ArrowWriterOptions {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
 impl TryFrom<(&ConfigOptions, &StatementOptions)> for ArrowWriterOptions {
     type Error = DataFusionError;
 
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 8bde0da133..d79879e57a 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1213,6 +1213,7 @@ message FileTypeWriterOptions {
     JsonWriterOptions json_options = 1;
     ParquetWriterOptions parquet_options = 2;
     CsvWriterOptions csv_options = 3;
+    ArrowWriterOptions arrow_options = 4;
   }
 }
 
@@ -1243,6 +1244,8 @@ message CsvWriterOptions {
   string null_value = 8;
 }
 
+message ArrowWriterOptions {}
+
 message WriterProperties {
   uint64 data_page_size_limit = 1;
   uint64 dictionary_page_size_limit = 2;
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 528761136c..d7ad6fb03c 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -1929,6 +1929,77 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
         deserializer.deserialize_struct("datafusion.ArrowType", FIELDS, 
GeneratedVisitor)
     }
 }
+impl serde::Serialize for ArrowWriterOptions {
+    #[allow(deprecated)]
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        use serde::ser::SerializeStruct;
+        let len = 0;
+        let struct_ser = 
serializer.serialize_struct("datafusion.ArrowWriterOptions", len)?;
+        struct_ser.end()
+    }
+}
+impl<'de> serde::Deserialize<'de> for ArrowWriterOptions {
+    #[allow(deprecated)]
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        const FIELDS: &[&str] = &[
+        ];
+
+        #[allow(clippy::enum_variant_names)]
+        enum GeneratedField {
+        }
+        impl<'de> serde::Deserialize<'de> for GeneratedField {
+            fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
+            where
+                D: serde::Deserializer<'de>,
+            {
+                struct GeneratedVisitor;
+
+                impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+                    type Value = GeneratedField;
+
+                    fn expecting(&self, formatter: &mut 
std::fmt::Formatter<'_>) -> std::fmt::Result {
+                        write!(formatter, "expected one of: {:?}", &FIELDS)
+                    }
+
+                    #[allow(unused_variables)]
+                    fn visit_str<E>(self, value: &str) -> 
std::result::Result<GeneratedField, E>
+                    where
+                        E: serde::de::Error,
+                    {
+                            Err(serde::de::Error::unknown_field(value, FIELDS))
+                    }
+                }
+                deserializer.deserialize_identifier(GeneratedVisitor)
+            }
+        }
+        struct GeneratedVisitor;
+        impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+            type Value = ArrowWriterOptions;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> 
std::fmt::Result {
+                formatter.write_str("struct datafusion.ArrowWriterOptions")
+            }
+
+            fn visit_map<V>(self, mut map_: V) -> 
std::result::Result<ArrowWriterOptions, V::Error>
+                where
+                    V: serde::de::MapAccess<'de>,
+            {
+                while map_.next_key::<GeneratedField>()?.is_some() {
+                    let _ = map_.next_value::<serde::de::IgnoredAny>()?;
+                }
+                Ok(ArrowWriterOptions {
+                })
+            }
+        }
+        deserializer.deserialize_struct("datafusion.ArrowWriterOptions", 
FIELDS, GeneratedVisitor)
+    }
+}
 impl serde::Serialize for AvroFormat {
     #[allow(deprecated)]
     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
@@ -8354,6 +8425,9 @@ impl serde::Serialize for FileTypeWriterOptions {
                 file_type_writer_options::FileType::CsvOptions(v) => {
                     struct_ser.serialize_field("csvOptions", v)?;
                 }
+                file_type_writer_options::FileType::ArrowOptions(v) => {
+                    struct_ser.serialize_field("arrowOptions", v)?;
+                }
             }
         }
         struct_ser.end()
@@ -8372,6 +8446,8 @@ impl<'de> serde::Deserialize<'de> for 
FileTypeWriterOptions {
             "parquetOptions",
             "csv_options",
             "csvOptions",
+            "arrow_options",
+            "arrowOptions",
         ];
 
         #[allow(clippy::enum_variant_names)]
@@ -8379,6 +8455,7 @@ impl<'de> serde::Deserialize<'de> for 
FileTypeWriterOptions {
             JsonOptions,
             ParquetOptions,
             CsvOptions,
+            ArrowOptions,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -8403,6 +8480,7 @@ impl<'de> serde::Deserialize<'de> for 
FileTypeWriterOptions {
                             "jsonOptions" | "json_options" => 
Ok(GeneratedField::JsonOptions),
                             "parquetOptions" | "parquet_options" => 
Ok(GeneratedField::ParquetOptions),
                             "csvOptions" | "csv_options" => 
Ok(GeneratedField::CsvOptions),
+                            "arrowOptions" | "arrow_options" => 
Ok(GeneratedField::ArrowOptions),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -8444,6 +8522,13 @@ impl<'de> serde::Deserialize<'de> for 
FileTypeWriterOptions {
                                 return 
Err(serde::de::Error::duplicate_field("csvOptions"));
                             }
                             file_type__ = 
map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::CsvOptions)
+;
+                        }
+                        GeneratedField::ArrowOptions => {
+                            if file_type__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("arrowOptions"));
+                            }
+                            file_type__ = 
map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::ArrowOptions)
 ;
                         }
                     }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 9a0b7ab332..d594da9087 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1646,7 +1646,7 @@ pub struct PartitionColumn {
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
 pub struct FileTypeWriterOptions {
-    #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2, 3")]
+    #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2, 3, 4")]
     pub file_type: ::core::option::Option<file_type_writer_options::FileType>,
 }
 /// Nested message and enum types in `FileTypeWriterOptions`.
@@ -1660,6 +1660,8 @@ pub mod file_type_writer_options {
         ParquetOptions(super::ParquetWriterOptions),
         #[prost(message, tag = "3")]
         CsvOptions(super::CsvWriterOptions),
+        #[prost(message, tag = "4")]
+        ArrowOptions(super::ArrowWriterOptions),
     }
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
@@ -1704,6 +1706,9 @@ pub struct CsvWriterOptions {
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
+pub struct ArrowWriterOptions {}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
 pub struct WriterProperties {
     #[prost(uint64, tag = "1")]
     pub data_page_size_limit: u64,
diff --git a/datafusion/proto/src/logical_plan/mod.rs 
b/datafusion/proto/src/logical_plan/mod.rs
index 6ca95519a9..f10f11c1c0 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 use arrow::csv::WriterBuilder;
+use datafusion_common::file_options::arrow_writer::ArrowWriterOptions;
 use std::collections::HashMap;
 use std::fmt::Debug;
 use std::str::FromStr;
@@ -858,6 +859,13 @@ impl AsLogicalPlan for LogicalPlanNode {
                     Some(copy_to_node::CopyOptions::WriterOptions(opt)) => {
                         match &opt.file_type {
                             Some(ft) => match ft {
+                                
file_type_writer_options::FileType::ArrowOptions(_) => {
+                                    CopyOptions::WriterOptions(Box::new(
+                                        FileTypeWriterOptions::Arrow(
+                                            ArrowWriterOptions::new(),
+                                        ),
+                                    ))
+                                }
                                 file_type_writer_options::FileType::CsvOptions(
                                     writer_options,
                                 ) => {
@@ -1659,6 +1667,17 @@ impl AsLogicalPlan for LogicalPlanNode {
                         }
                         CopyOptions::WriterOptions(opt) => {
                             match opt.as_ref() {
+                                FileTypeWriterOptions::Arrow(_) => {
+                                    let arrow_writer_options =
+                                        
file_type_writer_options::FileType::ArrowOptions(
+                                            protobuf::ArrowWriterOptions {},
+                                        );
+                                    
Some(copy_to_node::CopyOptions::WriterOptions(
+                                        protobuf::FileTypeWriterOptions {
+                                            file_type: 
Some(arrow_writer_options),
+                                        },
+                                    ))
+                                }
                                 FileTypeWriterOptions::CSV(csv_opts) => {
                                     let csv_options = &csv_opts.writer_options;
                                     let csv_writer_options = 
csv_writer_options_to_proto(
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs 
b/datafusion/proto/src/physical_plan/from_proto.rs
index ea28eeee88..dc827d02bf 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -42,6 +42,7 @@ use datafusion::physical_plan::windows::create_window_expr;
 use datafusion::physical_plan::{
     functions, ColumnStatistics, Partitioning, PhysicalExpr, Statistics, 
WindowExpr,
 };
+use datafusion_common::file_options::arrow_writer::ArrowWriterOptions;
 use datafusion_common::file_options::csv_writer::CsvWriterOptions;
 use datafusion_common::file_options::json_writer::JsonWriterOptions;
 use datafusion_common::file_options::parquet_writer::ParquetWriterOptions;
@@ -834,6 +835,10 @@ impl TryFrom<&protobuf::FileTypeWriterOptions> for 
FileTypeWriterOptions {
             .ok_or_else(|| proto_error("Missing required file_type field in 
protobuf"))?;
 
         match file_type {
+            protobuf::file_type_writer_options::FileType::ArrowOptions(_) => {
+                Ok(Self::Arrow(ArrowWriterOptions::new()))
+            }
+
             protobuf::file_type_writer_options::FileType::JsonOptions(opts) => 
{
                 let compression: CompressionTypeVariant = 
opts.compression().into();
                 Ok(Self::JSON(JsonWriterOptions::new(compression)))
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index ed21124a9e..2d38cfd400 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -27,6 +27,7 @@ use arrow::datatypes::{
     IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
 };
 
+use datafusion_common::file_options::arrow_writer::ArrowWriterOptions;
 use prost::Message;
 
 use datafusion::datasource::provider::TableProviderFactory;
@@ -394,6 +395,45 @@ async fn roundtrip_logical_plan_copy_to_writer_options() 
-> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn roundtrip_logical_plan_copy_to_arrow() -> Result<()> {
+    let ctx = SessionContext::new();
+
+    let input = create_csv_scan(&ctx).await?;
+
+    let plan = LogicalPlan::Copy(CopyTo {
+        input: Arc::new(input),
+        output_url: "test.arrow".to_string(),
+        file_format: FileType::ARROW,
+        single_file_output: true,
+        copy_options: 
CopyOptions::WriterOptions(Box::new(FileTypeWriterOptions::Arrow(
+            ArrowWriterOptions::new(),
+        ))),
+    });
+
+    let bytes = logical_plan_to_bytes(&plan)?;
+    let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?;
+    assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}"));
+
+    match logical_round_trip {
+        LogicalPlan::Copy(copy_to) => {
+            assert_eq!("test.arrow", copy_to.output_url);
+            assert_eq!(FileType::ARROW, copy_to.file_format);
+            assert!(copy_to.single_file_output);
+            match &copy_to.copy_options {
+                CopyOptions::WriterOptions(y) => match y.as_ref() {
+                    FileTypeWriterOptions::Arrow(_) => {}
+                    _ => panic!(),
+                },
+                _ => panic!(),
+            }
+        }
+        _ => panic!(),
+    }
+
+    Ok(())
+}
+
 #[tokio::test]
 async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> {
     let ctx = SessionContext::new();

Reply via email to