This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 71ff27c05a feat: preserve metadata for Field and Schema in proto
(#6865)
71ff27c05a is described below
commit 71ff27c05ac0b2492a2e4618dbe2288020ee8fea
Author: Jonah Gao <[email protected]>
AuthorDate: Fri Jul 7 04:34:17 2023 +0800
feat: preserve metadata for Field and Schema in proto (#6865)
---
datafusion/proto/proto/datafusion.proto | 2 ++
datafusion/proto/src/generated/pbjson.rs | 38 +++++++++++++++++++++++++
datafusion/proto/src/generated/prost.rs | 10 +++++++
datafusion/proto/src/logical_plan/from_proto.rs | 18 +++---------
datafusion/proto/src/logical_plan/mod.rs | 31 ++++++++++++++++++++
datafusion/proto/src/logical_plan/to_proto.rs | 3 ++
6 files changed, 88 insertions(+), 14 deletions(-)
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 00fa28906c..528c675570 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -740,6 +740,7 @@ message WindowFrameBound {
message Schema {
repeated Field columns = 1;
+ map<string, string> metadata = 2;
}
message Field {
@@ -749,6 +750,7 @@ message Field {
bool nullable = 3;
// for complex data types like structs, unions
repeated Field children = 4;
+ map<string, string> metadata = 5;
}
message FixedSizeBinary{
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 63303fc32c..d6a770159b 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -6049,6 +6049,9 @@ impl serde::Serialize for Field {
if !self.children.is_empty() {
len += 1;
}
+ if !self.metadata.is_empty() {
+ len += 1;
+ }
let mut struct_ser = serializer.serialize_struct("datafusion.Field",
len)?;
if !self.name.is_empty() {
struct_ser.serialize_field("name", &self.name)?;
@@ -6062,6 +6065,9 @@ impl serde::Serialize for Field {
if !self.children.is_empty() {
struct_ser.serialize_field("children", &self.children)?;
}
+ if !self.metadata.is_empty() {
+ struct_ser.serialize_field("metadata", &self.metadata)?;
+ }
struct_ser.end()
}
}
@@ -6077,6 +6083,7 @@ impl<'de> serde::Deserialize<'de> for Field {
"arrowType",
"nullable",
"children",
+ "metadata",
];
#[allow(clippy::enum_variant_names)]
@@ -6085,6 +6092,7 @@ impl<'de> serde::Deserialize<'de> for Field {
ArrowType,
Nullable,
Children,
+ Metadata,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -6110,6 +6118,7 @@ impl<'de> serde::Deserialize<'de> for Field {
"arrowType" | "arrow_type" =>
Ok(GeneratedField::ArrowType),
"nullable" => Ok(GeneratedField::Nullable),
"children" => Ok(GeneratedField::Children),
+ "metadata" => Ok(GeneratedField::Metadata),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -6133,6 +6142,7 @@ impl<'de> serde::Deserialize<'de> for Field {
let mut arrow_type__ = None;
let mut nullable__ = None;
let mut children__ = None;
+ let mut metadata__ = None;
while let Some(k) = map.next_key()? {
match k {
GeneratedField::Name => {
@@ -6159,6 +6169,14 @@ impl<'de> serde::Deserialize<'de> for Field {
}
children__ = Some(map.next_value()?);
}
+ GeneratedField::Metadata => {
+ if metadata__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("metadata"));
+ }
+ metadata__ = Some(
+ map.next_value::<std::collections::HashMap<_,
_>>()?
+ );
+ }
}
}
Ok(Field {
@@ -6166,6 +6184,7 @@ impl<'de> serde::Deserialize<'de> for Field {
arrow_type: arrow_type__,
nullable: nullable__.unwrap_or_default(),
children: children__.unwrap_or_default(),
+ metadata: metadata__.unwrap_or_default(),
})
}
}
@@ -19493,10 +19512,16 @@ impl serde::Serialize for Schema {
if !self.columns.is_empty() {
len += 1;
}
+ if !self.metadata.is_empty() {
+ len += 1;
+ }
let mut struct_ser = serializer.serialize_struct("datafusion.Schema",
len)?;
if !self.columns.is_empty() {
struct_ser.serialize_field("columns", &self.columns)?;
}
+ if !self.metadata.is_empty() {
+ struct_ser.serialize_field("metadata", &self.metadata)?;
+ }
struct_ser.end()
}
}
@@ -19508,11 +19533,13 @@ impl<'de> serde::Deserialize<'de> for Schema {
{
const FIELDS: &[&str] = &[
"columns",
+ "metadata",
];
#[allow(clippy::enum_variant_names)]
enum GeneratedField {
Columns,
+ Metadata,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -19535,6 +19562,7 @@ impl<'de> serde::Deserialize<'de> for Schema {
{
match value {
"columns" => Ok(GeneratedField::Columns),
+ "metadata" => Ok(GeneratedField::Metadata),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -19555,6 +19583,7 @@ impl<'de> serde::Deserialize<'de> for Schema {
V: serde::de::MapAccess<'de>,
{
let mut columns__ = None;
+ let mut metadata__ = None;
while let Some(k) = map.next_key()? {
match k {
GeneratedField::Columns => {
@@ -19563,10 +19592,19 @@ impl<'de> serde::Deserialize<'de> for Schema {
}
columns__ = Some(map.next_value()?);
}
+ GeneratedField::Metadata => {
+ if metadata__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("metadata"));
+ }
+ metadata__ = Some(
+ map.next_value::<std::collections::HashMap<_,
_>>()?
+ );
+ }
}
}
Ok(Schema {
columns: columns__.unwrap_or_default(),
+ metadata: metadata__.unwrap_or_default(),
})
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 00eea4d6ed..4e91fbab19 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -893,6 +893,11 @@ pub struct WindowFrameBound {
pub struct Schema {
#[prost(message, repeated, tag = "1")]
pub columns: ::prost::alloc::vec::Vec<Field>,
+ #[prost(map = "string, string", tag = "2")]
+ pub metadata: ::std::collections::HashMap<
+ ::prost::alloc::string::String,
+ ::prost::alloc::string::String,
+ >,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
@@ -907,6 +912,11 @@ pub struct Field {
/// for complex data types like structs, unions
#[prost(message, repeated, tag = "4")]
pub children: ::prost::alloc::vec::Vec<Field>,
+ #[prost(map = "string, string", tag = "5")]
+ pub metadata: ::std::collections::HashMap<
+ ::prost::alloc::string::String,
+ ::prost::alloc::string::String,
+ >,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 4e2f59a118..1b48364ad4 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -365,8 +365,8 @@ impl TryFrom<&protobuf::Field> for Field {
type Error = Error;
fn try_from(field: &protobuf::Field) -> Result<Self, Self::Error> {
let datatype = field.arrow_type.as_deref().required("arrow_type")?;
-
- Ok(Self::new(field.name.as_str(), datatype, field.nullable))
+ Ok(Self::new(field.name.as_str(), datatype, field.nullable)
+ .with_metadata(field.metadata.clone()))
}
}
@@ -581,19 +581,9 @@ impl TryFrom<&protobuf::Schema> for Schema {
let fields = schema
.columns
.iter()
- .map(|c| {
- let pb_arrow_type_res = c
- .arrow_type
- .as_ref()
- .ok_or_else(|| proto_error("Protobuf deserialization
error: Field message was missing required field 'arrow_type'"));
- let pb_arrow_type: &protobuf::ArrowType = match
pb_arrow_type_res {
- Ok(res) => res,
- Err(e) => return Err(e),
- };
- Ok(Field::new(&c.name, pb_arrow_type.try_into()?, c.nullable))
- })
+ .map(Field::try_from)
.collect::<Result<Vec<_>, _>>()?;
- Ok(Self::new(fields))
+ Ok(Self::new_with_metadata(fields, schema.metadata.clone()))
}
}
diff --git a/datafusion/proto/src/logical_plan/mod.rs
b/datafusion/proto/src/logical_plan/mod.rs
index 7d0ddac484..ea293067b7 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -2341,6 +2341,37 @@ mod roundtrip_tests {
}
}
+ #[test]
+ fn roundtrip_field() {
+ let field =
+ Field::new("f", DataType::Int32,
true).with_metadata(HashMap::from([
+ (String::from("k1"), String::from("v1")),
+ (String::from("k2"), String::from("v2")),
+ ]));
+ let proto_field: super::protobuf::Field = (&field).try_into().unwrap();
+ let returned_field: Field = (&proto_field).try_into().unwrap();
+ assert_eq!(field, returned_field);
+ }
+
+ #[test]
+ fn roundtrip_schema() {
+ let schema = Schema::new_with_metadata(
+ vec![
+ Field::new("a", DataType::Int64, false),
+ Field::new("b", DataType::Decimal128(15, 2),
true).with_metadata(
+ HashMap::from([(String::from("k1"), String::from("v1"))]),
+ ),
+ ],
+ HashMap::from([
+ (String::from("k2"), String::from("v2")),
+ (String::from("k3"), String::from("v3")),
+ ]),
+ );
+ let proto_schema: super::protobuf::Schema =
(&schema).try_into().unwrap();
+ let returned_schema: Schema = (&proto_schema).try_into().unwrap();
+ assert_eq!(schema, returned_schema);
+ }
+
#[test]
fn roundtrip_not() {
let test_expr = Expr::Not(Box::new(lit(1.0_f32)));
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 4a4b16db80..8665ca00c3 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -117,6 +117,7 @@ impl TryFrom<&Field> for protobuf::Field {
arrow_type: Some(Box::new(arrow_type)),
nullable: field.is_nullable(),
children: Vec::new(),
+ metadata: field.metadata().clone(),
})
}
}
@@ -266,6 +267,7 @@ impl TryFrom<&Schema> for protobuf::Schema {
.iter()
.map(|f| f.as_ref().try_into())
.collect::<Result<Vec<_>, Error>>()?,
+ metadata: schema.metadata.clone(),
})
}
}
@@ -280,6 +282,7 @@ impl TryFrom<SchemaRef> for protobuf::Schema {
.iter()
.map(|f| f.as_ref().try_into())
.collect::<Result<Vec<_>, Error>>()?,
+ metadata: schema.metadata.clone(),
})
}
}