This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 0965455486 Return scalar result when all inputs are constants in `map`
and `make_map` (#11461)
0965455486 is described below
commit 0965455486b7dcbd8c9a5efa8d2370ca5460bb9f
Author: kamille <[email protected]>
AuthorDate: Tue Jul 16 02:06:38 2024 +0800
Return scalar result when all inputs are constants in `map` and `make_map`
(#11461)
* return scalar result when all inputs are constants.
* support convert map array to scalar.
* disable the const evaluate for Map type before impl its hash calculation.
* add tests in map.slt.
* improve error return.
* fix error.
* fix remove unused import.
* remove duplicated testcase.
* remove inline.
---
datafusion/common/src/scalar/mod.rs | 5 +-
datafusion/functions/src/core/map.rs | 34 +++++++--
.../src/simplify_expressions/expr_simplifier.rs | 27 ++++++-
datafusion/sqllogictest/test_files/map.slt | 84 ++++++++++++++++++++++
4 files changed, 143 insertions(+), 7 deletions(-)
diff --git a/datafusion/common/src/scalar/mod.rs
b/datafusion/common/src/scalar/mod.rs
index 6c03e8698e..c891e85aa5 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -2678,7 +2678,10 @@ impl ScalarValue {
DataType::Duration(TimeUnit::Nanosecond) => {
typed_cast!(array, index, DurationNanosecondArray,
DurationNanosecond)?
}
-
+ DataType::Map(_, _) => {
+ let a = array.slice(index, 1);
+ Self::Map(Arc::new(a.as_map().to_owned()))
+ }
other => {
return _not_impl_err!(
"Can't create a scalar from array of type \"{other:?}\""
diff --git a/datafusion/functions/src/core/map.rs
b/datafusion/functions/src/core/map.rs
index 8a8a19d7af..6626831c80 100644
--- a/datafusion/functions/src/core/map.rs
+++ b/datafusion/functions/src/core/map.rs
@@ -28,7 +28,21 @@ use datafusion_common::{exec_err, internal_err, ScalarValue};
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
+/// Check if we can evaluate the expr to constant directly.
+///
+/// # Example
+/// ```sql
+/// SELECT make_map('type', 'test') from test
+/// ```
+/// We can evaluate the result of `make_map` directly.
+fn can_evaluate_to_const(args: &[ColumnarValue]) -> bool {
+ args.iter()
+ .all(|arg| matches!(arg, ColumnarValue::Scalar(_)))
+}
+
fn make_map(args: &[ColumnarValue]) -> Result<ColumnarValue> {
+ let can_evaluate_to_const = can_evaluate_to_const(args);
+
let (key, value): (Vec<_>, Vec<_>) = args
.chunks_exact(2)
.map(|chunk| {
@@ -58,7 +72,7 @@ fn make_map(args: &[ColumnarValue]) -> Result<ColumnarValue> {
Ok(value) => value,
Err(e) => return internal_err!("Error concatenating values: {}", e),
};
- make_map_batch_internal(key, value)
+ make_map_batch_internal(key, value, can_evaluate_to_const)
}
fn make_map_batch(args: &[ColumnarValue]) -> Result<ColumnarValue> {
@@ -68,9 +82,12 @@ fn make_map_batch(args: &[ColumnarValue]) ->
Result<ColumnarValue> {
args.len()
);
}
+
+ let can_evaluate_to_const = can_evaluate_to_const(args);
+
let key = get_first_array_ref(&args[0])?;
let value = get_first_array_ref(&args[1])?;
- make_map_batch_internal(key, value)
+ make_map_batch_internal(key, value, can_evaluate_to_const)
}
fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result<ArrayRef> {
@@ -85,7 +102,11 @@ fn get_first_array_ref(columnar_value: &ColumnarValue) ->
Result<ArrayRef> {
}
}
-fn make_map_batch_internal(keys: ArrayRef, values: ArrayRef) ->
Result<ColumnarValue> {
+fn make_map_batch_internal(
+ keys: ArrayRef,
+ values: ArrayRef,
+ can_evaluate_to_const: bool,
+) -> Result<ColumnarValue> {
if keys.null_count() > 0 {
return exec_err!("map key cannot be null");
}
@@ -124,8 +145,13 @@ fn make_map_batch_internal(keys: ArrayRef, values:
ArrayRef) -> Result<ColumnarV
.add_buffer(entry_offsets_buffer)
.add_child_data(entry_struct.to_data())
.build()?;
+ let map_array = Arc::new(MapArray::from(map_data));
- Ok(ColumnarValue::Array(Arc::new(MapArray::from(map_data))))
+ Ok(if can_evaluate_to_const {
+ ColumnarValue::Scalar(ScalarValue::try_from_array(map_array.as_ref(),
0)?)
+ } else {
+ ColumnarValue::Array(map_array)
+ })
}
#[derive(Debug)]
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 17855e17be..8414f39f30 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -656,12 +656,35 @@ impl<'a> ConstEvaluator<'a> {
} else {
// Non-ListArray
match ScalarValue::try_from_array(&a, 0) {
- Ok(s) => ConstSimplifyResult::Simplified(s),
+ Ok(s) => {
+ // TODO: support the optimization for `Map` type
after support impl hash for it
+ if matches!(&s, ScalarValue::Map(_)) {
+ ConstSimplifyResult::SimplifyRuntimeError(
+ DataFusionError::NotImplemented("Const
evaluate for Map type is still not supported".to_string()),
+ expr,
+ )
+ } else {
+ ConstSimplifyResult::Simplified(s)
+ }
+ }
Err(err) =>
ConstSimplifyResult::SimplifyRuntimeError(err, expr),
}
}
}
- ColumnarValue::Scalar(s) => ConstSimplifyResult::Simplified(s),
+ ColumnarValue::Scalar(s) => {
+ // TODO: support the optimization for `Map` type after support
impl hash for it
+ if matches!(&s, ScalarValue::Map(_)) {
+ ConstSimplifyResult::SimplifyRuntimeError(
+ DataFusionError::NotImplemented(
+ "Const evaluate for Map type is still not
supported"
+ .to_string(),
+ ),
+ expr,
+ )
+ } else {
+ ConstSimplifyResult::Simplified(s)
+ }
+ }
}
}
}
diff --git a/datafusion/sqllogictest/test_files/map.slt
b/datafusion/sqllogictest/test_files/map.slt
index abf5b2ebbf..fb8917a5f4 100644
--- a/datafusion/sqllogictest/test_files/map.slt
+++ b/datafusion/sqllogictest/test_files/map.slt
@@ -212,3 +212,87 @@ SELECT map(column5, column6) FROM t;
# {k1:1, k2:2}
# {k3: 3}
# {k5: 5}
+
+query ?
+SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', 30, 'OPTION', 29, 'GET', 27,
'PUT', 25, 'DELETE', 24) AS method_count from t;
+----
+{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24}
+{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24}
+{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24}
+
+query I
+SELECT MAKE_MAP('POST', 41, 'HEAD', 33)['POST'] from t;
+----
+41
+41
+41
+
+query ?
+SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', null) from t;
+----
+{POST: 41, HEAD: 33, PATCH: }
+{POST: 41, HEAD: 33, PATCH: }
+{POST: 41, HEAD: 33, PATCH: }
+
+query ?
+SELECT MAKE_MAP('POST', null, 'HEAD', 33, 'PATCH', null) from t;
+----
+{POST: , HEAD: 33, PATCH: }
+{POST: , HEAD: 33, PATCH: }
+{POST: , HEAD: 33, PATCH: }
+
+query ?
+SELECT MAKE_MAP(1, null, 2, 33, 3, null) from t;
+----
+{1: , 2: 33, 3: }
+{1: , 2: 33, 3: }
+{1: , 2: 33, 3: }
+
+query ?
+SELECT MAKE_MAP([1,2], ['a', 'b'], [3,4], ['b']) from t;
+----
+{[1, 2]: [a, b], [3, 4]: [b]}
+{[1, 2]: [a, b], [3, 4]: [b]}
+{[1, 2]: [a, b], [3, 4]: [b]}
+
+query ?
+SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, 30]) from t;
+----
+{POST: 41, HEAD: 33, PATCH: 30}
+{POST: 41, HEAD: 33, PATCH: 30}
+{POST: 41, HEAD: 33, PATCH: 30}
+
+query ?
+SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]) from t;
+----
+{POST: 41, HEAD: 33, PATCH: }
+{POST: 41, HEAD: 33, PATCH: }
+{POST: 41, HEAD: 33, PATCH: }
+
+query ?
+SELECT MAP([[1,2], [3,4]], ['a', 'b']) from t;
+----
+{[1, 2]: a, [3, 4]: b}
+{[1, 2]: a, [3, 4]: b}
+{[1, 2]: a, [3, 4]: b}
+
+query ?
+SELECT MAP(make_array('POST', 'HEAD', 'PATCH'), make_array(41, 33, 30)) from t;
+----
+{POST: 41, HEAD: 33, PATCH: 30}
+{POST: 41, HEAD: 33, PATCH: 30}
+{POST: 41, HEAD: 33, PATCH: 30}
+
+query ?
+SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'FixedSizeList(3,
Utf8)'), arrow_cast(make_array(41, 33, 30), 'FixedSizeList(3, Int64)')) from t;
+----
+{POST: 41, HEAD: 33, PATCH: 30}
+{POST: 41, HEAD: 33, PATCH: 30}
+{POST: 41, HEAD: 33, PATCH: 30}
+
+query ?
+SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'LargeList(Utf8)'),
arrow_cast(make_array(41, 33, 30), 'LargeList(Int64)')) from t;
+----
+{POST: 41, HEAD: 33, PATCH: 30}
+{POST: 41, HEAD: 33, PATCH: 30}
+{POST: 41, HEAD: 33, PATCH: 30}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]