This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new f80dde06ec feat: support Map literals in Substrait consumer and 
producer (#11547)
f80dde06ec is described below

commit f80dde06ec2cbe06938e7335facb8e30100ddb9f
Author: Arttu <[email protected]>
AuthorDate: Tue Jul 23 18:36:51 2024 +0200

    feat: support Map literals in Substrait consumer and producer (#11547)
    
    * implement Map literals/nulls  conversions in Substrait
    
    * fix name handling for lists/maps containing structs
    
    * add hashing for map scalars
    
    * add a test for creating a map in VALUES
    
    * fix clipppy
    
    * better test
    
    * use MapBuilder in test
    
    * fix hash test
    
    * remove unnecessary type variation checks from maps
---
 datafusion/common/src/hash_utils.rs                | 102 ++++++++++++++-
 datafusion/common/src/scalar/mod.rs                |   2 +-
 datafusion/sqllogictest/test_files/map.slt         |   8 ++
 datafusion/substrait/src/logical_plan/consumer.rs  | 143 ++++++++++++++++++---
 datafusion/substrait/src/logical_plan/producer.rs  |  76 ++++++++++-
 .../tests/cases/roundtrip_logical_plan.rs          |   4 +-
 6 files changed, 308 insertions(+), 27 deletions(-)

diff --git a/datafusion/common/src/hash_utils.rs 
b/datafusion/common/src/hash_utils.rs
index c8adae34f6..010221b048 100644
--- a/datafusion/common/src/hash_utils.rs
+++ b/datafusion/common/src/hash_utils.rs
@@ -29,8 +29,8 @@ use arrow_buffer::IntervalMonthDayNano;
 
 use crate::cast::{
     as_boolean_array, as_fixed_size_list_array, as_generic_binary_array,
-    as_large_list_array, as_list_array, as_primitive_array, as_string_array,
-    as_struct_array,
+    as_large_list_array, as_list_array, as_map_array, as_primitive_array,
+    as_string_array, as_struct_array,
 };
 use crate::error::{Result, _internal_err};
 
@@ -236,6 +236,40 @@ fn hash_struct_array(
     Ok(())
 }
 
+fn hash_map_array(
+    array: &MapArray,
+    random_state: &RandomState,
+    hashes_buffer: &mut [u64],
+) -> Result<()> {
+    let nulls = array.nulls();
+    let offsets = array.offsets();
+
+    // Create hashes for each entry in each row
+    let mut values_hashes = vec![0u64; array.entries().len()];
+    create_hashes(array.entries().columns(), random_state, &mut 
values_hashes)?;
+
+    // Combine the hashes for entries on each row with each other and previous 
hash for that row
+    if let Some(nulls) = nulls {
+        for (i, (start, stop)) in 
offsets.iter().zip(offsets.iter().skip(1)).enumerate() {
+            if nulls.is_valid(i) {
+                let hash = &mut hashes_buffer[i];
+                for values_hash in 
&values_hashes[start.as_usize()..stop.as_usize()] {
+                    *hash = combine_hashes(*hash, *values_hash);
+                }
+            }
+        }
+    } else {
+        for (i, (start, stop)) in 
offsets.iter().zip(offsets.iter().skip(1)).enumerate() {
+            let hash = &mut hashes_buffer[i];
+            for values_hash in 
&values_hashes[start.as_usize()..stop.as_usize()] {
+                *hash = combine_hashes(*hash, *values_hash);
+            }
+        }
+    }
+
+    Ok(())
+}
+
 fn hash_list_array<OffsetSize>(
     array: &GenericListArray<OffsetSize>,
     random_state: &RandomState,
@@ -400,6 +434,10 @@ pub fn create_hashes<'a>(
                 let array = as_large_list_array(array)?;
                 hash_list_array(array, random_state, hashes_buffer)?;
             }
+            DataType::Map(_, _) => {
+                let array = as_map_array(array)?;
+                hash_map_array(array, random_state, hashes_buffer)?;
+            }
             DataType::FixedSizeList(_,_) => {
                 let array = as_fixed_size_list_array(array)?;
                 hash_fixed_list_array(array, random_state, hashes_buffer)?;
@@ -572,6 +610,7 @@ mod tests {
             Some(vec![Some(3), None, Some(5)]),
             None,
             Some(vec![Some(0), Some(1), Some(2)]),
+            Some(vec![]),
         ];
         let list_array =
             Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(data)) 
as ArrayRef;
@@ -581,6 +620,7 @@ mod tests {
         assert_eq!(hashes[0], hashes[5]);
         assert_eq!(hashes[1], hashes[4]);
         assert_eq!(hashes[2], hashes[3]);
+        assert_eq!(hashes[1], hashes[6]); // null vs empty list
     }
 
     #[test]
@@ -692,6 +732,64 @@ mod tests {
         assert_eq!(hashes[0], hashes[1]);
     }
 
+    #[test]
+    // Tests actual values of hashes, which are different if forcing collisions
+    #[cfg(not(feature = "force_hash_collisions"))]
+    fn create_hashes_for_map_arrays() {
+        let mut builder =
+            MapBuilder::new(None, StringBuilder::new(), Int32Builder::new());
+        // Row 0
+        builder.keys().append_value("key1");
+        builder.keys().append_value("key2");
+        builder.values().append_value(1);
+        builder.values().append_value(2);
+        builder.append(true).unwrap();
+        // Row 1
+        builder.keys().append_value("key1");
+        builder.keys().append_value("key2");
+        builder.values().append_value(1);
+        builder.values().append_value(2);
+        builder.append(true).unwrap();
+        // Row 2
+        builder.keys().append_value("key1");
+        builder.keys().append_value("key2");
+        builder.values().append_value(1);
+        builder.values().append_value(3);
+        builder.append(true).unwrap();
+        // Row 3
+        builder.keys().append_value("key1");
+        builder.keys().append_value("key3");
+        builder.values().append_value(1);
+        builder.values().append_value(2);
+        builder.append(true).unwrap();
+        // Row 4
+        builder.keys().append_value("key1");
+        builder.values().append_value(1);
+        builder.append(true).unwrap();
+        // Row 5
+        builder.keys().append_value("key1");
+        builder.values().append_null();
+        builder.append(true).unwrap();
+        // Row 6
+        builder.append(true).unwrap();
+        // Row 7
+        builder.keys().append_value("key1");
+        builder.values().append_value(1);
+        builder.append(false).unwrap();
+
+        let array = Arc::new(builder.finish()) as ArrayRef;
+
+        let random_state = RandomState::with_seeds(0, 0, 0, 0);
+        let mut hashes = vec![0; array.len()];
+        create_hashes(&[array], &random_state, &mut hashes).unwrap();
+        assert_eq!(hashes[0], hashes[1]); // same value
+        assert_ne!(hashes[0], hashes[2]); // different value
+        assert_ne!(hashes[0], hashes[3]); // different key
+        assert_ne!(hashes[0], hashes[4]); // missing an entry
+        assert_ne!(hashes[4], hashes[5]); // filled vs null value
+        assert_eq!(hashes[6], hashes[7]); // empty vs null map
+    }
+
     #[test]
     // Tests actual values of hashes, which are different if forcing collisions
     #[cfg(not(feature = "force_hash_collisions"))]
diff --git a/datafusion/common/src/scalar/mod.rs 
b/datafusion/common/src/scalar/mod.rs
index 0651013901..92ed897e71 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -1770,6 +1770,7 @@ impl ScalarValue {
             }
             DataType::List(_)
             | DataType::LargeList(_)
+            | DataType::Map(_, _)
             | DataType::Struct(_)
             | DataType::Union(_, _) => {
                 let arrays = scalars.map(|s| 
s.to_array()).collect::<Result<Vec<_>>>()?;
@@ -1838,7 +1839,6 @@ impl ScalarValue {
             | DataType::Time32(TimeUnit::Nanosecond)
             | DataType::Time64(TimeUnit::Second)
             | DataType::Time64(TimeUnit::Millisecond)
-            | DataType::Map(_, _)
             | DataType::RunEndEncoded(_, _)
             | DataType::ListView(_)
             | DataType::LargeListView(_) => {
diff --git a/datafusion/sqllogictest/test_files/map.slt 
b/datafusion/sqllogictest/test_files/map.slt
index 26bfb4a592..e530e14df6 100644
--- a/datafusion/sqllogictest/test_files/map.slt
+++ b/datafusion/sqllogictest/test_files/map.slt
@@ -302,3 +302,11 @@ SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 
'LargeList(Utf8)'), a
 {POST: 41, HEAD: 33, PATCH: 30}
 {POST: 41, HEAD: 33, PATCH: 30}
 {POST: 41, HEAD: 33, PATCH: 30}
+
+
+query ?
+VALUES (MAP(['a'], [1])), (MAP(['b'], [2])), (MAP(['c', 'a'], [3, 1]))
+----
+{a: 1}
+{b: 2}
+{c: 3, a: 1}
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index 5768c44bbf..15c4471148 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -15,9 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
+use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, OffsetBuffer};
 use async_recursion::async_recursion;
-use datafusion::arrow::array::GenericListArray;
+use datafusion::arrow::array::{GenericListArray, MapArray};
 use datafusion::arrow::datatypes::{
     DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
 };
@@ -51,6 +51,7 @@ use crate::variation_const::{
     INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF,
     INTERVAL_YEAR_MONTH_TYPE_REF,
 };
+use datafusion::arrow::array::{new_empty_array, AsArray};
 use datafusion::common::scalar::ScalarStructBuilder;
 use datafusion::logical_expr::expr::InList;
 use datafusion::logical_expr::{
@@ -1449,21 +1450,14 @@ fn from_substrait_type(
                     from_substrait_type(value_type, extensions, dfs_names, 
name_idx)?,
                     true,
                 ));
-                match map.type_variation_reference {
-                    DEFAULT_CONTAINER_TYPE_VARIATION_REF => {
-                        Ok(DataType::Map(
-                            Arc::new(Field::new_struct(
-                                "entries",
-                                [key_field, value_field],
-                                false, // The inner map field is always 
non-nullable (Arrow #1697),
-                            )),
-                            false,
-                        ))
-                    }
-                    v => not_impl_err!(
-                        "Unsupported Substrait type variation {v} of type 
{s_kind:?}"
-                    )?,
-                }
+                Ok(DataType::Map(
+                    Arc::new(Field::new_struct(
+                        "entries",
+                        [key_field, value_field],
+                        false, // The inner map field is always non-nullable 
(Arrow #1697),
+                    )),
+                    false, // whether keys are sorted
+                ))
             }
             r#type::Kind::Decimal(d) => match d.type_variation_reference {
                 DECIMAL_128_TYPE_VARIATION_REF => {
@@ -1743,11 +1737,23 @@ fn from_substrait_literal(
             )
         }
         Some(LiteralType::List(l)) => {
+            // Each element should start the name index from the same value, 
then we increase it
+            // once at the end
+            let mut element_name_idx = *name_idx;
             let elements = l
                 .values
                 .iter()
-                .map(|el| from_substrait_literal(el, extensions, dfs_names, 
name_idx))
+                .map(|el| {
+                    element_name_idx = *name_idx;
+                    from_substrait_literal(
+                        el,
+                        extensions,
+                        dfs_names,
+                        &mut element_name_idx,
+                    )
+                })
                 .collect::<Result<Vec<_>>>()?;
+            *name_idx = element_name_idx;
             if elements.is_empty() {
                 return substrait_err!(
                     "Empty list must be encoded as EmptyList literal type, not 
List"
@@ -1785,6 +1791,84 @@ fn from_substrait_literal(
                 }
             }
         }
+        Some(LiteralType::Map(m)) => {
+            // Each entry should start the name index from the same value, 
then we increase it
+            // once at the end
+            let mut entry_name_idx = *name_idx;
+            let entries = m
+                .key_values
+                .iter()
+                .map(|kv| {
+                    entry_name_idx = *name_idx;
+                    let key_sv = from_substrait_literal(
+                        kv.key.as_ref().unwrap(),
+                        extensions,
+                        dfs_names,
+                        &mut entry_name_idx,
+                    )?;
+                    let value_sv = from_substrait_literal(
+                        kv.value.as_ref().unwrap(),
+                        extensions,
+                        dfs_names,
+                        &mut entry_name_idx,
+                    )?;
+                    ScalarStructBuilder::new()
+                        .with_scalar(Field::new("key", key_sv.data_type(), 
false), key_sv)
+                        .with_scalar(
+                            Field::new("value", value_sv.data_type(), true),
+                            value_sv,
+                        )
+                        .build()
+                })
+                .collect::<Result<Vec<_>>>()?;
+            *name_idx = entry_name_idx;
+
+            if entries.is_empty() {
+                return substrait_err!(
+                    "Empty map must be encoded as EmptyMap literal type, not 
Map"
+                );
+            }
+
+            ScalarValue::Map(Arc::new(MapArray::new(
+                Arc::new(Field::new("entries", entries[0].data_type(), false)),
+                OffsetBuffer::new(vec![0, entries.len() as i32].into()),
+                ScalarValue::iter_to_array(entries)?.as_struct().to_owned(),
+                None,
+                false,
+            )))
+        }
+        Some(LiteralType::EmptyMap(m)) => {
+            let key = match &m.key {
+                Some(k) => Ok(k),
+                _ => plan_err!("Missing key type for empty map"),
+            }?;
+            let value = match &m.value {
+                Some(v) => Ok(v),
+                _ => plan_err!("Missing value type for empty map"),
+            }?;
+            let key_type = from_substrait_type(key, extensions, dfs_names, 
name_idx)?;
+            let value_type = from_substrait_type(value, extensions, dfs_names, 
name_idx)?;
+
+            // new_empty_array on a MapType creates a too empty array
+            // We want it to contain an empty struct array to align with an 
empty MapBuilder one
+            let entries = Field::new_struct(
+                "entries",
+                vec![
+                    Field::new("key", key_type, false),
+                    Field::new("value", value_type, true),
+                ],
+                false,
+            );
+            let struct_array =
+                new_empty_array(entries.data_type()).as_struct().to_owned();
+            ScalarValue::Map(Arc::new(MapArray::new(
+                Arc::new(entries),
+                OffsetBuffer::new(vec![0, 0].into()),
+                struct_array,
+                None,
+                false,
+            )))
+        }
         Some(LiteralType::Struct(s)) => {
             let mut builder = ScalarStructBuilder::new();
             for (i, field) in s.fields.iter().enumerate() {
@@ -2013,6 +2097,29 @@ fn from_substrait_null(
                     ),
                 }
             }
+            r#type::Kind::Map(map) => {
+                let key_type = map.key.as_ref().ok_or_else(|| {
+                    substrait_datafusion_err!("Map type must have key type")
+                })?;
+                let value_type = map.value.as_ref().ok_or_else(|| {
+                    substrait_datafusion_err!("Map type must have value type")
+                })?;
+
+                let key_type =
+                    from_substrait_type(key_type, extensions, dfs_names, 
name_idx)?;
+                let value_type =
+                    from_substrait_type(value_type, extensions, dfs_names, 
name_idx)?;
+                let entries_field = Arc::new(Field::new_struct(
+                    "entries",
+                    vec![
+                        Field::new("key", key_type, false),
+                        Field::new("value", value_type, true),
+                    ],
+                    false,
+                ));
+
+                DataType::Map(entries_field, false /* keys sorted 
*/).try_into()
+            }
             r#type::Kind::Struct(s) => {
                 let fields =
                     from_substrait_struct_type(s, extensions, dfs_names, 
name_idx)?;
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index 8f69cc5e21..8263209ffc 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -57,8 +57,10 @@ use datafusion::logical_expr::{expr, Between, 
JoinConstraint, LogicalPlan, Opera
 use datafusion::prelude::Expr;
 use pbjson_types::Any as ProtoAny;
 use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
+use substrait::proto::expression::literal::map::KeyValue;
 use substrait::proto::expression::literal::{
-    user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Struct, 
UserDefined,
+    user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map, Struct,
+    UserDefined,
 };
 use substrait::proto::expression::subquery::InPredicate;
 use substrait::proto::expression::window_function::BoundsType;
@@ -1922,6 +1924,48 @@ fn to_substrait_literal(
             convert_array_to_literal_list(l, extensions)?,
             LARGE_CONTAINER_TYPE_VARIATION_REF,
         ),
+        ScalarValue::Map(m) => {
+            let map = if m.is_empty() || m.value(0).is_empty() {
+                let mt = to_substrait_type(m.data_type(), m.is_nullable(), 
extensions)?;
+                let mt = match mt {
+                    substrait::proto::Type {
+                        kind: Some(r#type::Kind::Map(mt)),
+                    } => Ok(mt.as_ref().to_owned()),
+                    _ => exec_err!("Unexpected type for a map: {mt:?}"),
+                }?;
+                LiteralType::EmptyMap(mt)
+            } else {
+                let keys = (0..m.keys().len())
+                    .map(|i| {
+                        to_substrait_literal(
+                            &ScalarValue::try_from_array(&m.keys(), i)?,
+                            extensions,
+                        )
+                    })
+                    .collect::<Result<Vec<_>>>()?;
+                let values = (0..m.values().len())
+                    .map(|i| {
+                        to_substrait_literal(
+                            &ScalarValue::try_from_array(&m.values(), i)?,
+                            extensions,
+                        )
+                    })
+                    .collect::<Result<Vec<_>>>()?;
+
+                let key_values = keys
+                    .into_iter()
+                    .zip(values.into_iter())
+                    .map(|(k, v)| {
+                        Ok(KeyValue {
+                            key: Some(k),
+                            value: Some(v),
+                        })
+                    })
+                    .collect::<Result<Vec<_>>>()?;
+                LiteralType::Map(Map { key_values })
+            };
+            (map, DEFAULT_CONTAINER_TYPE_VARIATION_REF)
+        }
         ScalarValue::Struct(s) => (
             LiteralType::Struct(Struct {
                 fields: s
@@ -1967,7 +2011,7 @@ fn convert_array_to_literal_list<T: OffsetSizeTrait>(
         .collect::<Result<Vec<_>>>()?;
 
     if values.is_empty() {
-        let et = match to_substrait_type(
+        let lt = match to_substrait_type(
             array.data_type(),
             array.is_nullable(),
             extensions,
@@ -1977,7 +2021,7 @@ fn convert_array_to_literal_list<T: OffsetSizeTrait>(
             } => lt.as_ref().to_owned(),
             _ => unreachable!(),
         };
-        Ok(LiteralType::EmptyList(et))
+        Ok(LiteralType::EmptyList(lt))
     } else {
         Ok(LiteralType::List(List { values }))
     }
@@ -2094,7 +2138,9 @@ mod test {
         from_substrait_literal_without_names, 
from_substrait_type_without_names,
     };
     use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
-    use datafusion::arrow::array::GenericListArray;
+    use datafusion::arrow::array::{
+        GenericListArray, Int64Builder, MapBuilder, StringBuilder,
+    };
     use datafusion::arrow::datatypes::Field;
     use datafusion::common::scalar::ScalarStructBuilder;
     use std::collections::HashMap;
@@ -2160,6 +2206,28 @@ mod test {
             ),
         )))?;
 
+        // Null map
+        let mut map_builder =
+            MapBuilder::new(None, StringBuilder::new(), Int64Builder::new());
+        map_builder.append(false)?;
+        round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?;
+
+        // Empty map
+        let mut map_builder =
+            MapBuilder::new(None, StringBuilder::new(), Int64Builder::new());
+        map_builder.append(true)?;
+        round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?;
+
+        // Valid map
+        let mut map_builder =
+            MapBuilder::new(None, StringBuilder::new(), Int64Builder::new());
+        map_builder.keys().append_value("key1");
+        map_builder.keys().append_value("key2");
+        map_builder.values().append_value(1);
+        map_builder.values().append_value(2);
+        map_builder.append(true)?;
+        round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?;
+
         let c0 = Field::new("c0", DataType::Boolean, true);
         let c1 = Field::new("c1", DataType::Int32, true);
         let c2 = Field::new("c2", DataType::Utf8, true);
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 5b4389c832..439e3efa29 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -749,7 +749,7 @@ async fn roundtrip_values() -> Result<()> {
                 [[-213.1, NULL, 5.5, 2.0, 1.0], []], \
                 arrow_cast([1,2,3], 'LargeList(Int64)'), \
                 STRUCT(true, 1 AS int_field, CAST(NULL AS STRING)), \
-                [STRUCT(STRUCT('a' AS string_field) AS struct_field)]\
+                [STRUCT(STRUCT('a' AS string_field) AS struct_field), 
STRUCT(STRUCT('b' AS string_field) AS struct_field)]\
             ), \
             (NULL, NULL, NULL, NULL, NULL, NULL)",
         "Values: \
@@ -759,7 +759,7 @@ async fn roundtrip_values() -> Result<()> {
                 List([[-213.1, , 5.5, 2.0, 1.0], []]), \
                 LargeList([1, 2, 3]), \
                 Struct({c0:true,int_field:1,c2:}), \
-                List([{struct_field: {string_field: a}}])\
+                List([{struct_field: {string_field: a}}, {struct_field: 
{string_field: b}}])\
             ), \
             (Int64(NULL), Utf8(NULL), List(), LargeList(), 
Struct({c0:,int_field:,c2:}), List())",
     true).await


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to