This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 7fc663c2e4 Implement serde for CSV and Parquet FileSinkExec (#8646)
7fc663c2e4 is described below
commit 7fc663c2e40be2928778102386bbf76962dd2cdc
Author: Andy Grove <[email protected]>
AuthorDate: Fri Dec 29 16:53:31 2023 -0700
Implement serde for CSV and Parquet FileSinkExec (#8646)
* Add serde for Csv and Parquet sink
* Add tests
* parquet test passes
* save progress
* add compression type to csv serde
* remove hard-coded compression from CSV serde
---
datafusion/core/src/datasource/file_format/csv.rs | 11 +-
.../core/src/datasource/file_format/parquet.rs | 9 +-
datafusion/proto/proto/datafusion.proto | 40 +-
datafusion/proto/src/generated/pbjson.rs | 517 +++++++++++++++++++++
datafusion/proto/src/generated/prost.rs | 59 ++-
datafusion/proto/src/logical_plan/mod.rs | 43 +-
datafusion/proto/src/physical_plan/from_proto.rs | 38 +-
datafusion/proto/src/physical_plan/mod.rs | 91 ++++
datafusion/proto/src/physical_plan/to_proto.rs | 46 +-
.../proto/tests/cases/roundtrip_physical_plan.rs | 125 ++++-
10 files changed, 922 insertions(+), 57 deletions(-)
diff --git a/datafusion/core/src/datasource/file_format/csv.rs
b/datafusion/core/src/datasource/file_format/csv.rs
index d4e63904bd..7a0af3ff08 100644
--- a/datafusion/core/src/datasource/file_format/csv.rs
+++ b/datafusion/core/src/datasource/file_format/csv.rs
@@ -437,7 +437,7 @@ impl BatchSerializer for CsvSerializer {
}
/// Implements [`DataSink`] for writing to a CSV file.
-struct CsvSink {
+pub struct CsvSink {
/// Config options for writing data
config: FileSinkConfig,
}
@@ -461,9 +461,16 @@ impl DisplayAs for CsvSink {
}
impl CsvSink {
- fn new(config: FileSinkConfig) -> Self {
+ /// Create from config.
+ pub fn new(config: FileSinkConfig) -> Self {
Self { config }
}
+
+ /// Retrieve the inner [`FileSinkConfig`].
+ pub fn config(&self) -> &FileSinkConfig {
+ &self.config
+ }
+
async fn multipartput_all(
&self,
data: SendableRecordBatchStream,
diff --git a/datafusion/core/src/datasource/file_format/parquet.rs
b/datafusion/core/src/datasource/file_format/parquet.rs
index 7044acccd6..9729bfa163 100644
--- a/datafusion/core/src/datasource/file_format/parquet.rs
+++ b/datafusion/core/src/datasource/file_format/parquet.rs
@@ -621,7 +621,7 @@ async fn fetch_statistics(
}
/// Implements [`DataSink`] for writing to a parquet file.
-struct ParquetSink {
+pub struct ParquetSink {
/// Config options for writing data
config: FileSinkConfig,
}
@@ -645,10 +645,15 @@ impl DisplayAs for ParquetSink {
}
impl ParquetSink {
- fn new(config: FileSinkConfig) -> Self {
+ /// Create from config.
+ pub fn new(config: FileSinkConfig) -> Self {
Self { config }
}
+ /// Retrieve the inner [`FileSinkConfig`].
+ pub fn config(&self) -> &FileSinkConfig {
+ &self.config
+ }
/// Converts table schema to writer schema, which may differ in the case
/// of hive style partitioning where some columns are removed from the
/// underlying files.
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 59b82efcbb..d5f8397aa3 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1187,6 +1187,8 @@ message PhysicalPlanNode {
SymmetricHashJoinExecNode symmetric_hash_join = 25;
InterleaveExecNode interleave = 26;
PlaceholderRowExecNode placeholder_row = 27;
+ CsvSinkExecNode csv_sink = 28;
+ ParquetSinkExecNode parquet_sink = 29;
}
}
@@ -1220,20 +1222,22 @@ message ParquetWriterOptions {
}
message CsvWriterOptions {
+ // Compression type
+ CompressionTypeVariant compression = 1;
// Optional column delimiter. Defaults to `b','`
- string delimiter = 1;
+ string delimiter = 2;
// Whether to write column names as file headers. Defaults to `true`
- bool has_header = 2;
+ bool has_header = 3;
// Optional date format for date arrays
- string date_format = 3;
+ string date_format = 4;
// Optional datetime format for datetime arrays
- string datetime_format = 4;
+ string datetime_format = 5;
// Optional timestamp format for timestamp arrays
- string timestamp_format = 5;
+ string timestamp_format = 6;
// Optional time format for time arrays
- string time_format = 6;
+ string time_format = 7;
// Optional value to represent null
- string null_value = 7;
+ string null_value = 8;
}
message WriterProperties {
@@ -1270,6 +1274,28 @@ message JsonSinkExecNode {
PhysicalSortExprNodeCollection sort_order = 4;
}
+message CsvSink {
+ FileSinkConfig config = 1;
+}
+
+message CsvSinkExecNode {
+ PhysicalPlanNode input = 1;
+ CsvSink sink = 2;
+ Schema sink_schema = 3;
+ PhysicalSortExprNodeCollection sort_order = 4;
+}
+
+message ParquetSink {
+ FileSinkConfig config = 1;
+}
+
+message ParquetSinkExecNode {
+ PhysicalPlanNode input = 1;
+ ParquetSink sink = 2;
+ Schema sink_schema = 3;
+ PhysicalSortExprNodeCollection sort_order = 4;
+}
+
message PhysicalExtensionNode {
bytes node = 1;
repeated PhysicalPlanNode inputs = 2;
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 956244ffdb..12e834d75a 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -5151,6 +5151,241 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode {
deserializer.deserialize_struct("datafusion.CsvScanExecNode", FIELDS,
GeneratedVisitor)
}
}
+impl serde::Serialize for CsvSink {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if self.config.is_some() {
+ len += 1;
+ }
+ let mut struct_ser = serializer.serialize_struct("datafusion.CsvSink",
len)?;
+ if let Some(v) = self.config.as_ref() {
+ struct_ser.serialize_field("config", v)?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for CsvSink {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "config",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ Config,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "config" => Ok(GeneratedField::Config),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = CsvSink;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct datafusion.CsvSink")
+ }
+
+ fn visit_map<V>(self, mut map_: V) -> std::result::Result<CsvSink,
V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut config__ = None;
+ while let Some(k) = map_.next_key()? {
+ match k {
+ GeneratedField::Config => {
+ if config__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("config"));
+ }
+ config__ = map_.next_value()?;
+ }
+ }
+ }
+ Ok(CsvSink {
+ config: config__,
+ })
+ }
+ }
+ deserializer.deserialize_struct("datafusion.CsvSink", FIELDS,
GeneratedVisitor)
+ }
+}
+impl serde::Serialize for CsvSinkExecNode {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if self.input.is_some() {
+ len += 1;
+ }
+ if self.sink.is_some() {
+ len += 1;
+ }
+ if self.sink_schema.is_some() {
+ len += 1;
+ }
+ if self.sort_order.is_some() {
+ len += 1;
+ }
+ let mut struct_ser =
serializer.serialize_struct("datafusion.CsvSinkExecNode", len)?;
+ if let Some(v) = self.input.as_ref() {
+ struct_ser.serialize_field("input", v)?;
+ }
+ if let Some(v) = self.sink.as_ref() {
+ struct_ser.serialize_field("sink", v)?;
+ }
+ if let Some(v) = self.sink_schema.as_ref() {
+ struct_ser.serialize_field("sinkSchema", v)?;
+ }
+ if let Some(v) = self.sort_order.as_ref() {
+ struct_ser.serialize_field("sortOrder", v)?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for CsvSinkExecNode {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "input",
+ "sink",
+ "sink_schema",
+ "sinkSchema",
+ "sort_order",
+ "sortOrder",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ Input,
+ Sink,
+ SinkSchema,
+ SortOrder,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "input" => Ok(GeneratedField::Input),
+ "sink" => Ok(GeneratedField::Sink),
+ "sinkSchema" | "sink_schema" =>
Ok(GeneratedField::SinkSchema),
+ "sortOrder" | "sort_order" =>
Ok(GeneratedField::SortOrder),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = CsvSinkExecNode;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct datafusion.CsvSinkExecNode")
+ }
+
+ fn visit_map<V>(self, mut map_: V) ->
std::result::Result<CsvSinkExecNode, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut input__ = None;
+ let mut sink__ = None;
+ let mut sink_schema__ = None;
+ let mut sort_order__ = None;
+ while let Some(k) = map_.next_key()? {
+ match k {
+ GeneratedField::Input => {
+ if input__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("input"));
+ }
+ input__ = map_.next_value()?;
+ }
+ GeneratedField::Sink => {
+ if sink__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("sink"));
+ }
+ sink__ = map_.next_value()?;
+ }
+ GeneratedField::SinkSchema => {
+ if sink_schema__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("sinkSchema"));
+ }
+ sink_schema__ = map_.next_value()?;
+ }
+ GeneratedField::SortOrder => {
+ if sort_order__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("sortOrder"));
+ }
+ sort_order__ = map_.next_value()?;
+ }
+ }
+ }
+ Ok(CsvSinkExecNode {
+ input: input__,
+ sink: sink__,
+ sink_schema: sink_schema__,
+ sort_order: sort_order__,
+ })
+ }
+ }
+ deserializer.deserialize_struct("datafusion.CsvSinkExecNode", FIELDS,
GeneratedVisitor)
+ }
+}
impl serde::Serialize for CsvWriterOptions {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
@@ -5159,6 +5394,9 @@ impl serde::Serialize for CsvWriterOptions {
{
use serde::ser::SerializeStruct;
let mut len = 0;
+ if self.compression != 0 {
+ len += 1;
+ }
if !self.delimiter.is_empty() {
len += 1;
}
@@ -5181,6 +5419,11 @@ impl serde::Serialize for CsvWriterOptions {
len += 1;
}
let mut struct_ser =
serializer.serialize_struct("datafusion.CsvWriterOptions", len)?;
+ if self.compression != 0 {
+ let v = CompressionTypeVariant::try_from(self.compression)
+ .map_err(|_| serde::ser::Error::custom(format!("Invalid
variant {}", self.compression)))?;
+ struct_ser.serialize_field("compression", &v)?;
+ }
if !self.delimiter.is_empty() {
struct_ser.serialize_field("delimiter", &self.delimiter)?;
}
@@ -5212,6 +5455,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
D: serde::Deserializer<'de>,
{
const FIELDS: &[&str] = &[
+ "compression",
"delimiter",
"has_header",
"hasHeader",
@@ -5229,6 +5473,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
#[allow(clippy::enum_variant_names)]
enum GeneratedField {
+ Compression,
Delimiter,
HasHeader,
DateFormat,
@@ -5257,6 +5502,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
E: serde::de::Error,
{
match value {
+ "compression" => Ok(GeneratedField::Compression),
"delimiter" => Ok(GeneratedField::Delimiter),
"hasHeader" | "has_header" =>
Ok(GeneratedField::HasHeader),
"dateFormat" | "date_format" =>
Ok(GeneratedField::DateFormat),
@@ -5283,6 +5529,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
where
V: serde::de::MapAccess<'de>,
{
+ let mut compression__ = None;
let mut delimiter__ = None;
let mut has_header__ = None;
let mut date_format__ = None;
@@ -5292,6 +5539,12 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
let mut null_value__ = None;
while let Some(k) = map_.next_key()? {
match k {
+ GeneratedField::Compression => {
+ if compression__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("compression"));
+ }
+ compression__ =
Some(map_.next_value::<CompressionTypeVariant>()? as i32);
+ }
GeneratedField::Delimiter => {
if delimiter__.is_some() {
return
Err(serde::de::Error::duplicate_field("delimiter"));
@@ -5337,6 +5590,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions {
}
}
Ok(CsvWriterOptions {
+ compression: compression__.unwrap_or_default(),
delimiter: delimiter__.unwrap_or_default(),
has_header: has_header__.unwrap_or_default(),
date_format: date_format__.unwrap_or_default(),
@@ -15398,6 +15652,241 @@ impl<'de> serde::Deserialize<'de> for
ParquetScanExecNode {
deserializer.deserialize_struct("datafusion.ParquetScanExecNode",
FIELDS, GeneratedVisitor)
}
}
+impl serde::Serialize for ParquetSink {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if self.config.is_some() {
+ len += 1;
+ }
+ let mut struct_ser =
serializer.serialize_struct("datafusion.ParquetSink", len)?;
+ if let Some(v) = self.config.as_ref() {
+ struct_ser.serialize_field("config", v)?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for ParquetSink {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "config",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ Config,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "config" => Ok(GeneratedField::Config),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = ParquetSink;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct datafusion.ParquetSink")
+ }
+
+ fn visit_map<V>(self, mut map_: V) ->
std::result::Result<ParquetSink, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut config__ = None;
+ while let Some(k) = map_.next_key()? {
+ match k {
+ GeneratedField::Config => {
+ if config__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("config"));
+ }
+ config__ = map_.next_value()?;
+ }
+ }
+ }
+ Ok(ParquetSink {
+ config: config__,
+ })
+ }
+ }
+ deserializer.deserialize_struct("datafusion.ParquetSink", FIELDS,
GeneratedVisitor)
+ }
+}
+impl serde::Serialize for ParquetSinkExecNode {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if self.input.is_some() {
+ len += 1;
+ }
+ if self.sink.is_some() {
+ len += 1;
+ }
+ if self.sink_schema.is_some() {
+ len += 1;
+ }
+ if self.sort_order.is_some() {
+ len += 1;
+ }
+ let mut struct_ser =
serializer.serialize_struct("datafusion.ParquetSinkExecNode", len)?;
+ if let Some(v) = self.input.as_ref() {
+ struct_ser.serialize_field("input", v)?;
+ }
+ if let Some(v) = self.sink.as_ref() {
+ struct_ser.serialize_field("sink", v)?;
+ }
+ if let Some(v) = self.sink_schema.as_ref() {
+ struct_ser.serialize_field("sinkSchema", v)?;
+ }
+ if let Some(v) = self.sort_order.as_ref() {
+ struct_ser.serialize_field("sortOrder", v)?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for ParquetSinkExecNode {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "input",
+ "sink",
+ "sink_schema",
+ "sinkSchema",
+ "sort_order",
+ "sortOrder",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ Input,
+ Sink,
+ SinkSchema,
+ SortOrder,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "input" => Ok(GeneratedField::Input),
+ "sink" => Ok(GeneratedField::Sink),
+ "sinkSchema" | "sink_schema" =>
Ok(GeneratedField::SinkSchema),
+ "sortOrder" | "sort_order" =>
Ok(GeneratedField::SortOrder),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = ParquetSinkExecNode;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct datafusion.ParquetSinkExecNode")
+ }
+
+ fn visit_map<V>(self, mut map_: V) ->
std::result::Result<ParquetSinkExecNode, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut input__ = None;
+ let mut sink__ = None;
+ let mut sink_schema__ = None;
+ let mut sort_order__ = None;
+ while let Some(k) = map_.next_key()? {
+ match k {
+ GeneratedField::Input => {
+ if input__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("input"));
+ }
+ input__ = map_.next_value()?;
+ }
+ GeneratedField::Sink => {
+ if sink__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("sink"));
+ }
+ sink__ = map_.next_value()?;
+ }
+ GeneratedField::SinkSchema => {
+ if sink_schema__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("sinkSchema"));
+ }
+ sink_schema__ = map_.next_value()?;
+ }
+ GeneratedField::SortOrder => {
+ if sort_order__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("sortOrder"));
+ }
+ sort_order__ = map_.next_value()?;
+ }
+ }
+ }
+ Ok(ParquetSinkExecNode {
+ input: input__,
+ sink: sink__,
+ sink_schema: sink_schema__,
+ sort_order: sort_order__,
+ })
+ }
+ }
+ deserializer.deserialize_struct("datafusion.ParquetSinkExecNode",
FIELDS, GeneratedVisitor)
+ }
+}
impl serde::Serialize for ParquetWriterOptions {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
@@ -18484,6 +18973,12 @@ impl serde::Serialize for PhysicalPlanNode {
physical_plan_node::PhysicalPlanType::PlaceholderRow(v) => {
struct_ser.serialize_field("placeholderRow", v)?;
}
+ physical_plan_node::PhysicalPlanType::CsvSink(v) => {
+ struct_ser.serialize_field("csvSink", v)?;
+ }
+ physical_plan_node::PhysicalPlanType::ParquetSink(v) => {
+ struct_ser.serialize_field("parquetSink", v)?;
+ }
}
}
struct_ser.end()
@@ -18535,6 +19030,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode
{
"interleave",
"placeholder_row",
"placeholderRow",
+ "csv_sink",
+ "csvSink",
+ "parquet_sink",
+ "parquetSink",
];
#[allow(clippy::enum_variant_names)]
@@ -18565,6 +19064,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
SymmetricHashJoin,
Interleave,
PlaceholderRow,
+ CsvSink,
+ ParquetSink,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -18612,6 +19113,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
"symmetricHashJoin" | "symmetric_hash_join" =>
Ok(GeneratedField::SymmetricHashJoin),
"interleave" => Ok(GeneratedField::Interleave),
"placeholderRow" | "placeholder_row" =>
Ok(GeneratedField::PlaceholderRow),
+ "csvSink" | "csv_sink" =>
Ok(GeneratedField::CsvSink),
+ "parquetSink" | "parquet_sink" =>
Ok(GeneratedField::ParquetSink),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -18814,6 +19317,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode
{
return
Err(serde::de::Error::duplicate_field("placeholderRow"));
}
physical_plan_type__ =
map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::PlaceholderRow)
+;
+ }
+ GeneratedField::CsvSink => {
+ if physical_plan_type__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("csvSink"));
+ }
+ physical_plan_type__ =
map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CsvSink)
+;
+ }
+ GeneratedField::ParquetSink => {
+ if physical_plan_type__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("parquetSink"));
+ }
+ physical_plan_type__ =
map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetSink)
;
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 32e892e663..4ee0b70325 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1566,7 +1566,7 @@ pub mod owned_table_reference {
pub struct PhysicalPlanNode {
#[prost(
oneof = "physical_plan_node::PhysicalPlanType",
- tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27"
+ tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29"
)]
pub physical_plan_type:
::core::option::Option<physical_plan_node::PhysicalPlanType>,
}
@@ -1629,6 +1629,10 @@ pub mod physical_plan_node {
Interleave(super::InterleaveExecNode),
#[prost(message, tag = "27")]
PlaceholderRow(super::PlaceholderRowExecNode),
+ #[prost(message, tag = "28")]
+ CsvSink(::prost::alloc::boxed::Box<super::CsvSinkExecNode>),
+ #[prost(message, tag = "29")]
+ ParquetSink(::prost::alloc::boxed::Box<super::ParquetSinkExecNode>),
}
}
#[allow(clippy::derive_partial_eq_without_eq)]
@@ -1673,26 +1677,29 @@ pub struct ParquetWriterOptions {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct CsvWriterOptions {
+ /// Compression type
+ #[prost(enumeration = "CompressionTypeVariant", tag = "1")]
+ pub compression: i32,
/// Optional column delimiter. Defaults to `b','`
- #[prost(string, tag = "1")]
+ #[prost(string, tag = "2")]
pub delimiter: ::prost::alloc::string::String,
/// Whether to write column names as file headers. Defaults to `true`
- #[prost(bool, tag = "2")]
+ #[prost(bool, tag = "3")]
pub has_header: bool,
/// Optional date format for date arrays
- #[prost(string, tag = "3")]
+ #[prost(string, tag = "4")]
pub date_format: ::prost::alloc::string::String,
/// Optional datetime format for datetime arrays
- #[prost(string, tag = "4")]
+ #[prost(string, tag = "5")]
pub datetime_format: ::prost::alloc::string::String,
/// Optional timestamp format for timestamp arrays
- #[prost(string, tag = "5")]
+ #[prost(string, tag = "6")]
pub timestamp_format: ::prost::alloc::string::String,
/// Optional time format for time arrays
- #[prost(string, tag = "6")]
+ #[prost(string, tag = "7")]
pub time_format: ::prost::alloc::string::String,
/// Optional value to represent null
- #[prost(string, tag = "7")]
+ #[prost(string, tag = "8")]
pub null_value: ::prost::alloc::string::String,
}
#[allow(clippy::derive_partial_eq_without_eq)]
@@ -1753,6 +1760,42 @@ pub struct JsonSinkExecNode {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct CsvSink {
+ #[prost(message, optional, tag = "1")]
+ pub config: ::core::option::Option<FileSinkConfig>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct CsvSinkExecNode {
+ #[prost(message, optional, boxed, tag = "1")]
+ pub input:
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
+ #[prost(message, optional, tag = "2")]
+ pub sink: ::core::option::Option<CsvSink>,
+ #[prost(message, optional, tag = "3")]
+ pub sink_schema: ::core::option::Option<Schema>,
+ #[prost(message, optional, tag = "4")]
+ pub sort_order: ::core::option::Option<PhysicalSortExprNodeCollection>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct ParquetSink {
+ #[prost(message, optional, tag = "1")]
+ pub config: ::core::option::Option<FileSinkConfig>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct ParquetSinkExecNode {
+ #[prost(message, optional, boxed, tag = "1")]
+ pub input:
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
+ #[prost(message, optional, tag = "2")]
+ pub sink: ::core::option::Option<ParquetSink>,
+ #[prost(message, optional, tag = "3")]
+ pub sink_schema: ::core::option::Option<Schema>,
+ #[prost(message, optional, tag = "4")]
+ pub sort_order: ::core::option::Option<PhysicalSortExprNodeCollection>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
pub struct PhysicalExtensionNode {
#[prost(bytes = "vec", tag = "1")]
pub node: ::prost::alloc::vec::Vec<u8>,
diff --git a/datafusion/proto/src/logical_plan/mod.rs
b/datafusion/proto/src/logical_plan/mod.rs
index dbed0252d0..5ee88c3d53 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -1648,28 +1648,10 @@ impl AsLogicalPlan for LogicalPlanNode {
match opt.as_ref() {
FileTypeWriterOptions::CSV(csv_opts) => {
let csv_options = &csv_opts.writer_options;
- let csv_writer_options =
protobuf::CsvWriterOptions {
- delimiter: (csv_options.delimiter() as
char)
- .to_string(),
- has_header: csv_options.header(),
- date_format: csv_options
- .date_format()
- .unwrap_or("")
- .to_owned(),
- datetime_format: csv_options
- .datetime_format()
- .unwrap_or("")
- .to_owned(),
- timestamp_format: csv_options
- .timestamp_format()
- .unwrap_or("")
- .to_owned(),
- time_format: csv_options
- .time_format()
- .unwrap_or("")
- .to_owned(),
- null_value:
csv_options.null().to_owned(),
- };
+ let csv_writer_options =
csv_writer_options_to_proto(
+ csv_options,
+ (&csv_opts.compression).into(),
+ );
let csv_options =
file_type_writer_options::FileType::CsvOptions(
csv_writer_options,
@@ -1724,6 +1706,23 @@ impl AsLogicalPlan for LogicalPlanNode {
}
}
+pub(crate) fn csv_writer_options_to_proto(
+ csv_options: &WriterBuilder,
+ compression: &CompressionTypeVariant,
+) -> protobuf::CsvWriterOptions {
+ let compression: protobuf::CompressionTypeVariant = compression.into();
+ protobuf::CsvWriterOptions {
+ compression: compression.into(),
+ delimiter: (csv_options.delimiter() as char).to_string(),
+ has_header: csv_options.header(),
+ date_format: csv_options.date_format().unwrap_or("").to_owned(),
+ datetime_format:
csv_options.datetime_format().unwrap_or("").to_owned(),
+ timestamp_format:
csv_options.timestamp_format().unwrap_or("").to_owned(),
+ time_format: csv_options.time_format().unwrap_or("").to_owned(),
+ null_value: csv_options.null().to_owned(),
+ }
+}
+
pub(crate) fn csv_writer_options_from_proto(
writer_options: &protobuf::CsvWriterOptions,
) -> Result<WriterBuilder> {
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs
b/datafusion/proto/src/physical_plan/from_proto.rs
index 6f1e811510..8ad6d679df 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -22,7 +22,10 @@ use std::sync::Arc;
use arrow::compute::SortOptions;
use datafusion::arrow::datatypes::Schema;
+use datafusion::datasource::file_format::csv::CsvSink;
use datafusion::datasource::file_format::json::JsonSink;
+#[cfg(feature = "parquet")]
+use datafusion::datasource::file_format::parquet::ParquetSink;
use datafusion::datasource::listing::{FileRange, ListingTableUrl,
PartitionedFile};
use datafusion::datasource::object_store::ObjectStoreUrl;
use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig};
@@ -713,6 +716,23 @@ impl TryFrom<&protobuf::JsonSink> for JsonSink {
}
}
+#[cfg(feature = "parquet")]
+impl TryFrom<&protobuf::ParquetSink> for ParquetSink {
+ type Error = DataFusionError;
+
+ fn try_from(value: &protobuf::ParquetSink) -> Result<Self, Self::Error> {
+ Ok(Self::new(convert_required!(value.config)?))
+ }
+}
+
+impl TryFrom<&protobuf::CsvSink> for CsvSink {
+ type Error = DataFusionError;
+
+ fn try_from(value: &protobuf::CsvSink) -> Result<Self, Self::Error> {
+ Ok(Self::new(convert_required!(value.config)?))
+ }
+}
+
impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig {
type Error = DataFusionError;
@@ -768,16 +788,16 @@ impl TryFrom<&protobuf::FileTypeWriterOptions> for
FileTypeWriterOptions {
.file_type
.as_ref()
.ok_or_else(|| proto_error("Missing required file_type field in
protobuf"))?;
+
match file_type {
- protobuf::file_type_writer_options::FileType::JsonOptions(opts) =>
Ok(
- Self::JSON(JsonWriterOptions::new(opts.compression().into())),
- ),
- protobuf::file_type_writer_options::FileType::CsvOptions(opt) => {
- let write_options = csv_writer_options_from_proto(opt)?;
- Ok(Self::CSV(CsvWriterOptions::new(
- write_options,
- CompressionTypeVariant::UNCOMPRESSED,
- )))
+ protobuf::file_type_writer_options::FileType::JsonOptions(opts) =>
{
+ let compression: CompressionTypeVariant =
opts.compression().into();
+ Ok(Self::JSON(JsonWriterOptions::new(compression)))
+ }
+ protobuf::file_type_writer_options::FileType::CsvOptions(opts) => {
+ let write_options = csv_writer_options_from_proto(opts)?;
+ let compression: CompressionTypeVariant =
opts.compression().into();
+ Ok(Self::CSV(CsvWriterOptions::new(write_options,
compression)))
}
protobuf::file_type_writer_options::FileType::ParquetOptions(opt)
=> {
let props = opt.writer_properties.clone().unwrap_or_default();
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 24ede3fcaf..95becb3fe4 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -21,9 +21,12 @@ use std::sync::Arc;
use datafusion::arrow::compute::SortOptions;
use datafusion::arrow::datatypes::SchemaRef;
+use datafusion::datasource::file_format::csv::CsvSink;
use
datafusion::datasource::file_format::file_compression_type::FileCompressionType;
use datafusion::datasource::file_format::json::JsonSink;
#[cfg(feature = "parquet")]
+use datafusion::datasource::file_format::parquet::ParquetSink;
+#[cfg(feature = "parquet")]
use datafusion::datasource::physical_plan::ParquetExec;
use datafusion::datasource::physical_plan::{AvroExec, CsvExec};
use datafusion::execution::runtime_env::RuntimeEnv;
@@ -921,6 +924,68 @@ impl AsExecutionPlan for PhysicalPlanNode {
sort_order,
)))
}
+ PhysicalPlanType::CsvSink(sink) => {
+ let input =
+ into_physical_plan(&sink.input, registry, runtime,
extension_codec)?;
+
+ let data_sink: CsvSink = sink
+ .sink
+ .as_ref()
+ .ok_or_else(|| proto_error("Missing required field in
protobuf"))?
+ .try_into()?;
+ let sink_schema = convert_required!(sink.sink_schema)?;
+ let sort_order = sink
+ .sort_order
+ .as_ref()
+ .map(|collection| {
+ collection
+ .physical_sort_expr_nodes
+ .iter()
+ .map(|proto| {
+ parse_physical_sort_expr(proto, registry,
&sink_schema)
+ .map(Into::into)
+ })
+ .collect::<Result<Vec<_>>>()
+ })
+ .transpose()?;
+ Ok(Arc::new(FileSinkExec::new(
+ input,
+ Arc::new(data_sink),
+ Arc::new(sink_schema),
+ sort_order,
+ )))
+ }
+ PhysicalPlanType::ParquetSink(sink) => {
+ let input =
+ into_physical_plan(&sink.input, registry, runtime,
extension_codec)?;
+
+ let data_sink: ParquetSink = sink
+ .sink
+ .as_ref()
+ .ok_or_else(|| proto_error("Missing required field in
protobuf"))?
+ .try_into()?;
+ let sink_schema = convert_required!(sink.sink_schema)?;
+ let sort_order = sink
+ .sort_order
+ .as_ref()
+ .map(|collection| {
+ collection
+ .physical_sort_expr_nodes
+ .iter()
+ .map(|proto| {
+ parse_physical_sort_expr(proto, registry,
&sink_schema)
+ .map(Into::into)
+ })
+ .collect::<Result<Vec<_>>>()
+ })
+ .transpose()?;
+ Ok(Arc::new(FileSinkExec::new(
+ input,
+ Arc::new(data_sink),
+ Arc::new(sink_schema),
+ sort_order,
+ )))
+ }
}
}
@@ -1678,6 +1743,32 @@ impl AsExecutionPlan for PhysicalPlanNode {
});
}
+ if let Some(sink) = exec.sink().as_any().downcast_ref::<CsvSink>()
{
+ return Ok(protobuf::PhysicalPlanNode {
+ physical_plan_type:
Some(PhysicalPlanType::CsvSink(Box::new(
+ protobuf::CsvSinkExecNode {
+ input: Some(Box::new(input)),
+ sink: Some(sink.try_into()?),
+ sink_schema:
Some(exec.schema().as_ref().try_into()?),
+ sort_order,
+ },
+ ))),
+ });
+ }
+
+ if let Some(sink) =
exec.sink().as_any().downcast_ref::<ParquetSink>() {
+ return Ok(protobuf::PhysicalPlanNode {
+ physical_plan_type:
Some(PhysicalPlanType::ParquetSink(Box::new(
+ protobuf::ParquetSinkExecNode {
+ input: Some(Box::new(input)),
+ sink: Some(sink.try_into()?),
+ sink_schema:
Some(exec.schema().as_ref().try_into()?),
+ sort_order,
+ },
+ ))),
+ });
+ }
+
// If unknown DataSink then let extension handle it
}
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index e9cdb34cf1..f4e3f9e4dc 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -28,7 +28,12 @@ use crate::protobuf::{
ScalarValue,
};
+#[cfg(feature = "parquet")]
+use datafusion::datasource::file_format::parquet::ParquetSink;
+
+use crate::logical_plan::{csv_writer_options_to_proto,
writer_properties_to_proto};
use datafusion::datasource::{
+ file_format::csv::CsvSink,
file_format::json::JsonSink,
listing::{FileRange, PartitionedFile},
physical_plan::FileScanConfig,
@@ -814,6 +819,27 @@ impl TryFrom<&JsonSink> for protobuf::JsonSink {
}
}
+impl TryFrom<&CsvSink> for protobuf::CsvSink {
+ type Error = DataFusionError;
+
+ fn try_from(value: &CsvSink) -> Result<Self, Self::Error> {
+ Ok(Self {
+ config: Some(value.config().try_into()?),
+ })
+ }
+}
+
+#[cfg(feature = "parquet")]
+impl TryFrom<&ParquetSink> for protobuf::ParquetSink {
+ type Error = DataFusionError;
+
+ fn try_from(value: &ParquetSink) -> Result<Self, Self::Error> {
+ Ok(Self {
+ config: Some(value.config().try_into()?),
+ })
+ }
+}
+
impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig {
type Error = DataFusionError;
@@ -870,13 +896,21 @@ impl TryFrom<&FileTypeWriterOptions> for
protobuf::FileTypeWriterOptions {
fn try_from(opts: &FileTypeWriterOptions) -> Result<Self, Self::Error> {
let file_type = match opts {
#[cfg(feature = "parquet")]
- FileTypeWriterOptions::Parquet(ParquetWriterOptions {
- writer_options: _,
- }) => return not_impl_err!("Parquet file sink protobuf
serialization"),
+ FileTypeWriterOptions::Parquet(ParquetWriterOptions {
writer_options }) => {
+ protobuf::file_type_writer_options::FileType::ParquetOptions(
+ protobuf::ParquetWriterOptions {
+ writer_properties: Some(writer_properties_to_proto(
+ writer_options,
+ )),
+ },
+ )
+ }
FileTypeWriterOptions::CSV(CsvWriterOptions {
- writer_options: _,
- compression: _,
- }) => return not_impl_err!("CSV file sink protobuf serialization"),
+ writer_options,
+ compression,
+ }) => protobuf::file_type_writer_options::FileType::CsvOptions(
+ csv_writer_options_to_proto(writer_options, compression),
+ ),
FileTypeWriterOptions::JSON(JsonWriterOptions { compression }) => {
let compression: protobuf::CompressionTypeVariant =
compression.into();
protobuf::file_type_writer_options::FileType::JsonOptions(
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 2eb04ab6cb..27ac5d122f 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -15,13 +15,16 @@
// specific language governing permissions and limitations
// under the License.
+use arrow::csv::WriterBuilder;
use std::ops::Deref;
use std::sync::Arc;
use datafusion::arrow::array::ArrayRef;
use datafusion::arrow::compute::kernels::sort::SortOptions;
use datafusion::arrow::datatypes::{DataType, Field, Fields, IntervalUnit,
Schema};
+use datafusion::datasource::file_format::csv::CsvSink;
use datafusion::datasource::file_format::json::JsonSink;
+use datafusion::datasource::file_format::parquet::ParquetSink;
use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile};
use datafusion::datasource::object_store::ObjectStoreUrl;
use datafusion::datasource::physical_plan::{
@@ -31,6 +34,7 @@ use datafusion::execution::context::ExecutionProps;
use datafusion::logical_expr::{
create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility,
};
+use datafusion::parquet::file::properties::WriterProperties;
use datafusion::physical_expr::window::SlidingAggregateWindowExpr;
use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr};
use datafusion::physical_plan::aggregates::{
@@ -62,7 +66,9 @@ use datafusion::physical_plan::{
};
use datafusion::prelude::SessionContext;
use datafusion::scalar::ScalarValue;
+use datafusion_common::file_options::csv_writer::CsvWriterOptions;
use datafusion_common::file_options::json_writer::JsonWriterOptions;
+use datafusion_common::file_options::parquet_writer::ParquetWriterOptions;
use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::stats::Precision;
use datafusion_common::{FileTypeWriterOptions, Result};
@@ -73,7 +79,23 @@ use datafusion_expr::{
use datafusion_proto::physical_plan::{AsExecutionPlan,
DefaultPhysicalExtensionCodec};
use datafusion_proto::protobuf;
+/// Perform a serde roundtrip and assert that the string representation of the
before and after plans
+/// are identical. Note that this often isn't sufficient to guarantee that no
information is
+/// lost during serde because the string representation of a plan often only
shows a subset of state.
fn roundtrip_test(exec_plan: Arc<dyn ExecutionPlan>) -> Result<()> {
+ let _ = roundtrip_test_and_return(exec_plan);
+ Ok(())
+}
+
+/// Perform a serde roundtrip and assert that the string representation of the
before and after plans
+/// are identical. Note that this often isn't sufficient to guarantee that no
information is
+/// lost during serde because the string representation of a plan often only
shows a subset of state.
+///
+/// This version of the roundtrip_test method returns the final plan after
serde so that it can be inspected
+/// farther in tests.
+fn roundtrip_test_and_return(
+ exec_plan: Arc<dyn ExecutionPlan>,
+) -> Result<Arc<dyn ExecutionPlan>> {
let ctx = SessionContext::new();
let codec = DefaultPhysicalExtensionCodec {};
let proto: protobuf::PhysicalPlanNode =
@@ -84,9 +106,15 @@ fn roundtrip_test(exec_plan: Arc<dyn ExecutionPlan>) ->
Result<()> {
.try_into_physical_plan(&ctx, runtime.deref(), &codec)
.expect("from proto");
assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}"));
- Ok(())
+ Ok(result_exec_plan)
}
+/// Perform a serde roundtrip and assert that the string representation of the
before and after plans
+/// are identical. Note that this often isn't sufficient to guarantee that no
information is
+/// lost during serde because the string representation of a plan often only
shows a subset of state.
+///
+/// This version of the roundtrip_test function accepts a SessionContext,
which is required when
+/// performing serde on some plans.
fn roundtrip_test_with_context(
exec_plan: Arc<dyn ExecutionPlan>,
ctx: SessionContext,
@@ -755,6 +783,101 @@ fn roundtrip_json_sink() -> Result<()> {
)))
}
+#[test]
+fn roundtrip_csv_sink() -> Result<()> {
+ let field_a = Field::new("plan_type", DataType::Utf8, false);
+ let field_b = Field::new("plan", DataType::Utf8, false);
+ let schema = Arc::new(Schema::new(vec![field_a, field_b]));
+ let input = Arc::new(PlaceholderRowExec::new(schema.clone()));
+
+ let file_sink_config = FileSinkConfig {
+ object_store_url: ObjectStoreUrl::local_filesystem(),
+ file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)],
+ table_paths: vec![ListingTableUrl::parse("file:///")?],
+ output_schema: schema.clone(),
+ table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)],
+ single_file_output: true,
+ overwrite: true,
+ file_type_writer_options:
FileTypeWriterOptions::CSV(CsvWriterOptions::new(
+ WriterBuilder::default(),
+ CompressionTypeVariant::ZSTD,
+ )),
+ };
+ let data_sink = Arc::new(CsvSink::new(file_sink_config));
+ let sort_order = vec![PhysicalSortRequirement::new(
+ Arc::new(Column::new("plan_type", 0)),
+ Some(SortOptions {
+ descending: true,
+ nulls_first: false,
+ }),
+ )];
+
+ let roundtrip_plan = roundtrip_test_and_return(Arc::new(FileSinkExec::new(
+ input,
+ data_sink,
+ schema.clone(),
+ Some(sort_order),
+ )))
+ .unwrap();
+
+ let roundtrip_plan = roundtrip_plan
+ .as_any()
+ .downcast_ref::<FileSinkExec>()
+ .unwrap();
+ let csv_sink = roundtrip_plan
+ .sink()
+ .as_any()
+ .downcast_ref::<CsvSink>()
+ .unwrap();
+ assert_eq!(
+ CompressionTypeVariant::ZSTD,
+ csv_sink
+ .config()
+ .file_type_writer_options
+ .try_into_csv()
+ .unwrap()
+ .compression
+ );
+
+ Ok(())
+}
+
+#[test]
+fn roundtrip_parquet_sink() -> Result<()> {
+ let field_a = Field::new("plan_type", DataType::Utf8, false);
+ let field_b = Field::new("plan", DataType::Utf8, false);
+ let schema = Arc::new(Schema::new(vec![field_a, field_b]));
+ let input = Arc::new(PlaceholderRowExec::new(schema.clone()));
+
+ let file_sink_config = FileSinkConfig {
+ object_store_url: ObjectStoreUrl::local_filesystem(),
+ file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)],
+ table_paths: vec![ListingTableUrl::parse("file:///")?],
+ output_schema: schema.clone(),
+ table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)],
+ single_file_output: true,
+ overwrite: true,
+ file_type_writer_options: FileTypeWriterOptions::Parquet(
+ ParquetWriterOptions::new(WriterProperties::default()),
+ ),
+ };
+ let data_sink = Arc::new(ParquetSink::new(file_sink_config));
+ let sort_order = vec![PhysicalSortRequirement::new(
+ Arc::new(Column::new("plan_type", 0)),
+ Some(SortOptions {
+ descending: true,
+ nulls_first: false,
+ }),
+ )];
+
+ roundtrip_test(Arc::new(FileSinkExec::new(
+ input,
+ data_sink,
+ schema.clone(),
+ Some(sort_order),
+ )))
+}
+
#[test]
fn roundtrip_sym_hash_join() -> Result<()> {
let field_a = Field::new("col", DataType::Int64, false);