This is an automated email from the ASF dual-hosted git repository.
thinkharderdev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new e7858ff0ab Handle dictionary values in ScalarValue serde (#10563)
e7858ff0ab is described below
commit e7858ff0ab1c282ab46bd93cabc3dc83db583165
Author: Dan Harris <[email protected]>
AuthorDate: Fri May 17 17:09:55 2024 -0400
Handle dictionary values in ScalarValue serde (#10563)
* Handle dictionary values in ScalarValue serde
* Do not panic on failed physical expr decoding (#241)
* revert clippy change
---
datafusion/proto/proto/datafusion.proto | 6 +
datafusion/proto/src/generated/pbjson.rs | 133 +++++++++++++++++++++
datafusion/proto/src/generated/prost.rs | 13 ++
datafusion/proto/src/logical_plan/from_proto.rs | 49 +++++++-
datafusion/proto/src/logical_plan/to_proto.rs | 9 +-
datafusion/proto/src/physical_plan/mod.rs | 4 +-
.../proto/tests/cases/roundtrip_logical_plan.rs | 49 ++++++++
7 files changed, 259 insertions(+), 4 deletions(-)
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index fd79345275..8d69b0bad5 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -797,9 +797,15 @@ message Union{
// Used for List/FixedSizeList/LargeList/Struct
message ScalarNestedValue {
+ message Dictionary {
+ bytes ipc_message = 1;
+ bytes arrow_data = 2;
+ }
+
bytes ipc_message = 1;
bytes arrow_data = 2;
Schema schema = 3;
+ repeated Dictionary dictionaries = 4;
}
message ScalarTime32Value {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 01d9a6e0dd..8df0aeb851 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -22605,6 +22605,9 @@ impl serde::Serialize for ScalarNestedValue {
if self.schema.is_some() {
len += 1;
}
+ if !self.dictionaries.is_empty() {
+ len += 1;
+ }
let mut struct_ser =
serializer.serialize_struct("datafusion.ScalarNestedValue", len)?;
if !self.ipc_message.is_empty() {
#[allow(clippy::needless_borrow)]
@@ -22617,6 +22620,9 @@ impl serde::Serialize for ScalarNestedValue {
if let Some(v) = self.schema.as_ref() {
struct_ser.serialize_field("schema", v)?;
}
+ if !self.dictionaries.is_empty() {
+ struct_ser.serialize_field("dictionaries", &self.dictionaries)?;
+ }
struct_ser.end()
}
}
@@ -22632,6 +22638,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue
{
"arrow_data",
"arrowData",
"schema",
+ "dictionaries",
];
#[allow(clippy::enum_variant_names)]
@@ -22639,6 +22646,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue
{
IpcMessage,
ArrowData,
Schema,
+ Dictionaries,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -22663,6 +22671,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue
{
"ipcMessage" | "ipc_message" =>
Ok(GeneratedField::IpcMessage),
"arrowData" | "arrow_data" =>
Ok(GeneratedField::ArrowData),
"schema" => Ok(GeneratedField::Schema),
+ "dictionaries" => Ok(GeneratedField::Dictionaries),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -22685,6 +22694,7 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue
{
let mut ipc_message__ = None;
let mut arrow_data__ = None;
let mut schema__ = None;
+ let mut dictionaries__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::IpcMessage => {
@@ -22709,18 +22719,141 @@ impl<'de> serde::Deserialize<'de> for
ScalarNestedValue {
}
schema__ = map_.next_value()?;
}
+ GeneratedField::Dictionaries => {
+ if dictionaries__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("dictionaries"));
+ }
+ dictionaries__ = Some(map_.next_value()?);
+ }
}
}
Ok(ScalarNestedValue {
ipc_message: ipc_message__.unwrap_or_default(),
arrow_data: arrow_data__.unwrap_or_default(),
schema: schema__,
+ dictionaries: dictionaries__.unwrap_or_default(),
})
}
}
deserializer.deserialize_struct("datafusion.ScalarNestedValue",
FIELDS, GeneratedVisitor)
}
}
+impl serde::Serialize for scalar_nested_value::Dictionary {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if !self.ipc_message.is_empty() {
+ len += 1;
+ }
+ if !self.arrow_data.is_empty() {
+ len += 1;
+ }
+ let mut struct_ser =
serializer.serialize_struct("datafusion.ScalarNestedValue.Dictionary", len)?;
+ if !self.ipc_message.is_empty() {
+ #[allow(clippy::needless_borrow)]
+ struct_ser.serialize_field("ipcMessage",
pbjson::private::base64::encode(&self.ipc_message).as_str())?;
+ }
+ if !self.arrow_data.is_empty() {
+ #[allow(clippy::needless_borrow)]
+ struct_ser.serialize_field("arrowData",
pbjson::private::base64::encode(&self.arrow_data).as_str())?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for scalar_nested_value::Dictionary {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "ipc_message",
+ "ipcMessage",
+ "arrow_data",
+ "arrowData",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ IpcMessage,
+ ArrowData,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "ipcMessage" | "ipc_message" =>
Ok(GeneratedField::IpcMessage),
+ "arrowData" | "arrow_data" =>
Ok(GeneratedField::ArrowData),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = scalar_nested_value::Dictionary;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct
datafusion.ScalarNestedValue.Dictionary")
+ }
+
+ fn visit_map<V>(self, mut map_: V) ->
std::result::Result<scalar_nested_value::Dictionary, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut ipc_message__ = None;
+ let mut arrow_data__ = None;
+ while let Some(k) = map_.next_key()? {
+ match k {
+ GeneratedField::IpcMessage => {
+ if ipc_message__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("ipcMessage"));
+ }
+ ipc_message__ =
+
Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0)
+ ;
+ }
+ GeneratedField::ArrowData => {
+ if arrow_data__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("arrowData"));
+ }
+ arrow_data__ =
+
Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0)
+ ;
+ }
+ }
+ }
+ Ok(scalar_nested_value::Dictionary {
+ ipc_message: ipc_message__.unwrap_or_default(),
+ arrow_data: arrow_data__.unwrap_or_default(),
+ })
+ }
+ }
+
deserializer.deserialize_struct("datafusion.ScalarNestedValue.Dictionary",
FIELDS, GeneratedVisitor)
+ }
+}
impl serde::Serialize for ScalarTime32Value {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 64e72ba038..b6b7687e6c 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1133,6 +1133,19 @@ pub struct ScalarNestedValue {
pub arrow_data: ::prost::alloc::vec::Vec<u8>,
#[prost(message, optional, tag = "3")]
pub schema: ::core::option::Option<Schema>,
+ #[prost(message, repeated, tag = "4")]
+ pub dictionaries:
::prost::alloc::vec::Vec<scalar_nested_value::Dictionary>,
+}
+/// Nested message and enum types in `ScalarNestedValue`.
+pub mod scalar_nested_value {
+ #[allow(clippy::derive_partial_eq_without_eq)]
+ #[derive(Clone, PartialEq, ::prost::Message)]
+ pub struct Dictionary {
+ #[prost(bytes = "vec", tag = "1")]
+ pub ipc_message: ::prost::alloc::vec::Vec<u8>,
+ #[prost(bytes = "vec", tag = "2")]
+ pub arrow_data: ::prost::alloc::vec::Vec<u8>,
+ }
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 585bcad7f3..5df8eb59e1 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -15,8 +15,10 @@
// specific language governing permissions and limitations
// under the License.
+use std::collections::HashMap;
use std::sync::Arc;
+use arrow::array::ArrayRef;
use arrow::{
array::AsArray,
buffer::Buffer,
@@ -522,6 +524,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
let protobuf::ScalarNestedValue {
ipc_message,
arrow_data,
+ dictionaries,
schema,
} = &v;
@@ -548,11 +551,55 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
)
})?;
+ let dict_by_id: HashMap<i64,ArrayRef> =
dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary {
ipc_message, arrow_data }| {
+ let message =
root_as_message(ipc_message.as_slice()).map_err(|e| {
+ Error::General(format!(
+ "Error IPC message while deserializing
ScalarValue::List dictionary message: {e}"
+ ))
+ })?;
+ let buffer = Buffer::from(arrow_data);
+
+ let dict_batch =
message.header_as_dictionary_batch().ok_or_else(|| {
+ Error::General(
+ "Unexpected message type deserializing
ScalarValue::List dictionary message"
+ .to_string(),
+ )
+ })?;
+
+ let id = dict_batch.id();
+
+ let fields_using_this_dictionary =
schema.fields_with_dict_id(id);
+ let first_field =
fields_using_this_dictionary.first().ok_or_else(|| {
+ Error::General("dictionary id not found in schema
while deserializing ScalarValue::List".to_string())
+ })?;
+
+ let values: ArrayRef = match first_field.data_type() {
+ DataType::Dictionary(_, ref value_type) => {
+ // Make a fake schema for the dictionary batch.
+ let value = value_type.as_ref().clone();
+ let schema = Schema::new(vec![Field::new("",
value, true)]);
+ // Read a single column
+ let record_batch = read_record_batch(
+ &buffer,
+ dict_batch.data().unwrap(),
+ Arc::new(schema),
+ &Default::default(),
+ None,
+ &message.version(),
+ )?;
+ Ok(record_batch.column(0).clone())
+ }
+ _ => Err(Error::General("dictionary id not found in
schema while deserializing ScalarValue::List".to_string())),
+ }?;
+
+ Ok((id,values))
+ }).collect::<Result<HashMap<_,_>>>()?;
+
let record_batch = read_record_batch(
&buffer,
ipc_batch,
Arc::new(schema),
- &Default::default(),
+ &dict_by_id,
None,
&message.version(),
)
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index ecdbde6faf..52482c890a 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1497,7 +1497,7 @@ fn encode_scalar_nested_value(
let gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
- let (_, encoded_message) = gen
+ let (encoded_dictionaries, encoded_message) = gen
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
.map_err(|e| {
Error::General(format!("Error encoding ScalarValue::List as IPC:
{e}"))
@@ -1508,6 +1508,13 @@ fn encode_scalar_nested_value(
let scalar_list_value = protobuf::ScalarNestedValue {
ipc_message: encoded_message.ipc_message,
arrow_data: encoded_message.arrow_data,
+ dictionaries: encoded_dictionaries
+ .into_iter()
+ .map(|data| protobuf::scalar_nested_value::Dictionary {
+ ipc_message: data.ipc_message,
+ arrow_data: data.arrow_data,
+ })
+ .collect(),
schema: Some(schema),
};
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 4de0b7c06d..0515ed5006 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -494,9 +494,9 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
match expr_type {
ExprType::AggregateExpr(agg_node) => {
let input_phy_expr: Vec<Arc<dyn PhysicalExpr>>
= agg_node.expr.iter()
- .map(|e| parse_physical_expr(e, registry,
&physical_schema, extension_codec).unwrap()).collect();
+ .map(|e| parse_physical_expr(e, registry,
&physical_schema, extension_codec)).collect::<Result<Vec<_>>>()?;
let ordering_req: Vec<PhysicalSortExpr> =
agg_node.ordering_req.iter()
- .map(|e| parse_physical_sort_expr(e,
registry, &physical_schema, extension_codec).unwrap()).collect();
+ .map(|e| parse_physical_sort_expr(e,
registry, &physical_schema, extension_codec)).collect::<Result<Vec<_>>>()?;
agg_node.aggregate_function.as_ref().map(|func| {
match func {
AggregateFunction::AggrFunction(i) => {
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index b5b0b4c224..472d64905b 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -1104,9 +1104,58 @@ fn round_trip_scalar_values() {
)
.build()
.unwrap(),
+ ScalarStructBuilder::new()
+ .with_scalar(
+ Field::new("a", DataType::Int32, true),
+ ScalarValue::from(23i32),
+ )
+ .with_scalar(
+ Field::new("b", DataType::Boolean, false),
+ ScalarValue::from(false),
+ )
+ .with_scalar(
+ Field::new(
+ "c",
+ DataType::Dictionary(
+ Box::new(DataType::UInt16),
+ Box::new(DataType::Utf8),
+ ),
+ false,
+ ),
+ ScalarValue::Dictionary(
+ Box::new(DataType::UInt16),
+ Box::new("value".into()),
+ ),
+ )
+ .build()
+ .unwrap(),
+ ScalarValue::try_from(&DataType::Struct(Fields::from(vec![
+ Field::new("a", DataType::Int32, true),
+ Field::new("b", DataType::Boolean, false),
+ ])))
+ .unwrap(),
ScalarValue::try_from(&DataType::Struct(Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Boolean, false),
+ Field::new(
+ "c",
+ DataType::Dictionary(
+ Box::new(DataType::UInt16),
+ Box::new(DataType::Binary),
+ ),
+ false,
+ ),
+ Field::new(
+ "d",
+ DataType::new_list(
+ DataType::Dictionary(
+ Box::new(DataType::UInt16),
+ Box::new(DataType::Binary),
+ ),
+ false,
+ ),
+ false,
+ ),
])))
.unwrap(),
ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32,
Some(b"bar".to_vec())),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]