This is an automated email from the ASF dual-hosted git repository.
avantgardner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 8074ca1e75 Support Union types in `ScalarValue` (#9683)
8074ca1e75 is described below
commit 8074ca1e758470319699a562074290906003b312
Author: Brent Gardner <[email protected]>
AuthorDate: Tue Mar 19 12:14:13 2024 -0600
Support Union types in `ScalarValue` (#9683)
Support Union types in `ScalarValue` (#9683)
---
datafusion/common/src/error.rs | 4 +-
datafusion/common/src/scalar/mod.rs | 82 +++++++
datafusion/physical-plan/src/filter.rs | 35 +++
datafusion/proto/proto/datafusion.proto | 15 ++
datafusion/proto/src/generated/pbjson.rs | 272 ++++++++++++++++++++++++
datafusion/proto/src/generated/prost.rs | 26 ++-
datafusion/proto/src/logical_plan/from_proto.rs | 35 +++
datafusion/proto/src/logical_plan/to_proto.rs | 29 +++
datafusion/sql/src/unparser/expr.rs | 1 +
9 files changed, 496 insertions(+), 3 deletions(-)
diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs
index 1ecd5b62be..d1e47b4734 100644
--- a/datafusion/common/src/error.rs
+++ b/datafusion/common/src/error.rs
@@ -63,7 +63,7 @@ pub enum DataFusionError {
IoError(io::Error),
/// Error when SQL is syntactically incorrect.
///
- /// 2nd argument is for optional backtrace
+ /// 2nd argument is for optional backtrace
SQL(ParserError, Option<String>),
/// Error when a feature is not yet implemented.
///
@@ -101,7 +101,7 @@ pub enum DataFusionError {
/// This error can be returned in cases such as when schema inference is
not
/// possible and when column names are not unique.
///
- /// 2nd argument is for optional backtrace
+ /// 2nd argument is for optional backtrace
/// Boxing the optional backtrace to prevent
<https://rust-lang.github.io/rust-clippy/master/index.html#/result_large_err>
SchemaError(SchemaError, Box<Option<String>>),
/// Error during execution of the query.
diff --git a/datafusion/common/src/scalar/mod.rs
b/datafusion/common/src/scalar/mod.rs
index a2484e93e8..d33b8b6e14 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -53,6 +53,8 @@ use arrow::{
},
};
use arrow_array::{ArrowNativeTypeOp, Scalar};
+use arrow_buffer::Buffer;
+use arrow_schema::{UnionFields, UnionMode};
pub use struct_builder::ScalarStructBuilder;
@@ -275,6 +277,11 @@ pub enum ScalarValue {
DurationMicrosecond(Option<i64>),
/// Duration in nanoseconds
DurationNanosecond(Option<i64>),
+ /// A nested datatype that can represent slots of differing types.
Components:
+ /// `.0`: a tuple of union `type_id` and the single value held by this
Scalar
+ /// `.1`: the list of fields, zero-to-one of which will by set in `.0`
+ /// `.2`: the physical storage of the source/destination UnionArray from
which this Scalar came
+ Union(Option<(i8, Box<ScalarValue>)>, UnionFields, UnionMode),
/// Dictionary type: index type and value
Dictionary(Box<DataType>, Box<ScalarValue>),
}
@@ -375,6 +382,10 @@ impl PartialEq for ScalarValue {
(IntervalDayTime(_), _) => false,
(IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2),
(IntervalMonthDayNano(_), _) => false,
+ (Union(val1, fields1, mode1), Union(val2, fields2, mode2)) => {
+ val1.eq(val2) && fields1.eq(fields2) && mode1.eq(mode2)
+ }
+ (Union(_, _, _), _) => false,
(Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2),
(Dictionary(_, _), _) => false,
(Null, Null) => true,
@@ -500,6 +511,14 @@ impl PartialOrd for ScalarValue {
(DurationMicrosecond(_), _) => None,
(DurationNanosecond(v1), DurationNanosecond(v2)) =>
v1.partial_cmp(v2),
(DurationNanosecond(_), _) => None,
+ (Union(v1, t1, m1), Union(v2, t2, m2)) => {
+ if t1.eq(t2) && m1.eq(m2) {
+ v1.partial_cmp(v2)
+ } else {
+ None
+ }
+ }
+ (Union(_, _, _), _) => None,
(Dictionary(k1, v1), Dictionary(k2, v2)) => {
// Don't compare if the key types don't match (it is
effectively a different datatype)
if k1 == k2 {
@@ -663,6 +682,11 @@ impl std::hash::Hash for ScalarValue {
IntervalYearMonth(v) => v.hash(state),
IntervalDayTime(v) => v.hash(state),
IntervalMonthDayNano(v) => v.hash(state),
+ Union(v, t, m) => {
+ v.hash(state);
+ t.hash(state);
+ m.hash(state);
+ }
Dictionary(k, v) => {
k.hash(state);
v.hash(state);
@@ -1093,6 +1117,7 @@ impl ScalarValue {
ScalarValue::DurationNanosecond(_) => {
DataType::Duration(TimeUnit::Nanosecond)
}
+ ScalarValue::Union(_, fields, mode) =>
DataType::Union(fields.clone(), *mode),
ScalarValue::Dictionary(k, v) => {
DataType::Dictionary(k.clone(), Box::new(v.data_type()))
}
@@ -1292,6 +1317,7 @@ impl ScalarValue {
ScalarValue::DurationMillisecond(v) => v.is_none(),
ScalarValue::DurationMicrosecond(v) => v.is_none(),
ScalarValue::DurationNanosecond(v) => v.is_none(),
+ ScalarValue::Union(v, _, _) => v.is_none(),
ScalarValue::Dictionary(_, v) => v.is_null(),
}
}
@@ -2087,6 +2113,39 @@ impl ScalarValue {
e,
size
),
+ ScalarValue::Union(value, fields, _mode) => match value {
+ Some((v_id, value)) => {
+ let mut field_type_ids =
Vec::<i8>::with_capacity(fields.len());
+ let mut child_arrays =
+ Vec::<(Field, ArrayRef)>::with_capacity(fields.len());
+ for (f_id, field) in fields.iter() {
+ let ar = if f_id == *v_id {
+ value.to_array_of_size(size)?
+ } else {
+ let dt = field.data_type();
+ new_null_array(dt, size)
+ };
+ let field = (**field).clone();
+ child_arrays.push((field, ar));
+ field_type_ids.push(f_id);
+ }
+ let type_ids =
repeat(*v_id).take(size).collect::<Vec<_>>();
+ let type_ids = Buffer::from_slice_ref(type_ids);
+ let value_offsets: Option<Buffer> = None;
+ let ar = UnionArray::try_new(
+ field_type_ids.as_slice(),
+ type_ids,
+ value_offsets,
+ child_arrays,
+ )
+ .map_err(|e| DataFusionError::ArrowError(e, None))?;
+ Arc::new(ar)
+ }
+ None => {
+ let dt = self.data_type();
+ new_null_array(&dt, size)
+ }
+ },
ScalarValue::Dictionary(key_type, v) => {
// values array is one element long (the value)
match key_type.as_ref() {
@@ -2626,6 +2685,9 @@ impl ScalarValue {
ScalarValue::DurationNanosecond(val) => {
eq_array_primitive!(array, index, DurationNanosecondArray,
val)?
}
+ ScalarValue::Union(_, _, _) => {
+ return _not_impl_err!("Union is not supported yet")
+ }
ScalarValue::Dictionary(key_type, v) => {
let (values_array, values_index) = match key_type.as_ref() {
DataType::Int8 => get_dict_value::<Int8Type>(array,
index)?,
@@ -2703,6 +2765,15 @@ impl ScalarValue {
ScalarValue::LargeList(arr) => arr.get_array_memory_size(),
ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(),
ScalarValue::Struct(arr) => arr.get_array_memory_size(),
+ ScalarValue::Union(vals, fields, _mode) => {
+ vals.as_ref()
+ .map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv))
+ .unwrap_or_default()
+ // `fields` is boxed, so it is NOT already included in
`self`
+ + std::mem::size_of_val(fields)
+ + (std::mem::size_of::<Field>() * fields.len())
+ + fields.iter().map(|(_idx, field)| field.size() -
std::mem::size_of_val(field)).sum::<usize>()
+ }
ScalarValue::Dictionary(dt, sv) => {
// `dt` and `sv` are boxed, so they are NOT already
included in `self`
dt.size() + sv.size()
@@ -3048,6 +3119,9 @@ impl TryFrom<&DataType> for ScalarValue {
.to_owned()
.into(),
),
+ DataType::Union(fields, mode) => {
+ ScalarValue::Union(None, fields.clone(), *mode)
+ }
DataType::Null => ScalarValue::Null,
_ => {
return _not_impl_err!(
@@ -3164,6 +3238,10 @@ impl fmt::Display for ScalarValue {
.join(",")
)?
}
+ ScalarValue::Union(val, _fields, _mode) => match val {
+ Some((id, val)) => write!(f, "{}:{}", id, val)?,
+ None => write!(f, "NULL")?,
+ },
ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?,
ScalarValue::Null => write!(f, "NULL")?,
};
@@ -3279,6 +3357,10 @@ impl fmt::Debug for ScalarValue {
ScalarValue::DurationNanosecond(_) => {
write!(f, "DurationNanosecond(\"{self}\")")
}
+ ScalarValue::Union(val, _fields, _mode) => match val {
+ Some((id, val)) => write!(f, "Union {}:{}", id, val),
+ None => write!(f, "Union(NULL)"),
+ },
ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?},
{v:?})"),
ScalarValue::Null => write!(f, "NULL"),
}
diff --git a/datafusion/physical-plan/src/filter.rs
b/datafusion/physical-plan/src/filter.rs
index 72f885a939..f44ade7106 100644
--- a/datafusion/physical-plan/src/filter.rs
+++ b/datafusion/physical-plan/src/filter.rs
@@ -441,7 +441,9 @@ mod tests {
use crate::test::exec::StatisticsExec;
use crate::ExecutionPlan;
+ use crate::empty::EmptyExec;
use arrow::datatypes::{DataType, Field, Schema};
+ use arrow_schema::{UnionFields, UnionMode};
use datafusion_common::{ColumnStatistics, ScalarValue};
use datafusion_expr::Operator;
@@ -1090,4 +1092,37 @@ mod tests {
assert_eq!(statistics.total_byte_size, Precision::Inexact(1600));
Ok(())
}
+
+ #[test]
+ fn test_equivalence_properties_union_type() -> Result<()> {
+ let union_type = DataType::Union(
+ UnionFields::new(
+ vec![0, 1],
+ vec![
+ Field::new("f1", DataType::Int32, true),
+ Field::new("f2", DataType::Utf8, true),
+ ],
+ ),
+ UnionMode::Sparse,
+ );
+
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("c1", DataType::Int32, true),
+ Field::new("c2", union_type, true),
+ ]));
+
+ let exec = FilterExec::try_new(
+ binary(
+ binary(col("c1", &schema)?, Operator::GtEq, lit(1i32),
&schema)?,
+ Operator::And,
+ binary(col("c1", &schema)?, Operator::LtEq, lit(4i32),
&schema)?,
+ &schema,
+ )?,
+ Arc::new(EmptyExec::new(schema.clone())),
+ )?;
+
+ exec.statistics().unwrap();
+
+ Ok(())
+ }
}
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 7a9b427ce7..10f79a2b8c 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -988,6 +988,20 @@ message IntervalMonthDayNanoValue {
int64 nanos = 3;
}
+message UnionField {
+ int32 field_id = 1;
+ Field field = 2;
+}
+
+message UnionValue {
+ // Note that a null union value must have one or more fields, so we
+ // encode a null UnionValue as one with value_id == 128
+ int32 value_id = 1;
+ ScalarValue value = 2;
+ repeated UnionField fields = 3;
+ UnionMode mode = 4;
+}
+
message ScalarFixedSizeBinary{
bytes values = 1;
int32 length = 2;
@@ -1042,6 +1056,7 @@ message ScalarValue{
ScalarTime64Value time64_value = 30;
IntervalMonthDayNanoValue interval_month_day_nano = 31;
ScalarFixedSizeBinary fixed_size_binary_value = 34;
+ UnionValue union_value = 42;
}
}
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index fd27520b3b..7757a64ef3 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -24053,6 +24053,9 @@ impl serde::Serialize for ScalarValue {
scalar_value::Value::FixedSizeBinaryValue(v) => {
struct_ser.serialize_field("fixedSizeBinaryValue", v)?;
}
+ scalar_value::Value::UnionValue(v) => {
+ struct_ser.serialize_field("unionValue", v)?;
+ }
}
}
struct_ser.end()
@@ -24137,6 +24140,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
"intervalMonthDayNano",
"fixed_size_binary_value",
"fixedSizeBinaryValue",
+ "union_value",
+ "unionValue",
];
#[allow(clippy::enum_variant_names)]
@@ -24177,6 +24182,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
Time64Value,
IntervalMonthDayNano,
FixedSizeBinaryValue,
+ UnionValue,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -24234,6 +24240,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
"time64Value" | "time64_value" =>
Ok(GeneratedField::Time64Value),
"intervalMonthDayNano" | "interval_month_day_nano"
=> Ok(GeneratedField::IntervalMonthDayNano),
"fixedSizeBinaryValue" | "fixed_size_binary_value"
=> Ok(GeneratedField::FixedSizeBinaryValue),
+ "unionValue" | "union_value" =>
Ok(GeneratedField::UnionValue),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -24483,6 +24490,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
return
Err(serde::de::Error::duplicate_field("fixedSizeBinaryValue"));
}
value__ =
map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeBinaryValue)
+;
+ }
+ GeneratedField::UnionValue => {
+ if value__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("unionValue"));
+ }
+ value__ =
map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::UnionValue)
;
}
}
@@ -26942,6 +26956,117 @@ impl<'de> serde::Deserialize<'de> for UnionExecNode {
deserializer.deserialize_struct("datafusion.UnionExecNode", FIELDS,
GeneratedVisitor)
}
}
+impl serde::Serialize for UnionField {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if self.field_id != 0 {
+ len += 1;
+ }
+ if self.field.is_some() {
+ len += 1;
+ }
+ let mut struct_ser =
serializer.serialize_struct("datafusion.UnionField", len)?;
+ if self.field_id != 0 {
+ struct_ser.serialize_field("fieldId", &self.field_id)?;
+ }
+ if let Some(v) = self.field.as_ref() {
+ struct_ser.serialize_field("field", v)?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for UnionField {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "field_id",
+ "fieldId",
+ "field",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ FieldId,
+ Field,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "fieldId" | "field_id" =>
Ok(GeneratedField::FieldId),
+ "field" => Ok(GeneratedField::Field),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = UnionField;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct datafusion.UnionField")
+ }
+
+ fn visit_map<V>(self, mut map_: V) ->
std::result::Result<UnionField, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut field_id__ = None;
+ let mut field__ = None;
+ while let Some(k) = map_.next_key()? {
+ match k {
+ GeneratedField::FieldId => {
+ if field_id__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("fieldId"));
+ }
+ field_id__ =
+
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+ ;
+ }
+ GeneratedField::Field => {
+ if field__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("field"));
+ }
+ field__ = map_.next_value()?;
+ }
+ }
+ }
+ Ok(UnionField {
+ field_id: field_id__.unwrap_or_default(),
+ field: field__,
+ })
+ }
+ }
+ deserializer.deserialize_struct("datafusion.UnionField", FIELDS,
GeneratedVisitor)
+ }
+}
impl serde::Serialize for UnionMode {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
@@ -27104,6 +27229,153 @@ impl<'de> serde::Deserialize<'de> for UnionNode {
deserializer.deserialize_struct("datafusion.UnionNode", FIELDS,
GeneratedVisitor)
}
}
+impl serde::Serialize for UnionValue {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if self.value_id != 0 {
+ len += 1;
+ }
+ if self.value.is_some() {
+ len += 1;
+ }
+ if !self.fields.is_empty() {
+ len += 1;
+ }
+ if self.mode != 0 {
+ len += 1;
+ }
+ let mut struct_ser =
serializer.serialize_struct("datafusion.UnionValue", len)?;
+ if self.value_id != 0 {
+ struct_ser.serialize_field("valueId", &self.value_id)?;
+ }
+ if let Some(v) = self.value.as_ref() {
+ struct_ser.serialize_field("value", v)?;
+ }
+ if !self.fields.is_empty() {
+ struct_ser.serialize_field("fields", &self.fields)?;
+ }
+ if self.mode != 0 {
+ let v = UnionMode::try_from(self.mode)
+ .map_err(|_| serde::ser::Error::custom(format!("Invalid
variant {}", self.mode)))?;
+ struct_ser.serialize_field("mode", &v)?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for UnionValue {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "value_id",
+ "valueId",
+ "value",
+ "fields",
+ "mode",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ ValueId,
+ Value,
+ Fields,
+ Mode,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "valueId" | "value_id" =>
Ok(GeneratedField::ValueId),
+ "value" => Ok(GeneratedField::Value),
+ "fields" => Ok(GeneratedField::Fields),
+ "mode" => Ok(GeneratedField::Mode),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = UnionValue;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct datafusion.UnionValue")
+ }
+
+ fn visit_map<V>(self, mut map_: V) ->
std::result::Result<UnionValue, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut value_id__ = None;
+ let mut value__ = None;
+ let mut fields__ = None;
+ let mut mode__ = None;
+ while let Some(k) = map_.next_key()? {
+ match k {
+ GeneratedField::ValueId => {
+ if value_id__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("valueId"));
+ }
+ value_id__ =
+
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+ ;
+ }
+ GeneratedField::Value => {
+ if value__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("value"));
+ }
+ value__ = map_.next_value()?;
+ }
+ GeneratedField::Fields => {
+ if fields__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("fields"));
+ }
+ fields__ = Some(map_.next_value()?);
+ }
+ GeneratedField::Mode => {
+ if mode__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("mode"));
+ }
+ mode__ = Some(map_.next_value::<UnionMode>()? as
i32);
+ }
+ }
+ }
+ Ok(UnionValue {
+ value_id: value_id__.unwrap_or_default(),
+ value: value__,
+ fields: fields__.unwrap_or_default(),
+ mode: mode__.unwrap_or_default(),
+ })
+ }
+ }
+ deserializer.deserialize_struct("datafusion.UnionValue", FIELDS,
GeneratedVisitor)
+ }
+}
impl serde::Serialize for UniqueConstraint {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 16ad2b848d..ab0ddb14eb 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1225,6 +1225,28 @@ pub struct IntervalMonthDayNanoValue {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct UnionField {
+ #[prost(int32, tag = "1")]
+ pub field_id: i32,
+ #[prost(message, optional, tag = "2")]
+ pub field: ::core::option::Option<Field>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct UnionValue {
+ /// Note that a null union value must have one or more fields, so we
+ /// encode a null UnionValue as one with value_id == 128
+ #[prost(int32, tag = "1")]
+ pub value_id: i32,
+ #[prost(message, optional, boxed, tag = "2")]
+ pub value: ::core::option::Option<::prost::alloc::boxed::Box<ScalarValue>>,
+ #[prost(message, repeated, tag = "3")]
+ pub fields: ::prost::alloc::vec::Vec<UnionField>,
+ #[prost(enumeration = "UnionMode", tag = "4")]
+ pub mode: i32,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ScalarFixedSizeBinary {
#[prost(bytes = "vec", tag = "1")]
pub values: ::prost::alloc::vec::Vec<u8>,
@@ -1236,7 +1258,7 @@ pub struct ScalarFixedSizeBinary {
pub struct ScalarValue {
#[prost(
oneof = "scalar_value::Value",
- tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 32, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 34"
+ tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 32, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 34, 42"
)]
pub value: ::core::option::Option<scalar_value::Value>,
}
@@ -1320,6 +1342,8 @@ pub mod scalar_value {
IntervalMonthDayNano(super::IntervalMonthDayNanoValue),
#[prost(message, tag = "34")]
FixedSizeBinaryValue(super::ScalarFixedSizeBinary),
+ #[prost(message, tag = "42")]
+ UnionValue(::prost::alloc::boxed::Box<super::UnionValue>),
}
}
#[allow(clippy::derive_partial_eq_without_eq)]
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 06aab16edd..8581156e2b 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -768,6 +768,41 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
Value::IntervalMonthDayNano(v) => Self::IntervalMonthDayNano(Some(
IntervalMonthDayNanoType::make_value(v.months, v.days,
v.nanos),
)),
+ Value::UnionValue(val) => {
+ let mode = match val.mode {
+ 0 => UnionMode::Sparse,
+ 1 => UnionMode::Dense,
+ id => Err(Error::unknown("UnionMode", id))?,
+ };
+ let ids = val
+ .fields
+ .iter()
+ .map(|f| f.field_id as i8)
+ .collect::<Vec<_>>();
+ let fields = val
+ .fields
+ .iter()
+ .map(|f| f.field.clone())
+ .collect::<Option<Vec<_>>>();
+ let fields = fields.ok_or_else(||
Error::required("UnionField"))?;
+ let fields = fields
+ .iter()
+ .map(Field::try_from)
+ .collect::<Result<Vec<_>, _>>()?;
+ let fields = UnionFields::new(ids, fields);
+ let v_id = val.value_id as i8;
+ let val = match &val.value {
+ None => None,
+ Some(val) => {
+ let val: ScalarValue = val
+ .as_ref()
+ .try_into()
+ .map_err(|_| Error::General("Invalid
Scalar".to_string()))?;
+ Some((v_id, Box::new(val)))
+ }
+ };
+ Self::Union(val, fields, mode)
+ }
Value::FixedSizeBinaryValue(v) => {
Self::FixedSizeBinary(v.length, Some(v.clone().values))
}
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 9201559490..05a29ff6d4 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -30,6 +30,7 @@ use crate::protobuf::{
},
AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode,
LogicalExprList,
OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode,
RollupNode,
+ UnionField, UnionValue,
};
use arrow::{
@@ -1405,6 +1406,34 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
};
Ok(protobuf::ScalarValue { value: Some(value) })
}
+
+ ScalarValue::Union(val, df_fields, mode) => {
+ let mut fields =
Vec::<UnionField>::with_capacity(df_fields.len());
+ for (id, field) in df_fields.iter() {
+ let field_id = id as i32;
+ let field = Some(field.as_ref().try_into()?);
+ let field = UnionField { field_id, field };
+ fields.push(field);
+ }
+ let mode = match mode {
+ UnionMode::Sparse => 0,
+ UnionMode::Dense => 1,
+ };
+ let value = match val {
+ None => None,
+ Some((_id, v)) => Some(Box::new(v.as_ref().try_into()?)),
+ };
+ let val = UnionValue {
+ value_id: val.as_ref().map(|(id, _v)| *id as
i32).unwrap_or(0),
+ value,
+ fields,
+ mode,
+ };
+ let val = Value::UnionValue(Box::new(val));
+ let val = protobuf::ScalarValue { value: Some(val) };
+ Ok(val)
+ }
+
ScalarValue::Dictionary(index_type, val) => {
let value: protobuf::ScalarValue = val.as_ref().try_into()?;
Ok(protobuf::ScalarValue {
diff --git a/datafusion/sql/src/unparser/expr.rs
b/datafusion/sql/src/unparser/expr.rs
index c26e8481ce..43f3e348dc 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -456,6 +456,7 @@ impl Unparser<'_> {
Ok(ast::Expr::Value(ast::Value::Null))
}
ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar:
{v:?}"),
+ ScalarValue::Union(..) => not_impl_err!("Unsupported scalar:
{v:?}"),
ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar:
{v:?}"),
}
}