This is an automated email from the ASF dual-hosted git repository.

thinkharderdev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new e7858ff0ab Handle dictionary values in ScalarValue serde (#10563)
e7858ff0ab is described below

commit e7858ff0ab1c282ab46bd93cabc3dc83db583165
Author: Dan Harris <[email protected]>
AuthorDate: Fri May 17 17:09:55 2024 -0400

    Handle dictionary values in ScalarValue serde (#10563)
    
    * Handle dictionary values in ScalarValue serde
    
    * Do not panic on failed physical expr decoding (#241)
    
    * revert clippy change
---
 datafusion/proto/proto/datafusion.proto            |   6 +
 datafusion/proto/src/generated/pbjson.rs           | 133 +++++++++++++++++++++
 datafusion/proto/src/generated/prost.rs            |  13 ++
 datafusion/proto/src/logical_plan/from_proto.rs    |  49 +++++++-
 datafusion/proto/src/logical_plan/to_proto.rs      |   9 +-
 datafusion/proto/src/physical_plan/mod.rs          |   4 +-
 .../proto/tests/cases/roundtrip_logical_plan.rs    |  49 ++++++++
 7 files changed, 259 insertions(+), 4 deletions(-)

diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index fd79345275..8d69b0bad5 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -797,9 +797,15 @@ message Union{
 
 // Used for List/FixedSizeList/LargeList/Struct
 message ScalarNestedValue {
+  message Dictionary {
+    bytes ipc_message = 1;
+    bytes arrow_data = 2;
+  }
+
   bytes ipc_message = 1;
   bytes arrow_data = 2;
   Schema schema = 3;
+  repeated Dictionary dictionaries = 4;
 }
 
 message ScalarTime32Value {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 01d9a6e0dd..8df0aeb851 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -22605,6 +22605,9 @@ impl serde::Serialize for ScalarNestedValue {
         if self.schema.is_some() {
             len += 1;
         }
+        if !self.dictionaries.is_empty() {
+            len += 1;
+        }
         let mut struct_ser = 
serializer.serialize_struct("datafusion.ScalarNestedValue", len)?;
         if !self.ipc_message.is_empty() {
             #[allow(clippy::needless_borrow)]
@@ -22617,6 +22620,9 @@ impl serde::Serialize for ScalarNestedValue {
         if let Some(v) = self.schema.as_ref() {
             struct_ser.serialize_field("schema", v)?;
         }
+        if !self.dictionaries.is_empty() {
+            struct_ser.serialize_field("dictionaries", &self.dictionaries)?;
+        }
         struct_ser.end()
     }
 }
@@ -22632,6 +22638,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue 
{
             "arrow_data",
             "arrowData",
             "schema",
+            "dictionaries",
         ];
 
         #[allow(clippy::enum_variant_names)]
@@ -22639,6 +22646,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue 
{
             IpcMessage,
             ArrowData,
             Schema,
+            Dictionaries,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -22663,6 +22671,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue 
{
                             "ipcMessage" | "ipc_message" => 
Ok(GeneratedField::IpcMessage),
                             "arrowData" | "arrow_data" => 
Ok(GeneratedField::ArrowData),
                             "schema" => Ok(GeneratedField::Schema),
+                            "dictionaries" => Ok(GeneratedField::Dictionaries),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -22685,6 +22694,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue 
{
                 let mut ipc_message__ = None;
                 let mut arrow_data__ = None;
                 let mut schema__ = None;
+                let mut dictionaries__ = None;
                 while let Some(k) = map_.next_key()? {
                     match k {
                         GeneratedField::IpcMessage => {
@@ -22709,18 +22719,141 @@ impl<'de> serde::Deserialize<'de> for 
ScalarNestedValue {
                             }
                             schema__ = map_.next_value()?;
                         }
+                        GeneratedField::Dictionaries => {
+                            if dictionaries__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("dictionaries"));
+                            }
+                            dictionaries__ = Some(map_.next_value()?);
+                        }
                     }
                 }
                 Ok(ScalarNestedValue {
                     ipc_message: ipc_message__.unwrap_or_default(),
                     arrow_data: arrow_data__.unwrap_or_default(),
                     schema: schema__,
+                    dictionaries: dictionaries__.unwrap_or_default(),
                 })
             }
         }
         deserializer.deserialize_struct("datafusion.ScalarNestedValue", 
FIELDS, GeneratedVisitor)
     }
 }
+impl serde::Serialize for scalar_nested_value::Dictionary {
+    #[allow(deprecated)]
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        use serde::ser::SerializeStruct;
+        let mut len = 0;
+        if !self.ipc_message.is_empty() {
+            len += 1;
+        }
+        if !self.arrow_data.is_empty() {
+            len += 1;
+        }
+        let mut struct_ser = 
serializer.serialize_struct("datafusion.ScalarNestedValue.Dictionary", len)?;
+        if !self.ipc_message.is_empty() {
+            #[allow(clippy::needless_borrow)]
+            struct_ser.serialize_field("ipcMessage", 
pbjson::private::base64::encode(&self.ipc_message).as_str())?;
+        }
+        if !self.arrow_data.is_empty() {
+            #[allow(clippy::needless_borrow)]
+            struct_ser.serialize_field("arrowData", 
pbjson::private::base64::encode(&self.arrow_data).as_str())?;
+        }
+        struct_ser.end()
+    }
+}
+impl<'de> serde::Deserialize<'de> for scalar_nested_value::Dictionary {
+    #[allow(deprecated)]
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        const FIELDS: &[&str] = &[
+            "ipc_message",
+            "ipcMessage",
+            "arrow_data",
+            "arrowData",
+        ];
+
+        #[allow(clippy::enum_variant_names)]
+        enum GeneratedField {
+            IpcMessage,
+            ArrowData,
+        }
+        impl<'de> serde::Deserialize<'de> for GeneratedField {
+            fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
+            where
+                D: serde::Deserializer<'de>,
+            {
+                struct GeneratedVisitor;
+
+                impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+                    type Value = GeneratedField;
+
+                    fn expecting(&self, formatter: &mut 
std::fmt::Formatter<'_>) -> std::fmt::Result {
+                        write!(formatter, "expected one of: {:?}", &FIELDS)
+                    }
+
+                    #[allow(unused_variables)]
+                    fn visit_str<E>(self, value: &str) -> 
std::result::Result<GeneratedField, E>
+                    where
+                        E: serde::de::Error,
+                    {
+                        match value {
+                            "ipcMessage" | "ipc_message" => 
Ok(GeneratedField::IpcMessage),
+                            "arrowData" | "arrow_data" => 
Ok(GeneratedField::ArrowData),
+                            _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
+                        }
+                    }
+                }
+                deserializer.deserialize_identifier(GeneratedVisitor)
+            }
+        }
+        struct GeneratedVisitor;
+        impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+            type Value = scalar_nested_value::Dictionary;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> 
std::fmt::Result {
+                formatter.write_str("struct 
datafusion.ScalarNestedValue.Dictionary")
+            }
+
+            fn visit_map<V>(self, mut map_: V) -> 
std::result::Result<scalar_nested_value::Dictionary, V::Error>
+                where
+                    V: serde::de::MapAccess<'de>,
+            {
+                let mut ipc_message__ = None;
+                let mut arrow_data__ = None;
+                while let Some(k) = map_.next_key()? {
+                    match k {
+                        GeneratedField::IpcMessage => {
+                            if ipc_message__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("ipcMessage"));
+                            }
+                            ipc_message__ = 
+                                
Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0)
+                            ;
+                        }
+                        GeneratedField::ArrowData => {
+                            if arrow_data__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("arrowData"));
+                            }
+                            arrow_data__ = 
+                                
Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0)
+                            ;
+                        }
+                    }
+                }
+                Ok(scalar_nested_value::Dictionary {
+                    ipc_message: ipc_message__.unwrap_or_default(),
+                    arrow_data: arrow_data__.unwrap_or_default(),
+                })
+            }
+        }
+        
deserializer.deserialize_struct("datafusion.ScalarNestedValue.Dictionary", 
FIELDS, GeneratedVisitor)
+    }
+}
 impl serde::Serialize for ScalarTime32Value {
     #[allow(deprecated)]
     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 64e72ba038..b6b7687e6c 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1133,6 +1133,19 @@ pub struct ScalarNestedValue {
     pub arrow_data: ::prost::alloc::vec::Vec<u8>,
     #[prost(message, optional, tag = "3")]
     pub schema: ::core::option::Option<Schema>,
+    #[prost(message, repeated, tag = "4")]
+    pub dictionaries: 
::prost::alloc::vec::Vec<scalar_nested_value::Dictionary>,
+}
+/// Nested message and enum types in `ScalarNestedValue`.
+pub mod scalar_nested_value {
+    #[allow(clippy::derive_partial_eq_without_eq)]
+    #[derive(Clone, PartialEq, ::prost::Message)]
+    pub struct Dictionary {
+        #[prost(bytes = "vec", tag = "1")]
+        pub ipc_message: ::prost::alloc::vec::Vec<u8>,
+        #[prost(bytes = "vec", tag = "2")]
+        pub arrow_data: ::prost::alloc::vec::Vec<u8>,
+    }
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index 585bcad7f3..5df8eb59e1 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -15,8 +15,10 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::collections::HashMap;
 use std::sync::Arc;
 
+use arrow::array::ArrayRef;
 use arrow::{
     array::AsArray,
     buffer::Buffer,
@@ -522,6 +524,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
                 let protobuf::ScalarNestedValue {
                     ipc_message,
                     arrow_data,
+                    dictionaries,
                     schema,
                 } = &v;
 
@@ -548,11 +551,55 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
                     )
                 })?;
 
+                let dict_by_id: HashMap<i64,ArrayRef> = 
dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { 
ipc_message, arrow_data }| {
+                    let message = 
root_as_message(ipc_message.as_slice()).map_err(|e| {
+                        Error::General(format!(
+                            "Error IPC message while deserializing 
ScalarValue::List dictionary message: {e}"
+                        ))
+                    })?;
+                    let buffer = Buffer::from(arrow_data);
+
+                    let dict_batch = 
message.header_as_dictionary_batch().ok_or_else(|| {
+                        Error::General(
+                            "Unexpected message type deserializing 
ScalarValue::List dictionary message"
+                                .to_string(),
+                        )
+                    })?;
+
+                    let id = dict_batch.id();
+
+                    let fields_using_this_dictionary = 
schema.fields_with_dict_id(id);
+                    let first_field = 
fields_using_this_dictionary.first().ok_or_else(|| {
+                        Error::General("dictionary id not found in schema 
while deserializing ScalarValue::List".to_string())
+                    })?;
+
+                    let values: ArrayRef = match first_field.data_type() {
+                        DataType::Dictionary(_, ref value_type) => {
+                            // Make a fake schema for the dictionary batch.
+                            let value = value_type.as_ref().clone();
+                            let schema = Schema::new(vec![Field::new("", 
value, true)]);
+                            // Read a single column
+                            let record_batch = read_record_batch(
+                                &buffer,
+                                dict_batch.data().unwrap(),
+                                Arc::new(schema),
+                                &Default::default(),
+                                None,
+                                &message.version(),
+                            )?;
+                            Ok(record_batch.column(0).clone())
+                        }
+                        _ => Err(Error::General("dictionary id not found in 
schema while deserializing ScalarValue::List".to_string())),
+                    }?;
+
+                    Ok((id,values))
+                }).collect::<Result<HashMap<_,_>>>()?;
+
                 let record_batch = read_record_batch(
                     &buffer,
                     ipc_batch,
                     Arc::new(schema),
-                    &Default::default(),
+                    &dict_by_id,
                     None,
                     &message.version(),
                 )
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index ecdbde6faf..52482c890a 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1497,7 +1497,7 @@ fn encode_scalar_nested_value(
 
     let gen = IpcDataGenerator {};
     let mut dict_tracker = DictionaryTracker::new(false);
-    let (_, encoded_message) = gen
+    let (encoded_dictionaries, encoded_message) = gen
         .encoded_batch(&batch, &mut dict_tracker, &Default::default())
         .map_err(|e| {
             Error::General(format!("Error encoding ScalarValue::List as IPC: 
{e}"))
@@ -1508,6 +1508,13 @@ fn encode_scalar_nested_value(
     let scalar_list_value = protobuf::ScalarNestedValue {
         ipc_message: encoded_message.ipc_message,
         arrow_data: encoded_message.arrow_data,
+        dictionaries: encoded_dictionaries
+            .into_iter()
+            .map(|data| protobuf::scalar_nested_value::Dictionary {
+                ipc_message: data.ipc_message,
+                arrow_data: data.arrow_data,
+            })
+            .collect(),
         schema: Some(schema),
     };
 
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index 4de0b7c06d..0515ed5006 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -494,9 +494,9 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
                         match expr_type {
                             ExprType::AggregateExpr(agg_node) => {
                                 let input_phy_expr: Vec<Arc<dyn PhysicalExpr>> 
= agg_node.expr.iter()
-                                    .map(|e| parse_physical_expr(e, registry, 
&physical_schema, extension_codec).unwrap()).collect();
+                                    .map(|e| parse_physical_expr(e, registry, 
&physical_schema, extension_codec)).collect::<Result<Vec<_>>>()?;
                                 let ordering_req: Vec<PhysicalSortExpr> = 
agg_node.ordering_req.iter()
-                                    .map(|e| parse_physical_sort_expr(e, 
registry, &physical_schema, extension_codec).unwrap()).collect();
+                                    .map(|e| parse_physical_sort_expr(e, 
registry, &physical_schema, extension_codec)).collect::<Result<Vec<_>>>()?;
                                 
agg_node.aggregate_function.as_ref().map(|func| {
                                     match func {
                                         AggregateFunction::AggrFunction(i) => {
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index b5b0b4c224..472d64905b 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -1104,9 +1104,58 @@ fn round_trip_scalar_values() {
             )
             .build()
             .unwrap(),
+        ScalarStructBuilder::new()
+            .with_scalar(
+                Field::new("a", DataType::Int32, true),
+                ScalarValue::from(23i32),
+            )
+            .with_scalar(
+                Field::new("b", DataType::Boolean, false),
+                ScalarValue::from(false),
+            )
+            .with_scalar(
+                Field::new(
+                    "c",
+                    DataType::Dictionary(
+                        Box::new(DataType::UInt16),
+                        Box::new(DataType::Utf8),
+                    ),
+                    false,
+                ),
+                ScalarValue::Dictionary(
+                    Box::new(DataType::UInt16),
+                    Box::new("value".into()),
+                ),
+            )
+            .build()
+            .unwrap(),
+        ScalarValue::try_from(&DataType::Struct(Fields::from(vec![
+            Field::new("a", DataType::Int32, true),
+            Field::new("b", DataType::Boolean, false),
+        ])))
+        .unwrap(),
         ScalarValue::try_from(&DataType::Struct(Fields::from(vec![
             Field::new("a", DataType::Int32, true),
             Field::new("b", DataType::Boolean, false),
+            Field::new(
+                "c",
+                DataType::Dictionary(
+                    Box::new(DataType::UInt16),
+                    Box::new(DataType::Binary),
+                ),
+                false,
+            ),
+            Field::new(
+                "d",
+                DataType::new_list(
+                    DataType::Dictionary(
+                        Box::new(DataType::UInt16),
+                        Box::new(DataType::Binary),
+                    ),
+                    false,
+                ),
+                false,
+            ),
         ])))
         .unwrap(),
         ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32, 
Some(b"bar".to_vec())),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to