This is an automated email from the ASF dual-hosted git repository.

avantgardner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 8074ca1e75 Support Union types in `ScalarValue`  (#9683)
8074ca1e75 is described below

commit 8074ca1e758470319699a562074290906003b312
Author: Brent Gardner <[email protected]>
AuthorDate: Tue Mar 19 12:14:13 2024 -0600

    Support Union types in `ScalarValue`  (#9683)
    
    Support Union types in `ScalarValue`  (#9683)
---
 datafusion/common/src/error.rs                  |   4 +-
 datafusion/common/src/scalar/mod.rs             |  82 +++++++
 datafusion/physical-plan/src/filter.rs          |  35 +++
 datafusion/proto/proto/datafusion.proto         |  15 ++
 datafusion/proto/src/generated/pbjson.rs        | 272 ++++++++++++++++++++++++
 datafusion/proto/src/generated/prost.rs         |  26 ++-
 datafusion/proto/src/logical_plan/from_proto.rs |  35 +++
 datafusion/proto/src/logical_plan/to_proto.rs   |  29 +++
 datafusion/sql/src/unparser/expr.rs             |   1 +
 9 files changed, 496 insertions(+), 3 deletions(-)

diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs
index 1ecd5b62be..d1e47b4734 100644
--- a/datafusion/common/src/error.rs
+++ b/datafusion/common/src/error.rs
@@ -63,7 +63,7 @@ pub enum DataFusionError {
     IoError(io::Error),
     /// Error when SQL is syntactically incorrect.
     ///
-    /// 2nd argument is for optional backtrace    
+    /// 2nd argument is for optional backtrace
     SQL(ParserError, Option<String>),
     /// Error when a feature is not yet implemented.
     ///
@@ -101,7 +101,7 @@ pub enum DataFusionError {
     /// This error can be returned in cases such as when schema inference is 
not
     /// possible and when column names are not unique.
     ///
-    /// 2nd argument is for optional backtrace    
+    /// 2nd argument is for optional backtrace
     /// Boxing the optional backtrace to prevent 
<https://rust-lang.github.io/rust-clippy/master/index.html#/result_large_err>
     SchemaError(SchemaError, Box<Option<String>>),
     /// Error during execution of the query.
diff --git a/datafusion/common/src/scalar/mod.rs 
b/datafusion/common/src/scalar/mod.rs
index a2484e93e8..d33b8b6e14 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -53,6 +53,8 @@ use arrow::{
     },
 };
 use arrow_array::{ArrowNativeTypeOp, Scalar};
+use arrow_buffer::Buffer;
+use arrow_schema::{UnionFields, UnionMode};
 
 pub use struct_builder::ScalarStructBuilder;
 
@@ -275,6 +277,11 @@ pub enum ScalarValue {
     DurationMicrosecond(Option<i64>),
     /// Duration in nanoseconds
     DurationNanosecond(Option<i64>),
+    /// A nested datatype that can represent slots of differing types. 
Components:
+    /// `.0`: a tuple of union `type_id` and the single value held by this 
Scalar
+    /// `.1`: the list of fields, zero-to-one of which will by set in `.0`
+    /// `.2`: the physical storage of the source/destination UnionArray from 
which this Scalar came
+    Union(Option<(i8, Box<ScalarValue>)>, UnionFields, UnionMode),
     /// Dictionary type: index type and value
     Dictionary(Box<DataType>, Box<ScalarValue>),
 }
@@ -375,6 +382,10 @@ impl PartialEq for ScalarValue {
             (IntervalDayTime(_), _) => false,
             (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2),
             (IntervalMonthDayNano(_), _) => false,
+            (Union(val1, fields1, mode1), Union(val2, fields2, mode2)) => {
+                val1.eq(val2) && fields1.eq(fields2) && mode1.eq(mode2)
+            }
+            (Union(_, _, _), _) => false,
             (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2),
             (Dictionary(_, _), _) => false,
             (Null, Null) => true,
@@ -500,6 +511,14 @@ impl PartialOrd for ScalarValue {
             (DurationMicrosecond(_), _) => None,
             (DurationNanosecond(v1), DurationNanosecond(v2)) => 
v1.partial_cmp(v2),
             (DurationNanosecond(_), _) => None,
+            (Union(v1, t1, m1), Union(v2, t2, m2)) => {
+                if t1.eq(t2) && m1.eq(m2) {
+                    v1.partial_cmp(v2)
+                } else {
+                    None
+                }
+            }
+            (Union(_, _, _), _) => None,
             (Dictionary(k1, v1), Dictionary(k2, v2)) => {
                 // Don't compare if the key types don't match (it is 
effectively a different datatype)
                 if k1 == k2 {
@@ -663,6 +682,11 @@ impl std::hash::Hash for ScalarValue {
             IntervalYearMonth(v) => v.hash(state),
             IntervalDayTime(v) => v.hash(state),
             IntervalMonthDayNano(v) => v.hash(state),
+            Union(v, t, m) => {
+                v.hash(state);
+                t.hash(state);
+                m.hash(state);
+            }
             Dictionary(k, v) => {
                 k.hash(state);
                 v.hash(state);
@@ -1093,6 +1117,7 @@ impl ScalarValue {
             ScalarValue::DurationNanosecond(_) => {
                 DataType::Duration(TimeUnit::Nanosecond)
             }
+            ScalarValue::Union(_, fields, mode) => 
DataType::Union(fields.clone(), *mode),
             ScalarValue::Dictionary(k, v) => {
                 DataType::Dictionary(k.clone(), Box::new(v.data_type()))
             }
@@ -1292,6 +1317,7 @@ impl ScalarValue {
             ScalarValue::DurationMillisecond(v) => v.is_none(),
             ScalarValue::DurationMicrosecond(v) => v.is_none(),
             ScalarValue::DurationNanosecond(v) => v.is_none(),
+            ScalarValue::Union(v, _, _) => v.is_none(),
             ScalarValue::Dictionary(_, v) => v.is_null(),
         }
     }
@@ -2087,6 +2113,39 @@ impl ScalarValue {
                 e,
                 size
             ),
+            ScalarValue::Union(value, fields, _mode) => match value {
+                Some((v_id, value)) => {
+                    let mut field_type_ids = 
Vec::<i8>::with_capacity(fields.len());
+                    let mut child_arrays =
+                        Vec::<(Field, ArrayRef)>::with_capacity(fields.len());
+                    for (f_id, field) in fields.iter() {
+                        let ar = if f_id == *v_id {
+                            value.to_array_of_size(size)?
+                        } else {
+                            let dt = field.data_type();
+                            new_null_array(dt, size)
+                        };
+                        let field = (**field).clone();
+                        child_arrays.push((field, ar));
+                        field_type_ids.push(f_id);
+                    }
+                    let type_ids = 
repeat(*v_id).take(size).collect::<Vec<_>>();
+                    let type_ids = Buffer::from_slice_ref(type_ids);
+                    let value_offsets: Option<Buffer> = None;
+                    let ar = UnionArray::try_new(
+                        field_type_ids.as_slice(),
+                        type_ids,
+                        value_offsets,
+                        child_arrays,
+                    )
+                    .map_err(|e| DataFusionError::ArrowError(e, None))?;
+                    Arc::new(ar)
+                }
+                None => {
+                    let dt = self.data_type();
+                    new_null_array(&dt, size)
+                }
+            },
             ScalarValue::Dictionary(key_type, v) => {
                 // values array is one element long (the value)
                 match key_type.as_ref() {
@@ -2626,6 +2685,9 @@ impl ScalarValue {
             ScalarValue::DurationNanosecond(val) => {
                 eq_array_primitive!(array, index, DurationNanosecondArray, 
val)?
             }
+            ScalarValue::Union(_, _, _) => {
+                return _not_impl_err!("Union is not supported yet")
+            }
             ScalarValue::Dictionary(key_type, v) => {
                 let (values_array, values_index) = match key_type.as_ref() {
                     DataType::Int8 => get_dict_value::<Int8Type>(array, 
index)?,
@@ -2703,6 +2765,15 @@ impl ScalarValue {
                 ScalarValue::LargeList(arr) => arr.get_array_memory_size(),
                 ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(),
                 ScalarValue::Struct(arr) => arr.get_array_memory_size(),
+                ScalarValue::Union(vals, fields, _mode) => {
+                    vals.as_ref()
+                        .map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv))
+                        .unwrap_or_default()
+                        // `fields` is boxed, so it is NOT already included in 
`self`
+                        + std::mem::size_of_val(fields)
+                        + (std::mem::size_of::<Field>() * fields.len())
+                        + fields.iter().map(|(_idx, field)| field.size() - 
std::mem::size_of_val(field)).sum::<usize>()
+                }
                 ScalarValue::Dictionary(dt, sv) => {
                     // `dt` and `sv` are boxed, so they are NOT already 
included in `self`
                     dt.size() + sv.size()
@@ -3048,6 +3119,9 @@ impl TryFrom<&DataType> for ScalarValue {
                     .to_owned()
                     .into(),
             ),
+            DataType::Union(fields, mode) => {
+                ScalarValue::Union(None, fields.clone(), *mode)
+            }
             DataType::Null => ScalarValue::Null,
             _ => {
                 return _not_impl_err!(
@@ -3164,6 +3238,10 @@ impl fmt::Display for ScalarValue {
                         .join(",")
                 )?
             }
+            ScalarValue::Union(val, _fields, _mode) => match val {
+                Some((id, val)) => write!(f, "{}:{}", id, val)?,
+                None => write!(f, "NULL")?,
+            },
             ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?,
             ScalarValue::Null => write!(f, "NULL")?,
         };
@@ -3279,6 +3357,10 @@ impl fmt::Debug for ScalarValue {
             ScalarValue::DurationNanosecond(_) => {
                 write!(f, "DurationNanosecond(\"{self}\")")
             }
+            ScalarValue::Union(val, _fields, _mode) => match val {
+                Some((id, val)) => write!(f, "Union {}:{}", id, val),
+                None => write!(f, "Union(NULL)"),
+            },
             ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, 
{v:?})"),
             ScalarValue::Null => write!(f, "NULL"),
         }
diff --git a/datafusion/physical-plan/src/filter.rs 
b/datafusion/physical-plan/src/filter.rs
index 72f885a939..f44ade7106 100644
--- a/datafusion/physical-plan/src/filter.rs
+++ b/datafusion/physical-plan/src/filter.rs
@@ -441,7 +441,9 @@ mod tests {
     use crate::test::exec::StatisticsExec;
     use crate::ExecutionPlan;
 
+    use crate::empty::EmptyExec;
     use arrow::datatypes::{DataType, Field, Schema};
+    use arrow_schema::{UnionFields, UnionMode};
     use datafusion_common::{ColumnStatistics, ScalarValue};
     use datafusion_expr::Operator;
 
@@ -1090,4 +1092,37 @@ mod tests {
         assert_eq!(statistics.total_byte_size, Precision::Inexact(1600));
         Ok(())
     }
+
+    #[test]
+    fn test_equivalence_properties_union_type() -> Result<()> {
+        let union_type = DataType::Union(
+            UnionFields::new(
+                vec![0, 1],
+                vec![
+                    Field::new("f1", DataType::Int32, true),
+                    Field::new("f2", DataType::Utf8, true),
+                ],
+            ),
+            UnionMode::Sparse,
+        );
+
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("c1", DataType::Int32, true),
+            Field::new("c2", union_type, true),
+        ]));
+
+        let exec = FilterExec::try_new(
+            binary(
+                binary(col("c1", &schema)?, Operator::GtEq, lit(1i32), 
&schema)?,
+                Operator::And,
+                binary(col("c1", &schema)?, Operator::LtEq, lit(4i32), 
&schema)?,
+                &schema,
+            )?,
+            Arc::new(EmptyExec::new(schema.clone())),
+        )?;
+
+        exec.statistics().unwrap();
+
+        Ok(())
+    }
 }
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 7a9b427ce7..10f79a2b8c 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -988,6 +988,20 @@ message IntervalMonthDayNanoValue {
   int64 nanos = 3;
 }
 
+message UnionField {
+  int32 field_id = 1;
+  Field field = 2;
+}
+
+message UnionValue {
+  // Note that a null union value must have one or more fields, so we
+  // encode a null UnionValue as one with value_id == 128
+  int32 value_id = 1;
+  ScalarValue value = 2;
+  repeated UnionField fields = 3;
+  UnionMode mode = 4;
+}
+
 message ScalarFixedSizeBinary{
   bytes values = 1;
   int32 length = 2;
@@ -1042,6 +1056,7 @@ message ScalarValue{
     ScalarTime64Value time64_value = 30;
     IntervalMonthDayNanoValue interval_month_day_nano = 31;
     ScalarFixedSizeBinary fixed_size_binary_value = 34;
+    UnionValue union_value = 42;
   }
 }
 
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index fd27520b3b..7757a64ef3 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -24053,6 +24053,9 @@ impl serde::Serialize for ScalarValue {
                 scalar_value::Value::FixedSizeBinaryValue(v) => {
                     struct_ser.serialize_field("fixedSizeBinaryValue", v)?;
                 }
+                scalar_value::Value::UnionValue(v) => {
+                    struct_ser.serialize_field("unionValue", v)?;
+                }
             }
         }
         struct_ser.end()
@@ -24137,6 +24140,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
             "intervalMonthDayNano",
             "fixed_size_binary_value",
             "fixedSizeBinaryValue",
+            "union_value",
+            "unionValue",
         ];
 
         #[allow(clippy::enum_variant_names)]
@@ -24177,6 +24182,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
             Time64Value,
             IntervalMonthDayNano,
             FixedSizeBinaryValue,
+            UnionValue,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -24234,6 +24240,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
                             "time64Value" | "time64_value" => 
Ok(GeneratedField::Time64Value),
                             "intervalMonthDayNano" | "interval_month_day_nano" 
=> Ok(GeneratedField::IntervalMonthDayNano),
                             "fixedSizeBinaryValue" | "fixed_size_binary_value" 
=> Ok(GeneratedField::FixedSizeBinaryValue),
+                            "unionValue" | "union_value" => 
Ok(GeneratedField::UnionValue),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -24483,6 +24490,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
                                 return 
Err(serde::de::Error::duplicate_field("fixedSizeBinaryValue"));
                             }
                             value__ = 
map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeBinaryValue)
+;
+                        }
+                        GeneratedField::UnionValue => {
+                            if value__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("unionValue"));
+                            }
+                            value__ = 
map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::UnionValue)
 ;
                         }
                     }
@@ -26942,6 +26956,117 @@ impl<'de> serde::Deserialize<'de> for UnionExecNode {
         deserializer.deserialize_struct("datafusion.UnionExecNode", FIELDS, 
GeneratedVisitor)
     }
 }
+impl serde::Serialize for UnionField {
+    #[allow(deprecated)]
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        use serde::ser::SerializeStruct;
+        let mut len = 0;
+        if self.field_id != 0 {
+            len += 1;
+        }
+        if self.field.is_some() {
+            len += 1;
+        }
+        let mut struct_ser = 
serializer.serialize_struct("datafusion.UnionField", len)?;
+        if self.field_id != 0 {
+            struct_ser.serialize_field("fieldId", &self.field_id)?;
+        }
+        if let Some(v) = self.field.as_ref() {
+            struct_ser.serialize_field("field", v)?;
+        }
+        struct_ser.end()
+    }
+}
+impl<'de> serde::Deserialize<'de> for UnionField {
+    #[allow(deprecated)]
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        const FIELDS: &[&str] = &[
+            "field_id",
+            "fieldId",
+            "field",
+        ];
+
+        #[allow(clippy::enum_variant_names)]
+        enum GeneratedField {
+            FieldId,
+            Field,
+        }
+        impl<'de> serde::Deserialize<'de> for GeneratedField {
+            fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
+            where
+                D: serde::Deserializer<'de>,
+            {
+                struct GeneratedVisitor;
+
+                impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+                    type Value = GeneratedField;
+
+                    fn expecting(&self, formatter: &mut 
std::fmt::Formatter<'_>) -> std::fmt::Result {
+                        write!(formatter, "expected one of: {:?}", &FIELDS)
+                    }
+
+                    #[allow(unused_variables)]
+                    fn visit_str<E>(self, value: &str) -> 
std::result::Result<GeneratedField, E>
+                    where
+                        E: serde::de::Error,
+                    {
+                        match value {
+                            "fieldId" | "field_id" => 
Ok(GeneratedField::FieldId),
+                            "field" => Ok(GeneratedField::Field),
+                            _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
+                        }
+                    }
+                }
+                deserializer.deserialize_identifier(GeneratedVisitor)
+            }
+        }
+        struct GeneratedVisitor;
+        impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+            type Value = UnionField;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> 
std::fmt::Result {
+                formatter.write_str("struct datafusion.UnionField")
+            }
+
+            fn visit_map<V>(self, mut map_: V) -> 
std::result::Result<UnionField, V::Error>
+                where
+                    V: serde::de::MapAccess<'de>,
+            {
+                let mut field_id__ = None;
+                let mut field__ = None;
+                while let Some(k) = map_.next_key()? {
+                    match k {
+                        GeneratedField::FieldId => {
+                            if field_id__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("fieldId"));
+                            }
+                            field_id__ = 
+                                
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+                            ;
+                        }
+                        GeneratedField::Field => {
+                            if field__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("field"));
+                            }
+                            field__ = map_.next_value()?;
+                        }
+                    }
+                }
+                Ok(UnionField {
+                    field_id: field_id__.unwrap_or_default(),
+                    field: field__,
+                })
+            }
+        }
+        deserializer.deserialize_struct("datafusion.UnionField", FIELDS, 
GeneratedVisitor)
+    }
+}
 impl serde::Serialize for UnionMode {
     #[allow(deprecated)]
     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
@@ -27104,6 +27229,153 @@ impl<'de> serde::Deserialize<'de> for UnionNode {
         deserializer.deserialize_struct("datafusion.UnionNode", FIELDS, 
GeneratedVisitor)
     }
 }
+impl serde::Serialize for UnionValue {
+    #[allow(deprecated)]
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        use serde::ser::SerializeStruct;
+        let mut len = 0;
+        if self.value_id != 0 {
+            len += 1;
+        }
+        if self.value.is_some() {
+            len += 1;
+        }
+        if !self.fields.is_empty() {
+            len += 1;
+        }
+        if self.mode != 0 {
+            len += 1;
+        }
+        let mut struct_ser = 
serializer.serialize_struct("datafusion.UnionValue", len)?;
+        if self.value_id != 0 {
+            struct_ser.serialize_field("valueId", &self.value_id)?;
+        }
+        if let Some(v) = self.value.as_ref() {
+            struct_ser.serialize_field("value", v)?;
+        }
+        if !self.fields.is_empty() {
+            struct_ser.serialize_field("fields", &self.fields)?;
+        }
+        if self.mode != 0 {
+            let v = UnionMode::try_from(self.mode)
+                .map_err(|_| serde::ser::Error::custom(format!("Invalid 
variant {}", self.mode)))?;
+            struct_ser.serialize_field("mode", &v)?;
+        }
+        struct_ser.end()
+    }
+}
+impl<'de> serde::Deserialize<'de> for UnionValue {
+    #[allow(deprecated)]
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        const FIELDS: &[&str] = &[
+            "value_id",
+            "valueId",
+            "value",
+            "fields",
+            "mode",
+        ];
+
+        #[allow(clippy::enum_variant_names)]
+        enum GeneratedField {
+            ValueId,
+            Value,
+            Fields,
+            Mode,
+        }
+        impl<'de> serde::Deserialize<'de> for GeneratedField {
+            fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
+            where
+                D: serde::Deserializer<'de>,
+            {
+                struct GeneratedVisitor;
+
+                impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+                    type Value = GeneratedField;
+
+                    fn expecting(&self, formatter: &mut 
std::fmt::Formatter<'_>) -> std::fmt::Result {
+                        write!(formatter, "expected one of: {:?}", &FIELDS)
+                    }
+
+                    #[allow(unused_variables)]
+                    fn visit_str<E>(self, value: &str) -> 
std::result::Result<GeneratedField, E>
+                    where
+                        E: serde::de::Error,
+                    {
+                        match value {
+                            "valueId" | "value_id" => 
Ok(GeneratedField::ValueId),
+                            "value" => Ok(GeneratedField::Value),
+                            "fields" => Ok(GeneratedField::Fields),
+                            "mode" => Ok(GeneratedField::Mode),
+                            _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
+                        }
+                    }
+                }
+                deserializer.deserialize_identifier(GeneratedVisitor)
+            }
+        }
+        struct GeneratedVisitor;
+        impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+            type Value = UnionValue;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> 
std::fmt::Result {
+                formatter.write_str("struct datafusion.UnionValue")
+            }
+
+            fn visit_map<V>(self, mut map_: V) -> 
std::result::Result<UnionValue, V::Error>
+                where
+                    V: serde::de::MapAccess<'de>,
+            {
+                let mut value_id__ = None;
+                let mut value__ = None;
+                let mut fields__ = None;
+                let mut mode__ = None;
+                while let Some(k) = map_.next_key()? {
+                    match k {
+                        GeneratedField::ValueId => {
+                            if value_id__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("valueId"));
+                            }
+                            value_id__ = 
+                                
Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+                            ;
+                        }
+                        GeneratedField::Value => {
+                            if value__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("value"));
+                            }
+                            value__ = map_.next_value()?;
+                        }
+                        GeneratedField::Fields => {
+                            if fields__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("fields"));
+                            }
+                            fields__ = Some(map_.next_value()?);
+                        }
+                        GeneratedField::Mode => {
+                            if mode__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("mode"));
+                            }
+                            mode__ = Some(map_.next_value::<UnionMode>()? as 
i32);
+                        }
+                    }
+                }
+                Ok(UnionValue {
+                    value_id: value_id__.unwrap_or_default(),
+                    value: value__,
+                    fields: fields__.unwrap_or_default(),
+                    mode: mode__.unwrap_or_default(),
+                })
+            }
+        }
+        deserializer.deserialize_struct("datafusion.UnionValue", FIELDS, 
GeneratedVisitor)
+    }
+}
 impl serde::Serialize for UniqueConstraint {
     #[allow(deprecated)]
     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 16ad2b848d..ab0ddb14eb 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1225,6 +1225,28 @@ pub struct IntervalMonthDayNanoValue {
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
+pub struct UnionField {
+    #[prost(int32, tag = "1")]
+    pub field_id: i32,
+    #[prost(message, optional, tag = "2")]
+    pub field: ::core::option::Option<Field>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct UnionValue {
+    /// Note that a null union value must have one or more fields, so we
+    /// encode a null UnionValue as one with value_id == 128
+    #[prost(int32, tag = "1")]
+    pub value_id: i32,
+    #[prost(message, optional, boxed, tag = "2")]
+    pub value: ::core::option::Option<::prost::alloc::boxed::Box<ScalarValue>>,
+    #[prost(message, repeated, tag = "3")]
+    pub fields: ::prost::alloc::vec::Vec<UnionField>,
+    #[prost(enumeration = "UnionMode", tag = "4")]
+    pub mode: i32,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
 pub struct ScalarFixedSizeBinary {
     #[prost(bytes = "vec", tag = "1")]
     pub values: ::prost::alloc::vec::Vec<u8>,
@@ -1236,7 +1258,7 @@ pub struct ScalarFixedSizeBinary {
 pub struct ScalarValue {
     #[prost(
         oneof = "scalar_value::Value",
-        tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 
18, 32, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 34"
+        tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 
18, 32, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 34, 42"
     )]
     pub value: ::core::option::Option<scalar_value::Value>,
 }
@@ -1320,6 +1342,8 @@ pub mod scalar_value {
         IntervalMonthDayNano(super::IntervalMonthDayNanoValue),
         #[prost(message, tag = "34")]
         FixedSizeBinaryValue(super::ScalarFixedSizeBinary),
+        #[prost(message, tag = "42")]
+        UnionValue(::prost::alloc::boxed::Box<super::UnionValue>),
     }
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index 06aab16edd..8581156e2b 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -768,6 +768,41 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
             Value::IntervalMonthDayNano(v) => Self::IntervalMonthDayNano(Some(
                 IntervalMonthDayNanoType::make_value(v.months, v.days, 
v.nanos),
             )),
+            Value::UnionValue(val) => {
+                let mode = match val.mode {
+                    0 => UnionMode::Sparse,
+                    1 => UnionMode::Dense,
+                    id => Err(Error::unknown("UnionMode", id))?,
+                };
+                let ids = val
+                    .fields
+                    .iter()
+                    .map(|f| f.field_id as i8)
+                    .collect::<Vec<_>>();
+                let fields = val
+                    .fields
+                    .iter()
+                    .map(|f| f.field.clone())
+                    .collect::<Option<Vec<_>>>();
+                let fields = fields.ok_or_else(|| 
Error::required("UnionField"))?;
+                let fields = fields
+                    .iter()
+                    .map(Field::try_from)
+                    .collect::<Result<Vec<_>, _>>()?;
+                let fields = UnionFields::new(ids, fields);
+                let v_id = val.value_id as i8;
+                let val = match &val.value {
+                    None => None,
+                    Some(val) => {
+                        let val: ScalarValue = val
+                            .as_ref()
+                            .try_into()
+                            .map_err(|_| Error::General("Invalid 
Scalar".to_string()))?;
+                        Some((v_id, Box::new(val)))
+                    }
+                };
+                Self::Union(val, fields, mode)
+            }
             Value::FixedSizeBinaryValue(v) => {
                 Self::FixedSizeBinary(v.length, Some(v.clone().values))
             }
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index 9201559490..05a29ff6d4 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -30,6 +30,7 @@ use crate::protobuf::{
     },
     AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, 
LogicalExprList,
     OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, 
RollupNode,
+    UnionField, UnionValue,
 };
 
 use arrow::{
@@ -1405,6 +1406,34 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
                 };
                 Ok(protobuf::ScalarValue { value: Some(value) })
             }
+
+            ScalarValue::Union(val, df_fields, mode) => {
+                let mut fields = 
Vec::<UnionField>::with_capacity(df_fields.len());
+                for (id, field) in df_fields.iter() {
+                    let field_id = id as i32;
+                    let field = Some(field.as_ref().try_into()?);
+                    let field = UnionField { field_id, field };
+                    fields.push(field);
+                }
+                let mode = match mode {
+                    UnionMode::Sparse => 0,
+                    UnionMode::Dense => 1,
+                };
+                let value = match val {
+                    None => None,
+                    Some((_id, v)) => Some(Box::new(v.as_ref().try_into()?)),
+                };
+                let val = UnionValue {
+                    value_id: val.as_ref().map(|(id, _v)| *id as 
i32).unwrap_or(0),
+                    value,
+                    fields,
+                    mode,
+                };
+                let val = Value::UnionValue(Box::new(val));
+                let val = protobuf::ScalarValue { value: Some(val) };
+                Ok(val)
+            }
+
             ScalarValue::Dictionary(index_type, val) => {
                 let value: protobuf::ScalarValue = val.as_ref().try_into()?;
                 Ok(protobuf::ScalarValue {
diff --git a/datafusion/sql/src/unparser/expr.rs 
b/datafusion/sql/src/unparser/expr.rs
index c26e8481ce..43f3e348dc 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -456,6 +456,7 @@ impl Unparser<'_> {
                 Ok(ast::Expr::Value(ast::Value::Null))
             }
             ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar: 
{v:?}"),
+            ScalarValue::Union(..) => not_impl_err!("Unsupported scalar: 
{v:?}"),
             ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar: 
{v:?}"),
         }
     }

Reply via email to