This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 0a2e422411 simplify `array_has` UDF to `InList` expr when haystack is
constant (#15354)
0a2e422411 is described below
commit 0a2e422411ae6134e00adc92bf42761fb330fe67
Author: David Hewitt <[email protected]>
AuthorDate: Tue Mar 25 20:42:46 2025 +0000
simplify `array_has` UDF to `InList` expr when haystack is constant (#15354)
* simplify `array_has` UDF to `InList` expr when haystack is constant
* add `.slt` tests, also simplify with `make_array`
* tweak comment
* add test for `make_array` arg simplification
---
datafusion/functions-nested/src/array_has.rs | 134 +++++++++++++++++++-
datafusion/sqllogictest/test_files/array.slt | 182 +++++++++++++++++++++++++++
2 files changed, 315 insertions(+), 1 deletion(-)
diff --git a/datafusion/functions-nested/src/array_has.rs
b/datafusion/functions-nested/src/array_has.rs
index 1857ead8c5..48ee341566 100644
--- a/datafusion/functions-nested/src/array_has.rs
+++ b/datafusion/functions-nested/src/array_has.rs
@@ -27,13 +27,16 @@ use datafusion_common::cast::as_generic_list_array;
use datafusion_common::utils::string_utils::string_array_to_vec;
use datafusion_common::utils::take_function_args;
use datafusion_common::{exec_err, Result, ScalarValue};
+use datafusion_expr::expr::{InList, ScalarFunction};
+use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
- ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
+ ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
use datafusion_physical_expr_common::datum::compare_with_eq;
use itertools::Itertools;
+use crate::make_array::make_array_udf;
use crate::utils::make_scalar_function;
use std::any::Any;
@@ -121,6 +124,52 @@ impl ScalarUDFImpl for ArrayHas {
Ok(DataType::Boolean)
}
+ fn simplify(
+ &self,
+ mut args: Vec<Expr>,
+ _info: &dyn datafusion_expr::simplify::SimplifyInfo,
+ ) -> Result<ExprSimplifyResult> {
+ let [haystack, needle] = take_function_args(self.name(), &mut args)?;
+
+ // if the haystack is a constant list, we can use an inlist expression
which is more
+ // efficient because the haystack is not varying per-row
+ if let Expr::Literal(ScalarValue::List(array)) = haystack {
+ // TODO: support LargeList
+ // (not supported by `convert_array_to_scalar_vec`)
+ // (FixedSizeList not supported either, but seems to have worked
fine when attempting to
+ // build a reproducer)
+
+ assert_eq!(array.len(), 1); // guarantee of ScalarValue
+ if let Ok(scalar_values) =
+ ScalarValue::convert_array_to_scalar_vec(array.as_ref())
+ {
+ assert_eq!(scalar_values.len(), 1);
+ let list = scalar_values
+ .into_iter()
+ .flatten()
+ .map(Expr::Literal)
+ .collect();
+
+ return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList {
+ expr: Box::new(std::mem::take(needle)),
+ list,
+ negated: false,
+ })));
+ }
+ } else if let Expr::ScalarFunction(ScalarFunction { func, args }) =
haystack {
+ // make_array has a static set of arguments, so we can pull the
arguments out from it
+ if func == &make_array_udf() {
+ return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList {
+ expr: Box::new(std::mem::take(needle)),
+ list: std::mem::take(args),
+ negated: false,
+ })));
+ }
+ }
+
+ Ok(ExprSimplifyResult::Original(args))
+ }
+
fn invoke_with_args(
&self,
args: datafusion_expr::ScalarFunctionArgs,
@@ -542,3 +591,86 @@ fn general_array_has_all_and_any_kernel(
}),
}
}
+
+#[cfg(test)]
+mod tests {
+ use arrow::array::create_array;
+ use datafusion_common::utils::SingleRowListArrayBuilder;
+ use datafusion_expr::{
+ col, execution_props::ExecutionProps, lit,
simplify::ExprSimplifyResult, Expr,
+ ScalarUDFImpl,
+ };
+
+ use crate::expr_fn::make_array;
+
+ use super::ArrayHas;
+
+ #[test]
+ fn test_simplify_array_has_to_in_list() {
+ let haystack = lit(SingleRowListArrayBuilder::new(create_array!(
+ Int32,
+ [1, 2, 3]
+ ))
+ .build_list_scalar());
+ let needle = col("c");
+
+ let props = ExecutionProps::new();
+ let context = datafusion_expr::simplify::SimplifyContext::new(&props);
+
+ let Ok(ExprSimplifyResult::Simplified(Expr::InList(in_list))) =
+ ArrayHas::new().simplify(vec![haystack, needle.clone()], &context)
+ else {
+ panic!("Expected simplified expression");
+ };
+
+ assert_eq!(
+ in_list,
+ datafusion_expr::expr::InList {
+ expr: Box::new(needle),
+ list: vec![lit(1), lit(2), lit(3)],
+ negated: false,
+ }
+ );
+ }
+
+ #[test]
+ fn test_simplify_array_has_with_make_array_to_in_list() {
+ let haystack = make_array(vec![lit(1), lit(2), lit(3)]);
+ let needle = col("c");
+
+ let props = ExecutionProps::new();
+ let context = datafusion_expr::simplify::SimplifyContext::new(&props);
+
+ let Ok(ExprSimplifyResult::Simplified(Expr::InList(in_list))) =
+ ArrayHas::new().simplify(vec![haystack, needle.clone()], &context)
+ else {
+ panic!("Expected simplified expression");
+ };
+
+ assert_eq!(
+ in_list,
+ datafusion_expr::expr::InList {
+ expr: Box::new(needle),
+ list: vec![lit(1), lit(2), lit(3)],
+ negated: false,
+ }
+ );
+ }
+
+ #[test]
+ fn test_array_has_complex_list_not_simplified() {
+ let haystack = col("c1");
+ let needle = col("c2");
+
+ let props = ExecutionProps::new();
+ let context = datafusion_expr::simplify::SimplifyContext::new(&props);
+
+ let Ok(ExprSimplifyResult::Original(args)) =
+ ArrayHas::new().simplify(vec![haystack, needle.clone()], &context)
+ else {
+ panic!("Expected simplified expression");
+ };
+
+ assert_eq!(args, vec![col("c1"), col("c2")],);
+ }
+}
diff --git a/datafusion/sqllogictest/test_files/array.slt
b/datafusion/sqllogictest/test_files/array.slt
index 509c7c182a..352064fbe5 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -5969,6 +5969,188 @@ true false true false false false true true false false
true false true
#----
#true false true false false false true true false false true false true
+# rewrite various array_has operations to InList where the haystack is a
literal list
+# NB that `col in (a, b, c)` is simplified to OR if there are <= 3 elements,
so we make 4-element haystack lists
+
+query I
+with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278',
'a', 'b', 'c');
+----
+1
+
+query TT
+explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278',
'a', 'b', 'c');
+----
+logical_plan
+01)Projection: count(Int64(1)) AS count(*)
+02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
+03)----SubqueryAlias: test
+04)------SubqueryAlias: t
+05)--------Projection:
+06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8),
Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"),
Utf8View("a"), Utf8View("b"), Utf8View("c")])
+07)------------TableScan: tmp_table projection=[value]
+physical_plan
+01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
+02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
+03)----CoalescePartitionsExec
+04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
+05)--------ProjectionExec: expr=[]
+06)----------CoalesceBatchesExec: target_batch_size=8192
+07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN
([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal {
value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value:
Utf8View("c") }])
+08)--------------RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1
+09)----------------LazyMemoryExec: partitions=1,
batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
+
+query I
+with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE needle =
ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']);
+----
+1
+
+query TT
+explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE needle =
ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']);
+----
+logical_plan
+01)Projection: count(Int64(1)) AS count(*)
+02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
+03)----SubqueryAlias: test
+04)------SubqueryAlias: t
+05)--------Projection:
+06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8),
Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"),
Utf8View("a"), Utf8View("b"), Utf8View("c")])
+07)------------TableScan: tmp_table projection=[value]
+physical_plan
+01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
+02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
+03)----CoalescePartitionsExec
+04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
+05)--------ProjectionExec: expr=[]
+06)----------CoalesceBatchesExec: target_batch_size=8192
+07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN
([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal {
value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value:
Utf8View("c") }])
+08)--------------RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1
+09)----------------LazyMemoryExec: partitions=1,
batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
+
+query I
+with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE array_has(['7f4b18de3cfeb9b4ac78c381ee2ad278',
'a', 'b', 'c'], needle);
+----
+1
+
+query TT
+explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE array_has(['7f4b18de3cfeb9b4ac78c381ee2ad278',
'a', 'b', 'c'], needle);
+----
+logical_plan
+01)Projection: count(Int64(1)) AS count(*)
+02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
+03)----SubqueryAlias: test
+04)------SubqueryAlias: t
+05)--------Projection:
+06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8),
Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"),
Utf8View("a"), Utf8View("b"), Utf8View("c")])
+07)------------TableScan: tmp_table projection=[value]
+physical_plan
+01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
+02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
+03)----CoalescePartitionsExec
+04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
+05)--------ProjectionExec: expr=[]
+06)----------CoalesceBatchesExec: target_batch_size=8192
+07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN
([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal {
value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value:
Utf8View("c") }])
+08)--------------RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1
+09)----------------LazyMemoryExec: partitions=1,
batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
+
+# FIXME: due to rewrite below not working, this is _extremely_ slow to evaluate
+# query I
+# with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+# select count(*) from test WHERE
array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'],
'LargeList(Utf8View)'), needle);
+# ----
+# 1
+
+# FIXME: array_has with large list haystack not currently rewritten to InList
+query TT
+explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE
array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'],
'LargeList(Utf8View)'), needle);
+----
+logical_plan
+01)Projection: count(Int64(1)) AS count(*)
+02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
+03)----SubqueryAlias: test
+04)------SubqueryAlias: t
+05)--------Projection:
+06)----------Filter: array_has(LargeList([7f4b18de3cfeb9b4ac78c381ee2ad278, a,
b, c]), substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1),
Int64(32)))
+07)------------TableScan: tmp_table projection=[value]
+physical_plan
+01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
+02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
+03)----CoalescePartitionsExec
+04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
+05)--------ProjectionExec: expr=[]
+06)----------CoalesceBatchesExec: target_batch_size=8192
+07)------------FilterExec: array_has([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b,
c], substr(md5(CAST(value@0 AS Utf8)), 1, 32))
+08)--------------RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1
+09)----------------LazyMemoryExec: partitions=1,
batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
+
+query I
+with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE
array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'],
'FixedSizeList(4, Utf8View)'), needle);
+----
+1
+
+query TT
+explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE
array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'],
'FixedSizeList(4, Utf8View)'), needle);
+----
+logical_plan
+01)Projection: count(Int64(1)) AS count(*)
+02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
+03)----SubqueryAlias: test
+04)------SubqueryAlias: t
+05)--------Projection:
+06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8),
Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"),
Utf8View("a"), Utf8View("b"), Utf8View("c")])
+07)------------TableScan: tmp_table projection=[value]
+physical_plan
+01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
+02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
+03)----CoalescePartitionsExec
+04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
+05)--------ProjectionExec: expr=[]
+06)----------CoalesceBatchesExec: target_batch_size=8192
+07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN
([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal {
value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value:
Utf8View("c") }])
+08)--------------RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1
+09)----------------LazyMemoryExec: partitions=1,
batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
+
+query I
+with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE array_has([needle], needle);
+----
+100000
+
+# TODO: this should probably be possible to completely remove the filter as
always true?
+query TT
+explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM
generate_series(1, 100000) t(i))
+select count(*) from test WHERE array_has([needle], needle);
+----
+logical_plan
+01)Projection: count(Int64(1)) AS count(*)
+02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
+03)----SubqueryAlias: test
+04)------SubqueryAlias: t
+05)--------Projection:
+06)----------Filter: __common_expr_3 = __common_expr_3
+07)------------Projection: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS
Utf8), Int64(1), Int64(32)) AS __common_expr_3
+08)--------------TableScan: tmp_table projection=[value]
+physical_plan
+01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)]
+02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))]
+03)----CoalescePartitionsExec
+04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
+05)--------ProjectionExec: expr=[]
+06)----------CoalesceBatchesExec: target_batch_size=8192
+07)------------FilterExec: __common_expr_3@0 = __common_expr_3@0
+08)--------------ProjectionExec: expr=[substr(md5(CAST(value@0 AS Utf8)), 1,
32) as __common_expr_3]
+09)----------------RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1
+10)------------------LazyMemoryExec: partitions=1,
batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
+
# any operator
query ?
select column3 from arrays where 'L'=any(column3);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]