This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 9b78efaecc Add serde support for Arrow FileTypeWriterOptions (#8850)
9b78efaecc is described below
commit 9b78efaeccd57ab8b8e20c29174a121af8130376
Author: Tushushu <[email protected]>
AuthorDate: Fri Jan 19 04:24:48 2024 +0800
Add serde support for Arrow FileTypeWriterOptions (#8850)
* refactor
* generated files
* feat
* feat
* feat
* feat
* tests
* clippy
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/common/src/file_options/arrow_writer.rs | 12 +++
datafusion/proto/proto/datafusion.proto | 3 +
datafusion/proto/src/generated/pbjson.rs | 85 ++++++++++++++++++++++
datafusion/proto/src/generated/prost.rs | 7 +-
datafusion/proto/src/logical_plan/mod.rs | 19 +++++
datafusion/proto/src/physical_plan/from_proto.rs | 5 ++
.../proto/tests/cases/roundtrip_logical_plan.rs | 40 ++++++++++
7 files changed, 170 insertions(+), 1 deletion(-)
diff --git a/datafusion/common/src/file_options/arrow_writer.rs
b/datafusion/common/src/file_options/arrow_writer.rs
index a30e6d800e..cb921535ab 100644
--- a/datafusion/common/src/file_options/arrow_writer.rs
+++ b/datafusion/common/src/file_options/arrow_writer.rs
@@ -27,6 +27,18 @@ use super::StatementOptions;
#[derive(Clone, Debug)]
pub struct ArrowWriterOptions {}
+impl ArrowWriterOptions {
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+impl Default for ArrowWriterOptions {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
impl TryFrom<(&ConfigOptions, &StatementOptions)> for ArrowWriterOptions {
type Error = DataFusionError;
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 8bde0da133..d79879e57a 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1213,6 +1213,7 @@ message FileTypeWriterOptions {
JsonWriterOptions json_options = 1;
ParquetWriterOptions parquet_options = 2;
CsvWriterOptions csv_options = 3;
+ ArrowWriterOptions arrow_options = 4;
}
}
@@ -1243,6 +1244,8 @@ message CsvWriterOptions {
string null_value = 8;
}
+message ArrowWriterOptions {}
+
message WriterProperties {
uint64 data_page_size_limit = 1;
uint64 dictionary_page_size_limit = 2;
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 528761136c..d7ad6fb03c 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -1929,6 +1929,77 @@ impl<'de> serde::Deserialize<'de> for ArrowType {
deserializer.deserialize_struct("datafusion.ArrowType", FIELDS,
GeneratedVisitor)
}
}
+impl serde::Serialize for ArrowWriterOptions {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let len = 0;
+ let struct_ser =
serializer.serialize_struct("datafusion.ArrowWriterOptions", len)?;
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for ArrowWriterOptions {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ Err(serde::de::Error::unknown_field(value, FIELDS))
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = ArrowWriterOptions;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct datafusion.ArrowWriterOptions")
+ }
+
+ fn visit_map<V>(self, mut map_: V) ->
std::result::Result<ArrowWriterOptions, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ while map_.next_key::<GeneratedField>()?.is_some() {
+ let _ = map_.next_value::<serde::de::IgnoredAny>()?;
+ }
+ Ok(ArrowWriterOptions {
+ })
+ }
+ }
+ deserializer.deserialize_struct("datafusion.ArrowWriterOptions",
FIELDS, GeneratedVisitor)
+ }
+}
impl serde::Serialize for AvroFormat {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
@@ -8354,6 +8425,9 @@ impl serde::Serialize for FileTypeWriterOptions {
file_type_writer_options::FileType::CsvOptions(v) => {
struct_ser.serialize_field("csvOptions", v)?;
}
+ file_type_writer_options::FileType::ArrowOptions(v) => {
+ struct_ser.serialize_field("arrowOptions", v)?;
+ }
}
}
struct_ser.end()
@@ -8372,6 +8446,8 @@ impl<'de> serde::Deserialize<'de> for
FileTypeWriterOptions {
"parquetOptions",
"csv_options",
"csvOptions",
+ "arrow_options",
+ "arrowOptions",
];
#[allow(clippy::enum_variant_names)]
@@ -8379,6 +8455,7 @@ impl<'de> serde::Deserialize<'de> for
FileTypeWriterOptions {
JsonOptions,
ParquetOptions,
CsvOptions,
+ ArrowOptions,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -8403,6 +8480,7 @@ impl<'de> serde::Deserialize<'de> for
FileTypeWriterOptions {
"jsonOptions" | "json_options" =>
Ok(GeneratedField::JsonOptions),
"parquetOptions" | "parquet_options" =>
Ok(GeneratedField::ParquetOptions),
"csvOptions" | "csv_options" =>
Ok(GeneratedField::CsvOptions),
+ "arrowOptions" | "arrow_options" =>
Ok(GeneratedField::ArrowOptions),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -8444,6 +8522,13 @@ impl<'de> serde::Deserialize<'de> for
FileTypeWriterOptions {
return
Err(serde::de::Error::duplicate_field("csvOptions"));
}
file_type__ =
map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::CsvOptions)
+;
+ }
+ GeneratedField::ArrowOptions => {
+ if file_type__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("arrowOptions"));
+ }
+ file_type__ =
map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::ArrowOptions)
;
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 9a0b7ab332..d594da9087 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1646,7 +1646,7 @@ pub struct PartitionColumn {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct FileTypeWriterOptions {
- #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2, 3")]
+ #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2, 3, 4")]
pub file_type: ::core::option::Option<file_type_writer_options::FileType>,
}
/// Nested message and enum types in `FileTypeWriterOptions`.
@@ -1660,6 +1660,8 @@ pub mod file_type_writer_options {
ParquetOptions(super::ParquetWriterOptions),
#[prost(message, tag = "3")]
CsvOptions(super::CsvWriterOptions),
+ #[prost(message, tag = "4")]
+ ArrowOptions(super::ArrowWriterOptions),
}
}
#[allow(clippy::derive_partial_eq_without_eq)]
@@ -1704,6 +1706,9 @@ pub struct CsvWriterOptions {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct ArrowWriterOptions {}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
pub struct WriterProperties {
#[prost(uint64, tag = "1")]
pub data_page_size_limit: u64,
diff --git a/datafusion/proto/src/logical_plan/mod.rs
b/datafusion/proto/src/logical_plan/mod.rs
index 6ca95519a9..f10f11c1c0 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -16,6 +16,7 @@
// under the License.
use arrow::csv::WriterBuilder;
+use datafusion_common::file_options::arrow_writer::ArrowWriterOptions;
use std::collections::HashMap;
use std::fmt::Debug;
use std::str::FromStr;
@@ -858,6 +859,13 @@ impl AsLogicalPlan for LogicalPlanNode {
Some(copy_to_node::CopyOptions::WriterOptions(opt)) => {
match &opt.file_type {
Some(ft) => match ft {
+
file_type_writer_options::FileType::ArrowOptions(_) => {
+ CopyOptions::WriterOptions(Box::new(
+ FileTypeWriterOptions::Arrow(
+ ArrowWriterOptions::new(),
+ ),
+ ))
+ }
file_type_writer_options::FileType::CsvOptions(
writer_options,
) => {
@@ -1659,6 +1667,17 @@ impl AsLogicalPlan for LogicalPlanNode {
}
CopyOptions::WriterOptions(opt) => {
match opt.as_ref() {
+ FileTypeWriterOptions::Arrow(_) => {
+ let arrow_writer_options =
+
file_type_writer_options::FileType::ArrowOptions(
+ protobuf::ArrowWriterOptions {},
+ );
+
Some(copy_to_node::CopyOptions::WriterOptions(
+ protobuf::FileTypeWriterOptions {
+ file_type:
Some(arrow_writer_options),
+ },
+ ))
+ }
FileTypeWriterOptions::CSV(csv_opts) => {
let csv_options = &csv_opts.writer_options;
let csv_writer_options =
csv_writer_options_to_proto(
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs
b/datafusion/proto/src/physical_plan/from_proto.rs
index ea28eeee88..dc827d02bf 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -42,6 +42,7 @@ use datafusion::physical_plan::windows::create_window_expr;
use datafusion::physical_plan::{
functions, ColumnStatistics, Partitioning, PhysicalExpr, Statistics,
WindowExpr,
};
+use datafusion_common::file_options::arrow_writer::ArrowWriterOptions;
use datafusion_common::file_options::csv_writer::CsvWriterOptions;
use datafusion_common::file_options::json_writer::JsonWriterOptions;
use datafusion_common::file_options::parquet_writer::ParquetWriterOptions;
@@ -834,6 +835,10 @@ impl TryFrom<&protobuf::FileTypeWriterOptions> for
FileTypeWriterOptions {
.ok_or_else(|| proto_error("Missing required file_type field in
protobuf"))?;
match file_type {
+ protobuf::file_type_writer_options::FileType::ArrowOptions(_) => {
+ Ok(Self::Arrow(ArrowWriterOptions::new()))
+ }
+
protobuf::file_type_writer_options::FileType::JsonOptions(opts) =>
{
let compression: CompressionTypeVariant =
opts.compression().into();
Ok(Self::JSON(JsonWriterOptions::new(compression)))
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index ed21124a9e..2d38cfd400 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -27,6 +27,7 @@ use arrow::datatypes::{
IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
};
+use datafusion_common::file_options::arrow_writer::ArrowWriterOptions;
use prost::Message;
use datafusion::datasource::provider::TableProviderFactory;
@@ -394,6 +395,45 @@ async fn roundtrip_logical_plan_copy_to_writer_options()
-> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn roundtrip_logical_plan_copy_to_arrow() -> Result<()> {
+ let ctx = SessionContext::new();
+
+ let input = create_csv_scan(&ctx).await?;
+
+ let plan = LogicalPlan::Copy(CopyTo {
+ input: Arc::new(input),
+ output_url: "test.arrow".to_string(),
+ file_format: FileType::ARROW,
+ single_file_output: true,
+ copy_options:
CopyOptions::WriterOptions(Box::new(FileTypeWriterOptions::Arrow(
+ ArrowWriterOptions::new(),
+ ))),
+ });
+
+ let bytes = logical_plan_to_bytes(&plan)?;
+ let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?;
+ assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}"));
+
+ match logical_round_trip {
+ LogicalPlan::Copy(copy_to) => {
+ assert_eq!("test.arrow", copy_to.output_url);
+ assert_eq!(FileType::ARROW, copy_to.file_format);
+ assert!(copy_to.single_file_output);
+ match ©_to.copy_options {
+ CopyOptions::WriterOptions(y) => match y.as_ref() {
+ FileTypeWriterOptions::Arrow(_) => {}
+ _ => panic!(),
+ },
+ _ => panic!(),
+ }
+ }
+ _ => panic!(),
+ }
+
+ Ok(())
+}
+
#[tokio::test]
async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> {
let ctx = SessionContext::new();