This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new f80dde06ec feat: support Map literals in Substrait consumer and
producer (#11547)
f80dde06ec is described below
commit f80dde06ec2cbe06938e7335facb8e30100ddb9f
Author: Arttu <[email protected]>
AuthorDate: Tue Jul 23 18:36:51 2024 +0200
feat: support Map literals in Substrait consumer and producer (#11547)
* implement Map literals/nulls conversions in Substrait
* fix name handling for lists/maps containing structs
* add hashing for map scalars
* add a test for creating a map in VALUES
* fix clipppy
* better test
* use MapBuilder in test
* fix hash test
* remove unnecessary type variation checks from maps
---
datafusion/common/src/hash_utils.rs | 102 ++++++++++++++-
datafusion/common/src/scalar/mod.rs | 2 +-
datafusion/sqllogictest/test_files/map.slt | 8 ++
datafusion/substrait/src/logical_plan/consumer.rs | 143 ++++++++++++++++++---
datafusion/substrait/src/logical_plan/producer.rs | 76 ++++++++++-
.../tests/cases/roundtrip_logical_plan.rs | 4 +-
6 files changed, 308 insertions(+), 27 deletions(-)
diff --git a/datafusion/common/src/hash_utils.rs
b/datafusion/common/src/hash_utils.rs
index c8adae34f6..010221b048 100644
--- a/datafusion/common/src/hash_utils.rs
+++ b/datafusion/common/src/hash_utils.rs
@@ -29,8 +29,8 @@ use arrow_buffer::IntervalMonthDayNano;
use crate::cast::{
as_boolean_array, as_fixed_size_list_array, as_generic_binary_array,
- as_large_list_array, as_list_array, as_primitive_array, as_string_array,
- as_struct_array,
+ as_large_list_array, as_list_array, as_map_array, as_primitive_array,
+ as_string_array, as_struct_array,
};
use crate::error::{Result, _internal_err};
@@ -236,6 +236,40 @@ fn hash_struct_array(
Ok(())
}
+fn hash_map_array(
+ array: &MapArray,
+ random_state: &RandomState,
+ hashes_buffer: &mut [u64],
+) -> Result<()> {
+ let nulls = array.nulls();
+ let offsets = array.offsets();
+
+ // Create hashes for each entry in each row
+ let mut values_hashes = vec![0u64; array.entries().len()];
+ create_hashes(array.entries().columns(), random_state, &mut
values_hashes)?;
+
+ // Combine the hashes for entries on each row with each other and previous
hash for that row
+ if let Some(nulls) = nulls {
+ for (i, (start, stop)) in
offsets.iter().zip(offsets.iter().skip(1)).enumerate() {
+ if nulls.is_valid(i) {
+ let hash = &mut hashes_buffer[i];
+ for values_hash in
&values_hashes[start.as_usize()..stop.as_usize()] {
+ *hash = combine_hashes(*hash, *values_hash);
+ }
+ }
+ }
+ } else {
+ for (i, (start, stop)) in
offsets.iter().zip(offsets.iter().skip(1)).enumerate() {
+ let hash = &mut hashes_buffer[i];
+ for values_hash in
&values_hashes[start.as_usize()..stop.as_usize()] {
+ *hash = combine_hashes(*hash, *values_hash);
+ }
+ }
+ }
+
+ Ok(())
+}
+
fn hash_list_array<OffsetSize>(
array: &GenericListArray<OffsetSize>,
random_state: &RandomState,
@@ -400,6 +434,10 @@ pub fn create_hashes<'a>(
let array = as_large_list_array(array)?;
hash_list_array(array, random_state, hashes_buffer)?;
}
+ DataType::Map(_, _) => {
+ let array = as_map_array(array)?;
+ hash_map_array(array, random_state, hashes_buffer)?;
+ }
DataType::FixedSizeList(_,_) => {
let array = as_fixed_size_list_array(array)?;
hash_fixed_list_array(array, random_state, hashes_buffer)?;
@@ -572,6 +610,7 @@ mod tests {
Some(vec![Some(3), None, Some(5)]),
None,
Some(vec![Some(0), Some(1), Some(2)]),
+ Some(vec![]),
];
let list_array =
Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(data))
as ArrayRef;
@@ -581,6 +620,7 @@ mod tests {
assert_eq!(hashes[0], hashes[5]);
assert_eq!(hashes[1], hashes[4]);
assert_eq!(hashes[2], hashes[3]);
+ assert_eq!(hashes[1], hashes[6]); // null vs empty list
}
#[test]
@@ -692,6 +732,64 @@ mod tests {
assert_eq!(hashes[0], hashes[1]);
}
+ #[test]
+ // Tests actual values of hashes, which are different if forcing collisions
+ #[cfg(not(feature = "force_hash_collisions"))]
+ fn create_hashes_for_map_arrays() {
+ let mut builder =
+ MapBuilder::new(None, StringBuilder::new(), Int32Builder::new());
+ // Row 0
+ builder.keys().append_value("key1");
+ builder.keys().append_value("key2");
+ builder.values().append_value(1);
+ builder.values().append_value(2);
+ builder.append(true).unwrap();
+ // Row 1
+ builder.keys().append_value("key1");
+ builder.keys().append_value("key2");
+ builder.values().append_value(1);
+ builder.values().append_value(2);
+ builder.append(true).unwrap();
+ // Row 2
+ builder.keys().append_value("key1");
+ builder.keys().append_value("key2");
+ builder.values().append_value(1);
+ builder.values().append_value(3);
+ builder.append(true).unwrap();
+ // Row 3
+ builder.keys().append_value("key1");
+ builder.keys().append_value("key3");
+ builder.values().append_value(1);
+ builder.values().append_value(2);
+ builder.append(true).unwrap();
+ // Row 4
+ builder.keys().append_value("key1");
+ builder.values().append_value(1);
+ builder.append(true).unwrap();
+ // Row 5
+ builder.keys().append_value("key1");
+ builder.values().append_null();
+ builder.append(true).unwrap();
+ // Row 6
+ builder.append(true).unwrap();
+ // Row 7
+ builder.keys().append_value("key1");
+ builder.values().append_value(1);
+ builder.append(false).unwrap();
+
+ let array = Arc::new(builder.finish()) as ArrayRef;
+
+ let random_state = RandomState::with_seeds(0, 0, 0, 0);
+ let mut hashes = vec![0; array.len()];
+ create_hashes(&[array], &random_state, &mut hashes).unwrap();
+ assert_eq!(hashes[0], hashes[1]); // same value
+ assert_ne!(hashes[0], hashes[2]); // different value
+ assert_ne!(hashes[0], hashes[3]); // different key
+ assert_ne!(hashes[0], hashes[4]); // missing an entry
+ assert_ne!(hashes[4], hashes[5]); // filled vs null value
+ assert_eq!(hashes[6], hashes[7]); // empty vs null map
+ }
+
#[test]
// Tests actual values of hashes, which are different if forcing collisions
#[cfg(not(feature = "force_hash_collisions"))]
diff --git a/datafusion/common/src/scalar/mod.rs
b/datafusion/common/src/scalar/mod.rs
index 0651013901..92ed897e71 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -1770,6 +1770,7 @@ impl ScalarValue {
}
DataType::List(_)
| DataType::LargeList(_)
+ | DataType::Map(_, _)
| DataType::Struct(_)
| DataType::Union(_, _) => {
let arrays = scalars.map(|s|
s.to_array()).collect::<Result<Vec<_>>>()?;
@@ -1838,7 +1839,6 @@ impl ScalarValue {
| DataType::Time32(TimeUnit::Nanosecond)
| DataType::Time64(TimeUnit::Second)
| DataType::Time64(TimeUnit::Millisecond)
- | DataType::Map(_, _)
| DataType::RunEndEncoded(_, _)
| DataType::ListView(_)
| DataType::LargeListView(_) => {
diff --git a/datafusion/sqllogictest/test_files/map.slt
b/datafusion/sqllogictest/test_files/map.slt
index 26bfb4a592..e530e14df6 100644
--- a/datafusion/sqllogictest/test_files/map.slt
+++ b/datafusion/sqllogictest/test_files/map.slt
@@ -302,3 +302,11 @@ SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'),
'LargeList(Utf8)'), a
{POST: 41, HEAD: 33, PATCH: 30}
{POST: 41, HEAD: 33, PATCH: 30}
{POST: 41, HEAD: 33, PATCH: 30}
+
+
+query ?
+VALUES (MAP(['a'], [1])), (MAP(['b'], [2])), (MAP(['c', 'a'], [3, 1]))
+----
+{a: 1}
+{b: 2}
+{c: 3, a: 1}
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 5768c44bbf..15c4471148 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -15,9 +15,9 @@
// specific language governing permissions and limitations
// under the License.
-use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
+use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, OffsetBuffer};
use async_recursion::async_recursion;
-use datafusion::arrow::array::GenericListArray;
+use datafusion::arrow::array::{GenericListArray, MapArray};
use datafusion::arrow::datatypes::{
DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
};
@@ -51,6 +51,7 @@ use crate::variation_const::{
INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF,
INTERVAL_YEAR_MONTH_TYPE_REF,
};
+use datafusion::arrow::array::{new_empty_array, AsArray};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{
@@ -1449,21 +1450,14 @@ fn from_substrait_type(
from_substrait_type(value_type, extensions, dfs_names,
name_idx)?,
true,
));
- match map.type_variation_reference {
- DEFAULT_CONTAINER_TYPE_VARIATION_REF => {
- Ok(DataType::Map(
- Arc::new(Field::new_struct(
- "entries",
- [key_field, value_field],
- false, // The inner map field is always
non-nullable (Arrow #1697),
- )),
- false,
- ))
- }
- v => not_impl_err!(
- "Unsupported Substrait type variation {v} of type
{s_kind:?}"
- )?,
- }
+ Ok(DataType::Map(
+ Arc::new(Field::new_struct(
+ "entries",
+ [key_field, value_field],
+ false, // The inner map field is always non-nullable
(Arrow #1697),
+ )),
+ false, // whether keys are sorted
+ ))
}
r#type::Kind::Decimal(d) => match d.type_variation_reference {
DECIMAL_128_TYPE_VARIATION_REF => {
@@ -1743,11 +1737,23 @@ fn from_substrait_literal(
)
}
Some(LiteralType::List(l)) => {
+ // Each element should start the name index from the same value,
then we increase it
+ // once at the end
+ let mut element_name_idx = *name_idx;
let elements = l
.values
.iter()
- .map(|el| from_substrait_literal(el, extensions, dfs_names,
name_idx))
+ .map(|el| {
+ element_name_idx = *name_idx;
+ from_substrait_literal(
+ el,
+ extensions,
+ dfs_names,
+ &mut element_name_idx,
+ )
+ })
.collect::<Result<Vec<_>>>()?;
+ *name_idx = element_name_idx;
if elements.is_empty() {
return substrait_err!(
"Empty list must be encoded as EmptyList literal type, not
List"
@@ -1785,6 +1791,84 @@ fn from_substrait_literal(
}
}
}
+ Some(LiteralType::Map(m)) => {
+ // Each entry should start the name index from the same value,
then we increase it
+ // once at the end
+ let mut entry_name_idx = *name_idx;
+ let entries = m
+ .key_values
+ .iter()
+ .map(|kv| {
+ entry_name_idx = *name_idx;
+ let key_sv = from_substrait_literal(
+ kv.key.as_ref().unwrap(),
+ extensions,
+ dfs_names,
+ &mut entry_name_idx,
+ )?;
+ let value_sv = from_substrait_literal(
+ kv.value.as_ref().unwrap(),
+ extensions,
+ dfs_names,
+ &mut entry_name_idx,
+ )?;
+ ScalarStructBuilder::new()
+ .with_scalar(Field::new("key", key_sv.data_type(),
false), key_sv)
+ .with_scalar(
+ Field::new("value", value_sv.data_type(), true),
+ value_sv,
+ )
+ .build()
+ })
+ .collect::<Result<Vec<_>>>()?;
+ *name_idx = entry_name_idx;
+
+ if entries.is_empty() {
+ return substrait_err!(
+ "Empty map must be encoded as EmptyMap literal type, not
Map"
+ );
+ }
+
+ ScalarValue::Map(Arc::new(MapArray::new(
+ Arc::new(Field::new("entries", entries[0].data_type(), false)),
+ OffsetBuffer::new(vec![0, entries.len() as i32].into()),
+ ScalarValue::iter_to_array(entries)?.as_struct().to_owned(),
+ None,
+ false,
+ )))
+ }
+ Some(LiteralType::EmptyMap(m)) => {
+ let key = match &m.key {
+ Some(k) => Ok(k),
+ _ => plan_err!("Missing key type for empty map"),
+ }?;
+ let value = match &m.value {
+ Some(v) => Ok(v),
+ _ => plan_err!("Missing value type for empty map"),
+ }?;
+ let key_type = from_substrait_type(key, extensions, dfs_names,
name_idx)?;
+ let value_type = from_substrait_type(value, extensions, dfs_names,
name_idx)?;
+
+ // new_empty_array on a MapType creates a too empty array
+ // We want it to contain an empty struct array to align with an
empty MapBuilder one
+ let entries = Field::new_struct(
+ "entries",
+ vec![
+ Field::new("key", key_type, false),
+ Field::new("value", value_type, true),
+ ],
+ false,
+ );
+ let struct_array =
+ new_empty_array(entries.data_type()).as_struct().to_owned();
+ ScalarValue::Map(Arc::new(MapArray::new(
+ Arc::new(entries),
+ OffsetBuffer::new(vec![0, 0].into()),
+ struct_array,
+ None,
+ false,
+ )))
+ }
Some(LiteralType::Struct(s)) => {
let mut builder = ScalarStructBuilder::new();
for (i, field) in s.fields.iter().enumerate() {
@@ -2013,6 +2097,29 @@ fn from_substrait_null(
),
}
}
+ r#type::Kind::Map(map) => {
+ let key_type = map.key.as_ref().ok_or_else(|| {
+ substrait_datafusion_err!("Map type must have key type")
+ })?;
+ let value_type = map.value.as_ref().ok_or_else(|| {
+ substrait_datafusion_err!("Map type must have value type")
+ })?;
+
+ let key_type =
+ from_substrait_type(key_type, extensions, dfs_names,
name_idx)?;
+ let value_type =
+ from_substrait_type(value_type, extensions, dfs_names,
name_idx)?;
+ let entries_field = Arc::new(Field::new_struct(
+ "entries",
+ vec![
+ Field::new("key", key_type, false),
+ Field::new("value", value_type, true),
+ ],
+ false,
+ ));
+
+ DataType::Map(entries_field, false /* keys sorted
*/).try_into()
+ }
r#type::Kind::Struct(s) => {
let fields =
from_substrait_struct_type(s, extensions, dfs_names,
name_idx)?;
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 8f69cc5e21..8263209ffc 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -57,8 +57,10 @@ use datafusion::logical_expr::{expr, Between,
JoinConstraint, LogicalPlan, Opera
use datafusion::prelude::Expr;
use pbjson_types::Any as ProtoAny;
use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
+use substrait::proto::expression::literal::map::KeyValue;
use substrait::proto::expression::literal::{
- user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Struct,
UserDefined,
+ user_defined, IntervalDayToSecond, IntervalYearToMonth, List, Map, Struct,
+ UserDefined,
};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
@@ -1922,6 +1924,48 @@ fn to_substrait_literal(
convert_array_to_literal_list(l, extensions)?,
LARGE_CONTAINER_TYPE_VARIATION_REF,
),
+ ScalarValue::Map(m) => {
+ let map = if m.is_empty() || m.value(0).is_empty() {
+ let mt = to_substrait_type(m.data_type(), m.is_nullable(),
extensions)?;
+ let mt = match mt {
+ substrait::proto::Type {
+ kind: Some(r#type::Kind::Map(mt)),
+ } => Ok(mt.as_ref().to_owned()),
+ _ => exec_err!("Unexpected type for a map: {mt:?}"),
+ }?;
+ LiteralType::EmptyMap(mt)
+ } else {
+ let keys = (0..m.keys().len())
+ .map(|i| {
+ to_substrait_literal(
+ &ScalarValue::try_from_array(&m.keys(), i)?,
+ extensions,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let values = (0..m.values().len())
+ .map(|i| {
+ to_substrait_literal(
+ &ScalarValue::try_from_array(&m.values(), i)?,
+ extensions,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let key_values = keys
+ .into_iter()
+ .zip(values.into_iter())
+ .map(|(k, v)| {
+ Ok(KeyValue {
+ key: Some(k),
+ value: Some(v),
+ })
+ })
+ .collect::<Result<Vec<_>>>()?;
+ LiteralType::Map(Map { key_values })
+ };
+ (map, DEFAULT_CONTAINER_TYPE_VARIATION_REF)
+ }
ScalarValue::Struct(s) => (
LiteralType::Struct(Struct {
fields: s
@@ -1967,7 +2011,7 @@ fn convert_array_to_literal_list<T: OffsetSizeTrait>(
.collect::<Result<Vec<_>>>()?;
if values.is_empty() {
- let et = match to_substrait_type(
+ let lt = match to_substrait_type(
array.data_type(),
array.is_nullable(),
extensions,
@@ -1977,7 +2021,7 @@ fn convert_array_to_literal_list<T: OffsetSizeTrait>(
} => lt.as_ref().to_owned(),
_ => unreachable!(),
};
- Ok(LiteralType::EmptyList(et))
+ Ok(LiteralType::EmptyList(lt))
} else {
Ok(LiteralType::List(List { values }))
}
@@ -2094,7 +2138,9 @@ mod test {
from_substrait_literal_without_names,
from_substrait_type_without_names,
};
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
- use datafusion::arrow::array::GenericListArray;
+ use datafusion::arrow::array::{
+ GenericListArray, Int64Builder, MapBuilder, StringBuilder,
+ };
use datafusion::arrow::datatypes::Field;
use datafusion::common::scalar::ScalarStructBuilder;
use std::collections::HashMap;
@@ -2160,6 +2206,28 @@ mod test {
),
)))?;
+ // Null map
+ let mut map_builder =
+ MapBuilder::new(None, StringBuilder::new(), Int64Builder::new());
+ map_builder.append(false)?;
+ round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?;
+
+ // Empty map
+ let mut map_builder =
+ MapBuilder::new(None, StringBuilder::new(), Int64Builder::new());
+ map_builder.append(true)?;
+ round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?;
+
+ // Valid map
+ let mut map_builder =
+ MapBuilder::new(None, StringBuilder::new(), Int64Builder::new());
+ map_builder.keys().append_value("key1");
+ map_builder.keys().append_value("key2");
+ map_builder.values().append_value(1);
+ map_builder.values().append_value(2);
+ map_builder.append(true)?;
+ round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?;
+
let c0 = Field::new("c0", DataType::Boolean, true);
let c1 = Field::new("c1", DataType::Int32, true);
let c2 = Field::new("c2", DataType::Utf8, true);
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 5b4389c832..439e3efa29 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -749,7 +749,7 @@ async fn roundtrip_values() -> Result<()> {
[[-213.1, NULL, 5.5, 2.0, 1.0], []], \
arrow_cast([1,2,3], 'LargeList(Int64)'), \
STRUCT(true, 1 AS int_field, CAST(NULL AS STRING)), \
- [STRUCT(STRUCT('a' AS string_field) AS struct_field)]\
+ [STRUCT(STRUCT('a' AS string_field) AS struct_field),
STRUCT(STRUCT('b' AS string_field) AS struct_field)]\
), \
(NULL, NULL, NULL, NULL, NULL, NULL)",
"Values: \
@@ -759,7 +759,7 @@ async fn roundtrip_values() -> Result<()> {
List([[-213.1, , 5.5, 2.0, 1.0], []]), \
LargeList([1, 2, 3]), \
Struct({c0:true,int_field:1,c2:}), \
- List([{struct_field: {string_field: a}}])\
+ List([{struct_field: {string_field: a}}, {struct_field:
{string_field: b}}])\
), \
(Int64(NULL), Utf8(NULL), List(), LargeList(),
Struct({c0:,int_field:,c2:}), List())",
true).await
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]